233 lines
8.0 KiB
Python
233 lines
8.0 KiB
Python
import os
|
|
import sys
|
|
import time
|
|
import json
|
|
from pathlib import Path
|
|
from collections import OrderedDict
|
|
|
|
import cv2
|
|
import numpy as np
|
|
from PIL import Image
|
|
|
|
from nitrogen.game_env import GamepadEnv
|
|
from nitrogen.shared import BUTTON_ACTION_TOKENS, PATH_REPO
|
|
from nitrogen.inference_viz import create_viz, VideoRecorder
|
|
from nitrogen.inference_client import ModelClient
|
|
|
|
import argparse
|
|
parser = argparse.ArgumentParser(description="VLM Inference")
|
|
parser.add_argument("--process", type=str, default="celeste.exe", help="Game to play")
|
|
parser.add_argument("--allow-menu", action="store_true", help="Allow menu actions (Disabled by default)")
|
|
parser.add_argument("--port", type=int, default=5555, help="Port for model server")
|
|
|
|
args = parser.parse_args()
|
|
|
|
policy = ModelClient(port=args.port)
|
|
policy.reset()
|
|
policy_info = policy.info()
|
|
action_downsample_ratio = policy_info["action_downsample_ratio"]
|
|
|
|
CKPT_NAME = Path(policy_info["ckpt_path"]).stem
|
|
NO_MENU = not args.allow_menu
|
|
|
|
PATH_DEBUG = PATH_REPO / "debug"
|
|
PATH_DEBUG.mkdir(parents=True, exist_ok=True)
|
|
|
|
PATH_OUT = (PATH_REPO / "out" / CKPT_NAME).resolve()
|
|
PATH_OUT.mkdir(parents=True, exist_ok=True)
|
|
|
|
BUTTON_PRESS_THRES = 0.5
|
|
|
|
# Find in path_out the list of existing video files, named 0001.mp4, 0002.mp4, etc.
|
|
# If they exist, find the max number and set the next number to be max + 1
|
|
video_files = sorted(PATH_OUT.glob("*_DEBUG.mp4"))
|
|
if video_files:
|
|
existing_numbers = [f.name.split("_")[0] for f in video_files]
|
|
existing_numbers = [int(n) for n in existing_numbers if n.isdigit()]
|
|
next_number = max(existing_numbers) + 1
|
|
else:
|
|
next_number = 1
|
|
|
|
PATH_MP4_DEBUG = PATH_OUT / f"{next_number:04d}_DEBUG.mp4"
|
|
PATH_MP4_CLEAN = PATH_OUT / f"{next_number:04d}_CLEAN.mp4"
|
|
PATH_ACTIONS = PATH_OUT / f"{next_number:04d}_ACTIONS.json"
|
|
|
|
def preprocess_img(main_image):
|
|
main_cv = cv2.cvtColor(np.array(main_image), cv2.COLOR_RGB2BGR)
|
|
final_image = cv2.resize(main_cv, (256, 256), interpolation=cv2.INTER_AREA)
|
|
return Image.fromarray(cv2.cvtColor(final_image, cv2.COLOR_BGR2RGB))
|
|
|
|
zero_action = OrderedDict(
|
|
[
|
|
("WEST", 0),
|
|
("SOUTH", 0),
|
|
("BACK", 0),
|
|
("DPAD_DOWN", 0),
|
|
("DPAD_LEFT", 0),
|
|
("DPAD_RIGHT", 0),
|
|
("DPAD_UP", 0),
|
|
("GUIDE", 0),
|
|
("AXIS_LEFTX", np.array([0], dtype=np.long)),
|
|
("AXIS_LEFTY", np.array([0], dtype=np.long)),
|
|
("LEFT_SHOULDER", 0),
|
|
("LEFT_TRIGGER", np.array([0], dtype=np.long)),
|
|
("AXIS_RIGHTX", np.array([0], dtype=np.long)),
|
|
("AXIS_RIGHTY", np.array([0], dtype=np.long)),
|
|
("LEFT_THUMB", 0),
|
|
("RIGHT_THUMB", 0),
|
|
("RIGHT_SHOULDER", 0),
|
|
("RIGHT_TRIGGER", np.array([0], dtype=np.long)),
|
|
("START", 0),
|
|
("EAST", 0),
|
|
("NORTH", 0),
|
|
]
|
|
)
|
|
|
|
TOKEN_SET = BUTTON_ACTION_TOKENS
|
|
|
|
print("Model loaded, starting environment...")
|
|
for i in range(3):
|
|
print(f"{3 - i}...")
|
|
time.sleep(1)
|
|
|
|
env = GamepadEnv(
|
|
game=args.process,
|
|
game_speed=1.0,
|
|
env_fps=60,
|
|
async_mode=True,
|
|
)
|
|
|
|
# These games requires to open a menu to initialize the controller
|
|
if args.process == "isaac-ng.exe":
|
|
print(f"GamepadEnv ready for {args.process} at {env.env_fps} FPS")
|
|
input("Press enter to create a virtual controller and start rollouts...")
|
|
for i in range(3):
|
|
print(f"{3 - i}...")
|
|
time.sleep(1)
|
|
|
|
def press(button):
|
|
env.gamepad_emulator.press_button(button)
|
|
env.gamepad_emulator.gamepad.update()
|
|
time.sleep(0.05)
|
|
env.gamepad_emulator.release_button(button)
|
|
env.gamepad_emulator.gamepad.update()
|
|
|
|
press("SOUTH")
|
|
for k in range(5):
|
|
press("EAST")
|
|
time.sleep(0.3)
|
|
|
|
if args.process == "Cuphead.exe":
|
|
print(f"GamepadEnv ready for {args.process} at {env.env_fps} FPS")
|
|
input("Press enter to create a virtual controller and start rollouts...")
|
|
for i in range(3):
|
|
print(f"{3 - i}...")
|
|
time.sleep(1)
|
|
|
|
def press(button):
|
|
env.gamepad_emulator.press_button(button)
|
|
env.gamepad_emulator.gamepad.update()
|
|
time.sleep(0.05)
|
|
env.gamepad_emulator.release_button(button)
|
|
env.gamepad_emulator.gamepad.update()
|
|
|
|
press("SOUTH")
|
|
for k in range(5):
|
|
press("EAST")
|
|
time.sleep(0.3)
|
|
|
|
env.reset()
|
|
env.pause()
|
|
|
|
|
|
# Initial call to get state
|
|
obs, reward, terminated, truncated, info = env.step(action=zero_action)
|
|
|
|
frames = None
|
|
step_count = 0
|
|
|
|
with VideoRecorder(str(PATH_MP4_DEBUG), fps=60, crf=32, preset="medium") as debug_recorder:
|
|
with VideoRecorder(str(PATH_MP4_CLEAN), fps=60, crf=28, preset="medium") as clean_recorder:
|
|
try:
|
|
while True:
|
|
obs = preprocess_img(obs)
|
|
obs.save(PATH_DEBUG / f"{step_count:05d}.png")
|
|
|
|
pred = policy.predict(obs)
|
|
|
|
j_left, j_right, buttons = pred["j_left"], pred["j_right"], pred["buttons"]
|
|
|
|
n = len(buttons)
|
|
assert n == len(j_left) == len(j_right), "Mismatch in action lengths"
|
|
|
|
|
|
env_actions = []
|
|
|
|
for i in range(n):
|
|
move_action = zero_action.copy()
|
|
|
|
xl, yl = j_left[i]
|
|
xr, yr = j_right[i]
|
|
move_action["AXIS_LEFTX"] = np.array([int(xl * 32767)], dtype=np.long)
|
|
move_action["AXIS_LEFTY"] = np.array([int(yl * 32767)], dtype=np.long)
|
|
move_action["AXIS_RIGHTX"] = np.array([int(xr * 32767)], dtype=np.long)
|
|
move_action["AXIS_RIGHTY"] = np.array([int(yr * 32767)], dtype=np.long)
|
|
|
|
button_vector = buttons[i]
|
|
assert len(button_vector) == len(TOKEN_SET), "Button vector length does not match token set length"
|
|
|
|
|
|
for name, value in zip(TOKEN_SET, button_vector):
|
|
if "TRIGGER" in name:
|
|
move_action[name] = np.array([value * 255], dtype=np.long)
|
|
else:
|
|
move_action[name] = 1 if value > BUTTON_PRESS_THRES else 0
|
|
|
|
|
|
env_actions.append(move_action)
|
|
|
|
print(f"Executing {len(env_actions)} actions, each action will be repeated {action_downsample_ratio} times")
|
|
|
|
for i, a in enumerate(env_actions):
|
|
if NO_MENU:
|
|
if a["START"]:
|
|
print("Model predicted start, disabling this action")
|
|
a["GUIDE"] = 0
|
|
a["START"] = 0
|
|
a["BACK"] = 0
|
|
|
|
for _ in range(action_downsample_ratio):
|
|
obs, reward, terminated, truncated, info = env.step(action=a)
|
|
|
|
# resize obs to 720p
|
|
obs_viz = np.array(obs).copy()
|
|
clean_viz = cv2.resize(obs_viz, (1920, 1080), interpolation=cv2.INTER_AREA)
|
|
debug_viz = create_viz(
|
|
cv2.resize(obs_viz, (1280, 720), interpolation=cv2.INTER_AREA), # 720p
|
|
i,
|
|
j_left,
|
|
j_right,
|
|
buttons,
|
|
token_set=TOKEN_SET
|
|
)
|
|
debug_recorder.add_frame(debug_viz)
|
|
clean_recorder.add_frame(clean_viz)
|
|
|
|
# Append env_actions dictionnary to JSONL file
|
|
with open(PATH_ACTIONS, "a") as f:
|
|
for i, a in enumerate(env_actions):
|
|
# convert numpy arrays to lists for JSON serialization
|
|
for k, v in a.items():
|
|
if isinstance(v, np.ndarray):
|
|
a[k] = v.tolist()
|
|
a["step"] = step_count
|
|
a["substep"] = i
|
|
json.dump(a, f)
|
|
f.write("\n")
|
|
|
|
|
|
step_count += 1
|
|
finally:
|
|
env.unpause()
|
|
env.close()
|