Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions aws-rft-sdk/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[project]
name = "aws-rft-sdk"
version = "0.1.0"
description = "AWS Reinforcement Fine-Tuning SDK for online rollout-based training"
readme = {text = "", content-type = "text/markdown"}
requires-python = ">=3.9"
dependencies = [
"boto3>=1.35.0",
"requests>=2.28.0",
]

[project.optional-dependencies]
strands = ["strands-agents>=0.1.0"]

[tool.hatch.build.targets.wheel]
packages = ["src/aws_rft_sdk"]
5 changes: 5 additions & 0 deletions aws-rft-sdk/src/aws_rft_sdk/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from aws_rft_sdk.client import RolloutFeedbackClient
from aws_rft_sdk.handler import rft_handler
from aws_rft_sdk.context import RFTContext

__all__ = ["RolloutFeedbackClient", "rft_handler", "RFTContext"]
Empty file.
78 changes: 78 additions & 0 deletions aws-rft-sdk/src/aws_rft_sdk/adapters/strands.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""Strands model adapter — wraps a Strands model to inject RFT headers.

Usage::

from aws_rft_sdk.adapters.strands import wrap_model
from strands.models.openai import OpenAIModel

model = OpenAIModel(
client_args={"api_key": key, "base_url": endpoint},
model_id="my-model",
)
model = wrap_model(model) # Now injects X-RFT-* headers on every call

Requires the Strands OpenAIModel to pass through ``extra_headers`` kwarg
to the underlying OpenAI client (supported since strands-agents >= X.Y.Z).
"""

import logging
from typing import Any

from aws_rft_sdk.context import RFTContext

logger = logging.getLogger(__name__)


def wrap_model(model: Any) -> Any:
"""Wrap a Strands model to automatically inject RFT training headers.

The wrapper reads the current rollout context (populated by ``@rft_handler``)
and adds ``X-RFT-*`` headers to every inference request so the training
inference endpoint can correlate requests with rollouts.

Args:
model: A Strands model instance (e.g., ``OpenAIModel``).

Returns:
A wrapped model that transparently injects RFT headers.
"""
return _RFTModelWrapper(model)


class _RFTModelWrapper:
"""Transparent proxy that injects RFT headers into Strands model calls.

Delegates all attribute access to the inner model so it quacks like
the original. Intercepts ``stream()`` to inject ``extra_headers``.
"""

def __init__(self, inner_model: Any):
object.__setattr__(self, "_inner", inner_model)

def __getattr__(self, name: str) -> Any:
return getattr(self._inner, name)

def __setattr__(self, name: str, value: Any):
if name == "_inner":
object.__setattr__(self, name, value)
else:
setattr(self._inner, name, value)

def stream(self, *args: Any, **kwargs: Any) -> Any:
"""Intercept stream() to inject RFT headers via extra_headers kwarg."""
rft_headers = RFTContext.get_headers()
if rft_headers:
existing = kwargs.get("extra_headers") or {}
existing.update(rft_headers)
kwargs["extra_headers"] = existing
logger.debug("Injected RFT headers: %s", list(rft_headers.keys()))
return self._inner.stream(*args, **kwargs)

def update_config(self, **model_config: Any) -> None:
return self._inner.update_config(**model_config)

def get_config(self) -> Any:
return self._inner.get_config()

def structured_output(self, *args: Any, **kwargs: Any) -> Any:
return self._inner.structured_output(*args, **kwargs)
147 changes: 147 additions & 0 deletions aws-rft-sdk/src/aws_rft_sdk/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
"""RolloutFeedbackClient — reports rewards and trajectory completion to AgenticRFTRuntimeService."""

import json
import logging
from typing import List, Optional

import boto3
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest

logger = logging.getLogger(__name__)

# Alpha endpoint; override via metadata["endpoint"] or AGENTIC_RFT_ENDPOINT env var.
_DEFAULT_ENDPOINT = "https://finetuning-job-runtime.alpha.sagemaker.us-west-2.api.aws"
_SIGNING_SERVICE = "sagemaker"


class RolloutFeedbackClient:
"""Client for reporting rollout feedback to the AgenticRFTRuntimeService.

Calls the real CompleteTrajectory and UpdateReward APIs using SigV4 auth.

Example::

from aws_rft_sdk import RolloutFeedbackClient

client = RolloutFeedbackClient(payload["metadata"])
client.complete_trajectory()
client.update_reward([0.8, 0.9, 1.0])

Args:
metadata: The ``metadata`` dict from the rollout payload. Expected keys:
- ``job_arn``: the RFT job ARN
- ``trajectory_id``: trajectory to act on
- ``endpoint`` (optional): override the runtime service URL
- ``region`` (optional): AWS region (default us-west-2)
"""

def __init__(self, metadata: dict):
self._metadata = metadata or {}
self._job_arn = self._metadata.get("job_arn")
self._trajectory_id = self._metadata.get("trajectory_id")
self._endpoint = (
self._metadata.get("endpoint")
or _DEFAULT_ENDPOINT
)
self._region = self._metadata.get("region", "us-west-2")
self._credentials = None

def _get_credentials(self):
if self._credentials is None:
session = boto3.Session(region_name=self._region)
self._credentials = session.get_credentials().get_frozen_credentials()
return self._credentials

def _signed_request(self, method: str, path: str, body: dict) -> dict:
"""Send a SigV4-signed request to the runtime service."""
import requests as http_requests

url = f"{self._endpoint}{path}"
data = json.dumps(body)
headers = {"Content-Type": "application/json"}

aws_request = AWSRequest(method=method, url=url, data=data, headers=headers)
SigV4Auth(self._get_credentials(), _SIGNING_SERVICE, self._region).add_auth(aws_request)

resp = http_requests.request(
method=method,
url=url,
headers=dict(aws_request.headers),
data=data,
timeout=30,
)
resp.raise_for_status()
return resp.json() if resp.text else {}

def complete_trajectory(self):
"""Mark the trajectory as complete (PENDING -> READY).

