init commit

This commit is contained in:
lm
2025-12-18 22:02:20 +01:00
commit 3f467ff158
15 changed files with 3224 additions and 0 deletions

37
LICENSE Normal file
View File

@ -0,0 +1,37 @@
NVIDIA License
1. Definitions
“Licensor” means any person or entity that distributes its Work.
“Work” means (a) the original work of authorship made available under this license, which may include software, documentation, or other files, and (b) any additions to or derivative works thereof that are made available under this license.
The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this license, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work.
Works are “made available” under this license by including in or with the Work either (a) a copyright notice referencing the applicability of this license to the Work, or (b) a copy of this license.
2. License Grant
2.1 Copyright Grant. Subject to the terms and conditions of this license, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to use, reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form.
3. Limitations
3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this license, (b) you include a complete copy of this license with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work.
3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works; (b) you comply with Other Licenses, and (c) you identify the specific derivative works that are subject to Your Terms and Other Licenses, as applicable. Notwithstanding Your Terms, this license (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself.
3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. As used herein, “non-commercially” means for non-commercial research purposes only, and excludes any military, surveillance, service of nuclear technology or biometric processing purposes.
3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this license from such Licensor (including the grant in Section 2.1) will terminate immediately.
3.5 Trademarks. This license does not grant any rights to use any Licensors or its affiliates names, logos, or trademarks, except as necessary to reproduce the notices described in this license.
3.6 Termination. If you violate any term of this license, then your rights under this license (including the grant in Section 2.1) will terminate immediately.
3.7 Components Under Other Licenses. The Work may include or be distributed with components provided with separate legal notices or terms that accompany the components, such as open source software licenses and other license terms, including but not limited to the Meta OPT-IML 175B License Agreement (“Other Licenses”). The components are subject to the applicable Other Licenses, including any proprietary notices, disclaimers, requirements and extended use rights; except that this Agreement will prevail regarding the use of third-party software, unless a third-party software license requires it license terms to prevail.
4. Disclaimer of Warranty.
THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE.
5. Limitation of Liability.
EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.

66
README.md Normal file
View File

@ -0,0 +1,66 @@
<img src="assets/github_banner.gif" width="100%" />
<div align="center">
<p style="font-size: 1.2em;">
<a href="https://nitrogen.minedojo.org/"><strong>Website</strong></a> |
<a href="https://huggingface.co/nvidia/NitroGen"><strong>Model</strong></a> |
<a href="https://huggingface.co/datasets/nvidia/NitroGen"><strong>Dataset</strong></a> |
<a href="https://nitrogen.minedojo.org/assets/documents/nitrogen.pdf"><strong>Paper</strong></a>
</p>
</div>
# NitroGen
NitroGen is an open foundation model for generalist gaming agents. This multi-game model takes pixel input and predicts gamepad actions.
NitroGen is trained through behavior cloning on the largest video-action gameplay dataset, assembled exclusively from internet videos. It can be adapted via post-training to unseen games.
# Installation
## Prerequisites
We **do not distribute game environments**, you must use your own copies of the games. This repository only supports running the agent on **Windows games**. You can serve the model from a Linux machine for inference, but the game ultimately has to run on Windows. We have tested on Windows 11 with Python ≥ 3.12.
## Setup
Install this repo:
```bash
git clone https://github.com/MineDojo/NitroGen.git
cd NitroGen
pip install -e .
```
Download NitroGen checkpoint from [HuggingFace](https://huggingface.co/nvidia/NitroGen):
```bash
hf download nvidia/NitroGen ng.pt
```
# Getting Started
First, start an inference server for the model:
```bash
python scripts/serve.py <path_to_ng.pt>
```
Then, run the agent on the game of your choice:
```bash
python scripts/play.py --process '<game_executable_name>.exe'
```
The `--process` parameter must be the exact executable name of the game you want to play. You can find it by right-clicking on the game process in Windows Task Manager (Ctrl+Shift+Esc), and selecting `Properties`. The process name should be in the `General` tab and end with `.exe`.
<!-- TODO # Paper and Citation
If you find our work useful, please consider citing us!
```bibtex
@article{,
title = {},
author = {},
year = {},
journal = {}
}
``` -->
**Disclaimer**: This project is strictly for research purposes and is not an official NVIDIA product.

BIN
assets/github_banner.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 28 MiB

26
nitrogen/cfg.py Normal file
View File

@ -0,0 +1,26 @@
from pydantic import BaseModel, Field
from nitrogen.flow_matching_transformer.nitrogen import NitroGen_Config
from nitrogen.mm_tokenizers import NitrogenTokenizerConfig
class ModalityConfig(BaseModel):
frame_per_sample: int = 1 # number of context frames per sample
frame_spacing: int | None = None # how many frames to skip between each frame. If None, use action_per_chunk
action_per_chunk: int = 8
action_shift: int = 1 # how many actions to skip between frame[i] and action_chunk[i]
action_interleaving: bool = False # if True, action chunks will be interleaved with context frames and used by the model to predict the next actions
token_set: str = "new"
def model_post_init(self, __context):
if self.frame_spacing is None:
# Use object.__setattr__ because the model is frozen
object.__setattr__(self, 'frame_spacing', self.action_per_chunk)
assert self.action_shift >= 1, "Frame shift must be at least 1 for correct action indexing"
class CkptConfig(BaseModel):
experiment_name: str = Field(..., description="Name of the experiment")
model_cfg: NitroGen_Config = Field(..., description="Model configuration. This is a placeholder and should be replaced with the actual model config class.")
tokenizer_cfg: NitrogenTokenizerConfig = Field(..., description="Tokenizer configuration. This is a placeholder and should be replaced with the actual tokenizer config class.")
modality_cfg: ModalityConfig = Field(..., description="Modality configuration for the dataset mixture.")

View File

@ -0,0 +1,434 @@
from typing import Optional
from pydantic import BaseModel, Field
import torch
import torch.nn.functional as F
from diffusers import ConfigMixin, ModelMixin
from diffusers.configuration_utils import register_to_config
from diffusers.models.attention import Attention, FeedForward
from diffusers.models.embeddings import (
SinusoidalPositionalEmbedding,
TimestepEmbedding,
Timesteps,
)
from torch import nn
class TimestepEncoder(nn.Module):
def __init__(self, embedding_dim, compute_dtype=torch.float32):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(self, timesteps):
dtype = next(self.parameters()).dtype
timesteps_proj = self.time_proj(timesteps).to(dtype)
timesteps_emb = self.timestep_embedder(timesteps_proj) # (N, D)
return timesteps_emb
class AdaLayerNorm(nn.Module):
def __init__(
self,
embedding_dim: int,
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-5,
chunk_dim: int = 0,
):
super().__init__()
self.chunk_dim = chunk_dim
output_dim = embedding_dim * 2
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, output_dim)
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
def forward(
self,
x: torch.Tensor,
temb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
temb = self.linear(self.silu(temb))
scale, shift = temb.chunk(2, dim=1)
x = self.norm(x) * (1 + scale[:, None]) + shift[:, None]
return x
class BasicTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
attention_bias: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
norm_eps: float = 1e-5,
final_dropout: bool = False,
attention_type: str = "default",
positional_embeddings: Optional[str] = None,
num_positional_embeddings: Optional[int] = None,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
attention_out_bias: bool = True,
):
super().__init__()
self.dim = dim
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.dropout = dropout
self.cross_attention_dim = cross_attention_dim
self.activation_fn = activation_fn
self.attention_bias = attention_bias
self.norm_elementwise_affine = norm_elementwise_affine
self.positional_embeddings = positional_embeddings
self.num_positional_embeddings = num_positional_embeddings
self.norm_type = norm_type
if positional_embeddings and (num_positional_embeddings is None):
raise ValueError(
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
)
if positional_embeddings == "sinusoidal":
self.pos_embed = SinusoidalPositionalEmbedding(
dim, max_seq_length=num_positional_embeddings
)
else:
self.pos_embed = None
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
if norm_type == "ada_norm":
self.norm1 = AdaLayerNorm(dim)
else:
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
)
# 3. Feed-forward
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias,
)
if final_dropout:
self.final_dropout = nn.Dropout(dropout)
else:
self.final_dropout = None
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
# 0. Self-Attention
if self.norm_type == "ada_norm":
norm_hidden_states = self.norm1(hidden_states, temb)
else:
norm_hidden_states = self.norm1(hidden_states)
if self.pos_embed is not None:
norm_hidden_states = self.pos_embed(norm_hidden_states)
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
# encoder_attention_mask=encoder_attention_mask,
)
if self.final_dropout:
attn_output = self.final_dropout(attn_output)
hidden_states = attn_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
# 4. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
ff_output = self.ff(norm_hidden_states)
hidden_states = ff_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
return hidden_states
class DiTConfig(BaseModel):
num_attention_heads: int = Field(default=8)
attention_head_dim: int = Field(default=64)
output_dim: int = Field(default=26)
num_layers: int = Field(default=12)
dropout: float = Field(default=0.1)
attention_bias: bool = Field(default=True)
activation_fn: str = Field(default="gelu-approximate")
num_embeds_ada_norm: Optional[int] = Field(default=1000)
upcast_attention: bool = Field(default=False)
norm_type: str = Field(default="ada_norm")
norm_elementwise_affine: bool = Field(default=False)
norm_eps: float = Field(default=1e-5)
max_num_positional_embeddings: int = Field(default=512)
compute_dtype: str = Field(default="float32")
final_dropout: bool = Field(default=True)
positional_embeddings: Optional[str] = Field(default="sinusoidal")
interleave_self_attention: bool = Field(default=False)
cross_attention_dim: Optional[int] = Field(default=None, description="Dimension of the cross-attention embeddings. If None, no cross-attention is used.")
class DiT(ModelMixin):
_supports_gradient_checkpointing = True
def __init__(self,config: DiTConfig):
super().__init__()
self.config = config
self.compute_dtype = getattr(torch, self.config.compute_dtype)
self.attention_head_dim = self.config.attention_head_dim
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.gradient_checkpointing = False
# Timestep encoder
self.timestep_encoder = TimestepEncoder(
embedding_dim=self.inner_dim, compute_dtype=self.compute_dtype
)
all_blocks = []
for idx in range(self.config.num_layers):
use_self_attn = idx % 2 == 1 and self.config.interleave_self_attention
curr_cross_attention_dim = self.config.cross_attention_dim if not use_self_attn else None
all_blocks += [
BasicTransformerBlock(
self.inner_dim,
self.config.num_attention_heads,
self.config.attention_head_dim,
dropout=self.config.dropout,
activation_fn=self.config.activation_fn,
attention_bias=self.config.attention_bias,
upcast_attention=self.config.upcast_attention,
norm_type=self.config.norm_type,
norm_elementwise_affine=self.config.norm_elementwise_affine,
norm_eps=self.config.norm_eps,
positional_embeddings=self.config.positional_embeddings,
num_positional_embeddings=self.config.max_num_positional_embeddings,
final_dropout=self.config.final_dropout,
cross_attention_dim=curr_cross_attention_dim,
)
]
self.transformer_blocks = nn.ModuleList(all_blocks)
# Output blocks
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
self.proj_out_2 = nn.Linear(self.inner_dim, self.config.output_dim)
print(
"Total number of DiT parameters: ",
sum(p.numel() for p in self.parameters() if p.requires_grad),
)
def forward(
self,
hidden_states: torch.Tensor, # Shape: (B, T, D)
encoder_hidden_states: torch.Tensor, # Shape: (B, S, D)
timestep: Optional[torch.LongTensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
return_all_hidden_states: bool = False,
):
# Encode timesteps
temb = self.timestep_encoder(timestep)
# Process through transformer blocks - single pass through the blocks
hidden_states = hidden_states.contiguous()
encoder_hidden_states = encoder_hidden_states.contiguous()
all_hidden_states = [hidden_states]
# Process through transformer blocks
for idx, block in enumerate(self.transformer_blocks):
if idx % 2 == 1 and self.config.interleave_self_attention:
hidden_states = block(
hidden_states,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
temb=temb,
)
else:
hidden_states = block(
hidden_states,
attention_mask=None,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=None,
temb=temb,
)
all_hidden_states.append(hidden_states)
# Output processing
conditioning = temb
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
if return_all_hidden_states:
return self.proj_out_2(hidden_states), all_hidden_states
else:
return self.proj_out_2(hidden_states)
class SelfAttentionTransformerConfig(BaseModel):
num_attention_heads: int = Field(default=8)
attention_head_dim: int = Field(default=64)
output_dim: int = Field(default=26)
num_layers: int = Field(default=12)
dropout: float = Field(default=0.1)
attention_bias: bool = Field(default=True)
activation_fn: str = Field(default="gelu-approximate")
num_embeds_ada_norm: Optional[int] = Field(default=1000)
upcast_attention: bool = Field(default=False)
max_num_positional_embeddings: int = Field(default=512)
compute_dtype: str = Field(default="float32")
final_dropout: bool = Field(default=True)
positional_embeddings: Optional[str] = Field(default="sinusoidal")
interleave_self_attention: bool = Field(default=False)
class SelfAttentionTransformer(ModelMixin):
_supports_gradient_checkpointing = True
def __init__(self, config: SelfAttentionTransformerConfig):
super().__init__()
self.config = config
self.attention_head_dim = self.config.attention_head_dim
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.gradient_checkpointing = False
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
self.inner_dim,
self.config.num_attention_heads,
self.config.attention_head_dim,
dropout=self.config.dropout,
activation_fn=self.config.activation_fn,
attention_bias=self.config.attention_bias,
upcast_attention=self.config.upcast_attention,
positional_embeddings=self.config.positional_embeddings,
num_positional_embeddings=self.config.max_num_positional_embeddings,
final_dropout=self.config.final_dropout,
)
for _ in range(self.config.num_layers)
]
)
print(
"Total number of SelfAttentionTransformer parameters: ",
sum(p.numel() for p in self.parameters() if p.requires_grad),
)
def forward(
self,
hidden_states: torch.Tensor, # Shape: (B, T, D)
return_all_hidden_states: bool = False,
):
# Process through transformer blocks - single pass through the blocks
hidden_states = hidden_states.contiguous()
all_hidden_states = [hidden_states]
# Process through transformer blocks
for idx, block in enumerate(self.transformer_blocks):
hidden_states = block(hidden_states)
all_hidden_states.append(hidden_states)
if return_all_hidden_states:
return hidden_states, all_hidden_states
else:
return hidden_states
class CrossAttentionTransformer(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
num_attention_heads: int = 8,
attention_head_dim: int = 64,
output_dim: int = 26,
num_layers: int = 12,
dropout: float = 0.1,
attention_bias: bool = True,
activation_fn: str = "gelu-approximate",
num_embeds_ada_norm: Optional[int] = 1000,
upcast_attention: bool = False,
max_num_positional_embeddings: int = 512,
compute_dtype=torch.float32,
final_dropout: bool = True,
positional_embeddings: Optional[str] = "sinusoidal",
interleave_self_attention=False,
):
super().__init__()
self.attention_head_dim = attention_head_dim
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.gradient_checkpointing = False
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
self.inner_dim,
self.config.num_attention_heads,
self.config.attention_head_dim,
dropout=self.config.dropout,
activation_fn=self.config.activation_fn,
attention_bias=self.config.attention_bias,
upcast_attention=self.config.upcast_attention,
positional_embeddings=positional_embeddings,
num_positional_embeddings=self.config.max_num_positional_embeddings,
final_dropout=final_dropout,
)
for _ in range(self.config.num_layers)
]
)
print(
"Total number of CrossAttentionTransformer parameters: ",
sum(p.numel() for p in self.parameters() if p.requires_grad),
)
def forward(
self,
hidden_states: torch.Tensor, # Shape: (B, T, D)
encoder_hidden_states: torch.Tensor, # Shape: (B, S, D)
):
# Process through transformer blocks - single pass through the blocks
hidden_states = hidden_states.contiguous()
encoder_hidden_states = encoder_hidden_states.contiguous()
# Process through transformer blocks
for idx, block in enumerate(self.transformer_blocks):
hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
)
return hidden_states

