import os import sys import time import json from pathlib import Path from collections import OrderedDict import cv2 import numpy as np from PIL import Image from nitrogen.game_env import GamepadEnv from nitrogen.shared import BUTTON_ACTION_TOKENS, PATH_REPO from nitrogen.inference_viz import create_viz, VideoRecorder from nitrogen.inference_client import ModelClient import argparse parser = argparse.ArgumentParser(description="VLM Inference") parser.add_argument("--process", type=str, default="celeste.exe", help="Game to play") parser.add_argument("--allow-menu", action="store_true", help="Allow menu actions (Disabled by default)") parser.add_argument("--port", type=int, default=5555, help="Port for model server") args = parser.parse_args() policy = ModelClient(port=args.port) policy.reset() policy_info = policy.info() action_downsample_ratio = policy_info["action_downsample_ratio"] CKPT_NAME = Path(policy_info["ckpt_path"]).stem NO_MENU = not args.allow_menu PATH_DEBUG = PATH_REPO / "debug" PATH_DEBUG.mkdir(parents=True, exist_ok=True) PATH_OUT = (PATH_REPO / "out" / CKPT_NAME).resolve() PATH_OUT.mkdir(parents=True, exist_ok=True) BUTTON_PRESS_THRES = 0.5 # Find in path_out the list of existing video files, named 0001.mp4, 0002.mp4, etc. # If they exist, find the max number and set the next number to be max + 1 video_files = sorted(PATH_OUT.glob("*_DEBUG.mp4")) if video_files: existing_numbers = [f.name.split("_")[0] for f in video_files] existing_numbers = [int(n) for n in existing_numbers if n.isdigit()] next_number = max(existing_numbers) + 1 else: next_number = 1 PATH_MP4_DEBUG = PATH_OUT / f"{next_number:04d}_DEBUG.mp4" PATH_MP4_CLEAN = PATH_OUT / f"{next_number:04d}_CLEAN.mp4" PATH_ACTIONS = PATH_OUT / f"{next_number:04d}_ACTIONS.json" def preprocess_img(main_image): main_cv = cv2.cvtColor(np.array(main_image), cv2.COLOR_RGB2BGR) final_image = cv2.resize(main_cv, (256, 256), interpolation=cv2.INTER_AREA) return Image.fromarray(cv2.cvtColor(final_image, cv2.COLOR_BGR2RGB)) zero_action = OrderedDict( [ ("WEST", 0), ("SOUTH", 0), ("BACK", 0), ("DPAD_DOWN", 0), ("DPAD_LEFT", 0), ("DPAD_RIGHT", 0), ("DPAD_UP", 0), ("GUIDE", 0), ("AXIS_LEFTX", np.array([0], dtype=np.long)), ("AXIS_LEFTY", np.array([0], dtype=np.long)), ("LEFT_SHOULDER", 0), ("LEFT_TRIGGER", np.array([0], dtype=np.long)), ("AXIS_RIGHTX", np.array([0], dtype=np.long)), ("AXIS_RIGHTY", np.array([0], dtype=np.long)), ("LEFT_THUMB", 0), ("RIGHT_THUMB", 0), ("RIGHT_SHOULDER", 0), ("RIGHT_TRIGGER", np.array([0], dtype=np.long)), ("START", 0), ("EAST", 0), ("NORTH", 0), ] ) TOKEN_SET = BUTTON_ACTION_TOKENS print("Model loaded, starting environment...") for i in range(3): print(f"{3 - i}...") time.sleep(1) env = GamepadEnv( game=args.process, game_speed=1.0, env_fps=60, async_mode=True, ) # These games requires to open a menu to initialize the controller if args.process == "isaac-ng.exe": print(f"GamepadEnv ready for {args.process} at {env.env_fps} FPS") input("Press enter to create a virtual controller and start rollouts...") for i in range(3): print(f"{3 - i}...") time.sleep(1) def press(button): env.gamepad_emulator.press_button(button) env.gamepad_emulator.gamepad.update() time.sleep(0.05) env.gamepad_emulator.release_button(button) env.gamepad_emulator.gamepad.update() press("SOUTH") for k in range(5): press("EAST") time.sleep(0.3) if args.process == "Cuphead.exe": print(f"GamepadEnv ready for {args.process} at {env.env_fps} FPS") input("Press enter to create a virtual controller and start rollouts...") for i in range(3): print(f"{3 - i}...") time.sleep(1) def press(button): env.gamepad_emulator.press_button(button) env.gamepad_emulator.gamepad.update() time.sleep(0.05) env.gamepad_emulator.release_button(button) env.gamepad_emulator.gamepad.update() press("SOUTH") for k in range(5): press("EAST") time.sleep(0.3) env.reset() env.pause() # Initial call to get state obs, reward, terminated, truncated, info = env.step(action=zero_action) frames = None step_count = 0 with VideoRecorder(str(PATH_MP4_DEBUG), fps=60, crf=32, preset="medium") as debug_recorder: with VideoRecorder(str(PATH_MP4_CLEAN), fps=60, crf=28, preset="medium") as clean_recorder: try: while True: obs = preprocess_img(obs) obs.save(PATH_DEBUG / f"{step_count:05d}.png") pred = policy.predict(obs) j_left, j_right, buttons = pred["j_left"], pred["j_right"], pred["buttons"] n = len(buttons) assert n == len(j_left) == len(j_right), "Mismatch in action lengths" env_actions = [] for i in range(n): move_action = zero_action.copy() xl, yl = j_left[i] xr, yr = j_right[i] move_action["AXIS_LEFTX"] = np.array([int(xl * 32767)], dtype=np.long) move_action["AXIS_LEFTY"] = np.array([int(yl * 32767)], dtype=np.long) move_action["AXIS_RIGHTX"] = np.array([int(xr * 32767)], dtype=np.long) move_action["AXIS_RIGHTY"] = np.array([int(yr * 32767)], dtype=np.long) button_vector = buttons[i] assert len(button_vector) == len(TOKEN_SET), "Button vector length does not match token set length" for name, value in zip(TOKEN_SET, button_vector): if "TRIGGER" in name: move_action[name] = np.array([value * 255], dtype=np.long) else: move_action[name] = 1 if value > BUTTON_PRESS_THRES else 0 env_actions.append(move_action) print(f"Executing {len(env_actions)} actions, each action will be repeated {action_downsample_ratio} times") for i, a in enumerate(env_actions): if NO_MENU: if a["START"]: print("Model predicted start, disabling this action") a["GUIDE"] = 0 a["START"] = 0 a["BACK"] = 0 for _ in range(action_downsample_ratio): obs, reward, terminated, truncated, info = env.step(action=a) # resize obs to 720p obs_viz = np.array(obs).copy() clean_viz = cv2.resize(obs_viz, (1920, 1080), interpolation=cv2.INTER_AREA) debug_viz = create_viz( cv2.resize(obs_viz, (1280, 720), interpolation=cv2.INTER_AREA), # 720p i, j_left, j_right, buttons, token_set=TOKEN_SET ) debug_recorder.add_frame(debug_viz) clean_recorder.add_frame(clean_viz) # Append env_actions dictionnary to JSONL file with open(PATH_ACTIONS, "a") as f: for i, a in enumerate(env_actions): # convert numpy arrays to lists for JSON serialization for k, v in a.items(): if isinstance(v, np.ndarray): a[k] = v.tolist() a["step"] = step_count a["substep"] = i json.dump(a, f) f.write("\n") step_count += 1 finally: env.unpause() env.close()