init commit
This commit is contained in:
252
nitrogen/inference_viz.py
Normal file
252
nitrogen/inference_viz.py
Normal file
@ -0,0 +1,252 @@
|
||||
import numpy as np
|
||||
import cv2
|
||||
import av
|
||||
|
||||
def create_viz(
|
||||
frame: np.ndarray,
|
||||
i: int,
|
||||
j_left: np.ndarray,
|
||||
j_right: np.ndarray,
|
||||
buttons: np.ndarray,
|
||||
token_set: list,
|
||||
):
|
||||
"""
|
||||
Visualize gamepad actions alongside a gameplay video frame.
|
||||
|
||||
Parameters:
|
||||
- frame: Video frame as numpy array
|
||||
- i: Current frame index (default 0)
|
||||
- j_left: 16x2 array of left joystick positions (-1 to 1)
|
||||
- j_right: 16x2 array of right joystick positions (-1 to 1)
|
||||
- buttons: 16x17 array of button states (boolean)
|
||||
- token_set: List of button names
|
||||
|
||||
Returns:
|
||||
- Visualization as numpy array
|
||||
"""
|
||||
# Get frame dimensions
|
||||
frame_height, frame_width = frame.shape[:2]
|
||||
|
||||
# Create visualization area
|
||||
viz_width = min(500, frame_width)
|
||||
combined_width = frame_width + viz_width
|
||||
combined_height = frame_height
|
||||
|
||||
# Create combined image (frame + visualization)
|
||||
combined = np.zeros((combined_height, combined_width, 3), dtype=np.uint8)
|
||||
|
||||
# Put the frame on the left side
|
||||
combined[:frame_height, :frame_width] = frame
|
||||
|
||||
# Starting position for visualizations
|
||||
viz_x = frame_width
|
||||
viz_y = 20
|
||||
|
||||
# Draw joysticks if data is provided
|
||||
if i < len(j_left) and i < len(j_right):
|
||||
# Add section title
|
||||
cv2.putText(combined, "JOYSTICKS",
|
||||
(viz_x + 10, viz_y),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (200, 200, 200), 1)
|
||||
|
||||
viz_y += 30 # Move down after title
|
||||
|
||||
# Size of joystick visualization
|
||||
joy_size = min(120, viz_width // 3)
|
||||
|
||||
# Horizontal positions of joysticks
|
||||
joy_left_x = viz_x + 30
|
||||
joy_right_x = viz_x + viz_width - joy_size - 30
|
||||
|
||||
# Draw joystick labels
|
||||
cv2.putText(combined, "Left", (joy_left_x, viz_y - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (180, 180, 180), 1)
|
||||
cv2.putText(combined, "Right", (joy_right_x, viz_y - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (180, 180, 180), 1)
|
||||
|
||||
# Draw joysticks
|
||||
draw_joystick(combined, joy_left_x, viz_y, joy_size, j_left[i])
|
||||
draw_joystick(combined, joy_right_x, viz_y, joy_size, j_right[i])
|
||||
|
||||
viz_y += joy_size + 40 # Move down after joysticks
|
||||
|
||||
# Draw buttons if data is provided
|
||||
if buttons is not None and i < len(buttons):
|
||||
# Add section title
|
||||
cv2.putText(combined, "BUTTON STATES",
|
||||
(viz_x + 10, viz_y),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (200, 200, 200), 1)
|
||||
|
||||
viz_y += 30 # Move down after title
|
||||
|
||||
# Size and position of button grid
|
||||
button_grid_x = viz_x + 20
|
||||
button_grid_y = viz_y
|
||||
button_size = 20
|
||||
|
||||
# Draw button grid
|
||||
draw_button_grid(combined, button_grid_x, button_grid_y,
|
||||
button_size, buttons, i, token_set)
|
||||
|
||||
return combined
|
||||
|
||||
def draw_joystick(img, x, y, size, position):
|
||||
"""Draw a joystick visualization at the specified position."""
|
||||
# Draw joystick background
|
||||
cv2.rectangle(img, (x, y), (x + size, y + size), (50, 50, 50), -1)
|
||||
cv2.rectangle(img, (x, y), (x + size, y + size), (100, 100, 100), 1)
|
||||
|
||||
# Calculate center point
|
||||
mid_x = x + size // 2
|
||||
mid_y = y + size // 2
|
||||
|
||||
# Draw center cross (0,0 coordinates)
|
||||
cv2.line(img, (x, mid_y), (x + size, mid_y), (150, 150, 150), 1)
|
||||
cv2.line(img, (mid_x, y), (mid_x, y + size), (150, 150, 150), 1)
|
||||
|
||||
# Draw 2x2 grid
|
||||
quarter_x = x + size // 4
|
||||
quarter_y = y + size // 4
|
||||
three_quarters_x = x + 3 * size // 4
|
||||
three_quarters_y = y + 3 * size // 4
|
||||
|
||||
# Draw grid lines
|
||||
cv2.line(img, (quarter_x, y), (quarter_x, y + size), (100, 100, 100), 1)
|
||||
cv2.line(img, (three_quarters_x, y), (three_quarters_x, y + size), (100, 100, 100), 1)
|
||||
cv2.line(img, (x, quarter_y), (x + size, quarter_y), (100, 100, 100), 1)
|
||||
cv2.line(img, (x, three_quarters_y), (x + size, three_quarters_y), (100, 100, 100), 1)
|
||||
|
||||
# Draw joystick position (clamp coordinates to valid range)
|
||||
px = max(-1, min(1, position[0]))
|
||||
py = max(-1, min(1, position[1]))
|
||||
|
||||
joy_x = int(mid_x + px * size // 2)
|
||||
joy_y = int(mid_y - py * size // 2) # Y is inverted in image coordinates
|
||||
|
||||
# Draw joystick position as a dot
|
||||
cv2.circle(img, (joy_x, joy_y), 5, (0, 0, 255), -1) # Red dot
|
||||
|
||||
def draw_button_grid(img, x, y, button_size, buttons, current_row, token_set):
|
||||
"""Draw the button state grid."""
|
||||
rows, cols = buttons.shape
|
||||
|
||||
# Ensure the grid fits in the visualization area
|
||||
available_width = img.shape[1] - x - 20
|
||||
if cols * button_size > available_width:
|
||||
button_size = max(10, available_width // cols)
|
||||
|
||||
# Draw column numbers at the top
|
||||
for col in range(cols):
|
||||
number_x = x + col * button_size + button_size // 2
|
||||
number_y = y - 5
|
||||
cv2.putText(img, str(col + 1), (number_x - 4, number_y),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.4, (200, 200, 200), 1)
|
||||
|
||||
# Draw button grid
|
||||
for row in range(rows):
|
||||
for col in range(cols):
|
||||
# Calculate button position
|
||||
bx = x + col * button_size
|
||||
by = y + row * button_size
|
||||
|
||||
# Draw button cell
|
||||
color = (0, 255, 0) if buttons[row, col] else (0, 0, 0) # Green if pressed, black otherwise
|
||||
cv2.rectangle(img, (bx, by), (bx + button_size, by + button_size), color, -1)
|
||||
|
||||
# Draw grid lines
|
||||
cv2.rectangle(img, (bx, by), (bx + button_size, by + button_size), (80, 80, 80), 1)
|
||||
|
||||
# Highlight current row
|
||||
highlight_y = y + current_row * button_size
|
||||
cv2.rectangle(img, (x, highlight_y), (x + cols * button_size, highlight_y + button_size),
|
||||
(0, 0, 255), 2) # Red highlight
|
||||
|
||||
# Draw button legend below the mosaic
|
||||
if token_set is not None:
|
||||
legend_y = y + rows * button_size + 20 # Starting Y position for legend
|
||||
legend_x = x # Starting X position for legend
|
||||
line_height = 15 # Height of each legend line
|
||||
|
||||
# Add legend title
|
||||
cv2.putText(img, "Button Legend:", (legend_x, legend_y),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (200, 200, 200), 1)
|
||||
legend_y += line_height + 5 # Move down after title
|
||||
|
||||
# Calculate how many columns to use for the legend based on available space
|
||||
legend_cols = max(1, min(3, cols // 6)) # Use 1-3 columns depending on button count
|
||||
legend_items_per_col = (cols + legend_cols - 1) // legend_cols # Items per column with ceiling division
|
||||
|
||||
# Draw legend entries
|
||||
for col in range(min(cols, len(token_set))):
|
||||
# Calculate position in the legend grid
|
||||
legend_col = col // legend_items_per_col
|
||||
legend_row = col % legend_items_per_col
|
||||
|
||||
# Calculate position
|
||||
entry_x = legend_x + legend_col * (available_width // legend_cols)
|
||||
entry_y = legend_y + legend_row * line_height
|
||||
|
||||
# Add legend entry
|
||||
if col < len(token_set):
|
||||
cv2.putText(img, f"{col+1}. {token_set[col]}", (entry_x, entry_y),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.4, (200, 200, 200), 1)
|
||||
|
||||
class VideoRecorder:
|
||||
def __init__(self, output_file, fps=30, crf=28, preset="fast"):
|
||||
"""
|
||||
Initialize a video recorder using PyAV.
|
||||
|
||||
Args:
|
||||
output_file (str): Path to save the video file
|
||||
fps (int): Frames per second
|
||||
crf (int): Constant Rate Factor (0-51, higher means smaller file but lower quality)
|
||||
preset (str): Encoding preset (ultrafast, superfast, veryfast, faster, fast, medium, slow, slower, veryslow)
|
||||
"""
|
||||
self.output_file = output_file
|
||||
self.fps = fps
|
||||
self.crf = str(crf)
|
||||
self.preset = preset
|
||||
self.container = av.open(output_file, mode="w")
|
||||
self.stream = None
|
||||
|
||||
def init_stream(self, width, height):
|
||||
"""Initialize the video stream with the frame dimensions."""
|
||||
self.stream = self.container.add_stream("h264", rate=self.fps)
|
||||
self.stream.width = width
|
||||
self.stream.height = height
|
||||
self.stream.pix_fmt = "yuv420p"
|
||||
self.stream.options = {
|
||||
"crf": self.crf,
|
||||
"preset": self.preset
|
||||
}
|
||||
|
||||
def add_frame(self, frame):
|
||||
"""
|
||||
Add a frame to the video.
|
||||
|
||||
Args:
|
||||
frame (numpy.ndarray): Frame as RGB numpy array
|
||||
"""
|
||||
if self.stream is None:
|
||||
self.init_stream(frame.shape[1], frame.shape[0])
|
||||
|
||||
av_frame = av.VideoFrame.from_ndarray(np.array(frame), format="rgb24")
|
||||
for packet in self.stream.encode(av_frame):
|
||||
self.container.mux(packet)
|
||||
|
||||
def close(self):
|
||||
"""Flush remaining packets and close the video file."""
|
||||
try:
|
||||
if self.stream is not None:
|
||||
for packet in self.stream.encode():
|
||||
self.container.mux(packet)
|
||||
finally:
|
||||
self.container.close()
|
||||
|
||||
def __enter__(self):
|
||||
"""Support for context manager."""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Close the recorder when exiting the context."""
|
||||
self.close()
|
||||
Reference in New Issue
Block a user