Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions platoon/config_defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class RolloutConfig:
timeout: int | None = None # Trajectory timeout (entire rollout)
step_timeout: int = 300 # Per-step timeout (agent.act + env.step)
return_dict: bool = False
propogate_root_success: bool = False
skip_subagent_reward_computation: bool = False
inference_params: InferenceParams = field(default_factory=InferenceParams)

def __post_init__(self) -> None:
Expand Down
64 changes: 64 additions & 0 deletions platoon/utils/subagent_rewards.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from __future__ import annotations

from typing import Any

from platoon.episode.trajectory import TrajectoryCollection


def _get_trajectories(trajectory_collection: dict[str, Any] | TrajectoryCollection) -> dict[str, Any]:
return (
trajectory_collection["trajectories"]
if isinstance(trajectory_collection, dict)
else trajectory_collection.trajectories
)


def _get_trajectory_reward(trajectory: Any) -> float:
return float(trajectory["reward"] if isinstance(trajectory, dict) else trajectory.reward)


def _set_trajectory_reward(trajectory: Any, reward: float) -> None:
if isinstance(trajectory, dict):
trajectory["reward"] = reward
else:
trajectory.reward = reward


def _get_steps(trajectory: Any) -> list[Any]:
return trajectory.get("steps", []) if isinstance(trajectory, dict) else trajectory.steps


def _get_step_reward_misc(step: Any) -> dict[str, Any]:
if isinstance(step, dict):
return step.setdefault("misc", {}).setdefault("reward_misc", {})
if step.misc is None:
step.misc = {}
return step.misc.setdefault("reward_misc", {})


def propogate_root_success(
trajectory_collection: dict[str, Any] | TrajectoryCollection,
) -> dict[str, Any] | TrajectoryCollection:
"""Rewrite recursive rollout rewards so all trajectories use root success."""
trajectories = _get_trajectories(trajectory_collection)
if not trajectories:
return trajectory_collection

_, root_trajectory = next(iter(trajectories.items()))
root_steps = _get_steps(root_trajectory)
root_success = _get_trajectory_reward(root_trajectory)
if root_steps:
root_success = float(_get_step_reward_misc(root_steps[-1]).get("reward/success", root_success))

for trajectory in trajectories.values():
_set_trajectory_reward(trajectory, root_success)
steps = _get_steps(trajectory)
if steps:
_get_step_reward_misc(steps[-1])["reward/success"] = root_success
for step in steps:
reward_misc = _get_step_reward_misc(step)
launched = float(reward_misc.get("reward/subagent_launched", 0.0))
if launched > 0:
reward_misc["reward/subagent_succeeded"] = launched * root_success

return trajectory_collection
27 changes: 25 additions & 2 deletions plugins/appworld/platoon/appworld/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,10 +488,12 @@ def __init__(
task: Task,
code_executor: AppWorldCodeExecutor | None = None,
timeout_seconds: int | None = DEFAULT_APPWORLD_TIMEOUT_SECONDS,
skip_subagent_reward_computation: bool = False,
**kwargs,
):
if code_executor is None:
code_executor = AppWorldCodeExecutor(task, timeout_seconds=timeout_seconds)
self._skip_subagent_reward_computation = skip_subagent_reward_computation

super().__init__(task, code_executor, **kwargs)

Expand All @@ -506,6 +508,11 @@ async def reset(self) -> CodeActObservation:

async def evaluate(self) -> tuple[float, dict]:
score, reward_misc = 0., {}
is_subagent_task = isinstance(self._task, SubTask) and bool(self._task.parent_tasks)
if self._skip_subagent_reward_computation and is_subagent_task:
reward_misc["reason"] = "Skipped subagent reward computation"
reward_misc["reward/success"] = 0.0
return 0.0, reward_misc

if self._state.finished:
if isinstance(self._task, SubTask) and self._task.parent_tasks:
Expand Down Expand Up @@ -541,6 +548,7 @@ async def fork(self, task: Task) -> AppWorldEnv:
return type(self)(
task,
code_executor=code_executor,
skip_subagent_reward_computation=self._skip_subagent_reward_computation,
)