View File

@ -0,0 +1,755 @@
from dataclasses import dataclass, field
from pydantic import BaseModel, Field
from pathlib import Path
import yaml
import numpy as np
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn
from torch.distributions import Beta
from transformers import SiglipVisionModel, AutoModel
from .modules import DiT, DiTConfig, SelfAttentionTransformer, SelfAttentionTransformerConfig
_PAD_TOKEN = 0
_IMG_TOKEN = 1
_IMG_SEP_TOKEN = 5
_LANG_TOKEN = 2
_PROPRIO_TOKEN = 3
_ACT_TOKEN = 4
_GAME_ID_TOKEN = 6
class NitroGen_Config(BaseModel):
model_type: str = Field(default="nitrogen", frozen=True)
add_pos_embed: bool = Field(default=False, description="Whether to add positional embedding")
model_dtype: str = Field(default="float32", description="Model data type.")
diffusion_model_cfg: DiTConfig = Field(..., description="Diffusion model configuration.")
vl_self_attention_cfg: SelfAttentionTransformerConfig = Field(..., description="VL self-attention configuration.")
hidden_size: int = Field(default=1024, description="Input embedding dimension.")
max_seq_len: int = Field(default=1024, description="Maxium Sequence Length")
action_dim: int = Field(default=None, description="Action dimension.")
action_horizon: int = Field(default=None, description="Action horizon.")
noise_beta_alpha: float = Field(default=1.5, description="")
noise_beta_beta: float = Field(default=1.0, description="")
noise_s: float = Field(default=0.999, description="Flow matching noise Beta distribution s.")
num_timestep_buckets: int = Field(default=1000, description="Number of timestep discretization buckets.")
num_inference_timesteps: int = Field(default=None, description="Number of inference steps for noise diffusion.")
max_num_embodiments: int = Field(default=1, description="Number of embodiments.")
vision_encoder_name: str = Field(default="google/siglip-large-patch16-256", description="Vision encoder name.")
vision_hidden_size: int = Field(default=768, description="Siglip hidden size.")
add_view_embed: bool = Field(default=False, description="Whether to add view embedding.")
tune_vision_tower: bool = Field(default=True, description="Tune vision if True.")
tune_mm_projector: bool = Field(default=True, description="Tune mm projector if True.")
tune_diffusion_model: bool = Field(default=True, description="Tune diffusion model if True.")
tune_multi_projector: bool = Field(default=True, description="Tune multi projector if True.")
tune_vl_mixing: bool = Field(default=True, description="Tune vl mixing if True.")
@classmethod
def from_yaml(cls, yaml_path: str | Path) -> "NitroGen_Config":
"""Load configuration from a YAML file."""
with open(yaml_path, "r") as f:
config_dict = yaml.safe_load(f)
return cls.model_validate(config_dict)
def swish(x):
return x * torch.sigmoid(x)
class SinusoidalPositionalEncoding(nn.Module):
"""
Produces a sinusoidal encoding of shape (B, T, w)
given timesteps of shape (B, T).
"""
def __init__(self, embedding_dim):
super().__init__()
self.embedding_dim = embedding_dim
def forward(self, timesteps):
# timesteps: shape (B, T)
# We'll compute sin/cos frequencies across dim T
timesteps = timesteps.float() # ensure float
B, T = timesteps.shape
device = timesteps.device
half_dim = self.embedding_dim // 2
# typical log space frequencies for sinusoidal encoding
exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * (
torch.log(torch.tensor(10000.0)) / half_dim
)
# Expand timesteps to (B, T, 1) then multiply
freqs = timesteps.unsqueeze(-1) * exponent.exp() # (B, T, half_dim)
sin = torch.sin(freqs)
cos = torch.cos(freqs)
enc = torch.cat([sin, cos], dim=-1) # (B, T, w)
return enc
class CategorySpecificLinear(nn.Module):
def __init__(self, num_categories, input_dim, hidden_dim):
super().__init__()
self.num_categories = num_categories
# For each category, we have separate weights and biases.
self.W = nn.Parameter(0.02 * torch.randn(num_categories, input_dim, hidden_dim))
self.b = nn.Parameter(torch.zeros(num_categories, hidden_dim))
def forward(self, x, cat_ids):
selected_W = self.W[cat_ids]
selected_b = self.b[cat_ids]
return torch.bmm(x, selected_W) + selected_b.unsqueeze(1)
class CategorySpecificMLP(nn.Module):
def __init__(self, num_categories, input_dim, hidden_dim, output_dim):
super().__init__()
self.num_categories = num_categories
self.layer1 = CategorySpecificLinear(num_categories, input_dim, hidden_dim)
self.layer2 = CategorySpecificLinear(num_categories, hidden_dim, output_dim)
def forward(self, x, cat_ids):
hidden = F.relu(self.layer1(x, cat_ids))
return self.layer2(hidden, cat_ids)
class MultiEmbodimentActionEncoder(nn.Module):
def __init__(self, action_dim, hidden_size, num_embodiments):
super().__init__()
self.hidden_size = hidden_size
self.num_embodiments = num_embodiments
# W1: R^{w x d}, W2: R^{w x 2w}, W3: R^{w x w}
self.W1 = CategorySpecificLinear(num_embodiments, action_dim, hidden_size) # (d -> w)
self.W2 = CategorySpecificLinear(num_embodiments, 2 * hidden_size, hidden_size) # (2w -> w)
self.W3 = CategorySpecificLinear(num_embodiments, hidden_size, hidden_size) # (w -> w)
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
def forward(self, actions, timesteps, cat_ids):
"""
actions: shape (B, T, action_dim)
timesteps: shape (B,) -- a single scalar per batch item
cat_ids: shape (B,)
returns: shape (B, T, hidden_size)
"""
B, T, _ = actions.shape
# 1) Expand each batch's single scalar time 'tau' across all T steps
# so that shape => (B, T)
# e.g. if timesteps is (B,), replicate across T
if timesteps.dim() == 1 and timesteps.shape[0] == B:
# shape (B,) => (B,T)
timesteps = timesteps.unsqueeze(1).expand(-1, T)
else:
raise ValueError(
"Expected `timesteps` to have shape (B,) so we can replicate across T."
)
# 2) Standard action MLP step for shape => (B, T, w)
a_emb = self.W1(actions, cat_ids)
# 3) Get the sinusoidal encoding (B, T, w)
tau_emb = self.pos_encoding(timesteps).to(dtype=a_emb.dtype)
# 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish
x = torch.cat([a_emb, tau_emb], dim=-1)
x = swish(self.W2(x, cat_ids))
# 5) Finally W3 => (B, T, w)
x = self.W3(x, cat_ids)
return x
class NitroGen(torch.nn.Module):
config_class = NitroGen_Config
supports_gradient_checkpointing = True
def __init__(
self,
config: NitroGen_Config,
game_mapping: dict[str, int] | None = None, # Used to add a game ID token
):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.vision_hidden_size = config.vision_hidden_size
if "siglip" in config.vision_encoder_name:
model = SiglipVisionModel.from_pretrained(config.vision_encoder_name)
self.vision_encoder = model.vision_model
self.vision_encoder_type = "siglip"
else:
self.vision_encoder = AutoModel.from_pretrained(config.vision_encoder_name)
self.vision_encoder_type = "hf_auto"
self.beta_dist = Beta(config.noise_beta_alpha, config.noise_beta_beta)
self.num_timestep_buckets = config.num_timestep_buckets
# self.model = instantiate(config.diffusion_model_cfg)
self.model = DiT(config=config.diffusion_model_cfg)
self.action_dim = config.action_dim
self.action_horizon = config.action_horizon
self.num_inference_timesteps = config.num_inference_timesteps
# self.vl_self_attention_model = instantiate(config.vl_self_attention_cfg)
self.vl_self_attention_model = SelfAttentionTransformer(config=config.vl_self_attention_cfg)
# if config.qformer_cfg is not None:
# self.qformer = instantiate(config.qformer_cfg)
# else:
# self.qformer = nn.Identity()
# self.state_encoder = CategorySpecificMLP(
# num_categories=config.max_num_embodiments,
# input_dim=config.max_state_dim,
# hidden_dim=self.hidden_size,
# output_dim=self.hidden_size,
# )
self.action_encoder = MultiEmbodimentActionEncoder(
action_dim=config.action_dim,
hidden_size=self.hidden_size,
num_embodiments=config.max_num_embodiments,
)
self.action_decoder = CategorySpecificMLP(
num_categories=config.max_num_embodiments,
input_dim=self.hidden_size,
hidden_dim=self.hidden_size,
output_dim=self.action_dim,
)
# self.mm_vision_select_layer = config.mm_vision_select_layer
# if config.mm_projector_cfg is not None:
# self.mm_projector = instantiate(config.mm_projector_cfg)
# else:
self.mm_projector = None
if config.add_pos_embed:
self.position_embedding = nn.Embedding(config.max_seq_len, self.hidden_size)
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
# if config.add_view_embed:
# self.view_embedding = nn.Embedding(config.max_num_views, self.hidden_size)
# nn.init.normal_(self.view_embedding.weight, mean=0.0, std=0.02)
# self.vision_projector = None
# if config.vision_hidden_size != self.hidden_size:
# self.vision_projector = nn.Sequential(
# nn.Linear(config.vision_hidden_size, self.hidden_size),
# nn.LayerNorm(self.hidden_size),
# )
self.game_mapping = game_mapping
# Create an embedding table for game IDs
# Game ID tokens will be put inside vision-language tokens
# so they need to be projected to the same dimension
if self.game_mapping is not None:
# 0 = unconditional
self.game_embedding = nn.Embedding(
len(self.game_mapping),
self.vision_hidden_size,
padding_idx=0,
scale_grad_by_freq=True
)
self.set_trainable_parameters(
tune_multi_projector=config.tune_multi_projector,
tune_diffusion_model=config.tune_diffusion_model,
tune_vision_tower=config.tune_vision_tower,
tune_mm_projector=config.tune_mm_projector,
tune_vl_mixing=config.tune_vl_mixing,
)
print(
"total number of parameters: %e",
sum(p.numel() for p in self.parameters() if p.requires_grad),
)
def set_trainable_parameters(
self,
tune_multi_projector: bool = True,
tune_diffusion_model: bool = True,
tune_vision_tower: bool = True,
tune_mm_projector: bool = True,
tune_vl_mixing: bool = True,
):
self.tune_multi_projector = tune_multi_projector
self.tune_diffusion_model = tune_diffusion_model
self.tune_vision_tower = tune_vision_tower
self.tune_mm_projector = tune_mm_projector
self.tune_vl_mixing = tune_vl_mixing
for param in self.parameters():
param.requires_grad = True
# ### Always freeze language encoder
# self.siglip_model.text_model.requires_grad_(False)
# # Freeze unused parameters in siglip vision encoder
# self.siglip_model.logit_scale.requires_grad = False
# self.siglip_model.logit_bias.requires_grad = False
# For siglip, we have to
if self.vision_encoder_type == "siglip":
for param in self.vision_encoder.encoder.layers[11].parameters():
param.requires_grad = False
for param in self.vision_encoder.head.parameters():
param.requires_grad = False
# Freeze parameters
if not tune_multi_projector:
# self.state_encoder.requires_grad_(False)
self.action_encoder.requires_grad_(False)
self.action_decoder.requires_grad_(False)
if self.config.add_pos_embed:
self.position_embedding.requires_grad_(False)
if self.config.add_view_embed:
self.view_embedding.requires_grad_(False)
if not tune_diffusion_model:
self.model.requires_grad_(False)
if not tune_vision_tower:
self.vision_encoder.requires_grad_(False)
if self.mm_projector is not None and not tune_mm_projector:
self.mm_projector.requires_grad_(False)
if not tune_vl_mixing:
self.vl_self_attention_model.requires_grad_(False)
print(f"Tune action head multi_projector: {self.tune_multi_projector}")
print(f"Tune action head diffusion model: {self.tune_diffusion_model}")
print(f"Tune action head vision tower: {self.tune_vision_tower}")
print(f"Tune action head mm_projector: {self.tune_mm_projector}")
print(f"Tune action head vl_mixing: {self.tune_vl_mixing}")
# Check if any parameters are still trainable. If not, print a warning.
if not any(p.requires_grad for p in self.parameters()):
print("Warning: No action head trainable parameters found.")
def set_frozen_modules_to_eval_mode(self):
"""
Huggingface will call model.train() at each training_step. To ensure
the expected behaviors for modules like dropout, batchnorm, etc., we
need to call model.eval() for the frozen modules.
"""
if self.training:
# self.siglip_model.text_model.eval()
if not self.tune_multi_projector:
# self.state_encoder.eval()
self.action_encoder.eval()
self.action_decoder.eval()
if self.config.add_pos_embed:
self.position_embedding.eval()
# if self.config.add_view_embed:
# self.view_embedding.eval()
if not self.tune_diffusion_model:
self.model.eval()
if not self.tune_vision_tower:
self.vision_encoder.eval()
if self.mm_projector is not None and not self.tune_mm_projector:
self.mm_projector.eval()
if not self.tune_vl_mixing:
self.vl_self_attention_model.eval()
# This function is supposedly incorrect
# def sample_time(self, batch_size, device, dtype):
# sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype)
# return (self.config.noise_s - sample) / self.config.noise_s
def sample_time(self, batch_size, device, dtype):
sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype)
return (1 - sample) * self.config.noise_s
def encode_images(self, images): #, view_ids):
batch_size, num_frames, channels, height, width = images.shape
images = images.reshape(-1, channels, height, width)
image_features = self.vision_encoder(images)["last_hidden_state"]
image_features = rearrange(image_features, "(b f) n d -> b f n d", f=num_frames)
# if self.vision_projector is not None:
# # change the hidden dimension of the vision features
# image_features = self.vision_projector(image_features)
if self.mm_projector is not None:
image_features = self.mm_projector(image_features) # [B, 256, 1024] -> [B, 16, 1024]
return image_features
def prepare_input_embs(self, vl_token_ids, sa_token_ids, vision, action, dropped_images, game_ids=None):
B, T = vl_token_ids.shape
vl_embs = torch.full(
size=(B, T, self.vision_hidden_size), fill_value=0.0, dtype=vision.dtype, device=vision.device
)
# Extract dimensions from vision tensor
B, num_images, tokens_per_image, hidden_size = vision.shape
# Create mask for _IMG_TOKEN positions
vision_mask = (vl_token_ids == _IMG_TOKEN) # [B, T]
# Flatten vision tensor over the num_images dimension
vision_flat = vision.reshape(B, -1, self.vision_hidden_size) # [B, T * tokens_per_image, hidden_size]
# Create a mask for the flattened vision dimension
# Each image contributes tokens_per_image tokens, so expand the mask accordingly
non_dropped_mask_expanded = (dropped_images == 0).unsqueeze(-1).repeat(1, 1, tokens_per_image).reshape(B, -1) # [B, T * tokens_per_image]
# Select only non-dropped vision embeddings
# This will give us the embeddings we need to place
valid_vision_embs = vision_flat[non_dropped_mask_expanded] # [total_valid_tokens, 1152]
assert valid_vision_embs.shape[0] == vision_mask.sum().item(), (
f"Number of valid vision embeddings {valid_vision_embs.shape[0]} does not match "
f"the number of _IMG_TOKEN positions {vision_mask.sum().item()}"
)
# Now we need to place these at the vision_mask positions
# Get indices where vision_mask is True
batch_indices, token_indices = vision_mask.nonzero(as_tuple=True)
# Place the valid embeddings at the masked positions
vl_embs[batch_indices, token_indices] = valid_vision_embs
# Handle Game ID tokens
if self.game_mapping is not None and game_ids is not None:
game_mask = vl_token_ids == _GAME_ID_TOKEN # shape: (B, T)
num_game_tokens = game_mask.sum().item()
if num_game_tokens > 0:
# Assert that each batch item has exactly one game token
game_tokens_per_batch = game_mask.sum(dim=1) # [B] - count of game tokens per batch item
assert torch.all(game_tokens_per_batch == 1), (
f"Expected exactly 1 game token per batch item, but got: {game_tokens_per_batch.tolist()}. "
f"Each batch item must have exactly one _GAME_ID_TOKEN."
)
# Get game embeddings for each batch item
game_embs = self.game_embedding(game_ids) # [B, vision_hidden_size]
batch_indices, token_indices = game_mask.nonzero(as_tuple=True)
vl_embs[batch_indices, token_indices] = game_embs[batch_indices].to(dtype=vl_embs.dtype)
# Project image separator using the learnable sep_embedding.
sep_mask = vl_token_ids == _IMG_SEP_TOKEN # shape: (B, T)
num_sep = sep_mask.sum().item()
if num_sep > 0:
# Expand the separator embedding for each occurrence.
repeated_sep = self.vis_sep_embedding.unsqueeze(0).expand(num_sep, self.hidden_size)
# Assign the separator embeddings to the correct positions.
vl_embs[sep_mask] = repeated_sep.to(dtype=vl_embs.dtype)
B, T = sa_token_ids.shape
sa_embs = torch.full(
size=(B, T, self.hidden_size), fill_value=0.0, dtype=vision.dtype, device=vision.device
)
# Project state.
# state_mask = sa_token_ids == _PROPRIO_TOKEN
# state_mask = state_mask.unsqueeze(-1).expand_as(sa_embs)
# sa_embs = sa_embs.masked_scatter(state_mask, state)
# Project action.
action_mask = sa_token_ids == _ACT_TOKEN
action_mask = action_mask.unsqueeze(-1).expand_as(sa_embs)
sa_embs = sa_embs.masked_scatter(action_mask, action)
# Add positional embeddings
pos_ids = torch.arange(T, dtype=torch.long, device=sa_token_ids.device)
if self.config.add_pos_embed:
pos_embs = self.position_embedding(pos_ids) # (T, hidden_size)
pos_embs = pos_embs.unsqueeze(0).expand(B, T, self.hidden_size)
sa_embs = sa_embs + pos_embs
return vl_embs, sa_embs
def pack_actions(self, buttons, j_left, j_right):
# Check that the first three dims of each input is the same
assert buttons.shape[:3] == j_left.shape[:3] == j_right.shape[:3], (
f"buttons shape: {buttons.shape}, "
f"j_left shape: {j_left.shape}, "
f"j_right shape: {j_right.shape}"
)
# Normalize the joysticks to 0,1
j_left = (j_left + 1) / 2.
j_right = (j_right + 1) / 2.
# Concatenate the buttons and joysticks along the last dimension
action = torch.cat([j_left, j_right, buttons], dim=-1)
# Squeeze the second dimension of each input: this is the number of chunks, which is 1 here
action = action.squeeze(1)
return action
# def unpack_actions(self, actions):
# # Unpack the actions into j_left, j_right, buttons
# j_left = actions[:, :, :2]
# j_right = actions[:, :, 2:4]
# buttons = actions[:, :, 4:]
# # Denormalize the joysticks back to -1,1
# j_left = j_left * 2. - 1.
# j_right = j_right * 2. - 1.
# # Clip into [-1,1]
# j_left = torch.clamp(j_left, -1, 1)
# j_right = torch.clamp(j_right, -1, 1)
# # Threshold the buttons to 0,1
# buttons = (buttons > 0.5).float()
# return j_left, j_right, buttons
# ========= ActionHead required ============
def forward(self, data: dict) -> dict:
self.set_frozen_modules_to_eval_mode()
# data = action_input
embodiment_id = data["embodiment_id"]
# # Check which data is present.
# has_real_action = action_input.has_real_action
has_real_action = data["has_real_action"]
# 1) Encode images/text/state
visual_features = self.encode_images(data["images"]) #, data["view_ids"])
# text_features = self.siglip_model.text_model(
# input_ids=data["lang_input_ids"]
# ).last_hidden_state
# state_features = self.state_encoder(data["state"], embodiment_id)
# 2) Prepare noisy trajectory
actions = data["actions"]
noise = torch.randn_like(actions)
t = self.sample_time(actions.shape[0], device=actions.device, dtype=actions.dtype)
t = t[:, None, None] # shape (B,1,1) for broadcast
noisy_trajectory = (1 - t) * noise + t * actions
velocity = actions - noise
# 3) Convert (continuous) t -> discrete if needed
t_discretized = (t[:, 0, 0] * self.num_timestep_buckets).long()
# 4) Get action encoder embeddings with correct time argument
action_features = self.action_encoder(noisy_trajectory, t_discretized, embodiment_id)
# 5) Prepare full input to DiT (or your model)
vl_embs, sa_embs = self.prepare_input_embs(
data["vl_token_ids"],
data["sa_token_ids"],
visual_features,
# text_features,
# state_features,
action_features,
data["dropped_images"],
game_ids=data.get("game_id"),
)
vl_embs = self.vl_self_attention_model(vl_embs)
# vl_embs = self.qformer(vl_embs)
model_output, all_hidden_states = self.model(
hidden_states=sa_embs,
encoder_hidden_states=vl_embs,
encoder_attention_mask=data["vl_attn_mask"],
timestep=t_discretized,
return_all_hidden_states=True,
)
pred = self.action_decoder(model_output, embodiment_id)
pred_actions = pred[:, -actions.shape[1] :]
# 6) Flow-matching or velocity-prediction MSE
# Mask for variable-length trajectories
mask = data["actions_mask"] # shape => (B, seq_len_of_actions, ...)
raw_loss = F.mse_loss(pred_actions, velocity, reduction="none")
mask = has_real_action[:, None, None] * mask
raw_loss = raw_loss * mask
action_loss = (has_real_action[:, None, None] * raw_loss).sum() / (mask.sum() + 1e-6)
loss = action_loss
return {
"loss": loss,
}
@torch.inference_mode()
def get_action(self, data: dict, old_layout:bool = False) -> dict:
"""
For i in [0..N-1]:
1) t = i/N
2) velocity = model(x(t), t)
3) x(t + dt) = x(t) + dt * velocity
"""
# data = action_input
embodiment_id = data["embodiment_id"]
batch_size = data["images"].shape[0]
device = data["images"].device
dtype = data["images"].dtype
actions = torch.randn(
size=(batch_size, self.config.action_horizon, self.config.action_dim),
dtype=dtype,
device=device,
)
# 1) Hyperparameters for flow sampling
num_steps = self.num_inference_timesteps
dt = 1.0 / num_steps
# 2) Encode static context (images, text, state) once if it does not depend on actions
visual_features = self.encode_images(data["images"]) #, data["view_ids"])
# text_features = self.siglip_model.text_model(
# input_ids=data["lang_input_ids"]
# ).last_hidden_state
# state_features = self.state_encoder(data["state"], embodiment_id)
# 3) Start denoising the actions
for i in range(num_steps):
# ---- (a) Discretize continuous time in [0,1]
t_cont = i / float(num_steps) # e.g. goes 0, 1/N, 2/N, ...
t_discretized = int(t_cont * self.num_timestep_buckets)
# ---- (b) Build embeddings (actions included)
# Pass the *current* actions at time t into the action encoder
action_features = self.action_encoder(
actions,
(torch.ones(actions.shape[0]) * t_discretized).to(device),
embodiment_id,
)
vl_embs, sa_embs = self.prepare_input_embs(
data["vl_token_ids"],
data["sa_token_ids"],
visual_features,
# text_features,
# state_features,
action_features,
data["dropped_images"],
game_ids=data["game_ids"],
)
vl_embs = self.vl_self_attention_model(vl_embs)
# vl_embs = self.qformer(vl_embs)
# ---- (c) Forward pass to get velocity = d/dt x(t)
timesteps = torch.from_numpy(np.array([t_discretized])).to(device).long()
model_output = self.model(
hidden_states=sa_embs,
encoder_hidden_states=vl_embs,
encoder_attention_mask=data["vl_attn_mask"],
timestep=timesteps,
)
pred = self.action_decoder(model_output, embodiment_id)
pred_velocity = pred[:, -actions.shape[1] :]
# ---- (d) Naive Euler step: x(t + dt) = x(t) + dt * velocity
actions = actions + dt * pred_velocity
return {
"action_tensor": actions,
}
@torch.inference_mode()
def get_action_with_cfg(self, data_cond: dict, data_uncond: dict, cfg_scale: float = 1.0) -> dict:
"""
Use a form of classifier free guidance to sample actions. This can only be used on
models that were trained on multiple frames of actions. The idea is that we sample
velocity with and without the frame history, and then we push the sampled actions
towards the ones that were sampled with the frame history.
data_with_hist = conditional input
data_without_hist = unconditional input
This function works with any kind of conditioning, not just history.
For i in [0..N-1]:
1) t = i/N
2) velocity = (1 - cfg_scale) * model(x(t), t, None) + cfg_scale * model(x(t), t, history)
3) x(t + dt) = x(t) + dt * velocity
"""
# data = action_input
embodiment_id = data_cond["embodiment_id"]
batch_size = data_cond["images"].shape[0]
device = data_cond["images"].device
dtype = data_cond["images"].dtype
actions = torch.randn(
size=(batch_size, self.config.action_horizon, self.config.action_dim),
dtype=dtype,
device=device,
)
# 1) Hyperparameters for flow sampling
num_steps = self.num_inference_timesteps
dt = 1.0 / num_steps
# 2) Encode static context (images, text, state) once if it does not depend on actions
visual_features_cond = self.encode_images(data_cond["images"])
visual_features_uncond = self.encode_images(data_uncond["images"])
# text_features = self.siglip_model.text_model(
# input_ids=data["lang_input_ids"]
# ).last_hidden_state
# state_features = self.state_encoder(data["state"], embodiment_id)
# 3) Start denoising the actions
for i in range(num_steps):
# ---- (a) Discretize continuous time in [0,1]
t_cont = i / float(num_steps) # e.g. goes 0, 1/N, 2/N, ...
t_discretized = int(t_cont * self.num_timestep_buckets)
# ---- (b) Build embeddings (actions included)
# Pass the *current* actions at time t into the action encoder
action_features = self.action_encoder(
actions,
(torch.ones(actions.shape[0]) * t_discretized).to(device),
embodiment_id,
)
# Predict velocity with history
vl_embs, sa_embs = self.prepare_input_embs(
data_cond["vl_token_ids"],
data_cond["sa_token_ids"],
visual_features_cond,
action_features,
data_cond["dropped_images"],
)
vl_embs = self.vl_self_attention_model(vl_embs)
# ---- (c) Forward pass to get velocity = d/dt x(t)
timesteps = torch.from_numpy(np.array([t_discretized])).to(device).long()
model_output = self.model(
hidden_states=sa_embs,
encoder_hidden_states=vl_embs,
encoder_attention_mask=data_cond["vl_attn_mask"],
timestep=timesteps,
)
pred = self.action_decoder(model_output, embodiment_id)
pred_velocity_cond = pred[:, -actions.shape[1] :]
# Predict velocity without history
vl_embs, sa_embs = self.prepare_input_embs(
data_uncond["vl_token_ids"],
data_uncond["sa_token_ids"],
visual_features_uncond,
action_features,
data_uncond["dropped_images"],
)
vl_embs = self.vl_self_attention_model(vl_embs)
# ---- (c) Forward pass to get velocity = d/dt x(t)
timesteps = torch.from_numpy(np.array([t_discretized])).to(device).long()
model_output = self.model(
hidden_states=sa_embs,
encoder_hidden_states=vl_embs,
encoder_attention_mask=data_uncond["vl_attn_mask"],
timestep=timesteps,
)
pred = self.action_decoder(model_output, embodiment_id)
pred_velocity_uncond = pred[:, -actions.shape[1] :]
# ---- (d) Combine velocities with cfg_scale
pred_velocity = pred_velocity_cond + cfg_scale * (pred_velocity_cond - pred_velocity_uncond)
# ---- (e) Naive Euler step: x(t + dt) = x(t) + dt * velocity
actions = actions + dt * pred_velocity
return {
"action_tensor": actions,
}
@property
def device(self):
return next(iter(self.parameters())).device
@property
def dtype(self):
return next(iter(self.parameters())).dtype

