init commit
This commit is contained in:
26
nitrogen/cfg.py
Normal file
26
nitrogen/cfg.py
Normal file
@ -0,0 +1,26 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from nitrogen.flow_matching_transformer.nitrogen import NitroGen_Config
|
||||
from nitrogen.mm_tokenizers import NitrogenTokenizerConfig
|
||||
|
||||
class ModalityConfig(BaseModel):
|
||||
frame_per_sample: int = 1 # number of context frames per sample
|
||||
frame_spacing: int | None = None # how many frames to skip between each frame. If None, use action_per_chunk
|
||||
action_per_chunk: int = 8
|
||||
action_shift: int = 1 # how many actions to skip between frame[i] and action_chunk[i]
|
||||
action_interleaving: bool = False # if True, action chunks will be interleaved with context frames and used by the model to predict the next actions
|
||||
token_set: str = "new"
|
||||
|
||||
def model_post_init(self, __context):
|
||||
if self.frame_spacing is None:
|
||||
# Use object.__setattr__ because the model is frozen
|
||||
object.__setattr__(self, 'frame_spacing', self.action_per_chunk)
|
||||
assert self.action_shift >= 1, "Frame shift must be at least 1 for correct action indexing"
|
||||
|
||||
|
||||
class CkptConfig(BaseModel):
|
||||
experiment_name: str = Field(..., description="Name of the experiment")
|
||||
|
||||
model_cfg: NitroGen_Config = Field(..., description="Model configuration. This is a placeholder and should be replaced with the actual model config class.")
|
||||
tokenizer_cfg: NitrogenTokenizerConfig = Field(..., description="Tokenizer configuration. This is a placeholder and should be replaced with the actual tokenizer config class.")
|
||||
modality_cfg: ModalityConfig = Field(..., description="Modality configuration for the dataset mixture.")
|
||||
434
nitrogen/flow_matching_transformer/modules.py
Normal file
434
nitrogen/flow_matching_transformer/modules.py
Normal file
@ -0,0 +1,434 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from diffusers import ConfigMixin, ModelMixin
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
from diffusers.models.attention import Attention, FeedForward
|
||||
from diffusers.models.embeddings import (
|
||||
SinusoidalPositionalEmbedding,
|
||||
TimestepEmbedding,
|
||||
Timesteps,
|
||||
)
|
||||
from torch import nn
|
||||
|
||||
|
||||
class TimestepEncoder(nn.Module):
|
||||
def __init__(self, embedding_dim, compute_dtype=torch.float32):
|
||||
super().__init__()
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
def forward(self, timesteps):
|
||||
dtype = next(self.parameters()).dtype
|
||||
timesteps_proj = self.time_proj(timesteps).to(dtype)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj) # (N, D)
|
||||
return timesteps_emb
|
||||
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
norm_elementwise_affine: bool = False,
|
||||
norm_eps: float = 1e-5,
|
||||
chunk_dim: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
self.chunk_dim = chunk_dim
|
||||
output_dim = embedding_dim * 2
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, output_dim)
|
||||
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
temb = self.linear(self.silu(temb))
|
||||
scale, shift = temb.chunk(2, dim=1)
|
||||
x = self.norm(x) * (1 + scale[:, None]) + shift[:, None]
|
||||
return x
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
dropout=0.0,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
attention_bias: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
|
||||
norm_eps: float = 1e-5,
|
||||
final_dropout: bool = False,
|
||||
attention_type: str = "default",
|
||||
positional_embeddings: Optional[str] = None,
|
||||
num_positional_embeddings: Optional[int] = None,
|
||||
ff_inner_dim: Optional[int] = None,
|
||||
ff_bias: bool = True,
|
||||
attention_out_bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.dropout = dropout
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.activation_fn = activation_fn
|
||||
self.attention_bias = attention_bias
|
||||
self.norm_elementwise_affine = norm_elementwise_affine
|
||||
self.positional_embeddings = positional_embeddings
|
||||
self.num_positional_embeddings = num_positional_embeddings
|
||||
self.norm_type = norm_type
|
||||
|
||||
if positional_embeddings and (num_positional_embeddings is None):
|
||||
raise ValueError(
|
||||
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
|
||||
)
|
||||
|
||||
if positional_embeddings == "sinusoidal":
|
||||
self.pos_embed = SinusoidalPositionalEmbedding(
|
||||
dim, max_seq_length=num_positional_embeddings
|
||||
)
|
||||
else:
|
||||
self.pos_embed = None
|
||||
|
||||
# Define 3 blocks. Each block has its own normalization layer.
|
||||
# 1. Self-Attn
|
||||
if norm_type == "ada_norm":
|
||||
self.norm1 = AdaLayerNorm(dim)
|
||||
else:
|
||||
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
||||
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
upcast_attention=upcast_attention,
|
||||
out_bias=attention_out_bias,
|
||||
)
|
||||
|
||||
# 3. Feed-forward
|
||||
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
||||
self.ff = FeedForward(
|
||||
dim,
|
||||
dropout=dropout,
|
||||
activation_fn=activation_fn,
|
||||
final_dropout=final_dropout,
|
||||
inner_dim=ff_inner_dim,
|
||||
bias=ff_bias,
|
||||
)
|
||||
if final_dropout:
|
||||
self.final_dropout = nn.Dropout(dropout)
|
||||
else:
|
||||
self.final_dropout = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
temb: Optional[torch.LongTensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# 0. Self-Attention
|
||||
if self.norm_type == "ada_norm":
|
||||
norm_hidden_states = self.norm1(hidden_states, temb)
|
||||
else:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
|
||||
if self.pos_embed is not None:
|
||||
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
||||
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
# encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
if self.final_dropout:
|
||||
attn_output = self.final_dropout(attn_output)
|
||||
|
||||
hidden_states = attn_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
# 4. Feed-forward
|
||||
norm_hidden_states = self.norm3(hidden_states)
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
hidden_states = ff_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
return hidden_states
|
||||
|
||||
class DiTConfig(BaseModel):
|
||||
num_attention_heads: int = Field(default=8)
|
||||
attention_head_dim: int = Field(default=64)
|
||||
output_dim: int = Field(default=26)
|
||||
num_layers: int = Field(default=12)
|
||||
dropout: float = Field(default=0.1)
|
||||
attention_bias: bool = Field(default=True)
|
||||
activation_fn: str = Field(default="gelu-approximate")
|
||||
num_embeds_ada_norm: Optional[int] = Field(default=1000)
|
||||
upcast_attention: bool = Field(default=False)
|
||||
norm_type: str = Field(default="ada_norm")
|
||||
norm_elementwise_affine: bool = Field(default=False)
|
||||
norm_eps: float = Field(default=1e-5)
|
||||
max_num_positional_embeddings: int = Field(default=512)
|
||||
compute_dtype: str = Field(default="float32")
|
||||
final_dropout: bool = Field(default=True)
|
||||
positional_embeddings: Optional[str] = Field(default="sinusoidal")
|
||||
interleave_self_attention: bool = Field(default=False)
|
||||
cross_attention_dim: Optional[int] = Field(default=None, description="Dimension of the cross-attention embeddings. If None, no cross-attention is used.")
|
||||
|
||||
|
||||
class DiT(ModelMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(self,config: DiTConfig):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
self.compute_dtype = getattr(torch, self.config.compute_dtype)
|
||||
|
||||
self.attention_head_dim = self.config.attention_head_dim
|
||||
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Timestep encoder
|
||||
self.timestep_encoder = TimestepEncoder(
|
||||
embedding_dim=self.inner_dim, compute_dtype=self.compute_dtype
|
||||
)
|
||||
|
||||
all_blocks = []
|
||||
for idx in range(self.config.num_layers):
|
||||
|
||||
use_self_attn = idx % 2 == 1 and self.config.interleave_self_attention
|
||||
curr_cross_attention_dim = self.config.cross_attention_dim if not use_self_attn else None
|
||||
|
||||
all_blocks += [
|
||||
BasicTransformerBlock(
|
||||
self.inner_dim,
|
||||
self.config.num_attention_heads,
|
||||
self.config.attention_head_dim,
|
||||
dropout=self.config.dropout,
|
||||
activation_fn=self.config.activation_fn,
|
||||
attention_bias=self.config.attention_bias,
|
||||
upcast_attention=self.config.upcast_attention,
|
||||
norm_type=self.config.norm_type,
|
||||
norm_elementwise_affine=self.config.norm_elementwise_affine,
|
||||
norm_eps=self.config.norm_eps,
|
||||
positional_embeddings=self.config.positional_embeddings,
|
||||
num_positional_embeddings=self.config.max_num_positional_embeddings,
|
||||
final_dropout=self.config.final_dropout,
|
||||
cross_attention_dim=curr_cross_attention_dim,
|
||||
)
|
||||
]
|
||||
self.transformer_blocks = nn.ModuleList(all_blocks)
|
||||
|
||||
# Output blocks
|
||||
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
|
||||
self.proj_out_2 = nn.Linear(self.inner_dim, self.config.output_dim)
|
||||
print(
|
||||
"Total number of DiT parameters: ",
|
||||
sum(p.numel() for p in self.parameters() if p.requires_grad),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor, # Shape: (B, T, D)
|
||||
encoder_hidden_states: torch.Tensor, # Shape: (B, S, D)
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
return_all_hidden_states: bool = False,
|
||||
):
|
||||
# Encode timesteps
|
||||
temb = self.timestep_encoder(timestep)
|
||||
|
||||
# Process through transformer blocks - single pass through the blocks
|
||||
hidden_states = hidden_states.contiguous()
|
||||
encoder_hidden_states = encoder_hidden_states.contiguous()
|
||||
|
||||
all_hidden_states = [hidden_states]
|
||||
|
||||
# Process through transformer blocks
|
||||
for idx, block in enumerate(self.transformer_blocks):
|
||||
if idx % 2 == 1 and self.config.interleave_self_attention:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
temb=temb,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=None,
|
||||
temb=temb,
|
||||
)
|
||||
all_hidden_states.append(hidden_states)
|
||||
|
||||
# Output processing
|
||||
conditioning = temb
|
||||
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
||||
if return_all_hidden_states:
|
||||
return self.proj_out_2(hidden_states), all_hidden_states
|
||||
else:
|
||||
return self.proj_out_2(hidden_states)
|
||||
|
||||
|
||||
class SelfAttentionTransformerConfig(BaseModel):
|
||||
num_attention_heads: int = Field(default=8)
|
||||
attention_head_dim: int = Field(default=64)
|
||||
output_dim: int = Field(default=26)
|
||||
num_layers: int = Field(default=12)
|
||||
dropout: float = Field(default=0.1)
|
||||
attention_bias: bool = Field(default=True)
|
||||
activation_fn: str = Field(default="gelu-approximate")
|
||||
num_embeds_ada_norm: Optional[int] = Field(default=1000)
|
||||
upcast_attention: bool = Field(default=False)
|
||||
max_num_positional_embeddings: int = Field(default=512)
|
||||
compute_dtype: str = Field(default="float32")
|
||||
final_dropout: bool = Field(default=True)
|
||||
positional_embeddings: Optional[str] = Field(default="sinusoidal")
|
||||
interleave_self_attention: bool = Field(default=False)
|
||||
|
||||
class SelfAttentionTransformer(ModelMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(self, config: SelfAttentionTransformerConfig):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
self.attention_head_dim = self.config.attention_head_dim
|
||||
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
self.inner_dim,
|
||||
self.config.num_attention_heads,
|
||||
self.config.attention_head_dim,
|
||||
dropout=self.config.dropout,
|
||||
activation_fn=self.config.activation_fn,
|
||||
attention_bias=self.config.attention_bias,
|
||||
upcast_attention=self.config.upcast_attention,
|
||||
positional_embeddings=self.config.positional_embeddings,
|
||||
num_positional_embeddings=self.config.max_num_positional_embeddings,
|
||||
final_dropout=self.config.final_dropout,
|
||||
)
|
||||
for _ in range(self.config.num_layers)
|
||||
]
|
||||
)
|
||||
print(
|
||||
"Total number of SelfAttentionTransformer parameters: ",
|
||||
sum(p.numel() for p in self.parameters() if p.requires_grad),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor, # Shape: (B, T, D)
|
||||
return_all_hidden_states: bool = False,
|
||||
):
|
||||
|
||||
# Process through transformer blocks - single pass through the blocks
|
||||
hidden_states = hidden_states.contiguous()
|
||||
all_hidden_states = [hidden_states]
|
||||
|
||||
# Process through transformer blocks
|
||||
for idx, block in enumerate(self.transformer_blocks):
|
||||
hidden_states = block(hidden_states)
|
||||
all_hidden_states.append(hidden_states)
|
||||
|
||||
if return_all_hidden_states:
|
||||
return hidden_states, all_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CrossAttentionTransformer(ModelMixin, ConfigMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 8,
|
||||
attention_head_dim: int = 64,
|
||||
output_dim: int = 26,
|
||||
num_layers: int = 12,
|
||||
dropout: float = 0.1,
|
||||
attention_bias: bool = True,
|
||||
activation_fn: str = "gelu-approximate",
|
||||
num_embeds_ada_norm: Optional[int] = 1000,
|
||||
upcast_attention: bool = False,
|
||||
max_num_positional_embeddings: int = 512,
|
||||
compute_dtype=torch.float32,
|
||||
final_dropout: bool = True,
|
||||
positional_embeddings: Optional[str] = "sinusoidal",
|
||||
interleave_self_attention=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
self.inner_dim,
|
||||
self.config.num_attention_heads,
|
||||
self.config.attention_head_dim,
|
||||
dropout=self.config.dropout,
|
||||
activation_fn=self.config.activation_fn,
|
||||
attention_bias=self.config.attention_bias,
|
||||
upcast_attention=self.config.upcast_attention,
|
||||
positional_embeddings=positional_embeddings,
|
||||
num_positional_embeddings=self.config.max_num_positional_embeddings,
|
||||
final_dropout=final_dropout,
|
||||
)
|
||||
for _ in range(self.config.num_layers)
|
||||
]
|
||||
)
|
||||
print(
|
||||
"Total number of CrossAttentionTransformer parameters: ",
|
||||
sum(p.numel() for p in self.parameters() if p.requires_grad),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor, # Shape: (B, T, D)
|
||||
encoder_hidden_states: torch.Tensor, # Shape: (B, S, D)
|
||||
):
|
||||
|
||||
# Process through transformer blocks - single pass through the blocks
|
||||
hidden_states = hidden_states.contiguous()
|
||||
encoder_hidden_states = encoder_hidden_states.contiguous()
|
||||
# Process through transformer blocks
|
||||
for idx, block in enumerate(self.transformer_blocks):
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
755
nitrogen/flow_matching_transformer/nitrogen.py
Normal file
755
nitrogen/flow_matching_transformer/nitrogen.py
Normal file
@ -0,0 +1,755 @@
|
||||
from dataclasses import dataclass, field
|
||||
from pydantic import BaseModel, Field
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
from torch.distributions import Beta
|
||||
from transformers import SiglipVisionModel, AutoModel
|
||||
|
||||
from .modules import DiT, DiTConfig, SelfAttentionTransformer, SelfAttentionTransformerConfig
|
||||
|
||||
_PAD_TOKEN = 0
|
||||
_IMG_TOKEN = 1
|
||||
_IMG_SEP_TOKEN = 5
|
||||
_LANG_TOKEN = 2
|
||||
_PROPRIO_TOKEN = 3
|
||||
_ACT_TOKEN = 4
|
||||
_GAME_ID_TOKEN = 6
|
||||
|
||||
|
||||
class NitroGen_Config(BaseModel):
|
||||
model_type: str = Field(default="nitrogen", frozen=True)
|
||||
|
||||
add_pos_embed: bool = Field(default=False, description="Whether to add positional embedding")
|
||||
model_dtype: str = Field(default="float32", description="Model data type.")
|
||||
diffusion_model_cfg: DiTConfig = Field(..., description="Diffusion model configuration.")
|
||||
vl_self_attention_cfg: SelfAttentionTransformerConfig = Field(..., description="VL self-attention configuration.")
|
||||
hidden_size: int = Field(default=1024, description="Input embedding dimension.")
|
||||
max_seq_len: int = Field(default=1024, description="Maxium Sequence Length")
|
||||
action_dim: int = Field(default=None, description="Action dimension.")
|
||||
action_horizon: int = Field(default=None, description="Action horizon.")
|
||||
noise_beta_alpha: float = Field(default=1.5, description="")
|
||||
noise_beta_beta: float = Field(default=1.0, description="")
|
||||
noise_s: float = Field(default=0.999, description="Flow matching noise Beta distribution s.")
|
||||
num_timestep_buckets: int = Field(default=1000, description="Number of timestep discretization buckets.")
|
||||
num_inference_timesteps: int = Field(default=None, description="Number of inference steps for noise diffusion.")
|
||||
max_num_embodiments: int = Field(default=1, description="Number of embodiments.")
|
||||
vision_encoder_name: str = Field(default="google/siglip-large-patch16-256", description="Vision encoder name.")
|
||||
vision_hidden_size: int = Field(default=768, description="Siglip hidden size.")
|
||||
add_view_embed: bool = Field(default=False, description="Whether to add view embedding.")
|
||||
|
||||
tune_vision_tower: bool = Field(default=True, description="Tune vision if True.")
|
||||
tune_mm_projector: bool = Field(default=True, description="Tune mm projector if True.")
|
||||
tune_diffusion_model: bool = Field(default=True, description="Tune diffusion model if True.")
|
||||
tune_multi_projector: bool = Field(default=True, description="Tune multi projector if True.")
|
||||
tune_vl_mixing: bool = Field(default=True, description="Tune vl mixing if True.")
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, yaml_path: str | Path) -> "NitroGen_Config":
|
||||
"""Load configuration from a YAML file."""
|
||||
with open(yaml_path, "r") as f:
|
||||
config_dict = yaml.safe_load(f)
|
||||
return cls.model_validate(config_dict)
|
||||
|
||||
def swish(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class SinusoidalPositionalEncoding(nn.Module):
|
||||
"""
|
||||
Produces a sinusoidal encoding of shape (B, T, w)
|
||||
given timesteps of shape (B, T).
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim):
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
def forward(self, timesteps):
|
||||
# timesteps: shape (B, T)
|
||||
# We'll compute sin/cos frequencies across dim T
|
||||
timesteps = timesteps.float() # ensure float
|
||||
|
||||
B, T = timesteps.shape
|
||||
device = timesteps.device
|
||||
|
||||
half_dim = self.embedding_dim // 2
|
||||
# typical log space frequencies for sinusoidal encoding
|
||||
exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * (
|
||||
torch.log(torch.tensor(10000.0)) / half_dim
|
||||
)
|
||||
# Expand timesteps to (B, T, 1) then multiply
|
||||
freqs = timesteps.unsqueeze(-1) * exponent.exp() # (B, T, half_dim)
|
||||
|
||||
sin = torch.sin(freqs)
|
||||
cos = torch.cos(freqs)
|
||||
enc = torch.cat([sin, cos], dim=-1) # (B, T, w)
|
||||
|
||||
return enc
|
||||
|
||||
|
||||
|
||||
class CategorySpecificLinear(nn.Module):
|
||||
def __init__(self, num_categories, input_dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.num_categories = num_categories
|
||||
# For each category, we have separate weights and biases.
|
||||
self.W = nn.Parameter(0.02 * torch.randn(num_categories, input_dim, hidden_dim))
|
||||
self.b = nn.Parameter(torch.zeros(num_categories, hidden_dim))
|
||||
|
||||
def forward(self, x, cat_ids):
|
||||
selected_W = self.W[cat_ids]
|
||||
selected_b = self.b[cat_ids]
|
||||
return torch.bmm(x, selected_W) + selected_b.unsqueeze(1)
|
||||
|
||||
|
||||
class CategorySpecificMLP(nn.Module):
|
||||
def __init__(self, num_categories, input_dim, hidden_dim, output_dim):
|
||||
super().__init__()
|
||||
self.num_categories = num_categories
|
||||
self.layer1 = CategorySpecificLinear(num_categories, input_dim, hidden_dim)
|
||||
self.layer2 = CategorySpecificLinear(num_categories, hidden_dim, output_dim)
|
||||
|
||||
def forward(self, x, cat_ids):
|
||||
hidden = F.relu(self.layer1(x, cat_ids))
|
||||
return self.layer2(hidden, cat_ids)
|
||||
|
||||
|
||||
class MultiEmbodimentActionEncoder(nn.Module):
|
||||
def __init__(self, action_dim, hidden_size, num_embodiments):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_embodiments = num_embodiments
|
||||
|
||||
# W1: R^{w x d}, W2: R^{w x 2w}, W3: R^{w x w}
|
||||
self.W1 = CategorySpecificLinear(num_embodiments, action_dim, hidden_size) # (d -> w)
|
||||
self.W2 = CategorySpecificLinear(num_embodiments, 2 * hidden_size, hidden_size) # (2w -> w)
|
||||
self.W3 = CategorySpecificLinear(num_embodiments, hidden_size, hidden_size) # (w -> w)
|
||||
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
|
||||
|
||||
def forward(self, actions, timesteps, cat_ids):
|
||||
"""
|
||||
actions: shape (B, T, action_dim)
|
||||
timesteps: shape (B,) -- a single scalar per batch item
|
||||
cat_ids: shape (B,)
|
||||
returns: shape (B, T, hidden_size)
|
||||
"""
|
||||
B, T, _ = actions.shape
|
||||
|
||||
# 1) Expand each batch's single scalar time 'tau' across all T steps
|
||||
# so that shape => (B, T)
|
||||
# e.g. if timesteps is (B,), replicate across T
|
||||
if timesteps.dim() == 1 and timesteps.shape[0] == B:
|
||||
# shape (B,) => (B,T)
|
||||
timesteps = timesteps.unsqueeze(1).expand(-1, T)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Expected `timesteps` to have shape (B,) so we can replicate across T."
|
||||
)
|
||||
|
||||
# 2) Standard action MLP step for shape => (B, T, w)
|
||||
a_emb = self.W1(actions, cat_ids)
|
||||
|
||||
# 3) Get the sinusoidal encoding (B, T, w)
|
||||
tau_emb = self.pos_encoding(timesteps).to(dtype=a_emb.dtype)
|
||||
|
||||
# 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish
|
||||
x = torch.cat([a_emb, tau_emb], dim=-1)
|
||||
x = swish(self.W2(x, cat_ids))
|
||||
|
||||
# 5) Finally W3 => (B, T, w)
|
||||
x = self.W3(x, cat_ids)
|
||||
return x
|
||||
|
||||
|
||||
class NitroGen(torch.nn.Module):
|
||||
config_class = NitroGen_Config
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: NitroGen_Config,
|
||||
game_mapping: dict[str, int] | None = None, # Used to add a game ID token
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.vision_hidden_size = config.vision_hidden_size
|
||||
|
||||
if "siglip" in config.vision_encoder_name:
|
||||
model = SiglipVisionModel.from_pretrained(config.vision_encoder_name)
|
||||
self.vision_encoder = model.vision_model
|
||||
self.vision_encoder_type = "siglip"
|
||||
else:
|
||||
self.vision_encoder = AutoModel.from_pretrained(config.vision_encoder_name)
|
||||
self.vision_encoder_type = "hf_auto"
|
||||
self.beta_dist = Beta(config.noise_beta_alpha, config.noise_beta_beta)
|
||||
self.num_timestep_buckets = config.num_timestep_buckets
|
||||
# self.model = instantiate(config.diffusion_model_cfg)
|
||||
self.model = DiT(config=config.diffusion_model_cfg)
|
||||
self.action_dim = config.action_dim
|
||||
self.action_horizon = config.action_horizon
|
||||
self.num_inference_timesteps = config.num_inference_timesteps
|
||||
|
||||
# self.vl_self_attention_model = instantiate(config.vl_self_attention_cfg)
|
||||
self.vl_self_attention_model = SelfAttentionTransformer(config=config.vl_self_attention_cfg)
|
||||
|
||||
# if config.qformer_cfg is not None:
|
||||
# self.qformer = instantiate(config.qformer_cfg)
|
||||
# else:
|
||||
# self.qformer = nn.Identity()
|
||||
|
||||
# self.state_encoder = CategorySpecificMLP(
|
||||
# num_categories=config.max_num_embodiments,
|
||||
# input_dim=config.max_state_dim,
|
||||
# hidden_dim=self.hidden_size,
|
||||
# output_dim=self.hidden_size,
|
||||
# )
|
||||
self.action_encoder = MultiEmbodimentActionEncoder(
|
||||
action_dim=config.action_dim,
|
||||
hidden_size=self.hidden_size,
|
||||
num_embodiments=config.max_num_embodiments,
|
||||
)
|
||||
|
||||
self.action_decoder = CategorySpecificMLP(
|
||||
num_categories=config.max_num_embodiments,
|
||||
input_dim=self.hidden_size,
|
||||
hidden_dim=self.hidden_size,
|
||||
output_dim=self.action_dim,
|
||||
)
|
||||
# self.mm_vision_select_layer = config.mm_vision_select_layer
|
||||
# if config.mm_projector_cfg is not None:
|
||||
# self.mm_projector = instantiate(config.mm_projector_cfg)
|
||||
# else:
|
||||
self.mm_projector = None
|
||||
if config.add_pos_embed:
|
||||
self.position_embedding = nn.Embedding(config.max_seq_len, self.hidden_size)
|
||||
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
|
||||
# if config.add_view_embed:
|
||||
# self.view_embedding = nn.Embedding(config.max_num_views, self.hidden_size)
|
||||
# nn.init.normal_(self.view_embedding.weight, mean=0.0, std=0.02)
|
||||
|
||||
# self.vision_projector = None
|
||||
# if config.vision_hidden_size != self.hidden_size:
|
||||
# self.vision_projector = nn.Sequential(
|
||||
# nn.Linear(config.vision_hidden_size, self.hidden_size),
|
||||
# nn.LayerNorm(self.hidden_size),
|
||||
# )
|
||||
|
||||
self.game_mapping = game_mapping
|
||||
# Create an embedding table for game IDs
|
||||
# Game ID tokens will be put inside vision-language tokens
|
||||
# so they need to be projected to the same dimension
|
||||
if self.game_mapping is not None:
|
||||
# 0 = unconditional
|
||||
self.game_embedding = nn.Embedding(
|
||||
len(self.game_mapping),
|
||||
self.vision_hidden_size,
|
||||
padding_idx=0,
|
||||
scale_grad_by_freq=True
|
||||
)
|
||||
|
||||
self.set_trainable_parameters(
|
||||
tune_multi_projector=config.tune_multi_projector,
|
||||
tune_diffusion_model=config.tune_diffusion_model,
|
||||
tune_vision_tower=config.tune_vision_tower,
|
||||
tune_mm_projector=config.tune_mm_projector,
|
||||
tune_vl_mixing=config.tune_vl_mixing,
|
||||
)
|
||||
|
||||
print(
|
||||
"total number of parameters: %e",
|
||||
sum(p.numel() for p in self.parameters() if p.requires_grad),
|
||||
)
|
||||
|
||||
def set_trainable_parameters(
|
||||
self,
|
||||
tune_multi_projector: bool = True,
|
||||
tune_diffusion_model: bool = True,
|
||||
tune_vision_tower: bool = True,
|
||||
tune_mm_projector: bool = True,
|
||||
tune_vl_mixing: bool = True,
|
||||
):
|
||||
self.tune_multi_projector = tune_multi_projector
|
||||
self.tune_diffusion_model = tune_diffusion_model
|
||||
self.tune_vision_tower = tune_vision_tower
|
||||
self.tune_mm_projector = tune_mm_projector
|
||||
self.tune_vl_mixing = tune_vl_mixing
|
||||
|
||||
for param in self.parameters():
|
||||
param.requires_grad = True
|
||||
# ### Always freeze language encoder
|
||||
# self.siglip_model.text_model.requires_grad_(False)
|
||||
# # Freeze unused parameters in siglip vision encoder
|
||||
# self.siglip_model.logit_scale.requires_grad = False
|
||||
# self.siglip_model.logit_bias.requires_grad = False
|
||||
|
||||
# For siglip, we have to
|
||||
if self.vision_encoder_type == "siglip":
|
||||
for param in self.vision_encoder.encoder.layers[11].parameters():
|
||||
param.requires_grad = False
|
||||
for param in self.vision_encoder.head.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
|
||||
# Freeze parameters
|
||||
if not tune_multi_projector:
|
||||
# self.state_encoder.requires_grad_(False)
|
||||
self.action_encoder.requires_grad_(False)
|
||||
self.action_decoder.requires_grad_(False)
|
||||
if self.config.add_pos_embed:
|
||||
self.position_embedding.requires_grad_(False)
|
||||
if self.config.add_view_embed:
|
||||
self.view_embedding.requires_grad_(False)
|
||||
if not tune_diffusion_model:
|
||||
self.model.requires_grad_(False)
|
||||
if not tune_vision_tower:
|
||||
self.vision_encoder.requires_grad_(False)
|
||||
if self.mm_projector is not None and not tune_mm_projector:
|
||||
self.mm_projector.requires_grad_(False)
|
||||
if not tune_vl_mixing:
|
||||
self.vl_self_attention_model.requires_grad_(False)
|
||||
|
||||
print(f"Tune action head multi_projector: {self.tune_multi_projector}")
|
||||
print(f"Tune action head diffusion model: {self.tune_diffusion_model}")
|
||||
print(f"Tune action head vision tower: {self.tune_vision_tower}")
|
||||
print(f"Tune action head mm_projector: {self.tune_mm_projector}")
|
||||
print(f"Tune action head vl_mixing: {self.tune_vl_mixing}")
|
||||
# Check if any parameters are still trainable. If not, print a warning.
|
||||
if not any(p.requires_grad for p in self.parameters()):
|
||||
print("Warning: No action head trainable parameters found.")
|
||||
|
||||
def set_frozen_modules_to_eval_mode(self):
|
||||
"""
|
||||
Huggingface will call model.train() at each training_step. To ensure
|
||||
the expected behaviors for modules like dropout, batchnorm, etc., we
|
||||
need to call model.eval() for the frozen modules.
|
||||
"""
|
||||
if self.training:
|
||||
# self.siglip_model.text_model.eval()
|
||||
if not self.tune_multi_projector:
|
||||
# self.state_encoder.eval()
|
||||
self.action_encoder.eval()
|
||||
self.action_decoder.eval()
|
||||
if self.config.add_pos_embed:
|
||||
self.position_embedding.eval()
|
||||
# if self.config.add_view_embed:
|
||||
# self.view_embedding.eval()
|
||||
if not self.tune_diffusion_model:
|
||||
self.model.eval()
|
||||
if not self.tune_vision_tower:
|
||||
self.vision_encoder.eval()
|
||||
if self.mm_projector is not None and not self.tune_mm_projector:
|
||||
self.mm_projector.eval()
|
||||
if not self.tune_vl_mixing:
|
||||
self.vl_self_attention_model.eval()
|
||||
|
||||
# This function is supposedly incorrect
|
||||
# def sample_time(self, batch_size, device, dtype):
|
||||
# sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype)
|
||||
# return (self.config.noise_s - sample) / self.config.noise_s
|
||||
|
||||
def sample_time(self, batch_size, device, dtype):
|
||||
sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype)
|
||||
return (1 - sample) * self.config.noise_s
|
||||
|
||||
def encode_images(self, images): #, view_ids):
|
||||
batch_size, num_frames, channels, height, width = images.shape
|
||||
images = images.reshape(-1, channels, height, width)
|
||||
|
||||
image_features = self.vision_encoder(images)["last_hidden_state"]
|
||||
image_features = rearrange(image_features, "(b f) n d -> b f n d", f=num_frames)
|
||||
|
||||
# if self.vision_projector is not None:
|
||||
# # change the hidden dimension of the vision features
|
||||
# image_features = self.vision_projector(image_features)
|
||||
if self.mm_projector is not None:
|
||||
image_features = self.mm_projector(image_features) # [B, 256, 1024] -> [B, 16, 1024]
|
||||
return image_features
|
||||
|
||||
def prepare_input_embs(self, vl_token_ids, sa_token_ids, vision, action, dropped_images, game_ids=None):
|
||||
B, T = vl_token_ids.shape
|
||||
vl_embs = torch.full(
|
||||
size=(B, T, self.vision_hidden_size), fill_value=0.0, dtype=vision.dtype, device=vision.device
|
||||
)
|
||||
|
||||
# Extract dimensions from vision tensor
|
||||
B, num_images, tokens_per_image, hidden_size = vision.shape
|
||||
|
||||
# Create mask for _IMG_TOKEN positions
|
||||
vision_mask = (vl_token_ids == _IMG_TOKEN) # [B, T]
|
||||
|
||||
# Flatten vision tensor over the num_images dimension
|
||||
vision_flat = vision.reshape(B, -1, self.vision_hidden_size) # [B, T * tokens_per_image, hidden_size]
|
||||
|
||||
# Create a mask for the flattened vision dimension
|
||||
# Each image contributes tokens_per_image tokens, so expand the mask accordingly
|
||||
non_dropped_mask_expanded = (dropped_images == 0).unsqueeze(-1).repeat(1, 1, tokens_per_image).reshape(B, -1) # [B, T * tokens_per_image]
|
||||
|
||||
# Select only non-dropped vision embeddings
|
||||
# This will give us the embeddings we need to place
|
||||
valid_vision_embs = vision_flat[non_dropped_mask_expanded] # [total_valid_tokens, 1152]
|
||||
|
||||
assert valid_vision_embs.shape[0] == vision_mask.sum().item(), (
|
||||
f"Number of valid vision embeddings {valid_vision_embs.shape[0]} does not match "
|
||||
f"the number of _IMG_TOKEN positions {vision_mask.sum().item()}"
|
||||
)
|
||||
# Now we need to place these at the vision_mask positions
|
||||
# Get indices where vision_mask is True
|
||||
batch_indices, token_indices = vision_mask.nonzero(as_tuple=True)
|
||||
|
||||
# Place the valid embeddings at the masked positions
|
||||
vl_embs[batch_indices, token_indices] = valid_vision_embs
|
||||
|
||||
# Handle Game ID tokens
|
||||
if self.game_mapping is not None and game_ids is not None:
|
||||
game_mask = vl_token_ids == _GAME_ID_TOKEN # shape: (B, T)
|
||||
num_game_tokens = game_mask.sum().item()
|
||||
if num_game_tokens > 0:
|
||||
|
||||
# Assert that each batch item has exactly one game token
|
||||
game_tokens_per_batch = game_mask.sum(dim=1) # [B] - count of game tokens per batch item
|
||||
assert torch.all(game_tokens_per_batch == 1), (
|
||||
f"Expected exactly 1 game token per batch item, but got: {game_tokens_per_batch.tolist()}. "
|
||||
f"Each batch item must have exactly one _GAME_ID_TOKEN."
|
||||
)
|
||||
|
||||
# Get game embeddings for each batch item
|
||||
game_embs = self.game_embedding(game_ids) # [B, vision_hidden_size]
|
||||
batch_indices, token_indices = game_mask.nonzero(as_tuple=True)
|
||||
vl_embs[batch_indices, token_indices] = game_embs[batch_indices].to(dtype=vl_embs.dtype)
|
||||
|
||||
# Project image separator using the learnable sep_embedding.
|
||||
sep_mask = vl_token_ids == _IMG_SEP_TOKEN # shape: (B, T)
|
||||
num_sep = sep_mask.sum().item()
|
||||
if num_sep > 0:
|
||||
# Expand the separator embedding for each occurrence.
|
||||
repeated_sep = self.vis_sep_embedding.unsqueeze(0).expand(num_sep, self.hidden_size)
|
||||
# Assign the separator embeddings to the correct positions.
|
||||
vl_embs[sep_mask] = repeated_sep.to(dtype=vl_embs.dtype)
|
||||
|
||||
B, T = sa_token_ids.shape
|
||||
sa_embs = torch.full(
|
||||
size=(B, T, self.hidden_size), fill_value=0.0, dtype=vision.dtype, device=vision.device
|
||||
)
|
||||
|
||||
# Project state.
|
||||
# state_mask = sa_token_ids == _PROPRIO_TOKEN
|
||||
# state_mask = state_mask.unsqueeze(-1).expand_as(sa_embs)
|
||||
# sa_embs = sa_embs.masked_scatter(state_mask, state)
|
||||
|
||||
# Project action.
|
||||
action_mask = sa_token_ids == _ACT_TOKEN
|
||||
action_mask = action_mask.unsqueeze(-1).expand_as(sa_embs)
|
||||
sa_embs = sa_embs.masked_scatter(action_mask, action)
|
||||
|
||||
# Add positional embeddings
|
||||
pos_ids = torch.arange(T, dtype=torch.long, device=sa_token_ids.device)
|
||||
if self.config.add_pos_embed:
|
||||
pos_embs = self.position_embedding(pos_ids) # (T, hidden_size)
|
||||
pos_embs = pos_embs.unsqueeze(0).expand(B, T, self.hidden_size)
|
||||
sa_embs = sa_embs + pos_embs
|
||||
return vl_embs, sa_embs
|
||||
|
||||
def pack_actions(self, buttons, j_left, j_right):
|
||||
# Check that the first three dims of each input is the same
|
||||
assert buttons.shape[:3] == j_left.shape[:3] == j_right.shape[:3], (
|
||||
f"buttons shape: {buttons.shape}, "
|
||||
f"j_left shape: {j_left.shape}, "
|
||||
f"j_right shape: {j_right.shape}"
|
||||
)
|
||||
|
||||
# Normalize the joysticks to 0,1
|
||||
j_left = (j_left + 1) / 2.
|
||||
j_right = (j_right + 1) / 2.
|
||||
|
||||
# Concatenate the buttons and joysticks along the last dimension
|
||||
action = torch.cat([j_left, j_right, buttons], dim=-1)
|
||||
|
||||
# Squeeze the second dimension of each input: this is the number of chunks, which is 1 here
|
||||
action = action.squeeze(1)
|
||||
return action
|
||||
|
||||
# def unpack_actions(self, actions):
|
||||
# # Unpack the actions into j_left, j_right, buttons
|
||||
# j_left = actions[:, :, :2]
|
||||
# j_right = actions[:, :, 2:4]
|
||||
# buttons = actions[:, :, 4:]
|
||||
|
||||
# # Denormalize the joysticks back to -1,1
|
||||
# j_left = j_left * 2. - 1.
|
||||
# j_right = j_right * 2. - 1.
|
||||
|
||||
# # Clip into [-1,1]
|
||||
# j_left = torch.clamp(j_left, -1, 1)
|
||||
# j_right = torch.clamp(j_right, -1, 1)
|
||||
|
||||
# # Threshold the buttons to 0,1
|
||||
# buttons = (buttons > 0.5).float()
|
||||
# return j_left, j_right, buttons
|
||||
|
||||
# ========= ActionHead required ============
|
||||
def forward(self, data: dict) -> dict:
|
||||
self.set_frozen_modules_to_eval_mode()
|
||||
|
||||
# data = action_input
|
||||
embodiment_id = data["embodiment_id"]
|
||||
|
||||
# # Check which data is present.
|
||||
# has_real_action = action_input.has_real_action
|
||||
has_real_action = data["has_real_action"]
|
||||
|
||||
# 1) Encode images/text/state
|
||||
visual_features = self.encode_images(data["images"]) #, data["view_ids"])
|
||||
# text_features = self.siglip_model.text_model(
|
||||
# input_ids=data["lang_input_ids"]
|
||||
# ).last_hidden_state
|
||||
# state_features = self.state_encoder(data["state"], embodiment_id)
|
||||
|
||||
# 2) Prepare noisy trajectory
|
||||
actions = data["actions"]
|
||||
noise = torch.randn_like(actions)
|
||||
t = self.sample_time(actions.shape[0], device=actions.device, dtype=actions.dtype)
|
||||
t = t[:, None, None] # shape (B,1,1) for broadcast
|
||||
|
||||
noisy_trajectory = (1 - t) * noise + t * actions
|
||||
velocity = actions - noise
|
||||
|
||||
# 3) Convert (continuous) t -> discrete if needed
|
||||
t_discretized = (t[:, 0, 0] * self.num_timestep_buckets).long()
|
||||
|
||||
# 4) Get action encoder embeddings with correct time argument
|
||||
action_features = self.action_encoder(noisy_trajectory, t_discretized, embodiment_id)
|
||||
|
||||
# 5) Prepare full input to DiT (or your model)
|
||||
vl_embs, sa_embs = self.prepare_input_embs(
|
||||
data["vl_token_ids"],
|
||||
data["sa_token_ids"],
|
||||
visual_features,
|
||||
# text_features,
|
||||
# state_features,
|
||||
action_features,
|
||||
data["dropped_images"],
|
||||
game_ids=data.get("game_id"),
|
||||
)
|
||||
|
||||
vl_embs = self.vl_self_attention_model(vl_embs)
|
||||
# vl_embs = self.qformer(vl_embs)
|
||||
model_output, all_hidden_states = self.model(
|
||||
hidden_states=sa_embs,
|
||||
encoder_hidden_states=vl_embs,
|
||||
encoder_attention_mask=data["vl_attn_mask"],
|
||||
timestep=t_discretized,
|
||||
return_all_hidden_states=True,
|
||||
)
|
||||
pred = self.action_decoder(model_output, embodiment_id)
|
||||
pred_actions = pred[:, -actions.shape[1] :]
|
||||
|
||||
# 6) Flow-matching or velocity-prediction MSE
|
||||
# Mask for variable-length trajectories
|
||||
mask = data["actions_mask"] # shape => (B, seq_len_of_actions, ...)
|
||||
raw_loss = F.mse_loss(pred_actions, velocity, reduction="none")
|
||||
mask = has_real_action[:, None, None] * mask
|
||||
raw_loss = raw_loss * mask
|
||||
action_loss = (has_real_action[:, None, None] * raw_loss).sum() / (mask.sum() + 1e-6)
|
||||
|
||||
loss = action_loss
|
||||
|
||||
return {
|
||||
"loss": loss,
|
||||
}
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_action(self, data: dict, old_layout:bool = False) -> dict:
|
||||
"""
|
||||
For i in [0..N-1]:
|
||||
1) t = i/N
|
||||
2) velocity = model(x(t), t)
|
||||
3) x(t + dt) = x(t) + dt * velocity
|
||||
"""
|
||||
|
||||
# data = action_input
|
||||
embodiment_id = data["embodiment_id"]
|
||||
|
||||
batch_size = data["images"].shape[0]
|
||||
device = data["images"].device
|
||||
dtype = data["images"].dtype
|
||||
actions = torch.randn(
|
||||
size=(batch_size, self.config.action_horizon, self.config.action_dim),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# 1) Hyperparameters for flow sampling
|
||||
num_steps = self.num_inference_timesteps
|
||||
dt = 1.0 / num_steps
|
||||
|
||||
# 2) Encode static context (images, text, state) once if it does not depend on actions
|
||||
visual_features = self.encode_images(data["images"]) #, data["view_ids"])
|
||||
# text_features = self.siglip_model.text_model(
|
||||
# input_ids=data["lang_input_ids"]
|
||||
# ).last_hidden_state
|
||||
# state_features = self.state_encoder(data["state"], embodiment_id)
|
||||
|
||||
# 3) Start denoising the actions
|
||||
for i in range(num_steps):
|
||||
# ---- (a) Discretize continuous time in [0,1]
|
||||
t_cont = i / float(num_steps) # e.g. goes 0, 1/N, 2/N, ...
|
||||
t_discretized = int(t_cont * self.num_timestep_buckets)
|
||||
|
||||
# ---- (b) Build embeddings (actions included)
|
||||
# Pass the *current* actions at time t into the action encoder
|
||||
action_features = self.action_encoder(
|
||||
actions,
|
||||
(torch.ones(actions.shape[0]) * t_discretized).to(device),
|
||||
embodiment_id,
|
||||
)
|
||||
vl_embs, sa_embs = self.prepare_input_embs(
|
||||
data["vl_token_ids"],
|
||||
data["sa_token_ids"],
|
||||
visual_features,
|
||||
# text_features,
|
||||
# state_features,
|
||||
action_features,
|
||||
data["dropped_images"],
|
||||
game_ids=data["game_ids"],
|
||||
)
|
||||
vl_embs = self.vl_self_attention_model(vl_embs)
|
||||
# vl_embs = self.qformer(vl_embs)
|
||||
# ---- (c) Forward pass to get velocity = d/dt x(t)
|
||||
timesteps = torch.from_numpy(np.array([t_discretized])).to(device).long()
|
||||
model_output = self.model(
|
||||
hidden_states=sa_embs,
|
||||
encoder_hidden_states=vl_embs,
|
||||
encoder_attention_mask=data["vl_attn_mask"],
|
||||
timestep=timesteps,
|
||||
)
|
||||
pred = self.action_decoder(model_output, embodiment_id)
|
||||
pred_velocity = pred[:, -actions.shape[1] :]
|
||||
|
||||
# ---- (d) Naive Euler step: x(t + dt) = x(t) + dt * velocity
|
||||
actions = actions + dt * pred_velocity
|
||||
|
||||
return {
|
||||
"action_tensor": actions,
|
||||
}
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_action_with_cfg(self, data_cond: dict, data_uncond: dict, cfg_scale: float = 1.0) -> dict:
|
||||
"""
|
||||
Use a form of classifier free guidance to sample actions. This can only be used on
|
||||
models that were trained on multiple frames of actions. The idea is that we sample
|
||||
velocity with and without the frame history, and then we push the sampled actions
|
||||
towards the ones that were sampled with the frame history.
|
||||
|
||||
data_with_hist = conditional input
|
||||
data_without_hist = unconditional input
|
||||
|
||||
This function works with any kind of conditioning, not just history.
|
||||
|
||||
For i in [0..N-1]:
|
||||
1) t = i/N
|
||||
2) velocity = (1 - cfg_scale) * model(x(t), t, None) + cfg_scale * model(x(t), t, history)
|
||||
3) x(t + dt) = x(t) + dt * velocity
|
||||
"""
|
||||
|
||||
# data = action_input
|
||||
embodiment_id = data_cond["embodiment_id"]
|
||||
|
||||
batch_size = data_cond["images"].shape[0]
|
||||
device = data_cond["images"].device
|
||||
dtype = data_cond["images"].dtype
|
||||
actions = torch.randn(
|
||||
size=(batch_size, self.config.action_horizon, self.config.action_dim),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# 1) Hyperparameters for flow sampling
|
||||
num_steps = self.num_inference_timesteps
|
||||
dt = 1.0 / num_steps
|
||||
|
||||
# 2) Encode static context (images, text, state) once if it does not depend on actions
|
||||
visual_features_cond = self.encode_images(data_cond["images"])
|
||||
visual_features_uncond = self.encode_images(data_uncond["images"])
|
||||
# text_features = self.siglip_model.text_model(
|
||||
# input_ids=data["lang_input_ids"]
|
||||
# ).last_hidden_state
|
||||
# state_features = self.state_encoder(data["state"], embodiment_id)
|
||||
|
||||
# 3) Start denoising the actions
|
||||
for i in range(num_steps):
|
||||
# ---- (a) Discretize continuous time in [0,1]
|
||||
t_cont = i / float(num_steps) # e.g. goes 0, 1/N, 2/N, ...
|
||||
t_discretized = int(t_cont * self.num_timestep_buckets)
|
||||
|
||||
# ---- (b) Build embeddings (actions included)
|
||||
# Pass the *current* actions at time t into the action encoder
|
||||
action_features = self.action_encoder(
|
||||
actions,
|
||||
(torch.ones(actions.shape[0]) * t_discretized).to(device),
|
||||
embodiment_id,
|
||||
)
|
||||
|
||||
# Predict velocity with history
|
||||
vl_embs, sa_embs = self.prepare_input_embs(
|
||||
data_cond["vl_token_ids"],
|
||||
data_cond["sa_token_ids"],
|
||||
visual_features_cond,
|
||||
action_features,
|
||||
data_cond["dropped_images"],
|
||||
)
|
||||
vl_embs = self.vl_self_attention_model(vl_embs)
|
||||
# ---- (c) Forward pass to get velocity = d/dt x(t)
|
||||
timesteps = torch.from_numpy(np.array([t_discretized])).to(device).long()
|
||||
model_output = self.model(
|
||||
hidden_states=sa_embs,
|
||||
encoder_hidden_states=vl_embs,
|
||||
encoder_attention_mask=data_cond["vl_attn_mask"],
|
||||
timestep=timesteps,
|
||||
)
|
||||
pred = self.action_decoder(model_output, embodiment_id)
|
||||
pred_velocity_cond = pred[:, -actions.shape[1] :]
|
||||
|
||||
# Predict velocity without history
|
||||
vl_embs, sa_embs = self.prepare_input_embs(
|
||||
data_uncond["vl_token_ids"],
|
||||
data_uncond["sa_token_ids"],
|
||||
visual_features_uncond,
|
||||
action_features,
|
||||
data_uncond["dropped_images"],
|
||||
)
|
||||
vl_embs = self.vl_self_attention_model(vl_embs)
|
||||
# ---- (c) Forward pass to get velocity = d/dt x(t)
|
||||
timesteps = torch.from_numpy(np.array([t_discretized])).to(device).long()
|
||||
model_output = self.model(
|
||||
hidden_states=sa_embs,
|
||||
encoder_hidden_states=vl_embs,
|
||||
encoder_attention_mask=data_uncond["vl_attn_mask"],
|
||||
timestep=timesteps,
|
||||
)
|
||||
pred = self.action_decoder(model_output, embodiment_id)
|
||||
pred_velocity_uncond = pred[:, -actions.shape[1] :]
|
||||
|
||||
# ---- (d) Combine velocities with cfg_scale
|
||||
pred_velocity = pred_velocity_cond + cfg_scale * (pred_velocity_cond - pred_velocity_uncond)
|
||||
|
||||
# ---- (e) Naive Euler step: x(t + dt) = x(t) + dt * velocity
|
||||
actions = actions + dt * pred_velocity
|
||||
|
||||
return {
|
||||
"action_tensor": actions,
|
||||
}
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(iter(self.parameters())).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(iter(self.parameters())).dtype
|
||||
577
nitrogen/game_env.py
Normal file
577
nitrogen/game_env.py
Normal file
@ -0,0 +1,577 @@
|
||||
import time
|
||||
import platform
|
||||
|
||||
import pyautogui
|
||||
import dxcam
|
||||
import pywinctl as pwc
|
||||
import xspeedhack as xsh
|
||||
from gymnasium import Env
|
||||
from gymnasium.spaces import Box, Dict, Discrete
|
||||
from PIL import Image
|
||||
|
||||
import time
|
||||
|
||||
import vgamepad as vg
|
||||
|
||||
import psutil
|
||||
|
||||
assert platform.system().lower() == "windows", "This module is only supported on Windows."
|
||||
import win32process
|
||||
import win32gui
|
||||
import win32api
|
||||
import win32con
|
||||
|
||||
def get_process_info(process_name):
|
||||
"""
|
||||
Get process information for a given process name on Windows.
|
||||
|
||||
Args:
|
||||
process_name (str): Name of the process (e.g., "isaac-ng.exe")
|
||||
|
||||
Returns:
|
||||
list: List of dictionaries containing PID, window_name, and architecture
|
||||
for each matching process. Returns empty list if no process found.
|
||||
"""
|
||||
results = []
|
||||
|
||||
# Find all processes with the given name
|
||||
for proc in psutil.process_iter(['pid', 'name']):
|
||||
try:
|
||||
if proc.info['name'].lower() == process_name.lower():
|
||||
pid = proc.info['pid']
|
||||
|
||||
# Get architecture
|
||||
try:
|
||||
# Check if process is 32-bit or 64-bit
|
||||
process_handle = win32api.OpenProcess(
|
||||
win32con.PROCESS_QUERY_INFORMATION,
|
||||
False,
|
||||
pid
|
||||
)
|
||||
is_wow64 = win32process.IsWow64Process(process_handle)
|
||||
win32api.CloseHandle(process_handle)
|
||||
|
||||
# On 64-bit Windows: WOW64 means "Windows 32-bit on Windows 64-bit", i.e. a 32-bit process
|
||||
architecture = "x86" if is_wow64 else "x64"
|
||||
except:
|
||||
architecture = "unknown"
|
||||
|
||||
# Find windows associated with this PID
|
||||
windows = []
|
||||
|
||||
def enum_window_callback(hwnd, pid_to_find):
|
||||
_, found_pid = win32process.GetWindowThreadProcessId(hwnd)
|
||||
if found_pid == pid_to_find:
|
||||
window_text = win32gui.GetWindowText(hwnd)
|
||||
if window_text and win32gui.IsWindowVisible(hwnd):
|
||||
windows.append({
|
||||
'hwnd': hwnd,
|
||||
'title': window_text,
|
||||
'visible': win32gui.IsWindowVisible(hwnd)
|
||||
})
|
||||
return True
|
||||
|
||||
# Find all windows for this PID
|
||||
try:
|
||||
win32gui.EnumWindows(enum_window_callback, pid)
|
||||
except:
|
||||
pass
|
||||
|
||||
# Choose the best window
|
||||
window_name = None
|
||||
if windows:
|
||||
if len(windows) > 1:
|
||||
print(f"Multiple windows found for PID {pid}: {[win['title'] for win in windows]}")
|
||||
print("Using heuristics to select the correct window...")
|
||||
# Filter out common proxy/helper windows
|
||||
proxy_keywords = ['d3dproxywindow', 'proxy', 'helper', 'overlay']
|
||||
|
||||
# First try to find a visible window without proxy keywords
|
||||
for win in windows:
|
||||
if not any(keyword in win['title'].lower() for keyword in proxy_keywords):
|
||||
window_name = win['title']
|
||||
break
|
||||
|
||||
# If no good window found, just use the first one
|
||||
if window_name is None and windows:
|
||||
window_name = windows[0]['title']
|
||||
|
||||
results.append({
|
||||
'pid': pid,
|
||||
'window_name': window_name,
|
||||
'architecture': architecture
|
||||
})
|
||||
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
||||
continue
|
||||
|
||||
if len(results) == 0:
|
||||
raise ValueError(f"No process found with name: {process_name}")
|
||||
elif len(results) > 1:
|
||||
print(f"Warning: Multiple processes found with name '{process_name}'. Returning first match.")
|
||||
|
||||
return results[0]
|
||||
|
||||
|
||||
XBOX_MAPPING = {
|
||||
"DPAD_UP": "XUSB_GAMEPAD_DPAD_UP",
|
||||
"DPAD_DOWN": "XUSB_GAMEPAD_DPAD_DOWN",
|
||||
"DPAD_LEFT": "XUSB_GAMEPAD_DPAD_LEFT",
|
||||
"DPAD_RIGHT": "XUSB_GAMEPAD_DPAD_RIGHT",
|
||||
"START": "XUSB_GAMEPAD_START",
|
||||
"BACK": "XUSB_GAMEPAD_BACK",
|
||||
"LEFT_SHOULDER": "XUSB_GAMEPAD_LEFT_SHOULDER",
|
||||
"RIGHT_SHOULDER": "XUSB_GAMEPAD_RIGHT_SHOULDER",
|
||||
"GUIDE": "XUSB_GAMEPAD_GUIDE",
|
||||
"WEST": "XUSB_GAMEPAD_X",
|
||||
"SOUTH": "XUSB_GAMEPAD_A",
|
||||
"EAST": "XUSB_GAMEPAD_B",
|
||||
"NORTH": "XUSB_GAMEPAD_Y",
|
||||
"LEFT_TRIGGER": "LEFT_TRIGGER",
|
||||
"RIGHT_TRIGGER": "RIGHT_TRIGGER",
|
||||
"AXIS_LEFTX": "LEFT_JOYSTICK",
|
||||
"AXIS_LEFTY": "LEFT_JOYSTICK",
|
||||
"AXIS_RIGHTX": "RIGHT_JOYSTICK",
|
||||
"AXIS_RIGHTY": "RIGHT_JOYSTICK",
|
||||
"LEFT_THUMB": "XUSB_GAMEPAD_LEFT_THUMB",
|
||||
"RIGHT_THUMB": "XUSB_GAMEPAD_RIGHT_THUMB",
|
||||
}
|
||||
|
||||
PS4_MAPPING = {
|
||||
"DPAD_UP": "DS4_BUTTON_DPAD_NORTH",
|
||||
"DPAD_DOWN": "DS4_BUTTON_DPAD_SOUTH",
|
||||
"DPAD_LEFT": "DS4_BUTTON_DPAD_WEST",
|
||||
"DPAD_RIGHT": "DS4_BUTTON_DPAD_EAST",
|
||||
"START": "DS4_BUTTON_OPTIONS",
|
||||
"BACK": "DS4_BUTTON_SHARE",
|
||||
"LEFT_SHOULDER": "DS4_BUTTON_SHOULDER_LEFT",
|
||||
"RIGHT_SHOULDER": "DS4_BUTTON_SHOULDER_RIGHT",
|
||||
"GUIDE": "DS4_BUTTON_GUIDE",
|
||||
"WEST": "DS4_BUTTON_SQUARE",
|
||||
"SOUTH": "DS4_BUTTON_CROSS",
|
||||
"EAST": "DS4_BUTTON_CIRCLE",
|
||||
"NORTH": "DS4_BUTTON_TRIANGLE",
|
||||
"LEFT_TRIGGER": "LEFT_TRIGGER",
|
||||
"RIGHT_TRIGGER": "RIGHT_TRIGGER",
|
||||
"AXIS_LEFTX": "LEFT_JOYSTICK",
|
||||
"AXIS_LEFTY": "LEFT_JOYSTICK",
|
||||
"AXIS_RIGHTX": "RIGHT_JOYSTICK",
|
||||
"AXIS_RIGHTY": "RIGHT_JOYSTICK",
|
||||
"LEFT_THUMB": "DS4_BUTTON_THUMB_LEFT",
|
||||
"RIGHT_THUMB": "DS4_BUTTON_THUMB_RIGHT",
|
||||
}
|
||||
|
||||
|
||||
class GamepadEmulator:
|
||||
def __init__(self, controller_type="xbox", system="windows"):
|
||||
"""
|
||||
Initialize the GamepadEmulator with a specific controller type and system.
|
||||
|
||||
Parameters:
|
||||
controller_type (str): The type of controller to emulate ("xbox" or "ps4").
|
||||
system (str): The operating system to use, which affects joystick value handling.
|
||||
"""
|
||||
self.controller_type = controller_type
|
||||
self.system = system
|
||||
if controller_type == "xbox":
|
||||
self.gamepad = vg.VX360Gamepad()
|
||||
self.mapping = XBOX_MAPPING
|
||||
elif controller_type == "ps4":
|
||||
self.gamepad = vg.VDS4Gamepad()
|
||||
self.mapping = PS4_MAPPING
|
||||
else:
|
||||
raise ValueError("Unsupported controller type")
|
||||
|
||||
# Initialize joystick values to keep track of the current state
|
||||
self.left_joystick_x: int = 0
|
||||
self.left_joystick_y: int = 0
|
||||
self.right_joystick_x: int = 0
|
||||
self.right_joystick_y: int = 0
|
||||
|
||||
def step(self, action):
|
||||
"""
|
||||
Perform actions based on the provided action dictionary.
|
||||
|
||||
Parameters:
|
||||
action (dict): Dictionary of an action to be performed. Keys are control names,
|
||||
and values are their respective states.
|
||||
"""
|
||||
self.gamepad.reset()
|
||||
|
||||
# Handle buttons
|
||||
for control in [
|
||||
"EAST",
|
||||
"SOUTH",
|
||||
"NORTH",
|
||||
"WEST",
|
||||
"BACK",
|
||||
"GUIDE",
|
||||
"START",
|
||||
"DPAD_DOWN",
|
||||
"DPAD_LEFT",
|
||||
"DPAD_RIGHT",
|
||||
"DPAD_UP",
|
||||
"LEFT_SHOULDER",
|
||||
"RIGHT_SHOULDER",
|
||||
"LEFT_THUMB",
|
||||
"RIGHT_THUMB",
|
||||
]:
|
||||
if control in action:
|
||||
if action[control]:
|
||||
self.press_button(control)
|
||||
else:
|
||||
self.release_button(control)
|
||||
|
||||
# Handle triggers
|
||||
if "LEFT_TRIGGER" in action:
|
||||
self.set_trigger("LEFT_TRIGGER", action["LEFT_TRIGGER"][0])
|
||||
if "RIGHT_TRIGGER" in action:
|
||||
self.set_trigger("RIGHT_TRIGGER", action["RIGHT_TRIGGER"][0])
|
||||
|
||||
# Handle joysticks
|
||||
if "AXIS_LEFTX" in action and "AXIS_LEFTY" in action:
|
||||
self.set_joystick("AXIS_LEFTX", action["AXIS_LEFTX"][0])
|
||||
self.set_joystick("AXIS_LEFTY", action["AXIS_LEFTY"][0])
|
||||
|
||||
if "AXIS_RIGHTX" in action and "AXIS_RIGHTY" in action:
|
||||
self.set_joystick("AXIS_RIGHTX", action["AXIS_RIGHTX"][0])
|
||||
self.set_joystick("AXIS_RIGHTY", action["AXIS_RIGHTY"][0])
|
||||
|
||||
self.gamepad.update()
|
||||
|
||||
def press_button(self, button):
|
||||
"""
|
||||
Press a button on the gamepad.
|
||||
|
||||
Parameters:
|
||||
button (str): The unified name of the button to press.
|
||||
"""
|
||||
button_mapped = self.mapping.get(button)
|
||||
if self.controller_type == "xbox":
|
||||
self.gamepad.press_button(button=getattr(vg.XUSB_BUTTON, button_mapped))
|
||||
elif self.controller_type == "ps4":
|
||||
self.gamepad.press_button(button=getattr(vg.DS4_BUTTONS, button_mapped))
|
||||
else:
|
||||
raise ValueError("Unsupported controller type")
|
||||
|
||||
def release_button(self, button):
|
||||
"""
|
||||
Release a button on the gamepad.
|
||||
|
||||
Parameters:
|
||||
button (str): The unified name of the button to release.
|
||||
"""
|
||||
button_mapped = self.mapping.get(button)
|
||||
if self.controller_type == "xbox":
|
||||
self.gamepad.release_button(button=getattr(vg.XUSB_BUTTON, button_mapped))
|
||||
elif self.controller_type == "ps4":
|
||||
self.gamepad.release_button(button=getattr(vg.DS4_BUTTONS, button_mapped))
|
||||
else:
|
||||
raise ValueError("Unsupported controller type")
|
||||
|
||||
def set_trigger(self, trigger, value):
|
||||
"""
|
||||
Set the value of a trigger on the gamepad.
|
||||
|
||||
Parameters:
|
||||
trigger (str): The unified name of the trigger.
|
||||
value (float): The value to set the trigger to (between 0 and 1).
|
||||
"""
|
||||
value = int(value)
|
||||
trigger_mapped = self.mapping.get(trigger)
|
||||
if trigger_mapped == "LEFT_TRIGGER":
|
||||
self.gamepad.left_trigger(value=value)
|
||||
elif trigger_mapped == "RIGHT_TRIGGER":
|
||||
self.gamepad.right_trigger(value=value)
|
||||
else:
|
||||
raise ValueError("Unsupported trigger action")
|
||||
|
||||
def set_joystick(self, joystick, value):
|
||||
"""
|
||||
Set the position of a joystick on the gamepad.
|
||||
|
||||
Parameters:
|
||||
joystick (str): The name of the joystick axis.
|
||||
value (float): The value to set the joystick axis to (between -32768 and 32767)
|
||||
"""
|
||||
if joystick == "AXIS_LEFTX":
|
||||
self.left_joystick_x = value
|
||||
self.gamepad.left_joystick(x_value=self.left_joystick_x, y_value=self.left_joystick_y)
|
||||
elif joystick == "AXIS_LEFTY":
|
||||
if self.system == "windows":
|
||||
value = -value - 1
|
||||
self.left_joystick_y = value
|
||||
self.gamepad.left_joystick(x_value=self.left_joystick_x, y_value=self.left_joystick_y)
|
||||
elif joystick == "AXIS_RIGHTX":
|
||||
self.right_joystick_x = value
|
||||
self.gamepad.right_joystick(
|
||||
x_value=self.right_joystick_x, y_value=self.right_joystick_y
|
||||
)
|
||||
elif joystick == "AXIS_RIGHTY":
|
||||
if self.system == "windows":
|
||||
value = -value - 1
|
||||
self.right_joystick_y = value
|
||||
self.gamepad.right_joystick(
|
||||
x_value=self.right_joystick_x, y_value=self.right_joystick_y
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unsupported joystick action")
|
||||
|
||||
def wakeup(self, duration=0.1):
|
||||
"""
|
||||
Wake up the controller by pressing a button.
|
||||
|
||||
Parameters:
|
||||
duration (float): Duration to press the button.
|
||||
"""
|
||||
self.gamepad.press_button(vg.XUSB_BUTTON.XUSB_GAMEPAD_LEFT_THUMB)
|
||||
self.gamepad.update()
|
||||
time.sleep(duration)
|
||||
self.gamepad.reset()
|
||||
self.gamepad.update()
|
||||
time.sleep(duration)
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the gamepad to its default state.
|
||||
"""
|
||||
self.gamepad.reset()
|
||||
self.gamepad.update()
|
||||
|
||||
class PyautoguiScreenshotBackend:
|
||||
|
||||
def __init__(self, bbox):
|
||||
self.bbox = bbox
|
||||
|
||||
def screenshot(self):
|
||||
return pyautogui.screenshot(region=self.bbox)
|
||||
|
||||
class DxcamScreenshotBackend:
|
||||
def __init__(self, bbox):
|
||||
import dxcam
|
||||
self.camera = dxcam.create()
|
||||
self.bbox = bbox
|
||||
self.last_screenshot = None
|
||||
|
||||
def screenshot(self):
|
||||
screenshot = self.camera.grab(region=self.bbox)
|
||||
if screenshot is None:
|
||||
print("DXCAM failed to capture frame, trying to use the latest screenshot")
|
||||
if self.last_screenshot is not None:
|
||||
return self.last_screenshot
|
||||
else:
|
||||
return Image.new("RGB", (self.bbox[2], self.bbox[3]), (0, 0, 0))
|
||||
screenshot = Image.fromarray(screenshot)
|
||||
self.last_screenshot = screenshot
|
||||
return screenshot
|
||||
|
||||
|
||||
class GamepadEnv(Env):
|
||||
"""
|
||||
Base class for creating a game environment controlled with a gamepad.
|
||||
|
||||
Attributes:
|
||||
game (str): Name of the game to interact with.
|
||||
image_height (int): Height of the observation space.
|
||||
image_width (int): Width of the observation space.
|
||||
controller_type (str): Platform for the gamepad emulator ("xbox" or "ps4").
|
||||
game_speed (float): Speed multiplier for the game.
|
||||
env_fps (int): Number of actions to perform per second at normal speed.
|
||||
async_mode (bool): Whether to pause/unpause the game during each step.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
game,
|
||||
image_height=1440,
|
||||
image_width=2560,
|
||||
controller_type="xbox",
|
||||
game_speed=1.0,
|
||||
env_fps=10,
|
||||
async_mode=True,
|
||||
screenshot_backend="dxcam",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Assert that system is windows
|
||||
os_name = platform.system().lower()
|
||||
assert os_name == "windows", "This environment is currently only supported on Windows."
|
||||
assert controller_type in ["xbox", "ps4"], "Platform must be either 'xbox' or 'ps4'"
|
||||
assert screenshot_backend in ["pyautogui", "dxcam"], "Screenshot backend must be either 'pyautogui' or 'dxcam'"
|
||||
|
||||
self.game = game
|
||||
self.image_height = int(image_height)
|
||||
self.image_width = int(image_width)
|
||||
self.game_speed = game_speed
|
||||
self.env_fps = env_fps
|
||||
self.step_duration = self.calculate_step_duration()
|
||||
self.async_mode = async_mode
|
||||
|
||||
self.gamepad_emulator = GamepadEmulator(controller_type=controller_type, system=os_name)
|
||||
proc_info = get_process_info(game)
|
||||
|
||||
self.game_pid = proc_info["pid"]
|
||||
self.game_arch = proc_info["architecture"]
|
||||
self.game_window_name = proc_info["window_name"]
|
||||
|
||||
print(f"Game process found: {self.game} (PID: {self.game_pid}, Arch: {self.game_arch}, Window: {self.game_window_name})")
|
||||
|
||||
if self.game_pid is None:
|
||||
raise Exception(f"Could not find PID for game: {game}")
|
||||
|
||||
|
||||
self.observation_space = Box(
|
||||
low=0, high=255, shape=(self.image_height, self.image_width, 3), dtype="uint8"
|
||||
)
|
||||
|
||||
# Define a unified action space
|
||||
self.action_space = Dict(
|
||||
{
|
||||
"BACK": Discrete(2),
|
||||
"GUIDE": Discrete(2),
|
||||
"RIGHT_SHOULDER": Discrete(2),
|
||||
"RIGHT_TRIGGER": Box(low=0.0, high=1.0, shape=(1,)),
|
||||
"LEFT_TRIGGER": Box(low=0.0, high=1.0, shape=(1,)),
|
||||
"LEFT_SHOULDER": Discrete(2),
|
||||
"AXIS_RIGHTX": Box(low=-32768.0, high=32767, shape=(1,)),
|
||||
"AXIS_RIGHTY": Box(low=-32768.0, high=32767, shape=(1,)),
|
||||
"AXIS_LEFTX": Box(low=-32768.0, high=32767, shape=(1,)),
|
||||
"AXIS_LEFTY": Box(low=-32768.0, high=32767, shape=(1,)),
|
||||
"LEFT_THUMB": Discrete(2),
|
||||
"RIGHT_THUMB": Discrete(2),
|
||||
"DPAD_UP": Discrete(2),
|
||||
"DPAD_RIGHT": Discrete(2),
|
||||
"DPAD_DOWN": Discrete(2),
|
||||
"DPAD_LEFT": Discrete(2),
|
||||
"WEST": Discrete(2),
|
||||
"SOUTH": Discrete(2),
|
||||
"EAST": Discrete(2),
|
||||
"NORTH": Discrete(2),
|
||||
"START": Discrete(2),
|
||||
}
|
||||
)
|
||||
|
||||
# Determine window name
|
||||
windows = pwc.getAllWindows()
|
||||
self.game_window = None
|
||||
for window in windows:
|
||||
if window.title == self.game_window_name:
|
||||
self.game_window = window
|
||||
break
|
||||
|
||||
if not self.game_window:
|
||||
raise Exception(f"No window found with game name: {self.game}")
|
||||
|
||||
self.game_window.activate()
|
||||
l, t, r, b = self.game_window.left, self.game_window.top, self.game_window.right, self.game_window.bottom
|
||||
self.bbox = (l, t, r-l, b-t)
|
||||
|
||||
# Initialize speedhack client if using DLL injection
|
||||
self.speedhack_client = xsh.Client(process_id=self.game_pid, arch=self.game_arch)
|
||||
|
||||
# Get the screenshot backend
|
||||
if screenshot_backend == "dxcam":
|
||||
self.screenshot_backend = DxcamScreenshotBackend(self.bbox)
|
||||
elif screenshot_backend == "pyautogui":
|
||||
self.screenshot_backend = PyautoguiScreenshotBackend(self.bbox)
|
||||
else:
|
||||
raise ValueError("Unsupported screenshot backend. Use 'dxcam' or 'pyautogui'.")
|
||||
|
||||
|
||||
def calculate_step_duration(self):
|
||||
"""
|
||||
Calculate the step duration based on game speed and environment FPS.
|
||||
|
||||
Returns:
|
||||
float: Calculated step duration.
|
||||
|
||||
Example:
|
||||
If game_speed=1.0 and env_fps=10, then step_duration
|
||||
will be 0.1 seconds.
|
||||
"""
|
||||
return 1.0 / (self.env_fps * self.game_speed)
|
||||
|
||||
def unpause(self):
|
||||
"""
|
||||
Unpause the game using the specified method.
|
||||
"""
|
||||
self.speedhack_client.set_speed(1.0)
|
||||
|
||||
def pause(self):
|
||||
"""
|
||||
Pause the game using the specified method.
|
||||
"""
|
||||
self.speedhack_client.set_speed(0.0)
|
||||
|
||||
def perform_action(self, action, duration):
|
||||
"""
|
||||
Perform the action without handling the game pause/unpause.
|
||||
|
||||
Parameters:
|
||||
action (dict): Action to be performed.
|
||||
duration (float): Duration for the action step.
|
||||
"""
|
||||
self.gamepad_emulator.step(action)
|
||||
start = time.perf_counter()
|
||||
self.unpause()
|
||||
# Wait until the next step
|
||||
end = start + self.step_duration
|
||||
now = time.perf_counter()
|
||||
while now < end:
|
||||
now = time.perf_counter()
|
||||
self.pause()
|
||||
|
||||
def step(self, action, step_duration=None):
|
||||
"""
|
||||
Perform an action in the game environment and return the observation.
|
||||
|
||||
Parameters:
|
||||
action (dict): Dictionary of the action to be performed. Keys are control names,
|
||||
and values are their respective states.
|
||||
step_duration (float, optional): Duration for which the action should be performed.
|
||||
|
||||
Returns:
|
||||
tuple: (obs, reward, terminated, truncated, info) where obs is the observation of the game environment after performing the action.
|
||||
"""
|
||||
# Determine the duration for this step
|
||||
duration = step_duration if step_duration is not None else self.step_duration
|
||||
|
||||
self.perform_action(action, duration)
|
||||
|
||||
obs = self.render() # Render after pausing the game
|
||||
|
||||
reward = 0.0
|
||||
terminated = False
|
||||
truncated = False
|
||||
info = {}
|
||||
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
def reset(self, seed=None, options=None):
|
||||
"""
|
||||
Reset the environment to its initial state.
|
||||
|
||||
Parameters:
|
||||
seed (int, optional): Random seed.
|
||||
options (dict, optional): Additional options for reset.
|
||||
"""
|
||||
self.gamepad_emulator.wakeup(duration=0.1)
|
||||
time.sleep(1.0)
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Close the environment and release any resources.
|
||||
"""
|
||||
pass # Implement env close logic here
|
||||
|
||||
def render(self):
|
||||
"""
|
||||
Render the current state of the game window as an observation.
|
||||
|
||||
Returns:
|
||||
Image: Observation of the game environment.
|
||||
"""
|
||||
screenshot = self.screenshot_backend.screenshot()
|
||||
screenshot = screenshot.resize((self.image_width, self.image_height))
|
||||
|
||||
return screenshot
|
||||
91
nitrogen/inference_client.py
Normal file
91
nitrogen/inference_client.py
Normal file
@ -0,0 +1,91 @@
|
||||
import time
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
import zmq
|
||||
|
||||
class ModelClient:
|
||||
"""Client for model inference server."""
|
||||
|
||||
def __init__(self, host="localhost", port=5555):
|
||||
"""
|
||||
Initialize client connection.
|
||||
|
||||
Args:
|
||||
host: Server hostname or IP
|
||||
port: Server port
|
||||
"""
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.timeout_ms = 30000
|
||||
|
||||
self.context = zmq.Context()
|
||||
self.socket = self.context.socket(zmq.REQ)
|
||||
self.socket.connect(f"tcp://{host}:{port}")
|
||||
self.socket.setsockopt(zmq.RCVTIMEO, self.timeout_ms) # Set receive timeout
|
||||
|
||||
print(f"Connected to model server at {host}:{port}")
|
||||
|
||||
def predict(self, image: np.ndarray) -> dict:
|
||||
"""
|
||||
Send an image and receive predicted actions.
|
||||
|
||||
Args:
|
||||
image: numpy array (H, W, 3) in RGB format
|
||||
|
||||
Returns:
|
||||
List of action dicts, each containing:
|
||||
- j_left: [x, y] left joystick position
|
||||
- j_right: [x, y] right joystick position
|
||||
- buttons: list of button values
|
||||
"""
|
||||
request = {
|
||||
"type": "predict",
|
||||
"image": image
|
||||
}
|
||||
|
||||
self.socket.send(pickle.dumps(request))
|
||||
response = pickle.loads(self.socket.recv())
|
||||
|
||||
if response["status"] != "ok":
|
||||
raise RuntimeError(f"Server error: {response.get('message', 'Unknown error')}")
|
||||
|
||||
return response["pred"]
|
||||
|
||||
def reset(self):
|
||||
"""Reset the server's session (clear buffers)."""
|
||||
request = {"type": "reset"}
|
||||
|
||||
self.socket.send(pickle.dumps(request))
|
||||
response = pickle.loads(self.socket.recv())
|
||||
|
||||
if response["status"] != "ok":
|
||||
raise RuntimeError(f"Server error: {response.get('message', 'Unknown error')}")
|
||||
|
||||
print("Session reset")
|
||||
|
||||
def info(self) -> dict:
|
||||
"""Get session info from the server."""
|
||||
request = {"type": "info"}
|
||||
|
||||
self.socket.send(pickle.dumps(request))
|
||||
response = pickle.loads(self.socket.recv())
|
||||
|
||||
if response["status"] != "ok":
|
||||
raise RuntimeError(f"Server error: {response.get('message', 'Unknown error')}")
|
||||
|
||||
return response["info"]
|
||||
|
||||
def close(self):
|
||||
"""Close the connection."""
|
||||
self.socket.close()
|
||||
self.context.term()
|
||||
print("Connection closed")
|
||||
|
||||
def __enter__(self):
|
||||
"""Support for context manager."""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Close connection when exiting context."""
|
||||
self.close()
|
||||
276
nitrogen/inference_session.py
Normal file
276
nitrogen/inference_session.py
Normal file
@ -0,0 +1,276 @@
|
||||
import time
|
||||
import json
|
||||
from collections import deque
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from transformers import AutoImageProcessor
|
||||
from nitrogen.flow_matching_transformer.nitrogen import NitroGen, NitroGen_Config
|
||||
from nitrogen.mm_tokenizers import NitrogenTokenizerConfig, NitrogenTokenizer, Tokenizer
|
||||
from nitrogen.cfg import CkptConfig
|
||||
from nitrogen.shared import PATH_REPO
|
||||
|
||||
def summarize_parameters(module, name='model', depth=0, max_depth=3):
|
||||
"""
|
||||
Print a tree-like summary of parameters in a PyTorch module.
|
||||
|
||||
Args:
|
||||
module: PyTorch module to summarize
|
||||
name: Name of the module (for root level)
|
||||
depth: Current depth in the tree
|
||||
max_depth: Maximum depth to traverse
|
||||
"""
|
||||
if depth > max_depth:
|
||||
return
|
||||
|
||||
# Count total parameters in this module
|
||||
total_params = sum(p.numel() for p in module.parameters())
|
||||
trainable_params = sum(p.numel() for p in module.parameters() if p.requires_grad)
|
||||
|
||||
# Print indented summary
|
||||
indent = " " * depth
|
||||
print(f"{indent}{name}: {total_params:,} params ({trainable_params:,} trainable)")
|
||||
|
||||
# Recursively summarize submodules
|
||||
if depth < max_depth:
|
||||
for child_name, child_module in module.named_children():
|
||||
summarize_parameters(child_module, child_name, depth + 1, max_depth)
|
||||
|
||||
|
||||
def load_model(checkpoint_path: str):
|
||||
"""Load model and args from checkpoint."""
|
||||
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
||||
ckpt_config = CkptConfig.model_validate(checkpoint["ckpt_config"])
|
||||
model_cfg = ckpt_config.model_cfg
|
||||
tokenizer_cfg = ckpt_config.tokenizer_cfg
|
||||
|
||||
print("Checkpoint args:")
|
||||
print(json.dumps(ckpt_config.model_dump(), indent=4))
|
||||
|
||||
# Initialize tokenizer and language model
|
||||
img_proc = AutoImageProcessor.from_pretrained(model_cfg.vision_encoder_name)
|
||||
|
||||
# Create VLM with pre-loaded language model
|
||||
if isinstance(model_cfg, NitroGen_Config):
|
||||
assert isinstance(tokenizer_cfg, NitrogenTokenizerConfig), \
|
||||
"NitroGen_Config requires NitrogenTokenizerConfig for tokenization"
|
||||
tokenizer_cfg.training = False
|
||||
if tokenizer_cfg.game_mapping_cfg is not None:
|
||||
tokenizer_cfg.game_mapping_cfg.src_files = [
|
||||
x.replace("/mnt/amlfs-02/shared/gaming/gamingvla", str(PATH_REPO))
|
||||
for x in tokenizer_cfg.game_mapping_cfg.src_files
|
||||
]
|
||||
tokenizer = NitrogenTokenizer(tokenizer_cfg)
|
||||
game_mapping = tokenizer.game_mapping
|
||||
model = NitroGen(config=model_cfg, game_mapping=game_mapping)
|
||||
# model.num_inference_timesteps = 16
|
||||
action_downsample_ratio = 1
|
||||
else:
|
||||
raise ValueError(f"Unsupported model config type: {type(model_cfg)}")
|
||||
|
||||
summarize_parameters(model, max_depth=3)
|
||||
|
||||
print(model)
|
||||
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
model.eval()
|
||||
tokenizer.eval()
|
||||
model.to("cuda")
|
||||
|
||||
return model, tokenizer, img_proc, ckpt_config, game_mapping, action_downsample_ratio
|
||||
|
||||
class InferenceSession:
|
||||
"""Manages state for a single inference session."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
ckpt_path: str,
|
||||
tokenizer: Tokenizer,
|
||||
img_proc,
|
||||
ckpt_config: CkptConfig,
|
||||
game_mapping: dict,
|
||||
selected_game: str,
|
||||
old_layout: bool,
|
||||
cfg_scale: float,
|
||||
action_downsample_ratio: float,
|
||||
context_length=None
|
||||
):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.img_proc = img_proc
|
||||
self.ckpt_config = ckpt_config
|
||||
self.game_mapping = game_mapping
|
||||
self.selected_game = selected_game
|
||||
self.old_layout = old_layout
|
||||
self.cfg_scale = cfg_scale
|
||||
self.action_downsample_ratio = action_downsample_ratio
|
||||
self.ckpt_path = ckpt_path
|
||||
|
||||
# Load modality config
|
||||
self.modality_config = self.ckpt_config.modality_cfg
|
||||
|
||||
self.max_buffer_size = context_length if context_length is not None else self.modality_config.frame_per_sample
|
||||
self.action_interleaving = self.modality_config.action_interleaving
|
||||
self.is_flowmatching = isinstance(self.ckpt_config.model_cfg, NitroGen_Config)
|
||||
|
||||
# Buffers
|
||||
self.obs_buffer = deque(maxlen=self.max_buffer_size)
|
||||
self.action_buffer = deque(maxlen=self.max_buffer_size)
|
||||
|
||||
@classmethod
|
||||
def from_ckpt(cls, checkpoint_path: str, old_layout=False, cfg_scale=1.0, context_length=None):
|
||||
"""Create an InferenceSession from a checkpoint."""
|
||||
model, tokenizer, img_proc, ckpt_config, game_mapping, action_downsample_ratio = load_model(checkpoint_path)
|
||||
|
||||
if game_mapping is not None:
|
||||
# Ask user to pick a game from the list
|
||||
print("Available games in tokenizer mapping:")
|
||||
for game, idx in game_mapping.items():
|
||||
print(f"{idx:03d}: {game}")
|
||||
selected_game = input("Enter the game ID to use (leave empty for unconditional): ")
|
||||
if selected_game == "":
|
||||
selected_game = None
|
||||
else:
|
||||
selected_idx = int(selected_game)
|
||||
assert selected_idx in game_mapping.values(), f"Invalid game ID {selected_idx}"
|
||||
|
||||
candidates = [k for k,v in game_mapping.items() if v == selected_idx]
|
||||
assert len(candidates) == 1, f"Multiple games found for ID {selected_idx}: {candidates}"
|
||||
|
||||
selected_game = candidates[0]
|
||||
else:
|
||||
selected_game = None
|
||||
print("No game mapping available, proceeding without game conditioning")
|
||||
|
||||
return cls(
|
||||
model,
|
||||
checkpoint_path,
|
||||
tokenizer,
|
||||
img_proc,
|
||||
ckpt_config,
|
||||
game_mapping,
|
||||
selected_game,
|
||||
old_layout,
|
||||
cfg_scale,
|
||||
action_downsample_ratio,
|
||||
context_length
|
||||
)
|
||||
|
||||
def info(self):
|
||||
return {
|
||||
"ckpt_path": self.ckpt_path,
|
||||
"selected_game": self.selected_game,
|
||||
"old_layout": self.old_layout,
|
||||
"cfg_scale": self.cfg_scale,
|
||||
"context_length": self.max_buffer_size,
|
||||
"action_interleaving": self.action_interleaving,
|
||||
"is_flowmatching": self.is_flowmatching,
|
||||
"action_downsample_ratio": self.action_downsample_ratio,
|
||||
}
|
||||
|
||||
def reset(self):
|
||||
"""Reset all buffers."""
|
||||
self.obs_buffer.clear()
|
||||
self.action_buffer.clear()
|
||||
|
||||
def predict(self, obs):
|
||||
start_time = time.time()
|
||||
|
||||
current_frame = self.img_proc([obs], return_tensors="pt")["pixel_values"]
|
||||
self.obs_buffer.append(current_frame)
|
||||
|
||||
# Prepare model inputs
|
||||
pixel_values = torch.cat(list(self.obs_buffer), dim=0)
|
||||
|
||||
if self.action_interleaving and len(self.action_buffer) > 0:
|
||||
action_tensors = {
|
||||
key: torch.cat([a[key] for a in list(self.action_buffer)], dim=0)
|
||||
for key in ["buttons", "j_left", "j_right"]
|
||||
}
|
||||
else:
|
||||
action_tensors = {"buttons": None, "j_left": None, "j_right": None}
|
||||
|
||||
print("Running inference with the following inputs:")
|
||||
print(f"- pixel_values: {pixel_values.shape}")
|
||||
print("- action_tensors:")
|
||||
for k, v in action_tensors.items():
|
||||
if v is not None:
|
||||
print(f" - {k}: {v.shape}")
|
||||
else:
|
||||
print(f" - {k}: None")
|
||||
|
||||
# Run inference
|
||||
if self.is_flowmatching:
|
||||
predicted_actions = self._predict_flowmatching(pixel_values, action_tensors)
|
||||
else:
|
||||
predicted_actions = self._predict_ar(pixel_values, action_tensors)
|
||||
|
||||
# Add to action buffer
|
||||
self.action_buffer.append(predicted_actions)
|
||||
|
||||
inference_time = time.time() - start_time
|
||||
print(f"Inference time: {inference_time:.3f}s")
|
||||
|
||||
# Convert to list of action dicts
|
||||
n_actions = len(predicted_actions["buttons"])
|
||||
j_left = predicted_actions["j_left"].squeeze().cpu().numpy()
|
||||
j_right = predicted_actions["j_right"].squeeze().cpu().numpy()
|
||||
buttons = predicted_actions["buttons"].squeeze().cpu().numpy()
|
||||
|
||||
return {
|
||||
"j_left": j_left,
|
||||
"j_right": j_right,
|
||||
"buttons": buttons,
|
||||
}
|
||||
|
||||
def _predict_flowmatching(self, pixel_values, action_tensors):
|
||||
|
||||
available_frames = len(self.obs_buffer)
|
||||
frames = torch.zeros((self.max_buffer_size, *pixel_values.shape[1:]),
|
||||
dtype=pixel_values.dtype, device="cuda")
|
||||
frames[-available_frames:] = pixel_values
|
||||
dropped_frames = torch.zeros((self.max_buffer_size,), dtype=torch.bool, device="cuda")
|
||||
dropped_frames[:self.max_buffer_size - available_frames] = True
|
||||
|
||||
data_with_history = {
|
||||
"frames": frames,
|
||||
"dropped_frames": dropped_frames,
|
||||
"game": self.selected_game
|
||||
}
|
||||
tokenized_data_with_history = self.tokenizer.encode(data_with_history)
|
||||
|
||||
frame_mask = torch.ones((self.max_buffer_size,), dtype=torch.bool, device="cuda")
|
||||
frame_mask[-1] = False
|
||||
data_without_history = {
|
||||
"frames": frames,
|
||||
"dropped_frames": frame_mask,
|
||||
"game": None
|
||||
}
|
||||
tokenized_data_without_history = self.tokenizer.encode(data_without_history)
|
||||
|
||||
# Convert to CUDA tensors with batch dimension
|
||||
for tokenized_data in [tokenized_data_with_history, tokenized_data_without_history]:
|
||||
for k, v in tokenized_data.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
tokenized_data[k] = v.unsqueeze(0).to("cuda")
|
||||
elif isinstance(v, np.ndarray):
|
||||
tokenized_data[k] = torch.tensor(v, device="cuda").unsqueeze(0)
|
||||
else:
|
||||
tokenized_data[k] = [v]
|
||||
|
||||
with torch.inference_mode():
|
||||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||
if self.cfg_scale == 1.0:
|
||||
model_output = self.model.get_action(tokenized_data_with_history,
|
||||
old_layout=self.old_layout)
|
||||
else:
|
||||
model_output = self.model.get_action_with_cfg(
|
||||
tokenized_data_with_history,
|
||||
tokenized_data_without_history,
|
||||
cfg_scale=self.cfg_scale
|
||||
)
|
||||
predicted_actions = self.tokenizer.decode(model_output)
|
||||
|
||||
return predicted_actions
|
||||
252
nitrogen/inference_viz.py
Normal file
252
nitrogen/inference_viz.py
Normal file
@ -0,0 +1,252 @@
|
||||
import numpy as np
|
||||
import cv2
|
||||
import av
|
||||
|
||||
def create_viz(
|
||||
frame: np.ndarray,
|
||||
i: int,
|
||||
j_left: np.ndarray,
|
||||
j_right: np.ndarray,
|
||||
buttons: np.ndarray,
|
||||
token_set: list,
|
||||
):
|
||||
"""
|
||||
Visualize gamepad actions alongside a gameplay video frame.
|
||||
|
||||
Parameters:
|
||||
- frame: Video frame as numpy array
|
||||
- i: Current frame index (default 0)
|
||||
- j_left: 16x2 array of left joystick positions (-1 to 1)
|
||||
- j_right: 16x2 array of right joystick positions (-1 to 1)
|
||||
- buttons: 16x17 array of button states (boolean)
|
||||
- token_set: List of button names
|
||||
|
||||
Returns:
|
||||
- Visualization as numpy array
|
||||
"""
|
||||
# Get frame dimensions
|
||||
frame_height, frame_width = frame.shape[:2]
|
||||
|
||||
# Create visualization area
|
||||
viz_width = min(500, frame_width)
|
||||
combined_width = frame_width + viz_width
|
||||
combined_height = frame_height
|
||||
|
||||
# Create combined image (frame + visualization)
|
||||
combined = np.zeros((combined_height, combined_width, 3), dtype=np.uint8)
|
||||
|
||||
# Put the frame on the left side
|
||||
combined[:frame_height, :frame_width] = frame
|
||||
|
||||
# Starting position for visualizations
|
||||
viz_x = frame_width
|
||||
viz_y = 20
|
||||
|
||||
# Draw joysticks if data is provided
|
||||
if i < len(j_left) and i < len(j_right):
|
||||
# Add section title
|
||||
cv2.putText(combined, "JOYSTICKS",
|
||||
(viz_x + 10, viz_y),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (200, 200, 200), 1)
|
||||
|
||||
viz_y += 30 # Move down after title
|
||||
|
||||
# Size of joystick visualization
|
||||
joy_size = min(120, viz_width // 3)
|
||||
|
||||
# Horizontal positions of joysticks
|
||||
joy_left_x = viz_x + 30
|
||||
joy_right_x = viz_x + viz_width - joy_size - 30
|
||||
|
||||
# Draw joystick labels
|
||||
cv2.putText(combined, "Left", (joy_left_x, viz_y - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (180, 180, 180), 1)
|
||||
cv2.putText(combined, "Right", (joy_right_x, viz_y - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (180, 180, 180), 1)
|
||||
|
||||
# Draw joysticks
|
||||
draw_joystick(combined, joy_left_x, viz_y, joy_size, j_left[i])
|
||||
draw_joystick(combined, joy_right_x, viz_y, joy_size, j_right[i])
|
||||
|
||||
viz_y += joy_size + 40 # Move down after joysticks
|
||||
|
||||
# Draw buttons if data is provided
|
||||
if buttons is not None and i < len(buttons):
|
||||
# Add section title
|
||||
cv2.putText(combined, "BUTTON STATES",
|
||||
(viz_x + 10, viz_y),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (200, 200, 200), 1)
|
||||
|
||||
viz_y += 30 # Move down after title
|
||||
|
||||
# Size and position of button grid
|
||||
button_grid_x = viz_x + 20
|
||||
button_grid_y = viz_y
|
||||
button_size = 20
|
||||
|
||||
# Draw button grid
|
||||
draw_button_grid(combined, button_grid_x, button_grid_y,
|
||||
button_size, buttons, i, token_set)
|
||||
|
||||
return combined
|
||||
|
||||
def draw_joystick(img, x, y, size, position):
|
||||
"""Draw a joystick visualization at the specified position."""
|
||||
# Draw joystick background
|
||||
cv2.rectangle(img, (x, y), (x + size, y + size), (50, 50, 50), -1)
|
||||
cv2.rectangle(img, (x, y), (x + size, y + size), (100, 100, 100), 1)
|
||||
|
||||
# Calculate center point
|
||||
mid_x = x + size // 2
|
||||
mid_y = y + size // 2
|
||||
|
||||
# Draw center cross (0,0 coordinates)
|
||||
cv2.line(img, (x, mid_y), (x + size, mid_y), (150, 150, 150), 1)
|
||||
cv2.line(img, (mid_x, y), (mid_x, y + size), (150, 150, 150), 1)
|
||||
|
||||
# Draw 2x2 grid
|
||||
quarter_x = x + size // 4
|
||||
quarter_y = y + size // 4
|
||||
three_quarters_x = x + 3 * size // 4
|
||||
three_quarters_y = y + 3 * size // 4
|
||||
|
||||
# Draw grid lines
|
||||
cv2.line(img, (quarter_x, y), (quarter_x, y + size), (100, 100, 100), 1)
|
||||
cv2.line(img, (three_quarters_x, y), (three_quarters_x, y + size), (100, 100, 100), 1)
|
||||
cv2.line(img, (x, quarter_y), (x + size, quarter_y), (100, 100, 100), 1)
|
||||
cv2.line(img, (x, three_quarters_y), (x + size, three_quarters_y), (100, 100, 100), 1)
|
||||
|
||||
# Draw joystick position (clamp coordinates to valid range)
|
||||
px = max(-1, min(1, position[0]))
|
||||
py = max(-1, min(1, position[1]))
|
||||
|
||||
joy_x = int(mid_x + px * size // 2)
|
||||
joy_y = int(mid_y - py * size // 2) # Y is inverted in image coordinates
|
||||
|
||||
# Draw joystick position as a dot
|
||||
cv2.circle(img, (joy_x, joy_y), 5, (0, 0, 255), -1) # Red dot
|
||||
|
||||
def draw_button_grid(img, x, y, button_size, buttons, current_row, token_set):
|
||||
"""Draw the button state grid."""
|
||||
rows, cols = buttons.shape
|
||||
|
||||
# Ensure the grid fits in the visualization area
|
||||
available_width = img.shape[1] - x - 20
|
||||
if cols * button_size > available_width:
|
||||
button_size = max(10, available_width // cols)
|
||||
|
||||
# Draw column numbers at the top
|
||||
for col in range(cols):
|
||||
number_x = x + col * button_size + button_size // 2
|
||||
number_y = y - 5
|
||||
cv2.putText(img, str(col + 1), (number_x - 4, number_y),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.4, (200, 200, 200), 1)
|
||||
|
||||
# Draw button grid
|
||||
for row in range(rows):
|
||||
for col in range(cols):
|
||||
# Calculate button position
|
||||
bx = x + col * button_size
|
||||
by = y + row * button_size
|
||||
|
||||
# Draw button cell
|
||||
color = (0, 255, 0) if buttons[row, col] else (0, 0, 0) # Green if pressed, black otherwise
|
||||
cv2.rectangle(img, (bx, by), (bx + button_size, by + button_size), color, -1)
|
||||
|
||||
# Draw grid lines
|
||||
cv2.rectangle(img, (bx, by), (bx + button_size, by + button_size), (80, 80, 80), 1)
|
||||
|
||||
# Highlight current row
|
||||
highlight_y = y + current_row * button_size
|
||||
cv2.rectangle(img, (x, highlight_y), (x + cols * button_size, highlight_y + button_size),
|
||||
(0, 0, 255), 2) # Red highlight
|
||||
|
||||
# Draw button legend below the mosaic
|
||||
if token_set is not None:
|
||||
legend_y = y + rows * button_size + 20 # Starting Y position for legend
|
||||
legend_x = x # Starting X position for legend
|
||||
line_height = 15 # Height of each legend line
|
||||
|
||||
# Add legend title
|
||||
cv2.putText(img, "Button Legend:", (legend_x, legend_y),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (200, 200, 200), 1)
|
||||
legend_y += line_height + 5 # Move down after title
|
||||
|
||||
# Calculate how many columns to use for the legend based on available space
|
||||
legend_cols = max(1, min(3, cols // 6)) # Use 1-3 columns depending on button count
|
||||
legend_items_per_col = (cols + legend_cols - 1) // legend_cols # Items per column with ceiling division
|
||||
|
||||
# Draw legend entries
|
||||
for col in range(min(cols, len(token_set))):
|
||||
# Calculate position in the legend grid
|
||||
legend_col = col // legend_items_per_col
|
||||
legend_row = col % legend_items_per_col
|
||||
|
||||
# Calculate position
|
||||
entry_x = legend_x + legend_col * (available_width // legend_cols)
|
||||
entry_y = legend_y + legend_row * line_height
|
||||
|
||||
# Add legend entry
|
||||
if col < len(token_set):
|
||||
cv2.putText(img, f"{col+1}. {token_set[col]}", (entry_x, entry_y),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.4, (200, 200, 200), 1)
|
||||
|
||||
class VideoRecorder:
|
||||
def __init__(self, output_file, fps=30, crf=28, preset="fast"):
|
||||
"""
|
||||
Initialize a video recorder using PyAV.
|
||||
|
||||
Args:
|
||||
output_file (str): Path to save the video file
|
||||
fps (int): Frames per second
|
||||
crf (int): Constant Rate Factor (0-51, higher means smaller file but lower quality)
|
||||
preset (str): Encoding preset (ultrafast, superfast, veryfast, faster, fast, medium, slow, slower, veryslow)
|
||||
"""
|
||||
self.output_file = output_file
|
||||
self.fps = fps
|
||||
self.crf = str(crf)
|
||||
self.preset = preset
|
||||
self.container = av.open(output_file, mode="w")
|
||||
self.stream = None
|
||||
|
||||
def init_stream(self, width, height):
|
||||
"""Initialize the video stream with the frame dimensions."""
|
||||
self.stream = self.container.add_stream("h264", rate=self.fps)
|
||||
self.stream.width = width
|
||||
self.stream.height = height
|
||||
self.stream.pix_fmt = "yuv420p"
|
||||
self.stream.options = {
|
||||
"crf": self.crf,
|
||||
"preset": self.preset
|
||||
}
|
||||
|
||||
def add_frame(self, frame):
|
||||
"""
|
||||
Add a frame to the video.
|
||||
|
||||
Args:
|
||||
frame (numpy.ndarray): Frame as RGB numpy array
|
||||
"""
|
||||
if self.stream is None:
|
||||
self.init_stream(frame.shape[1], frame.shape[0])
|
||||
|
||||
av_frame = av.VideoFrame.from_ndarray(np.array(frame), format="rgb24")
|
||||
for packet in self.stream.encode(av_frame):
|
||||
self.container.mux(packet)
|
||||
|
||||
def close(self):
|
||||
"""Flush remaining packets and close the video file."""
|
||||
try:
|
||||
if self.stream is not None:
|
||||
for packet in self.stream.encode():
|
||||
self.container.mux(packet)
|
||||
finally:
|
||||
self.container.close()
|
||||
|
||||
def __enter__(self):
|
||||
"""Support for context manager."""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Close the recorder when exiting the context."""
|
||||
self.close()
|
||||
332
nitrogen/mm_tokenizers.py
Normal file
332
nitrogen/mm_tokenizers.py
Normal file
@ -0,0 +1,332 @@
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import polars as pl
|
||||
from pydantic import BaseModel, Field
|
||||
from itertools import count
|
||||
|
||||
|
||||
DEBUG = int(os.getenv("DEBUG", 0))
|
||||
|
||||
class Tokenizer(ABC):
|
||||
@abstractmethod
|
||||
def encode(self, data: dict) -> dict:
|
||||
"""
|
||||
Transform the input data into a tokenized format.
|
||||
|
||||
Args:
|
||||
data (dict): Input data containing frames and actions.
|
||||
|
||||
Returns:
|
||||
dict: Tokenized data ready for model input.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def decode(self, data: dict) -> dict:
|
||||
"""
|
||||
Reverse the tokenization process to retrieve original data.
|
||||
|
||||
Args:
|
||||
data (dict): Tokenized data.
|
||||
|
||||
Returns:
|
||||
dict: Original data structure.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def train(self):
|
||||
"""
|
||||
Set the tokenizer to training mode.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def eval(self):
|
||||
"""
|
||||
Set the tokenizer to evaluation mode.
|
||||
"""
|
||||
pass
|
||||
|
||||
# Set IDs for each token type.
|
||||
_PAD_TOKEN = 0
|
||||
_IMG_TOKEN = 1
|
||||
_IMG_SEP_TOKEN = 5 # New separator token
|
||||
_LANG_TOKEN = 2
|
||||
_PROPRIO_TOKEN = 3
|
||||
_ACT_TOKEN = 4
|
||||
_GAME_ID_TOKEN = 6
|
||||
|
||||
|
||||
_UNCONDITIONAL_ID = None # Special ID for unconditional game
|
||||
|
||||
class GameMappingConfig(BaseModel):
|
||||
src_files: list[str] = Field(default_factory=list, description="List of source parquet files to build game mapping.")
|
||||
|
||||
def get_game_mapping(cfg: GameMappingConfig) -> dict:
|
||||
game_set = set()
|
||||
for path in cfg.src_files:
|
||||
df = pl.read_parquet(path)
|
||||
for game in df['game_label'].unique():
|
||||
if game == _UNCONDITIONAL_ID:
|
||||
continue
|
||||
game_set.add(game)
|
||||
games = sorted(list(game_set))
|
||||
|
||||
# Set the 0th element to be the unconditional game ID
|
||||
games = [_UNCONDITIONAL_ID] + games
|
||||
return {game: idx for idx, game in enumerate(games)}
|
||||
|
||||
class NitrogenTokenizerConfig(BaseModel):
|
||||
tokenizer_id: Literal['nitrogen'] = Field(default='nitrogen', frozen=True)
|
||||
training: bool = Field(default=True, description="Whether to apply the transform in training mode.")
|
||||
num_visual_tokens_per_frame: int = Field(default=256, description="Number of visual tokens per frame.")
|
||||
max_action_dim: int = Field(default=25, description="Maximum action dimension.")
|
||||
max_sequence_length: int = Field(default=300, description="Maximum sequence length.")
|
||||
action_horizon: int = Field(default=16, description="Action horizon.")
|
||||
game_mapping_cfg: GameMappingConfig | None = Field(default=None, description="Game mapping configuration.")
|
||||
old_layout: bool = Field(default=False, description="Whether to use the old layout for actions. If True, the action layout is [buttons, j_left, j_right]. If False, it is [j_left, j_right, buttons].")
|
||||
|
||||
class NitrogenTokenizer(Tokenizer):
|
||||
"""
|
||||
Example transform that prepares video, language, state, and actions
|
||||
into a token-based format suitable for your model.
|
||||
|
||||
The sub-methods below (prefixed with `_prepare_`) mirror the original
|
||||
modular structure.
|
||||
"""
|
||||
|
||||
def __init__(self, config: NitrogenTokenizerConfig):
|
||||
self.training = config.training
|
||||
self.num_visual_tokens_per_frame = config.num_visual_tokens_per_frame
|
||||
self.max_action_dim = config.max_action_dim
|
||||
self.max_sequence_length = config.max_sequence_length
|
||||
self.action_horizon = config.action_horizon
|
||||
self.old_layout = config.old_layout
|
||||
|
||||
if config.game_mapping_cfg:
|
||||
self.game_mapping = get_game_mapping(config.game_mapping_cfg)
|
||||
with open("game_mapping.json", "w") as f:
|
||||
import json
|
||||
json.dump(self.game_mapping, f, indent=2)
|
||||
else:
|
||||
self.game_mapping = None
|
||||
|
||||
def train(self):
|
||||
self.training = True
|
||||
|
||||
def eval(self):
|
||||
self.training = False
|
||||
|
||||
def check_batch_size(self, data):
|
||||
# Use video key to determine batch size.
|
||||
video_ndim = data["images"].ndim
|
||||
if video_ndim == 4: # Interpret as [T*V, H, W, C]
|
||||
is_batched = False
|
||||
batch_size = 1
|
||||
elif video_ndim == 5: # Interpret as [B, T*V, H, W, C]
|
||||
is_batched = True
|
||||
batch_size = data["images"].shape[0]
|
||||
else:
|
||||
raise ValueError(f"Unsupported video number of dimensions: {video_ndim}")
|
||||
|
||||
return is_batched, batch_size
|
||||
|
||||
def _prepare_action(self, data: dict):
|
||||
"""
|
||||
Pad to max_action_dim, return masks.
|
||||
"""
|
||||
if "action" not in data:
|
||||
actions = np.zeros((self.action_horizon, self.max_action_dim))
|
||||
actions_mask = np.zeros((self.action_horizon, self.max_action_dim), dtype=bool)
|
||||
n_action_tokens = self.action_horizon
|
||||
return actions, actions_mask, n_action_tokens
|
||||
|
||||
actions = data["action"]
|
||||
assert actions.shape[0] == self.action_horizon, f"{actions.shape=}, {self.action_horizon=}"
|
||||
|
||||
n_action_tokens = actions.shape[0] # T
|
||||
n_action_dims = actions.shape[1]
|
||||
|
||||
assert (
|
||||
n_action_dims <= self.max_action_dim
|
||||
), f"Action dim {n_action_dims} exceeds max allowed {self.max_action_dim}."
|
||||
|
||||
# Pad the channel dimension
|
||||
actions = np.pad(actions, ((0, 0), (0, self.max_action_dim - n_action_dims)), "constant")
|
||||
|
||||
# Create mask: [T, max_action_dim]
|
||||
actions_mask = np.zeros((n_action_tokens, self.max_action_dim), dtype=bool)
|
||||
actions_mask[:, :n_action_dims] = True
|
||||
|
||||
return actions, actions_mask, n_action_tokens
|
||||
|
||||
def _build_token_ids(self, n_images, n_action_tokens): # n_lang_tokens, n_state_tokens):
|
||||
"""
|
||||
Build the 1D array of token_ids based on the number of each block.
|
||||
Return (token_ids, special_pad_token_idx).
|
||||
"""
|
||||
vl_token_ids = []
|
||||
sa_token_ids = []
|
||||
|
||||
# 0.5) Add a Game ID placeholder
|
||||
if self.game_mapping:
|
||||
vl_token_ids.append(_GAME_ID_TOKEN)
|
||||
|
||||
# 1) Video placeholders
|
||||
for _ in range(n_images):
|
||||
vl_token_ids.extend([_IMG_TOKEN] * self.num_visual_tokens_per_frame)
|
||||
|
||||
# 2) Action tokens
|
||||
sa_token_ids.extend([_ACT_TOKEN] * n_action_tokens)
|
||||
|
||||
return np.array(vl_token_ids), np.array(sa_token_ids)
|
||||
|
||||
def _prepare_attention_mask(
|
||||
self,
|
||||
vl_token_ids: np.ndarray,
|
||||
):
|
||||
"""
|
||||
Build 1D attention mask for vision-language tokens.
|
||||
1 indicates valid token, 0 indicates padding token.
|
||||
State-action attention will be handled separately by the model.
|
||||
"""
|
||||
# Only create attention mask for vision-language tokens
|
||||
vl_seq_len = vl_token_ids.shape[0]
|
||||
vl_attn_mask = np.ones(vl_seq_len, dtype=bool) # All tokens are valid initially
|
||||
|
||||
# Pad vl_token_ids and vl_attn_mask to max_sequence_length
|
||||
if vl_seq_len > self.max_sequence_length:
|
||||
raise ValueError("VL sequence length exceeds the max sequence length!")
|
||||
|
||||
left_pad_len = self.max_sequence_length - vl_seq_len
|
||||
|
||||
# Pad token_ids (with PAD_TOKEN)
|
||||
vl_token_ids = np.pad(vl_token_ids, (left_pad_len, 0), constant_values=_PAD_TOKEN)
|
||||
|
||||
# Pad attention mask with 0 (padding tokens)
|
||||
vl_attn_mask = np.pad(vl_attn_mask, (left_pad_len, 0), constant_values=0)
|
||||
|
||||
return vl_token_ids, vl_attn_mask
|
||||
|
||||
def pack_actions(self, buttons, j_left, j_right):
|
||||
# Check that the first two dims of each input is the same (number of chunks, control frequency)
|
||||
assert buttons.shape[:2] == j_left.shape[:2] == j_right.shape[:2], (
|
||||
f"buttons shape: {buttons.shape}, "
|
||||
f"j_left shape: {j_left.shape}, "
|
||||
f"j_right shape: {j_right.shape}"
|
||||
)
|
||||
|
||||
# Normalize the joysticks to 0,1
|
||||
j_left = (j_left + 1) / 2.
|
||||
j_right = (j_right + 1) / 2.
|
||||
|
||||
# Concatenate the buttons and joysticks along the last dimension
|
||||
action = np.concatenate([buttons,j_left,j_right],axis=-1, dtype=np.float32)
|
||||
|
||||
# Squeeze the first dimension of each input: this is the number of chunks, which is 1 here
|
||||
action = action.squeeze(0)
|
||||
return action
|
||||
|
||||
def unpack_actions(self, actions):
|
||||
if self.old_layout:
|
||||
# Unpack the actions into j_left, j_right, buttons
|
||||
j_left = actions[:, :, :2]
|
||||
j_right = actions[:, :, 2:4]
|
||||
buttons = actions[:, :, 4:]
|
||||
else:
|
||||
# Unpack the actions into j_left, j_right, buttons
|
||||
buttons = actions[:, :, :-4]
|
||||
j_left = actions[:, :, -4:-2]
|
||||
j_right = actions[:, :, -2:]
|
||||
|
||||
# Denormalize the joysticks back to -1,1
|
||||
j_left = j_left * 2. - 1.
|
||||
j_right = j_right * 2. - 1.
|
||||
|
||||
# Clip into [-1,1]
|
||||
j_left = torch.clamp(j_left, -1, 1)
|
||||
j_right = torch.clamp(j_right, -1, 1)
|
||||
|
||||
# Threshold the buttons to 0/1
|
||||
buttons = (buttons > 0.5).float()
|
||||
return j_left, j_right, buttons
|
||||
|
||||
###########################################################################
|
||||
# apply
|
||||
###########################################################################
|
||||
def encode(self, data: dict) -> dict:
|
||||
"""
|
||||
Main entry point for the transform. We assume that `data` has
|
||||
data['video'], data['language'], data['state'], and data['action'] in
|
||||
the shapes needed. If you have multiple keys for each modality, you
|
||||
could use your own grouping logic (similar to GR1Transform) first.
|
||||
"""
|
||||
|
||||
# 1) Pack buttons/joysticks into a single action tensor
|
||||
|
||||
|
||||
transformed_data = {**data} # Start with a copy of the input data
|
||||
|
||||
n_images = (data["dropped_frames"] == False).sum()
|
||||
transformed_data["images"] = data["frames"]
|
||||
transformed_data["dropped_images"] = data["dropped_frames"]
|
||||
|
||||
if self.training:
|
||||
# Keep the original actions in the data for evaluation
|
||||
packed_actions = self.pack_actions(
|
||||
data["buttons"],
|
||||
data["j_left"],
|
||||
data["j_right"]
|
||||
)
|
||||
data["action"] = packed_actions
|
||||
|
||||
transformed_data["has_real_action"] = np.ones((), dtype=bool)
|
||||
|
||||
actions, actions_mask, n_action_tokens = self._prepare_action(data)
|
||||
transformed_data["actions"] = actions
|
||||
transformed_data["actions_mask"] = actions_mask
|
||||
|
||||
action_and_mask_keys = ["actions", "actions_mask"]
|
||||
assert all(
|
||||
transformed_data[key].shape == transformed_data["actions"].shape
|
||||
for key in action_and_mask_keys
|
||||
), f"Shape mismatch: {[(key, transformed_data[key].shape) for key in action_and_mask_keys]}"
|
||||
else:
|
||||
n_action_tokens = self.action_horizon
|
||||
|
||||
transformed_data["has_detection_target"] = np.zeros((), dtype=bool)
|
||||
|
||||
# 5) Build token_ids
|
||||
vl_token_ids, sa_token_ids = self._build_token_ids(
|
||||
n_images, n_action_tokens
|
||||
)
|
||||
|
||||
# 6) Build the attention mask only for vision-language tokens
|
||||
vl_token_ids, vl_attn_mask = self._prepare_attention_mask(vl_token_ids)
|
||||
|
||||
transformed_data["vl_token_ids"] = vl_token_ids
|
||||
transformed_data["sa_token_ids"] = sa_token_ids
|
||||
transformed_data["vl_attn_mask"] = vl_attn_mask
|
||||
transformed_data["embodiment_id"] = torch.tensor(0, dtype=torch.long)
|
||||
|
||||
if self.game_mapping:
|
||||
game_name = data["game"]
|
||||
assert game_name in self.game_mapping, f"Game '{game_name}' not found in game mapping."
|
||||
transformed_data["game_ids"] = torch.tensor(self.game_mapping[game_name], dtype=torch.long)
|
||||
else:
|
||||
transformed_data["game_ids"] = torch.tensor(0, dtype=torch.long)
|
||||
return transformed_data
|
||||
|
||||
def decode(self, data: dict) -> dict:
|
||||
j_left, j_right, buttons = self.unpack_actions(data["action_tensor"])
|
||||
|
||||
return {
|
||||
"j_left": j_left,
|
||||
"j_right": j_right,
|
||||
"buttons": buttons,
|
||||
}
|
||||
10
nitrogen/shared.py
Normal file
10
nitrogen/shared.py
Normal file
@ -0,0 +1,10 @@
|
||||
from pathlib import Path
|
||||
|
||||
BUTTON_ACTION_TOKENS = [
|
||||
'BACK', 'DPAD_DOWN', 'DPAD_LEFT', 'DPAD_RIGHT', 'DPAD_UP', 'EAST', 'GUIDE',
|
||||
'LEFT_SHOULDER', 'LEFT_THUMB', 'LEFT_TRIGGER', 'NORTH', 'RIGHT_BOTTOM', 'RIGHT_LEFT',
|
||||
'RIGHT_RIGHT', 'RIGHT_SHOULDER', 'RIGHT_THUMB', 'RIGHT_TRIGGER', 'RIGHT_UP', 'SOUTH',
|
||||
'START', 'WEST'
|
||||
]
|
||||
|
||||
PATH_REPO = Path(__file__).parent.parent.resolve()
|
||||
Reference in New Issue
Block a user