Expand All @@ -552,11 +560,18 @@ def __init__(
task: Task,
code_executor: AppWorldRecursiveCodeExecutor | None = None,
timeout_seconds: int | None = DEFAULT_APPWORLD_TIMEOUT_SECONDS,
skip_subagent_reward_computation: bool = False,
**kwargs,
):
if code_executor is None:
code_executor = AppWorldRecursiveCodeExecutor(task, timeout_seconds=timeout_seconds)
super().__init__(task, code_executor, **kwargs)
super().__init__(
task,
code_executor,
timeout_seconds=timeout_seconds,
skip_subagent_reward_computation=skip_subagent_reward_computation,
**kwargs,
)

@property
def code_executor(self) -> AppWorldRecursiveCodeExecutor:
Expand Down Expand Up @@ -643,6 +658,7 @@ def __init__(
code_executor: AppWorldDepthAwareCodeExecutor | None = None,
subagent_max_steps: int = 25,
timeout_seconds: int | None = DEFAULT_APPWORLD_TIMEOUT_SECONDS,
skip_subagent_reward_computation: bool = False,
**kwargs,
):
self._subagent_max_steps = subagent_max_steps
Expand All @@ -652,13 +668,20 @@ def __init__(
subagent_max_steps=subagent_max_steps,
timeout_seconds=timeout_seconds,
)
super().__init__(task, code_executor, **kwargs)
super().__init__(
task,
code_executor,
timeout_seconds=timeout_seconds,
skip_subagent_reward_computation=skip_subagent_reward_computation,
**kwargs,
)

async def fork(self, task: Task) -> "AppWorldDepthAwareEnv":
code_executor = await self.code_executor.fork(task)
return AppWorldDepthAwareEnv(
task,
code_executor=code_executor,
subagent_max_steps=self._subagent_max_steps,
skip_subagent_reward_computation=self._skip_subagent_reward_computation,
)

24 changes: 19 additions & 5 deletions plugins/appworld/platoon/appworld/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from platoon.episode.loop import run_episode
from platoon.episode.trajectory import DepthAwareStepBudgetTracker, TrajectoryCollection
from platoon.utils.llm_client import LiteLLMClient
from platoon.utils.subagent_rewards import propogate_root_success
from platoon.visualization.event_sinks import JsonlFileSink

from .agent import AppWorldAgent, AppWorldDepthAwareAgent, AppWorldRecursiveAgent
Expand Down Expand Up @@ -88,7 +89,11 @@ async def run_recursive_rollout(task: Task, config: RolloutConfig) -> dict | Tra
base_url=config.model_endpoint,
api_key=config.model_api_key,
)
env = AppWorldRecursiveEnv(task, timeout_seconds=config.step_timeout)
env = AppWorldRecursiveEnv(
task,
timeout_seconds=config.step_timeout,
skip_subagent_reward_computation=config.skip_subagent_reward_computation,
)
agent = AppWorldRecursiveAgent(
llm_client=llm_client,
inference_params=config.inference_params,
Expand Down Expand Up @@ -123,10 +128,14 @@ async def run_recursive_rollout(task: Task, config: RolloutConfig) -> dict | Tra
)
raise

result: dict | TrajectoryCollection
if config.return_dict:
return current_trajectory_collection.get().to_dict()
result = current_trajectory_collection.get().to_dict()
else:
return current_trajectory_collection.get()
result = current_trajectory_collection.get()
if config.propogate_root_success:
result = propogate_root_success(result)
return result