577
nitrogen/game_env.py Normal file
View File

@ -0,0 +1,577 @@
import time
import platform
import pyautogui
import dxcam
import pywinctl as pwc
import xspeedhack as xsh
from gymnasium import Env
from gymnasium.spaces import Box, Dict, Discrete
from PIL import Image
import time
import vgamepad as vg
import psutil
assert platform.system().lower() == "windows", "This module is only supported on Windows."
import win32process
import win32gui
import win32api
import win32con
def get_process_info(process_name):
"""
Get process information for a given process name on Windows.
Args:
process_name (str): Name of the process (e.g., "isaac-ng.exe")
Returns:
list: List of dictionaries containing PID, window_name, and architecture
for each matching process. Returns empty list if no process found.
"""
results = []
# Find all processes with the given name
for proc in psutil.process_iter(['pid', 'name']):
try:
if proc.info['name'].lower() == process_name.lower():
pid = proc.info['pid']
# Get architecture
try:
# Check if process is 32-bit or 64-bit
process_handle = win32api.OpenProcess(
win32con.PROCESS_QUERY_INFORMATION,
False,
pid
)
is_wow64 = win32process.IsWow64Process(process_handle)
win32api.CloseHandle(process_handle)
# On 64-bit Windows: WOW64 means "Windows 32-bit on Windows 64-bit", i.e. a 32-bit process
architecture = "x86" if is_wow64 else "x64"
except:
architecture = "unknown"
# Find windows associated with this PID
windows = []
def enum_window_callback(hwnd, pid_to_find):
_, found_pid = win32process.GetWindowThreadProcessId(hwnd)
if found_pid == pid_to_find:
window_text = win32gui.GetWindowText(hwnd)
if window_text and win32gui.IsWindowVisible(hwnd):
windows.append({
'hwnd': hwnd,
'title': window_text,
'visible': win32gui.IsWindowVisible(hwnd)
})
return True
# Find all windows for this PID
try:
win32gui.EnumWindows(enum_window_callback, pid)
except:
pass
# Choose the best window
window_name = None
if windows:
if len(windows) > 1:
print(f"Multiple windows found for PID {pid}: {[win['title'] for win in windows]}")
print("Using heuristics to select the correct window...")
# Filter out common proxy/helper windows
proxy_keywords = ['d3dproxywindow', 'proxy', 'helper', 'overlay']
# First try to find a visible window without proxy keywords
for win in windows:
if not any(keyword in win['title'].lower() for keyword in proxy_keywords):
window_name = win['title']
break
# If no good window found, just use the first one
if window_name is None and windows:
window_name = windows[0]['title']
results.append({
'pid': pid,
'window_name': window_name,
'architecture': architecture
})
except (psutil.NoSuchProcess, psutil.AccessDenied):
continue
if len(results) == 0:
raise ValueError(f"No process found with name: {process_name}")
elif len(results) > 1:
print(f"Warning: Multiple processes found with name '{process_name}'. Returning first match.")
return results[0]
XBOX_MAPPING = {
"DPAD_UP": "XUSB_GAMEPAD_DPAD_UP",
"DPAD_DOWN": "XUSB_GAMEPAD_DPAD_DOWN",
"DPAD_LEFT": "XUSB_GAMEPAD_DPAD_LEFT",
"DPAD_RIGHT": "XUSB_GAMEPAD_DPAD_RIGHT",
"START": "XUSB_GAMEPAD_START",
"BACK": "XUSB_GAMEPAD_BACK",
"LEFT_SHOULDER": "XUSB_GAMEPAD_LEFT_SHOULDER",
"RIGHT_SHOULDER": "XUSB_GAMEPAD_RIGHT_SHOULDER",
"GUIDE": "XUSB_GAMEPAD_GUIDE",
"WEST": "XUSB_GAMEPAD_X",
"SOUTH": "XUSB_GAMEPAD_A",
"EAST": "XUSB_GAMEPAD_B",
"NORTH": "XUSB_GAMEPAD_Y",
"LEFT_TRIGGER": "LEFT_TRIGGER",
"RIGHT_TRIGGER": "RIGHT_TRIGGER",
"AXIS_LEFTX": "LEFT_JOYSTICK",
"AXIS_LEFTY": "LEFT_JOYSTICK",
"AXIS_RIGHTX": "RIGHT_JOYSTICK",
"AXIS_RIGHTY": "RIGHT_JOYSTICK",
"LEFT_THUMB": "XUSB_GAMEPAD_LEFT_THUMB",
"RIGHT_THUMB": "XUSB_GAMEPAD_RIGHT_THUMB",
}
PS4_MAPPING = {
"DPAD_UP": "DS4_BUTTON_DPAD_NORTH",
"DPAD_DOWN": "DS4_BUTTON_DPAD_SOUTH",
"DPAD_LEFT": "DS4_BUTTON_DPAD_WEST",
"DPAD_RIGHT": "DS4_BUTTON_DPAD_EAST",
"START": "DS4_BUTTON_OPTIONS",
"BACK": "DS4_BUTTON_SHARE",
"LEFT_SHOULDER": "DS4_BUTTON_SHOULDER_LEFT",
"RIGHT_SHOULDER": "DS4_BUTTON_SHOULDER_RIGHT",
"GUIDE": "DS4_BUTTON_GUIDE",
"WEST": "DS4_BUTTON_SQUARE",
"SOUTH": "DS4_BUTTON_CROSS",
"EAST": "DS4_BUTTON_CIRCLE",
"NORTH": "DS4_BUTTON_TRIANGLE",
"LEFT_TRIGGER": "LEFT_TRIGGER",
"RIGHT_TRIGGER": "RIGHT_TRIGGER",
"AXIS_LEFTX": "LEFT_JOYSTICK",
"AXIS_LEFTY": "LEFT_JOYSTICK",
"AXIS_RIGHTX": "RIGHT_JOYSTICK",
"AXIS_RIGHTY": "RIGHT_JOYSTICK",
"LEFT_THUMB": "DS4_BUTTON_THUMB_LEFT",
"RIGHT_THUMB": "DS4_BUTTON_THUMB_RIGHT",
}
class GamepadEmulator:
def __init__(self, controller_type="xbox", system="windows"):
"""
Initialize the GamepadEmulator with a specific controller type and system.
Parameters:
controller_type (str): The type of controller to emulate ("xbox" or "ps4").
system (str): The operating system to use, which affects joystick value handling.
"""
self.controller_type = controller_type
self.system = system
if controller_type == "xbox":
self.gamepad = vg.VX360Gamepad()
self.mapping = XBOX_MAPPING
elif controller_type == "ps4":
self.gamepad = vg.VDS4Gamepad()
self.mapping = PS4_MAPPING
else:
raise ValueError("Unsupported controller type")
# Initialize joystick values to keep track of the current state
self.left_joystick_x: int = 0
self.left_joystick_y: int = 0
self.right_joystick_x: int = 0
self.right_joystick_y: int = 0
def step(self, action):
"""
Perform actions based on the provided action dictionary.
Parameters:
action (dict): Dictionary of an action to be performed. Keys are control names,
and values are their respective states.
"""
self.gamepad.reset()
# Handle buttons
for control in [
"EAST",
"SOUTH",
"NORTH",
"WEST",
"BACK",
"GUIDE",
"START",
"DPAD_DOWN",
"DPAD_LEFT",
"DPAD_RIGHT",
"DPAD_UP",
"LEFT_SHOULDER",
"RIGHT_SHOULDER",
"LEFT_THUMB",
"RIGHT_THUMB",
]:
if control in action:
if action[control]:
self.press_button(control)
else:
self.release_button(control)
# Handle triggers
if "LEFT_TRIGGER" in action:
self.set_trigger("LEFT_TRIGGER", action["LEFT_TRIGGER"][0])
if "RIGHT_TRIGGER" in action:
self.set_trigger("RIGHT_TRIGGER", action["RIGHT_TRIGGER"][0])
# Handle joysticks
if "AXIS_LEFTX" in action and "AXIS_LEFTY" in action:
self.set_joystick("AXIS_LEFTX", action["AXIS_LEFTX"][0])
self.set_joystick("AXIS_LEFTY", action["AXIS_LEFTY"][0])
if "AXIS_RIGHTX" in action and "AXIS_RIGHTY" in action:
self.set_joystick("AXIS_RIGHTX", action["AXIS_RIGHTX"][0])
self.set_joystick("AXIS_RIGHTY", action["AXIS_RIGHTY"][0])
self.gamepad.update()
def press_button(self, button):
"""
Press a button on the gamepad.
Parameters:
button (str): The unified name of the button to press.
"""
button_mapped = self.mapping.get(button)
if self.controller_type == "xbox":
self.gamepad.press_button(button=getattr(vg.XUSB_BUTTON, button_mapped))
elif self.controller_type == "ps4":
self.gamepad.press_button(button=getattr(vg.DS4_BUTTONS, button_mapped))
else:
raise ValueError("Unsupported controller type")
def release_button(self, button):
"""
Release a button on the gamepad.
Parameters:
button (str): The unified name of the button to release.
"""
button_mapped = self.mapping.get(button)
if self.controller_type == "xbox":
self.gamepad.release_button(button=getattr(vg.XUSB_BUTTON, button_mapped))
elif self.controller_type == "ps4":
self.gamepad.release_button(button=getattr(vg.DS4_BUTTONS, button_mapped))
else:
raise ValueError("Unsupported controller type")
def set_trigger(self, trigger, value):
"""
Set the value of a trigger on the gamepad.
Parameters:
trigger (str): The unified name of the trigger.
value (float): The value to set the trigger to (between 0 and 1).
"""
value = int(value)
trigger_mapped = self.mapping.get(trigger)
if trigger_mapped == "LEFT_TRIGGER":
self.gamepad.left_trigger(value=value)
elif trigger_mapped == "RIGHT_TRIGGER":
self.gamepad.right_trigger(value=value)
else:
raise ValueError("Unsupported trigger action")
def set_joystick(self, joystick, value):
"""
Set the position of a joystick on the gamepad.
Parameters:
joystick (str): The name of the joystick axis.
value (float): The value to set the joystick axis to (between -32768 and 32767)
"""
if joystick == "AXIS_LEFTX":
self.left_joystick_x = value
self.gamepad.left_joystick(x_value=self.left_joystick_x, y_value=self.left_joystick_y)
elif joystick == "AXIS_LEFTY":
if self.system == "windows":
value = -value - 1
self.left_joystick_y = value
self.gamepad.left_joystick(x_value=self.left_joystick_x, y_value=self.left_joystick_y)
elif joystick == "AXIS_RIGHTX":
self.right_joystick_x = value
self.gamepad.right_joystick(
x_value=self.right_joystick_x, y_value=self.right_joystick_y
)
elif joystick == "AXIS_RIGHTY":
if self.system == "windows":
value = -value - 1
self.right_joystick_y = value
self.gamepad.right_joystick(
x_value=self.right_joystick_x, y_value=self.right_joystick_y
)
else:
raise ValueError("Unsupported joystick action")
def wakeup(self, duration=0.1):
"""
Wake up the controller by pressing a button.
Parameters:
duration (float): Duration to press the button.
"""
self.gamepad.press_button(vg.XUSB_BUTTON.XUSB_GAMEPAD_LEFT_THUMB)
self.gamepad.update()
time.sleep(duration)
self.gamepad.reset()
self.gamepad.update()
time.sleep(duration)
def reset(self):
"""
Reset the gamepad to its default state.
"""
self.gamepad.reset()
self.gamepad.update()
class PyautoguiScreenshotBackend:
def __init__(self, bbox):
self.bbox = bbox
def screenshot(self):
return pyautogui.screenshot(region=self.bbox)
class DxcamScreenshotBackend:
def __init__(self, bbox):
import dxcam
self.camera = dxcam.create()
self.bbox = bbox
self.last_screenshot = None
def screenshot(self):
screenshot = self.camera.grab(region=self.bbox)
if screenshot is None:
print("DXCAM failed to capture frame, trying to use the latest screenshot")
if self.last_screenshot is not None:
return self.last_screenshot
else:
return Image.new("RGB", (self.bbox[2], self.bbox[3]), (0, 0, 0))
screenshot = Image.fromarray(screenshot)
self.last_screenshot = screenshot
return screenshot
class GamepadEnv(Env):
"""
Base class for creating a game environment controlled with a gamepad.
Attributes:
game (str): Name of the game to interact with.
image_height (int): Height of the observation space.
image_width (int): Width of the observation space.
controller_type (str): Platform for the gamepad emulator ("xbox" or "ps4").
game_speed (float): Speed multiplier for the game.
env_fps (int): Number of actions to perform per second at normal speed.
async_mode (bool): Whether to pause/unpause the game during each step.
"""
def __init__(
self,
game,
image_height=1440,
image_width=2560,
controller_type="xbox",
game_speed=1.0,
env_fps=10,
async_mode=True,
screenshot_backend="dxcam",
):
super().__init__()
# Assert that system is windows
os_name = platform.system().lower()
assert os_name == "windows", "This environment is currently only supported on Windows."
assert controller_type in ["xbox", "ps4"], "Platform must be either 'xbox' or 'ps4'"
assert screenshot_backend in ["pyautogui", "dxcam"], "Screenshot backend must be either 'pyautogui' or 'dxcam'"
self.game = game
self.image_height = int(image_height)
self.image_width = int(image_width)
self.game_speed = game_speed
self.env_fps = env_fps
self.step_duration = self.calculate_step_duration()
self.async_mode = async_mode
self.gamepad_emulator = GamepadEmulator(controller_type=controller_type, system=os_name)
proc_info = get_process_info(game)
self.game_pid = proc_info["pid"]
self.game_arch = proc_info["architecture"]
self.game_window_name = proc_info["window_name"]
print(f"Game process found: {self.game} (PID: {self.game_pid}, Arch: {self.game_arch}, Window: {self.game_window_name})")
if self.game_pid is None:
raise Exception(f"Could not find PID for game: {game}")
self.observation_space = Box(
low=0, high=255, shape=(self.image_height, self.image_width, 3), dtype="uint8"
)
# Define a unified action space
self.action_space = Dict(
{
"BACK": Discrete(2),
"GUIDE": Discrete(2),
"RIGHT_SHOULDER": Discrete(2),
"RIGHT_TRIGGER": Box(low=0.0, high=1.0, shape=(1,)),
"LEFT_TRIGGER": Box(low=0.0, high=1.0, shape=(1,)),
"LEFT_SHOULDER": Discrete(2),
"AXIS_RIGHTX": Box(low=-32768.0, high=32767, shape=(1,)),
"AXIS_RIGHTY": Box(low=-32768.0, high=32767, shape=(1,)),
"AXIS_LEFTX": Box(low=-32768.0, high=32767, shape=(1,)),
"AXIS_LEFTY": Box(low=-32768.0, high=32767, shape=(1,)),
"LEFT_THUMB": Discrete(2),
"RIGHT_THUMB": Discrete(2),
"DPAD_UP": Discrete(2),
"DPAD_RIGHT": Discrete(2),
"DPAD_DOWN": Discrete(2),
"DPAD_LEFT": Discrete(2),
"WEST": Discrete(2),
"SOUTH": Discrete(2),
"EAST": Discrete(2),
"NORTH": Discrete(2),
"START": Discrete(2),
}
)
# Determine window name
windows = pwc.getAllWindows()
self.game_window = None
for window in windows:
if window.title == self.game_window_name:
self.game_window = window
break
if not self.game_window:
raise Exception(f"No window found with game name: {self.game}")
self.game_window.activate()
l, t, r, b = self.game_window.left, self.game_window.top, self.game_window.right, self.game_window.bottom
self.bbox = (l, t, r-l, b-t)
# Initialize speedhack client if using DLL injection
self.speedhack_client = xsh.Client(process_id=self.game_pid, arch=self.game_arch)
# Get the screenshot backend
if screenshot_backend == "dxcam":
self.screenshot_backend = DxcamScreenshotBackend(self.bbox)
elif screenshot_backend == "pyautogui":
self.screenshot_backend = PyautoguiScreenshotBackend(self.bbox)
else:
raise ValueError("Unsupported screenshot backend. Use 'dxcam' or 'pyautogui'.")
def calculate_step_duration(self):
"""
Calculate the step duration based on game speed and environment FPS.
Returns:
float: Calculated step duration.
Example:
If game_speed=1.0 and env_fps=10, then step_duration
will be 0.1 seconds.
"""
return 1.0 / (self.env_fps * self.game_speed)
def unpause(self):
"""
Unpause the game using the specified method.
"""
self.speedhack_client.set_speed(1.0)
def pause(self):
"""
Pause the game using the specified method.
"""
self.speedhack_client.set_speed(0.0)
def perform_action(self, action, duration):
"""
Perform the action without handling the game pause/unpause.
Parameters:
action (dict): Action to be performed.
duration (float): Duration for the action step.
"""
self.gamepad_emulator.step(action)
start = time.perf_counter()
self.unpause()
# Wait until the next step
end = start + self.step_duration
now = time.perf_counter()
while now < end:
now = time.perf_counter()
self.pause()
def step(self, action, step_duration=None):
"""
Perform an action in the game environment and return the observation.
Parameters:
action (dict): Dictionary of the action to be performed. Keys are control names,
and values are their respective states.
step_duration (float, optional): Duration for which the action should be performed.
Returns:
tuple: (obs, reward, terminated, truncated, info) where obs is the observation of the game environment after performing the action.
"""
# Determine the duration for this step
duration = step_duration if step_duration is not None else self.step_duration
self.perform_action(action, duration)
obs = self.render() # Render after pausing the game
reward = 0.0
terminated = False
truncated = False
info = {}
return obs, reward, terminated, truncated, info
def reset(self, seed=None, options=None):
"""
Reset the environment to its initial state.
Parameters:
seed (int, optional): Random seed.
options (dict, optional): Additional options for reset.
"""
self.gamepad_emulator.wakeup(duration=0.1)
time.sleep(1.0)
def close(self):
"""
Close the environment and release any resources.
"""
pass # Implement env close logic here
def render(self):
"""
Render the current state of the game window as an observation.
Returns:
Image: Observation of the game environment.
"""
screenshot = self.screenshot_backend.screenshot()
screenshot = screenshot.resize((self.image_width, self.image_height))
return screenshot

