init commit
This commit is contained in:
332
nitrogen/mm_tokenizers.py
Normal file
332
nitrogen/mm_tokenizers.py
Normal file
@ -0,0 +1,332 @@
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import polars as pl
|
||||
from pydantic import BaseModel, Field
|
||||
from itertools import count
|
||||
|
||||
|
||||
DEBUG = int(os.getenv("DEBUG", 0))
|
||||
|
||||
class Tokenizer(ABC):
|
||||
@abstractmethod
|
||||
def encode(self, data: dict) -> dict:
|
||||
"""
|
||||
Transform the input data into a tokenized format.
|
||||
|
||||
Args:
|
||||
data (dict): Input data containing frames and actions.
|
||||
|
||||
Returns:
|
||||
dict: Tokenized data ready for model input.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def decode(self, data: dict) -> dict:
|
||||
"""
|
||||
Reverse the tokenization process to retrieve original data.
|
||||
|
||||
Args:
|
||||
data (dict): Tokenized data.
|
||||
|
||||
Returns:
|
||||
dict: Original data structure.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def train(self):
|
||||
"""
|
||||
Set the tokenizer to training mode.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def eval(self):
|
||||
"""
|
||||
Set the tokenizer to evaluation mode.
|
||||
"""
|
||||
pass
|
||||
|
||||
# Set IDs for each token type.
|
||||
_PAD_TOKEN = 0
|
||||
_IMG_TOKEN = 1
|
||||
_IMG_SEP_TOKEN = 5 # New separator token
|
||||
_LANG_TOKEN = 2
|
||||
_PROPRIO_TOKEN = 3
|
||||
_ACT_TOKEN = 4
|
||||
_GAME_ID_TOKEN = 6
|
||||
|
||||
|
||||
_UNCONDITIONAL_ID = None # Special ID for unconditional game
|
||||
|
||||
class GameMappingConfig(BaseModel):
|
||||
src_files: list[str] = Field(default_factory=list, description="List of source parquet files to build game mapping.")
|
||||
|
||||
def get_game_mapping(cfg: GameMappingConfig) -> dict:
|
||||
game_set = set()
|
||||
for path in cfg.src_files:
|
||||
df = pl.read_parquet(path)
|
||||
for game in df['game_label'].unique():
|
||||
if game == _UNCONDITIONAL_ID:
|
||||
continue
|
||||
game_set.add(game)
|
||||
games = sorted(list(game_set))
|
||||
|
||||
# Set the 0th element to be the unconditional game ID
|
||||
games = [_UNCONDITIONAL_ID] + games
|
||||
return {game: idx for idx, game in enumerate(games)}
|
||||
|
||||
class NitrogenTokenizerConfig(BaseModel):
|
||||
tokenizer_id: Literal['nitrogen'] = Field(default='nitrogen', frozen=True)
|
||||
training: bool = Field(default=True, description="Whether to apply the transform in training mode.")
|
||||
num_visual_tokens_per_frame: int = Field(default=256, description="Number of visual tokens per frame.")
|
||||
max_action_dim: int = Field(default=25, description="Maximum action dimension.")
|
||||
max_sequence_length: int = Field(default=300, description="Maximum sequence length.")
|
||||
action_horizon: int = Field(default=16, description="Action horizon.")
|
||||
game_mapping_cfg: GameMappingConfig | None = Field(default=None, description="Game mapping configuration.")
|
||||
old_layout: bool = Field(default=False, description="Whether to use the old layout for actions. If True, the action layout is [buttons, j_left, j_right]. If False, it is [j_left, j_right, buttons].")
|
||||
|
||||
class NitrogenTokenizer(Tokenizer):
|
||||
"""
|
||||
Example transform that prepares video, language, state, and actions
|
||||
into a token-based format suitable for your model.
|
||||
|
||||
The sub-methods below (prefixed with `_prepare_`) mirror the original
|
||||
modular structure.
|
||||
"""
|
||||
|
||||
def __init__(self, config: NitrogenTokenizerConfig):
|
||||
self.training = config.training
|
||||
self.num_visual_tokens_per_frame = config.num_visual_tokens_per_frame
|
||||
self.max_action_dim = config.max_action_dim
|
||||
self.max_sequence_length = config.max_sequence_length
|
||||
self.action_horizon = config.action_horizon
|
||||
self.old_layout = config.old_layout
|
||||
|
||||
if config.game_mapping_cfg:
|
||||
self.game_mapping = get_game_mapping(config.game_mapping_cfg)
|
||||
with open("game_mapping.json", "w") as f:
|
||||
import json
|
||||
json.dump(self.game_mapping, f, indent=2)
|
||||
else:
|
||||
self.game_mapping = None
|
||||
|
||||
def train(self):
|
||||
self.training = True
|
||||
|
||||
def eval(self):
|
||||
self.training = False
|
||||
|
||||
def check_batch_size(self, data):
|
||||
# Use video key to determine batch size.
|
||||
video_ndim = data["images"].ndim
|
||||
if video_ndim == 4: # Interpret as [T*V, H, W, C]
|
||||
is_batched = False
|
||||
batch_size = 1
|
||||
elif video_ndim == 5: # Interpret as [B, T*V, H, W, C]
|
||||
is_batched = True
|
||||
batch_size = data["images"].shape[0]
|
||||
else:
|
||||
raise ValueError(f"Unsupported video number of dimensions: {video_ndim}")
|
||||
|
||||
return is_batched, batch_size
|
||||
|
||||
def _prepare_action(self, data: dict):
|
||||
"""
|
||||
Pad to max_action_dim, return masks.
|
||||
"""
|
||||
if "action" not in data:
|
||||
actions = np.zeros((self.action_horizon, self.max_action_dim))
|
||||
actions_mask = np.zeros((self.action_horizon, self.max_action_dim), dtype=bool)
|
||||
n_action_tokens = self.action_horizon
|
||||
return actions, actions_mask, n_action_tokens
|
||||
|
||||
actions = data["action"]
|
||||
assert actions.shape[0] == self.action_horizon, f"{actions.shape=}, {self.action_horizon=}"
|
||||
|
||||
n_action_tokens = actions.shape[0] # T
|
||||
n_action_dims = actions.shape[1]
|
||||
|
||||
assert (
|
||||
n_action_dims <= self.max_action_dim
|
||||
), f"Action dim {n_action_dims} exceeds max allowed {self.max_action_dim}."
|
||||
|
||||
# Pad the channel dimension
|
||||
actions = np.pad(actions, ((0, 0), (0, self.max_action_dim - n_action_dims)), "constant")
|
||||
|
||||
# Create mask: [T, max_action_dim]
|
||||
actions_mask = np.zeros((n_action_tokens, self.max_action_dim), dtype=bool)
|
||||
actions_mask[:, :n_action_dims] = True
|
||||
|
||||
return actions, actions_mask, n_action_tokens
|
||||
|
||||
def _build_token_ids(self, n_images, n_action_tokens): # n_lang_tokens, n_state_tokens):
|
||||
"""
|
||||
Build the 1D array of token_ids based on the number of each block.
|
||||
Return (token_ids, special_pad_token_idx).
|
||||
"""
|
||||
vl_token_ids = []
|
||||
sa_token_ids = []
|
||||
|
||||
# 0.5) Add a Game ID placeholder
|
||||
if self.game_mapping:
|
||||
vl_token_ids.append(_GAME_ID_TOKEN)
|
||||
|
||||
# 1) Video placeholders
|
||||
for _ in range(n_images):
|
||||
vl_token_ids.extend([_IMG_TOKEN] * self.num_visual_tokens_per_frame)
|
||||
|
||||
# 2) Action tokens
|
||||
sa_token_ids.extend([_ACT_TOKEN] * n_action_tokens)
|
||||
|
||||
return np.array(vl_token_ids), np.array(sa_token_ids)
|
||||
|
||||
def _prepare_attention_mask(
|
||||
self,
|
||||
vl_token_ids: np.ndarray,
|
||||
):
|
||||
"""
|
||||
Build 1D attention mask for vision-language tokens.
|
||||
1 indicates valid token, 0 indicates padding token.
|
||||
State-action attention will be handled separately by the model.
|
||||
"""
|
||||
# Only create attention mask for vision-language tokens
|
||||
vl_seq_len = vl_token_ids.shape[0]
|
||||
vl_attn_mask = np.ones(vl_seq_len, dtype=bool) # All tokens are valid initially
|
||||
|
||||
# Pad vl_token_ids and vl_attn_mask to max_sequence_length
|
||||
if vl_seq_len > self.max_sequence_length:
|
||||
raise ValueError("VL sequence length exceeds the max sequence length!")
|
||||
|
||||
left_pad_len = self.max_sequence_length - vl_seq_len
|
||||
|
||||
# Pad token_ids (with PAD_TOKEN)
|
||||
vl_token_ids = np.pad(vl_token_ids, (left_pad_len, 0), constant_values=_PAD_TOKEN)
|
||||
|
||||
# Pad attention mask with 0 (padding tokens)
|
||||
vl_attn_mask = np.pad(vl_attn_mask, (left_pad_len, 0), constant_values=0)
|
||||
|
||||
return vl_token_ids, vl_attn_mask
|
||||
|
||||
def pack_actions(self, buttons, j_left, j_right):
|
||||
# Check that the first two dims of each input is the same (number of chunks, control frequency)
|
||||
assert buttons.shape[:2] == j_left.shape[:2] == j_right.shape[:2], (
|
||||
f"buttons shape: {buttons.shape}, "
|
||||
f"j_left shape: {j_left.shape}, "
|
||||
f"j_right shape: {j_right.shape}"
|
||||
)
|
||||
|
||||
# Normalize the joysticks to 0,1
|
||||
j_left = (j_left + 1) / 2.
|
||||
j_right = (j_right + 1) / 2.
|
||||
|
||||
# Concatenate the buttons and joysticks along the last dimension
|
||||
action = np.concatenate([buttons,j_left,j_right],axis=-1, dtype=np.float32)
|
||||
|
||||
# Squeeze the first dimension of each input: this is the number of chunks, which is 1 here
|
||||
action = action.squeeze(0)
|
||||
return action
|
||||
|
||||
def unpack_actions(self, actions):
|
||||
if self.old_layout:
|
||||
# Unpack the actions into j_left, j_right, buttons
|
||||
j_left = actions[:, :, :2]
|
||||
j_right = actions[:, :, 2:4]
|
||||
buttons = actions[:, :, 4:]
|
||||
else:
|
||||
# Unpack the actions into j_left, j_right, buttons
|
||||
buttons = actions[:, :, :-4]
|
||||
j_left = actions[:, :, -4:-2]
|
||||
j_right = actions[:, :, -2:]
|
||||
|
||||
# Denormalize the joysticks back to -1,1
|
||||
j_left = j_left * 2. - 1.
|
||||
j_right = j_right * 2. - 1.
|
||||
|
||||
# Clip into [-1,1]
|
||||
j_left = torch.clamp(j_left, -1, 1)
|
||||
j_right = torch.clamp(j_right, -1, 1)
|
||||
|
||||
# Threshold the buttons to 0/1
|
||||
buttons = (buttons > 0.5).float()
|
||||
return j_left, j_right, buttons
|
||||
|
||||
###########################################################################
|
||||
# apply
|
||||
###########################################################################
|
||||
def encode(self, data: dict) -> dict:
|
||||
"""
|
||||
Main entry point for the transform. We assume that `data` has
|
||||
data['video'], data['language'], data['state'], and data['action'] in
|
||||
the shapes needed. If you have multiple keys for each modality, you
|
||||
could use your own grouping logic (similar to GR1Transform) first.
|
||||
"""
|
||||
|
||||
# 1) Pack buttons/joysticks into a single action tensor
|
||||
|
||||
|
||||
transformed_data = {**data} # Start with a copy of the input data
|
||||
|
||||
n_images = (data["dropped_frames"] == False).sum()
|
||||
transformed_data["images"] = data["frames"]
|
||||
transformed_data["dropped_images"] = data["dropped_frames"]
|
||||
|
||||
if self.training:
|
||||
# Keep the original actions in the data for evaluation
|
||||
packed_actions = self.pack_actions(
|
||||
data["buttons"],
|
||||
data["j_left"],
|
||||
data["j_right"]
|
||||
)
|
||||
data["action"] = packed_actions
|
||||
|
||||
transformed_data["has_real_action"] = np.ones((), dtype=bool)
|
||||
|
||||
actions, actions_mask, n_action_tokens = self._prepare_action(data)
|
||||
transformed_data["actions"] = actions
|
||||
transformed_data["actions_mask"] = actions_mask
|
||||
|
||||
action_and_mask_keys = ["actions", "actions_mask"]
|
||||
assert all(
|
||||
transformed_data[key].shape == transformed_data["actions"].shape
|
||||
for key in action_and_mask_keys
|
||||
), f"Shape mismatch: {[(key, transformed_data[key].shape) for key in action_and_mask_keys]}"
|
||||
else:
|
||||
n_action_tokens = self.action_horizon
|
||||
|
||||
transformed_data["has_detection_target"] = np.zeros((), dtype=bool)
|
||||
|
||||
# 5) Build token_ids
|
||||
vl_token_ids, sa_token_ids = self._build_token_ids(
|
||||
n_images, n_action_tokens
|
||||
)
|
||||
|
||||
# 6) Build the attention mask only for vision-language tokens
|
||||
vl_token_ids, vl_attn_mask = self._prepare_attention_mask(vl_token_ids)
|
||||
|
||||
transformed_data["vl_token_ids"] = vl_token_ids
|
||||
transformed_data["sa_token_ids"] = sa_token_ids
|
||||
transformed_data["vl_attn_mask"] = vl_attn_mask
|
||||
transformed_data["embodiment_id"] = torch.tensor(0, dtype=torch.long)
|
||||
|
||||
if self.game_mapping:
|
||||
game_name = data["game"]
|
||||
assert game_name in self.game_mapping, f"Game '{game_name}' not found in game mapping."
|
||||
transformed_data["game_ids"] = torch.tensor(self.game_mapping[game_name], dtype=torch.long)
|
||||
else:
|
||||
transformed_data["game_ids"] = torch.tensor(0, dtype=torch.long)
|
||||
return transformed_data
|
||||
|
||||
def decode(self, data: dict) -> dict:
|
||||
j_left, j_right, buttons = self.unpack_actions(data["action_tensor"])
|
||||
|
||||
return {
|
||||
"j_left": j_left,
|
||||
"j_right": j_right,
|
||||
"buttons": buttons,
|
||||
}
|
||||
Reference in New Issue
Block a user