except Exception as e:
if config.verbose:
Expand Down Expand Up @@ -171,6 +180,7 @@ async def run_depth_aware_rollout(
task,
subagent_max_steps=per_subagent_max_steps,
timeout_seconds=config.step_timeout,
skip_subagent_reward_computation=config.skip_subagent_reward_computation,
)
agent = AppWorldDepthAwareAgent(
llm_client=llm_client,
Expand Down Expand Up @@ -209,10 +219,14 @@ async def run_depth_aware_rollout(
)
raise

result: dict | TrajectoryCollection
if config.return_dict:
return current_trajectory_collection.get().to_dict()
result = current_trajectory_collection.get().to_dict()
else:
return current_trajectory_collection.get()
result = current_trajectory_collection.get()
if config.propogate_root_success:
result = propogate_root_success(result)
return result

except Exception as e:
if config.verbose:
Expand Down
13 changes: 11 additions & 2 deletions plugins/deepdive/platoon/deepdive/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,10 @@ async def fork(self, task: Task) -> DeepDiveRecursiveCodeExecutor:
)

class DeepDiveEnv(CodeActEnv):
def __init__(self, task: Task):
def __init__(self, task: Task, skip_subagent_reward_computation: bool = False):
#task.fork_strategy = "task"
super().__init__(task, DeepDiveCodeExecutor(task))
self._skip_subagent_reward_computation = skip_subagent_reward_computation