View File

@ -0,0 +1,91 @@
import time
import pickle
import numpy as np
import zmq
class ModelClient:
"""Client for model inference server."""
def __init__(self, host="localhost", port=5555):
"""
Initialize client connection.
Args:
host: Server hostname or IP
port: Server port
"""
self.host = host
self.port = port
self.timeout_ms = 30000
self.context = zmq.Context()
self.socket = self.context.socket(zmq.REQ)
self.socket.connect(f"tcp://{host}:{port}")
self.socket.setsockopt(zmq.RCVTIMEO, self.timeout_ms) # Set receive timeout
print(f"Connected to model server at {host}:{port}")
def predict(self, image: np.ndarray) -> dict:
"""
Send an image and receive predicted actions.
Args:
image: numpy array (H, W, 3) in RGB format
Returns:
List of action dicts, each containing:
- j_left: [x, y] left joystick position
- j_right: [x, y] right joystick position
- buttons: list of button values
"""
request = {
"type": "predict",
"image": image
}
self.socket.send(pickle.dumps(request))
response = pickle.loads(self.socket.recv())
if response["status"] != "ok":
raise RuntimeError(f"Server error: {response.get('message', 'Unknown error')}")
return response["pred"]
def reset(self):
"""Reset the server's session (clear buffers)."""
request = {"type": "reset"}
self.socket.send(pickle.dumps(request))
response = pickle.loads(self.socket.recv())
if response["status"] != "ok":
raise RuntimeError(f"Server error: {response.get('message', 'Unknown error')}")
print("Session reset")
def info(self) -> dict:
"""Get session info from the server."""
request = {"type": "info"}
self.socket.send(pickle.dumps(request))
response = pickle.loads(self.socket.recv())
if response["status"] != "ok":
raise RuntimeError(f"Server error: {response.get('message', 'Unknown error')}")
return response["info"]
def close(self):
"""Close the connection."""
self.socket.close()
self.context.term()
print("Connection closed")
def __enter__(self):
"""Support for context manager."""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Close connection when exiting context."""
self.close()

View File

@ -0,0 +1,276 @@
import time
import json
from collections import deque
import torch
import numpy as np
from transformers import AutoImageProcessor
from nitrogen.flow_matching_transformer.nitrogen import NitroGen, NitroGen_Config
from nitrogen.mm_tokenizers import NitrogenTokenizerConfig, NitrogenTokenizer, Tokenizer
from nitrogen.cfg import CkptConfig
from nitrogen.shared import PATH_REPO
def summarize_parameters(module, name='model', depth=0, max_depth=3):
"""
Print a tree-like summary of parameters in a PyTorch module.
Args:
module: PyTorch module to summarize
name: Name of the module (for root level)
depth: Current depth in the tree
max_depth: Maximum depth to traverse
"""
if depth > max_depth:
return
# Count total parameters in this module
total_params = sum(p.numel() for p in module.parameters())
trainable_params = sum(p.numel() for p in module.parameters() if p.requires_grad)
# Print indented summary
indent = " " * depth
print(f"{indent}{name}: {total_params:,} params ({trainable_params:,} trainable)")
# Recursively summarize submodules
if depth < max_depth:
for child_name, child_module in module.named_children():
summarize_parameters(child_module, child_name, depth + 1, max_depth)
def load_model(checkpoint_path: str):
"""Load model and args from checkpoint."""
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
ckpt_config = CkptConfig.model_validate(checkpoint["ckpt_config"])
model_cfg = ckpt_config.model_cfg
tokenizer_cfg = ckpt_config.tokenizer_cfg
print("Checkpoint args:")
print(json.dumps(ckpt_config.model_dump(), indent=4))
# Initialize tokenizer and language model
img_proc = AutoImageProcessor.from_pretrained(model_cfg.vision_encoder_name)
# Create VLM with pre-loaded language model
if isinstance(model_cfg, NitroGen_Config):
assert isinstance(tokenizer_cfg, NitrogenTokenizerConfig), \
"NitroGen_Config requires NitrogenTokenizerConfig for tokenization"
tokenizer_cfg.training = False
if tokenizer_cfg.game_mapping_cfg is not None:
tokenizer_cfg.game_mapping_cfg.src_files = [
x.replace("/mnt/amlfs-02/shared/gaming/gamingvla", str(PATH_REPO))
for x in tokenizer_cfg.game_mapping_cfg.src_files
]
tokenizer = NitrogenTokenizer(tokenizer_cfg)
game_mapping = tokenizer.game_mapping
model = NitroGen(config=model_cfg, game_mapping=game_mapping)
# model.num_inference_timesteps = 16
action_downsample_ratio = 1
else:
raise ValueError(f"Unsupported model config type: {type(model_cfg)}")
summarize_parameters(model, max_depth=3)
print(model)
model.load_state_dict(checkpoint["model"])
model.eval()
tokenizer.eval()
model.to("cuda")
return model, tokenizer, img_proc, ckpt_config, game_mapping, action_downsample_ratio
class InferenceSession:
"""Manages state for a single inference session."""
def __init__(
self,
model,
ckpt_path: str,
tokenizer: Tokenizer,
img_proc,
ckpt_config: CkptConfig,
game_mapping: dict,
selected_game: str,
old_layout: bool,
cfg_scale: float,
action_downsample_ratio: float,
context_length=None
):
self.model = model
self.tokenizer = tokenizer
self.img_proc = img_proc
self.ckpt_config = ckpt_config
self.game_mapping = game_mapping
self.selected_game = selected_game
self.old_layout = old_layout
self.cfg_scale = cfg_scale
self.action_downsample_ratio = action_downsample_ratio
self.ckpt_path = ckpt_path
# Load modality config
self.modality_config = self.ckpt_config.modality_cfg
self.max_buffer_size = context_length if context_length is not None else self.modality_config.frame_per_sample
self.action_interleaving = self.modality_config.action_interleaving
self.is_flowmatching = isinstance(self.ckpt_config.model_cfg, NitroGen_Config)
# Buffers
self.obs_buffer = deque(maxlen=self.max_buffer_size)
self.action_buffer = deque(maxlen=self.max_buffer_size)
@classmethod
def from_ckpt(cls, checkpoint_path: str, old_layout=False, cfg_scale=1.0, context_length=None):
"""Create an InferenceSession from a checkpoint."""
model, tokenizer, img_proc, ckpt_config, game_mapping, action_downsample_ratio = load_model(checkpoint_path)
if game_mapping is not None:
# Ask user to pick a game from the list
print("Available games in tokenizer mapping:")
for game, idx in game_mapping.items():
print(f"{idx:03d}: {game}")
selected_game = input("Enter the game ID to use (leave empty for unconditional): ")
if selected_game == "":
selected_game = None
else:
selected_idx = int(selected_game)
assert selected_idx in game_mapping.values(), f"Invalid game ID {selected_idx}"
candidates = [k for k,v in game_mapping.items() if v == selected_idx]
assert len(candidates) == 1, f"Multiple games found for ID {selected_idx}: {candidates}"
selected_game = candidates[0]
else:
selected_game = None
print("No game mapping available, proceeding without game conditioning")
return cls(
model,
checkpoint_path,
tokenizer,
img_proc,
ckpt_config,
game_mapping,
selected_game,
old_layout,
cfg_scale,
action_downsample_ratio,
context_length
)
def info(self):
return {
"ckpt_path": self.ckpt_path,
"selected_game": self.selected_game,
"old_layout": self.old_layout,
"cfg_scale": self.cfg_scale,
"context_length": self.max_buffer_size,
"action_interleaving": self.action_interleaving,
"is_flowmatching": self.is_flowmatching,
"action_downsample_ratio": self.action_downsample_ratio,
}
def reset(self):
"""Reset all buffers."""
self.obs_buffer.clear()
self.action_buffer.clear()
def predict(self, obs):
start_time = time.time()
current_frame = self.img_proc([obs], return_tensors="pt")["pixel_values"]
self.obs_buffer.append(current_frame)
# Prepare model inputs
pixel_values = torch.cat(list(self.obs_buffer), dim=0)
if self.action_interleaving and len(self.action_buffer) > 0:
action_tensors = {
key: torch.cat([a[key] for a in list(self.action_buffer)], dim=0)
for key in ["buttons", "j_left", "j_right"]
}
else:
action_tensors = {"buttons": None, "j_left": None, "j_right": None}
print("Running inference with the following inputs:")
print(f"- pixel_values: {pixel_values.shape}")
print("- action_tensors:")
for k, v in action_tensors.items():
if v is not None:
print(f" - {k}: {v.shape}")
else:
print(f" - {k}: None")
# Run inference
if self.is_flowmatching:
predicted_actions = self._predict_flowmatching(pixel_values, action_tensors)
else:
predicted_actions = self._predict_ar(pixel_values, action_tensors)
# Add to action buffer
self.action_buffer.append(predicted_actions)
inference_time = time.time() - start_time
print(f"Inference time: {inference_time:.3f}s")
# Convert to list of action dicts
n_actions = len(predicted_actions["buttons"])
j_left = predicted_actions["j_left"].squeeze().cpu().numpy()
j_right = predicted_actions["j_right"].squeeze().cpu().numpy()
buttons = predicted_actions["buttons"].squeeze().cpu().numpy()
return {
"j_left": j_left,
"j_right": j_right,
"buttons": buttons,
}
def _predict_flowmatching(self, pixel_values, action_tensors):
available_frames = len(self.obs_buffer)
frames = torch.zeros((self.max_buffer_size, *pixel_values.shape[1:]),
dtype=pixel_values.dtype, device="cuda")
frames[-available_frames:] = pixel_values
dropped_frames = torch.zeros((self.max_buffer_size,), dtype=torch.bool, device="cuda")
dropped_frames[:self.max_buffer_size - available_frames] = True
data_with_history = {
"frames": frames,
"dropped_frames": dropped_frames,
"game": self.selected_game
}
tokenized_data_with_history = self.tokenizer.encode(data_with_history)
frame_mask = torch.ones((self.max_buffer_size,), dtype=torch.bool, device="cuda")
frame_mask[-1] = False
data_without_history = {
"frames": frames,
"dropped_frames": frame_mask,
"game": None
}
tokenized_data_without_history = self.tokenizer.encode(data_without_history)
# Convert to CUDA tensors with batch dimension
for tokenized_data in [tokenized_data_with_history, tokenized_data_without_history]:
for k, v in tokenized_data.items():
if isinstance(v, torch.Tensor):
tokenized_data[k] = v.unsqueeze(0).to("cuda")
elif isinstance(v, np.ndarray):
tokenized_data[k] = torch.tensor(v, device="cuda").unsqueeze(0)
else:
tokenized_data[k] = [v]
with torch.inference_mode():
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
if self.cfg_scale == 1.0:
model_output = self.model.get_action(tokenized_data_with_history,
old_layout=self.old_layout)
else:
model_output = self.model.get_action_with_cfg(
tokenized_data_with_history,
tokenized_data_without_history,
cfg_scale=self.cfg_scale
)
predicted_actions = self.tokenizer.decode(model_output)
return predicted_actions

252
nitrogen/inference_viz.py Normal file
View File

@ -0,0 +1,252 @@
import numpy as np
import cv2
import av
def create_viz(
frame: np.ndarray,
i: int,
j_left: np.ndarray,
j_right: np.ndarray,
buttons: np.ndarray,
token_set: list,
):
"""
Visualize gamepad actions alongside a gameplay video frame.
Parameters:
- frame: Video frame as numpy array
- i: Current frame index (default 0)
- j_left: 16x2 array of left joystick positions (-1 to 1)
- j_right: 16x2 array of right joystick positions (-1 to 1)
- buttons: 16x17 array of button states (boolean)
- token_set: List of button names
Returns:
- Visualization as numpy array
"""
# Get frame dimensions
frame_height, frame_width = frame.shape[:2]
# Create visualization area
viz_width = min(500, frame_width)
combined_width = frame_width + viz_width
combined_height = frame_height
# Create combined image (frame + visualization)
combined = np.zeros((combined_height, combined_width, 3), dtype=np.uint8)
# Put the frame on the left side
combined[:frame_height, :frame_width] = frame
# Starting position for visualizations
viz_x = frame_width
viz_y = 20
# Draw joysticks if data is provided
if i < len(j_left) and i < len(j_right):
# Add section title
cv2.putText(combined, "JOYSTICKS",
(viz_x + 10, viz_y),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (200, 200, 200), 1)
viz_y += 30 # Move down after title
# Size of joystick visualization
joy_size = min(120, viz_width // 3)
# Horizontal positions of joysticks
joy_left_x = viz_x + 30
joy_right_x = viz_x + viz_width - joy_size - 30
# Draw joystick labels
cv2.putText(combined, "Left", (joy_left_x, viz_y - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (180, 180, 180), 1)
cv2.putText(combined, "Right", (joy_right_x, viz_y - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (180, 180, 180), 1)
# Draw joysticks
draw_joystick(combined, joy_left_x, viz_y, joy_size, j_left[i])
draw_joystick(combined, joy_right_x, viz_y, joy_size, j_right[i])
viz_y += joy_size + 40 # Move down after joysticks
# Draw buttons if data is provided
if buttons is not None and i < len(buttons):
# Add section title
cv2.putText(combined, "BUTTON STATES",
(viz_x + 10, viz_y),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (200, 200, 200), 1)
viz_y += 30 # Move down after title
# Size and position of button grid
button_grid_x = viz_x + 20
button_grid_y = viz_y
button_size = 20
# Draw button grid
draw_button_grid(combined, button_grid_x, button_grid_y,
button_size, buttons, i, token_set)
return combined
def draw_joystick(img, x, y, size, position):
"""Draw a joystick visualization at the specified position."""
# Draw joystick background
cv2.rectangle(img, (x, y), (x + size, y + size), (50, 50, 50), -1)
cv2.rectangle(img, (x, y), (x + size, y + size), (100, 100, 100), 1)
# Calculate center point
mid_x = x + size // 2
mid_y = y + size // 2
# Draw center cross (0,0 coordinates)
cv2.line(img, (x, mid_y), (x + size, mid_y), (150, 150, 150), 1)
cv2.line(img, (mid_x, y), (mid_x, y + size), (150, 150, 150), 1)
# Draw 2x2 grid
quarter_x = x + size // 4
quarter_y = y + size // 4
three_quarters_x = x + 3 * size // 4
three_quarters_y = y + 3 * size // 4
# Draw grid lines
cv2.line(img, (quarter_x, y), (quarter_x, y + size), (100, 100, 100), 1)
cv2.line(img, (three_quarters_x, y), (three_quarters_x, y + size), (100, 100, 100), 1)
cv2.line(img, (x, quarter_y), (x + size, quarter_y), (100, 100, 100), 1)
cv2.line(img, (x, three_quarters_y), (x + size, three_quarters_y), (100, 100, 100), 1)
# Draw joystick position (clamp coordinates to valid range)
px = max(-1, min(1, position[0]))
py = max(-1, min(1, position[1]))
joy_x = int(mid_x + px * size // 2)
joy_y = int(mid_y - py * size // 2) # Y is inverted in image coordinates
# Draw joystick position as a dot
cv2.circle(img, (joy_x, joy_y), 5, (0, 0, 255), -1) # Red dot
def draw_button_grid(img, x, y, button_size, buttons, current_row, token_set):
"""Draw the button state grid."""
rows, cols = buttons.shape
# Ensure the grid fits in the visualization area
available_width = img.shape[1] - x - 20
if cols * button_size > available_width:
button_size = max(10, available_width // cols)
# Draw column numbers at the top
for col in range(cols):
number_x = x + col * button_size + button_size // 2
number_y = y - 5
cv2.putText(img, str(col + 1), (number_x - 4, number_y),
cv2.FONT_HERSHEY_SIMPLEX, 0.4, (200, 200, 200), 1)
# Draw button grid
for row in range(rows):
for col in range(cols):
# Calculate button position
bx = x + col * button_size
by = y + row * button_size
# Draw button cell
color = (0, 255, 0) if buttons[row, col] else (0, 0, 0) # Green if pressed, black otherwise
cv2.rectangle(img, (bx, by), (bx + button_size, by + button_size), color, -1)
# Draw grid lines
cv2.rectangle(img, (bx, by), (bx + button_size, by + button_size), (80, 80, 80), 1)
# Highlight current row
highlight_y = y + current_row * button_size
cv2.rectangle(img, (x, highlight_y), (x + cols * button_size, highlight_y + button_size),
(0, 0, 255), 2) # Red highlight
# Draw button legend below the mosaic
if token_set is not None:
legend_y = y + rows * button_size + 20 # Starting Y position for legend
legend_x = x # Starting X position for legend
line_height = 15 # Height of each legend line
# Add legend title
cv2.putText(img, "Button Legend:", (legend_x, legend_y),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (200, 200, 200), 1)
legend_y += line_height + 5 # Move down after title
# Calculate how many columns to use for the legend based on available space
legend_cols = max(1, min(3, cols // 6)) # Use 1-3 columns depending on button count
legend_items_per_col = (cols + legend_cols - 1) // legend_cols # Items per column with ceiling division
# Draw legend entries
for col in range(min(cols, len(token_set))):
# Calculate position in the legend grid
legend_col = col // legend_items_per_col
legend_row = col % legend_items_per_col
# Calculate position
entry_x = legend_x + legend_col * (available_width // legend_cols)
entry_y = legend_y + legend_row * line_height
# Add legend entry
if col < len(token_set):
cv2.putText(img, f"{col+1}. {token_set[col]}", (entry_x, entry_y),
cv2.FONT_HERSHEY_SIMPLEX, 0.4, (200, 200, 200), 1)
class VideoRecorder:
def __init__(self, output_file, fps=30, crf=28, preset="fast"):
"""
Initialize a video recorder using PyAV.
Args:
output_file (str): Path to save the video file
fps (int): Frames per second
crf (int): Constant Rate Factor (0-51, higher means smaller file but lower quality)
preset (str): Encoding preset (ultrafast, superfast, veryfast, faster, fast, medium, slow, slower, veryslow)
"""
self.output_file = output_file
self.fps = fps
self.crf = str(crf)
self.preset = preset
self.container = av.open(output_file, mode="w")
self.stream = None
def init_stream(self, width, height):
"""Initialize the video stream with the frame dimensions."""
self.stream = self.container.add_stream("h264", rate=self.fps)
self.stream.width = width
self.stream.height = height
self.stream.pix_fmt = "yuv420p"
self.stream.options = {
"crf": self.crf,
"preset": self.preset
}
def add_frame(self, frame):
"""
Add a frame to the video.
Args:
frame (numpy.ndarray): Frame as RGB numpy array
"""
if self.stream is None:
self.init_stream(frame.shape[1], frame.shape[0])
av_frame = av.VideoFrame.from_ndarray(np.array(frame), format="rgb24")
for packet in self.stream.encode(av_frame):
self.container.mux(packet)
def close(self):
"""Flush remaining packets and close the video file."""
try:
if self.stream is not None:
for packet in self.stream.encode():
self.container.mux(packet)
finally:
self.container.close()
def __enter__(self):
"""Support for context manager."""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Close the recorder when exiting the context."""
self.close()

