init commit
This commit is contained in:
37
LICENSE
Normal file
37
LICENSE
Normal file
@ -0,0 +1,37 @@
|
||||
NVIDIA License
|
||||
|
||||
1. Definitions
|
||||
|
||||
“Licensor” means any person or entity that distributes its Work.
|
||||
“Work” means (a) the original work of authorship made available under this license, which may include software, documentation, or other files, and (b) any additions to or derivative works thereof that are made available under this license.
|
||||
The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this license, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work.
|
||||
Works are “made available” under this license by including in or with the Work either (a) a copyright notice referencing the applicability of this license to the Work, or (b) a copy of this license.
|
||||
|
||||
2. License Grant
|
||||
|
||||
2.1 Copyright Grant. Subject to the terms and conditions of this license, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to use, reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form.
|
||||
|
||||
3. Limitations
|
||||
|
||||
3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this license, (b) you include a complete copy of this license with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work.
|
||||
|
||||
3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works; (b) you comply with Other Licenses, and (c) you identify the specific derivative works that are subject to Your Terms and Other Licenses, as applicable. Notwithstanding Your Terms, this license (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself.
|
||||
|
||||
3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. As used herein, “non-commercially” means for non-commercial research purposes only, and excludes any military, surveillance, service of nuclear technology or biometric processing purposes.
|
||||
|
||||
3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this license from such Licensor (including the grant in Section 2.1) will terminate immediately.
|
||||
|
||||
3.5 Trademarks. This license does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this license.
|
||||
|
||||
3.6 Termination. If you violate any term of this license, then your rights under this license (including the grant in Section 2.1) will terminate immediately.
|
||||
|
||||
3.7 Components Under Other Licenses. The Work may include or be distributed with components provided with separate legal notices or terms that accompany the components, such as open source software licenses and other license terms, including but not limited to the Meta OPT-IML 175B License Agreement (“Other Licenses”). The components are subject to the applicable Other Licenses, including any proprietary notices, disclaimers, requirements and extended use rights; except that this Agreement will prevail regarding the use of third-party software, unless a third-party software license requires it license terms to prevail.
|
||||
|
||||
4. Disclaimer of Warranty.
|
||||
|
||||
THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
|
||||
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE.
|
||||
|
||||
5. Limitation of Liability.
|
||||
|
||||
EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
|
||||
66
README.md
Normal file
66
README.md
Normal file
@ -0,0 +1,66 @@
|
||||
<img src="assets/github_banner.gif" width="100%" />
|
||||
|
||||
<div align="center">
|
||||
<p style="font-size: 1.2em;">
|
||||
<a href="https://nitrogen.minedojo.org/"><strong>Website</strong></a> |
|
||||
<a href="https://huggingface.co/nvidia/NitroGen"><strong>Model</strong></a> |
|
||||
<a href="https://huggingface.co/datasets/nvidia/NitroGen"><strong>Dataset</strong></a> |
|
||||
<a href="https://nitrogen.minedojo.org/assets/documents/nitrogen.pdf"><strong>Paper</strong></a>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
|
||||
# NitroGen
|
||||
|
||||
NitroGen is an open foundation model for generalist gaming agents. This multi-game model takes pixel input and predicts gamepad actions.
|
||||
|
||||
NitroGen is trained through behavior cloning on the largest video-action gameplay dataset, assembled exclusively from internet videos. It can be adapted via post-training to unseen games.
|
||||
|
||||
# Installation
|
||||
|
||||
## Prerequisites
|
||||
|
||||
We **do not distribute game environments**, you must use your own copies of the games. This repository only supports running the agent on **Windows games**. You can serve the model from a Linux machine for inference, but the game ultimately has to run on Windows. We have tested on Windows 11 with Python ≥ 3.12.
|
||||
|
||||
## Setup
|
||||
|
||||
Install this repo:
|
||||
```bash
|
||||
git clone https://github.com/MineDojo/NitroGen.git
|
||||
cd NitroGen
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Download NitroGen checkpoint from [HuggingFace](https://huggingface.co/nvidia/NitroGen):
|
||||
```bash
|
||||
hf download nvidia/NitroGen ng.pt
|
||||
```
|
||||
|
||||
# Getting Started
|
||||
|
||||
First, start an inference server for the model:
|
||||
```bash
|
||||
python scripts/serve.py <path_to_ng.pt>
|
||||
```
|
||||
|
||||
Then, run the agent on the game of your choice:
|
||||
```bash
|
||||
python scripts/play.py --process '<game_executable_name>.exe'
|
||||
```
|
||||
|
||||
The `--process` parameter must be the exact executable name of the game you want to play. You can find it by right-clicking on the game process in Windows Task Manager (Ctrl+Shift+Esc), and selecting `Properties`. The process name should be in the `General` tab and end with `.exe`.
|
||||
|
||||
<!-- TODO # Paper and Citation
|
||||
|
||||
If you find our work useful, please consider citing us!
|
||||
|
||||
```bibtex
|
||||
@article{,
|
||||
title = {},
|
||||
author = {},
|
||||
year = {},
|
||||
journal = {}
|
||||
}
|
||||
``` -->
|
||||
|
||||
**Disclaimer**: This project is strictly for research purposes and is not an official NVIDIA product.
|
||||
BIN
assets/github_banner.gif
Normal file
BIN
assets/github_banner.gif
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 28 MiB |
26
nitrogen/cfg.py
Normal file
26
nitrogen/cfg.py
Normal file
@ -0,0 +1,26 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from nitrogen.flow_matching_transformer.nitrogen import NitroGen_Config
|
||||
from nitrogen.mm_tokenizers import NitrogenTokenizerConfig
|
||||
|
||||
class ModalityConfig(BaseModel):
|
||||
frame_per_sample: int = 1 # number of context frames per sample
|
||||
frame_spacing: int | None = None # how many frames to skip between each frame. If None, use action_per_chunk
|
||||
action_per_chunk: int = 8
|
||||
action_shift: int = 1 # how many actions to skip between frame[i] and action_chunk[i]
|
||||
action_interleaving: bool = False # if True, action chunks will be interleaved with context frames and used by the model to predict the next actions
|
||||
token_set: str = "new"
|
||||
|
||||
def model_post_init(self, __context):
|
||||
if self.frame_spacing is None:
|
||||
# Use object.__setattr__ because the model is frozen
|
||||
object.__setattr__(self, 'frame_spacing', self.action_per_chunk)
|
||||
assert self.action_shift >= 1, "Frame shift must be at least 1 for correct action indexing"
|
||||
|
||||
|
||||
class CkptConfig(BaseModel):
|
||||
experiment_name: str = Field(..., description="Name of the experiment")
|
||||
|
||||
model_cfg: NitroGen_Config = Field(..., description="Model configuration. This is a placeholder and should be replaced with the actual model config class.")
|
||||
tokenizer_cfg: NitrogenTokenizerConfig = Field(..., description="Tokenizer configuration. This is a placeholder and should be replaced with the actual tokenizer config class.")
|
||||
modality_cfg: ModalityConfig = Field(..., description="Modality configuration for the dataset mixture.")
|
||||
434
nitrogen/flow_matching_transformer/modules.py
Normal file
434
nitrogen/flow_matching_transformer/modules.py
Normal file
@ -0,0 +1,434 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from diffusers import ConfigMixin, ModelMixin
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
from diffusers.models.attention import Attention, FeedForward
|
||||
from diffusers.models.embeddings import (
|
||||
SinusoidalPositionalEmbedding,
|
||||
TimestepEmbedding,
|
||||
Timesteps,
|
||||
)
|
||||
from torch import nn
|
||||
|
||||
|
||||
class TimestepEncoder(nn.Module):
|
||||
def __init__(self, embedding_dim, compute_dtype=torch.float32):
|
||||
super().__init__()
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
def forward(self, timesteps):
|
||||
dtype = next(self.parameters()).dtype
|
||||
timesteps_proj = self.time_proj(timesteps).to(dtype)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj) # (N, D)
|
||||
return timesteps_emb
|
||||
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
norm_elementwise_affine: bool = False,
|
||||
norm_eps: float = 1e-5,
|
||||
chunk_dim: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
self.chunk_dim = chunk_dim
|
||||
output_dim = embedding_dim * 2
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, output_dim)
|
||||
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
temb = self.linear(self.silu(temb))
|
||||
scale, shift = temb.chunk(2, dim=1)
|
||||
x = self.norm(x) * (1 + scale[:, None]) + shift[:, None]
|
||||
return x
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
dropout=0.0,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
attention_bias: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
|
||||
norm_eps: float = 1e-5,
|
||||
final_dropout: bool = False,
|
||||
attention_type: str = "default",
|
||||
positional_embeddings: Optional[str] = None,
|
||||
num_positional_embeddings: Optional[int] = None,
|
||||
ff_inner_dim: Optional[int] = None,
|
||||
ff_bias: bool = True,
|
||||
attention_out_bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.dropout = dropout
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.activation_fn = activation_fn
|
||||
self.attention_bias = attention_bias
|
||||
self.norm_elementwise_affine = norm_elementwise_affine
|
||||
self.positional_embeddings = positional_embeddings
|
||||
self.num_positional_embeddings = num_positional_embeddings
|
||||
self.norm_type = norm_type
|
||||
|
||||
if positional_embeddings and (num_positional_embeddings is None):
|
||||
raise ValueError(
|
||||
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
|
||||
)
|
||||
|
||||
if positional_embeddings == "sinusoidal":
|
||||
self.pos_embed = SinusoidalPositionalEmbedding(
|
||||
dim, max_seq_length=num_positional_embeddings
|
||||
)
|
||||
else:
|
||||
self.pos_embed = None
|
||||
|
||||
# Define 3 blocks. Each block has its own normalization layer.
|
||||
# 1. Self-Attn
|
||||
if norm_type == "ada_norm":
|
||||
self.norm1 = AdaLayerNorm(dim)
|
||||
else:
|
||||
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
||||
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
upcast_attention=upcast_attention,
|
||||
out_bias=attention_out_bias,
|
||||
)
|
||||
|
||||
# 3. Feed-forward
|
||||
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
||||
self.ff = FeedForward(
|
||||
dim,
|
||||
dropout=dropout,
|
||||
activation_fn=activation_fn,
|
||||
final_dropout=final_dropout,
|
||||
inner_dim=ff_inner_dim,
|
||||
bias=ff_bias,
|
||||
)
|
||||
if final_dropout:
|
||||
self.final_dropout = nn.Dropout(dropout)
|
||||
else:
|
||||
self.final_dropout = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
temb: Optional[torch.LongTensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# 0. Self-Attention
|
||||
if self.norm_type == "ada_norm":
|
||||
norm_hidden_states = self.norm1(hidden_states, temb)
|
||||
else:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
|
||||
if self.pos_embed is not None:
|
||||
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
||||
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
# encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
if self.final_dropout:
|
||||
attn_output = self.final_dropout(attn_output)
|
||||
|
||||
hidden_states = attn_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
# 4. Feed-forward
|
||||
norm_hidden_states = self.norm3(hidden_states)
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
hidden_states = ff_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
return hidden_states
|
||||
|
||||
class DiTConfig(BaseModel):
|
||||
num_attention_heads: int = Field(default=8)
|
||||
attention_head_dim: int = Field(default=64)
|
||||
output_dim: int = Field(default=26)
|
||||
num_layers: int = Field(default=12)
|
||||
dropout: float = Field(default=0.1)
|
||||
attention_bias: bool = Field(default=True)
|
||||
activation_fn: str = Field(default="gelu-approximate")
|
||||
num_embeds_ada_norm: Optional[int] = Field(default=1000)
|
||||
upcast_attention: bool = Field(default=False)
|
||||
norm_type: str = Field(default="ada_norm")
|
||||
norm_elementwise_affine: bool = Field(default=False)
|
||||
norm_eps: float = Field(default=1e-5)
|
||||
max_num_positional_embeddings: int = Field(default=512)
|
||||
compute_dtype: str = Field(default="float32")
|
||||
final_dropout: bool = Field(default=True)
|
||||
positional_embeddings: Optional[str] = Field(default="sinusoidal")
|
||||
interleave_self_attention: bool = Field(default=False)
|
||||
cross_attention_dim: Optional[int] = Field(default=None, description="Dimension of the cross-attention embeddings. If None, no cross-attention is used.")
|
||||
|
||||
|
||||
class DiT(ModelMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(self,config: DiTConfig):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
self.compute_dtype = getattr(torch, self.config.compute_dtype)
|
||||
|
||||
self.attention_head_dim = self.config.attention_head_dim
|
||||
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Timestep encoder
|
||||
self.timestep_encoder = TimestepEncoder(
|
||||
embedding_dim=self.inner_dim, compute_dtype=self.compute_dtype
|
||||
)
|
||||
|
||||
all_blocks = []
|
||||
for idx in range(self.config.num_layers):
|
||||
|
||||
use_self_attn = idx % 2 == 1 and self.config.interleave_self_attention
|
||||
curr_cross_attention_dim = self.config.cross_attention_dim if not use_self_attn else None
|
||||
|
||||
all_blocks += [
|
||||
BasicTransformerBlock(
|
||||
self.inner_dim,
|
||||
self.config.num_attention_heads,
|
||||
self.config.attention_head_dim,
|
||||
dropout=self.config.dropout,
|
||||
activation_fn=self.config.activation_fn,
|
||||
attention_bias=self.config.attention_bias,
|
||||
upcast_attention=self.config.upcast_attention,
|
||||
norm_type=self.config.norm_type,
|
||||
norm_elementwise_affine=self.config.norm_elementwise_affine,
|
||||
norm_eps=self.config.norm_eps,
|
||||
positional_embeddings=self.config.positional_embeddings,
|
||||
num_positional_embeddings=self.config.max_num_positional_embeddings,
|
||||
final_dropout=self.config.final_dropout,
|
||||
cross_attention_dim=curr_cross_attention_dim,
|
||||
)
|
||||
]
|
||||
self.transformer_blocks = nn.ModuleList(all_blocks)
|
||||
|
||||
# Output blocks
|
||||
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
|
||||
self.proj_out_2 = nn.Linear(self.inner_dim, self.config.output_dim)
|
||||
print(
|
||||
"Total number of DiT parameters: ",
|
||||
sum(p.numel() for p in self.parameters() if p.requires_grad),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor, # Shape: (B, T, D)
|
||||
encoder_hidden_states: torch.Tensor, # Shape: (B, S, D)
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
return_all_hidden_states: bool = False,
|
||||
):
|
||||
# Encode timesteps
|
||||
temb = self.timestep_encoder(timestep)
|
||||
|
||||
# Process through transformer blocks - single pass through the blocks
|
||||
hidden_states = hidden_states.contiguous()
|
||||
encoder_hidden_states = encoder_hidden_states.contiguous()
|
||||
|
||||
all_hidden_states = [hidden_states]
|
||||
|
||||
# Process through transformer blocks
|
||||
for idx, block in enumerate(self.transformer_blocks):
|
||||
if idx % 2 == 1 and self.config.interleave_self_attention:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
temb=temb,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=None,
|
||||
temb=temb,
|
||||
)
|
||||
all_hidden_states.append(hidden_states)
|
||||
|
||||
# Output processing
|
||||
conditioning = temb
|
||||
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
||||
if return_all_hidden_states:
|
||||
return self.proj_out_2(hidden_states), all_hidden_states
|
||||
else:
|
||||
return self.proj_out_2(hidden_states)
|
||||
|
||||
|
||||
class SelfAttentionTransformerConfig(BaseModel):
|
||||
num_attention_heads: int = Field(default=8)
|
||||
attention_head_dim: int = Field(default=64)
|
||||
output_dim: int = Field(default=26)
|
||||
num_layers: int = Field(default=12)
|
||||
dropout: float = Field(default=0.1)
|
||||
attention_bias: bool = Field(default=True)
|
||||
activation_fn: str = Field(default="gelu-approximate")
|
||||
num_embeds_ada_norm: Optional[int] = Field(default=1000)
|
||||
upcast_attention: bool = Field(default=False)
|
||||
max_num_positional_embeddings: int = Field(default=512)
|
||||
compute_dtype: str = Field(default="float32")
|
||||
final_dropout: bool = Field(default=True)
|
||||
positional_embeddings: Optional[str] = Field(default="sinusoidal")
|
||||
interleave_self_attention: bool = Field(default=False)
|
||||
|
||||
class SelfAttentionTransformer(ModelMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(self, config: SelfAttentionTransformerConfig):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
self.attention_head_dim = self.config.attention_head_dim
|
||||
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
self.inner_dim,
|
||||
self.config.num_attention_heads,
|
||||
self.config.attention_head_dim,
|
||||
dropout=self.config.dropout,
|
||||
activation_fn=self.config.activation_fn,
|
||||
attention_bias=self.config.attention_bias,
|
||||
upcast_attention=self.config.upcast_attention,
|
||||
positional_embeddings=self.config.positional_embeddings,
|
||||
num_positional_embeddings=self.config.max_num_positional_embeddings,
|
||||
final_dropout=self.config.final_dropout,
|
||||
)
|
||||
for _ in range(self.config.num_layers)
|
||||
]
|
||||
)
|
||||
print(
|
||||
"Total number of SelfAttentionTransformer parameters: ",
|
||||
sum(p.numel() for p in self.parameters() if p.requires_grad),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor, # Shape: (B, T, D)
|
||||
return_all_hidden_states: bool = False,
|
||||
):
|
||||
|
||||
# Process through transformer blocks - single pass through the blocks
|
||||
hidden_states = hidden_states.contiguous()
|
||||
all_hidden_states = [hidden_states]
|
||||
|
||||
# Process through transformer blocks
|
||||
for idx, block in enumerate(self.transformer_blocks):
|
||||
hidden_states = block(hidden_states)
|
||||
all_hidden_states.append(hidden_states)
|
||||
|
||||
if return_all_hidden_states:
|
||||
return hidden_states, all_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CrossAttentionTransformer(ModelMixin, ConfigMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 8,
|
||||
attention_head_dim: int = 64,
|
||||
output_dim: int = 26,
|
||||
num_layers: int = 12,
|
||||
dropout: float = 0.1,
|
||||
attention_bias: bool = True,
|
||||
activation_fn: str = "gelu-approximate",
|
||||
num_embeds_ada_norm: Optional[int] = 1000,
|
||||
upcast_attention: bool = False,
|
||||
max_num_positional_embeddings: int = 512,
|
||||
compute_dtype=torch.float32,
|
||||
final_dropout: bool = True,
|
||||
positional_embeddings: Optional[str] = "sinusoidal",
|
||||
interleave_self_attention=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
self.inner_dim,
|
||||
self.config.num_attention_heads,
|
||||
self.config.attention_head_dim,
|
||||
dropout=self.config.dropout,
|
||||
activation_fn=self.config.activation_fn,
|
||||
attention_bias=self.config.attention_bias,
|
||||
upcast_attention=self.config.upcast_attention,
|
||||
positional_embeddings=positional_embeddings,
|
||||
num_positional_embeddings=self.config.max_num_positional_embeddings,
|
||||
final_dropout=final_dropout,
|
||||
)
|
||||
for _ in range(self.config.num_layers)
|
||||
]
|
||||
)
|
||||
print(
|
||||
"Total number of CrossAttentionTransformer parameters: ",
|
||||
sum(p.numel() for p in self.parameters() if p.requires_grad),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor, # Shape: (B, T, D)
|
||||
encoder_hidden_states: torch.Tensor, # Shape: (B, S, D)
|
||||
):
|
||||
|
||||
# Process through transformer blocks - single pass through the blocks
|
||||
hidden_states = hidden_states.contiguous()
|
||||
encoder_hidden_states = encoder_hidden_states.contiguous()
|
||||
# Process through transformer blocks
|
||||
for idx, block in enumerate(self.transformer_blocks):
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
755
nitrogen/flow_matching_transformer/nitrogen.py
Normal file
755
nitrogen/flow_matching_transformer/nitrogen.py
Normal file
@ -0,0 +1,755 @@
|
||||
from dataclasses import dataclass, field
|
||||
from pydantic import BaseModel, Field
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
from torch.distributions import Beta
|
||||
from transformers import SiglipVisionModel, AutoModel
|
||||
|
||||
from .modules import DiT, DiTConfig, SelfAttentionTransformer, SelfAttentionTransformerConfig
|
||||
|
||||
_PAD_TOKEN = 0
|
||||
_IMG_TOKEN = 1
|
||||
_IMG_SEP_TOKEN = 5
|
||||
_LANG_TOKEN = 2
|
||||
_PROPRIO_TOKEN = 3
|
||||
_ACT_TOKEN = 4
|
||||
_GAME_ID_TOKEN = 6
|
||||
|
||||
|
||||
class NitroGen_Config(BaseModel):
|
||||
model_type: str = Field(default="nitrogen", frozen=True)
|
||||
|
||||
add_pos_embed: bool = Field(default=False, description="Whether to add positional embedding")
|
||||
model_dtype: str = Field(default="float32", description="Model data type.")
|
||||
diffusion_model_cfg: DiTConfig = Field(..., description="Diffusion model configuration.")
|
||||
vl_self_attention_cfg: SelfAttentionTransformerConfig = Field(..., description="VL self-attention configuration.")
|
||||
hidden_size: int = Field(default=1024, description="Input embedding dimension.")
|
||||
max_seq_len: int = Field(default=1024, description="Maxium Sequence Length")
|
||||
action_dim: int = Field(default=None, description="Action dimension.")
|
||||
action_horizon: int = Field(default=None, description="Action horizon.")
|
||||
noise_beta_alpha: float = Field(default=1.5, description="")
|
||||
noise_beta_beta: float = Field(default=1.0, description="")
|
||||
noise_s: float = Field(default=0.999, description="Flow matching noise Beta distribution s.")
|
||||
num_timestep_buckets: int = Field(default=1000, description="Number of timestep discretization buckets.")
|
||||
num_inference_timesteps: int = Field(default=None, description="Number of inference steps for noise diffusion.")
|
||||
max_num_embodiments: int = Field(default=1, description="Number of embodiments.")
|
||||
vision_encoder_name: str = Field(default="google/siglip-large-patch16-256", description="Vision encoder name.")
|
||||
vision_hidden_size: int = Field(default=768, description="Siglip hidden size.")
|
||||
add_view_embed: bool = Field(default=False, description="Whether to add view embedding.")
|
||||
|
||||
tune_vision_tower: bool = Field(default=True, description="Tune vision if True.")
|
||||
tune_mm_projector: bool = Field(default=True, description="Tune mm projector if True.")
|
||||
tune_diffusion_model: bool = Field(default=True, description="Tune diffusion model if True.")
|
||||
tune_multi_projector: bool = Field(default=True, description="Tune multi projector if True.")
|
||||
tune_vl_mixing: bool = Field(default=True, description="Tune vl mixing if True.")
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, yaml_path: str | Path) -> "NitroGen_Config":
|
||||
"""Load configuration from a YAML file."""
|
||||
with open(yaml_path, "r") as f:
|
||||
config_dict = yaml.safe_load(f)
|
||||
return cls.model_validate(config_dict)
|
||||
|
||||
def swish(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class SinusoidalPositionalEncoding(nn.Module):
|
||||
"""
|
||||
Produces a sinusoidal encoding of shape (B, T, w)
|
||||
given timesteps of shape (B, T).
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim):
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
def forward(self, timesteps):
|
||||
# timesteps: shape (B, T)
|
||||
# We'll compute sin/cos frequencies across dim T
|
||||
timesteps = timesteps.float() # ensure float
|
||||
|
||||
B, T = timesteps.shape
|
||||
device = timesteps.device
|
||||
|
||||
half_dim = self.embedding_dim // 2
|
||||
# typical log space frequencies for sinusoidal encoding
|
||||
exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * (
|
||||
torch.log(torch.tensor(10000.0)) / half_dim
|
||||
)
|
||||
# Expand timesteps to (B, T, 1) then multiply
|
||||
freqs = timesteps.unsqueeze(-1) * exponent.exp() # (B, T, half_dim)
|
||||
|
||||
sin = torch.sin(freqs)
|
||||
cos = torch.cos(freqs)
|
||||
enc = torch.cat([sin, cos], dim=-1) # (B, T, w)
|
||||
|
||||
return enc
|
||||
|
||||
|
||||
|
||||
class CategorySpecificLinear(nn.Module):
|
||||
def __init__(self, num_categories, input_dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.num_categories = num_categories
|
||||
# For each category, we have separate weights and biases.
|
||||
self.W = nn.Parameter(0.02 * torch.randn(num_categories, input_dim, hidden_dim))
|
||||
self.b = nn.Parameter(torch.zeros(num_categories, hidden_dim))
|
||||
|
||||
def forward(self, x, cat_ids):
|
||||
selected_W = self.W[cat_ids]
|
||||
selected_b = self.b[cat_ids]
|
||||
return torch.bmm(x, selected_W) + selected_b.unsqueeze(1)
|
||||
|
||||
|
||||
class CategorySpecificMLP(nn.Module):
|
||||
def __init__(self, num_categories, input_dim, hidden_dim, output_dim):
|
||||
super().__init__()
|
||||
self.num_categories = num_categories
|
||||
self.layer1 = CategorySpecificLinear(num_categories, input_dim, hidden_dim)
|
||||
self.layer2 = CategorySpecificLinear(num_categories, hidden_dim, output_dim)
|
||||
|
||||
def forward(self, x, cat_ids):
|
||||
hidden = F.relu(self.layer1(x, cat_ids))
|
||||
return self.layer2(hidden, cat_ids)
|
||||
|
||||
|
||||
class MultiEmbodimentActionEncoder(nn.Module):
|
||||
def __init__(self, action_dim, hidden_size, num_embodiments):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_embodiments = num_embodiments
|
||||
|
||||
# W1: R^{w x d}, W2: R^{w x 2w}, W3: R^{w x w}
|
||||
self.W1 = CategorySpecificLinear(num_embodiments, action_dim, hidden_size) # (d -> w)
|
||||
self.W2 = CategorySpecificLinear(num_embodiments, 2 * hidden_size, hidden_size) # (2w -> w)
|
||||
self.W3 = CategorySpecificLinear(num_embodiments, hidden_size, hidden_size) # (w -> w)
|
||||
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
|
||||
|
||||
def forward(self, actions, timesteps, cat_ids):
|
||||
"""
|
||||
actions: shape (B, T, action_dim)
|
||||
timesteps: shape (B,) -- a single scalar per batch item
|
||||
cat_ids: shape (B,)
|
||||
returns: shape (B, T, hidden_size)
|
||||
"""
|
||||
B, T, _ = actions.shape
|
||||
|
||||
# 1) Expand each batch's single scalar time 'tau' across all T steps
|
||||
# so that shape => (B, T)
|
||||
# e.g. if timesteps is (B,), replicate across T
|
||||
if timesteps.dim() == 1 and timesteps.shape[0] == B:
|
||||
# shape (B,) => (B,T)
|
||||
timesteps = timesteps.unsqueeze(1).expand(-1, T)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Expected `timesteps` to have shape (B,) so we can replicate across T."
|
||||
)
|
||||
|
||||
# 2) Standard action MLP step for shape => (B, T, w)
|
||||
a_emb = self.W1(actions, cat_ids)
|
||||
|
||||
# 3) Get the sinusoidal encoding (B, T, w)
|
||||
tau_emb = self.pos_encoding(timesteps).to(dtype=a_emb.dtype)
|
||||
|
||||
# 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish
|
||||
x = torch.cat([a_emb, tau_emb], dim=-1)
|
||||
x = swish(self.W2(x, cat_ids))
|
||||
|
||||
# 5) Finally W3 => (B, T, w)
|
||||
x = self.W3(x, cat_ids)
|
||||
return x
|
||||
|
||||
|
||||
class NitroGen(torch.nn.Module):
|
||||
config_class = NitroGen_Config
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: NitroGen_Config,
|
||||
game_mapping: dict[str, int] | None = None, # Used to add a game ID token
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.vision_hidden_size = config.vision_hidden_size
|
||||
|
||||
if "siglip" in config.vision_encoder_name:
|
||||
model = SiglipVisionModel.from_pretrained(config.vision_encoder_name)
|
||||
self.vision_encoder = model.vision_model
|
||||
self.vision_encoder_type = "siglip"
|
||||
else:
|
||||
self.vision_encoder = AutoModel.from_pretrained(config.vision_encoder_name)
|
||||
self.vision_encoder_type = "hf_auto"
|
||||
self.beta_dist = Beta(config.noise_beta_alpha, config.noise_beta_beta)
|
||||
self.num_timestep_buckets = config.num_timestep_buckets
|
||||
# self.model = instantiate(config.diffusion_model_cfg)
|
||||
self.model = DiT(config=config.diffusion_model_cfg)
|
||||
self.action_dim = config.action_dim
|
||||
self.action_horizon = config.action_horizon
|
||||
self.num_inference_timesteps = config.num_inference_timesteps
|
||||
|
||||
# self.vl_self_attention_model = instantiate(config.vl_self_attention_cfg)
|
||||
self.vl_self_attention_model = SelfAttentionTransformer(config=config.vl_self_attention_cfg)
|
||||
|
||||
# if config.qformer_cfg is not None:
|
||||
# self.qformer = instantiate(config.qformer_cfg)
|
||||
# else:
|
||||
# self.qformer = nn.Identity()
|
||||
|
||||
# self.state_encoder = CategorySpecificMLP(
|
||||
# num_categories=config.max_num_embodiments,
|
||||
# input_dim=config.max_state_dim,
|
||||
# hidden_dim=self.hidden_size,
|
||||
# output_dim=self.hidden_size,
|
||||
# )
|
||||
self.action_encoder = MultiEmbodimentActionEncoder(
|
||||
action_dim=config.action_dim,
|
||||
hidden_size=self.hidden_size,
|
||||
num_embodiments=config.max_num_embodiments,
|
||||
)
|
||||
|
||||
self.action_decoder = CategorySpecificMLP(
|
||||
num_categories=config.max_num_embodiments,
|
||||
input_dim=self.hidden_size,
|
||||
hidden_dim=self.hidden_size,
|
||||
output_dim=self.action_dim,
|
||||
)
|
||||
# self.mm_vision_select_layer = config.mm_vision_select_layer
|
||||
# if config.mm_projector_cfg is not None:
|
||||
# self.mm_projector = instantiate(config.mm_projector_cfg)
|
||||
# else:
|
||||
self.mm_projector = None
|
||||
if config.add_pos_embed:
|
||||
self.position_embedding = nn.Embedding(config.max_seq_len, self.hidden_size)
|
||||
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
|
||||
# if config.add_view_embed:
|
||||
# self.view_embedding = nn.Embedding(config.max_num_views, self.hidden_size)
|
||||
# nn.init.normal_(self.view_embedding.weight, mean=0.0, std=0.02)
|
||||
|
||||
# self.vision_projector = None
|
||||
# if config.vision_hidden_size != self.hidden_size:
|
||||
# self.vision_projector = nn.Sequential(
|
||||
# nn.Linear(config.vision_hidden_size, self.hidden_size),
|
||||
# nn.LayerNorm(self.hidden_size),
|
||||
# )
|
||||
|
||||
self.game_mapping = game_mapping
|
||||
# Create an embedding table for game IDs
|
||||
# Game ID tokens will be put inside vision-language tokens
|
||||
# so they need to be projected to the same dimension
|
||||
if self.game_mapping is not None:
|
||||
# 0 = unconditional
|
||||
self.game_embedding = nn.Embedding(
|
||||
len(self.game_mapping),
|
||||
self.vision_hidden_size,
|
||||
padding_idx=0,
|
||||
scale_grad_by_freq=True
|
||||
)
|
||||
|
||||
self.set_trainable_parameters(
|
||||
tune_multi_projector=config.tune_multi_projector,
|
||||
tune_diffusion_model=config.tune_diffusion_model,
|
||||
tune_vision_tower=config.tune_vision_tower,
|
||||
tune_mm_projector=config.tune_mm_projector,
|
||||
tune_vl_mixing=config.tune_vl_mixing,
|
||||
)
|
||||
|
||||
print(
|
||||
"total number of parameters: %e",
|
||||
sum(p.numel() for p in self.parameters() if p.requires_grad),
|
||||
)
|
||||
|
||||
def set_trainable_parameters(
|
||||
self,
|
||||
tune_multi_projector: bool = True,
|
||||
tune_diffusion_model: bool = True,
|
||||
tune_vision_tower: bool = True,
|
||||
tune_mm_projector: bool = True,
|
||||
tune_vl_mixing: bool = True,
|
||||
):
|
||||
self.tune_multi_projector = tune_multi_projector
|
||||
self.tune_diffusion_model = tune_diffusion_model
|
||||
self.tune_vision_tower = tune_vision_tower
|
||||
self.tune_mm_projector = tune_mm_projector
|
||||
self.tune_vl_mixing = tune_vl_mixing
|
||||
|
||||
for param in self.parameters():
|
||||
param.requires_grad = True
|
||||
# ### Always freeze language encoder
|
||||
# self.siglip_model.text_model.requires_grad_(False)
|
||||
# # Freeze unused parameters in siglip vision encoder
|
||||
# self.siglip_model.logit_scale.requires_grad = False
|
||||
# self.siglip_model.logit_bias.requires_grad = False
|
||||
|
||||
# For siglip, we have to
|
||||
if self.vision_encoder_type == "siglip":
|
||||
for param in self.vision_encoder.encoder.layers[11].parameters():
|
||||
param.requires_grad = False
|
||||
for param in self.vision_encoder.head.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
|
||||
# Freeze parameters
|
||||
if not tune_multi_projector:
|
||||
# self.state_encoder.requires_grad_(False)
|
||||
self.action_encoder.requires_grad_(False)
|
||||
self.action_decoder.requires_grad_(False)
|
||||
if self.config.add_pos_embed:
|
||||
self.position_embedding.requires_grad_(False)
|
||||
if self.config.add_view_embed:
|
||||
self.view_embedding.requires_grad_(False)
|
||||
if not tune_diffusion_model:
|
||||
self.model.requires_grad_(False)
|
||||
if not tune_vision_tower:
|
||||
self.vision_encoder.requires_grad_(False)
|
||||
if self.mm_projector is not None and not tune_mm_projector:
|
||||
self.mm_projector.requires_grad_(False)
|
||||
if not tune_vl_mixing:
|
||||
self.vl_self_attention_model.requires_grad_(False)
|
||||
|
||||
print(f"Tune action head multi_projector: {self.tune_multi_projector}")
|
||||
print(f"Tune action head diffusion model: {self.tune_diffusion_model}")
|
||||
print(f"Tune action head vision tower: {self.tune_vision_tower}")
|
||||
print(f"Tune action head mm_projector: {self.tune_mm_projector}")
|
||||
print(f"Tune action head vl_mixing: {self.tune_vl_mixing}")
|
||||
# Check if any parameters are still trainable. If not, print a warning.
|
||||
if not any(p.requires_grad for p in self.parameters()):
|
||||
print("Warning: No action head trainable parameters found.")
|
||||
|
||||
def set_frozen_modules_to_eval_mode(self):
|
||||
"""
|
||||
Huggingface will call model.train() at each training_step. To ensure
|
||||
the expected behaviors for modules like dropout, batchnorm, etc., we
|
||||
need to call model.eval() for the frozen modules.
|
||||
"""
|
||||
if self.training:
|
||||
# self.siglip_model.text_model.eval()
|
||||
if not self.tune_multi_projector:
|
||||
# self.state_encoder.eval()
|
||||
self.action_encoder.eval()
|
||||
self.action_decoder.eval()
|
||||
if self.config.add_pos_embed:
|
||||
self.position_embedding.eval()
|
||||
# if self.config.add_view_embed:
|
||||
# self.view_embedding.eval()
|
||||
if not self.tune_diffusion_model:
|
||||
self.model.eval()
|
||||
if not self.tune_vision_tower:
|
||||
self.vision_encoder.eval()
|
||||
if self.mm_projector is not None and not self.tune_mm_projector:
|
||||
self.mm_projector.eval()
|
||||
if not self.tune_vl_mixing:
|
||||
self.vl_self_attention_model.eval()
|
||||
|
||||
# This function is supposedly incorrect
|
||||
# def sample_time(self, batch_size, device, dtype):
|
||||
# sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype)
|
||||
# return (self.config.noise_s - sample) / self.config.noise_s
|
||||
|
||||
def sample_time(self, batch_size, device, dtype):
|
||||
sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype)
|
||||
return (1 - sample) * self.config.noise_s
|
||||
|
||||
def encode_images(self, images): #, view_ids):
|
||||
batch_size, num_frames, channels, height, width = images.shape
|
||||
images = images.reshape(-1, channels, height, width)
|
||||
|
||||
image_features = self.vision_encoder(images)["last_hidden_state"]
|
||||
image_features = rearrange(image_features, "(b f) n d -> b f n d", f=num_frames)
|
||||
|
||||
# if self.vision_projector is not None:
|
||||
# # change the hidden dimension of the vision features
|
||||
# image_features = self.vision_projector(image_features)
|
||||
if self.mm_projector is not None:
|
||||
image_features = self.mm_projector(image_features) # [B, 256, 1024] -> [B, 16, 1024]
|
||||
return image_features
|
||||
|
||||
def prepare_input_embs(self, vl_token_ids, sa_token_ids, vision, action, dropped_images, game_ids=None):
|
||||
B, T = vl_token_ids.shape
|
||||
vl_embs = torch.full(
|
||||
size=(B, T, self.vision_hidden_size), fill_value=0.0, dtype=vision.dtype, device=vision.device
|
||||
)
|
||||
|
||||
# Extract dimensions from vision tensor
|
||||
B, num_images, tokens_per_image, hidden_size = vision.shape
|
||||
|
||||
# Create mask for _IMG_TOKEN positions
|
||||
vision_mask = (vl_token_ids == _IMG_TOKEN) # [B, T]
|
||||
|
||||
# Flatten vision tensor over the num_images dimension
|
||||
vision_flat = vision.reshape(B, -1, self.vision_hidden_size) # [B, T * tokens_per_image, hidden_size]
|
||||
|
||||
# Create a mask for the flattened vision dimension
|
||||
# Each image contributes tokens_per_image tokens, so expand the mask accordingly
|
||||
non_dropped_mask_expanded = (dropped_images == 0).unsqueeze(-1).repeat(1, 1, tokens_per_image).reshape(B, -1) # [B, T * tokens_per_image]
|
||||
|
||||
# Select only non-dropped vision embeddings
|
||||
# This will give us the embeddings we need to place
|
||||
valid_vision_embs = vision_flat[non_dropped_mask_expanded] # [total_valid_tokens, 1152]
|
||||
|
||||
assert valid_vision_embs.shape[0] == vision_mask.sum().item(), (
|
||||
f"Number of valid vision embeddings {valid_vision_embs.shape[0]} does not match "
|
||||
f"the number of _IMG_TOKEN positions {vision_mask.sum().item()}"
|
||||
)
|
||||
# Now we need to place these at the vision_mask positions
|
||||
# Get indices where vision_mask is True
|
||||
batch_indices, token_indices = vision_mask.nonzero(as_tuple=True)
|
||||
|
||||
# Place the valid embeddings at the masked positions
|
||||
vl_embs[batch_indices, token_indices] = valid_vision_embs
|
||||
|
||||
# Handle Game ID tokens
|
||||
if self.game_mapping is not None and game_ids is not None:
|
||||
game_mask = vl_token_ids == _GAME_ID_TOKEN # shape: (B, T)
|
||||
num_game_tokens = game_mask.sum().item()
|
||||
if num_game_tokens > 0:
|
||||
|
||||
# Assert that each batch item has exactly one game token
|
||||
game_tokens_per_batch = game_mask.sum(dim=1) # [B] - count of game tokens per batch item
|
||||
assert torch.all(game_tokens_per_batch == 1), (
|
||||
f"Expected exactly 1 game token per batch item, but got: {game_tokens_per_batch.tolist()}. "
|
||||
f"Each batch item must have exactly one _GAME_ID_TOKEN."
|
||||
)
|
||||
|
||||
# Get game embeddings for each batch item
|
||||
game_embs = self.game_embedding(game_ids) # [B, vision_hidden_size]
|
||||
batch_indices, token_indices = game_mask.nonzero(as_tuple=True)
|
||||
vl_embs[batch_indices, token_indices] = game_embs[batch_indices].to(dtype=vl_embs.dtype)
|
||||
|
||||
# Project image separator using the learnable sep_embedding.
|
||||
sep_mask = vl_token_ids == _IMG_SEP_TOKEN # shape: (B, T)
|
||||
num_sep = sep_mask.sum().item()
|
||||
if num_sep > 0:
|
||||
# Expand the separator embedding for each occurrence.
|
||||
repeated_sep = self.vis_sep_embedding.unsqueeze(0).expand(num_sep, self.hidden_size)
|
||||
# Assign the separator embeddings to the correct positions.
|
||||
vl_embs[sep_mask] = repeated_sep.to(dtype=vl_embs.dtype)
|
||||
|
||||
B, T = sa_token_ids.shape
|
||||
sa_embs = torch.full(
|
||||
size=(B, T, self.hidden_size), fill_value=0.0, dtype=vision.dtype, device=vision.device
|
||||
)
|
||||
|
||||
# Project state.
|
||||
# state_mask = sa_token_ids == _PROPRIO_TOKEN
|
||||
# state_mask = state_mask.unsqueeze(-1).expand_as(sa_embs)
|
||||
# sa_embs = sa_embs.masked_scatter(state_mask, state)
|
||||
|
||||
# Project action.
|
||||
action_mask = sa_token_ids == _ACT_TOKEN
|
||||
action_mask = action_mask.unsqueeze(-1).expand_as(sa_embs)
|
||||
sa_embs = sa_embs.masked_scatter(action_mask, action)
|
||||
|
||||
# Add positional embeddings
|
||||
pos_ids = torch.arange(T, dtype=torch.long, device=sa_token_ids.device)
|
||||
if self.config.add_pos_embed:
|
||||
pos_embs = self.position_embedding(pos_ids) # (T, hidden_size)
|
||||
pos_embs = pos_embs.unsqueeze(0).expand(B, T, self.hidden_size)
|
||||
sa_embs = sa_embs + pos_embs
|
||||
return vl_embs, sa_embs
|
||||
|
||||
def pack_actions(self, buttons, j_left, j_right):
|
||||
# Check that the first three dims of each input is the same
|
||||
assert buttons.shape[:3] == j_left.shape[:3] == j_right.shape[:3], (
|
||||
f"buttons shape: {buttons.shape}, "
|
||||
f"j_left shape: {j_left.shape}, "
|
||||
f"j_right shape: {j_right.shape}"
|
||||
)
|
||||
|
||||
# Normalize the joysticks to 0,1
|
||||
j_left = (j_left + 1) / 2.
|
||||
j_right = (j_right + 1) / 2.
|
||||
|
||||
# Concatenate the buttons and joysticks along the last dimension
|
||||
action = torch.cat([j_left, j_right, buttons], dim=-1)
|
||||
|
||||
# Squeeze the second dimension of each input: this is the number of chunks, which is 1 here
|
||||
action = action.squeeze(1)
|
||||
return action
|
||||
|
||||
# def unpack_actions(self, actions):
|
||||
# # Unpack the actions into j_left, j_right, buttons
|
||||
# j_left = actions[:, :, :2]
|
||||
# j_right = actions[:, :, 2:4]
|
||||
# buttons = actions[:, :, 4:]
|
||||
|
||||
# # Denormalize the joysticks back to -1,1
|
||||
# j_left = j_left * 2. - 1.
|
||||
# j_right = j_right * 2. - 1.
|
||||
|
||||
# # Clip into [-1,1]
|
||||
# j_left = torch.clamp(j_left, -1, 1)
|
||||
# j_right = torch.clamp(j_right, -1, 1)
|
||||
|
||||
# # Threshold the buttons to 0,1
|
||||
# buttons = (buttons > 0.5).float()
|
||||
# return j_left, j_right, buttons
|
||||
|
||||
# ========= ActionHead required ============
|
||||
def forward(self, data: dict) -> dict:
|
||||
self.set_frozen_modules_to_eval_mode()
|
||||
|
||||
# data = action_input
|
||||
embodiment_id = data["embodiment_id"]
|
||||
|
||||
# # Check which data is present.
|
||||
# has_real_action = action_input.has_real_action
|
||||
has_real_action = data["has_real_action"]
|
||||
|
||||
# 1) Encode images/text/state
|
||||
visual_features = self.encode_images(data["images"]) #, data["view_ids"])
|
||||
# text_features = self.siglip_model.text_model(
|
||||
# input_ids=data["lang_input_ids"]
|
||||
# ).last_hidden_state
|
||||
# state_features = self.state_encoder(data["state"], embodiment_id)
|
||||
|
||||
# 2) Prepare noisy trajectory
|
||||
actions = data["actions"]
|
||||
noise = torch.randn_like(actions)
|
||||
t = self.sample_time(actions.shape[0], device=actions.device, dtype=actions.dtype)
|
||||
t = t[:, None, None] # shape (B,1,1) for broadcast
|
||||
|
||||
noisy_trajectory = (1 - t) * noise + t * actions
|
||||
velocity = actions - noise
|
||||
|
||||
# 3) Convert (continuous) t -> discrete if needed
|
||||
t_discretized = (t[:, 0, 0] * self.num_timestep_buckets).long()
|
||||
|
||||
# 4) Get action encoder embeddings with correct time argument
|
||||
action_features = self.action_encoder(noisy_trajectory, t_discretized, embodiment_id)
|
||||
|
||||
# 5) Prepare full input to DiT (or your model)
|
||||
vl_embs, sa_embs = self.prepare_input_embs(
|
||||
data["vl_token_ids"],
|
||||
data["sa_token_ids"],
|
||||
visual_features,
|
||||
# text_features,
|
||||
# state_features,
|
||||
action_features,
|
||||
data["dropped_images"],
|
||||
game_ids=data.get("game_id"),
|
||||
)
|
||||
|
||||
vl_embs = self.vl_self_attention_model(vl_embs)
|
||||
# vl_embs = self.qformer(vl_embs)
|
||||
model_output, all_hidden_states = self.model(
|
||||
hidden_states=sa_embs,
|
||||
encoder_hidden_states=vl_embs,
|
||||
encoder_attention_mask=data["vl_attn_mask"],
|
||||
timestep=t_discretized,
|
||||
return_all_hidden_states=True,
|
||||
)
|
||||
pred = self.action_decoder(model_output, embodiment_id)
|
||||
pred_actions = pred[:, -actions.shape[1] :]
|
||||
|
||||
# 6) Flow-matching or velocity-prediction MSE
|
||||
# Mask for variable-length trajectories
|
||||
mask = data["actions_mask"] # shape => (B, seq_len_of_actions, ...)
|
||||
raw_loss = F.mse_loss(pred_actions, velocity, reduction="none")
|
||||
mask = has_real_action[:, None, None] * mask
|
||||
raw_loss = raw_loss * mask
|
||||
action_loss = (has_real_action[:, None, None] * raw_loss).sum() / (mask.sum() + 1e-6)
|
||||
|
||||
loss = action_loss
|
||||
|
||||
return {
|
||||
"loss": loss,
|
||||
}
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_action(self, data: dict, old_layout:bool = False) -> dict:
|
||||
"""
|
||||
For i in [0..N-1]:
|
||||
1) t = i/N
|
||||
2) velocity = model(x(t), t)
|
||||
3) x(t + dt) = x(t) + dt * velocity
|
||||
"""
|
||||
|
||||
# data = action_input
|
||||
embodiment_id = data["embodiment_id"]
|
||||
|
||||
batch_size = data["images"].shape[0]
|
||||
device = data["images"].device
|
||||
dtype = data["images"].dtype
|
||||
actions = torch.randn(
|
||||
size=(batch_size, self.config.action_horizon, self.config.action_dim),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# 1) Hyperparameters for flow sampling
|
||||
num_steps = self.num_inference_timesteps
|
||||
dt = 1.0 / num_steps
|
||||
|
||||
# 2) Encode static context (images, text, state) once if it does not depend on actions
|
||||
visual_features = self.encode_images(data["images"]) #, data["view_ids"])
|
||||
# text_features = self.siglip_model.text_model(
|
||||
# input_ids=data["lang_input_ids"]
|
||||
# ).last_hidden_state
|
||||
# state_features = self.state_encoder(data["state"], embodiment_id)
|
||||
|
||||
# 3) Start denoising the actions
|
||||
for i in range(num_steps):
|
||||
# ---- (a) Discretize continuous time in [0,1]
|
||||
t_cont = i / float(num_steps) # e.g. goes 0, 1/N, 2/N, ...
|
||||
t_discretized = int(t_cont * self.num_timestep_buckets)
|
||||
|
||||
# ---- (b) Build embeddings (actions included)
|
||||
# Pass the *current* actions at time t into the action encoder
|
||||
action_features = self.action_encoder(
|
||||
actions,
|
||||
(torch.ones(actions.shape[0]) * t_discretized).to(device),
|
||||
embodiment_id,
|
||||
)
|
||||
vl_embs, sa_embs = self.prepare_input_embs(
|
||||
data["vl_token_ids"],
|
||||
data["sa_token_ids"],
|
||||
visual_features,
|
||||
# text_features,
|
||||
# state_features,
|
||||
action_features,
|
||||
data["dropped_images"],
|
||||
game_ids=data["game_ids"],
|
||||
)
|
||||
vl_embs = self.vl_self_attention_model(vl_embs)
|
||||
# vl_embs = self.qformer(vl_embs)
|
||||
# ---- (c) Forward pass to get velocity = d/dt x(t)
|
||||
timesteps = torch.from_numpy(np.array([t_discretized])).to(device).long()
|
||||
model_output = self.model(
|
||||
hidden_states=sa_embs,
|
||||
encoder_hidden_states=vl_embs,
|
||||
encoder_attention_mask=data["vl_attn_mask"],
|
||||
timestep=timesteps,
|
||||
)
|
||||
pred = self.action_decoder(model_output, embodiment_id)
|
||||
pred_velocity = pred[:, -actions.shape[1] :]
|
||||
|
||||
# ---- (d) Naive Euler step: x(t + dt) = x(t) + dt * velocity
|
||||
actions = actions + dt * pred_velocity
|
||||
|
||||
return {
|
||||
"action_tensor": actions,
|
||||
}
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_action_with_cfg(self, data_cond: dict, data_uncond: dict, cfg_scale: float = 1.0) -> dict:
|
||||
"""
|
||||
Use a form of classifier free guidance to sample actions. This can only be used on
|
||||
models that were trained on multiple frames of actions. The idea is that we sample
|
||||
velocity with and without the frame history, and then we push the sampled actions
|
||||
towards the ones that were sampled with the frame history.
|
||||
|
||||
data_with_hist = conditional input
|
||||
data_without_hist = unconditional input
|
||||
|
||||
This function works with any kind of conditioning, not just history.
|
||||
|
||||
For i in [0..N-1]:
|
||||
1) t = i/N
|
||||
2) velocity = (1 - cfg_scale) * model(x(t), t, None) + cfg_scale * model(x(t), t, history)
|
||||
3) x(t + dt) = x(t) + dt * velocity
|
||||
"""
|
||||
|
||||
# data = action_input
|
||||
embodiment_id = data_cond["embodiment_id"]
|
||||
|
||||
batch_size = data_cond["images"].shape[0]
|
||||
device = data_cond["images"].device
|
||||
dtype = data_cond["images"].dtype
|
||||
actions = torch.randn(
|
||||
size=(batch_size, self.config.action_horizon, self.config.action_dim),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# 1) Hyperparameters for flow sampling
|
||||
num_steps = self.num_inference_timesteps
|
||||
dt = 1.0 / num_steps
|
||||
|
||||
# 2) Encode static context (images, text, state) once if it does not depend on actions
|
||||
visual_features_cond = self.encode_images(data_cond["images"])
|
||||
visual_features_uncond = self.encode_images(data_uncond["images"])
|
||||
# text_features = self.siglip_model.text_model(
|
||||
# input_ids=data["lang_input_ids"]
|
||||
# ).last_hidden_state
|
||||
# state_features = self.state_encoder(data["state"], embodiment_id)
|
||||
|
||||
# 3) Start denoising the actions
|
||||
for i in range(num_steps):
|
||||
# ---- (a) Discretize continuous time in [0,1]
|
||||
t_cont = i / float(num_steps) # e.g. goes 0, 1/N, 2/N, ...
|
||||
t_discretized = int(t_cont * self.num_timestep_buckets)
|
||||
|
||||
# ---- (b) Build embeddings (actions included)
|
||||
# Pass the *current* actions at time t into the action encoder
|
||||
action_features = self.action_encoder(
|
||||
actions,
|
||||
(torch.ones(actions.shape[0]) * t_discretized).to(device),
|
||||
embodiment_id,
|
||||
)
|
||||
|
||||
# Predict velocity with history
|
||||
vl_embs, sa_embs = self.prepare_input_embs(
|
||||
data_cond["vl_token_ids"],
|
||||
data_cond["sa_token_ids"],
|
||||
visual_features_cond,
|
||||
action_features,
|
||||
data_cond["dropped_images"],
|
||||
)
|
||||
vl_embs = self.vl_self_attention_model(vl_embs)
|
||||
# ---- (c) Forward pass to get velocity = d/dt x(t)
|
||||
timesteps = torch.from_numpy(np.array([t_discretized])).to(device).long()
|
||||
model_output = self.model(
|
||||
hidden_states=sa_embs,
|
||||
encoder_hidden_states=vl_embs,
|
||||
encoder_attention_mask=data_cond["vl_attn_mask"],
|
||||
timestep=timesteps,
|
||||
)
|
||||
pred = self.action_decoder(model_output, embodiment_id)
|
||||
pred_velocity_cond = pred[:, -actions.shape[1] :]
|
||||
|
||||
# Predict velocity without history
|
||||
vl_embs, sa_embs = self.prepare_input_embs(
|
||||
data_uncond["vl_token_ids"],
|
||||
data_uncond["sa_token_ids"],
|
||||
visual_features_uncond,
|
||||
action_features,
|
||||
data_uncond["dropped_images"],
|
||||
)
|
||||
vl_embs = self.vl_self_attention_model(vl_embs)
|
||||
# ---- (c) Forward pass to get velocity = d/dt x(t)
|
||||
timesteps = torch.from_numpy(np.array([t_discretized])).to(device).long()
|
||||
model_output = self.model(
|
||||
hidden_states=sa_embs,
|
||||
encoder_hidden_states=vl_embs,
|
||||
encoder_attention_mask=data_uncond["vl_attn_mask"],
|
||||
timestep=timesteps,
|
||||
)
|
||||
pred = self.action_decoder(model_output, embodiment_id)
|
||||
pred_velocity_uncond = pred[:, -actions.shape[1] :]
|
||||
|
||||
# ---- (d) Combine velocities with cfg_scale
|
||||
pred_velocity = pred_velocity_cond + cfg_scale * (pred_velocity_cond - pred_velocity_uncond)
|
||||
|
||||
# ---- (e) Naive Euler step: x(t + dt) = x(t) + dt * velocity
|
||||
actions = actions + dt * pred_velocity
|
||||
|
||||
return {
|
||||
"action_tensor": actions,
|
||||
}
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(iter(self.parameters())).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(iter(self.parameters())).dtype
|
||||
577
nitrogen/game_env.py
Normal file
577
nitrogen/game_env.py
Normal file
@ -0,0 +1,577 @@
|
||||
import time
|
||||
import platform
|
||||
|
||||
import pyautogui
|
||||
import dxcam
|
||||
import pywinctl as pwc
|
||||
import xspeedhack as xsh
|
||||
from gymnasium import Env
|
||||
from gymnasium.spaces import Box, Dict, Discrete
|
||||
from PIL import Image
|
||||
|
||||
import time
|
||||
|
||||
import vgamepad as vg
|
||||
|
||||
import psutil
|
||||
|
||||
assert platform.system().lower() == "windows", "This module is only supported on Windows."
|
||||
import win32process
|
||||
import win32gui
|
||||
import win32api
|
||||
import win32con
|
||||
|
||||
def get_process_info(process_name):
|
||||
"""
|
||||
Get process information for a given process name on Windows.
|
||||
|
||||
Args:
|
||||
process_name (str): Name of the process (e.g., "isaac-ng.exe")
|
||||
|
||||
Returns:
|
||||
list: List of dictionaries containing PID, window_name, and architecture
|
||||
for each matching process. Returns empty list if no process found.
|
||||
"""
|
||||
results = []
|
||||
|
||||
# Find all processes with the given name
|
||||
for proc in psutil.process_iter(['pid', 'name']):
|
||||
try:
|
||||
if proc.info['name'].lower() == process_name.lower():
|
||||
pid = proc.info['pid']
|
||||
|
||||
# Get architecture
|
||||
try:
|
||||
# Check if process is 32-bit or 64-bit
|
||||
process_handle = win32api.OpenProcess(
|
||||
win32con.PROCESS_QUERY_INFORMATION,
|
||||
False,
|
||||
pid
|
||||
)
|
||||
is_wow64 = win32process.IsWow64Process(process_handle)
|
||||
win32api.CloseHandle(process_handle)
|
||||
|
||||
# On 64-bit Windows: WOW64 means "Windows 32-bit on Windows 64-bit", i.e. a 32-bit process
|
||||
architecture = "x86" if is_wow64 else "x64"
|
||||
except:
|
||||
architecture = "unknown"
|
||||
|
||||
# Find windows associated with this PID
|
||||
windows = []
|
||||
|
||||
def enum_window_callback(hwnd, pid_to_find):
|
||||
_, found_pid = win32process.GetWindowThreadProcessId(hwnd)
|
||||
if found_pid == pid_to_find:
|
||||
window_text = win32gui.GetWindowText(hwnd)
|
||||
if window_text and win32gui.IsWindowVisible(hwnd):
|
||||
windows.append({
|
||||
'hwnd': hwnd,
|
||||
'title': window_text,
|
||||
'visible': win32gui.IsWindowVisible(hwnd)
|
||||
})
|
||||
return True
|
||||
|
||||
# Find all windows for this PID
|
||||
try:
|
||||
win32gui.EnumWindows(enum_window_callback, pid)
|
||||
except:
|
||||
pass
|
||||
|
||||
# Choose the best window
|
||||
window_name = None
|
||||
if windows:
|
||||
if len(windows) > 1:
|
||||
print(f"Multiple windows found for PID {pid}: {[win['title'] for win in windows]}")
|
||||
print("Using heuristics to select the correct window...")
|
||||
# Filter out common proxy/helper windows
|
||||
proxy_keywords = ['d3dproxywindow', 'proxy', 'helper', 'overlay']
|
||||
|
||||
# First try to find a visible window without proxy keywords
|
||||
for win in windows:
|
||||
if not any(keyword in win['title'].lower() for keyword in proxy_keywords):
|
||||
window_name = win['title']
|
||||
break
|
||||
|
||||
# If no good window found, just use the first one
|
||||
if window_name is None and windows:
|
||||
window_name = windows[0]['title']
|
||||
|
||||
results.append({
|
||||
'pid': pid,
|
||||
'window_name': window_name,
|
||||
'architecture': architecture
|
||||
})
|
||||
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
||||
continue
|
||||
|
||||
if len(results) == 0:
|
||||
raise ValueError(f"No process found with name: {process_name}")
|
||||
elif len(results) > 1:
|
||||
print(f"Warning: Multiple processes found with name '{process_name}'. Returning first match.")
|
||||
|
||||
return results[0]
|
||||
|
||||
|
||||
XBOX_MAPPING = {
|
||||
"DPAD_UP": "XUSB_GAMEPAD_DPAD_UP",
|
||||
"DPAD_DOWN": "XUSB_GAMEPAD_DPAD_DOWN",
|
||||
"DPAD_LEFT": "XUSB_GAMEPAD_DPAD_LEFT",
|
||||
"DPAD_RIGHT": "XUSB_GAMEPAD_DPAD_RIGHT",
|
||||
"START": "XUSB_GAMEPAD_START",
|
||||
"BACK": "XUSB_GAMEPAD_BACK",
|
||||
"LEFT_SHOULDER": "XUSB_GAMEPAD_LEFT_SHOULDER",
|
||||
"RIGHT_SHOULDER": "XUSB_GAMEPAD_RIGHT_SHOULDER",
|
||||
"GUIDE": "XUSB_GAMEPAD_GUIDE",
|
||||
"WEST": "XUSB_GAMEPAD_X",
|
||||
"SOUTH": "XUSB_GAMEPAD_A",
|
||||
"EAST": "XUSB_GAMEPAD_B",
|
||||
"NORTH": "XUSB_GAMEPAD_Y",
|
||||
"LEFT_TRIGGER": "LEFT_TRIGGER",
|
||||
"RIGHT_TRIGGER": "RIGHT_TRIGGER",
|
||||
"AXIS_LEFTX": "LEFT_JOYSTICK",
|
||||
"AXIS_LEFTY": "LEFT_JOYSTICK",
|
||||
"AXIS_RIGHTX": "RIGHT_JOYSTICK",
|
||||
"AXIS_RIGHTY": "RIGHT_JOYSTICK",
|
||||
"LEFT_THUMB": "XUSB_GAMEPAD_LEFT_THUMB",
|
||||
"RIGHT_THUMB": "XUSB_GAMEPAD_RIGHT_THUMB",
|
||||
}
|
||||
|
||||
PS4_MAPPING = {
|
||||
"DPAD_UP": "DS4_BUTTON_DPAD_NORTH",
|
||||
"DPAD_DOWN": "DS4_BUTTON_DPAD_SOUTH",
|
||||
"DPAD_LEFT": "DS4_BUTTON_DPAD_WEST",
|
||||
"DPAD_RIGHT": "DS4_BUTTON_DPAD_EAST",
|
||||
"START": "DS4_BUTTON_OPTIONS",
|
||||
"BACK": "DS4_BUTTON_SHARE",
|
||||
"LEFT_SHOULDER": "DS4_BUTTON_SHOULDER_LEFT",
|
||||
"RIGHT_SHOULDER": "DS4_BUTTON_SHOULDER_RIGHT",
|
||||
"GUIDE": "DS4_BUTTON_GUIDE",
|
||||
"WEST": "DS4_BUTTON_SQUARE",
|
||||
"SOUTH": "DS4_BUTTON_CROSS",
|
||||
"EAST": "DS4_BUTTON_CIRCLE",
|
||||
"NORTH": "DS4_BUTTON_TRIANGLE",
|
||||
"LEFT_TRIGGER": "LEFT_TRIGGER",
|
||||
"RIGHT_TRIGGER": "RIGHT_TRIGGER",
|
||||
"AXIS_LEFTX": "LEFT_JOYSTICK",
|
||||
"AXIS_LEFTY": "LEFT_JOYSTICK",
|
||||
"AXIS_RIGHTX": "RIGHT_JOYSTICK",
|
||||
"AXIS_RIGHTY": "RIGHT_JOYSTICK",
|
||||
"LEFT_THUMB": "DS4_BUTTON_THUMB_LEFT",
|
||||
"RIGHT_THUMB": "DS4_BUTTON_THUMB_RIGHT",
|
||||
}
|
||||
|
||||
|
||||
class GamepadEmulator:
|
||||
def __init__(self, controller_type="xbox", system="windows"):
|
||||
"""
|
||||
Initialize the GamepadEmulator with a specific controller type and system.
|
||||
|
||||
Parameters:
|
||||
controller_type (str): The type of controller to emulate ("xbox" or "ps4").
|
||||
system (str): The operating system to use, which affects joystick value handling.
|
||||
"""
|
||||
self.controller_type = controller_type
|
||||
self.system = system
|
||||
if controller_type == "xbox":
|
||||
self.gamepad = vg.VX360Gamepad()
|
||||
self.mapping = XBOX_MAPPING
|
||||
elif controller_type == "ps4":
|
||||
self.gamepad = vg.VDS4Gamepad()
|
||||
self.mapping = PS4_MAPPING
|
||||
else:
|
||||
raise ValueError("Unsupported controller type")
|
||||
|
||||
# Initialize joystick values to keep track of the current state
|
||||
self.left_joystick_x: int = 0
|
||||
self.left_joystick_y: int = 0
|
||||
self.right_joystick_x: int = 0
|
||||
self.right_joystick_y: int = 0
|
||||
|
||||
def step(self, action):
|
||||
"""
|
||||
Perform actions based on the provided action dictionary.
|
||||
|
||||
Parameters:
|
||||
action (dict): Dictionary of an action to be performed. Keys are control names,
|
||||
and values are their respective states.
|
||||
"""
|
||||
self.gamepad.reset()
|
||||
|
||||
# Handle buttons
|
||||
for control in [
|
||||
"EAST",
|
||||
"SOUTH",
|
||||
"NORTH",
|
||||
"WEST",
|
||||
"BACK",
|
||||
"GUIDE",
|
||||
"START",
|
||||
"DPAD_DOWN",
|
||||
"DPAD_LEFT",
|
||||
"DPAD_RIGHT",
|
||||
"DPAD_UP",
|
||||
"LEFT_SHOULDER",
|
||||
"RIGHT_SHOULDER",
|
||||
"LEFT_THUMB",
|
||||
"RIGHT_THUMB",
|
||||
]:
|
||||
if control in action:
|
||||
if action[control]:
|
||||
self.press_button(control)
|
||||
else:
|
||||
self.release_button(control)
|
||||
|
||||
# Handle triggers
|
||||
if "LEFT_TRIGGER" in action:
|
||||
self.set_trigger("LEFT_TRIGGER", action["LEFT_TRIGGER"][0])
|
||||
if "RIGHT_TRIGGER" in action:
|
||||
self.set_trigger("RIGHT_TRIGGER", action["RIGHT_TRIGGER"][0])
|
||||
|
||||
# Handle joysticks
|
||||
if "AXIS_LEFTX" in action and "AXIS_LEFTY" in action:
|
||||
self.set_joystick("AXIS_LEFTX", action["AXIS_LEFTX"][0])
|
||||
self.set_joystick("AXIS_LEFTY", action["AXIS_LEFTY"][0])
|
||||
|
||||
if "AXIS_RIGHTX" in action and "AXIS_RIGHTY" in action:
|
||||
self.set_joystick("AXIS_RIGHTX", action["AXIS_RIGHTX"][0])
|
||||
self.set_joystick("AXIS_RIGHTY", action["AXIS_RIGHTY"][0])
|
||||
|
||||
self.gamepad.update()
|
||||
|
||||
def press_button(self, button):
|
||||
"""
|
||||
Press a button on the gamepad.
|
||||
|
||||
Parameters:
|
||||
button (str): The unified name of the button to press.
|
||||
"""
|
||||
button_mapped = self.mapping.get(button)
|
||||
if self.controller_type == "xbox":
|
||||
self.gamepad.press_button(button=getattr(vg.XUSB_BUTTON, button_mapped))
|
||||
elif self.controller_type == "ps4":
|
||||
self.gamepad.press_button(button=getattr(vg.DS4_BUTTONS, button_mapped))
|
||||
else:
|
||||
raise ValueError("Unsupported controller type")
|
||||
|
||||
def release_button(self, button):
|
||||
"""
|
||||
Release a button on the gamepad.
|
||||
|
||||
Parameters:
|
||||
button (str): The unified name of the button to release.
|
||||
"""
|
||||
button_mapped = self.mapping.get(button)
|
||||
if self.controller_type == "xbox":
|
||||
self.gamepad.release_button(button=getattr(vg.XUSB_BUTTON, button_mapped))
|
||||
elif self.controller_type == "ps4":
|
||||
self.gamepad.release_button(button=getattr(vg.DS4_BUTTONS, button_mapped))
|
||||
else:
|
||||
raise ValueError("Unsupported controller type")
|
||||
|
||||
def set_trigger(self, trigger, value):
|
||||
"""
|
||||
Set the value of a trigger on the gamepad.
|
||||
|
||||
Parameters:
|
||||
trigger (str): The unified name of the trigger.
|
||||
value (float): The value to set the trigger to (between 0 and 1).
|
||||
"""
|
||||
value = int(value)
|
||||
trigger_mapped = self.mapping.get(trigger)
|
||||
if trigger_mapped == "LEFT_TRIGGER":
|
||||
self.gamepad.left_trigger(value=value)
|
||||
elif trigger_mapped == "RIGHT_TRIGGER":
|
||||
self.gamepad.right_trigger(value=value)
|
||||
else:
|
||||
raise ValueError("Unsupported trigger action")
|
||||
|
||||
def set_joystick(self, joystick, value):
|
||||
"""
|
||||
Set the position of a joystick on the gamepad.
|
||||
|
||||
Parameters:
|
||||
joystick (str): The name of the joystick axis.
|
||||
value (float): The value to set the joystick axis to (between -32768 and 32767)
|
||||
"""
|
||||
if joystick == "AXIS_LEFTX":
|
||||
self.left_joystick_x = value
|
||||
self.gamepad.left_joystick(x_value=self.left_joystick_x, y_value=self.left_joystick_y)
|
||||
elif joystick == "AXIS_LEFTY":
|
||||
if self.system == "windows":
|
||||
value = -value - 1
|
||||
self.left_joystick_y = value
|
||||
self.gamepad.left_joystick(x_value=self.left_joystick_x, y_value=self.left_joystick_y)
|
||||
elif joystick == "AXIS_RIGHTX":
|
||||
self.right_joystick_x = value
|
||||
self.gamepad.right_joystick(
|
||||
x_value=self.right_joystick_x, y_value=self.right_joystick_y
|
||||
)
|
||||
elif joystick == "AXIS_RIGHTY":
|
||||
if self.system == "windows":
|
||||
value = -value - 1
|
||||
self.right_joystick_y = value
|
||||
self.gamepad.right_joystick(
|
||||
x_value=self.right_joystick_x, y_value=self.right_joystick_y
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unsupported joystick action")
|
||||
|
||||
def wakeup(self, duration=0.1):
|
||||
"""
|
||||
Wake up the controller by pressing a button.
|
||||
|
||||
Parameters:
|
||||
duration (float): Duration to press the button.
|
||||
"""
|
||||
self.gamepad.press_button(vg.XUSB_BUTTON.XUSB_GAMEPAD_LEFT_THUMB)
|
||||
self.gamepad.update()
|
||||
time.sleep(duration)
|
||||
self.gamepad.reset()
|
||||
self.gamepad.update()
|
||||
time.sleep(duration)
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the gamepad to its default state.
|
||||
"""
|
||||
self.gamepad.reset()
|
||||
self.gamepad.update()
|
||||
|
||||
class PyautoguiScreenshotBackend:
|
||||
|
||||
def __init__(self, bbox):
|
||||
self.bbox = bbox
|
||||
|
||||
def screenshot(self):
|
||||
return pyautogui.screenshot(region=self.bbox)
|
||||
|
||||
class DxcamScreenshotBackend:
|
||||
def __init__(self, bbox):
|
||||
import dxcam
|
||||
self.camera = dxcam.create()
|
||||
self.bbox = bbox
|
||||
self.last_screenshot = None
|
||||
|
||||
def screenshot(self):
|
||||
screenshot = self.camera.grab(region=self.bbox)
|
||||
if screenshot is None:
|
||||
print("DXCAM failed to capture frame, trying to use the latest screenshot")
|
||||
if self.last_screenshot is not None:
|
||||
return self.last_screenshot
|
||||
else:
|
||||
return Image.new("RGB", (self.bbox[2], self.bbox[3]), (0, 0, 0))
|
||||
screenshot = Image.fromarray(screenshot)
|
||||
self.last_screenshot = screenshot
|
||||
return screenshot
|
||||
|
||||
|
||||
class GamepadEnv(Env):
|
||||
"""
|
||||
Base class for creating a game environment controlled with a gamepad.
|
||||
|
||||
Attributes:
|
||||
game (str): Name of the game to interact with.
|
||||
image_height (int): Height of the observation space.
|
||||
image_width (int): Width of the observation space.
|
||||
controller_type (str): Platform for the gamepad emulator ("xbox" or "ps4").
|
||||
game_speed (float): Speed multiplier for the game.
|
||||
env_fps (int): Number of actions to perform per second at normal speed.
|
||||
async_mode (bool): Whether to pause/unpause the game during each step.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
game,
|
||||
image_height=1440,
|
||||
image_width=2560,
|
||||
controller_type="xbox",
|
||||
game_speed=1.0,
|
||||
env_fps=10,
|
||||
async_mode=True,
|
||||
screenshot_backend="dxcam",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Assert that system is windows
|
||||
os_name = platform.system().lower()
|
||||
assert os_name == "windows", "This environment is currently only supported on Windows."
|
||||
assert controller_type in ["xbox", "ps4"], "Platform must be either 'xbox' or 'ps4'"
|
||||
assert screenshot_backend in ["pyautogui", "dxcam"], "Screenshot backend must be either 'pyautogui' or 'dxcam'"
|
||||
|
||||
self.game = game
|
||||
self.image_height = int(image_height)
|
||||
self.image_width = int(image_width)
|
||||
self.game_speed = game_speed
|
||||
self.env_fps = env_fps
|
||||
self.step_duration = self.calculate_step_duration()
|
||||
self.async_mode = async_mode
|
||||
|
||||
self.gamepad_emulator = GamepadEmulator(controller_type=controller_type, system=os_name)
|
||||
proc_info = get_process_info(game)
|
||||
|
||||
self.game_pid = proc_info["pid"]
|
||||
self.game_arch = proc_info["architecture"]
|
||||
self.game_window_name = proc_info["window_name"]
|
||||
|
||||
print(f"Game process found: {self.game} (PID: {self.game_pid}, Arch: {self.game_arch}, Window: {self.game_window_name})")
|
||||
|
||||
if self.game_pid is None:
|
||||
raise Exception(f"Could not find PID for game: {game}")
|
||||
|
||||
|
||||
self.observation_space = Box(
|
||||
low=0, high=255, shape=(self.image_height, self.image_width, 3), dtype="uint8"
|
||||
)
|
||||
|
||||
# Define a unified action space
|
||||
self.action_space = Dict(
|
||||
{
|
||||
"BACK": Discrete(2),
|
||||
"GUIDE": Discrete(2),
|
||||
"RIGHT_SHOULDER": Discrete(2),
|
||||
"RIGHT_TRIGGER": Box(low=0.0, high=1.0, shape=(1,)),
|
||||
"LEFT_TRIGGER": Box(low=0.0, high=1.0, shape=(1,)),
|
||||
"LEFT_SHOULDER": Discrete(2),
|
||||
"AXIS_RIGHTX": Box(low=-32768.0, high=32767, shape=(1,)),
|
||||
"AXIS_RIGHTY": Box(low=-32768.0, high=32767, shape=(1,)),
|
||||
"AXIS_LEFTX": Box(low=-32768.0, high=32767, shape=(1,)),
|
||||
"AXIS_LEFTY": Box(low=-32768.0, high=32767, shape=(1,)),
|
||||
"LEFT_THUMB": Discrete(2),
|
||||
"RIGHT_THUMB": Discrete(2),
|
||||
"DPAD_UP": Discrete(2),
|
||||
"DPAD_RIGHT": Discrete(2),
|
||||
"DPAD_DOWN": Discrete(2),
|
||||
"DPAD_LEFT": Discrete(2),
|
||||
"WEST": Discrete(2),
|
||||
"SOUTH": Discrete(2),
|
||||
"EAST": Discrete(2),
|
||||
"NORTH": Discrete(2),
|
||||
"START": Discrete(2),
|
||||
}
|
||||
)
|
||||
|
||||
# Determine window name
|
||||
windows = pwc.getAllWindows()
|
||||
self.game_window = None
|
||||
for window in windows:
|
||||
if window.title == self.game_window_name:
|
||||
self.game_window = window
|
||||
break
|
||||
|
||||
if not self.game_window:
|
||||
raise Exception(f"No window found with game name: {self.game}")
|
||||
|
||||
self.game_window.activate()
|
||||
l, t, r, b = self.game_window.left, self.game_window.top, self.game_window.right, self.game_window.bottom
|
||||
self.bbox = (l, t, r-l, b-t)
|
||||
|
||||
# Initialize speedhack client if using DLL injection
|
||||
self.speedhack_client = xsh.Client(process_id=self.game_pid, arch=self.game_arch)
|
||||
|
||||
# Get the screenshot backend
|
||||
if screenshot_backend == "dxcam":
|
||||
self.screenshot_backend = DxcamScreenshotBackend(self.bbox)
|
||||
elif screenshot_backend == "pyautogui":
|
||||
self.screenshot_backend = PyautoguiScreenshotBackend(self.bbox)
|
||||
else:
|
||||
raise ValueError("Unsupported screenshot backend. Use 'dxcam' or 'pyautogui'.")
|
||||
|
||||
|
||||
def calculate_step_duration(self):
|
||||
"""
|
||||
Calculate the step duration based on game speed and environment FPS.
|
||||
|
||||
Returns:
|
||||
float: Calculated step duration.
|
||||
|
||||
Example:
|
||||
If game_speed=1.0 and env_fps=10, then step_duration
|
||||
will be 0.1 seconds.
|
||||
"""
|
||||
return 1.0 / (self.env_fps * self.game_speed)
|
||||
|
||||
def unpause(self):
|
||||
"""
|
||||
Unpause the game using the specified method.
|
||||
"""
|
||||
self.speedhack_client.set_speed(1.0)
|
||||
|
||||
def pause(self):
|
||||
"""
|
||||
Pause the game using the specified method.
|
||||
"""
|
||||
self.speedhack_client.set_speed(0.0)
|
||||
|
||||
def perform_action(self, action, duration):
|
||||
"""
|
||||
Perform the action without handling the game pause/unpause.
|
||||
|
||||
Parameters:
|
||||
action (dict): Action to be performed.
|
||||
duration (float): Duration for the action step.
|
||||
"""
|
||||
self.gamepad_emulator.step(action)
|
||||
start = time.perf_counter()
|
||||
self.unpause()
|
||||
# Wait until the next step
|
||||
end = start + self.step_duration
|
||||
now = time.perf_counter()
|
||||
while now < end:
|
||||
now = time.perf_counter()
|
||||
self.pause()
|
||||
|
||||
def step(self, action, step_duration=None):
|
||||
"""
|
||||
Perform an action in the game environment and return the observation.
|
||||
|
||||
Parameters:
|
||||
action (dict): Dictionary of the action to be performed. Keys are control names,
|
||||
and values are their respective states.
|
||||
step_duration (float, optional): Duration for which the action should be performed.
|
||||
|
||||
Returns:
|
||||
tuple: (obs, reward, terminated, truncated, info) where obs is the observation of the game environment after performing the action.
|
||||
"""
|
||||
# Determine the duration for this step
|
||||
duration = step_duration if step_duration is not None else self.step_duration
|
||||
|
||||
self.perform_action(action, duration)
|
||||
|
||||
obs = self.render() # Render after pausing the game
|
||||
|
||||
reward = 0.0
|
||||
terminated = False
|
||||
truncated = False
|
||||
info = {}
|
||||
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
def reset(self, seed=None, options=None):
|
||||
"""
|
||||
Reset the environment to its initial state.
|
||||
|
||||
Parameters:
|
||||
seed (int, optional): Random seed.
|
||||
options (dict, optional): Additional options for reset.
|
||||
"""
|
||||
self.gamepad_emulator.wakeup(duration=0.1)
|
||||
time.sleep(1.0)
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Close the environment and release any resources.
|
||||
"""
|
||||
pass # Implement env close logic here
|
||||
|
||||
def render(self):
|
||||
"""
|
||||
Render the current state of the game window as an observation.
|
||||
|
||||
Returns:
|
||||
Image: Observation of the game environment.
|
||||
"""
|
||||
screenshot = self.screenshot_backend.screenshot()
|
||||
screenshot = screenshot.resize((self.image_width, self.image_height))
|
||||
|
||||
return screenshot
|
||||
91
nitrogen/inference_client.py
Normal file
91
nitrogen/inference_client.py
Normal file
@ -0,0 +1,91 @@
|
||||
import time
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
import zmq
|
||||
|
||||
class ModelClient:
|
||||
"""Client for model inference server."""
|
||||
|
||||
def __init__(self, host="localhost", port=5555):
|
||||
"""
|
||||
Initialize client connection.
|
||||
|
||||
Args:
|
||||
host: Server hostname or IP
|
||||
port: Server port
|
||||
"""
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.timeout_ms = 30000
|
||||
|
||||
self.context = zmq.Context()
|
||||
self.socket = self.context.socket(zmq.REQ)
|
||||
self.socket.connect(f"tcp://{host}:{port}")
|
||||
self.socket.setsockopt(zmq.RCVTIMEO, self.timeout_ms) # Set receive timeout
|
||||
|
||||
print(f"Connected to model server at {host}:{port}")
|
||||
|
||||
def predict(self, image: np.ndarray) -> dict:
|
||||
"""
|
||||
Send an image and receive predicted actions.
|
||||
|
||||
Args:
|
||||
image: numpy array (H, W, 3) in RGB format
|
||||
|
||||
Returns:
|
||||
List of action dicts, each containing:
|
||||
- j_left: [x, y] left joystick position
|
||||
- j_right: [x, y] right joystick position
|
||||
- buttons: list of button values
|
||||
"""
|
||||
request = {
|
||||
"type": "predict",
|
||||
"image": image
|
||||
}
|
||||
|
||||
self.socket.send(pickle.dumps(request))
|
||||
response = pickle.loads(self.socket.recv())
|
||||
|
||||
if response["status"] != "ok":
|
||||
raise RuntimeError(f"Server error: {response.get('message', 'Unknown error')}")
|
||||
|
||||
return response["pred"]
|
||||
|
||||
def reset(self):
|
||||
"""Reset the server's session (clear buffers)."""
|
||||
request = {"type": "reset"}
|
||||
|
||||
self.socket.send(pickle.dumps(request))
|
||||
response = pickle.loads(self.socket.recv())
|
||||
|
||||
if response["status"] != "ok":
|
||||
raise RuntimeError(f"Server error: {response.get('message', 'Unknown error')}")
|
||||
|
||||
print("Session reset")
|
||||
|
||||
def info(self) -> dict:
|
||||
"""Get session info from the server."""
|
||||
request = {"type": "info"}
|
||||
|
||||
self.socket.send(pickle.dumps(request))
|
||||
response = pickle.loads(self.socket.recv())
|
||||
|
||||
if response["status"] != "ok":
|
||||
raise RuntimeError(f"Server error: {response.get('message', 'Unknown error')}")
|
||||
|
||||
return response["info"]
|
||||
|
||||
def close(self):
|
||||
"""Close the connection."""
|
||||
self.socket.close()
|
||||
self.context.term()
|
||||
print("Connection closed")
|
||||
|
||||
def __enter__(self):
|
||||
"""Support for context manager."""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Close connection when exiting context."""
|
||||
self.close()
|
||||
276
nitrogen/inference_session.py
Normal file
276
nitrogen/inference_session.py
Normal file
@ -0,0 +1,276 @@
|
||||
import time
|
||||
import json
|
||||
from collections import deque
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from transformers import AutoImageProcessor
|
||||
from nitrogen.flow_matching_transformer.nitrogen import NitroGen, NitroGen_Config
|
||||
from nitrogen.mm_tokenizers import NitrogenTokenizerConfig, NitrogenTokenizer, Tokenizer
|
||||
from nitrogen.cfg import CkptConfig
|
||||
from nitrogen.shared import PATH_REPO
|
||||
|
||||
def summarize_parameters(module, name='model', depth=0, max_depth=3):
|
||||
"""
|
||||
Print a tree-like summary of parameters in a PyTorch module.
|
||||
|
||||
Args:
|
||||
module: PyTorch module to summarize
|
||||
name: Name of the module (for root level)
|
||||
depth: Current depth in the tree
|
||||
max_depth: Maximum depth to traverse
|
||||
"""
|
||||
if depth > max_depth:
|
||||
return
|
||||
|
||||
# Count total parameters in this module
|
||||
total_params = sum(p.numel() for p in module.parameters())
|
||||
trainable_params = sum(p.numel() for p in module.parameters() if p.requires_grad)
|
||||
|
||||
# Print indented summary
|
||||
indent = " " * depth
|
||||
print(f"{indent}{name}: {total_params:,} params ({trainable_params:,} trainable)")
|
||||
|
||||
# Recursively summarize submodules
|
||||
if depth < max_depth:
|
||||
for child_name, child_module in module.named_children():
|
||||
summarize_parameters(child_module, child_name, depth + 1, max_depth)
|
||||
|
||||
|
||||
def load_model(checkpoint_path: str):
|
||||
"""Load model and args from checkpoint."""
|
||||
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
||||
ckpt_config = CkptConfig.model_validate(checkpoint["ckpt_config"])
|
||||
model_cfg = ckpt_config.model_cfg
|
||||
tokenizer_cfg = ckpt_config.tokenizer_cfg
|
||||
|
||||
print("Checkpoint args:")
|
||||
print(json.dumps(ckpt_config.model_dump(), indent=4))
|
||||
|
||||
# Initialize tokenizer and language model
|
||||
img_proc = AutoImageProcessor.from_pretrained(model_cfg.vision_encoder_name)
|
||||
|
||||
# Create VLM with pre-loaded language model
|
||||
if isinstance(model_cfg, NitroGen_Config):
|
||||
assert isinstance(tokenizer_cfg, NitrogenTokenizerConfig), \
|
||||
"NitroGen_Config requires NitrogenTokenizerConfig for tokenization"
|
||||
tokenizer_cfg.training = False
|
||||
if tokenizer_cfg.game_mapping_cfg is not None:
|
||||
tokenizer_cfg.game_mapping_cfg.src_files = [
|
||||
x.replace("/mnt/amlfs-02/shared/gaming/gamingvla", str(PATH_REPO))
|
||||
for x in tokenizer_cfg.game_mapping_cfg.src_files
|
||||
]
|
||||
tokenizer = NitrogenTokenizer(tokenizer_cfg)
|
||||
game_mapping = tokenizer.game_mapping
|
||||
model = NitroGen(config=model_cfg, game_mapping=game_mapping)
|
||||
# model.num_inference_timesteps = 16
|
||||
action_downsample_ratio = 1
|
||||
else:
|
||||
raise ValueError(f"Unsupported model config type: {type(model_cfg)}")
|
||||
|
||||
summarize_parameters(model, max_depth=3)
|
||||
|
||||
print(model)
|
||||
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
model.eval()
|
||||
tokenizer.eval()
|
||||
model.to("cuda")
|
||||
|
||||
return model, tokenizer, img_proc, ckpt_config, game_mapping, action_downsample_ratio
|
||||
|
||||
class InferenceSession:
|
||||
"""Manages state for a single inference session."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
ckpt_path: str,
|
||||
tokenizer: Tokenizer,
|
||||
img_proc,
|
||||
ckpt_config: CkptConfig,
|
||||
game_mapping: dict,
|
||||
selected_game: str,
|
||||
old_layout: bool,
|
||||
cfg_scale: float,
|
||||
action_downsample_ratio: float,
|
||||
context_length=None
|
||||
):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.img_proc = img_proc
|
||||
self.ckpt_config = ckpt_config
|
||||
self.game_mapping = game_mapping
|
||||
self.selected_game = selected_game
|
||||
self.old_layout = old_layout
|
||||
self.cfg_scale = cfg_scale
|
||||
self.action_downsample_ratio = action_downsample_ratio
|
||||
self.ckpt_path = ckpt_path
|
||||
|
||||
# Load modality config
|
||||
self.modality_config = self.ckpt_config.modality_cfg
|
||||
|
||||
self.max_buffer_size = context_length if context_length is not None else self.modality_config.frame_per_sample
|
||||
self.action_interleaving = self.modality_config.action_interleaving
|
||||
self.is_flowmatching = isinstance(self.ckpt_config.model_cfg, NitroGen_Config)
|
||||
|
||||
# Buffers
|
||||
self.obs_buffer = deque(maxlen=self.max_buffer_size)
|
||||
self.action_buffer = deque(maxlen=self.max_buffer_size)
|
||||
|
||||
@classmethod
|
||||
def from_ckpt(cls, checkpoint_path: str, old_layout=False, cfg_scale=1.0, context_length=None):
|
||||
"""Create an InferenceSession from a checkpoint."""
|
||||
model, tokenizer, img_proc, ckpt_config, game_mapping, action_downsample_ratio = load_model(checkpoint_path)
|
||||
|
||||
if game_mapping is not None:
|
||||
# Ask user to pick a game from the list
|
||||
print("Available games in tokenizer mapping:")
|
||||
for game, idx in game_mapping.items():
|
||||
print(f"{idx:03d}: {game}")
|
||||
selected_game = input("Enter the game ID to use (leave empty for unconditional): ")
|
||||
if selected_game == "":
|
||||
selected_game = None
|
||||
else:
|
||||
selected_idx = int(selected_game)
|
||||
assert selected_idx in game_mapping.values(), f"Invalid game ID {selected_idx}"
|
||||
|
||||
candidates = [k for k,v in game_mapping.items() if v == selected_idx]
|
||||
assert len(candidates) == 1, f"Multiple games found for ID {selected_idx}: {candidates}"
|
||||
|
||||
selected_game = candidates[0]
|
||||
else:
|
||||
selected_game = None
|
||||
print("No game mapping available, proceeding without game conditioning")
|
||||
|
||||
return cls(
|
||||
model,
|
||||
checkpoint_path,
|
||||
tokenizer,
|
||||
img_proc,
|
||||
ckpt_config,
|
||||
game_mapping,
|
||||
selected_game,
|
||||
old_layout,
|
||||
cfg_scale,
|
||||
action_downsample_ratio,
|
||||
context_length
|
||||
)
|
||||
|
||||
def info(self):
|
||||
return {
|
||||
"ckpt_path": self.ckpt_path,
|
||||
"selected_game": self.selected_game,
|
||||
"old_layout": self.old_layout,
|
||||
"cfg_scale": self.cfg_scale,
|
||||
"context_length": self.max_buffer_size,
|
||||
"action_interleaving": self.action_interleaving,
|
||||
"is_flowmatching": self.is_flowmatching,
|
||||
"action_downsample_ratio": self.action_downsample_ratio,
|
||||
}
|
||||
|
||||
def reset(self):
|
||||
"""Reset all buffers."""
|
||||
self.obs_buffer.clear()
|
||||
self.action_buffer.clear()
|
||||
|
||||
def predict(self, obs):
|
||||
start_time = time.time()
|
||||
|
||||
current_frame = self.img_proc([obs], return_tensors="pt")["pixel_values"]
|
||||
self.obs_buffer.append(current_frame)
|
||||
|
||||
# Prepare model inputs
|
||||
pixel_values = torch.cat(list(self.obs_buffer), dim=0)
|
||||
|
||||
if self.action_interleaving and len(self.action_buffer) > 0:
|
||||
action_tensors = {
|
||||
key: torch.cat([a[key] for a in list(self.action_buffer)], dim=0)
|
||||
for key in ["buttons", "j_left", "j_right"]
|
||||
}
|
||||
else:
|
||||
action_tensors = {"buttons": None, "j_left": None, "j_right": None}
|
||||
|
||||
print("Running inference with the following inputs:")
|
||||
print(f"- pixel_values: {pixel_values.shape}")
|
||||
print("- action_tensors:")
|
||||
for k, v in action_tensors.items():
|
||||
if v is not None:
|
||||
print(f" - {k}: {v.shape}")
|
||||
else:
|
||||
print(f" - {k}: None")
|
||||
|
||||
# Run inference
|
||||
if self.is_flowmatching:
|
||||
predicted_actions = self._predict_flowmatching(pixel_values, action_tensors)
|
||||
else:
|
||||
predicted_actions = self._predict_ar(pixel_values, action_tensors)
|
||||
|
||||
# Add to action buffer
|
||||
self.action_buffer.append(predicted_actions)
|
||||
|
||||
inference_time = time.time() - start_time
|
||||
print(f"Inference time: {inference_time:.3f}s")
|
||||
|
||||
# Convert to list of action dicts
|
||||
n_actions = len(predicted_actions["buttons"])
|
||||
j_left = predicted_actions["j_left"].squeeze().cpu().numpy()
|
||||
j_right = predicted_actions["j_right"].squeeze().cpu().numpy()
|
||||
buttons = predicted_actions["buttons"].squeeze().cpu().numpy()
|
||||
|
||||
return {
|
||||
"j_left": j_left,
|
||||
"j_right": j_right,
|
||||
"buttons": buttons,
|
||||
}
|
||||
|
||||
def _predict_flowmatching(self, pixel_values, action_tensors):
|
||||
|
||||
available_frames = len(self.obs_buffer)
|
||||
frames = torch.zeros((self.max_buffer_size, *pixel_values.shape[1:]),
|
||||
dtype=pixel_values.dtype, device="cuda")
|
||||
frames[-available_frames:] = pixel_values
|
||||
dropped_frames = torch.zeros((self.max_buffer_size,), dtype=torch.bool, device="cuda")
|
||||
dropped_frames[:self.max_buffer_size - available_frames] = True
|
||||
|
||||
data_with_history = {
|
||||
"frames": frames,
|
||||
"dropped_frames": dropped_frames,
|
||||
"game": self.selected_game
|
||||
}
|
||||
tokenized_data_with_history = self.tokenizer.encode(data_with_history)
|
||||
|
||||
frame_mask = torch.ones((self.max_buffer_size,), dtype=torch.bool, device="cuda")
|
||||
frame_mask[-1] = False
|
||||
data_without_history = {
|
||||
"frames": frames,
|
||||
"dropped_frames": frame_mask,
|
||||
"game": None
|
||||
}
|
||||
tokenized_data_without_history = self.tokenizer.encode(data_without_history)
|
||||
|
||||
# Convert to CUDA tensors with batch dimension
|
||||
for tokenized_data in [tokenized_data_with_history, tokenized_data_without_history]:
|
||||
for k, v in tokenized_data.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
tokenized_data[k] = v.unsqueeze(0).to("cuda")
|
||||
elif isinstance(v, np.ndarray):
|
||||
tokenized_data[k] = torch.tensor(v, device="cuda").unsqueeze(0)
|
||||
else:
|
||||
tokenized_data[k] = [v]
|
||||
|
||||
with torch.inference_mode():
|
||||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||
if self.cfg_scale == 1.0:
|
||||
model_output = self.model.get_action(tokenized_data_with_history,
|
||||
old_layout=self.old_layout)
|
||||
else:
|
||||
model_output = self.model.get_action_with_cfg(
|
||||
tokenized_data_with_history,
|
||||
tokenized_data_without_history,
|
||||
cfg_scale=self.cfg_scale
|
||||
)
|
||||
predicted_actions = self.tokenizer.decode(model_output)
|
||||
|
||||
return predicted_actions
|
||||
252
nitrogen/inference_viz.py
Normal file
252
nitrogen/inference_viz.py
Normal file
@ -0,0 +1,252 @@
|
||||
import numpy as np
|
||||
import cv2
|
||||
import av
|
||||
|
||||
def create_viz(
|
||||
frame: np.ndarray,
|
||||
i: int,
|
||||
j_left: np.ndarray,
|
||||
j_right: np.ndarray,
|
||||
buttons: np.ndarray,
|
||||
token_set: list,
|
||||
):
|
||||
"""
|
||||
Visualize gamepad actions alongside a gameplay video frame.
|
||||
|
||||
Parameters:
|
||||
- frame: Video frame as numpy array
|
||||
- i: Current frame index (default 0)
|
||||
- j_left: 16x2 array of left joystick positions (-1 to 1)
|
||||
- j_right: 16x2 array of right joystick positions (-1 to 1)
|
||||
- buttons: 16x17 array of button states (boolean)
|
||||
- token_set: List of button names
|
||||
|
||||
Returns:
|
||||
- Visualization as numpy array
|
||||
"""
|
||||
# Get frame dimensions
|
||||
frame_height, frame_width = frame.shape[:2]
|
||||
|
||||
# Create visualization area
|
||||
viz_width = min(500, frame_width)
|
||||
combined_width = frame_width + viz_width
|
||||
combined_height = frame_height
|
||||
|
||||
# Create combined image (frame + visualization)
|
||||
combined = np.zeros((combined_height, combined_width, 3), dtype=np.uint8)
|
||||
|
||||
# Put the frame on the left side
|
||||
combined[:frame_height, :frame_width] = frame
|
||||
|
||||
# Starting position for visualizations
|
||||
viz_x = frame_width
|
||||
viz_y = 20
|
||||
|
||||
# Draw joysticks if data is provided
|
||||
if i < len(j_left) and i < len(j_right):
|
||||
# Add section title
|
||||
cv2.putText(combined, "JOYSTICKS",
|
||||
(viz_x + 10, viz_y),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (200, 200, 200), 1)
|
||||
|
||||
viz_y += 30 # Move down after title
|
||||
|
||||
# Size of joystick visualization
|
||||
joy_size = min(120, viz_width // 3)
|
||||
|
||||
# Horizontal positions of joysticks
|
||||
joy_left_x = viz_x + 30
|
||||
joy_right_x = viz_x + viz_width - joy_size - 30
|
||||
|
||||
# Draw joystick labels
|
||||
cv2.putText(combined, "Left", (joy_left_x, viz_y - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (180, 180, 180), 1)
|
||||
cv2.putText(combined, "Right", (joy_right_x, viz_y - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (180, 180, 180), 1)
|
||||
|
||||
# Draw joysticks
|
||||
draw_joystick(combined, joy_left_x, viz_y, joy_size, j_left[i])
|
||||
draw_joystick(combined, joy_right_x, viz_y, joy_size, j_right[i])
|
||||
|
||||
viz_y += joy_size + 40 # Move down after joysticks
|
||||
|
||||
# Draw buttons if data is provided
|
||||
if buttons is not None and i < len(buttons):
|
||||
# Add section title
|
||||
cv2.putText(combined, "BUTTON STATES",
|
||||
(viz_x + 10, viz_y),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (200, 200, 200), 1)
|
||||
|
||||
viz_y += 30 # Move down after title
|
||||
|
||||
# Size and position of button grid
|
||||
button_grid_x = viz_x + 20
|
||||
button_grid_y = viz_y
|
||||
button_size = 20
|
||||
|
||||
# Draw button grid
|
||||
draw_button_grid(combined, button_grid_x, button_grid_y,
|
||||
button_size, buttons, i, token_set)
|
||||
|
||||
return combined
|
||||
|
||||
def draw_joystick(img, x, y, size, position):
|
||||
"""Draw a joystick visualization at the specified position."""
|
||||
# Draw joystick background
|
||||
cv2.rectangle(img, (x, y), (x + size, y + size), (50, 50, 50), -1)
|
||||
cv2.rectangle(img, (x, y), (x + size, y + size), (100, 100, 100), 1)
|
||||
|
||||
# Calculate center point
|
||||
mid_x = x + size // 2
|
||||
mid_y = y + size // 2
|
||||
|
||||
# Draw center cross (0,0 coordinates)
|
||||
cv2.line(img, (x, mid_y), (x + size, mid_y), (150, 150, 150), 1)
|
||||
cv2.line(img, (mid_x, y), (mid_x, y + size), (150, 150, 150), 1)
|
||||
|
||||
# Draw 2x2 grid
|
||||
quarter_x = x + size // 4
|
||||
quarter_y = y + size // 4
|
||||
three_quarters_x = x + 3 * size // 4
|
||||
three_quarters_y = y + 3 * size // 4
|
||||
|
||||
# Draw grid lines
|
||||
cv2.line(img, (quarter_x, y), (quarter_x, y + size), (100, 100, 100), 1)
|
||||
cv2.line(img, (three_quarters_x, y), (three_quarters_x, y + size), (100, 100, 100), 1)
|
||||
cv2.line(img, (x, quarter_y), (x + size, quarter_y), (100, 100, 100), 1)
|
||||
cv2.line(img, (x, three_quarters_y), (x + size, three_quarters_y), (100, 100, 100), 1)
|
||||
|
||||
# Draw joystick position (clamp coordinates to valid range)
|
||||
px = max(-1, min(1, position[0]))
|
||||
py = max(-1, min(1, position[1]))
|
||||
|
||||
joy_x = int(mid_x + px * size // 2)
|
||||
joy_y = int(mid_y - py * size // 2) # Y is inverted in image coordinates
|
||||
|
||||
# Draw joystick position as a dot
|
||||
cv2.circle(img, (joy_x, joy_y), 5, (0, 0, 255), -1) # Red dot
|
||||
|
||||
def draw_button_grid(img, x, y, button_size, buttons, current_row, token_set):
|
||||
"""Draw the button state grid."""
|
||||
rows, cols = buttons.shape
|
||||
|
||||
# Ensure the grid fits in the visualization area
|
||||
available_width = img.shape[1] - x - 20
|
||||
if cols * button_size > available_width:
|
||||
button_size = max(10, available_width // cols)
|
||||
|
||||
# Draw column numbers at the top
|
||||
for col in range(cols):
|
||||
number_x = x + col * button_size + button_size // 2
|
||||
number_y = y - 5
|
||||
cv2.putText(img, str(col + 1), (number_x - 4, number_y),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.4, (200, 200, 200), 1)
|
||||
|
||||
# Draw button grid
|
||||
for row in range(rows):
|
||||
for col in range(cols):
|
||||
# Calculate button position
|
||||
bx = x + col * button_size
|
||||
by = y + row * button_size
|
||||
|
||||
# Draw button cell
|
||||
color = (0, 255, 0) if buttons[row, col] else (0, 0, 0) # Green if pressed, black otherwise
|
||||
cv2.rectangle(img, (bx, by), (bx + button_size, by + button_size), color, -1)
|
||||
|
||||
# Draw grid lines
|
||||
cv2.rectangle(img, (bx, by), (bx + button_size, by + button_size), (80, 80, 80), 1)
|
||||
|
||||
# Highlight current row
|
||||
highlight_y = y + current_row * button_size
|
||||
cv2.rectangle(img, (x, highlight_y), (x + cols * button_size, highlight_y + button_size),
|
||||
(0, 0, 255), 2) # Red highlight
|
||||
|
||||
# Draw button legend below the mosaic
|
||||
if token_set is not None:
|
||||
legend_y = y + rows * button_size + 20 # Starting Y position for legend
|
||||
legend_x = x # Starting X position for legend
|
||||
line_height = 15 # Height of each legend line
|
||||
|
||||
# Add legend title
|
||||
cv2.putText(img, "Button Legend:", (legend_x, legend_y),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (200, 200, 200), 1)
|
||||
legend_y += line_height + 5 # Move down after title
|
||||
|
||||
# Calculate how many columns to use for the legend based on available space
|
||||
legend_cols = max(1, min(3, cols // 6)) # Use 1-3 columns depending on button count
|
||||
legend_items_per_col = (cols + legend_cols - 1) // legend_cols # Items per column with ceiling division
|
||||
|
||||
# Draw legend entries
|
||||
for col in range(min(cols, len(token_set))):
|
||||
# Calculate position in the legend grid
|
||||
legend_col = col // legend_items_per_col
|
||||
legend_row = col % legend_items_per_col
|
||||
|
||||
# Calculate position
|
||||
entry_x = legend_x + legend_col * (available_width // legend_cols)
|
||||
entry_y = legend_y + legend_row * line_height
|
||||
|
||||
# Add legend entry
|
||||
if col < len(token_set):
|
||||
cv2.putText(img, f"{col+1}. {token_set[col]}", (entry_x, entry_y),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.4, (200, 200, 200), 1)
|
||||
|
||||
class VideoRecorder:
|
||||
def __init__(self, output_file, fps=30, crf=28, preset="fast"):
|
||||
"""
|
||||
Initialize a video recorder using PyAV.
|
||||
|
||||
Args:
|
||||
output_file (str): Path to save the video file
|
||||
fps (int): Frames per second
|
||||
crf (int): Constant Rate Factor (0-51, higher means smaller file but lower quality)
|
||||
preset (str): Encoding preset (ultrafast, superfast, veryfast, faster, fast, medium, slow, slower, veryslow)
|
||||
"""
|
||||
self.output_file = output_file
|
||||
self.fps = fps
|
||||
self.crf = str(crf)
|
||||
self.preset = preset
|
||||
self.container = av.open(output_file, mode="w")
|
||||
self.stream = None
|
||||
|
||||
def init_stream(self, width, height):
|
||||
"""Initialize the video stream with the frame dimensions."""
|
||||
self.stream = self.container.add_stream("h264", rate=self.fps)
|
||||
self.stream.width = width
|
||||
self.stream.height = height
|
||||
self.stream.pix_fmt = "yuv420p"
|
||||
self.stream.options = {
|
||||
"crf": self.crf,
|
||||
"preset": self.preset
|
||||
}
|
||||
|
||||
def add_frame(self, frame):
|
||||
"""
|
||||
Add a frame to the video.
|
||||
|
||||
Args:
|
||||
frame (numpy.ndarray): Frame as RGB numpy array
|
||||
"""
|
||||
if self.stream is None:
|
||||
self.init_stream(frame.shape[1], frame.shape[0])
|
||||
|
||||
av_frame = av.VideoFrame.from_ndarray(np.array(frame), format="rgb24")
|
||||
for packet in self.stream.encode(av_frame):
|
||||
self.container.mux(packet)
|
||||
|
||||
def close(self):
|
||||
"""Flush remaining packets and close the video file."""
|
||||
try:
|
||||
if self.stream is not None:
|
||||
for packet in self.stream.encode():
|
||||
self.container.mux(packet)
|
||||
finally:
|
||||
self.container.close()
|
||||
|
||||
def __enter__(self):
|
||||
"""Support for context manager."""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Close the recorder when exiting the context."""
|
||||
self.close()
|
||||
332
nitrogen/mm_tokenizers.py
Normal file
332
nitrogen/mm_tokenizers.py
Normal file
@ -0,0 +1,332 @@
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import polars as pl
|
||||
from pydantic import BaseModel, Field
|
||||
from itertools import count
|
||||
|
||||
|
||||
DEBUG = int(os.getenv("DEBUG", 0))
|
||||
|
||||
class Tokenizer(ABC):
|
||||
@abstractmethod
|
||||
def encode(self, data: dict) -> dict:
|
||||
"""
|
||||
Transform the input data into a tokenized format.
|
||||
|
||||
Args:
|
||||
data (dict): Input data containing frames and actions.
|
||||
|
||||
Returns:
|
||||
dict: Tokenized data ready for model input.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def decode(self, data: dict) -> dict:
|
||||
"""
|
||||
Reverse the tokenization process to retrieve original data.
|
||||
|
||||
Args:
|
||||
data (dict): Tokenized data.
|
||||
|
||||
Returns:
|
||||
dict: Original data structure.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def train(self):
|
||||
"""
|
||||
Set the tokenizer to training mode.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def eval(self):
|
||||
"""
|
||||
Set the tokenizer to evaluation mode.
|
||||
"""
|
||||
pass
|
||||
|
||||
# Set IDs for each token type.
|
||||
_PAD_TOKEN = 0
|
||||
_IMG_TOKEN = 1
|
||||
_IMG_SEP_TOKEN = 5 # New separator token
|
||||
_LANG_TOKEN = 2
|
||||
_PROPRIO_TOKEN = 3
|
||||
_ACT_TOKEN = 4
|
||||
_GAME_ID_TOKEN = 6
|
||||
|
||||
|
||||
_UNCONDITIONAL_ID = None # Special ID for unconditional game
|
||||
|
||||
class GameMappingConfig(BaseModel):
|
||||
src_files: list[str] = Field(default_factory=list, description="List of source parquet files to build game mapping.")
|
||||
|
||||
def get_game_mapping(cfg: GameMappingConfig) -> dict:
|
||||
game_set = set()
|
||||
for path in cfg.src_files:
|
||||
df = pl.read_parquet(path)
|
||||
for game in df['game_label'].unique():
|
||||
if game == _UNCONDITIONAL_ID:
|
||||
continue
|
||||
game_set.add(game)
|
||||
games = sorted(list(game_set))
|
||||
|
||||
# Set the 0th element to be the unconditional game ID
|
||||
games = [_UNCONDITIONAL_ID] + games
|
||||
return {game: idx for idx, game in enumerate(games)}
|
||||
|
||||
class NitrogenTokenizerConfig(BaseModel):
|
||||
tokenizer_id: Literal['nitrogen'] = Field(default='nitrogen', frozen=True)
|
||||
training: bool = Field(default=True, description="Whether to apply the transform in training mode.")
|
||||
num_visual_tokens_per_frame: int = Field(default=256, description="Number of visual tokens per frame.")
|
||||
max_action_dim: int = Field(default=25, description="Maximum action dimension.")
|
||||
max_sequence_length: int = Field(default=300, description="Maximum sequence length.")
|
||||
action_horizon: int = Field(default=16, description="Action horizon.")
|
||||
game_mapping_cfg: GameMappingConfig | None = Field(default=None, description="Game mapping configuration.")
|
||||
old_layout: bool = Field(default=False, description="Whether to use the old layout for actions. If True, the action layout is [buttons, j_left, j_right]. If False, it is [j_left, j_right, buttons].")
|
||||
|
||||
class NitrogenTokenizer(Tokenizer):
|
||||
"""
|
||||
Example transform that prepares video, language, state, and actions
|
||||
into a token-based format suitable for your model.
|
||||
|
||||
The sub-methods below (prefixed with `_prepare_`) mirror the original
|
||||
modular structure.
|
||||
"""
|
||||
|
||||
def __init__(self, config: NitrogenTokenizerConfig):
|
||||
self.training = config.training
|
||||
self.num_visual_tokens_per_frame = config.num_visual_tokens_per_frame
|
||||
self.max_action_dim = config.max_action_dim
|
||||
self.max_sequence_length = config.max_sequence_length
|
||||
self.action_horizon = config.action_horizon
|
||||
self.old_layout = config.old_layout
|
||||
|
||||
if config.game_mapping_cfg:
|
||||
self.game_mapping = get_game_mapping(config.game_mapping_cfg)
|
||||
with open("game_mapping.json", "w") as f:
|
||||
import json
|
||||
json.dump(self.game_mapping, f, indent=2)
|
||||
else:
|
||||
self.game_mapping = None
|
||||
|
||||
def train(self):
|
||||
self.training = True
|
||||
|
||||
def eval(self):
|
||||
self.training = False
|
||||
|
||||
def check_batch_size(self, data):
|
||||
# Use video key to determine batch size.
|
||||
video_ndim = data["images"].ndim
|
||||
if video_ndim == 4: # Interpret as [T*V, H, W, C]
|
||||
is_batched = False
|
||||
batch_size = 1
|
||||
elif video_ndim == 5: # Interpret as [B, T*V, H, W, C]
|
||||
is_batched = True
|
||||
batch_size = data["images"].shape[0]
|
||||
else:
|
||||
raise ValueError(f"Unsupported video number of dimensions: {video_ndim}")
|
||||
|
||||
return is_batched, batch_size
|
||||
|
||||
def _prepare_action(self, data: dict):
|
||||
"""
|
||||
Pad to max_action_dim, return masks.
|
||||
"""
|
||||
if "action" not in data:
|
||||
actions = np.zeros((self.action_horizon, self.max_action_dim))
|
||||
actions_mask = np.zeros((self.action_horizon, self.max_action_dim), dtype=bool)
|
||||
n_action_tokens = self.action_horizon
|
||||
return actions, actions_mask, n_action_tokens
|
||||
|
||||
actions = data["action"]
|
||||
assert actions.shape[0] == self.action_horizon, f"{actions.shape=}, {self.action_horizon=}"
|
||||
|
||||
n_action_tokens = actions.shape[0] # T
|
||||
n_action_dims = actions.shape[1]
|
||||
|
||||
assert (
|
||||
n_action_dims <= self.max_action_dim
|
||||
), f"Action dim {n_action_dims} exceeds max allowed {self.max_action_dim}."
|
||||
|
||||
# Pad the channel dimension
|
||||
actions = np.pad(actions, ((0, 0), (0, self.max_action_dim - n_action_dims)), "constant")
|
||||
|
||||
# Create mask: [T, max_action_dim]
|
||||
actions_mask = np.zeros((n_action_tokens, self.max_action_dim), dtype=bool)
|
||||
actions_mask[:, :n_action_dims] = True
|
||||
|
||||
return actions, actions_mask, n_action_tokens
|
||||
|
||||
def _build_token_ids(self, n_images, n_action_tokens): # n_lang_tokens, n_state_tokens):
|
||||
"""
|
||||
Build the 1D array of token_ids based on the number of each block.
|
||||
Return (token_ids, special_pad_token_idx).
|
||||
"""
|
||||
vl_token_ids = []
|
||||
sa_token_ids = []
|
||||
|
||||
# 0.5) Add a Game ID placeholder
|
||||
if self.game_mapping:
|
||||
vl_token_ids.append(_GAME_ID_TOKEN)
|
||||
|
||||
# 1) Video placeholders
|
||||
for _ in range(n_images):
|
||||
vl_token_ids.extend([_IMG_TOKEN] * self.num_visual_tokens_per_frame)
|
||||
|
||||
# 2) Action tokens
|
||||
sa_token_ids.extend([_ACT_TOKEN] * n_action_tokens)
|
||||
|
||||
return np.array(vl_token_ids), np.array(sa_token_ids)
|
||||
|
||||
def _prepare_attention_mask(
|
||||
self,
|
||||
vl_token_ids: np.ndarray,
|
||||
):
|
||||
"""
|
||||
Build 1D attention mask for vision-language tokens.
|
||||
1 indicates valid token, 0 indicates padding token.
|
||||
State-action attention will be handled separately by the model.
|
||||
"""
|
||||
# Only create attention mask for vision-language tokens
|
||||
vl_seq_len = vl_token_ids.shape[0]
|
||||
vl_attn_mask = np.ones(vl_seq_len, dtype=bool) # All tokens are valid initially
|
||||
|
||||
# Pad vl_token_ids and vl_attn_mask to max_sequence_length
|
||||
if vl_seq_len > self.max_sequence_length:
|
||||
raise ValueError("VL sequence length exceeds the max sequence length!")
|
||||
|
||||
left_pad_len = self.max_sequence_length - vl_seq_len
|
||||
|
||||
# Pad token_ids (with PAD_TOKEN)
|
||||
vl_token_ids = np.pad(vl_token_ids, (left_pad_len, 0), constant_values=_PAD_TOKEN)
|
||||
|
||||
# Pad attention mask with 0 (padding tokens)
|
||||
vl_attn_mask = np.pad(vl_attn_mask, (left_pad_len, 0), constant_values=0)
|
||||
|
||||
return vl_token_ids, vl_attn_mask
|
||||
|
||||
def pack_actions(self, buttons, j_left, j_right):
|
||||
# Check that the first two dims of each input is the same (number of chunks, control frequency)
|
||||
assert buttons.shape[:2] == j_left.shape[:2] == j_right.shape[:2], (
|
||||
f"buttons shape: {buttons.shape}, "
|
||||
f"j_left shape: {j_left.shape}, "
|
||||
f"j_right shape: {j_right.shape}"
|
||||
)
|
||||
|
||||
# Normalize the joysticks to 0,1
|
||||
j_left = (j_left + 1) / 2.
|
||||
j_right = (j_right + 1) / 2.
|
||||
|
||||
# Concatenate the buttons and joysticks along the last dimension
|
||||
action = np.concatenate([buttons,j_left,j_right],axis=-1, dtype=np.float32)
|
||||
|
||||
# Squeeze the first dimension of each input: this is the number of chunks, which is 1 here
|
||||
action = action.squeeze(0)
|
||||
return action
|
||||
|
||||
def unpack_actions(self, actions):
|
||||
if self.old_layout:
|
||||
# Unpack the actions into j_left, j_right, buttons
|
||||
j_left = actions[:, :, :2]
|
||||
j_right = actions[:, :, 2:4]
|
||||
buttons = actions[:, :, 4:]
|
||||
else:
|
||||
# Unpack the actions into j_left, j_right, buttons
|
||||
buttons = actions[:, :, :-4]
|
||||
j_left = actions[:, :, -4:-2]
|
||||
j_right = actions[:, :, -2:]
|
||||
|
||||
# Denormalize the joysticks back to -1,1
|
||||
j_left = j_left * 2. - 1.
|
||||
j_right = j_right * 2. - 1.
|
||||
|
||||
# Clip into [-1,1]
|
||||
j_left = torch.clamp(j_left, -1, 1)
|
||||
j_right = torch.clamp(j_right, -1, 1)
|
||||
|
||||
# Threshold the buttons to 0/1
|
||||
buttons = (buttons > 0.5).float()
|
||||
return j_left, j_right, buttons
|
||||
|
||||
###########################################################################
|
||||
# apply
|
||||
###########################################################################
|
||||
def encode(self, data: dict) -> dict:
|
||||
"""
|
||||
Main entry point for the transform. We assume that `data` has
|
||||
data['video'], data['language'], data['state'], and data['action'] in
|
||||
the shapes needed. If you have multiple keys for each modality, you
|
||||
could use your own grouping logic (similar to GR1Transform) first.
|
||||
"""
|
||||
|
||||
# 1) Pack buttons/joysticks into a single action tensor
|
||||
|
||||
|
||||
transformed_data = {**data} # Start with a copy of the input data
|
||||
|
||||
n_images = (data["dropped_frames"] == False).sum()
|
||||
transformed_data["images"] = data["frames"]
|
||||
transformed_data["dropped_images"] = data["dropped_frames"]
|
||||
|
||||
if self.training:
|
||||
# Keep the original actions in the data for evaluation
|
||||
packed_actions = self.pack_actions(
|
||||
data["buttons"],
|
||||
data["j_left"],
|
||||
data["j_right"]
|
||||
)
|
||||
data["action"] = packed_actions
|
||||
|
||||
transformed_data["has_real_action"] = np.ones((), dtype=bool)
|
||||
|
||||
actions, actions_mask, n_action_tokens = self._prepare_action(data)
|
||||
transformed_data["actions"] = actions
|
||||
transformed_data["actions_mask"] = actions_mask
|
||||
|
||||
action_and_mask_keys = ["actions", "actions_mask"]
|
||||
assert all(
|
||||
transformed_data[key].shape == transformed_data["actions"].shape
|
||||
for key in action_and_mask_keys
|
||||
), f"Shape mismatch: {[(key, transformed_data[key].shape) for key in action_and_mask_keys]}"
|
||||
else:
|
||||
n_action_tokens = self.action_horizon
|
||||
|
||||
transformed_data["has_detection_target"] = np.zeros((), dtype=bool)
|
||||
|
||||
# 5) Build token_ids
|
||||
vl_token_ids, sa_token_ids = self._build_token_ids(
|
||||
n_images, n_action_tokens
|
||||
)
|
||||
|
||||
# 6) Build the attention mask only for vision-language tokens
|
||||
vl_token_ids, vl_attn_mask = self._prepare_attention_mask(vl_token_ids)
|
||||
|
||||
transformed_data["vl_token_ids"] = vl_token_ids
|
||||
transformed_data["sa_token_ids"] = sa_token_ids
|
||||
transformed_data["vl_attn_mask"] = vl_attn_mask
|
||||
transformed_data["embodiment_id"] = torch.tensor(0, dtype=torch.long)
|
||||
|
||||
if self.game_mapping:
|
||||
game_name = data["game"]
|
||||
assert game_name in self.game_mapping, f"Game '{game_name}' not found in game mapping."
|
||||
transformed_data["game_ids"] = torch.tensor(self.game_mapping[game_name], dtype=torch.long)
|
||||
else:
|
||||
transformed_data["game_ids"] = torch.tensor(0, dtype=torch.long)
|
||||
return transformed_data
|
||||
|
||||
def decode(self, data: dict) -> dict:
|
||||
j_left, j_right, buttons = self.unpack_actions(data["action_tensor"])
|
||||
|
||||
return {
|
||||
"j_left": j_left,
|
||||
"j_right": j_right,
|
||||
"buttons": buttons,
|
||||
}
|
||||
10
nitrogen/shared.py
Normal file
10
nitrogen/shared.py
Normal file
@ -0,0 +1,10 @@
|
||||
from pathlib import Path
|
||||
|
||||
BUTTON_ACTION_TOKENS = [
|
||||
'BACK', 'DPAD_DOWN', 'DPAD_LEFT', 'DPAD_RIGHT', 'DPAD_UP', 'EAST', 'GUIDE',
|
||||
'LEFT_SHOULDER', 'LEFT_THUMB', 'LEFT_TRIGGER', 'NORTH', 'RIGHT_BOTTOM', 'RIGHT_LEFT',
|
||||
'RIGHT_RIGHT', 'RIGHT_SHOULDER', 'RIGHT_THUMB', 'RIGHT_TRIGGER', 'RIGHT_UP', 'SOUTH',
|
||||
'START', 'WEST'
|
||||
]
|
||||
|
||||
PATH_REPO = Path(__file__).parent.parent.resolve()
|
||||
72
pyproject.toml
Normal file
72
pyproject.toml
Normal file
@ -0,0 +1,72 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "nitrogen"
|
||||
version = "1.0.0"
|
||||
description = "VLM-based game playing agent"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
# Default: everything
|
||||
dependencies = [
|
||||
# Shared
|
||||
"numpy",
|
||||
"pyzmq",
|
||||
|
||||
# Serve
|
||||
"torch",
|
||||
"pyyaml",
|
||||
"einops",
|
||||
"transformers",
|
||||
"pydantic",
|
||||
"diffusers",
|
||||
"polars",
|
||||
|
||||
# Play (Windows-only deps marked)
|
||||
"pillow",
|
||||
"opencv-python",
|
||||
"pyautogui",
|
||||
"gymnasium",
|
||||
"psutil",
|
||||
"av",
|
||||
"dxcam; sys_platform == 'win32'",
|
||||
"pywinctl; sys_platform == 'win32'",
|
||||
"vgamepad; sys_platform == 'win32'",
|
||||
"pywin32; sys_platform == 'win32'",
|
||||
"xspeedhack; sys_platform == 'win32'",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
serve = [
|
||||
"numpy",
|
||||
"pyzmq",
|
||||
"torch",
|
||||
"pyyaml",
|
||||
"einops",
|
||||
"transformers",
|
||||
"pydantic",
|
||||
"diffusers",
|
||||
"polars",
|
||||
]
|
||||
|
||||
play = [
|
||||
"numpy",
|
||||
"pyzmq",
|
||||
"pillow",
|
||||
"opencv-python",
|
||||
"pyautogui",
|
||||
"gymnasium",
|
||||
"psutil",
|
||||
"av",
|
||||
"dxcam",
|
||||
"pywinctl",
|
||||
"vgamepad",
|
||||
"pywin32",
|
||||
"xspeedhack",
|
||||
]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["."]
|
||||
exclude = ["scripts*"]
|
||||
232
scripts/play.py
Normal file
232
scripts/play.py
Normal file
@ -0,0 +1,232 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import json
|
||||
from pathlib import Path
|
||||
from collections import OrderedDict
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from nitrogen.game_env import GamepadEnv
|
||||
from nitrogen.shared import BUTTON_ACTION_TOKENS, PATH_REPO
|
||||
from nitrogen.inference_viz import create_viz, VideoRecorder
|
||||
from nitrogen.inference_client import ModelClient
|
||||
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="VLM Inference")
|
||||
parser.add_argument("--process", type=str, default="celeste.exe", help="Game to play")
|
||||
parser.add_argument("--allow-menu", action="store_true", help="Allow menu actions (Disabled by default)")
|
||||
parser.add_argument("--port", type=int, default=5555, help="Port for model server")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
policy = ModelClient(port=args.port)
|
||||
policy.reset()
|
||||
policy_info = policy.info()
|
||||
action_downsample_ratio = policy_info["action_downsample_ratio"]
|
||||
|
||||
CKPT_NAME = Path(policy_info["ckpt_path"]).stem
|
||||
NO_MENU = not args.allow_menu
|
||||
|
||||
PATH_DEBUG = PATH_REPO / "debug"
|
||||
PATH_DEBUG.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
PATH_OUT = (PATH_REPO / "out" / CKPT_NAME).resolve()
|
||||
PATH_OUT.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
BUTTON_PRESS_THRES = 0.5
|
||||
|
||||
# Find in path_out the list of existing video files, named 0001.mp4, 0002.mp4, etc.
|
||||
# If they exist, find the max number and set the next number to be max + 1
|
||||
video_files = sorted(PATH_OUT.glob("*_DEBUG.mp4"))
|
||||
if video_files:
|
||||
existing_numbers = [f.name.split("_")[0] for f in video_files]
|
||||
existing_numbers = [int(n) for n in existing_numbers if n.isdigit()]
|
||||
next_number = max(existing_numbers) + 1
|
||||
else:
|
||||
next_number = 1
|
||||
|
||||
PATH_MP4_DEBUG = PATH_OUT / f"{next_number:04d}_DEBUG.mp4"
|
||||
PATH_MP4_CLEAN = PATH_OUT / f"{next_number:04d}_CLEAN.mp4"
|
||||
PATH_ACTIONS = PATH_OUT / f"{next_number:04d}_ACTIONS.json"
|
||||
|
||||
def preprocess_img(main_image):
|
||||
main_cv = cv2.cvtColor(np.array(main_image), cv2.COLOR_RGB2BGR)
|
||||
final_image = cv2.resize(main_cv, (256, 256), interpolation=cv2.INTER_AREA)
|
||||
return Image.fromarray(cv2.cvtColor(final_image, cv2.COLOR_BGR2RGB))
|
||||
|
||||
zero_action = OrderedDict(
|
||||
[
|
||||
("WEST", 0),
|
||||
("SOUTH", 0),
|
||||
("BACK", 0),
|
||||
("DPAD_DOWN", 0),
|
||||
("DPAD_LEFT", 0),
|
||||
("DPAD_RIGHT", 0),
|
||||
("DPAD_UP", 0),
|
||||
("GUIDE", 0),
|
||||
("AXIS_LEFTX", np.array([0], dtype=np.long)),
|
||||
("AXIS_LEFTY", np.array([0], dtype=np.long)),
|
||||
("LEFT_SHOULDER", 0),
|
||||
("LEFT_TRIGGER", np.array([0], dtype=np.long)),
|
||||
("AXIS_RIGHTX", np.array([0], dtype=np.long)),
|
||||
("AXIS_RIGHTY", np.array([0], dtype=np.long)),
|
||||
("LEFT_THUMB", 0),
|
||||
("RIGHT_THUMB", 0),
|
||||
("RIGHT_SHOULDER", 0),
|
||||
("RIGHT_TRIGGER", np.array([0], dtype=np.long)),
|
||||
("START", 0),
|
||||
("EAST", 0),
|
||||
("NORTH", 0),
|
||||
]
|
||||
)
|
||||
|
||||
TOKEN_SET = BUTTON_ACTION_TOKENS
|
||||
|
||||
print("Model loaded, starting environment...")
|
||||
for i in range(3):
|
||||
print(f"{3 - i}...")
|
||||
time.sleep(1)
|
||||
|
||||
env = GamepadEnv(
|
||||
game=args.process,
|
||||
game_speed=1.0,
|
||||
env_fps=60,
|
||||
async_mode=True,
|
||||
)
|
||||
|
||||
# These games requires to open a menu to initialize the controller
|
||||
if args.process == "isaac-ng.exe":
|
||||
print(f"GamepadEnv ready for {args.process} at {env.env_fps} FPS")
|
||||
input("Press enter to create a virtual controller and start rollouts...")
|
||||
for i in range(3):
|
||||
print(f"{3 - i}...")
|
||||
time.sleep(1)
|
||||
|
||||
def press(button):
|
||||
env.gamepad_emulator.press_button(button)
|
||||
env.gamepad_emulator.gamepad.update()
|
||||
time.sleep(0.05)
|
||||
env.gamepad_emulator.release_button(button)
|
||||
env.gamepad_emulator.gamepad.update()
|
||||
|
||||
press("SOUTH")
|
||||
for k in range(5):
|
||||
press("EAST")
|
||||
time.sleep(0.3)
|
||||
|
||||
if args.process == "Cuphead.exe":
|
||||
print(f"GamepadEnv ready for {args.process} at {env.env_fps} FPS")
|
||||
input("Press enter to create a virtual controller and start rollouts...")
|
||||
for i in range(3):
|
||||
print(f"{3 - i}...")
|
||||
time.sleep(1)
|
||||
|
||||
def press(button):
|
||||
env.gamepad_emulator.press_button(button)
|
||||
env.gamepad_emulator.gamepad.update()
|
||||
time.sleep(0.05)
|
||||
env.gamepad_emulator.release_button(button)
|
||||
env.gamepad_emulator.gamepad.update()
|
||||
|
||||
press("SOUTH")
|
||||
for k in range(5):
|
||||
press("EAST")
|
||||
time.sleep(0.3)
|
||||
|
||||
env.reset()
|
||||
env.pause()
|
||||
|
||||
|
||||
# Initial call to get state
|
||||
obs, reward, terminated, truncated, info = env.step(action=zero_action)
|
||||
|
||||
frames = None
|
||||
step_count = 0
|
||||
|
||||
with VideoRecorder(str(PATH_MP4_DEBUG), fps=60, crf=32, preset="medium") as debug_recorder:
|
||||
with VideoRecorder(str(PATH_MP4_CLEAN), fps=60, crf=28, preset="medium") as clean_recorder:
|
||||
try:
|
||||
while True:
|
||||
obs = preprocess_img(obs)
|
||||
obs.save(PATH_DEBUG / f"{step_count:05d}.png")
|
||||
|
||||
pred = policy.predict(obs)
|
||||
|
||||
j_left, j_right, buttons = pred["j_left"], pred["j_right"], pred["buttons"]
|
||||
|
||||
n = len(buttons)
|
||||
assert n == len(j_left) == len(j_right), "Mismatch in action lengths"
|
||||
|
||||
|
||||
env_actions = []
|
||||
|
||||
for i in range(n):
|
||||
move_action = zero_action.copy()
|
||||
|
||||
xl, yl = j_left[i]
|
||||
xr, yr = j_right[i]
|
||||
move_action["AXIS_LEFTX"] = np.array([int(xl * 32767)], dtype=np.long)
|
||||
move_action["AXIS_LEFTY"] = np.array([int(yl * 32767)], dtype=np.long)
|
||||
move_action["AXIS_RIGHTX"] = np.array([int(xr * 32767)], dtype=np.long)
|
||||
move_action["AXIS_RIGHTY"] = np.array([int(yr * 32767)], dtype=np.long)
|
||||
|
||||
button_vector = buttons[i]
|
||||
assert len(button_vector) == len(TOKEN_SET), "Button vector length does not match token set length"
|
||||
|
||||
|
||||
for name, value in zip(TOKEN_SET, button_vector):
|
||||
if "TRIGGER" in name:
|
||||
move_action[name] = np.array([value * 255], dtype=np.long)
|
||||
else:
|
||||
move_action[name] = 1 if value > BUTTON_PRESS_THRES else 0
|
||||
|
||||
|
||||
env_actions.append(move_action)
|
||||
|
||||
print(f"Executing {len(env_actions)} actions, each action will be repeated {action_downsample_ratio} times")
|
||||
|
||||
for i, a in enumerate(env_actions):
|
||||
if NO_MENU:
|
||||
if a["START"]:
|
||||
print("Model predicted start, disabling this action")
|
||||
a["GUIDE"] = 0
|
||||
a["START"] = 0
|
||||
a["BACK"] = 0
|
||||
|
||||
for _ in range(action_downsample_ratio):
|
||||
obs, reward, terminated, truncated, info = env.step(action=a)
|
||||
|
||||
# resize obs to 720p
|
||||
obs_viz = np.array(obs).copy()
|
||||
clean_viz = cv2.resize(obs_viz, (1920, 1080), interpolation=cv2.INTER_AREA)
|
||||
debug_viz = create_viz(
|
||||
cv2.resize(obs_viz, (1280, 720), interpolation=cv2.INTER_AREA), # 720p
|
||||
i,
|
||||
j_left,
|
||||
j_right,
|
||||
buttons,
|
||||
token_set=TOKEN_SET
|
||||
)
|
||||
debug_recorder.add_frame(debug_viz)
|
||||
clean_recorder.add_frame(clean_viz)
|
||||
|
||||
# Append env_actions dictionnary to JSONL file
|
||||
with open(PATH_ACTIONS, "a") as f:
|
||||
for i, a in enumerate(env_actions):
|
||||
# convert numpy arrays to lists for JSON serialization
|
||||
for k, v in a.items():
|
||||
if isinstance(v, np.ndarray):
|
||||
a[k] = v.tolist()
|
||||
a["step"] = step_count
|
||||
a["substep"] = i
|
||||
json.dump(a, f)
|
||||
f.write("\n")
|
||||
|
||||
|
||||
step_count += 1
|
||||
finally:
|
||||
env.unpause()
|
||||
env.close()
|
||||
64
scripts/serve.py
Normal file
64
scripts/serve.py
Normal file
@ -0,0 +1,64 @@
|
||||
import zmq
|
||||
import argparse
|
||||
import pickle
|
||||
|
||||
from nitrogen.inference_session import InferenceSession
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Model inference server")
|
||||
parser.add_argument("ckpt", type=str, help="Path to checkpoint file")
|
||||
parser.add_argument("--port", type=int, default=5555, help="Port to serve on")
|
||||
parser.add_argument("--old-layout", action="store_true", help="Use old layout")
|
||||
parser.add_argument("--cfg", type=float, default=1.0, help="CFG scale")
|
||||
parser.add_argument("--ctx", type=int, default=1, help="Context length")
|
||||
args = parser.parse_args()
|
||||
|
||||
session = InferenceSession.from_ckpt(args.ckpt, old_layout=args.old_layout, cfg_scale=args.cfg, context_length=args.ctx)
|
||||
|
||||
# Setup ZeroMQ
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REP)
|
||||
socket.bind(f"tcp://*:{args.port}")
|
||||
|
||||
# Create poller
|
||||
poller = zmq.Poller()
|
||||
poller.register(socket, zmq.POLLIN)
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Server running on port {args.port}")
|
||||
print(f"Waiting for requests...")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Poll with 100ms timeout to allow interrupt handling
|
||||
events = dict(poller.poll(timeout=100))
|
||||
if socket in events and events[socket] == zmq.POLLIN:
|
||||
# Receive request only when data is available
|
||||
request = socket.recv()
|
||||
request = pickle.loads(request)
|
||||
if request["type"] == "reset":
|
||||
session.reset()
|
||||
response = {"status": "ok"}
|
||||
print("Session reset")
|
||||
elif request["type"] == "info":
|
||||
info = session.info()
|
||||
response = {"status": "ok", "info": info}
|
||||
print("Sent session info")
|
||||
elif request["type"] == "predict":
|
||||
raw_image = request["image"]
|
||||
result = session.predict(raw_image)
|
||||
response = {
|
||||
"status": "ok",
|
||||
"pred": result
|
||||
}
|
||||
else:
|
||||
response = {"status": "error", "message": f"Unknown request type: {request['type']}"}
|
||||
# Send response
|
||||
socket.send(pickle.dumps(response))
|
||||
except KeyboardInterrupt:
|
||||
print("\nShutting down server...")
|
||||
exit(0)
|
||||
finally:
|
||||
socket.close()
|
||||
context.term()
|
||||
Reference in New Issue
Block a user