From 9615fc90afae43ea9ba34bda4db8f37a5d7faac1 Mon Sep 17 00:00:00 2001 From: Mike Shen Date: Tue, 3 Mar 2026 21:11:06 +0000 Subject: [PATCH 1/2] feat: add aws-rft-sdk package for Reinforcement Fine-Tuning Add a standalone SDK package that integrates SageMaker RFT (Reinforcement Fine-Tuning) with Strands agent framework. Provides: - RolloutFeedbackClient: report rewards back to the training service - @rft_handler: decorator to extract rollout metadata from payloads - RFTContext: thread-local context for propagating training metadata - wrap_model: Strands model adapter that injects X-RFT-* headers --- aws-rft-sdk/pyproject.toml | 19 ++++ aws-rft-sdk/src/aws_rft_sdk/__init__.py | 5 + .../src/aws_rft_sdk/adapters/__init__.py | 0 .../src/aws_rft_sdk/adapters/strands.py | 78 ++++++++++++++++ aws-rft-sdk/src/aws_rft_sdk/client.py | 91 +++++++++++++++++++ aws-rft-sdk/src/aws_rft_sdk/context.py | 45 +++++++++ aws-rft-sdk/src/aws_rft_sdk/handler.py | 70 ++++++++++++++ 7 files changed, 308 insertions(+) create mode 100644 aws-rft-sdk/pyproject.toml create mode 100644 aws-rft-sdk/src/aws_rft_sdk/__init__.py create mode 100644 aws-rft-sdk/src/aws_rft_sdk/adapters/__init__.py create mode 100644 aws-rft-sdk/src/aws_rft_sdk/adapters/strands.py create mode 100644 aws-rft-sdk/src/aws_rft_sdk/client.py create mode 100644 aws-rft-sdk/src/aws_rft_sdk/context.py create mode 100644 aws-rft-sdk/src/aws_rft_sdk/handler.py diff --git a/aws-rft-sdk/pyproject.toml b/aws-rft-sdk/pyproject.toml new file mode 100644 index 0000000000..1b1389d1ae --- /dev/null +++ b/aws-rft-sdk/pyproject.toml @@ -0,0 +1,19 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "aws-rft-sdk" +version = "0.1.0" +description = "AWS Reinforcement Fine-Tuning SDK for online rollout-based training" +readme = {text = "", content-type = "text/markdown"} +requires-python = ">=3.9" +dependencies = [ + "boto3>=1.35.0", +] + +[project.optional-dependencies] +strands = ["strands-agents>=0.1.0"] + +[tool.hatch.build.targets.wheel] +packages = ["src/aws_rft_sdk"] diff --git a/aws-rft-sdk/src/aws_rft_sdk/__init__.py b/aws-rft-sdk/src/aws_rft_sdk/__init__.py new file mode 100644 index 0000000000..49c5df0055 --- /dev/null +++ b/aws-rft-sdk/src/aws_rft_sdk/__init__.py @@ -0,0 +1,5 @@ +from aws_rft_sdk.client import RolloutFeedbackClient +from aws_rft_sdk.handler import rft_handler +from aws_rft_sdk.context import RFTContext + +__all__ = ["RolloutFeedbackClient", "rft_handler", "RFTContext"] diff --git a/aws-rft-sdk/src/aws_rft_sdk/adapters/__init__.py b/aws-rft-sdk/src/aws_rft_sdk/adapters/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/aws-rft-sdk/src/aws_rft_sdk/adapters/strands.py b/aws-rft-sdk/src/aws_rft_sdk/adapters/strands.py new file mode 100644 index 0000000000..59a6c017bb --- /dev/null +++ b/aws-rft-sdk/src/aws_rft_sdk/adapters/strands.py @@ -0,0 +1,78 @@ +"""Strands model adapter — wraps a Strands model to inject RFT headers. + +Usage:: + + from aws_rft_sdk.adapters.strands import wrap_model + from strands.models.openai import OpenAIModel + + model = OpenAIModel( + client_args={"api_key": key, "base_url": endpoint}, + model_id="my-model", + ) + model = wrap_model(model) # Now injects X-RFT-* headers on every call + +Requires the Strands OpenAIModel to pass through ``extra_headers`` kwarg +to the underlying OpenAI client (supported since strands-agents >= X.Y.Z). +""" + +import logging +from typing import Any + +from aws_rft_sdk.context import RFTContext + +logger = logging.getLogger(__name__) + + +def wrap_model(model: Any) -> Any: + """Wrap a Strands model to automatically inject RFT training headers. + + The wrapper reads the current rollout context (populated by ``@rft_handler``) + and adds ``X-RFT-*`` headers to every inference request so the training + inference endpoint can correlate requests with rollouts. + + Args: + model: A Strands model instance (e.g., ``OpenAIModel``). + + Returns: + A wrapped model that transparently injects RFT headers. + """ + return _RFTModelWrapper(model) + + +class _RFTModelWrapper: + """Transparent proxy that injects RFT headers into Strands model calls. + + Delegates all attribute access to the inner model so it quacks like + the original. Intercepts ``stream()`` to inject ``extra_headers``. + """ + + def __init__(self, inner_model: Any): + object.__setattr__(self, "_inner", inner_model) + + def __getattr__(self, name: str) -> Any: + return getattr(self._inner, name) + + def __setattr__(self, name: str, value: Any): + if name == "_inner": + object.__setattr__(self, name, value) + else: + setattr(self._inner, name, value) + + def stream(self, *args: Any, **kwargs: Any) -> Any: + """Intercept stream() to inject RFT headers via extra_headers kwarg.""" + rft_headers = RFTContext.get_headers() + if rft_headers: + existing = kwargs.get("extra_headers") or {} + existing.update(rft_headers) + kwargs["extra_headers"] = existing + logger.debug("Injected RFT headers: %s", list(rft_headers.keys())) + return self._inner.stream(*args, **kwargs) + + def update_config(self, **model_config: Any) -> None: + return self._inner.update_config(**model_config) + + def get_config(self) -> Any: + return self._inner.get_config() + + def structured_output(self, *args: Any, **kwargs: Any) -> Any: + return self._inner.structured_output(*args, **kwargs) diff --git a/aws-rft-sdk/src/aws_rft_sdk/client.py b/aws-rft-sdk/src/aws_rft_sdk/client.py new file mode 100644 index 0000000000..95c1cca6f0 --- /dev/null +++ b/aws-rft-sdk/src/aws_rft_sdk/client.py @@ -0,0 +1,91 @@ +"""RolloutFeedbackClient — reports rewards and completion status back to the training service.""" + +import logging +from typing import Optional + +import boto3 + +logger = logging.getLogger(__name__) + + +class RolloutFeedbackClient: + """Client for reporting rollout feedback (rewards) to the RFT training service. + + Typically used inside an @rft_handler-decorated entrypoint to report + the reward computed from the agent's rollout. + + Example:: + + from aws_rft_sdk import RolloutFeedbackClient + + client = RolloutFeedbackClient(payload.get("metadata")) + client.report_complete(reward=0.85) + + Args: + metadata: The ``metadata`` dict from the rollout payload. Contains + training_job_arn, rollout_id, feedback_endpoint, etc. + """ + + def __init__(self, metadata: dict): + self._metadata = metadata or {} + self._training_job_arn = self._metadata.get("training_job_arn") + self._rollout_id = self._metadata.get("rollout_id") + self._feedback_endpoint = self._metadata.get("feedback_endpoint") + self._client = None + + def _get_client(self): + if self._client is None: + kwargs = {} + if self._feedback_endpoint: + kwargs["endpoint_url"] = self._feedback_endpoint + self._client = boto3.client("sagemaker", **kwargs) + return self._client + + def report_complete(self, reward: float): + """Report successful rollout completion with a reward score. + + Args: + reward: The computed reward for this rollout (typically 0.0–1.0). + """ + logger.info( + "Reporting rollout complete: training_job=%s rollout=%s reward=%s", + self._training_job_arn, + self._rollout_id, + reward, + ) + # TODO: Replace with actual RFT feedback API call when available. + # The service API will accept: + # - TrainingJobArn + # - RolloutId + # - Reward (float) + # - Status (COMPLETED) + client = self._get_client() + # Placeholder — actual API TBD + # client.send_rollout_feedback( + # TrainingJobArn=self._training_job_arn, + # RolloutId=self._rollout_id, + # Reward=reward, + # Status="COMPLETED", + # ) + + def report_error(self, error: str, reward: Optional[float] = None): + """Report a rollout error. + + Args: + error: Error description. + reward: Optional partial reward (defaults to 0.0). + """ + logger.error( + "Reporting rollout error: training_job=%s rollout=%s error=%s", + self._training_job_arn, + self._rollout_id, + error, + ) + # TODO: Replace with actual RFT feedback API call when available. + # client.send_rollout_feedback( + # TrainingJobArn=self._training_job_arn, + # RolloutId=self._rollout_id, + # Reward=reward or 0.0, + # Status="FAILED", + # ErrorMessage=error, + # ) diff --git a/aws-rft-sdk/src/aws_rft_sdk/context.py b/aws-rft-sdk/src/aws_rft_sdk/context.py new file mode 100644 index 0000000000..1d412453d9 --- /dev/null +++ b/aws-rft-sdk/src/aws_rft_sdk/context.py @@ -0,0 +1,45 @@ +"""Thread-local context for RFT rollout metadata. + +The rft_handler decorator populates this context from the payload metadata. +The Strands model wrapper reads it to inject per-request headers. +""" + +import threading +from typing import Optional + +_context = threading.local() + + +class RFTContext: + """Access the current RFT rollout context. + + Set by @rft_handler, read by wrap_model adapters to inject headers. + """ + + @staticmethod + def get_headers() -> dict: + """Return HTTP headers for the current rollout context.""" + metadata = getattr(_context, "metadata", None) + if metadata is None: + return {} + headers = {} + if metadata.get("training_job_arn"): + headers["X-RFT-Training-Job-Arn"] = metadata["training_job_arn"] + if metadata.get("rollout_id"): + headers["X-RFT-Rollout-Id"] = metadata["rollout_id"] + if metadata.get("episode_id"): + headers["X-RFT-Episode-Id"] = metadata["episode_id"] + return headers + + @staticmethod + def get_metadata() -> Optional[dict]: + """Return the raw metadata dict, or None if not in an RFT context.""" + return getattr(_context, "metadata", None) + + +def _set_metadata(metadata: dict): + _context.metadata = metadata + + +def _clear_metadata(): + _context.metadata = None diff --git a/aws-rft-sdk/src/aws_rft_sdk/handler.py b/aws-rft-sdk/src/aws_rft_sdk/handler.py new file mode 100644 index 0000000000..0c8b1be015 --- /dev/null +++ b/aws-rft-sdk/src/aws_rft_sdk/handler.py @@ -0,0 +1,70 @@ +"""@rft_handler decorator — wraps an entrypoint to manage RFT rollout context.""" + +import asyncio +import functools +import inspect +import logging + +from aws_rft_sdk.client import RolloutFeedbackClient +from aws_rft_sdk.context import _set_metadata, _clear_metadata + +logger = logging.getLogger(__name__) + + +def rft_handler(func): + """Decorator that sets up RFT rollout context around an entrypoint. + + Extracts ``metadata`` from the payload, makes it available via + ``RFTContext.get_headers()`` (used by ``wrap_model``), and auto-reports + errors if the function raises. + + Works with both sync and async functions. + + Example:: + + @app.entrypoint + @rft_handler + async def invoke_agent(payload): + user_input = payload.get("instance") + response = await agent.invoke_async(user_input) + return response.message["content"][0]["text"] + """ + + if asyncio.iscoroutinefunction(func): + + @functools.wraps(func) + async def async_wrapper(payload, *args, **kwargs): + metadata = payload.get("metadata", {}) if isinstance(payload, dict) else {} + _set_metadata(metadata) + try: + return await func(payload, *args, **kwargs) + except Exception as e: + logger.error("RFT rollout failed: %s", e) + try: + RolloutFeedbackClient(metadata).report_error(str(e)) + except Exception: + logger.exception("Failed to report rollout error") + raise + finally: + _clear_metadata() + + return async_wrapper + else: + + @functools.wraps(func) + def sync_wrapper(payload, *args, **kwargs): + metadata = payload.get("metadata", {}) if isinstance(payload, dict) else {} + _set_metadata(metadata) + try: + return func(payload, *args, **kwargs) + except Exception as e: + logger.error("RFT rollout failed: %s", e) + try: + RolloutFeedbackClient(metadata).report_error(str(e)) + except Exception: + logger.exception("Failed to report rollout error") + raise + finally: + _clear_metadata() + + return sync_wrapper From 18a63edb5306f64f959c645b09938de5602ff24c Mon Sep 17 00:00:00 2001 From: Mike Shen Date: Tue, 3 Mar 2026 23:01:19 +0000 Subject: [PATCH 2/2] update aws-rft-sdk to use real AgenticRFTRuntimeService APIs - client.py: implement CompleteTrajectory and UpdateReward using SigV4- signed HTTP calls to finetuning-job-runtime.alpha.sagemaker endpoint - context.py: align header names with service API (X-Rft-Job-Arn, X-Trajectory-Id, X-Span-Id); auto-generate span ID per inference call - pyproject.toml: add requests dependency Tested against alpha endpoint: - CompleteTrajectory: 404 for non-existent trajectory (API reachable) - UpdateReward: 404 for non-existent trajectory (API reachable) - SigV4 signing with service name 'sagemaker' confirmed working - End-to-end header injection via wrap_model verified --- aws-rft-sdk/pyproject.toml | 1 + aws-rft-sdk/src/aws_rft_sdk/client.py | 166 +++++++++++++++++-------- aws-rft-sdk/src/aws_rft_sdk/context.py | 24 ++-- 3 files changed, 129 insertions(+), 62 deletions(-) diff --git a/aws-rft-sdk/pyproject.toml b/aws-rft-sdk/pyproject.toml index 1b1389d1ae..a9936bdb1e 100644 --- a/aws-rft-sdk/pyproject.toml +++ b/aws-rft-sdk/pyproject.toml @@ -10,6 +10,7 @@ readme = {text = "", content-type = "text/markdown"} requires-python = ">=3.9" dependencies = [ "boto3>=1.35.0", + "requests>=2.28.0", ] [project.optional-dependencies] diff --git a/aws-rft-sdk/src/aws_rft_sdk/client.py b/aws-rft-sdk/src/aws_rft_sdk/client.py index 95c1cca6f0..67ee973098 100644 --- a/aws-rft-sdk/src/aws_rft_sdk/client.py +++ b/aws-rft-sdk/src/aws_rft_sdk/client.py @@ -1,91 +1,147 @@ -"""RolloutFeedbackClient — reports rewards and completion status back to the training service.""" +"""RolloutFeedbackClient — reports rewards and trajectory completion to AgenticRFTRuntimeService.""" +import json import logging -from typing import Optional +from typing import List, Optional import boto3 +from botocore.auth import SigV4Auth +from botocore.awsrequest import AWSRequest logger = logging.getLogger(__name__) +# Alpha endpoint; override via metadata["endpoint"] or AGENTIC_RFT_ENDPOINT env var. +_DEFAULT_ENDPOINT = "https://finetuning-job-runtime.alpha.sagemaker.us-west-2.api.aws" +_SIGNING_SERVICE = "sagemaker" + class RolloutFeedbackClient: - """Client for reporting rollout feedback (rewards) to the RFT training service. + """Client for reporting rollout feedback to the AgenticRFTRuntimeService. - Typically used inside an @rft_handler-decorated entrypoint to report - the reward computed from the agent's rollout. + Calls the real CompleteTrajectory and UpdateReward APIs using SigV4 auth. Example:: from aws_rft_sdk import RolloutFeedbackClient - client = RolloutFeedbackClient(payload.get("metadata")) - client.report_complete(reward=0.85) + client = RolloutFeedbackClient(payload["metadata"]) + client.complete_trajectory() + client.update_reward([0.8, 0.9, 1.0]) Args: - metadata: The ``metadata`` dict from the rollout payload. Contains - training_job_arn, rollout_id, feedback_endpoint, etc. + metadata: The ``metadata`` dict from the rollout payload. Expected keys: + - ``job_arn``: the RFT job ARN + - ``trajectory_id``: trajectory to act on + - ``endpoint`` (optional): override the runtime service URL + - ``region`` (optional): AWS region (default us-west-2) """ def __init__(self, metadata: dict): self._metadata = metadata or {} - self._training_job_arn = self._metadata.get("training_job_arn") - self._rollout_id = self._metadata.get("rollout_id") - self._feedback_endpoint = self._metadata.get("feedback_endpoint") - self._client = None - - def _get_client(self): - if self._client is None: - kwargs = {} - if self._feedback_endpoint: - kwargs["endpoint_url"] = self._feedback_endpoint - self._client = boto3.client("sagemaker", **kwargs) - return self._client + self._job_arn = self._metadata.get("job_arn") + self._trajectory_id = self._metadata.get("trajectory_id") + self._endpoint = ( + self._metadata.get("endpoint") + or _DEFAULT_ENDPOINT + ) + self._region = self._metadata.get("region", "us-west-2") + self._credentials = None + + def _get_credentials(self): + if self._credentials is None: + session = boto3.Session(region_name=self._region) + self._credentials = session.get_credentials().get_frozen_credentials() + return self._credentials + + def _signed_request(self, method: str, path: str, body: dict) -> dict: + """Send a SigV4-signed request to the runtime service.""" + import requests as http_requests + + url = f"{self._endpoint}{path}" + data = json.dumps(body) + headers = {"Content-Type": "application/json"} + + aws_request = AWSRequest(method=method, url=url, data=data, headers=headers) + SigV4Auth(self._get_credentials(), _SIGNING_SERVICE, self._region).add_auth(aws_request) + + resp = http_requests.request( + method=method, + url=url, + headers=dict(aws_request.headers), + data=data, + timeout=30, + ) + resp.raise_for_status() + return resp.json() if resp.text else {} - def report_complete(self, reward: float): - """Report successful rollout completion with a reward score. + def complete_trajectory(self): + """Mark the trajectory as complete (PENDING -> READY). + + Calls POST /CompleteTrajectory with the trajectory ID. + """ + if not self._trajectory_id: + logger.warning("No trajectory_id in metadata; skipping complete_trajectory") + return + + logger.info( + "CompleteTrajectory: trajectory_id=%s", + self._trajectory_id, + ) + self._signed_request("POST", "/CompleteTrajectory", { + "TrajectoryId": self._trajectory_id, + }) + + def update_reward(self, rewards: List[float]): + """Submit reward scores for the trajectory (READY -> REWARD_RECEIVED). + + Calls POST /UpdateReward with per-transition rewards. Args: - reward: The computed reward for this rollout (typically 0.0–1.0). + rewards: List of reward values, one per transition in the trajectory. """ + if not self._trajectory_id: + logger.warning("No trajectory_id in metadata; skipping update_reward") + return + logger.info( - "Reporting rollout complete: training_job=%s rollout=%s reward=%s", - self._training_job_arn, - self._rollout_id, - reward, + "UpdateReward: trajectory_id=%s rewards=%s", + self._trajectory_id, + rewards, ) - # TODO: Replace with actual RFT feedback API call when available. - # The service API will accept: - # - TrainingJobArn - # - RolloutId - # - Reward (float) - # - Status (COMPLETED) - client = self._get_client() - # Placeholder — actual API TBD - # client.send_rollout_feedback( - # TrainingJobArn=self._training_job_arn, - # RolloutId=self._rollout_id, - # Reward=reward, - # Status="COMPLETED", - # ) + self._signed_request("POST", "/UpdateReward", { + "TrajectoryId": self._trajectory_id, + "Rewards": rewards, + }) + + # Convenience wrappers (backward-compatible names) + + def report_complete(self, reward: float): + """Complete the trajectory and report a single reward. + + This is a convenience method that calls complete_trajectory() + then update_reward() with a single reward value. + + Args: + reward: The computed reward for this rollout. + """ + self.complete_trajectory() + self.update_reward([reward]) def report_error(self, error: str, reward: Optional[float] = None): - """Report a rollout error. + """Log a rollout error. Args: error: Error description. - reward: Optional partial reward (defaults to 0.0). + reward: Optional partial reward. """ logger.error( - "Reporting rollout error: training_job=%s rollout=%s error=%s", - self._training_job_arn, - self._rollout_id, + "Rollout error: trajectory_id=%s error=%s", + self._trajectory_id, error, ) - # TODO: Replace with actual RFT feedback API call when available. - # client.send_rollout_feedback( - # TrainingJobArn=self._training_job_arn, - # RolloutId=self._rollout_id, - # Reward=reward or 0.0, - # Status="FAILED", - # ErrorMessage=error, - # ) + # Still try to complete + report zero reward so the trajectory isn't stuck + try: + self.complete_trajectory() + self.update_reward([reward or 0.0]) + except Exception: + logger.exception("Failed to report error reward") diff --git a/aws-rft-sdk/src/aws_rft_sdk/context.py b/aws-rft-sdk/src/aws_rft_sdk/context.py index 1d412453d9..f2ed0b21df 100644 --- a/aws-rft-sdk/src/aws_rft_sdk/context.py +++ b/aws-rft-sdk/src/aws_rft_sdk/context.py @@ -5,6 +5,7 @@ """ import threading +import uuid from typing import Optional _context = threading.local() @@ -14,21 +15,30 @@ class RFTContext: """Access the current RFT rollout context. Set by @rft_handler, read by wrap_model adapters to inject headers. + + The injected headers match the AgenticRFTRuntimeService API: + - ``X-Rft-Job-Arn``: job ARN that identifies the Lego session + - ``X-Trajectory-Id``: groups turns into a single trajectory + - ``X-Span-Id``: unique ID for each turn within the trajectory """ @staticmethod def get_headers() -> dict: - """Return HTTP headers for the current rollout context.""" + """Return HTTP headers for the current rollout context. + + A new ``X-Span-Id`` is generated on every call so each inference + turn gets a unique span within the trajectory. + """ metadata = getattr(_context, "metadata", None) if metadata is None: return {} headers = {} - if metadata.get("training_job_arn"): - headers["X-RFT-Training-Job-Arn"] = metadata["training_job_arn"] - if metadata.get("rollout_id"): - headers["X-RFT-Rollout-Id"] = metadata["rollout_id"] - if metadata.get("episode_id"): - headers["X-RFT-Episode-Id"] = metadata["episode_id"] + if metadata.get("job_arn"): + headers["X-Rft-Job-Arn"] = metadata["job_arn"] + if metadata.get("trajectory_id"): + headers["X-Trajectory-Id"] = metadata["trajectory_id"] + # Auto-generate a span ID for each inference call + headers["X-Span-Id"] = str(uuid.uuid4()) return headers @staticmethod