332
nitrogen/mm_tokenizers.py Normal file
View File

@ -0,0 +1,332 @@
import os
from abc import ABC, abstractmethod
from typing import Literal
import numpy as np
import torch
import polars as pl
from pydantic import BaseModel, Field
from itertools import count
DEBUG = int(os.getenv("DEBUG", 0))
class Tokenizer(ABC):
@abstractmethod
def encode(self, data: dict) -> dict:
"""
Transform the input data into a tokenized format.
Args:
data (dict): Input data containing frames and actions.
Returns:
dict: Tokenized data ready for model input.
"""
pass
@abstractmethod
def decode(self, data: dict) -> dict:
"""
Reverse the tokenization process to retrieve original data.
Args:
data (dict): Tokenized data.
Returns:
dict: Original data structure.
"""
pass
@abstractmethod
def train(self):
"""
Set the tokenizer to training mode.
"""
pass
@abstractmethod
def eval(self):
"""
Set the tokenizer to evaluation mode.
"""
pass
# Set IDs for each token type.
_PAD_TOKEN = 0
_IMG_TOKEN = 1
_IMG_SEP_TOKEN = 5 # New separator token
_LANG_TOKEN = 2
_PROPRIO_TOKEN = 3
_ACT_TOKEN = 4
_GAME_ID_TOKEN = 6
_UNCONDITIONAL_ID = None # Special ID for unconditional game
class GameMappingConfig(BaseModel):
src_files: list[str] = Field(default_factory=list, description="List of source parquet files to build game mapping.")
def get_game_mapping(cfg: GameMappingConfig) -> dict:
game_set = set()
for path in cfg.src_files:
df = pl.read_parquet(path)
for game in df['game_label'].unique():
if game == _UNCONDITIONAL_ID:
continue
game_set.add(game)
games = sorted(list(game_set))
# Set the 0th element to be the unconditional game ID
games = [_UNCONDITIONAL_ID] + games
return {game: idx for idx, game in enumerate(games)}
class NitrogenTokenizerConfig(BaseModel):
tokenizer_id: Literal['nitrogen'] = Field(default='nitrogen', frozen=True)
training: bool = Field(default=True, description="Whether to apply the transform in training mode.")
num_visual_tokens_per_frame: int = Field(default=256, description="Number of visual tokens per frame.")
max_action_dim: int = Field(default=25, description="Maximum action dimension.")
max_sequence_length: int = Field(default=300, description="Maximum sequence length.")
action_horizon: int = Field(default=16, description="Action horizon.")
game_mapping_cfg: GameMappingConfig | None = Field(default=None, description="Game mapping configuration.")
old_layout: bool = Field(default=False, description="Whether to use the old layout for actions. If True, the action layout is [buttons, j_left, j_right]. If False, it is [j_left, j_right, buttons].")
class NitrogenTokenizer(Tokenizer):
"""
Example transform that prepares video, language, state, and actions
into a token-based format suitable for your model.
The sub-methods below (prefixed with `_prepare_`) mirror the original
modular structure.
"""
def __init__(self, config: NitrogenTokenizerConfig):
self.training = config.training
self.num_visual_tokens_per_frame = config.num_visual_tokens_per_frame
self.max_action_dim = config.max_action_dim
self.max_sequence_length = config.max_sequence_length
self.action_horizon = config.action_horizon
self.old_layout = config.old_layout
if config.game_mapping_cfg:
self.game_mapping = get_game_mapping(config.game_mapping_cfg)
with open("game_mapping.json", "w") as f:
import json
json.dump(self.game_mapping, f, indent=2)
else:
self.game_mapping = None
def train(self):
self.training = True
def eval(self):
self.training = False
def check_batch_size(self, data):
# Use video key to determine batch size.
video_ndim = data["images"].ndim
if video_ndim == 4: # Interpret as [T*V, H, W, C]
is_batched = False
batch_size = 1
elif video_ndim == 5: # Interpret as [B, T*V, H, W, C]
is_batched = True
batch_size = data["images"].shape[0]
else:
raise ValueError(f"Unsupported video number of dimensions: {video_ndim}")
return is_batched, batch_size
def _prepare_action(self, data: dict):
"""
Pad to max_action_dim, return masks.
"""
if "action" not in data:
actions = np.zeros((self.action_horizon, self.max_action_dim))
actions_mask = np.zeros((self.action_horizon, self.max_action_dim), dtype=bool)
n_action_tokens = self.action_horizon
return actions, actions_mask, n_action_tokens
actions = data["action"]
assert actions.shape[0] == self.action_horizon, f"{actions.shape=}, {self.action_horizon=}"
n_action_tokens = actions.shape[0] # T
n_action_dims = actions.shape[1]
assert (
n_action_dims <= self.max_action_dim
), f"Action dim {n_action_dims} exceeds max allowed {self.max_action_dim}."
# Pad the channel dimension
actions = np.pad(actions, ((0, 0), (0, self.max_action_dim - n_action_dims)), "constant")
# Create mask: [T, max_action_dim]
actions_mask = np.zeros((n_action_tokens, self.max_action_dim), dtype=bool)
actions_mask[:, :n_action_dims] = True
return actions, actions_mask, n_action_tokens
def _build_token_ids(self, n_images, n_action_tokens): # n_lang_tokens, n_state_tokens):
"""
Build the 1D array of token_ids based on the number of each block.
Return (token_ids, special_pad_token_idx).
"""
vl_token_ids = []
sa_token_ids = []
# 0.5) Add a Game ID placeholder
if self.game_mapping:
vl_token_ids.append(_GAME_ID_TOKEN)
# 1) Video placeholders
for _ in range(n_images):
vl_token_ids.extend([_IMG_TOKEN] * self.num_visual_tokens_per_frame)
# 2) Action tokens
sa_token_ids.extend([_ACT_TOKEN] * n_action_tokens)
return np.array(vl_token_ids), np.array(sa_token_ids)
def _prepare_attention_mask(
self,
vl_token_ids: np.ndarray,
):
"""
Build 1D attention mask for vision-language tokens.
1 indicates valid token, 0 indicates padding token.
State-action attention will be handled separately by the model.
"""
# Only create attention mask for vision-language tokens
vl_seq_len = vl_token_ids.shape[0]
vl_attn_mask = np.ones(vl_seq_len, dtype=bool) # All tokens are valid initially
# Pad vl_token_ids and vl_attn_mask to max_sequence_length
if vl_seq_len > self.max_sequence_length:
raise ValueError("VL sequence length exceeds the max sequence length!")
left_pad_len = self.max_sequence_length - vl_seq_len
# Pad token_ids (with PAD_TOKEN)
vl_token_ids = np.pad(vl_token_ids, (left_pad_len, 0), constant_values=_PAD_TOKEN)
# Pad attention mask with 0 (padding tokens)
vl_attn_mask = np.pad(vl_attn_mask, (left_pad_len, 0), constant_values=0)
return vl_token_ids, vl_attn_mask
def pack_actions(self, buttons, j_left, j_right):
# Check that the first two dims of each input is the same (number of chunks, control frequency)
assert buttons.shape[:2] == j_left.shape[:2] == j_right.shape[:2], (
f"buttons shape: {buttons.shape}, "
f"j_left shape: {j_left.shape}, "
f"j_right shape: {j_right.shape}"
)
# Normalize the joysticks to 0,1
j_left = (j_left + 1) / 2.
j_right = (j_right + 1) / 2.
# Concatenate the buttons and joysticks along the last dimension
action = np.concatenate([buttons,j_left,j_right],axis=-1, dtype=np.float32)
# Squeeze the first dimension of each input: this is the number of chunks, which is 1 here
action = action.squeeze(0)
return action
def unpack_actions(self, actions):
if self.old_layout:
# Unpack the actions into j_left, j_right, buttons
j_left = actions[:, :, :2]
j_right = actions[:, :, 2:4]
buttons = actions[:, :, 4:]
else:
# Unpack the actions into j_left, j_right, buttons
buttons = actions[:, :, :-4]
j_left = actions[:, :, -4:-2]
j_right = actions[:, :, -2:]
# Denormalize the joysticks back to -1,1
j_left = j_left * 2. - 1.
j_right = j_right * 2. - 1.
# Clip into [-1,1]
j_left = torch.clamp(j_left, -1, 1)
j_right = torch.clamp(j_right, -1, 1)
# Threshold the buttons to 0/1
buttons = (buttons > 0.5).float()
return j_left, j_right, buttons
###########################################################################
# apply
###########################################################################
def encode(self, data: dict) -> dict:
"""
Main entry point for the transform. We assume that `data` has
data['video'], data['language'], data['state'], and data['action'] in
the shapes needed. If you have multiple keys for each modality, you
could use your own grouping logic (similar to GR1Transform) first.
"""
# 1) Pack buttons/joysticks into a single action tensor
transformed_data = {**data} # Start with a copy of the input data
n_images = (data["dropped_frames"] == False).sum()
transformed_data["images"] = data["frames"]
transformed_data["dropped_images"] = data["dropped_frames"]
if self.training:
# Keep the original actions in the data for evaluation
packed_actions = self.pack_actions(
data["buttons"],
data["j_left"],
data["j_right"]
)
data["action"] = packed_actions
transformed_data["has_real_action"] = np.ones((), dtype=bool)
actions, actions_mask, n_action_tokens = self._prepare_action(data)
transformed_data["actions"] = actions
transformed_data["actions_mask"] = actions_mask
action_and_mask_keys = ["actions", "actions_mask"]
assert all(
transformed_data[key].shape == transformed_data["actions"].shape
for key in action_and_mask_keys
), f"Shape mismatch: {[(key, transformed_data[key].shape) for key in action_and_mask_keys]}"
else:
n_action_tokens = self.action_horizon
transformed_data["has_detection_target"] = np.zeros((), dtype=bool)
# 5) Build token_ids
vl_token_ids, sa_token_ids = self._build_token_ids(
n_images, n_action_tokens
)
# 6) Build the attention mask only for vision-language tokens
vl_token_ids, vl_attn_mask = self._prepare_attention_mask(vl_token_ids)
transformed_data["vl_token_ids"] = vl_token_ids
transformed_data["sa_token_ids"] = sa_token_ids
transformed_data["vl_attn_mask"] = vl_attn_mask
transformed_data["embodiment_id"] = torch.tensor(0, dtype=torch.long)
if self.game_mapping:
game_name = data["game"]
assert game_name in self.game_mapping, f"Game '{game_name}' not found in game mapping."
transformed_data["game_ids"] = torch.tensor(self.game_mapping[game_name], dtype=torch.long)
else:
transformed_data["game_ids"] = torch.tensor(0, dtype=torch.long)
return transformed_data
def decode(self, data: dict) -> dict:
j_left, j_right, buttons = self.unpack_actions(data["action_tensor"])
return {
"j_left": j_left,
"j_right": j_right,
"buttons": buttons,
}

