Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 57 additions & 3 deletions platoon/episode/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,53 @@
finish_message,
)

CLEANUP_TIMEOUT = 180 # seconds to allow each close() call before giving up
# from openhands.sdk.conversation import ConversationExecutionStatus
# from platoon.utils.openhands_utils import is_finished
# def agent_finished(obs):
# if obs.conversation_state.execution_status in [
# ConversationExecutionStatus.FINISHED,
# ConversationExecutionStatus.STUCK,
# ConversationExecutionStatus.ERROR
# ]:
# return True
# return False

# NOTE: This function should be called using asyncio.create_task() to make sure edits to contextvars do not leak to parent context
async def run_episode(agent: Agent, env: Env, verbose: bool = False, timeout: int = 300) -> Trajectory:
async def run_episode(agent: Agent, env: Env, verbose: bool = True, timeout: int = 300) -> Trajectory:
curr = "pre reset"
try:
step_count = 0
set_context_vars(agent, env)
print("waiting for env.reset()", flush=True)
obs = await env.reset()
curr = "reset done"
# while True:
# import time
# time.sleep(10000000)
while not halt_episode(obs):
# if agent_finished(obs):
# print("OpenHands Finished -- waiting for agent.act() to complete", flush=True)
print("waiting for agent.act()", flush=True)
curr = "before act"
action = await asyncio.wait_for(agent.act(obs), timeout=timeout)
# if agent_finished(obs):
# print("OpenHands Finished -- waiting for env.step() to complete", flush=True)
print("waiting for env.step()", flush=True)
curr = "before step"
obs = await asyncio.wait_for(env.step(action), timeout=timeout)
print("completed env.step()", flush=True)
# if agent_finished(obs):
# print("OpenHands Finished -- env.step() completed", flush=True)
# if not is_finished(obs):
# print(f"WARNING: Conversation execution status is {obs.conversation_state.execution_status} but is_finished() returned False", flush=True)
step_count += 1
except asyncio.CancelledError:
# Task was cancelled by parent (e.g. rollout timeout via wait_for).
# Catch it so the finally block can run normally without re-cancellation.
error_message.set(f"Episode cancelled at step {step_count} (likely rollout timeout)")
if verbose:
print(f"Episode cancelled at step {step_count}", flush=True)
except Exception as e:
tb_summary = traceback.extract_tb(e.__traceback__)
origin = ""
Expand All @@ -40,13 +77,30 @@ async def run_episode(agent: Agent, env: Env, verbose: bool = False, timeout: in
print(detailed_msg)
error_message.set(detailed_msg)
finally:
await agent.close()
await env.close()
# Cleanup with bounded timeouts so a blocking close() can't stall the process.
# Use asyncio.shield() so that CancelledError from a parent wait_for()
# doesn't prevent cleanup from running.
# for label, closeable in [("agent", agent), ("env", env)]:
# try:
# await asyncio.shield(
# asyncio.wait_for(closeable.close(), timeout=CLEANUP_TIMEOUT)
# )
# except asyncio.CancelledError:
# print(f"Warning: {label}.close() was cancelled, cleanup may be incomplete", flush=True)
# except asyncio.TimeoutError:
# print(f"Warning: {label}.close() timed out after {CLEANUP_TIMEOUT}s, skipping", flush=True)
# except Exception as e:
# print(f"Warning: {label}.close() raised {e}, skipping", flush=True)
# Finalize trajectory and emit a finish event to sinks
traj_collection = current_trajectory_collection.get()
traj = current_trajectory.get()
traj.error_message = error_message.get()
traj.finish_message = finish_message.get()
print(f"Current state: {curr}", flush=True)
if traj.finish_message is None:
traj.finish_message = f"Episode finished without a finish message: {curr}"
# if traj.error_message is None:
# traj.error_message = "Rollout finished without an error or finish message"
# TODO: We could move out trajectory finish logic (adding up rewards, setting finish message, etc.) from env logic to here.
traj_collection.finish_trajectory(traj.id)
return traj
Expand Down
8 changes: 8 additions & 0 deletions platoon/train/areal/patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,14 @@ async def patched_create(
# Convert messages to prompt format
tools_val = tools if not is_omitted(tools) else None
if self.chat_template_type == "hf":
for message in messages_list:
if isinstance(message["content"], list):
new_content = "".join(
item.get("text", "")
for item in message["content"]
if isinstance(item, dict) and item.get("type") == "text"
)
message["content"] = new_content
prompt_token_ids = self.tokenizer.apply_chat_template(
messages_list,
tools=tools_val,
Expand Down
2 changes: 1 addition & 1 deletion platoon/train/areal/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __init__(
self.ref.initialize(None, self.ft_spec)

# Setup proxy server
self.llm_client = ArealOpenAI(engine=self.rollout, tokenizer=self.tokenizer)
self.llm_client = ArealOpenAI(engine=self.rollout, tokenizer=self.tokenizer, tool_call_parser="qwen25")
free_port = find_free_ports(1)[0]
self.proxy_server = ProxyServer(free_port, client=self.llm_client)
self.proxy_server.start(wait_until_ready=True)
Expand Down
Loading