Calls POST /CompleteTrajectory with the trajectory ID.
"""
if not self._trajectory_id:
logger.warning("No trajectory_id in metadata; skipping complete_trajectory")
return

logger.info(
"CompleteTrajectory: trajectory_id=%s",
self._trajectory_id,
)
self._signed_request("POST", "/CompleteTrajectory", {
"TrajectoryId": self._trajectory_id,
})

def update_reward(self, rewards: List[float]):
"""Submit reward scores for the trajectory (READY -> REWARD_RECEIVED).

Calls POST /UpdateReward with per-transition rewards.

Args:
rewards: List of reward values, one per transition in the trajectory.
"""
if not self._trajectory_id:
logger.warning("No trajectory_id in metadata; skipping update_reward")
return

logger.info(
"UpdateReward: trajectory_id=%s rewards=%s",
self._trajectory_id,
rewards,
)
self._signed_request("POST", "/UpdateReward", {
"TrajectoryId": self._trajectory_id,
"Rewards": rewards,
})

# Convenience wrappers (backward-compatible names)

def report_complete(self, reward: float):
"""Complete the trajectory and report a single reward.

This is a convenience method that calls complete_trajectory()
then update_reward() with a single reward value.

Args:
reward: The computed reward for this rollout.
"""
self.complete_trajectory()
self.update_reward([reward])

def report_error(self, error: str, reward: Optional[float] = None):
"""Log a rollout error.

Args:
error: Error description.
reward: Optional partial reward.
"""
logger.error(
"Rollout error: trajectory_id=%s error=%s",
self._trajectory_id,
error,
)
# Still try to complete + report zero reward so the trajectory isn't stuck
try:
self.complete_trajectory()
self.update_reward([reward or 0.0])
except Exception:
logger.exception("Failed to report error reward")
55 changes: 55 additions & 0 deletions aws-rft-sdk/src/aws_rft_sdk/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""Thread-local context for RFT rollout metadata.

The rft_handler decorator populates this context from the payload metadata.
The Strands model wrapper reads it to inject per-request headers.
"""

import threading
import uuid
from typing import Optional

_context = threading.local()


class RFTContext:
"""Access the current RFT rollout context.

Set by @rft_handler, read by wrap_model adapters to inject headers.

The injected headers match the AgenticRFTRuntimeService API:
- ``X-Rft-Job-Arn``: job ARN that identifies the Lego session
- ``X-Trajectory-Id``: groups turns into a single trajectory
- ``X-Span-Id``: unique ID for each turn within the trajectory
"""

@staticmethod
def get_headers() -> dict:
"""Return HTTP headers for the current rollout context.

A new ``X-Span-Id`` is generated on every call so each inference
turn gets a unique span within the trajectory.
"""
metadata = getattr(_context, "metadata", None)
if metadata is None:
return {}
headers = {}
if metadata.get("job_arn"):
headers["X-Rft-Job-Arn"] = metadata["job_arn"]
if metadata.get("trajectory_id"):
headers["X-Trajectory-Id"] = metadata["trajectory_id"]
# Auto-generate a span ID for each inference call
headers["X-Span-Id"] = str(uuid.uuid4())
return headers

@staticmethod
def get_metadata() -> Optional[dict]:
"""Return the raw metadata dict, or None if not in an RFT context."""
return getattr(_context, "metadata", None)


def _set_metadata(metadata: dict):
_context.metadata = metadata


def _clear_metadata():
_context.metadata = None
70 changes: 70 additions & 0 deletions aws-rft-sdk/src/aws_rft_sdk/handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""@rft_handler decorator — wraps an entrypoint to manage RFT rollout context."""

import asyncio
import functools
import inspect
import logging

from aws_rft_sdk.client import RolloutFeedbackClient
from aws_rft_sdk.context import _set_metadata, _clear_metadata

logger = logging.getLogger(__name__)


def rft_handler(func):
"""Decorator that sets up RFT rollout context around an entrypoint.

Extracts ``metadata`` from the payload, makes it available via
``RFTContext.get_headers()`` (used by ``wrap_model``), and auto-reports
errors if the function raises.

Works with both sync and async functions.

Example::

@app.entrypoint
@rft_handler
async def invoke_agent(payload):
user_input = payload.get("instance")
response = await agent.invoke_async(user_input)
return response.message["content"][0]["text"]
"""

if asyncio.iscoroutinefunction(func):

@functools.wraps(func)
async def async_wrapper(payload, *args, **kwargs):
metadata = payload.get("metadata", {}) if isinstance(payload, dict) else {}
_set_metadata(metadata)
try:
return await func(payload, *args, **kwargs)
except Exception as e:
logger.error("RFT rollout failed: %s", e)
try:
RolloutFeedbackClient(metadata).report_error(str(e))
except Exception:
logger.exception("Failed to report rollout error")
raise
finally:
_clear_metadata()

return async_wrapper
else:

@functools.wraps(func)
def sync_wrapper(payload, *args, **kwargs):
metadata = payload.get("metadata", {}) if isinstance(payload, dict) else {}
_set_metadata(metadata)
try:
return func(payload, *args, **kwargs)
except Exception as e:
logger.error("RFT rollout failed: %s", e)
try:
RolloutFeedbackClient(metadata).report_error(str(e))
except Exception:
logger.exception("Failed to report rollout error")
raise
finally:
_clear_metadata()

return sync_wrapper