10
nitrogen/shared.py Normal file
View File

@ -0,0 +1,10 @@
from pathlib import Path
BUTTON_ACTION_TOKENS = [
'BACK', 'DPAD_DOWN', 'DPAD_LEFT', 'DPAD_RIGHT', 'DPAD_UP', 'EAST', 'GUIDE',
'LEFT_SHOULDER', 'LEFT_THUMB', 'LEFT_TRIGGER', 'NORTH', 'RIGHT_BOTTOM', 'RIGHT_LEFT',
'RIGHT_RIGHT', 'RIGHT_SHOULDER', 'RIGHT_THUMB', 'RIGHT_TRIGGER', 'RIGHT_UP', 'SOUTH',
'START', 'WEST'
]
PATH_REPO = Path(__file__).parent.parent.resolve()

72
pyproject.toml Normal file
View File

@ -0,0 +1,72 @@
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
[project]
name = "nitrogen"
version = "1.0.0"
description = "VLM-based game playing agent"
readme = "README.md"
requires-python = ">=3.10"
# Default: everything
dependencies = [
# Shared
"numpy",
"pyzmq",
# Serve
"torch",
"pyyaml",
"einops",
"transformers",
"pydantic",
"diffusers",
"polars",
# Play (Windows-only deps marked)
"pillow",
"opencv-python",
"pyautogui",
"gymnasium",
"psutil",
"av",
"dxcam; sys_platform == 'win32'",
"pywinctl; sys_platform == 'win32'",
"vgamepad; sys_platform == 'win32'",
"pywin32; sys_platform == 'win32'",
"xspeedhack; sys_platform == 'win32'",
]
[project.optional-dependencies]
serve = [
"numpy",
"pyzmq",
"torch",
"pyyaml",
"einops",
"transformers",
"pydantic",
"diffusers",
"polars",
]
play = [
"numpy",
"pyzmq",
"pillow",
"opencv-python",
"pyautogui",
"gymnasium",
"psutil",
"av",
"dxcam",
"pywinctl",
"vgamepad",
"pywin32",
"xspeedhack",
]
[tool.setuptools.packages.find]
where = ["."]
exclude = ["scripts*"]

