Files
NitroGen/nitrogen/mm_tokenizers.py
2025-12-19 17:21:03 +01:00

333 lines
12 KiB
Python

import os
from abc import ABC, abstractmethod
from typing import Literal
import numpy as np
import torch
import polars as pl
from pydantic import BaseModel, Field
from itertools import count
DEBUG = int(os.getenv("DEBUG", 0))
class Tokenizer(ABC):
@abstractmethod
def encode(self, data: dict) -> dict:
"""
Transform the input data into a tokenized format.
Args:
data (dict): Input data containing frames and actions.
Returns:
dict: Tokenized data ready for model input.
"""
pass
@abstractmethod
def decode(self, data: dict) -> dict:
"""
Reverse the tokenization process to retrieve original data.
Args:
data (dict): Tokenized data.
Returns:
dict: Original data structure.
"""
pass
@abstractmethod
def train(self):
"""
Set the tokenizer to training mode.
"""
pass
@abstractmethod
def eval(self):
"""
Set the tokenizer to evaluation mode.
"""
pass
# Set IDs for each token type.
_PAD_TOKEN = 0
_IMG_TOKEN = 1
_IMG_SEP_TOKEN = 5 # New separator token
_LANG_TOKEN = 2
_PROPRIO_TOKEN = 3
_ACT_TOKEN = 4
_GAME_ID_TOKEN = 6
_UNCONDITIONAL_ID = None # Special ID for unconditional game
class GameMappingConfig(BaseModel):
src_files: list[str] = Field(default_factory=list, description="List of source parquet files to build game mapping.")
def get_game_mapping(cfg: GameMappingConfig) -> dict:
game_set = set()
for path in cfg.src_files:
df = pl.read_parquet(path)
for game in df['game_label'].unique():
if game == _UNCONDITIONAL_ID:
continue
game_set.add(game)
games = sorted(list(game_set))
# Set the 0th element to be the unconditional game ID
games = [_UNCONDITIONAL_ID] + games
return {game: idx for idx, game in enumerate(games)}
class NitrogenTokenizerConfig(BaseModel):
tokenizer_id: Literal['nitrogen'] = Field(default='nitrogen', frozen=True)
training: bool = Field(default=True, description="Whether to apply the transform in training mode.")
num_visual_tokens_per_frame: int = Field(default=256, description="Number of visual tokens per frame.")
max_action_dim: int = Field(default=25, description="Maximum action dimension.")
max_sequence_length: int = Field(default=300, description="Maximum sequence length.")
action_horizon: int = Field(default=16, description="Action horizon.")
game_mapping_cfg: GameMappingConfig | None = Field(default=None, description="Game mapping configuration.")
old_layout: bool = Field(default=False, description="Whether to use the old layout for actions. If True, the action layout is [buttons, j_left, j_right]. If False, it is [j_left, j_right, buttons].")
class NitrogenTokenizer(Tokenizer):
"""
Example transform that prepares video, language, state, and actions
into a token-based format suitable for your model.
The sub-methods below (prefixed with `_prepare_`) mirror the original
modular structure.
"""
def __init__(self, config: NitrogenTokenizerConfig):
self.training = config.training
self.num_visual_tokens_per_frame = config.num_visual_tokens_per_frame
self.max_action_dim = config.max_action_dim
self.max_sequence_length = config.max_sequence_length
self.action_horizon = config.action_horizon
self.old_layout = config.old_layout
if config.game_mapping_cfg:
self.game_mapping = get_game_mapping(config.game_mapping_cfg)
with open("game_mapping.json", "w") as f:
import json
json.dump(self.game_mapping, f, indent=2)
else:
self.game_mapping = None
def train(self):
self.training = True
def eval(self):
self.training = False
def check_batch_size(self, data):
# Use video key to determine batch size.
video_ndim = data["images"].ndim
if video_ndim == 4: # Interpret as [T*V, H, W, C]
is_batched = False
batch_size = 1
elif video_ndim == 5: # Interpret as [B, T*V, H, W, C]
is_batched = True
batch_size = data["images"].shape[0]
else:
raise ValueError(f"Unsupported video number of dimensions: {video_ndim}")
return is_batched, batch_size
def _prepare_action(self, data: dict):
"""
Pad to max_action_dim, return masks.
"""
if "action" not in data:
actions = np.zeros((self.action_horizon, self.max_action_dim))
actions_mask = np.zeros((self.action_horizon, self.max_action_dim), dtype=bool)
n_action_tokens = self.action_horizon
return actions, actions_mask, n_action_tokens
actions = data["action"]
assert actions.shape[0] == self.action_horizon, f"{actions.shape=}, {self.action_horizon=}"
n_action_tokens = actions.shape[0] # T
n_action_dims = actions.shape[1]
assert (
n_action_dims <= self.max_action_dim
), f"Action dim {n_action_dims} exceeds max allowed {self.max_action_dim}."
# Pad the channel dimension
actions = np.pad(actions, ((0, 0), (0, self.max_action_dim - n_action_dims)), "constant")
# Create mask: [T, max_action_dim]
actions_mask = np.zeros((n_action_tokens, self.max_action_dim), dtype=bool)
actions_mask[:, :n_action_dims] = True
return actions, actions_mask, n_action_tokens
def _build_token_ids(self, n_images, n_action_tokens): # n_lang_tokens, n_state_tokens):
"""
Build the 1D array of token_ids based on the number of each block.
Return (token_ids, special_pad_token_idx).
"""
vl_token_ids = []
sa_token_ids = []
# 0.5) Add a Game ID placeholder
if self.game_mapping:
vl_token_ids.append(_GAME_ID_TOKEN)
# 1) Video placeholders
for _ in range(n_images):
vl_token_ids.extend([_IMG_TOKEN] * self.num_visual_tokens_per_frame)
# 2) Action tokens
sa_token_ids.extend([_ACT_TOKEN] * n_action_tokens)
return np.array(vl_token_ids), np.array(sa_token_ids)
def _prepare_attention_mask(
self,
vl_token_ids: np.ndarray,
):
"""
Build 1D attention mask for vision-language tokens.
1 indicates valid token, 0 indicates padding token.
State-action attention will be handled separately by the model.
"""
# Only create attention mask for vision-language tokens
vl_seq_len = vl_token_ids.shape[0]
vl_attn_mask = np.ones(vl_seq_len, dtype=bool) # All tokens are valid initially
# Pad vl_token_ids and vl_attn_mask to max_sequence_length
if vl_seq_len > self.max_sequence_length:
raise ValueError("VL sequence length exceeds the max sequence length!")
left_pad_len = self.max_sequence_length - vl_seq_len
# Pad token_ids (with PAD_TOKEN)
vl_token_ids = np.pad(vl_token_ids, (left_pad_len, 0), constant_values=_PAD_TOKEN)
# Pad attention mask with 0 (padding tokens)
vl_attn_mask = np.pad(vl_attn_mask, (left_pad_len, 0), constant_values=0)
return vl_token_ids, vl_attn_mask
def pack_actions(self, buttons, j_left, j_right):
# Check that the first two dims of each input is the same (number of chunks, control frequency)
assert buttons.shape[:2] == j_left.shape[:2] == j_right.shape[:2], (
f"buttons shape: {buttons.shape}, "
f"j_left shape: {j_left.shape}, "
f"j_right shape: {j_right.shape}"
)
# Normalize the joysticks to 0,1
j_left = (j_left + 1) / 2.
j_right = (j_right + 1) / 2.
# Concatenate the buttons and joysticks along the last dimension
action = np.concatenate([buttons,j_left,j_right],axis=-1, dtype=np.float32)
# Squeeze the first dimension of each input: this is the number of chunks, which is 1 here
action = action.squeeze(0)
return action
def unpack_actions(self, actions):
if self.old_layout:
# Unpack the actions into j_left, j_right, buttons
j_left = actions[:, :, :2]
j_right = actions[:, :, 2:4]
buttons = actions[:, :, 4:]
else:
# Unpack the actions into j_left, j_right, buttons
buttons = actions[:, :, :-4]
j_left = actions[:, :, -4:-2]
j_right = actions[:, :, -2:]
# Denormalize the joysticks back to -1,1
j_left = j_left * 2. - 1.
j_right = j_right * 2. - 1.
# Clip into [-1,1]
j_left = torch.clamp(j_left, -1, 1)
j_right = torch.clamp(j_right, -1, 1)
# Threshold the buttons to 0/1
buttons = (buttons > 0.5).float()
return j_left, j_right, buttons
###########################################################################
# apply
###########################################################################
def encode(self, data: dict) -> dict:
"""
Main entry point for the transform. We assume that `data` has
data['video'], data['language'], data['state'], and data['action'] in
the shapes needed. If you have multiple keys for each modality, you
could use your own grouping logic (similar to GR1Transform) first.
"""
# 1) Pack buttons/joysticks into a single action tensor
transformed_data = {**data} # Start with a copy of the input data
n_images = (data["dropped_frames"] == False).sum()
transformed_data["images"] = data["frames"]
transformed_data["dropped_images"] = data["dropped_frames"]
if self.training:
# Keep the original actions in the data for evaluation
packed_actions = self.pack_actions(
data["buttons"],
data["j_left"],
data["j_right"]
)
data["action"] = packed_actions
transformed_data["has_real_action"] = np.ones((), dtype=bool)
actions, actions_mask, n_action_tokens = self._prepare_action(data)
transformed_data["actions"] = actions
transformed_data["actions_mask"] = actions_mask
action_and_mask_keys = ["actions", "actions_mask"]
assert all(
transformed_data[key].shape == transformed_data["actions"].shape
for key in action_and_mask_keys
), f"Shape mismatch: {[(key, transformed_data[key].shape) for key in action_and_mask_keys]}"
else:
n_action_tokens = self.action_horizon
transformed_data["has_detection_target"] = np.zeros((), dtype=bool)
# 5) Build token_ids
vl_token_ids, sa_token_ids = self._build_token_ids(
n_images, n_action_tokens
)
# 6) Build the attention mask only for vision-language tokens
vl_token_ids, vl_attn_mask = self._prepare_attention_mask(vl_token_ids)
transformed_data["vl_token_ids"] = vl_token_ids
transformed_data["sa_token_ids"] = sa_token_ids
transformed_data["vl_attn_mask"] = vl_attn_mask
transformed_data["embodiment_id"] = torch.tensor(0, dtype=torch.long)
if self.game_mapping:
game_name = data["game"]
assert game_name in self.game_mapping, f"Game '{game_name}' not found in game mapping."
transformed_data["game_ids"] = torch.tensor(self.game_mapping[game_name], dtype=torch.long)
else:
transformed_data["game_ids"] = torch.tensor(0, dtype=torch.long)
return transformed_data
def decode(self, data: dict) -> dict:
j_left, j_right, buttons = self.unpack_actions(data["action_tensor"])
return {
"j_left": j_left,
"j_right": j_right,
"buttons": buttons,
}