init commit
This commit is contained in:
434
nitrogen/flow_matching_transformer/modules.py
Normal file
434
nitrogen/flow_matching_transformer/modules.py
Normal file
@ -0,0 +1,434 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from diffusers import ConfigMixin, ModelMixin
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
from diffusers.models.attention import Attention, FeedForward
|
||||
from diffusers.models.embeddings import (
|
||||
SinusoidalPositionalEmbedding,
|
||||
TimestepEmbedding,
|
||||
Timesteps,
|
||||
)
|
||||
from torch import nn
|
||||
|
||||
|
||||
class TimestepEncoder(nn.Module):
|
||||
def __init__(self, embedding_dim, compute_dtype=torch.float32):
|
||||
super().__init__()
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
def forward(self, timesteps):
|
||||
dtype = next(self.parameters()).dtype
|
||||
timesteps_proj = self.time_proj(timesteps).to(dtype)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj) # (N, D)
|
||||
return timesteps_emb
|
||||
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
norm_elementwise_affine: bool = False,
|
||||
norm_eps: float = 1e-5,
|
||||
chunk_dim: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
self.chunk_dim = chunk_dim
|
||||
output_dim = embedding_dim * 2
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, output_dim)
|
||||
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
temb = self.linear(self.silu(temb))
|
||||
scale, shift = temb.chunk(2, dim=1)
|
||||
x = self.norm(x) * (1 + scale[:, None]) + shift[:, None]
|
||||
return x
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
dropout=0.0,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
attention_bias: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
|
||||
norm_eps: float = 1e-5,
|
||||
final_dropout: bool = False,
|
||||
attention_type: str = "default",
|
||||
positional_embeddings: Optional[str] = None,
|
||||
num_positional_embeddings: Optional[int] = None,
|
||||
ff_inner_dim: Optional[int] = None,
|
||||
ff_bias: bool = True,
|
||||
attention_out_bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.dropout = dropout
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.activation_fn = activation_fn
|
||||
self.attention_bias = attention_bias
|
||||
self.norm_elementwise_affine = norm_elementwise_affine
|
||||
self.positional_embeddings = positional_embeddings
|
||||
self.num_positional_embeddings = num_positional_embeddings
|
||||
self.norm_type = norm_type
|
||||
|
||||
if positional_embeddings and (num_positional_embeddings is None):
|
||||
raise ValueError(
|
||||
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
|
||||
)
|
||||
|
||||
if positional_embeddings == "sinusoidal":
|
||||
self.pos_embed = SinusoidalPositionalEmbedding(
|
||||
dim, max_seq_length=num_positional_embeddings
|
||||
)
|
||||
else:
|
||||
self.pos_embed = None
|
||||
|
||||
# Define 3 blocks. Each block has its own normalization layer.
|
||||
# 1. Self-Attn
|
||||
if norm_type == "ada_norm":
|
||||
self.norm1 = AdaLayerNorm(dim)
|
||||
else:
|
||||
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
||||
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
upcast_attention=upcast_attention,
|
||||
out_bias=attention_out_bias,
|
||||
)
|
||||
|
||||
# 3. Feed-forward
|
||||
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
||||
self.ff = FeedForward(
|
||||
dim,
|
||||
dropout=dropout,
|
||||
activation_fn=activation_fn,
|
||||
final_dropout=final_dropout,
|
||||
inner_dim=ff_inner_dim,
|
||||
bias=ff_bias,
|
||||
)
|
||||
if final_dropout:
|
||||
self.final_dropout = nn.Dropout(dropout)
|
||||
else:
|
||||
self.final_dropout = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
temb: Optional[torch.LongTensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# 0. Self-Attention
|
||||
if self.norm_type == "ada_norm":
|
||||
norm_hidden_states = self.norm1(hidden_states, temb)
|
||||
else:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
|
||||
if self.pos_embed is not None:
|
||||
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
||||
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
# encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
if self.final_dropout:
|
||||
attn_output = self.final_dropout(attn_output)
|
||||
|
||||
hidden_states = attn_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
# 4. Feed-forward
|
||||
norm_hidden_states = self.norm3(hidden_states)
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
hidden_states = ff_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
return hidden_states
|
||||
|
||||
class DiTConfig(BaseModel):
|
||||
num_attention_heads: int = Field(default=8)
|
||||
attention_head_dim: int = Field(default=64)
|
||||
output_dim: int = Field(default=26)
|
||||
num_layers: int = Field(default=12)
|
||||
dropout: float = Field(default=0.1)
|
||||
attention_bias: bool = Field(default=True)
|
||||
activation_fn: str = Field(default="gelu-approximate")
|
||||
num_embeds_ada_norm: Optional[int] = Field(default=1000)
|
||||
upcast_attention: bool = Field(default=False)
|
||||
norm_type: str = Field(default="ada_norm")
|
||||
norm_elementwise_affine: bool = Field(default=False)
|
||||
norm_eps: float = Field(default=1e-5)
|
||||
max_num_positional_embeddings: int = Field(default=512)
|
||||
compute_dtype: str = Field(default="float32")
|
||||
final_dropout: bool = Field(default=True)
|
||||
positional_embeddings: Optional[str] = Field(default="sinusoidal")
|
||||
interleave_self_attention: bool = Field(default=False)
|
||||
cross_attention_dim: Optional[int] = Field(default=None, description="Dimension of the cross-attention embeddings. If None, no cross-attention is used.")
|
||||
|
||||
|
||||
class DiT(ModelMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(self,config: DiTConfig):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
self.compute_dtype = getattr(torch, self.config.compute_dtype)
|
||||
|
||||
self.attention_head_dim = self.config.attention_head_dim
|
||||
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Timestep encoder
|
||||
self.timestep_encoder = TimestepEncoder(
|
||||
embedding_dim=self.inner_dim, compute_dtype=self.compute_dtype
|
||||
)
|
||||
|
||||
all_blocks = []
|
||||
for idx in range(self.config.num_layers):
|
||||
|
||||
use_self_attn = idx % 2 == 1 and self.config.interleave_self_attention
|
||||
curr_cross_attention_dim = self.config.cross_attention_dim if not use_self_attn else None
|
||||
|
||||
all_blocks += [
|
||||
BasicTransformerBlock(
|
||||
self.inner_dim,
|
||||
self.config.num_attention_heads,
|
||||
self.config.attention_head_dim,
|
||||
dropout=self.config.dropout,
|
||||
activation_fn=self.config.activation_fn,
|
||||
attention_bias=self.config.attention_bias,
|
||||
upcast_attention=self.config.upcast_attention,
|
||||
norm_type=self.config.norm_type,
|
||||
norm_elementwise_affine=self.config.norm_elementwise_affine,
|
||||
norm_eps=self.config.norm_eps,
|
||||
positional_embeddings=self.config.positional_embeddings,
|
||||
num_positional_embeddings=self.config.max_num_positional_embeddings,
|
||||
final_dropout=self.config.final_dropout,
|
||||
cross_attention_dim=curr_cross_attention_dim,
|
||||
)
|
||||
]
|
||||
self.transformer_blocks = nn.ModuleList(all_blocks)
|
||||
|
||||
# Output blocks
|
||||
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
|
||||
self.proj_out_2 = nn.Linear(self.inner_dim, self.config.output_dim)
|
||||
print(
|
||||
"Total number of DiT parameters: ",
|
||||
sum(p.numel() for p in self.parameters() if p.requires_grad),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor, # Shape: (B, T, D)
|
||||
encoder_hidden_states: torch.Tensor, # Shape: (B, S, D)
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
return_all_hidden_states: bool = False,
|
||||
):
|
||||
# Encode timesteps
|
||||
temb = self.timestep_encoder(timestep)
|
||||
|
||||
# Process through transformer blocks - single pass through the blocks
|
||||
hidden_states = hidden_states.contiguous()
|
||||
encoder_hidden_states = encoder_hidden_states.contiguous()
|
||||
|
||||
all_hidden_states = [hidden_states]
|
||||
|
||||
# Process through transformer blocks
|
||||
for idx, block in enumerate(self.transformer_blocks):
|
||||
if idx % 2 == 1 and self.config.interleave_self_attention:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
temb=temb,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=None,
|
||||
temb=temb,
|
||||
)
|
||||
all_hidden_states.append(hidden_states)
|
||||
|
||||
# Output processing
|
||||
conditioning = temb
|
||||
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
||||
if return_all_hidden_states:
|
||||
return self.proj_out_2(hidden_states), all_hidden_states
|
||||
else:
|
||||
return self.proj_out_2(hidden_states)
|
||||
|
||||
|
||||
class SelfAttentionTransformerConfig(BaseModel):
|
||||
num_attention_heads: int = Field(default=8)
|
||||
attention_head_dim: int = Field(default=64)
|
||||
output_dim: int = Field(default=26)
|
||||
num_layers: int = Field(default=12)
|
||||
dropout: float = Field(default=0.1)
|
||||
attention_bias: bool = Field(default=True)
|
||||
activation_fn: str = Field(default="gelu-approximate")
|
||||
num_embeds_ada_norm: Optional[int] = Field(default=1000)
|
||||
upcast_attention: bool = Field(default=False)
|
||||
max_num_positional_embeddings: int = Field(default=512)
|
||||
compute_dtype: str = Field(default="float32")
|
||||
final_dropout: bool = Field(default=True)
|
||||
positional_embeddings: Optional[str] = Field(default="sinusoidal")
|
||||
interleave_self_attention: bool = Field(default=False)
|
||||
|
||||
class SelfAttentionTransformer(ModelMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(self, config: SelfAttentionTransformerConfig):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
self.attention_head_dim = self.config.attention_head_dim
|
||||
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
self.inner_dim,
|
||||
self.config.num_attention_heads,
|
||||
self.config.attention_head_dim,
|
||||
dropout=self.config.dropout,
|
||||
activation_fn=self.config.activation_fn,
|
||||
attention_bias=self.config.attention_bias,
|
||||
upcast_attention=self.config.upcast_attention,
|
||||
positional_embeddings=self.config.positional_embeddings,
|
||||
num_positional_embeddings=self.config.max_num_positional_embeddings,
|
||||
final_dropout=self.config.final_dropout,
|
||||
)
|
||||
for _ in range(self.config.num_layers)
|
||||
]
|
||||
)
|
||||
print(
|
||||
"Total number of SelfAttentionTransformer parameters: ",
|
||||
sum(p.numel() for p in self.parameters() if p.requires_grad),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor, # Shape: (B, T, D)
|
||||
return_all_hidden_states: bool = False,
|
||||
):
|
||||
|
||||
# Process through transformer blocks - single pass through the blocks
|
||||
hidden_states = hidden_states.contiguous()
|
||||
all_hidden_states = [hidden_states]
|
||||
|
||||
# Process through transformer blocks
|
||||
for idx, block in enumerate(self.transformer_blocks):
|
||||
hidden_states = block(hidden_states)
|
||||
all_hidden_states.append(hidden_states)
|
||||
|
||||
if return_all_hidden_states:
|
||||
return hidden_states, all_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CrossAttentionTransformer(ModelMixin, ConfigMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 8,
|
||||
attention_head_dim: int = 64,
|
||||
output_dim: int = 26,
|
||||
num_layers: int = 12,
|
||||
dropout: float = 0.1,
|
||||
attention_bias: bool = True,
|
||||
activation_fn: str = "gelu-approximate",
|
||||
num_embeds_ada_norm: Optional[int] = 1000,
|
||||
upcast_attention: bool = False,
|
||||
max_num_positional_embeddings: int = 512,
|
||||
compute_dtype=torch.float32,
|
||||
final_dropout: bool = True,
|
||||
positional_embeddings: Optional[str] = "sinusoidal",
|
||||
interleave_self_attention=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
self.inner_dim,
|
||||
self.config.num_attention_heads,
|
||||
self.config.attention_head_dim,
|
||||
dropout=self.config.dropout,
|
||||
activation_fn=self.config.activation_fn,
|
||||
attention_bias=self.config.attention_bias,
|
||||
upcast_attention=self.config.upcast_attention,
|
||||
positional_embeddings=positional_embeddings,
|
||||
num_positional_embeddings=self.config.max_num_positional_embeddings,
|
||||
final_dropout=final_dropout,
|
||||
)
|
||||
for _ in range(self.config.num_layers)
|
||||
]
|
||||
)
|
||||
print(
|
||||
"Total number of CrossAttentionTransformer parameters: ",
|
||||
sum(p.numel() for p in self.parameters() if p.requires_grad),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor, # Shape: (B, T, D)
|
||||
encoder_hidden_states: torch.Tensor, # Shape: (B, S, D)
|
||||
):
|
||||
|
||||
# Process through transformer blocks - single pass through the blocks
|
||||
hidden_states = hidden_states.contiguous()
|
||||
encoder_hidden_states = encoder_hidden_states.contiguous()
|
||||
# Process through transformer blocks
|
||||
for idx, block in enumerate(self.transformer_blocks):
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
755
nitrogen/flow_matching_transformer/nitrogen.py
Normal file
755
nitrogen/flow_matching_transformer/nitrogen.py
Normal file
@ -0,0 +1,755 @@
|
||||
from dataclasses import dataclass, field
|
||||
from pydantic import BaseModel, Field
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
from torch.distributions import Beta
|
||||
from transformers import SiglipVisionModel, AutoModel
|
||||
|
||||
from .modules import DiT, DiTConfig, SelfAttentionTransformer, SelfAttentionTransformerConfig
|
||||
|
||||
_PAD_TOKEN = 0
|
||||
_IMG_TOKEN = 1
|
||||
_IMG_SEP_TOKEN = 5
|
||||
_LANG_TOKEN = 2
|
||||
_PROPRIO_TOKEN = 3
|
||||
_ACT_TOKEN = 4
|
||||
_GAME_ID_TOKEN = 6
|
||||
|
||||
|
||||
class NitroGen_Config(BaseModel):
|
||||
model_type: str = Field(default="nitrogen", frozen=True)
|
||||
|
||||
add_pos_embed: bool = Field(default=False, description="Whether to add positional embedding")
|
||||
model_dtype: str = Field(default="float32", description="Model data type.")
|
||||
diffusion_model_cfg: DiTConfig = Field(..., description="Diffusion model configuration.")
|
||||
vl_self_attention_cfg: SelfAttentionTransformerConfig = Field(..., description="VL self-attention configuration.")
|
||||
hidden_size: int = Field(default=1024, description="Input embedding dimension.")
|
||||
max_seq_len: int = Field(default=1024, description="Maxium Sequence Length")
|
||||
action_dim: int = Field(default=None, description="Action dimension.")
|
||||
action_horizon: int = Field(default=None, description="Action horizon.")
|
||||
noise_beta_alpha: float = Field(default=1.5, description="")
|
||||
noise_beta_beta: float = Field(default=1.0, description="")
|
||||
noise_s: float = Field(default=0.999, description="Flow matching noise Beta distribution s.")
|
||||
num_timestep_buckets: int = Field(default=1000, description="Number of timestep discretization buckets.")
|
||||
num_inference_timesteps: int = Field(default=None, description="Number of inference steps for noise diffusion.")
|
||||
max_num_embodiments: int = Field(default=1, description="Number of embodiments.")
|
||||
vision_encoder_name: str = Field(default="google/siglip-large-patch16-256", description="Vision encoder name.")
|
||||
vision_hidden_size: int = Field(default=768, description="Siglip hidden size.")
|
||||
add_view_embed: bool = Field(default=False, description="Whether to add view embedding.")
|
||||
|
||||
tune_vision_tower: bool = Field(default=True, description="Tune vision if True.")
|
||||
tune_mm_projector: bool = Field(default=True, description="Tune mm projector if True.")
|
||||
tune_diffusion_model: bool = Field(default=True, description="Tune diffusion model if True.")
|
||||
tune_multi_projector: bool = Field(default=True, description="Tune multi projector if True.")
|
||||
tune_vl_mixing: bool = Field(default=True, description="Tune vl mixing if True.")
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, yaml_path: str | Path) -> "NitroGen_Config":
|
||||
"""Load configuration from a YAML file."""
|
||||
with open(yaml_path, "r") as f:
|
||||
config_dict = yaml.safe_load(f)
|
||||
return cls.model_validate(config_dict)
|
||||
|
||||
def swish(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class SinusoidalPositionalEncoding(nn.Module):
|
||||
"""
|
||||
Produces a sinusoidal encoding of shape (B, T, w)
|
||||
given timesteps of shape (B, T).
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim):
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
def forward(self, timesteps):
|
||||
# timesteps: shape (B, T)
|
||||
# We'll compute sin/cos frequencies across dim T
|
||||
timesteps = timesteps.float() # ensure float
|
||||
|
||||
B, T = timesteps.shape
|
||||
device = timesteps.device
|
||||
|
||||
half_dim = self.embedding_dim // 2
|
||||
# typical log space frequencies for sinusoidal encoding
|
||||
exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * (
|
||||
torch.log(torch.tensor(10000.0)) / half_dim
|
||||
)
|
||||
# Expand timesteps to (B, T, 1) then multiply
|
||||
freqs = timesteps.unsqueeze(-1) * exponent.exp() # (B, T, half_dim)
|
||||
|
||||
sin = torch.sin(freqs)
|
||||
cos = torch.cos(freqs)
|
||||
enc = torch.cat([sin, cos], dim=-1) # (B, T, w)
|
||||
|
||||
return enc
|
||||
|
||||
|
||||
|
||||
class CategorySpecificLinear(nn.Module):
|
||||
def __init__(self, num_categories, input_dim, hidden_dim):
|
||||
super().__init__()
|
||||
self.num_categories = num_categories
|
||||
# For each category, we have separate weights and biases.
|
||||
self.W = nn.Parameter(0.02 * torch.randn(num_categories, input_dim, hidden_dim))
|
||||
self.b = nn.Parameter(torch.zeros(num_categories, hidden_dim))
|
||||
|
||||
def forward(self, x, cat_ids):
|
||||
selected_W = self.W[cat_ids]
|
||||
selected_b = self.b[cat_ids]
|
||||
return torch.bmm(x, selected_W) + selected_b.unsqueeze(1)
|
||||
|
||||
|
||||
class CategorySpecificMLP(nn.Module):
|
||||
def __init__(self, num_categories, input_dim, hidden_dim, output_dim):
|
||||
super().__init__()
|
||||
self.num_categories = num_categories
|
||||
self.layer1 = CategorySpecificLinear(num_categories, input_dim, hidden_dim)
|
||||
self.layer2 = CategorySpecificLinear(num_categories, hidden_dim, output_dim)
|
||||
|
||||
def forward(self, x, cat_ids):
|
||||
hidden = F.relu(self.layer1(x, cat_ids))
|
||||
return self.layer2(hidden, cat_ids)
|
||||
|
||||
|
||||
class MultiEmbodimentActionEncoder(nn.Module):
|
||||
def __init__(self, action_dim, hidden_size, num_embodiments):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_embodiments = num_embodiments
|
||||
|
||||
# W1: R^{w x d}, W2: R^{w x 2w}, W3: R^{w x w}
|
||||
self.W1 = CategorySpecificLinear(num_embodiments, action_dim, hidden_size) # (d -> w)
|
||||
self.W2 = CategorySpecificLinear(num_embodiments, 2 * hidden_size, hidden_size) # (2w -> w)
|
||||
self.W3 = CategorySpecificLinear(num_embodiments, hidden_size, hidden_size) # (w -> w)
|
||||
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
|
||||
|
||||
def forward(self, actions, timesteps, cat_ids):
|
||||
"""
|
||||
actions: shape (B, T, action_dim)
|
||||
timesteps: shape (B,) -- a single scalar per batch item
|
||||
cat_ids: shape (B,)
|
||||
returns: shape (B, T, hidden_size)
|
||||
"""
|
||||
B, T, _ = actions.shape
|
||||
|
||||
# 1) Expand each batch's single scalar time 'tau' across all T steps
|
||||
# so that shape => (B, T)
|
||||
# e.g. if timesteps is (B,), replicate across T
|
||||
if timesteps.dim() == 1 and timesteps.shape[0] == B:
|
||||
# shape (B,) => (B,T)
|
||||
timesteps = timesteps.unsqueeze(1).expand(-1, T)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Expected `timesteps` to have shape (B,) so we can replicate across T."
|
||||
)
|
||||
|
||||
# 2) Standard action MLP step for shape => (B, T, w)
|
||||
a_emb = self.W1(actions, cat_ids)
|
||||
|
||||
# 3) Get the sinusoidal encoding (B, T, w)
|
||||
tau_emb = self.pos_encoding(timesteps).to(dtype=a_emb.dtype)
|
||||
|
||||
# 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish
|
||||
x = torch.cat([a_emb, tau_emb], dim=-1)
|
||||
x = swish(self.W2(x, cat_ids))
|
||||
|
||||
# 5) Finally W3 => (B, T, w)
|
||||
x = self.W3(x, cat_ids)
|
||||
return x
|
||||
|
||||
|
||||
class NitroGen(torch.nn.Module):
|
||||
config_class = NitroGen_Config
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: NitroGen_Config,
|
||||
game_mapping: dict[str, int] | None = None, # Used to add a game ID token
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.vision_hidden_size = config.vision_hidden_size
|
||||
|
||||
if "siglip" in config.vision_encoder_name:
|
||||
model = SiglipVisionModel.from_pretrained(config.vision_encoder_name)
|
||||
self.vision_encoder = model.vision_model
|
||||
self.vision_encoder_type = "siglip"
|
||||
else:
|
||||
self.vision_encoder = AutoModel.from_pretrained(config.vision_encoder_name)
|
||||
self.vision_encoder_type = "hf_auto"
|
||||
self.beta_dist = Beta(config.noise_beta_alpha, config.noise_beta_beta)
|
||||
self.num_timestep_buckets = config.num_timestep_buckets
|
||||
# self.model = instantiate(config.diffusion_model_cfg)
|
||||
self.model = DiT(config=config.diffusion_model_cfg)
|
||||
self.action_dim = config.action_dim
|
||||
self.action_horizon = config.action_horizon
|
||||
self.num_inference_timesteps = config.num_inference_timesteps
|
||||
|
||||
# self.vl_self_attention_model = instantiate(config.vl_self_attention_cfg)
|
||||
self.vl_self_attention_model = SelfAttentionTransformer(config=config.vl_self_attention_cfg)
|
||||
|
||||
# if config.qformer_cfg is not None:
|
||||
# self.qformer = instantiate(config.qformer_cfg)
|
||||
# else:
|
||||
# self.qformer = nn.Identity()
|
||||
|
||||
# self.state_encoder = CategorySpecificMLP(
|
||||
# num_categories=config.max_num_embodiments,
|
||||
# input_dim=config.max_state_dim,
|
||||
# hidden_dim=self.hidden_size,
|
||||
# output_dim=self.hidden_size,
|
||||
# )
|
||||
self.action_encoder = MultiEmbodimentActionEncoder(
|
||||
action_dim=config.action_dim,
|
||||
hidden_size=self.hidden_size,
|
||||
num_embodiments=config.max_num_embodiments,
|
||||
)
|
||||
|
||||
self.action_decoder = CategorySpecificMLP(
|
||||
num_categories=config.max_num_embodiments,
|
||||
input_dim=self.hidden_size,
|
||||
hidden_dim=self.hidden_size,
|
||||
output_dim=self.action_dim,
|
||||
)
|
||||
# self.mm_vision_select_layer = config.mm_vision_select_layer
|
||||
# if config.mm_projector_cfg is not None:
|
||||
# self.mm_projector = instantiate(config.mm_projector_cfg)
|
||||
# else:
|
||||
self.mm_projector = None
|
||||
if config.add_pos_embed:
|
||||
self.position_embedding = nn.Embedding(config.max_seq_len, self.hidden_size)
|
||||
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
|
||||
# if config.add_view_embed:
|
||||
# self.view_embedding = nn.Embedding(config.max_num_views, self.hidden_size)
|
||||
# nn.init.normal_(self.view_embedding.weight, mean=0.0, std=0.02)
|
||||
|
||||
# self.vision_projector = None
|
||||
# if config.vision_hidden_size != self.hidden_size:
|
||||
# self.vision_projector = nn.Sequential(
|
||||
# nn.Linear(config.vision_hidden_size, self.hidden_size),
|
||||
# nn.LayerNorm(self.hidden_size),
|
||||
# )
|
||||
|
||||
self.game_mapping = game_mapping
|
||||
# Create an embedding table for game IDs
|
||||
# Game ID tokens will be put inside vision-language tokens
|
||||
# so they need to be projected to the same dimension
|
||||
if self.game_mapping is not None:
|
||||
# 0 = unconditional
|
||||
self.game_embedding = nn.Embedding(
|
||||
len(self.game_mapping),
|
||||
self.vision_hidden_size,
|
||||
padding_idx=0,
|
||||
scale_grad_by_freq=True
|
||||
)
|
||||
|
||||
self.set_trainable_parameters(
|
||||
tune_multi_projector=config.tune_multi_projector,
|
||||
tune_diffusion_model=config.tune_diffusion_model,
|
||||
tune_vision_tower=config.tune_vision_tower,
|
||||
tune_mm_projector=config.tune_mm_projector,
|
||||
tune_vl_mixing=config.tune_vl_mixing,
|
||||
)
|
||||
|
||||
print(
|
||||
"total number of parameters: %e",
|
||||
sum(p.numel() for p in self.parameters() if p.requires_grad),
|
||||
)
|
||||
|
||||
def set_trainable_parameters(
|
||||
self,
|
||||
tune_multi_projector: bool = True,
|
||||
tune_diffusion_model: bool = True,
|
||||
tune_vision_tower: bool = True,
|
||||
tune_mm_projector: bool = True,
|
||||
tune_vl_mixing: bool = True,
|
||||
):
|
||||
self.tune_multi_projector = tune_multi_projector
|
||||
self.tune_diffusion_model = tune_diffusion_model
|
||||
self.tune_vision_tower = tune_vision_tower
|
||||
self.tune_mm_projector = tune_mm_projector
|
||||
self.tune_vl_mixing = tune_vl_mixing
|
||||
|
||||
for param in self.parameters():
|
||||
param.requires_grad = True
|
||||
# ### Always freeze language encoder
|
||||
# self.siglip_model.text_model.requires_grad_(False)
|
||||
# # Freeze unused parameters in siglip vision encoder
|
||||
# self.siglip_model.logit_scale.requires_grad = False
|
||||
# self.siglip_model.logit_bias.requires_grad = False
|
||||
|
||||
# For siglip, we have to
|
||||
if self.vision_encoder_type == "siglip":
|
||||
for param in self.vision_encoder.encoder.layers[11].parameters():
|
||||
param.requires_grad = False
|
||||
for param in self.vision_encoder.head.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
|
||||
# Freeze parameters
|
||||
if not tune_multi_projector:
|
||||
# self.state_encoder.requires_grad_(False)
|
||||
self.action_encoder.requires_grad_(False)
|
||||
self.action_decoder.requires_grad_(False)
|
||||
if self.config.add_pos_embed:
|
||||
self.position_embedding.requires_grad_(False)
|
||||
if self.config.add_view_embed:
|
||||
self.view_embedding.requires_grad_(False)
|
||||
if not tune_diffusion_model:
|
||||
self.model.requires_grad_(False)
|
||||
if not tune_vision_tower:
|
||||
self.vision_encoder.requires_grad_(False)
|
||||
if self.mm_projector is not None and not tune_mm_projector:
|
||||
self.mm_projector.requires_grad_(False)
|
||||
if not tune_vl_mixing:
|
||||
self.vl_self_attention_model.requires_grad_(False)
|
||||
|
||||
print(f"Tune action head multi_projector: {self.tune_multi_projector}")
|
||||
print(f"Tune action head diffusion model: {self.tune_diffusion_model}")
|
||||
print(f"Tune action head vision tower: {self.tune_vision_tower}")
|
||||
print(f"Tune action head mm_projector: {self.tune_mm_projector}")
|
||||
print(f"Tune action head vl_mixing: {self.tune_vl_mixing}")
|
||||
# Check if any parameters are still trainable. If not, print a warning.
|
||||
if not any(p.requires_grad for p in self.parameters()):
|
||||
print("Warning: No action head trainable parameters found.")
|
||||
|
||||
def set_frozen_modules_to_eval_mode(self):
|
||||
"""
|
||||
Huggingface will call model.train() at each training_step. To ensure
|
||||
the expected behaviors for modules like dropout, batchnorm, etc., we
|
||||
need to call model.eval() for the frozen modules.
|
||||
"""
|
||||
if self.training:
|
||||
# self.siglip_model.text_model.eval()
|
||||
if not self.tune_multi_projector:
|
||||
# self.state_encoder.eval()
|
||||
self.action_encoder.eval()
|
||||
self.action_decoder.eval()
|
||||
if self.config.add_pos_embed:
|
||||
self.position_embedding.eval()
|
||||
# if self.config.add_view_embed:
|
||||
# self.view_embedding.eval()
|
||||
if not self.tune_diffusion_model:
|
||||
self.model.eval()
|
||||
if not self.tune_vision_tower:
|
||||
self.vision_encoder.eval()
|
||||
if self.mm_projector is not None and not self.tune_mm_projector:
|
||||
self.mm_projector.eval()
|
||||
if not self.tune_vl_mixing:
|
||||
self.vl_self_attention_model.eval()
|
||||
|
||||
# This function is supposedly incorrect
|
||||
# def sample_time(self, batch_size, device, dtype):
|
||||
# sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype)
|
||||
# return (self.config.noise_s - sample) / self.config.noise_s
|
||||
|
||||
def sample_time(self, batch_size, device, dtype):
|
||||
sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype)
|
||||
return (1 - sample) * self.config.noise_s
|
||||
|
||||
def encode_images(self, images): #, view_ids):
|
||||
batch_size, num_frames, channels, height, width = images.shape
|
||||
images = images.reshape(-1, channels, height, width)
|
||||
|
||||
image_features = self.vision_encoder(images)["last_hidden_state"]
|
||||
image_features = rearrange(image_features, "(b f) n d -> b f n d", f=num_frames)
|
||||
|
||||
# if self.vision_projector is not None:
|
||||
# # change the hidden dimension of the vision features
|
||||
# image_features = self.vision_projector(image_features)
|
||||
if self.mm_projector is not None:
|
||||
image_features = self.mm_projector(image_features) # [B, 256, 1024] -> [B, 16, 1024]
|
||||
return image_features
|
||||
|
||||
def prepare_input_embs(self, vl_token_ids, sa_token_ids, vision, action, dropped_images, game_ids=None):
|
||||
B, T = vl_token_ids.shape
|
||||
vl_embs = torch.full(
|
||||
size=(B, T, self.vision_hidden_size), fill_value=0.0, dtype=vision.dtype, device=vision.device
|
||||
)
|
||||
|
||||
# Extract dimensions from vision tensor
|
||||
B, num_images, tokens_per_image, hidden_size = vision.shape
|
||||
|
||||
# Create mask for _IMG_TOKEN positions
|
||||
vision_mask = (vl_token_ids == _IMG_TOKEN) # [B, T]
|
||||
|
||||
# Flatten vision tensor over the num_images dimension
|
||||
vision_flat = vision.reshape(B, -1, self.vision_hidden_size) # [B, T * tokens_per_image, hidden_size]
|
||||
|
||||
# Create a mask for the flattened vision dimension
|
||||
# Each image contributes tokens_per_image tokens, so expand the mask accordingly
|
||||
non_dropped_mask_expanded = (dropped_images == 0).unsqueeze(-1).repeat(1, 1, tokens_per_image).reshape(B, -1) # [B, T * tokens_per_image]
|
||||
|
||||
# Select only non-dropped vision embeddings
|
||||
# This will give us the embeddings we need to place
|
||||
valid_vision_embs = vision_flat[non_dropped_mask_expanded] # [total_valid_tokens, 1152]
|
||||
|
||||
assert valid_vision_embs.shape[0] == vision_mask.sum().item(), (
|
||||
f"Number of valid vision embeddings {valid_vision_embs.shape[0]} does not match "
|
||||
f"the number of _IMG_TOKEN positions {vision_mask.sum().item()}"
|
||||
)
|
||||
# Now we need to place these at the vision_mask positions
|
||||
# Get indices where vision_mask is True
|
||||
batch_indices, token_indices = vision_mask.nonzero(as_tuple=True)
|
||||
|
||||
# Place the valid embeddings at the masked positions
|
||||
vl_embs[batch_indices, token_indices] = valid_vision_embs
|
||||
|
||||
# Handle Game ID tokens
|
||||
if self.game_mapping is not None and game_ids is not None:
|
||||
game_mask = vl_token_ids == _GAME_ID_TOKEN # shape: (B, T)
|
||||
num_game_tokens = game_mask.sum().item()
|
||||
if num_game_tokens > 0:
|
||||
|
||||
# Assert that each batch item has exactly one game token
|
||||
game_tokens_per_batch = game_mask.sum(dim=1) # [B] - count of game tokens per batch item
|
||||
assert torch.all(game_tokens_per_batch == 1), (
|
||||
f"Expected exactly 1 game token per batch item, but got: {game_tokens_per_batch.tolist()}. "
|
||||
f"Each batch item must have exactly one _GAME_ID_TOKEN."
|
||||
)
|
||||
|
||||
# Get game embeddings for each batch item
|
||||
game_embs = self.game_embedding(game_ids) # [B, vision_hidden_size]
|
||||
batch_indices, token_indices = game_mask.nonzero(as_tuple=True)
|
||||
vl_embs[batch_indices, token_indices] = game_embs[batch_indices].to(dtype=vl_embs.dtype)
|
||||
|
||||
# Project image separator using the learnable sep_embedding.
|
||||
sep_mask = vl_token_ids == _IMG_SEP_TOKEN # shape: (B, T)
|
||||
num_sep = sep_mask.sum().item()
|
||||
if num_sep > 0:
|
||||
# Expand the separator embedding for each occurrence.
|
||||
repeated_sep = self.vis_sep_embedding.unsqueeze(0).expand(num_sep, self.hidden_size)
|
||||
# Assign the separator embeddings to the correct positions.
|
||||
vl_embs[sep_mask] = repeated_sep.to(dtype=vl_embs.dtype)
|
||||
|
||||
B, T = sa_token_ids.shape
|
||||
sa_embs = torch.full(
|
||||
size=(B, T, self.hidden_size), fill_value=0.0, dtype=vision.dtype, device=vision.device
|
||||
)
|
||||
|
||||
# Project state.
|
||||
# state_mask = sa_token_ids == _PROPRIO_TOKEN
|
||||
# state_mask = state_mask.unsqueeze(-1).expand_as(sa_embs)
|
||||
# sa_embs = sa_embs.masked_scatter(state_mask, state)
|
||||
|
||||
# Project action.
|
||||
action_mask = sa_token_ids == _ACT_TOKEN
|
||||
action_mask = action_mask.unsqueeze(-1).expand_as(sa_embs)
|
||||
sa_embs = sa_embs.masked_scatter(action_mask, action)
|
||||
|
||||
# Add positional embeddings
|
||||
pos_ids = torch.arange(T, dtype=torch.long, device=sa_token_ids.device)
|
||||
if self.config.add_pos_embed:
|
||||
pos_embs = self.position_embedding(pos_ids) # (T, hidden_size)
|
||||
pos_embs = pos_embs.unsqueeze(0).expand(B, T, self.hidden_size)
|
||||
sa_embs = sa_embs + pos_embs
|
||||
return vl_embs, sa_embs
|
||||
|
||||
def pack_actions(self, buttons, j_left, j_right):
|
||||
# Check that the first three dims of each input is the same
|
||||
assert buttons.shape[:3] == j_left.shape[:3] == j_right.shape[:3], (
|
||||
f"buttons shape: {buttons.shape}, "
|
||||
f"j_left shape: {j_left.shape}, "
|
||||
f"j_right shape: {j_right.shape}"
|
||||
)
|
||||
|
||||
# Normalize the joysticks to 0,1
|
||||
j_left = (j_left + 1) / 2.
|
||||
j_right = (j_right + 1) / 2.
|
||||
|
||||
# Concatenate the buttons and joysticks along the last dimension
|
||||
action = torch.cat([j_left, j_right, buttons], dim=-1)
|
||||
|
||||
# Squeeze the second dimension of each input: this is the number of chunks, which is 1 here
|
||||
action = action.squeeze(1)
|
||||
return action
|
||||
|
||||
# def unpack_actions(self, actions):
|
||||
# # Unpack the actions into j_left, j_right, buttons
|
||||
# j_left = actions[:, :, :2]
|
||||
# j_right = actions[:, :, 2:4]
|
||||
# buttons = actions[:, :, 4:]
|
||||
|
||||
# # Denormalize the joysticks back to -1,1
|
||||
# j_left = j_left * 2. - 1.
|
||||
# j_right = j_right * 2. - 1.
|
||||
|
||||
# # Clip into [-1,1]
|
||||
# j_left = torch.clamp(j_left, -1, 1)
|
||||
# j_right = torch.clamp(j_right, -1, 1)
|
||||
|
||||
# # Threshold the buttons to 0,1
|
||||
# buttons = (buttons > 0.5).float()
|
||||
# return j_left, j_right, buttons
|
||||
|
||||
# ========= ActionHead required ============
|
||||
def forward(self, data: dict) -> dict:
|
||||
self.set_frozen_modules_to_eval_mode()
|
||||
|
||||
# data = action_input
|
||||
embodiment_id = data["embodiment_id"]
|
||||
|
||||
# # Check which data is present.
|
||||
# has_real_action = action_input.has_real_action
|
||||
has_real_action = data["has_real_action"]
|
||||
|
||||
# 1) Encode images/text/state
|
||||
visual_features = self.encode_images(data["images"]) #, data["view_ids"])
|
||||
# text_features = self.siglip_model.text_model(
|
||||
# input_ids=data["lang_input_ids"]
|
||||
# ).last_hidden_state
|
||||
# state_features = self.state_encoder(data["state"], embodiment_id)
|
||||
|
||||
# 2) Prepare noisy trajectory
|
||||
actions = data["actions"]
|
||||
noise = torch.randn_like(actions)
|
||||
t = self.sample_time(actions.shape[0], device=actions.device, dtype=actions.dtype)
|
||||
t = t[:, None, None] # shape (B,1,1) for broadcast
|
||||
|
||||
noisy_trajectory = (1 - t) * noise + t * actions
|
||||
velocity = actions - noise
|
||||
|
||||
# 3) Convert (continuous) t -> discrete if needed
|
||||
t_discretized = (t[:, 0, 0] * self.num_timestep_buckets).long()
|
||||
|
||||
# 4) Get action encoder embeddings with correct time argument
|
||||
action_features = self.action_encoder(noisy_trajectory, t_discretized, embodiment_id)
|
||||
|
||||
# 5) Prepare full input to DiT (or your model)
|
||||
vl_embs, sa_embs = self.prepare_input_embs(
|
||||
data["vl_token_ids"],
|
||||
data["sa_token_ids"],
|
||||
visual_features,
|
||||
# text_features,
|
||||
# state_features,
|
||||
action_features,
|
||||
data["dropped_images"],
|
||||
game_ids=data.get("game_id"),
|
||||
)
|
||||
|
||||
vl_embs = self.vl_self_attention_model(vl_embs)
|
||||
# vl_embs = self.qformer(vl_embs)
|
||||
model_output, all_hidden_states = self.model(
|
||||
hidden_states=sa_embs,
|
||||
encoder_hidden_states=vl_embs,
|
||||
encoder_attention_mask=data["vl_attn_mask"],
|
||||
timestep=t_discretized,
|
||||
return_all_hidden_states=True,
|
||||
)
|
||||
pred = self.action_decoder(model_output, embodiment_id)
|
||||
pred_actions = pred[:, -actions.shape[1] :]
|
||||
|
||||
# 6) Flow-matching or velocity-prediction MSE
|
||||
# Mask for variable-length trajectories
|
||||
mask = data["actions_mask"] # shape => (B, seq_len_of_actions, ...)
|
||||
raw_loss = F.mse_loss(pred_actions, velocity, reduction="none")
|
||||
mask = has_real_action[:, None, None] * mask
|
||||
raw_loss = raw_loss * mask
|
||||
action_loss = (has_real_action[:, None, None] * raw_loss).sum() / (mask.sum() + 1e-6)
|
||||
|
||||
loss = action_loss
|
||||
|
||||
return {
|
||||
"loss": loss,
|
||||
}
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_action(self, data: dict, old_layout:bool = False) -> dict:
|
||||
"""
|
||||
For i in [0..N-1]:
|
||||
1) t = i/N
|
||||
2) velocity = model(x(t), t)
|
||||
3) x(t + dt) = x(t) + dt * velocity
|
||||
"""
|
||||
|
||||
# data = action_input
|
||||
embodiment_id = data["embodiment_id"]
|
||||
|
||||
batch_size = data["images"].shape[0]
|
||||
device = data["images"].device
|
||||
dtype = data["images"].dtype
|
||||
actions = torch.randn(
|
||||
size=(batch_size, self.config.action_horizon, self.config.action_dim),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# 1) Hyperparameters for flow sampling
|
||||
num_steps = self.num_inference_timesteps
|
||||
dt = 1.0 / num_steps
|
||||
|
||||
# 2) Encode static context (images, text, state) once if it does not depend on actions
|
||||
visual_features = self.encode_images(data["images"]) #, data["view_ids"])
|
||||
# text_features = self.siglip_model.text_model(
|
||||
# input_ids=data["lang_input_ids"]
|
||||
# ).last_hidden_state
|
||||
# state_features = self.state_encoder(data["state"], embodiment_id)
|
||||
|
||||
# 3) Start denoising the actions
|
||||
for i in range(num_steps):
|
||||
# ---- (a) Discretize continuous time in [0,1]
|
||||
t_cont = i / float(num_steps) # e.g. goes 0, 1/N, 2/N, ...
|
||||
t_discretized = int(t_cont * self.num_timestep_buckets)
|
||||
|
||||
# ---- (b) Build embeddings (actions included)
|
||||
# Pass the *current* actions at time t into the action encoder
|
||||
action_features = self.action_encoder(
|
||||
actions,
|
||||
(torch.ones(actions.shape[0]) * t_discretized).to(device),
|
||||
embodiment_id,
|
||||
)
|
||||
vl_embs, sa_embs = self.prepare_input_embs(
|
||||
data["vl_token_ids"],
|
||||
data["sa_token_ids"],
|
||||
visual_features,
|
||||
# text_features,
|
||||
# state_features,
|
||||
action_features,
|
||||
data["dropped_images"],
|
||||
game_ids=data["game_ids"],
|
||||
)
|
||||
vl_embs = self.vl_self_attention_model(vl_embs)
|
||||
# vl_embs = self.qformer(vl_embs)
|
||||
# ---- (c) Forward pass to get velocity = d/dt x(t)
|
||||
timesteps = torch.from_numpy(np.array([t_discretized])).to(device).long()
|
||||
model_output = self.model(
|
||||
hidden_states=sa_embs,
|
||||
encoder_hidden_states=vl_embs,
|
||||
encoder_attention_mask=data["vl_attn_mask"],
|
||||
timestep=timesteps,
|
||||
)
|
||||
pred = self.action_decoder(model_output, embodiment_id)
|
||||
pred_velocity = pred[:, -actions.shape[1] :]
|
||||
|
||||
# ---- (d) Naive Euler step: x(t + dt) = x(t) + dt * velocity
|
||||
actions = actions + dt * pred_velocity
|
||||
|
||||
return {
|
||||
"action_tensor": actions,
|
||||
}
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_action_with_cfg(self, data_cond: dict, data_uncond: dict, cfg_scale: float = 1.0) -> dict:
|
||||
"""
|
||||
Use a form of classifier free guidance to sample actions. This can only be used on
|
||||
models that were trained on multiple frames of actions. The idea is that we sample
|
||||
velocity with and without the frame history, and then we push the sampled actions
|
||||
towards the ones that were sampled with the frame history.
|
||||
|
||||
data_with_hist = conditional input
|
||||
data_without_hist = unconditional input
|
||||
|
||||
This function works with any kind of conditioning, not just history.
|
||||
|
||||
For i in [0..N-1]:
|
||||
1) t = i/N
|
||||
2) velocity = (1 - cfg_scale) * model(x(t), t, None) + cfg_scale * model(x(t), t, history)
|
||||
3) x(t + dt) = x(t) + dt * velocity
|
||||
"""
|
||||
|
||||
# data = action_input
|
||||
embodiment_id = data_cond["embodiment_id"]
|
||||
|
||||
batch_size = data_cond["images"].shape[0]
|
||||
device = data_cond["images"].device
|
||||
dtype = data_cond["images"].dtype
|
||||
actions = torch.randn(
|
||||
size=(batch_size, self.config.action_horizon, self.config.action_dim),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# 1) Hyperparameters for flow sampling
|
||||
num_steps = self.num_inference_timesteps
|
||||
dt = 1.0 / num_steps
|
||||
|
||||
# 2) Encode static context (images, text, state) once if it does not depend on actions
|
||||
visual_features_cond = self.encode_images(data_cond["images"])
|
||||
visual_features_uncond = self.encode_images(data_uncond["images"])
|
||||
# text_features = self.siglip_model.text_model(
|
||||
# input_ids=data["lang_input_ids"]
|
||||
# ).last_hidden_state
|
||||
# state_features = self.state_encoder(data["state"], embodiment_id)
|
||||
|
||||
# 3) Start denoising the actions
|
||||
for i in range(num_steps):
|
||||
# ---- (a) Discretize continuous time in [0,1]
|
||||
t_cont = i / float(num_steps) # e.g. goes 0, 1/N, 2/N, ...
|
||||
t_discretized = int(t_cont * self.num_timestep_buckets)
|
||||
|
||||
# ---- (b) Build embeddings (actions included)
|
||||
# Pass the *current* actions at time t into the action encoder
|
||||
action_features = self.action_encoder(
|
||||
actions,
|
||||
(torch.ones(actions.shape[0]) * t_discretized).to(device),
|
||||
embodiment_id,
|
||||
)
|
||||
|
||||
# Predict velocity with history
|
||||
vl_embs, sa_embs = self.prepare_input_embs(
|
||||
data_cond["vl_token_ids"],
|
||||
data_cond["sa_token_ids"],
|
||||
visual_features_cond,
|
||||
action_features,
|
||||
data_cond["dropped_images"],
|
||||
)
|
||||
vl_embs = self.vl_self_attention_model(vl_embs)
|
||||
# ---- (c) Forward pass to get velocity = d/dt x(t)
|
||||
timesteps = torch.from_numpy(np.array([t_discretized])).to(device).long()
|
||||
model_output = self.model(
|
||||
hidden_states=sa_embs,
|
||||
encoder_hidden_states=vl_embs,
|
||||
encoder_attention_mask=data_cond["vl_attn_mask"],
|
||||
timestep=timesteps,
|
||||
)
|
||||
pred = self.action_decoder(model_output, embodiment_id)
|
||||
pred_velocity_cond = pred[:, -actions.shape[1] :]
|
||||
|
||||
# Predict velocity without history
|
||||
vl_embs, sa_embs = self.prepare_input_embs(
|
||||
data_uncond["vl_token_ids"],
|
||||
data_uncond["sa_token_ids"],
|
||||
visual_features_uncond,
|
||||
action_features,
|
||||
data_uncond["dropped_images"],
|
||||
)
|
||||
vl_embs = self.vl_self_attention_model(vl_embs)
|
||||
# ---- (c) Forward pass to get velocity = d/dt x(t)
|
||||
timesteps = torch.from_numpy(np.array([t_discretized])).to(device).long()
|
||||
model_output = self.model(
|
||||
hidden_states=sa_embs,
|
||||
encoder_hidden_states=vl_embs,
|
||||
encoder_attention_mask=data_uncond["vl_attn_mask"],
|
||||
timestep=timesteps,
|
||||
)
|
||||
pred = self.action_decoder(model_output, embodiment_id)
|
||||
pred_velocity_uncond = pred[:, -actions.shape[1] :]
|
||||
|
||||
# ---- (d) Combine velocities with cfg_scale
|
||||
pred_velocity = pred_velocity_cond + cfg_scale * (pred_velocity_cond - pred_velocity_uncond)
|
||||
|
||||
# ---- (e) Naive Euler step: x(t + dt) = x(t) + dt * velocity
|
||||
actions = actions + dt * pred_velocity
|
||||
|
||||
return {
|
||||
"action_tensor": actions,
|
||||
}
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(iter(self.parameters())).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(iter(self.parameters())).dtype
|
||||
Reference in New Issue
Block a user