232
scripts/play.py Normal file
View File

@ -0,0 +1,232 @@
import os
import sys
import time
import json
from pathlib import Path
from collections import OrderedDict
import cv2
import numpy as np
from PIL import Image
from nitrogen.game_env import GamepadEnv
from nitrogen.shared import BUTTON_ACTION_TOKENS, PATH_REPO
from nitrogen.inference_viz import create_viz, VideoRecorder
from nitrogen.inference_client import ModelClient
import argparse
parser = argparse.ArgumentParser(description="VLM Inference")
parser.add_argument("--process", type=str, default="celeste.exe", help="Game to play")
parser.add_argument("--allow-menu", action="store_true", help="Allow menu actions (Disabled by default)")
parser.add_argument("--port", type=int, default=5555, help="Port for model server")
args = parser.parse_args()
policy = ModelClient(port=args.port)
policy.reset()
policy_info = policy.info()
action_downsample_ratio = policy_info["action_downsample_ratio"]
CKPT_NAME = Path(policy_info["ckpt_path"]).stem
NO_MENU = not args.allow_menu
PATH_DEBUG = PATH_REPO / "debug"
PATH_DEBUG.mkdir(parents=True, exist_ok=True)
PATH_OUT = (PATH_REPO / "out" / CKPT_NAME).resolve()
PATH_OUT.mkdir(parents=True, exist_ok=True)
BUTTON_PRESS_THRES = 0.5
# Find in path_out the list of existing video files, named 0001.mp4, 0002.mp4, etc.
# If they exist, find the max number and set the next number to be max + 1
video_files = sorted(PATH_OUT.glob("*_DEBUG.mp4"))
if video_files:
existing_numbers = [f.name.split("_")[0] for f in video_files]
existing_numbers = [int(n) for n in existing_numbers if n.isdigit()]
next_number = max(existing_numbers) + 1
else:
next_number = 1
PATH_MP4_DEBUG = PATH_OUT / f"{next_number:04d}_DEBUG.mp4"
PATH_MP4_CLEAN = PATH_OUT / f"{next_number:04d}_CLEAN.mp4"
PATH_ACTIONS = PATH_OUT / f"{next_number:04d}_ACTIONS.json"
def preprocess_img(main_image):
main_cv = cv2.cvtColor(np.array(main_image), cv2.COLOR_RGB2BGR)
final_image = cv2.resize(main_cv, (256, 256), interpolation=cv2.INTER_AREA)
return Image.fromarray(cv2.cvtColor(final_image, cv2.COLOR_BGR2RGB))
zero_action = OrderedDict(
[
("WEST", 0),
("SOUTH", 0),
("BACK", 0),
("DPAD_DOWN", 0),
("DPAD_LEFT", 0),
("DPAD_RIGHT", 0),
("DPAD_UP", 0),
("GUIDE", 0),
("AXIS_LEFTX", np.array([0], dtype=np.long)),
("AXIS_LEFTY", np.array([0], dtype=np.long)),
("LEFT_SHOULDER", 0),
("LEFT_TRIGGER", np.array([0], dtype=np.long)),
("AXIS_RIGHTX", np.array([0], dtype=np.long)),
("AXIS_RIGHTY", np.array([0], dtype=np.long)),
("LEFT_THUMB", 0),
("RIGHT_THUMB", 0),
("RIGHT_SHOULDER", 0),
("RIGHT_TRIGGER", np.array([0], dtype=np.long)),
("START", 0),
("EAST", 0),
("NORTH", 0),
]
)
TOKEN_SET = BUTTON_ACTION_TOKENS
print("Model loaded, starting environment...")
for i in range(3):
print(f"{3 - i}...")
time.sleep(1)
env = GamepadEnv(
game=args.process,
game_speed=1.0,
env_fps=60,
async_mode=True,
)
# These games requires to open a menu to initialize the controller
if args.process == "isaac-ng.exe":
print(f"GamepadEnv ready for {args.process} at {env.env_fps} FPS")
input("Press enter to create a virtual controller and start rollouts...")
for i in range(3):
print(f"{3 - i}...")
time.sleep(1)
def press(button):
env.gamepad_emulator.press_button(button)
env.gamepad_emulator.gamepad.update()
time.sleep(0.05)
env.gamepad_emulator.release_button(button)
env.gamepad_emulator.gamepad.update()
press("SOUTH")
for k in range(5):
press("EAST")
time.sleep(0.3)
if args.process == "Cuphead.exe":
print(f"GamepadEnv ready for {args.process} at {env.env_fps} FPS")
input("Press enter to create a virtual controller and start rollouts...")
for i in range(3):
print(f"{3 - i}...")
time.sleep(1)
def press(button):
env.gamepad_emulator.press_button(button)
env.gamepad_emulator.gamepad.update()
time.sleep(0.05)
env.gamepad_emulator.release_button(button)
env.gamepad_emulator.gamepad.update()
press("SOUTH")
for k in range(5):
press("EAST")
time.sleep(0.3)
env.reset()
env.pause()
# Initial call to get state
obs, reward, terminated, truncated, info = env.step(action=zero_action)
frames = None
step_count = 0
with VideoRecorder(str(PATH_MP4_DEBUG), fps=60, crf=32, preset="medium") as debug_recorder:
with VideoRecorder(str(PATH_MP4_CLEAN), fps=60, crf=28, preset="medium") as clean_recorder:
try:
while True:
obs = preprocess_img(obs)
obs.save(PATH_DEBUG / f"{step_count:05d}.png")
pred = policy.predict(obs)
j_left, j_right, buttons = pred["j_left"], pred["j_right"], pred["buttons"]
n = len(buttons)
assert n == len(j_left) == len(j_right), "Mismatch in action lengths"
env_actions = []
for i in range(n):
move_action = zero_action.copy()
xl, yl = j_left[i]
xr, yr = j_right[i]
move_action["AXIS_LEFTX"] = np.array([int(xl * 32767)], dtype=np.long)
move_action["AXIS_LEFTY"] = np.array([int(yl * 32767)], dtype=np.long)
move_action["AXIS_RIGHTX"] = np.array([int(xr * 32767)], dtype=np.long)
move_action["AXIS_RIGHTY"] = np.array([int(yr * 32767)], dtype=np.long)
button_vector = buttons[i]
assert len(button_vector) == len(TOKEN_SET), "Button vector length does not match token set length"
for name, value in zip(TOKEN_SET, button_vector):
if "TRIGGER" in name:
move_action[name] = np.array([value * 255], dtype=np.long)
else:
move_action[name] = 1 if value > BUTTON_PRESS_THRES else 0
env_actions.append(move_action)
print(f"Executing {len(env_actions)} actions, each action will be repeated {action_downsample_ratio} times")
for i, a in enumerate(env_actions):
if NO_MENU:
if a["START"]:
print("Model predicted start, disabling this action")
a["GUIDE"] = 0
a["START"] = 0
a["BACK"] = 0
for _ in range(action_downsample_ratio):
obs, reward, terminated, truncated, info = env.step(action=a)
# resize obs to 720p
obs_viz = np.array(obs).copy()
clean_viz = cv2.resize(obs_viz, (1920, 1080), interpolation=cv2.INTER_AREA)
debug_viz = create_viz(
cv2.resize(obs_viz, (1280, 720), interpolation=cv2.INTER_AREA), # 720p
i,
j_left,
j_right,
buttons,
token_set=TOKEN_SET
)
debug_recorder.add_frame(debug_viz)
clean_recorder.add_frame(clean_viz)
# Append env_actions dictionnary to JSONL file
with open(PATH_ACTIONS, "a") as f:
for i, a in enumerate(env_actions):
# convert numpy arrays to lists for JSON serialization
for k, v in a.items():
if isinstance(v, np.ndarray):
a[k] = v.tolist()
a["step"] = step_count
a["substep"] = i
json.dump(a, f)
f.write("\n")
step_count += 1
finally:
env.unpause()
env.close()

