26 lines
1.5 KiB
Python
26 lines
1.5 KiB
Python
from pydantic import BaseModel, Field
|
|
|
|
from nitrogen.flow_matching_transformer.nitrogen import NitroGen_Config
|
|
from nitrogen.mm_tokenizers import NitrogenTokenizerConfig
|
|
|
|
class ModalityConfig(BaseModel):
|
|
frame_per_sample: int = 1 # number of context frames per sample
|
|
frame_spacing: int | None = None # how many frames to skip between each frame. If None, use action_per_chunk
|
|
action_per_chunk: int = 8
|
|
action_shift: int = 1 # how many actions to skip between frame[i] and action_chunk[i]
|
|
action_interleaving: bool = False # if True, action chunks will be interleaved with context frames and used by the model to predict the next actions
|
|
token_set: str = "new"
|
|
|
|
def model_post_init(self, __context):
|
|
if self.frame_spacing is None:
|
|
# Use object.__setattr__ because the model is frozen
|
|
object.__setattr__(self, 'frame_spacing', self.action_per_chunk)
|
|
assert self.action_shift >= 1, "Frame shift must be at least 1 for correct action indexing"
|
|
|
|
|
|
class CkptConfig(BaseModel):
|
|
experiment_name: str = Field(..., description="Name of the experiment")
|
|
|
|
model_cfg: NitroGen_Config = Field(..., description="Model configuration. This is a placeholder and should be replaced with the actual model config class.")
|
|
tokenizer_cfg: NitrogenTokenizerConfig = Field(..., description="Tokenizer configuration. This is a placeholder and should be replaced with the actual tokenizer config class.")
|
|
modality_cfg: ModalityConfig = Field(..., description="Modality configuration for the dataset mixture.") |