Files
NitroGen/scripts/play.py
2025-12-19 17:21:03 +01:00

233 lines
8.0 KiB
Python

import os
import sys
import time
import json
from pathlib import Path
from collections import OrderedDict
import cv2
import numpy as np
from PIL import Image
from nitrogen.game_env import GamepadEnv
from nitrogen.shared import BUTTON_ACTION_TOKENS, PATH_REPO
from nitrogen.inference_viz import create_viz, VideoRecorder
from nitrogen.inference_client import ModelClient
import argparse
parser = argparse.ArgumentParser(description="VLM Inference")
parser.add_argument("--process", type=str, default="celeste.exe", help="Game to play")
parser.add_argument("--allow-menu", action="store_true", help="Allow menu actions (Disabled by default)")
parser.add_argument("--port", type=int, default=5555, help="Port for model server")
args = parser.parse_args()
policy = ModelClient(port=args.port)
policy.reset()
policy_info = policy.info()
action_downsample_ratio = policy_info["action_downsample_ratio"]
CKPT_NAME = Path(policy_info["ckpt_path"]).stem
NO_MENU = not args.allow_menu
PATH_DEBUG = PATH_REPO / "debug"
PATH_DEBUG.mkdir(parents=True, exist_ok=True)
PATH_OUT = (PATH_REPO / "out" / CKPT_NAME).resolve()
PATH_OUT.mkdir(parents=True, exist_ok=True)
BUTTON_PRESS_THRES = 0.5
# Find in path_out the list of existing video files, named 0001.mp4, 0002.mp4, etc.
# If they exist, find the max number and set the next number to be max + 1
video_files = sorted(PATH_OUT.glob("*_DEBUG.mp4"))
if video_files:
existing_numbers = [f.name.split("_")[0] for f in video_files]
existing_numbers = [int(n) for n in existing_numbers if n.isdigit()]
next_number = max(existing_numbers) + 1
else:
next_number = 1
PATH_MP4_DEBUG = PATH_OUT / f"{next_number:04d}_DEBUG.mp4"
PATH_MP4_CLEAN = PATH_OUT / f"{next_number:04d}_CLEAN.mp4"
PATH_ACTIONS = PATH_OUT / f"{next_number:04d}_ACTIONS.json"
def preprocess_img(main_image):
main_cv = cv2.cvtColor(np.array(main_image), cv2.COLOR_RGB2BGR)
final_image = cv2.resize(main_cv, (256, 256), interpolation=cv2.INTER_AREA)
return Image.fromarray(cv2.cvtColor(final_image, cv2.COLOR_BGR2RGB))
zero_action = OrderedDict(
[
("WEST", 0),
("SOUTH", 0),
("BACK", 0),
("DPAD_DOWN", 0),
("DPAD_LEFT", 0),
("DPAD_RIGHT", 0),
("DPAD_UP", 0),
("GUIDE", 0),
("AXIS_LEFTX", np.array([0], dtype=np.long)),
("AXIS_LEFTY", np.array([0], dtype=np.long)),
("LEFT_SHOULDER", 0),
("LEFT_TRIGGER", np.array([0], dtype=np.long)),
("AXIS_RIGHTX", np.array([0], dtype=np.long)),
("AXIS_RIGHTY", np.array([0], dtype=np.long)),
("LEFT_THUMB", 0),
("RIGHT_THUMB", 0),
("RIGHT_SHOULDER", 0),
("RIGHT_TRIGGER", np.array([0], dtype=np.long)),
("START", 0),
("EAST", 0),
("NORTH", 0),
]
)
TOKEN_SET = BUTTON_ACTION_TOKENS
print("Model loaded, starting environment...")
for i in range(3):
print(f"{3 - i}...")
time.sleep(1)
env = GamepadEnv(
game=args.process,
game_speed=1.0,
env_fps=60,
async_mode=True,
)
# These games requires to open a menu to initialize the controller
if args.process == "isaac-ng.exe":
print(f"GamepadEnv ready for {args.process} at {env.env_fps} FPS")
input("Press enter to create a virtual controller and start rollouts...")
for i in range(3):
print(f"{3 - i}...")
time.sleep(1)
def press(button):
env.gamepad_emulator.press_button(button)
env.gamepad_emulator.gamepad.update()
time.sleep(0.05)
env.gamepad_emulator.release_button(button)
env.gamepad_emulator.gamepad.update()
press("SOUTH")
for k in range(5):
press("EAST")
time.sleep(0.3)
if args.process == "Cuphead.exe":
print(f"GamepadEnv ready for {args.process} at {env.env_fps} FPS")
input("Press enter to create a virtual controller and start rollouts...")
for i in range(3):
print(f"{3 - i}...")
time.sleep(1)
def press(button):
env.gamepad_emulator.press_button(button)
env.gamepad_emulator.gamepad.update()
time.sleep(0.05)
env.gamepad_emulator.release_button(button)
env.gamepad_emulator.gamepad.update()
press("SOUTH")
for k in range(5):
press("EAST")
time.sleep(0.3)
env.reset()
env.pause()
# Initial call to get state
obs, reward, terminated, truncated, info = env.step(action=zero_action)
frames = None
step_count = 0
with VideoRecorder(str(PATH_MP4_DEBUG), fps=60, crf=32, preset="medium") as debug_recorder:
with VideoRecorder(str(PATH_MP4_CLEAN), fps=60, crf=28, preset="medium") as clean_recorder:
try:
while True:
obs = preprocess_img(obs)
obs.save(PATH_DEBUG / f"{step_count:05d}.png")
pred = policy.predict(obs)
j_left, j_right, buttons = pred["j_left"], pred["j_right"], pred["buttons"]
n = len(buttons)
assert n == len(j_left) == len(j_right), "Mismatch in action lengths"
env_actions = []
for i in range(n):
move_action = zero_action.copy()
xl, yl = j_left[i]
xr, yr = j_right[i]
move_action["AXIS_LEFTX"] = np.array([int(xl * 32767)], dtype=np.long)
move_action["AXIS_LEFTY"] = np.array([int(yl * 32767)], dtype=np.long)
move_action["AXIS_RIGHTX"] = np.array([int(xr * 32767)], dtype=np.long)
move_action["AXIS_RIGHTY"] = np.array([int(yr * 32767)], dtype=np.long)
button_vector = buttons[i]
assert len(button_vector) == len(TOKEN_SET), "Button vector length does not match token set length"
for name, value in zip(TOKEN_SET, button_vector):
if "TRIGGER" in name:
move_action[name] = np.array([value * 255], dtype=np.long)
else:
move_action[name] = 1 if value > BUTTON_PRESS_THRES else 0
env_actions.append(move_action)
print(f"Executing {len(env_actions)} actions, each action will be repeated {action_downsample_ratio} times")
for i, a in enumerate(env_actions):
if NO_MENU:
if a["START"]:
print("Model predicted start, disabling this action")
a["GUIDE"] = 0
a["START"] = 0
a["BACK"] = 0
for _ in range(action_downsample_ratio):
obs, reward, terminated, truncated, info = env.step(action=a)
# resize obs to 720p
obs_viz = np.array(obs).copy()
clean_viz = cv2.resize(obs_viz, (1920, 1080), interpolation=cv2.INTER_AREA)
debug_viz = create_viz(
cv2.resize(obs_viz, (1280, 720), interpolation=cv2.INTER_AREA), # 720p
i,
j_left,
j_right,
buttons,
token_set=TOKEN_SET
)
debug_recorder.add_frame(debug_viz)
clean_recorder.add_frame(clean_viz)
# Append env_actions dictionnary to JSONL file
with open(PATH_ACTIONS, "a") as f:
for i, a in enumerate(env_actions):
# convert numpy arrays to lists for JSON serialization
for k, v in a.items():
if isinstance(v, np.ndarray):
a[k] = v.tolist()
a["step"] = step_count
a["substep"] = i
json.dump(a, f)
f.write("\n")
step_count += 1
finally:
env.unpause()
env.close()