Files
NitroGen/nitrogen/flow_matching_transformer/nitrogen.py
2025-12-19 17:21:03 +01:00

756 lines
31 KiB
Python

from dataclasses import dataclass, field
from pydantic import BaseModel, Field
from pathlib import Path
import yaml
import numpy as np
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn
from torch.distributions import Beta
from transformers import SiglipVisionModel, AutoModel
from .modules import DiT, DiTConfig, SelfAttentionTransformer, SelfAttentionTransformerConfig
_PAD_TOKEN = 0
_IMG_TOKEN = 1
_IMG_SEP_TOKEN = 5
_LANG_TOKEN = 2
_PROPRIO_TOKEN = 3
_ACT_TOKEN = 4
_GAME_ID_TOKEN = 6
class NitroGen_Config(BaseModel):
model_type: str = Field(default="nitrogen", frozen=True)
add_pos_embed: bool = Field(default=False, description="Whether to add positional embedding")
model_dtype: str = Field(default="float32", description="Model data type.")
diffusion_model_cfg: DiTConfig = Field(..., description="Diffusion model configuration.")
vl_self_attention_cfg: SelfAttentionTransformerConfig = Field(..., description="VL self-attention configuration.")
hidden_size: int = Field(default=1024, description="Input embedding dimension.")
max_seq_len: int = Field(default=1024, description="Maxium Sequence Length")
action_dim: int = Field(default=None, description="Action dimension.")
action_horizon: int = Field(default=None, description="Action horizon.")
noise_beta_alpha: float = Field(default=1.5, description="")
noise_beta_beta: float = Field(default=1.0, description="")
noise_s: float = Field(default=0.999, description="Flow matching noise Beta distribution s.")
num_timestep_buckets: int = Field(default=1000, description="Number of timestep discretization buckets.")
num_inference_timesteps: int = Field(default=None, description="Number of inference steps for noise diffusion.")
max_num_embodiments: int = Field(default=1, description="Number of embodiments.")
vision_encoder_name: str = Field(default="google/siglip-large-patch16-256", description="Vision encoder name.")
vision_hidden_size: int = Field(default=768, description="Siglip hidden size.")
add_view_embed: bool = Field(default=False, description="Whether to add view embedding.")
tune_vision_tower: bool = Field(default=True, description="Tune vision if True.")
tune_mm_projector: bool = Field(default=True, description="Tune mm projector if True.")
tune_diffusion_model: bool = Field(default=True, description="Tune diffusion model if True.")
tune_multi_projector: bool = Field(default=True, description="Tune multi projector if True.")
tune_vl_mixing: bool = Field(default=True, description="Tune vl mixing if True.")
@classmethod
def from_yaml(cls, yaml_path: str | Path) -> "NitroGen_Config":
"""Load configuration from a YAML file."""
with open(yaml_path, "r") as f:
config_dict = yaml.safe_load(f)
return cls.model_validate(config_dict)
def swish(x):
return x * torch.sigmoid(x)
class SinusoidalPositionalEncoding(nn.Module):
"""
Produces a sinusoidal encoding of shape (B, T, w)
given timesteps of shape (B, T).
"""
def __init__(self, embedding_dim):
super().__init__()
self.embedding_dim = embedding_dim
def forward(self, timesteps):
# timesteps: shape (B, T)
# We'll compute sin/cos frequencies across dim T
timesteps = timesteps.float() # ensure float
B, T = timesteps.shape
device = timesteps.device
half_dim = self.embedding_dim // 2
# typical log space frequencies for sinusoidal encoding
exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * (
torch.log(torch.tensor(10000.0)) / half_dim
)
# Expand timesteps to (B, T, 1) then multiply
freqs = timesteps.unsqueeze(-1) * exponent.exp() # (B, T, half_dim)
sin = torch.sin(freqs)
cos = torch.cos(freqs)
enc = torch.cat([sin, cos], dim=-1) # (B, T, w)
return enc
class CategorySpecificLinear(nn.Module):
def __init__(self, num_categories, input_dim, hidden_dim):
super().__init__()
self.num_categories = num_categories
# For each category, we have separate weights and biases.
self.W = nn.Parameter(0.02 * torch.randn(num_categories, input_dim, hidden_dim))
self.b = nn.Parameter(torch.zeros(num_categories, hidden_dim))
def forward(self, x, cat_ids):
selected_W = self.W[cat_ids]
selected_b = self.b[cat_ids]
return torch.bmm(x, selected_W) + selected_b.unsqueeze(1)
class CategorySpecificMLP(nn.Module):
def __init__(self, num_categories, input_dim, hidden_dim, output_dim):
super().__init__()
self.num_categories = num_categories
self.layer1 = CategorySpecificLinear(num_categories, input_dim, hidden_dim)
self.layer2 = CategorySpecificLinear(num_categories, hidden_dim, output_dim)
def forward(self, x, cat_ids):
hidden = F.relu(self.layer1(x, cat_ids))
return self.layer2(hidden, cat_ids)
class MultiEmbodimentActionEncoder(nn.Module):
def __init__(self, action_dim, hidden_size, num_embodiments):
super().__init__()
self.hidden_size = hidden_size
self.num_embodiments = num_embodiments
# W1: R^{w x d}, W2: R^{w x 2w}, W3: R^{w x w}
self.W1 = CategorySpecificLinear(num_embodiments, action_dim, hidden_size) # (d -> w)
self.W2 = CategorySpecificLinear(num_embodiments, 2 * hidden_size, hidden_size) # (2w -> w)
self.W3 = CategorySpecificLinear(num_embodiments, hidden_size, hidden_size) # (w -> w)
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
def forward(self, actions, timesteps, cat_ids):
"""
actions: shape (B, T, action_dim)
timesteps: shape (B,) -- a single scalar per batch item
cat_ids: shape (B,)
returns: shape (B, T, hidden_size)
"""
B, T, _ = actions.shape
# 1) Expand each batch's single scalar time 'tau' across all T steps
# so that shape => (B, T)
# e.g. if timesteps is (B,), replicate across T
if timesteps.dim() == 1 and timesteps.shape[0] == B:
# shape (B,) => (B,T)
timesteps = timesteps.unsqueeze(1).expand(-1, T)
else:
raise ValueError(
"Expected `timesteps` to have shape (B,) so we can replicate across T."
)
# 2) Standard action MLP step for shape => (B, T, w)
a_emb = self.W1(actions, cat_ids)
# 3) Get the sinusoidal encoding (B, T, w)
tau_emb = self.pos_encoding(timesteps).to(dtype=a_emb.dtype)
# 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish
x = torch.cat([a_emb, tau_emb], dim=-1)
x = swish(self.W2(x, cat_ids))
# 5) Finally W3 => (B, T, w)
x = self.W3(x, cat_ids)
return x
class NitroGen(torch.nn.Module):
config_class = NitroGen_Config
supports_gradient_checkpointing = True
def __init__(
self,
config: NitroGen_Config,
game_mapping: dict[str, int] | None = None, # Used to add a game ID token
):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.vision_hidden_size = config.vision_hidden_size
if "siglip" in config.vision_encoder_name:
model = SiglipVisionModel.from_pretrained(config.vision_encoder_name)
self.vision_encoder = model.vision_model
self.vision_encoder_type = "siglip"
else:
self.vision_encoder = AutoModel.from_pretrained(config.vision_encoder_name)
self.vision_encoder_type = "hf_auto"
self.beta_dist = Beta(config.noise_beta_alpha, config.noise_beta_beta)
self.num_timestep_buckets = config.num_timestep_buckets
# self.model = instantiate(config.diffusion_model_cfg)
self.model = DiT(config=config.diffusion_model_cfg)
self.action_dim = config.action_dim
self.action_horizon = config.action_horizon
self.num_inference_timesteps = config.num_inference_timesteps
# self.vl_self_attention_model = instantiate(config.vl_self_attention_cfg)
self.vl_self_attention_model = SelfAttentionTransformer(config=config.vl_self_attention_cfg)
# if config.qformer_cfg is not None:
# self.qformer = instantiate(config.qformer_cfg)
# else:
# self.qformer = nn.Identity()
# self.state_encoder = CategorySpecificMLP(
# num_categories=config.max_num_embodiments,
# input_dim=config.max_state_dim,
# hidden_dim=self.hidden_size,
# output_dim=self.hidden_size,
# )
self.action_encoder = MultiEmbodimentActionEncoder(
action_dim=config.action_dim,
hidden_size=self.hidden_size,
num_embodiments=config.max_num_embodiments,
)
self.action_decoder = CategorySpecificMLP(
num_categories=config.max_num_embodiments,
input_dim=self.hidden_size,
hidden_dim=self.hidden_size,
output_dim=self.action_dim,
)
# self.mm_vision_select_layer = config.mm_vision_select_layer
# if config.mm_projector_cfg is not None:
# self.mm_projector = instantiate(config.mm_projector_cfg)
# else:
self.mm_projector = None
if config.add_pos_embed:
self.position_embedding = nn.Embedding(config.max_seq_len, self.hidden_size)
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
# if config.add_view_embed:
# self.view_embedding = nn.Embedding(config.max_num_views, self.hidden_size)
# nn.init.normal_(self.view_embedding.weight, mean=0.0, std=0.02)
# self.vision_projector = None
# if config.vision_hidden_size != self.hidden_size:
# self.vision_projector = nn.Sequential(
# nn.Linear(config.vision_hidden_size, self.hidden_size),
# nn.LayerNorm(self.hidden_size),
# )
self.game_mapping = game_mapping
# Create an embedding table for game IDs
# Game ID tokens will be put inside vision-language tokens
# so they need to be projected to the same dimension
if self.game_mapping is not None:
# 0 = unconditional
self.game_embedding = nn.Embedding(
len(self.game_mapping),
self.vision_hidden_size,
padding_idx=0,
scale_grad_by_freq=True
)
self.set_trainable_parameters(
tune_multi_projector=config.tune_multi_projector,
tune_diffusion_model=config.tune_diffusion_model,
tune_vision_tower=config.tune_vision_tower,
tune_mm_projector=config.tune_mm_projector,
tune_vl_mixing=config.tune_vl_mixing,
)
print(
"total number of parameters: %e",
sum(p.numel() for p in self.parameters() if p.requires_grad),
)
def set_trainable_parameters(
self,
tune_multi_projector: bool = True,
tune_diffusion_model: bool = True,
tune_vision_tower: bool = True,
tune_mm_projector: bool = True,
tune_vl_mixing: bool = True,
):
self.tune_multi_projector = tune_multi_projector
self.tune_diffusion_model = tune_diffusion_model
self.tune_vision_tower = tune_vision_tower
self.tune_mm_projector = tune_mm_projector
self.tune_vl_mixing = tune_vl_mixing
for param in self.parameters():
param.requires_grad = True
# ### Always freeze language encoder
# self.siglip_model.text_model.requires_grad_(False)
# # Freeze unused parameters in siglip vision encoder
# self.siglip_model.logit_scale.requires_grad = False
# self.siglip_model.logit_bias.requires_grad = False
# For siglip, we have to
if self.vision_encoder_type == "siglip":
for param in self.vision_encoder.encoder.layers[11].parameters():
param.requires_grad = False
for param in self.vision_encoder.head.parameters():
param.requires_grad = False
# Freeze parameters
if not tune_multi_projector:
# self.state_encoder.requires_grad_(False)
self.action_encoder.requires_grad_(False)
self.action_decoder.requires_grad_(False)
if self.config.add_pos_embed:
self.position_embedding.requires_grad_(False)
if self.config.add_view_embed:
self.view_embedding.requires_grad_(False)
if not tune_diffusion_model:
self.model.requires_grad_(False)
if not tune_vision_tower:
self.vision_encoder.requires_grad_(False)
if self.mm_projector is not None and not tune_mm_projector:
self.mm_projector.requires_grad_(False)
if not tune_vl_mixing:
self.vl_self_attention_model.requires_grad_(False)
print(f"Tune action head multi_projector: {self.tune_multi_projector}")
print(f"Tune action head diffusion model: {self.tune_diffusion_model}")
print(f"Tune action head vision tower: {self.tune_vision_tower}")
print(f"Tune action head mm_projector: {self.tune_mm_projector}")
print(f"Tune action head vl_mixing: {self.tune_vl_mixing}")
# Check if any parameters are still trainable. If not, print a warning.
if not any(p.requires_grad for p in self.parameters()):
print("Warning: No action head trainable parameters found.")
def set_frozen_modules_to_eval_mode(self):
"""
Huggingface will call model.train() at each training_step. To ensure
the expected behaviors for modules like dropout, batchnorm, etc., we
need to call model.eval() for the frozen modules.
"""
if self.training:
# self.siglip_model.text_model.eval()
if not self.tune_multi_projector:
# self.state_encoder.eval()
self.action_encoder.eval()
self.action_decoder.eval()
if self.config.add_pos_embed:
self.position_embedding.eval()
# if self.config.add_view_embed:
# self.view_embedding.eval()
if not self.tune_diffusion_model:
self.model.eval()
if not self.tune_vision_tower:
self.vision_encoder.eval()
if self.mm_projector is not None and not self.tune_mm_projector:
self.mm_projector.eval()
if not self.tune_vl_mixing:
self.vl_self_attention_model.eval()
# This function is supposedly incorrect
# def sample_time(self, batch_size, device, dtype):
# sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype)
# return (self.config.noise_s - sample) / self.config.noise_s
def sample_time(self, batch_size, device, dtype):
sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype)
return (1 - sample) * self.config.noise_s
def encode_images(self, images): #, view_ids):
batch_size, num_frames, channels, height, width = images.shape
images = images.reshape(-1, channels, height, width)
image_features = self.vision_encoder(images)["last_hidden_state"]
image_features = rearrange(image_features, "(b f) n d -> b f n d", f=num_frames)
# if self.vision_projector is not None:
# # change the hidden dimension of the vision features
# image_features = self.vision_projector(image_features)
if self.mm_projector is not None:
image_features = self.mm_projector(image_features) # [B, 256, 1024] -> [B, 16, 1024]
return image_features
def prepare_input_embs(self, vl_token_ids, sa_token_ids, vision, action, dropped_images, game_ids=None):
B, T = vl_token_ids.shape
vl_embs = torch.full(
size=(B, T, self.vision_hidden_size), fill_value=0.0, dtype=vision.dtype, device=vision.device
)
# Extract dimensions from vision tensor
B, num_images, tokens_per_image, hidden_size = vision.shape
# Create mask for _IMG_TOKEN positions
vision_mask = (vl_token_ids == _IMG_TOKEN) # [B, T]
# Flatten vision tensor over the num_images dimension
vision_flat = vision.reshape(B, -1, self.vision_hidden_size) # [B, T * tokens_per_image, hidden_size]
# Create a mask for the flattened vision dimension
# Each image contributes tokens_per_image tokens, so expand the mask accordingly
non_dropped_mask_expanded = (dropped_images == 0).unsqueeze(-1).repeat(1, 1, tokens_per_image).reshape(B, -1) # [B, T * tokens_per_image]
# Select only non-dropped vision embeddings
# This will give us the embeddings we need to place
valid_vision_embs = vision_flat[non_dropped_mask_expanded] # [total_valid_tokens, 1152]
assert valid_vision_embs.shape[0] == vision_mask.sum().item(), (
f"Number of valid vision embeddings {valid_vision_embs.shape[0]} does not match "
f"the number of _IMG_TOKEN positions {vision_mask.sum().item()}"
)
# Now we need to place these at the vision_mask positions
# Get indices where vision_mask is True
batch_indices, token_indices = vision_mask.nonzero(as_tuple=True)
# Place the valid embeddings at the masked positions
vl_embs[batch_indices, token_indices] = valid_vision_embs
# Handle Game ID tokens
if self.game_mapping is not None and game_ids is not None:
game_mask = vl_token_ids == _GAME_ID_TOKEN # shape: (B, T)
num_game_tokens = game_mask.sum().item()
if num_game_tokens > 0:
# Assert that each batch item has exactly one game token
game_tokens_per_batch = game_mask.sum(dim=1) # [B] - count of game tokens per batch item
assert torch.all(game_tokens_per_batch == 1), (
f"Expected exactly 1 game token per batch item, but got: {game_tokens_per_batch.tolist()}. "
f"Each batch item must have exactly one _GAME_ID_TOKEN."
)
# Get game embeddings for each batch item
game_embs = self.game_embedding(game_ids) # [B, vision_hidden_size]
batch_indices, token_indices = game_mask.nonzero(as_tuple=True)
vl_embs[batch_indices, token_indices] = game_embs[batch_indices].to(dtype=vl_embs.dtype)
# Project image separator using the learnable sep_embedding.
sep_mask = vl_token_ids == _IMG_SEP_TOKEN # shape: (B, T)
num_sep = sep_mask.sum().item()
if num_sep > 0:
# Expand the separator embedding for each occurrence.
repeated_sep = self.vis_sep_embedding.unsqueeze(0).expand(num_sep, self.hidden_size)
# Assign the separator embeddings to the correct positions.
vl_embs[sep_mask] = repeated_sep.to(dtype=vl_embs.dtype)
B, T = sa_token_ids.shape
sa_embs = torch.full(
size=(B, T, self.hidden_size), fill_value=0.0, dtype=vision.dtype, device=vision.device
)
# Project state.
# state_mask = sa_token_ids == _PROPRIO_TOKEN
# state_mask = state_mask.unsqueeze(-1).expand_as(sa_embs)
# sa_embs = sa_embs.masked_scatter(state_mask, state)
# Project action.
action_mask = sa_token_ids == _ACT_TOKEN
action_mask = action_mask.unsqueeze(-1).expand_as(sa_embs)
sa_embs = sa_embs.masked_scatter(action_mask, action)
# Add positional embeddings
pos_ids = torch.arange(T, dtype=torch.long, device=sa_token_ids.device)
if self.config.add_pos_embed:
pos_embs = self.position_embedding(pos_ids) # (T, hidden_size)
pos_embs = pos_embs.unsqueeze(0).expand(B, T, self.hidden_size)
sa_embs = sa_embs + pos_embs
return vl_embs, sa_embs
def pack_actions(self, buttons, j_left, j_right):
# Check that the first three dims of each input is the same
assert buttons.shape[:3] == j_left.shape[:3] == j_right.shape[:3], (
f"buttons shape: {buttons.shape}, "
f"j_left shape: {j_left.shape}, "
f"j_right shape: {j_right.shape}"
)
# Normalize the joysticks to 0,1
j_left = (j_left + 1) / 2.
j_right = (j_right + 1) / 2.
# Concatenate the buttons and joysticks along the last dimension
action = torch.cat([j_left, j_right, buttons], dim=-1)
# Squeeze the second dimension of each input: this is the number of chunks, which is 1 here
action = action.squeeze(1)
return action
# def unpack_actions(self, actions):
# # Unpack the actions into j_left, j_right, buttons
# j_left = actions[:, :, :2]
# j_right = actions[:, :, 2:4]
# buttons = actions[:, :, 4:]
# # Denormalize the joysticks back to -1,1
# j_left = j_left * 2. - 1.
# j_right = j_right * 2. - 1.
# # Clip into [-1,1]
# j_left = torch.clamp(j_left, -1, 1)
# j_right = torch.clamp(j_right, -1, 1)
# # Threshold the buttons to 0,1
# buttons = (buttons > 0.5).float()
# return j_left, j_right, buttons
# ========= ActionHead required ============
def forward(self, data: dict) -> dict:
self.set_frozen_modules_to_eval_mode()
# data = action_input
embodiment_id = data["embodiment_id"]
# # Check which data is present.
# has_real_action = action_input.has_real_action
has_real_action = data["has_real_action"]
# 1) Encode images/text/state
visual_features = self.encode_images(data["images"]) #, data["view_ids"])
# text_features = self.siglip_model.text_model(
# input_ids=data["lang_input_ids"]
# ).last_hidden_state
# state_features = self.state_encoder(data["state"], embodiment_id)
# 2) Prepare noisy trajectory
actions = data["actions"]
noise = torch.randn_like(actions)
t = self.sample_time(actions.shape[0], device=actions.device, dtype=actions.dtype)
t = t[:, None, None] # shape (B,1,1) for broadcast
noisy_trajectory = (1 - t) * noise + t * actions
velocity = actions - noise
# 3) Convert (continuous) t -> discrete if needed
t_discretized = (t[:, 0, 0] * self.num_timestep_buckets).long()
# 4) Get action encoder embeddings with correct time argument
action_features = self.action_encoder(noisy_trajectory, t_discretized, embodiment_id)
# 5) Prepare full input to DiT (or your model)
vl_embs, sa_embs = self.prepare_input_embs(
data["vl_token_ids"],
data["sa_token_ids"],
visual_features,
# text_features,
# state_features,
action_features,
data["dropped_images"],
game_ids=data.get("game_id"),
)
vl_embs = self.vl_self_attention_model(vl_embs)
# vl_embs = self.qformer(vl_embs)
model_output, all_hidden_states = self.model(
hidden_states=sa_embs,
encoder_hidden_states=vl_embs,
encoder_attention_mask=data["vl_attn_mask"],
timestep=t_discretized,
return_all_hidden_states=True,
)
pred = self.action_decoder(model_output, embodiment_id)
pred_actions = pred[:, -actions.shape[1] :]
# 6) Flow-matching or velocity-prediction MSE
# Mask for variable-length trajectories
mask = data["actions_mask"] # shape => (B, seq_len_of_actions, ...)
raw_loss = F.mse_loss(pred_actions, velocity, reduction="none")
mask = has_real_action[:, None, None] * mask
raw_loss = raw_loss * mask
action_loss = (has_real_action[:, None, None] * raw_loss).sum() / (mask.sum() + 1e-6)
loss = action_loss
return {
"loss": loss,
}
@torch.inference_mode()
def get_action(self, data: dict, old_layout:bool = False) -> dict:
"""
For i in [0..N-1]:
1) t = i/N
2) velocity = model(x(t), t)
3) x(t + dt) = x(t) + dt * velocity
"""
# data = action_input
embodiment_id = data["embodiment_id"]
batch_size = data["images"].shape[0]
device = data["images"].device
dtype = data["images"].dtype
actions = torch.randn(
size=(batch_size, self.config.action_horizon, self.config.action_dim),
dtype=dtype,
device=device,
)
# 1) Hyperparameters for flow sampling
num_steps = self.num_inference_timesteps
dt = 1.0 / num_steps
# 2) Encode static context (images, text, state) once if it does not depend on actions
visual_features = self.encode_images(data["images"]) #, data["view_ids"])
# text_features = self.siglip_model.text_model(
# input_ids=data["lang_input_ids"]
# ).last_hidden_state
# state_features = self.state_encoder(data["state"], embodiment_id)
# 3) Start denoising the actions
for i in range(num_steps):
# ---- (a) Discretize continuous time in [0,1]
t_cont = i / float(num_steps) # e.g. goes 0, 1/N, 2/N, ...
t_discretized = int(t_cont * self.num_timestep_buckets)
# ---- (b) Build embeddings (actions included)
# Pass the *current* actions at time t into the action encoder
action_features = self.action_encoder(
actions,
(torch.ones(actions.shape[0]) * t_discretized).to(device),
embodiment_id,
)
vl_embs, sa_embs = self.prepare_input_embs(
data["vl_token_ids"],
data["sa_token_ids"],
visual_features,
# text_features,
# state_features,
action_features,
data["dropped_images"],
game_ids=data["game_ids"],
)
vl_embs = self.vl_self_attention_model(vl_embs)
# vl_embs = self.qformer(vl_embs)
# ---- (c) Forward pass to get velocity = d/dt x(t)
timesteps = torch.from_numpy(np.array([t_discretized])).to(device).long()
model_output = self.model(
hidden_states=sa_embs,
encoder_hidden_states=vl_embs,
encoder_attention_mask=data["vl_attn_mask"],
timestep=timesteps,
)
pred = self.action_decoder(model_output, embodiment_id)
pred_velocity = pred[:, -actions.shape[1] :]
# ---- (d) Naive Euler step: x(t + dt) = x(t) + dt * velocity
actions = actions + dt * pred_velocity
return {
"action_tensor": actions,
}
@torch.inference_mode()
def get_action_with_cfg(self, data_cond: dict, data_uncond: dict, cfg_scale: float = 1.0) -> dict:
"""
Use a form of classifier free guidance to sample actions. This can only be used on
models that were trained on multiple frames of actions. The idea is that we sample
velocity with and without the frame history, and then we push the sampled actions
towards the ones that were sampled with the frame history.
data_with_hist = conditional input
data_without_hist = unconditional input
This function works with any kind of conditioning, not just history.
For i in [0..N-1]:
1) t = i/N
2) velocity = (1 - cfg_scale) * model(x(t), t, None) + cfg_scale * model(x(t), t, history)
3) x(t + dt) = x(t) + dt * velocity
"""
# data = action_input
embodiment_id = data_cond["embodiment_id"]
batch_size = data_cond["images"].shape[0]
device = data_cond["images"].device
dtype = data_cond["images"].dtype
actions = torch.randn(
size=(batch_size, self.config.action_horizon, self.config.action_dim),
dtype=dtype,
device=device,
)
# 1) Hyperparameters for flow sampling
num_steps = self.num_inference_timesteps
dt = 1.0 / num_steps
# 2) Encode static context (images, text, state) once if it does not depend on actions
visual_features_cond = self.encode_images(data_cond["images"])
visual_features_uncond = self.encode_images(data_uncond["images"])
# text_features = self.siglip_model.text_model(
# input_ids=data["lang_input_ids"]
# ).last_hidden_state
# state_features = self.state_encoder(data["state"], embodiment_id)
# 3) Start denoising the actions
for i in range(num_steps):
# ---- (a) Discretize continuous time in [0,1]
t_cont = i / float(num_steps) # e.g. goes 0, 1/N, 2/N, ...
t_discretized = int(t_cont * self.num_timestep_buckets)
# ---- (b) Build embeddings (actions included)
# Pass the *current* actions at time t into the action encoder
action_features = self.action_encoder(
actions,
(torch.ones(actions.shape[0]) * t_discretized).to(device),
embodiment_id,
)
# Predict velocity with history
vl_embs, sa_embs = self.prepare_input_embs(
data_cond["vl_token_ids"],
data_cond["sa_token_ids"],
visual_features_cond,
action_features,
data_cond["dropped_images"],
)
vl_embs = self.vl_self_attention_model(vl_embs)
# ---- (c) Forward pass to get velocity = d/dt x(t)
timesteps = torch.from_numpy(np.array([t_discretized])).to(device).long()
model_output = self.model(
hidden_states=sa_embs,
encoder_hidden_states=vl_embs,
encoder_attention_mask=data_cond["vl_attn_mask"],
timestep=timesteps,
)
pred = self.action_decoder(model_output, embodiment_id)
pred_velocity_cond = pred[:, -actions.shape[1] :]
# Predict velocity without history
vl_embs, sa_embs = self.prepare_input_embs(
data_uncond["vl_token_ids"],
data_uncond["sa_token_ids"],
visual_features_uncond,
action_features,
data_uncond["dropped_images"],
)
vl_embs = self.vl_self_attention_model(vl_embs)
# ---- (c) Forward pass to get velocity = d/dt x(t)
timesteps = torch.from_numpy(np.array([t_discretized])).to(device).long()
model_output = self.model(
hidden_states=sa_embs,
encoder_hidden_states=vl_embs,
encoder_attention_mask=data_uncond["vl_attn_mask"],
timestep=timesteps,
)
pred = self.action_decoder(model_output, embodiment_id)
pred_velocity_uncond = pred[:, -actions.shape[1] :]
# ---- (d) Combine velocities with cfg_scale
pred_velocity = pred_velocity_cond + cfg_scale * (pred_velocity_cond - pred_velocity_uncond)
# ---- (e) Naive Euler step: x(t + dt) = x(t) + dt * velocity
actions = actions + dt * pred_velocity
return {
"action_tensor": actions,
}
@property
def device(self):
return next(iter(self.parameters())).device
@property
def dtype(self):
return next(iter(self.parameters())).dtype