64
scripts/serve.py Normal file
View File

@ -0,0 +1,64 @@
import zmq
import argparse
import pickle
from nitrogen.inference_session import InferenceSession
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Model inference server")
parser.add_argument("ckpt", type=str, help="Path to checkpoint file")
parser.add_argument("--port", type=int, default=5555, help="Port to serve on")
parser.add_argument("--old-layout", action="store_true", help="Use old layout")
parser.add_argument("--cfg", type=float, default=1.0, help="CFG scale")
parser.add_argument("--ctx", type=int, default=1, help="Context length")
args = parser.parse_args()
session = InferenceSession.from_ckpt(args.ckpt, old_layout=args.old_layout, cfg_scale=args.cfg, context_length=args.ctx)
# Setup ZeroMQ
context = zmq.Context()
socket = context.socket(zmq.REP)
socket.bind(f"tcp://*:{args.port}")
# Create poller
poller = zmq.Poller()
poller.register(socket, zmq.POLLIN)
print(f"\n{'='*60}")
print(f"Server running on port {args.port}")
print(f"Waiting for requests...")
print(f"{'='*60}\n")
try:
while True:
# Poll with 100ms timeout to allow interrupt handling
events = dict(poller.poll(timeout=100))
if socket in events and events[socket] == zmq.POLLIN:
# Receive request only when data is available
request = socket.recv()
request = pickle.loads(request)
if request["type"] == "reset":
session.reset()
response = {"status": "ok"}
print("Session reset")
elif request["type"] == "info":
info = session.info()
response = {"status": "ok", "info": info}
print("Sent session info")
elif request["type"] == "predict":
raw_image = request["image"]
result = session.predict(raw_image)
response = {
"status": "ok",
"pred": result
}
else:
response = {"status": "error", "message": f"Unknown request type: {request['type']}"}
# Send response
socket.send(pickle.dumps(response))
except KeyboardInterrupt:
print("\nShutting down server...")
exit(0)
finally:
socket.close()
context.term()