init commit
This commit is contained in:
91
nitrogen/inference_client.py
Normal file
91
nitrogen/inference_client.py
Normal file
@ -0,0 +1,91 @@
|
||||
import time
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
import zmq
|
||||
|
||||
class ModelClient:
|
||||
"""Client for model inference server."""
|
||||
|
||||
def __init__(self, host="localhost", port=5555):
|
||||
"""
|
||||
Initialize client connection.
|
||||
|
||||
Args:
|
||||
host: Server hostname or IP
|
||||
port: Server port
|
||||
"""
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.timeout_ms = 30000
|
||||
|
||||
self.context = zmq.Context()
|
||||
self.socket = self.context.socket(zmq.REQ)
|
||||
self.socket.connect(f"tcp://{host}:{port}")
|
||||
self.socket.setsockopt(zmq.RCVTIMEO, self.timeout_ms) # Set receive timeout
|
||||
|
||||
print(f"Connected to model server at {host}:{port}")
|
||||
|
||||
def predict(self, image: np.ndarray) -> dict:
|
||||
"""
|
||||
Send an image and receive predicted actions.
|
||||
|
||||
Args:
|
||||
image: numpy array (H, W, 3) in RGB format
|
||||
|
||||
Returns:
|
||||
List of action dicts, each containing:
|
||||
- j_left: [x, y] left joystick position
|
||||
- j_right: [x, y] right joystick position
|
||||
- buttons: list of button values
|
||||
"""
|
||||
request = {
|
||||
"type": "predict",
|
||||
"image": image
|
||||
}
|
||||
|
||||
self.socket.send(pickle.dumps(request))
|
||||
response = pickle.loads(self.socket.recv())
|
||||
|
||||
if response["status"] != "ok":
|
||||
raise RuntimeError(f"Server error: {response.get('message', 'Unknown error')}")
|
||||
|
||||
return response["pred"]
|
||||
|
||||
def reset(self):
|
||||
"""Reset the server's session (clear buffers)."""
|
||||
request = {"type": "reset"}
|
||||
|
||||
self.socket.send(pickle.dumps(request))
|
||||
response = pickle.loads(self.socket.recv())
|
||||
|
||||
if response["status"] != "ok":
|
||||
raise RuntimeError(f"Server error: {response.get('message', 'Unknown error')}")
|
||||
|
||||
print("Session reset")
|
||||
|
||||
def info(self) -> dict:
|
||||
"""Get session info from the server."""
|
||||
request = {"type": "info"}
|
||||
|
||||
self.socket.send(pickle.dumps(request))
|
||||
response = pickle.loads(self.socket.recv())
|
||||
|
||||
if response["status"] != "ok":
|
||||
raise RuntimeError(f"Server error: {response.get('message', 'Unknown error')}")
|
||||
|
||||
return response["info"]
|
||||
|
||||
def close(self):
|
||||
"""Close the connection."""
|
||||
self.socket.close()
|
||||
self.context.term()
|
||||
print("Connection closed")
|
||||
|
||||
def __enter__(self):
|
||||
"""Support for context manager."""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Close connection when exiting context."""
|
||||
self.close()
|
||||
Reference in New Issue
Block a user