def _parse_rubric_response(self, response: str) -> dict:
"""Parse the LLM response to extract structured data.
Expand Down Expand Up @@ -282,6 +283,12 @@ def parse_rubric_response(self, response: str) -> tuple[float, str]:

async def evaluate(self) -> tuple[float, dict]:
score, reward_misc = 0., {}
is_subagent_task = "deepdive" not in (self._task.id or "")
if self._skip_subagent_reward_computation and is_subagent_task:
reward_misc["reason"] = "Skipped subagent reward computation"
reward_misc["success"] = False
reward_misc["reward/success"] = 0.0
return 0.0, reward_misc

final_message = finish_message.get()
if final_message is None and self._state.history:
Expand Down Expand Up @@ -347,8 +354,9 @@ def __init__(
self,
task: Task,
subagent_max_steps: int | None = 25,
skip_subagent_reward_computation: bool = False,
):
super().__init__(task)
super().__init__(task, skip_subagent_reward_computation=skip_subagent_reward_computation)
self._code_executor = DeepDiveRecursiveCodeExecutor(
task=task,
subagent_max_steps=subagent_max_steps
Expand All @@ -372,4 +380,5 @@ async def fork(self, task: Task) -> DeepDiveRecursiveEnv:
return DeepDiveRecursiveEnv(
task=task,
subagent_max_steps=self.subagent_max_steps,
skip_subagent_reward_computation=self._skip_subagent_reward_computation,
)
15 changes: 12 additions & 3 deletions plugins/deepdive/platoon/deepdive/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from platoon.episode.loop import run_episode
from platoon.episode.trajectory import DepthAwareStepBudgetTracker, TrajectoryCollection
from platoon.utils.llm_client import LiteLLMClient
from platoon.utils.subagent_rewards import propogate_root_success
from platoon.visualization.event_sinks import JsonlFileSink

from .agent import DeepDiveAgent, DeepDiveRecursiveAgent
Expand Down Expand Up @@ -91,7 +92,10 @@ async def run_recursive_rollout(task: Task, config: RolloutConfig) -> dict | Tra
base_url=config.model_endpoint,
api_key=config.model_api_key,
)
env = DeepDiveRecursiveEnv(task)
env = DeepDiveRecursiveEnv(
task,
skip_subagent_reward_computation=config.skip_subagent_reward_computation,
)
agent = DeepDiveRecursiveAgent(
llm_client=llm_client,
inference_params=config.inference_params,
Expand Down Expand Up @@ -133,9 +137,14 @@ async def run_recursive_rollout(task: Task, config: RolloutConfig) -> dict | Tra
)
raise

result: dict | TrajectoryCollection
if config.return_dict:
return current_trajectory_collection.get().to_dict()
return current_trajectory_collection.get()
result = current_trajectory_collection.get().to_dict()
else:
result = current_trajectory_collection.get()
if config.propogate_root_success:
result = propogate_root_success(result)
return result
except Exception as e:
if config.verbose:
print(f"Error running rollout for task {task.id}: {e}")
Expand Down
12 changes: 10 additions & 2 deletions plugins/oolong/platoon/oolong/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,12 @@ async def fork(self, task: Task) -> OolongRecursiveCodeExecutor:


class OolongEnv(CodeActEnv):
def __init__(self, task: Task):
def __init__(self, task: Task, skip_subagent_reward_computation: bool = False):
task.fork_strategy = "task"
code_executor = OolongCodeExecutor(task)
# Remove context task misc to avoid massive context logging in events
self.context = task.misc.pop('context')
self._skip_subagent_reward_computation = skip_subagent_reward_computation
super().__init__(task, code_executor)

def _parse_rubric_response(self, response: str) -> dict:
Expand Down Expand Up @@ -204,6 +205,11 @@ async def evaluate(self) -> tuple[float, dict]:

score = 0.0
reward_misc = {}
is_subagent_task = "oolong" not in (self._task.id or "")
if self._skip_subagent_reward_computation and is_subagent_task:
reward_misc["reason"] = "Skipped subagent reward computation"
reward_misc["reward/success"] = 0.0
return 0.0, reward_misc

if self._state.finished:
if not "oolong" in self._task.id:
Expand Down Expand Up @@ -287,12 +293,13 @@ async def evaluate(self) -> tuple[float, dict]:
class OolongRecursiveEnv(OolongEnv):
def __init__(self, task: Task,
subagent_max_steps: int | None = 25,
skip_subagent_reward_computation: bool = False,
):
code_executor = OolongRecursiveCodeExecutor(
task,
subagent_max_steps=subagent_max_steps
)
super().__init__(task)
super().__init__(task, skip_subagent_reward_computation=skip_subagent_reward_computation)
self._code_executor = code_executor
self.subagent_max_steps = subagent_max_steps

Expand All @@ -315,6 +322,7 @@ async def fork(self, task: Task) -> OolongRecursiveEnv:
return OolongRecursiveEnv(
task,
subagent_max_steps=self.subagent_max_steps,
skip_subagent_reward_computation=self._skip_subagent_reward_computation,
)


14 changes: 11 additions & 3 deletions plugins/oolong/platoon/oolong/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .agent import OolongAgent, OolongRecursiveAgent
from platoon.config_defs import RolloutConfig
from platoon.utils.llm_client import LiteLLMClient
from platoon.utils.subagent_rewards import propogate_root_success
from platoon.episode.context import current_trajectory_collection, budget_tracker
from platoon.episode.loop import run_episode
from platoon.episode.trajectory import TrajectoryCollection, DepthAwareStepBudgetTracker
Expand Down Expand Up @@ -109,7 +110,10 @@ async def run_recursive_rollout(task: Task, config: RolloutConfig) -> dict | Tra
# Disable Qwen3 reasoning/thinking mode for faster inference
# default_extra_body={"chat_template_kwargs": {"enable_thinking": False}},
)
env = OolongRecursiveEnv(task)
env = OolongRecursiveEnv(
task,
skip_subagent_reward_computation=config.skip_subagent_reward_computation,
)
agent = OolongRecursiveAgent(
llm_client=llm_client,
inference_params=config.inference_params,
Expand Down Expand Up @@ -150,10 +154,14 @@ async def run_recursive_rollout(task: Task, config: RolloutConfig) -> dict | Tra
logger.warning(f"Process {os.getpid()}: Task cancellation did not complete in 5s for {task.id}, abandoning")
raise

result: dict | TrajectoryCollection
if config.return_dict:
return current_trajectory_collection.get().to_dict()
result = current_trajectory_collection.get().to_dict()
else:
return current_trajectory_collection.get()
result = current_trajectory_collection.get()
if config.propogate_root_success:
result = propogate_root_success(result)
return result

except Exception as e:
if config.verbose:
Expand Down
Loading
Loading