init commit
This commit is contained in:
276
nitrogen/inference_session.py
Normal file
276
nitrogen/inference_session.py
Normal file
@ -0,0 +1,276 @@
|
||||
import time
|
||||
import json
|
||||
from collections import deque
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from transformers import AutoImageProcessor
|
||||
from nitrogen.flow_matching_transformer.nitrogen import NitroGen, NitroGen_Config
|
||||
from nitrogen.mm_tokenizers import NitrogenTokenizerConfig, NitrogenTokenizer, Tokenizer
|
||||
from nitrogen.cfg import CkptConfig
|
||||
from nitrogen.shared import PATH_REPO
|
||||
|
||||
def summarize_parameters(module, name='model', depth=0, max_depth=3):
|
||||
"""
|
||||
Print a tree-like summary of parameters in a PyTorch module.
|
||||
|
||||
Args:
|
||||
module: PyTorch module to summarize
|
||||
name: Name of the module (for root level)
|
||||
depth: Current depth in the tree
|
||||
max_depth: Maximum depth to traverse
|
||||
"""
|
||||
if depth > max_depth:
|
||||
return
|
||||
|
||||
# Count total parameters in this module
|
||||
total_params = sum(p.numel() for p in module.parameters())
|
||||
trainable_params = sum(p.numel() for p in module.parameters() if p.requires_grad)
|
||||
|
||||
# Print indented summary
|
||||
indent = " " * depth
|
||||
print(f"{indent}{name}: {total_params:,} params ({trainable_params:,} trainable)")
|
||||
|
||||
# Recursively summarize submodules
|
||||
if depth < max_depth:
|
||||
for child_name, child_module in module.named_children():
|
||||
summarize_parameters(child_module, child_name, depth + 1, max_depth)
|
||||
|
||||
|
||||
def load_model(checkpoint_path: str):
|
||||
"""Load model and args from checkpoint."""
|
||||
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
||||
ckpt_config = CkptConfig.model_validate(checkpoint["ckpt_config"])
|
||||
model_cfg = ckpt_config.model_cfg
|
||||
tokenizer_cfg = ckpt_config.tokenizer_cfg
|
||||
|
||||
print("Checkpoint args:")
|
||||
print(json.dumps(ckpt_config.model_dump(), indent=4))
|
||||
|
||||
# Initialize tokenizer and language model
|
||||
img_proc = AutoImageProcessor.from_pretrained(model_cfg.vision_encoder_name)
|
||||
|
||||
# Create VLM with pre-loaded language model
|
||||
if isinstance(model_cfg, NitroGen_Config):
|
||||
assert isinstance(tokenizer_cfg, NitrogenTokenizerConfig), \
|
||||
"NitroGen_Config requires NitrogenTokenizerConfig for tokenization"
|
||||
tokenizer_cfg.training = False
|
||||
if tokenizer_cfg.game_mapping_cfg is not None:
|
||||
tokenizer_cfg.game_mapping_cfg.src_files = [
|
||||
x.replace("/mnt/amlfs-02/shared/gaming/gamingvla", str(PATH_REPO))
|
||||
for x in tokenizer_cfg.game_mapping_cfg.src_files
|
||||
]
|
||||
tokenizer = NitrogenTokenizer(tokenizer_cfg)
|
||||
game_mapping = tokenizer.game_mapping
|
||||
model = NitroGen(config=model_cfg, game_mapping=game_mapping)
|
||||
# model.num_inference_timesteps = 16
|
||||
action_downsample_ratio = 1
|
||||
else:
|
||||
raise ValueError(f"Unsupported model config type: {type(model_cfg)}")
|
||||
|
||||
summarize_parameters(model, max_depth=3)
|
||||
|
||||
print(model)
|
||||
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
model.eval()
|
||||
tokenizer.eval()
|
||||
model.to("cuda")
|
||||
|
||||
return model, tokenizer, img_proc, ckpt_config, game_mapping, action_downsample_ratio
|
||||
|
||||
class InferenceSession:
|
||||
"""Manages state for a single inference session."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
ckpt_path: str,
|
||||
tokenizer: Tokenizer,
|
||||
img_proc,
|
||||
ckpt_config: CkptConfig,
|
||||
game_mapping: dict,
|
||||
selected_game: str,
|
||||
old_layout: bool,
|
||||
cfg_scale: float,
|
||||
action_downsample_ratio: float,
|
||||
context_length=None
|
||||
):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.img_proc = img_proc
|
||||
self.ckpt_config = ckpt_config
|
||||
self.game_mapping = game_mapping
|
||||
self.selected_game = selected_game
|
||||
self.old_layout = old_layout
|
||||
self.cfg_scale = cfg_scale
|
||||
self.action_downsample_ratio = action_downsample_ratio
|
||||
self.ckpt_path = ckpt_path
|
||||
|
||||
# Load modality config
|
||||
self.modality_config = self.ckpt_config.modality_cfg
|
||||
|
||||
self.max_buffer_size = context_length if context_length is not None else self.modality_config.frame_per_sample
|
||||
self.action_interleaving = self.modality_config.action_interleaving
|
||||
self.is_flowmatching = isinstance(self.ckpt_config.model_cfg, NitroGen_Config)
|
||||
|
||||
# Buffers
|
||||
self.obs_buffer = deque(maxlen=self.max_buffer_size)
|
||||
self.action_buffer = deque(maxlen=self.max_buffer_size)
|
||||
|
||||
@classmethod
|
||||
def from_ckpt(cls, checkpoint_path: str, old_layout=False, cfg_scale=1.0, context_length=None):
|
||||
"""Create an InferenceSession from a checkpoint."""
|
||||
model, tokenizer, img_proc, ckpt_config, game_mapping, action_downsample_ratio = load_model(checkpoint_path)
|
||||
|
||||
if game_mapping is not None:
|
||||
# Ask user to pick a game from the list
|
||||
print("Available games in tokenizer mapping:")
|
||||
for game, idx in game_mapping.items():
|
||||
print(f"{idx:03d}: {game}")
|
||||
selected_game = input("Enter the game ID to use (leave empty for unconditional): ")
|
||||
if selected_game == "":
|
||||
selected_game = None
|
||||
else:
|
||||
selected_idx = int(selected_game)
|
||||
assert selected_idx in game_mapping.values(), f"Invalid game ID {selected_idx}"
|
||||
|
||||
candidates = [k for k,v in game_mapping.items() if v == selected_idx]
|
||||
assert len(candidates) == 1, f"Multiple games found for ID {selected_idx}: {candidates}"
|
||||
|
||||
selected_game = candidates[0]
|
||||
else:
|
||||
selected_game = None
|
||||
print("No game mapping available, proceeding without game conditioning")
|
||||
|
||||
return cls(
|
||||
model,
|
||||
checkpoint_path,
|
||||
tokenizer,
|
||||
img_proc,
|
||||
ckpt_config,
|
||||
game_mapping,
|
||||
selected_game,
|
||||
old_layout,
|
||||
cfg_scale,
|
||||
action_downsample_ratio,
|
||||
context_length
|
||||
)
|
||||
|
||||
def info(self):
|
||||
return {
|
||||
"ckpt_path": self.ckpt_path,
|
||||
"selected_game": self.selected_game,
|
||||
"old_layout": self.old_layout,
|
||||
"cfg_scale": self.cfg_scale,
|
||||
"context_length": self.max_buffer_size,
|
||||
"action_interleaving": self.action_interleaving,
|
||||
"is_flowmatching": self.is_flowmatching,
|
||||
"action_downsample_ratio": self.action_downsample_ratio,
|
||||
}
|
||||
|
||||
def reset(self):
|
||||
"""Reset all buffers."""
|
||||
self.obs_buffer.clear()
|
||||
self.action_buffer.clear()
|
||||
|
||||
def predict(self, obs):
|
||||
start_time = time.time()
|
||||
|
||||
current_frame = self.img_proc([obs], return_tensors="pt")["pixel_values"]
|
||||
self.obs_buffer.append(current_frame)
|
||||
|
||||
# Prepare model inputs
|
||||
pixel_values = torch.cat(list(self.obs_buffer), dim=0)
|
||||
|
||||
if self.action_interleaving and len(self.action_buffer) > 0:
|
||||
action_tensors = {
|
||||
key: torch.cat([a[key] for a in list(self.action_buffer)], dim=0)
|
||||
for key in ["buttons", "j_left", "j_right"]
|
||||
}
|
||||
else:
|
||||
action_tensors = {"buttons": None, "j_left": None, "j_right": None}
|
||||
|
||||
print("Running inference with the following inputs:")
|
||||
print(f"- pixel_values: {pixel_values.shape}")
|
||||
print("- action_tensors:")
|
||||
for k, v in action_tensors.items():
|
||||
if v is not None:
|
||||
print(f" - {k}: {v.shape}")
|
||||
else:
|
||||
print(f" - {k}: None")
|
||||
|
||||
# Run inference
|
||||
if self.is_flowmatching:
|
||||
predicted_actions = self._predict_flowmatching(pixel_values, action_tensors)
|
||||
else:
|
||||
predicted_actions = self._predict_ar(pixel_values, action_tensors)
|
||||
|
||||
# Add to action buffer
|
||||
self.action_buffer.append(predicted_actions)
|
||||
|
||||
inference_time = time.time() - start_time
|
||||
print(f"Inference time: {inference_time:.3f}s")
|
||||
|
||||
# Convert to list of action dicts
|
||||
n_actions = len(predicted_actions["buttons"])
|
||||
j_left = predicted_actions["j_left"].squeeze().cpu().numpy()
|
||||
j_right = predicted_actions["j_right"].squeeze().cpu().numpy()
|
||||
buttons = predicted_actions["buttons"].squeeze().cpu().numpy()
|
||||
|
||||
return {
|
||||
"j_left": j_left,
|
||||
"j_right": j_right,
|
||||
"buttons": buttons,
|
||||
}
|
||||
|
||||
def _predict_flowmatching(self, pixel_values, action_tensors):
|
||||
|
||||
available_frames = len(self.obs_buffer)
|
||||
frames = torch.zeros((self.max_buffer_size, *pixel_values.shape[1:]),
|
||||
dtype=pixel_values.dtype, device="cuda")
|
||||
frames[-available_frames:] = pixel_values
|
||||
dropped_frames = torch.zeros((self.max_buffer_size,), dtype=torch.bool, device="cuda")
|
||||
dropped_frames[:self.max_buffer_size - available_frames] = True
|
||||
|
||||
data_with_history = {
|
||||
"frames": frames,
|
||||
"dropped_frames": dropped_frames,
|
||||
"game": self.selected_game
|
||||
}
|
||||
tokenized_data_with_history = self.tokenizer.encode(data_with_history)
|
||||
|
||||
frame_mask = torch.ones((self.max_buffer_size,), dtype=torch.bool, device="cuda")
|
||||
frame_mask[-1] = False
|
||||
data_without_history = {
|
||||
"frames": frames,
|
||||
"dropped_frames": frame_mask,
|
||||
"game": None
|
||||
}
|
||||
tokenized_data_without_history = self.tokenizer.encode(data_without_history)
|
||||
|
||||
# Convert to CUDA tensors with batch dimension
|
||||
for tokenized_data in [tokenized_data_with_history, tokenized_data_without_history]:
|
||||
for k, v in tokenized_data.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
tokenized_data[k] = v.unsqueeze(0).to("cuda")
|
||||
elif isinstance(v, np.ndarray):
|
||||
tokenized_data[k] = torch.tensor(v, device="cuda").unsqueeze(0)
|
||||
else:
|
||||
tokenized_data[k] = [v]
|
||||
|
||||
with torch.inference_mode():
|
||||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||
if self.cfg_scale == 1.0:
|
||||
model_output = self.model.get_action(tokenized_data_with_history,
|
||||
old_layout=self.old_layout)
|
||||
else:
|
||||
model_output = self.model.get_action_with_cfg(
|
||||
tokenized_data_with_history,
|
||||
tokenized_data_without_history,
|
||||
cfg_scale=self.cfg_scale
|
||||
)
|
||||
predicted_actions = self.tokenizer.decode(model_output)
|
||||
|
||||
return predicted_actions
|
||||
Reference in New Issue
Block a user