756 lines
31 KiB
Python
756 lines
31 KiB
Python
from dataclasses import dataclass, field
|
|
from pydantic import BaseModel, Field
|
|
from pathlib import Path
|
|
|
|
import yaml
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from einops import rearrange
|
|
from torch import nn
|
|
from torch.distributions import Beta
|
|
from transformers import SiglipVisionModel, AutoModel
|
|
|
|
from .modules import DiT, DiTConfig, SelfAttentionTransformer, SelfAttentionTransformerConfig
|
|
|
|
_PAD_TOKEN = 0
|
|
_IMG_TOKEN = 1
|
|
_IMG_SEP_TOKEN = 5
|
|
_LANG_TOKEN = 2
|
|
_PROPRIO_TOKEN = 3
|
|
_ACT_TOKEN = 4
|
|
_GAME_ID_TOKEN = 6
|
|
|
|
|
|
class NitroGen_Config(BaseModel):
|
|
model_type: str = Field(default="nitrogen", frozen=True)
|
|
|
|
add_pos_embed: bool = Field(default=False, description="Whether to add positional embedding")
|
|
model_dtype: str = Field(default="float32", description="Model data type.")
|
|
diffusion_model_cfg: DiTConfig = Field(..., description="Diffusion model configuration.")
|
|
vl_self_attention_cfg: SelfAttentionTransformerConfig = Field(..., description="VL self-attention configuration.")
|
|
hidden_size: int = Field(default=1024, description="Input embedding dimension.")
|
|
max_seq_len: int = Field(default=1024, description="Maxium Sequence Length")
|
|
action_dim: int = Field(default=None, description="Action dimension.")
|
|
action_horizon: int = Field(default=None, description="Action horizon.")
|
|
noise_beta_alpha: float = Field(default=1.5, description="")
|
|
noise_beta_beta: float = Field(default=1.0, description="")
|
|
noise_s: float = Field(default=0.999, description="Flow matching noise Beta distribution s.")
|
|
num_timestep_buckets: int = Field(default=1000, description="Number of timestep discretization buckets.")
|
|
num_inference_timesteps: int = Field(default=None, description="Number of inference steps for noise diffusion.")
|
|
max_num_embodiments: int = Field(default=1, description="Number of embodiments.")
|
|
vision_encoder_name: str = Field(default="google/siglip-large-patch16-256", description="Vision encoder name.")
|
|
vision_hidden_size: int = Field(default=768, description="Siglip hidden size.")
|
|
add_view_embed: bool = Field(default=False, description="Whether to add view embedding.")
|
|
|
|
tune_vision_tower: bool = Field(default=True, description="Tune vision if True.")
|
|
tune_mm_projector: bool = Field(default=True, description="Tune mm projector if True.")
|
|
tune_diffusion_model: bool = Field(default=True, description="Tune diffusion model if True.")
|
|
tune_multi_projector: bool = Field(default=True, description="Tune multi projector if True.")
|
|
tune_vl_mixing: bool = Field(default=True, description="Tune vl mixing if True.")
|
|
|
|
@classmethod
|
|
def from_yaml(cls, yaml_path: str | Path) -> "NitroGen_Config":
|
|
"""Load configuration from a YAML file."""
|
|
with open(yaml_path, "r") as f:
|
|
config_dict = yaml.safe_load(f)
|
|
return cls.model_validate(config_dict)
|
|
|
|
def swish(x):
|
|
return x * torch.sigmoid(x)
|
|
|
|
|
|
class SinusoidalPositionalEncoding(nn.Module):
|
|
"""
|
|
Produces a sinusoidal encoding of shape (B, T, w)
|
|
given timesteps of shape (B, T).
|
|
"""
|
|
|
|
def __init__(self, embedding_dim):
|
|
super().__init__()
|
|
self.embedding_dim = embedding_dim
|
|
|
|
def forward(self, timesteps):
|
|
# timesteps: shape (B, T)
|
|
# We'll compute sin/cos frequencies across dim T
|
|
timesteps = timesteps.float() # ensure float
|
|
|
|
B, T = timesteps.shape
|
|
device = timesteps.device
|
|
|
|
half_dim = self.embedding_dim // 2
|
|
# typical log space frequencies for sinusoidal encoding
|
|
exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * (
|
|
torch.log(torch.tensor(10000.0)) / half_dim
|
|
)
|
|
# Expand timesteps to (B, T, 1) then multiply
|
|
freqs = timesteps.unsqueeze(-1) * exponent.exp() # (B, T, half_dim)
|
|
|
|
sin = torch.sin(freqs)
|
|
cos = torch.cos(freqs)
|
|
enc = torch.cat([sin, cos], dim=-1) # (B, T, w)
|
|
|
|
return enc
|
|
|
|
|
|
|
|
class CategorySpecificLinear(nn.Module):
|
|
def __init__(self, num_categories, input_dim, hidden_dim):
|
|
super().__init__()
|
|
self.num_categories = num_categories
|
|
# For each category, we have separate weights and biases.
|
|
self.W = nn.Parameter(0.02 * torch.randn(num_categories, input_dim, hidden_dim))
|
|
self.b = nn.Parameter(torch.zeros(num_categories, hidden_dim))
|
|
|
|
def forward(self, x, cat_ids):
|
|
selected_W = self.W[cat_ids]
|
|
selected_b = self.b[cat_ids]
|
|
return torch.bmm(x, selected_W) + selected_b.unsqueeze(1)
|
|
|
|
|
|
class CategorySpecificMLP(nn.Module):
|
|
def __init__(self, num_categories, input_dim, hidden_dim, output_dim):
|
|
super().__init__()
|
|
self.num_categories = num_categories
|
|
self.layer1 = CategorySpecificLinear(num_categories, input_dim, hidden_dim)
|
|
self.layer2 = CategorySpecificLinear(num_categories, hidden_dim, output_dim)
|
|
|
|
def forward(self, x, cat_ids):
|
|
hidden = F.relu(self.layer1(x, cat_ids))
|
|
return self.layer2(hidden, cat_ids)
|
|
|
|
|
|
class MultiEmbodimentActionEncoder(nn.Module):
|
|
def __init__(self, action_dim, hidden_size, num_embodiments):
|
|
super().__init__()
|
|
self.hidden_size = hidden_size
|
|
self.num_embodiments = num_embodiments
|
|
|
|
# W1: R^{w x d}, W2: R^{w x 2w}, W3: R^{w x w}
|
|
self.W1 = CategorySpecificLinear(num_embodiments, action_dim, hidden_size) # (d -> w)
|
|
self.W2 = CategorySpecificLinear(num_embodiments, 2 * hidden_size, hidden_size) # (2w -> w)
|
|
self.W3 = CategorySpecificLinear(num_embodiments, hidden_size, hidden_size) # (w -> w)
|
|
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
|
|
|
|
def forward(self, actions, timesteps, cat_ids):
|
|
"""
|
|
actions: shape (B, T, action_dim)
|
|
timesteps: shape (B,) -- a single scalar per batch item
|
|
cat_ids: shape (B,)
|
|
returns: shape (B, T, hidden_size)
|
|
"""
|
|
B, T, _ = actions.shape
|
|
|
|
# 1) Expand each batch's single scalar time 'tau' across all T steps
|
|
# so that shape => (B, T)
|
|
# e.g. if timesteps is (B,), replicate across T
|
|
if timesteps.dim() == 1 and timesteps.shape[0] == B:
|
|
# shape (B,) => (B,T)
|
|
timesteps = timesteps.unsqueeze(1).expand(-1, T)
|
|
else:
|
|
raise ValueError(
|
|
"Expected `timesteps` to have shape (B,) so we can replicate across T."
|
|
)
|
|
|
|
# 2) Standard action MLP step for shape => (B, T, w)
|
|
a_emb = self.W1(actions, cat_ids)
|
|
|
|
# 3) Get the sinusoidal encoding (B, T, w)
|
|
tau_emb = self.pos_encoding(timesteps).to(dtype=a_emb.dtype)
|
|
|
|
# 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish
|
|
x = torch.cat([a_emb, tau_emb], dim=-1)
|
|
x = swish(self.W2(x, cat_ids))
|
|
|
|
# 5) Finally W3 => (B, T, w)
|
|
x = self.W3(x, cat_ids)
|
|
return x
|
|
|
|
|
|
class NitroGen(torch.nn.Module):
|
|
config_class = NitroGen_Config
|
|
supports_gradient_checkpointing = True
|
|
|
|
def __init__(
|
|
self,
|
|
config: NitroGen_Config,
|
|
game_mapping: dict[str, int] | None = None, # Used to add a game ID token
|
|
):
|
|
super().__init__()
|
|
self.config = config
|
|
self.hidden_size = config.hidden_size
|
|
self.vision_hidden_size = config.vision_hidden_size
|
|
|
|
if "siglip" in config.vision_encoder_name:
|
|
model = SiglipVisionModel.from_pretrained(config.vision_encoder_name)
|
|
self.vision_encoder = model.vision_model
|
|
self.vision_encoder_type = "siglip"
|
|
else:
|
|
self.vision_encoder = AutoModel.from_pretrained(config.vision_encoder_name)
|
|
self.vision_encoder_type = "hf_auto"
|
|
self.beta_dist = Beta(config.noise_beta_alpha, config.noise_beta_beta)
|
|
self.num_timestep_buckets = config.num_timestep_buckets
|
|
# self.model = instantiate(config.diffusion_model_cfg)
|
|
self.model = DiT(config=config.diffusion_model_cfg)
|
|
self.action_dim = config.action_dim
|
|
self.action_horizon = config.action_horizon
|
|
self.num_inference_timesteps = config.num_inference_timesteps
|
|
|
|
# self.vl_self_attention_model = instantiate(config.vl_self_attention_cfg)
|
|
self.vl_self_attention_model = SelfAttentionTransformer(config=config.vl_self_attention_cfg)
|
|
|
|
# if config.qformer_cfg is not None:
|
|
# self.qformer = instantiate(config.qformer_cfg)
|
|
# else:
|
|
# self.qformer = nn.Identity()
|
|
|
|
# self.state_encoder = CategorySpecificMLP(
|
|
# num_categories=config.max_num_embodiments,
|
|
# input_dim=config.max_state_dim,
|
|
# hidden_dim=self.hidden_size,
|
|
# output_dim=self.hidden_size,
|
|
# )
|
|
self.action_encoder = MultiEmbodimentActionEncoder(
|
|
action_dim=config.action_dim,
|
|
hidden_size=self.hidden_size,
|
|
num_embodiments=config.max_num_embodiments,
|
|
)
|
|
|
|
self.action_decoder = CategorySpecificMLP(
|
|
num_categories=config.max_num_embodiments,
|
|
input_dim=self.hidden_size,
|
|
hidden_dim=self.hidden_size,
|
|
output_dim=self.action_dim,
|
|
)
|
|
# self.mm_vision_select_layer = config.mm_vision_select_layer
|
|
# if config.mm_projector_cfg is not None:
|
|
# self.mm_projector = instantiate(config.mm_projector_cfg)
|
|
# else:
|
|
self.mm_projector = None
|
|
if config.add_pos_embed:
|
|
self.position_embedding = nn.Embedding(config.max_seq_len, self.hidden_size)
|
|
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
|
|
# if config.add_view_embed:
|
|
# self.view_embedding = nn.Embedding(config.max_num_views, self.hidden_size)
|
|
# nn.init.normal_(self.view_embedding.weight, mean=0.0, std=0.02)
|
|
|
|
# self.vision_projector = None
|
|
# if config.vision_hidden_size != self.hidden_size:
|
|
# self.vision_projector = nn.Sequential(
|
|
# nn.Linear(config.vision_hidden_size, self.hidden_size),
|
|
# nn.LayerNorm(self.hidden_size),
|
|
# )
|
|
|
|
self.game_mapping = game_mapping
|
|
# Create an embedding table for game IDs
|
|
# Game ID tokens will be put inside vision-language tokens
|
|
# so they need to be projected to the same dimension
|
|
if self.game_mapping is not None:
|
|
# 0 = unconditional
|
|
self.game_embedding = nn.Embedding(
|
|
len(self.game_mapping),
|
|
self.vision_hidden_size,
|
|
padding_idx=0,
|
|
scale_grad_by_freq=True
|
|
)
|
|
|
|
self.set_trainable_parameters(
|
|
tune_multi_projector=config.tune_multi_projector,
|
|
tune_diffusion_model=config.tune_diffusion_model,
|
|
tune_vision_tower=config.tune_vision_tower,
|
|
tune_mm_projector=config.tune_mm_projector,
|
|
tune_vl_mixing=config.tune_vl_mixing,
|
|
)
|
|
|
|
print(
|
|
"total number of parameters: %e",
|
|
sum(p.numel() for p in self.parameters() if p.requires_grad),
|
|
)
|
|
|
|
def set_trainable_parameters(
|
|
self,
|
|
tune_multi_projector: bool = True,
|
|
tune_diffusion_model: bool = True,
|
|
tune_vision_tower: bool = True,
|
|
tune_mm_projector: bool = True,
|
|
tune_vl_mixing: bool = True,
|
|
):
|
|
self.tune_multi_projector = tune_multi_projector
|
|
self.tune_diffusion_model = tune_diffusion_model
|
|
self.tune_vision_tower = tune_vision_tower
|
|
self.tune_mm_projector = tune_mm_projector
|
|
self.tune_vl_mixing = tune_vl_mixing
|
|
|
|
for param in self.parameters():
|
|
param.requires_grad = True
|
|
# ### Always freeze language encoder
|
|
# self.siglip_model.text_model.requires_grad_(False)
|
|
# # Freeze unused parameters in siglip vision encoder
|
|
# self.siglip_model.logit_scale.requires_grad = False
|
|
# self.siglip_model.logit_bias.requires_grad = False
|
|
|
|
# For siglip, we have to
|
|
if self.vision_encoder_type == "siglip":
|
|
for param in self.vision_encoder.encoder.layers[11].parameters():
|
|
param.requires_grad = False
|
|
for param in self.vision_encoder.head.parameters():
|
|
param.requires_grad = False
|
|
|
|
|
|
# Freeze parameters
|
|
if not tune_multi_projector:
|
|
# self.state_encoder.requires_grad_(False)
|
|
self.action_encoder.requires_grad_(False)
|
|
self.action_decoder.requires_grad_(False)
|
|
if self.config.add_pos_embed:
|
|
self.position_embedding.requires_grad_(False)
|
|
if self.config.add_view_embed:
|
|
self.view_embedding.requires_grad_(False)
|
|
if not tune_diffusion_model:
|
|
self.model.requires_grad_(False)
|
|
if not tune_vision_tower:
|
|
self.vision_encoder.requires_grad_(False)
|
|
if self.mm_projector is not None and not tune_mm_projector:
|
|
self.mm_projector.requires_grad_(False)
|
|
if not tune_vl_mixing:
|
|
self.vl_self_attention_model.requires_grad_(False)
|
|
|
|
print(f"Tune action head multi_projector: {self.tune_multi_projector}")
|
|
print(f"Tune action head diffusion model: {self.tune_diffusion_model}")
|
|
print(f"Tune action head vision tower: {self.tune_vision_tower}")
|
|
print(f"Tune action head mm_projector: {self.tune_mm_projector}")
|
|
print(f"Tune action head vl_mixing: {self.tune_vl_mixing}")
|
|
# Check if any parameters are still trainable. If not, print a warning.
|
|
if not any(p.requires_grad for p in self.parameters()):
|
|
print("Warning: No action head trainable parameters found.")
|
|
|
|
def set_frozen_modules_to_eval_mode(self):
|
|
"""
|
|
Huggingface will call model.train() at each training_step. To ensure
|
|
the expected behaviors for modules like dropout, batchnorm, etc., we
|
|
need to call model.eval() for the frozen modules.
|
|
"""
|
|
if self.training:
|
|
# self.siglip_model.text_model.eval()
|
|
if not self.tune_multi_projector:
|
|
# self.state_encoder.eval()
|
|
self.action_encoder.eval()
|
|
self.action_decoder.eval()
|
|
if self.config.add_pos_embed:
|
|
self.position_embedding.eval()
|
|
# if self.config.add_view_embed:
|
|
# self.view_embedding.eval()
|
|
if not self.tune_diffusion_model:
|
|
self.model.eval()
|
|
if not self.tune_vision_tower:
|
|
self.vision_encoder.eval()
|
|
if self.mm_projector is not None and not self.tune_mm_projector:
|
|
self.mm_projector.eval()
|
|
if not self.tune_vl_mixing:
|
|
self.vl_self_attention_model.eval()
|
|
|
|
# This function is supposedly incorrect
|
|
# def sample_time(self, batch_size, device, dtype):
|
|
# sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype)
|
|
# return (self.config.noise_s - sample) / self.config.noise_s
|
|
|
|
def sample_time(self, batch_size, device, dtype):
|
|
sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype)
|
|
return (1 - sample) * self.config.noise_s
|
|
|
|
def encode_images(self, images): #, view_ids):
|
|
batch_size, num_frames, channels, height, width = images.shape
|
|
images = images.reshape(-1, channels, height, width)
|
|
|
|
image_features = self.vision_encoder(images)["last_hidden_state"]
|
|
image_features = rearrange(image_features, "(b f) n d -> b f n d", f=num_frames)
|
|
|
|
# if self.vision_projector is not None:
|
|
# # change the hidden dimension of the vision features
|
|
# image_features = self.vision_projector(image_features)
|
|
if self.mm_projector is not None:
|
|
image_features = self.mm_projector(image_features) # [B, 256, 1024] -> [B, 16, 1024]
|
|
return image_features
|
|
|
|
def prepare_input_embs(self, vl_token_ids, sa_token_ids, vision, action, dropped_images, game_ids=None):
|
|
B, T = vl_token_ids.shape
|
|
vl_embs = torch.full(
|
|
size=(B, T, self.vision_hidden_size), fill_value=0.0, dtype=vision.dtype, device=vision.device
|
|
)
|
|
|
|
# Extract dimensions from vision tensor
|
|
B, num_images, tokens_per_image, hidden_size = vision.shape
|
|
|
|
# Create mask for _IMG_TOKEN positions
|
|
vision_mask = (vl_token_ids == _IMG_TOKEN) # [B, T]
|
|
|
|
# Flatten vision tensor over the num_images dimension
|
|
vision_flat = vision.reshape(B, -1, self.vision_hidden_size) # [B, T * tokens_per_image, hidden_size]
|
|
|
|
# Create a mask for the flattened vision dimension
|
|
# Each image contributes tokens_per_image tokens, so expand the mask accordingly
|
|
non_dropped_mask_expanded = (dropped_images == 0).unsqueeze(-1).repeat(1, 1, tokens_per_image).reshape(B, -1) # [B, T * tokens_per_image]
|
|
|
|
# Select only non-dropped vision embeddings
|
|
# This will give us the embeddings we need to place
|
|
valid_vision_embs = vision_flat[non_dropped_mask_expanded] # [total_valid_tokens, 1152]
|
|
|
|
assert valid_vision_embs.shape[0] == vision_mask.sum().item(), (
|
|
f"Number of valid vision embeddings {valid_vision_embs.shape[0]} does not match "
|
|
f"the number of _IMG_TOKEN positions {vision_mask.sum().item()}"
|
|
)
|
|
# Now we need to place these at the vision_mask positions
|
|
# Get indices where vision_mask is True
|
|
batch_indices, token_indices = vision_mask.nonzero(as_tuple=True)
|
|
|
|
# Place the valid embeddings at the masked positions
|
|
vl_embs[batch_indices, token_indices] = valid_vision_embs
|
|
|
|
# Handle Game ID tokens
|
|
if self.game_mapping is not None and game_ids is not None:
|
|
game_mask = vl_token_ids == _GAME_ID_TOKEN # shape: (B, T)
|
|
num_game_tokens = game_mask.sum().item()
|
|
if num_game_tokens > 0:
|
|
|
|
# Assert that each batch item has exactly one game token
|
|
game_tokens_per_batch = game_mask.sum(dim=1) # [B] - count of game tokens per batch item
|
|
assert torch.all(game_tokens_per_batch == 1), (
|
|
f"Expected exactly 1 game token per batch item, but got: {game_tokens_per_batch.tolist()}. "
|
|
f"Each batch item must have exactly one _GAME_ID_TOKEN."
|
|
)
|
|
|
|
# Get game embeddings for each batch item
|
|
game_embs = self.game_embedding(game_ids) # [B, vision_hidden_size]
|
|
batch_indices, token_indices = game_mask.nonzero(as_tuple=True)
|
|
vl_embs[batch_indices, token_indices] = game_embs[batch_indices].to(dtype=vl_embs.dtype)
|
|
|
|
# Project image separator using the learnable sep_embedding.
|
|
sep_mask = vl_token_ids == _IMG_SEP_TOKEN # shape: (B, T)
|
|
num_sep = sep_mask.sum().item()
|
|
if num_sep > 0:
|
|
# Expand the separator embedding for each occurrence.
|
|
repeated_sep = self.vis_sep_embedding.unsqueeze(0).expand(num_sep, self.hidden_size)
|
|
# Assign the separator embeddings to the correct positions.
|
|
vl_embs[sep_mask] = repeated_sep.to(dtype=vl_embs.dtype)
|
|
|
|
B, T = sa_token_ids.shape
|
|
sa_embs = torch.full(
|
|
size=(B, T, self.hidden_size), fill_value=0.0, dtype=vision.dtype, device=vision.device
|
|
)
|
|
|
|
# Project state.
|
|
# state_mask = sa_token_ids == _PROPRIO_TOKEN
|
|
# state_mask = state_mask.unsqueeze(-1).expand_as(sa_embs)
|
|
# sa_embs = sa_embs.masked_scatter(state_mask, state)
|
|
|
|
# Project action.
|
|
action_mask = sa_token_ids == _ACT_TOKEN
|
|
action_mask = action_mask.unsqueeze(-1).expand_as(sa_embs)
|
|
sa_embs = sa_embs.masked_scatter(action_mask, action)
|
|
|
|
# Add positional embeddings
|
|
pos_ids = torch.arange(T, dtype=torch.long, device=sa_token_ids.device)
|
|
if self.config.add_pos_embed:
|
|
pos_embs = self.position_embedding(pos_ids) # (T, hidden_size)
|
|
pos_embs = pos_embs.unsqueeze(0).expand(B, T, self.hidden_size)
|
|
sa_embs = sa_embs + pos_embs
|
|
return vl_embs, sa_embs
|
|
|
|
def pack_actions(self, buttons, j_left, j_right):
|
|
# Check that the first three dims of each input is the same
|
|
assert buttons.shape[:3] == j_left.shape[:3] == j_right.shape[:3], (
|
|
f"buttons shape: {buttons.shape}, "
|
|
f"j_left shape: {j_left.shape}, "
|
|
f"j_right shape: {j_right.shape}"
|
|
)
|
|
|
|
# Normalize the joysticks to 0,1
|
|
j_left = (j_left + 1) / 2.
|
|
j_right = (j_right + 1) / 2.
|
|
|
|
# Concatenate the buttons and joysticks along the last dimension
|
|
action = torch.cat([j_left, j_right, buttons], dim=-1)
|
|
|
|
# Squeeze the second dimension of each input: this is the number of chunks, which is 1 here
|
|
action = action.squeeze(1)
|
|
return action
|
|
|
|
# def unpack_actions(self, actions):
|
|
# # Unpack the actions into j_left, j_right, buttons
|
|
# j_left = actions[:, :, :2]
|
|
# j_right = actions[:, :, 2:4]
|
|
# buttons = actions[:, :, 4:]
|
|
|
|
# # Denormalize the joysticks back to -1,1
|
|
# j_left = j_left * 2. - 1.
|
|
# j_right = j_right * 2. - 1.
|
|
|
|
# # Clip into [-1,1]
|
|
# j_left = torch.clamp(j_left, -1, 1)
|
|
# j_right = torch.clamp(j_right, -1, 1)
|
|
|
|
# # Threshold the buttons to 0,1
|
|
# buttons = (buttons > 0.5).float()
|
|
# return j_left, j_right, buttons
|
|
|
|
# ========= ActionHead required ============
|
|
def forward(self, data: dict) -> dict:
|
|
self.set_frozen_modules_to_eval_mode()
|
|
|
|
# data = action_input
|
|
embodiment_id = data["embodiment_id"]
|
|
|
|
# # Check which data is present.
|
|
# has_real_action = action_input.has_real_action
|
|
has_real_action = data["has_real_action"]
|
|
|
|
# 1) Encode images/text/state
|
|
visual_features = self.encode_images(data["images"]) #, data["view_ids"])
|
|
# text_features = self.siglip_model.text_model(
|
|
# input_ids=data["lang_input_ids"]
|
|
# ).last_hidden_state
|
|
# state_features = self.state_encoder(data["state"], embodiment_id)
|
|
|
|
# 2) Prepare noisy trajectory
|
|
actions = data["actions"]
|
|
noise = torch.randn_like(actions)
|
|
t = self.sample_time(actions.shape[0], device=actions.device, dtype=actions.dtype)
|
|
t = t[:, None, None] # shape (B,1,1) for broadcast
|
|
|
|
noisy_trajectory = (1 - t) * noise + t * actions
|
|
velocity = actions - noise
|
|
|
|
# 3) Convert (continuous) t -> discrete if needed
|
|
t_discretized = (t[:, 0, 0] * self.num_timestep_buckets).long()
|
|
|
|
# 4) Get action encoder embeddings with correct time argument
|
|
action_features = self.action_encoder(noisy_trajectory, t_discretized, embodiment_id)
|
|
|
|
# 5) Prepare full input to DiT (or your model)
|
|
vl_embs, sa_embs = self.prepare_input_embs(
|
|
data["vl_token_ids"],
|
|
data["sa_token_ids"],
|
|
visual_features,
|
|
# text_features,
|
|
# state_features,
|
|
action_features,
|
|
data["dropped_images"],
|
|
game_ids=data.get("game_id"),
|
|
)
|
|
|
|
vl_embs = self.vl_self_attention_model(vl_embs)
|
|
# vl_embs = self.qformer(vl_embs)
|
|
model_output, all_hidden_states = self.model(
|
|
hidden_states=sa_embs,
|
|
encoder_hidden_states=vl_embs,
|
|
encoder_attention_mask=data["vl_attn_mask"],
|
|
timestep=t_discretized,
|
|
return_all_hidden_states=True,
|
|
)
|
|
pred = self.action_decoder(model_output, embodiment_id)
|
|
pred_actions = pred[:, -actions.shape[1] :]
|
|
|
|
# 6) Flow-matching or velocity-prediction MSE
|
|
# Mask for variable-length trajectories
|
|
mask = data["actions_mask"] # shape => (B, seq_len_of_actions, ...)
|
|
raw_loss = F.mse_loss(pred_actions, velocity, reduction="none")
|
|
mask = has_real_action[:, None, None] * mask
|
|
raw_loss = raw_loss * mask
|
|
action_loss = (has_real_action[:, None, None] * raw_loss).sum() / (mask.sum() + 1e-6)
|
|
|
|
loss = action_loss
|
|
|
|
return {
|
|
"loss": loss,
|
|
}
|
|
|
|
@torch.inference_mode()
|
|
def get_action(self, data: dict, old_layout:bool = False) -> dict:
|
|
"""
|
|
For i in [0..N-1]:
|
|
1) t = i/N
|
|
2) velocity = model(x(t), t)
|
|
3) x(t + dt) = x(t) + dt * velocity
|
|
"""
|
|
|
|
# data = action_input
|
|
embodiment_id = data["embodiment_id"]
|
|
|
|
batch_size = data["images"].shape[0]
|
|
device = data["images"].device
|
|
dtype = data["images"].dtype
|
|
actions = torch.randn(
|
|
size=(batch_size, self.config.action_horizon, self.config.action_dim),
|
|
dtype=dtype,
|
|
device=device,
|
|
)
|
|
|
|
# 1) Hyperparameters for flow sampling
|
|
num_steps = self.num_inference_timesteps
|
|
dt = 1.0 / num_steps
|
|
|
|
# 2) Encode static context (images, text, state) once if it does not depend on actions
|
|
visual_features = self.encode_images(data["images"]) #, data["view_ids"])
|
|
# text_features = self.siglip_model.text_model(
|
|
# input_ids=data["lang_input_ids"]
|
|
# ).last_hidden_state
|
|
# state_features = self.state_encoder(data["state"], embodiment_id)
|
|
|
|
# 3) Start denoising the actions
|
|
for i in range(num_steps):
|
|
# ---- (a) Discretize continuous time in [0,1]
|
|
t_cont = i / float(num_steps) # e.g. goes 0, 1/N, 2/N, ...
|
|
t_discretized = int(t_cont * self.num_timestep_buckets)
|
|
|
|
# ---- (b) Build embeddings (actions included)
|
|
# Pass the *current* actions at time t into the action encoder
|
|
action_features = self.action_encoder(
|
|
actions,
|
|
(torch.ones(actions.shape[0]) * t_discretized).to(device),
|
|
embodiment_id,
|
|
)
|
|
vl_embs, sa_embs = self.prepare_input_embs(
|
|
data["vl_token_ids"],
|
|
data["sa_token_ids"],
|
|
visual_features,
|
|
# text_features,
|
|
# state_features,
|
|
action_features,
|
|
data["dropped_images"],
|
|
game_ids=data["game_ids"],
|
|
)
|
|
vl_embs = self.vl_self_attention_model(vl_embs)
|
|
# vl_embs = self.qformer(vl_embs)
|
|
# ---- (c) Forward pass to get velocity = d/dt x(t)
|
|
timesteps = torch.from_numpy(np.array([t_discretized])).to(device).long()
|
|
model_output = self.model(
|
|
hidden_states=sa_embs,
|
|
encoder_hidden_states=vl_embs,
|
|
encoder_attention_mask=data["vl_attn_mask"],
|
|
timestep=timesteps,
|
|
)
|
|
pred = self.action_decoder(model_output, embodiment_id)
|
|
pred_velocity = pred[:, -actions.shape[1] :]
|
|
|
|
# ---- (d) Naive Euler step: x(t + dt) = x(t) + dt * velocity
|
|
actions = actions + dt * pred_velocity
|
|
|
|
return {
|
|
"action_tensor": actions,
|
|
}
|
|
|
|
@torch.inference_mode()
|
|
def get_action_with_cfg(self, data_cond: dict, data_uncond: dict, cfg_scale: float = 1.0) -> dict:
|
|
"""
|
|
Use a form of classifier free guidance to sample actions. This can only be used on
|
|
models that were trained on multiple frames of actions. The idea is that we sample
|
|
velocity with and without the frame history, and then we push the sampled actions
|
|
towards the ones that were sampled with the frame history.
|
|
|
|
data_with_hist = conditional input
|
|
data_without_hist = unconditional input
|
|
|
|
This function works with any kind of conditioning, not just history.
|
|
|
|
For i in [0..N-1]:
|
|
1) t = i/N
|
|
2) velocity = (1 - cfg_scale) * model(x(t), t, None) + cfg_scale * model(x(t), t, history)
|
|
3) x(t + dt) = x(t) + dt * velocity
|
|
"""
|
|
|
|
# data = action_input
|
|
embodiment_id = data_cond["embodiment_id"]
|
|
|
|
batch_size = data_cond["images"].shape[0]
|
|
device = data_cond["images"].device
|
|
dtype = data_cond["images"].dtype
|
|
actions = torch.randn(
|
|
size=(batch_size, self.config.action_horizon, self.config.action_dim),
|
|
dtype=dtype,
|
|
device=device,
|
|
)
|
|
|
|
# 1) Hyperparameters for flow sampling
|
|
num_steps = self.num_inference_timesteps
|
|
dt = 1.0 / num_steps
|
|
|
|
# 2) Encode static context (images, text, state) once if it does not depend on actions
|
|
visual_features_cond = self.encode_images(data_cond["images"])
|
|
visual_features_uncond = self.encode_images(data_uncond["images"])
|
|
# text_features = self.siglip_model.text_model(
|
|
# input_ids=data["lang_input_ids"]
|
|
# ).last_hidden_state
|
|
# state_features = self.state_encoder(data["state"], embodiment_id)
|
|
|
|
# 3) Start denoising the actions
|
|
for i in range(num_steps):
|
|
# ---- (a) Discretize continuous time in [0,1]
|
|
t_cont = i / float(num_steps) # e.g. goes 0, 1/N, 2/N, ...
|
|
t_discretized = int(t_cont * self.num_timestep_buckets)
|
|
|
|
# ---- (b) Build embeddings (actions included)
|
|
# Pass the *current* actions at time t into the action encoder
|
|
action_features = self.action_encoder(
|
|
actions,
|
|
(torch.ones(actions.shape[0]) * t_discretized).to(device),
|
|
embodiment_id,
|
|
)
|
|
|
|
# Predict velocity with history
|
|
vl_embs, sa_embs = self.prepare_input_embs(
|
|
data_cond["vl_token_ids"],
|
|
data_cond["sa_token_ids"],
|
|
visual_features_cond,
|
|
action_features,
|
|
data_cond["dropped_images"],
|
|
)
|
|
vl_embs = self.vl_self_attention_model(vl_embs)
|
|
# ---- (c) Forward pass to get velocity = d/dt x(t)
|
|
timesteps = torch.from_numpy(np.array([t_discretized])).to(device).long()
|
|
model_output = self.model(
|
|
hidden_states=sa_embs,
|
|
encoder_hidden_states=vl_embs,
|
|
encoder_attention_mask=data_cond["vl_attn_mask"],
|
|
timestep=timesteps,
|
|
)
|
|
pred = self.action_decoder(model_output, embodiment_id)
|
|
pred_velocity_cond = pred[:, -actions.shape[1] :]
|
|
|
|
# Predict velocity without history
|
|
vl_embs, sa_embs = self.prepare_input_embs(
|
|
data_uncond["vl_token_ids"],
|
|
data_uncond["sa_token_ids"],
|
|
visual_features_uncond,
|
|
action_features,
|
|
data_uncond["dropped_images"],
|
|
)
|
|
vl_embs = self.vl_self_attention_model(vl_embs)
|
|
# ---- (c) Forward pass to get velocity = d/dt x(t)
|
|
timesteps = torch.from_numpy(np.array([t_discretized])).to(device).long()
|
|
model_output = self.model(
|
|
hidden_states=sa_embs,
|
|
encoder_hidden_states=vl_embs,
|
|
encoder_attention_mask=data_uncond["vl_attn_mask"],
|
|
timestep=timesteps,
|
|
)
|
|
pred = self.action_decoder(model_output, embodiment_id)
|
|
pred_velocity_uncond = pred[:, -actions.shape[1] :]
|
|
|
|
# ---- (d) Combine velocities with cfg_scale
|
|
pred_velocity = pred_velocity_cond + cfg_scale * (pred_velocity_cond - pred_velocity_uncond)
|
|
|
|
# ---- (e) Naive Euler step: x(t + dt) = x(t) + dt * velocity
|
|
actions = actions + dt * pred_velocity
|
|
|
|
return {
|
|
"action_tensor": actions,
|
|
}
|
|
|
|
@property
|
|
def device(self):
|
|
return next(iter(self.parameters())).device
|
|
|
|
@property
|
|
def dtype(self):
|
|
return next(iter(self.parameters())).dtype
|