diff --git a/aws-rft-sdk/pyproject.toml b/aws-rft-sdk/pyproject.toml new file mode 100644 index 0000000000..a9936bdb1e --- /dev/null +++ b/aws-rft-sdk/pyproject.toml @@ -0,0 +1,20 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "aws-rft-sdk" +version = "0.1.0" +description = "AWS Reinforcement Fine-Tuning SDK for online rollout-based training" +readme = {text = "", content-type = "text/markdown"} +requires-python = ">=3.9" +dependencies = [ + "boto3>=1.35.0", + "requests>=2.28.0", +] + +[project.optional-dependencies] +strands = ["strands-agents>=0.1.0"] + +[tool.hatch.build.targets.wheel] +packages = ["src/aws_rft_sdk"] diff --git a/aws-rft-sdk/src/aws_rft_sdk/__init__.py b/aws-rft-sdk/src/aws_rft_sdk/__init__.py new file mode 100644 index 0000000000..49c5df0055 --- /dev/null +++ b/aws-rft-sdk/src/aws_rft_sdk/__init__.py @@ -0,0 +1,5 @@ +from aws_rft_sdk.client import RolloutFeedbackClient +from aws_rft_sdk.handler import rft_handler +from aws_rft_sdk.context import RFTContext + +__all__ = ["RolloutFeedbackClient", "rft_handler", "RFTContext"] diff --git a/aws-rft-sdk/src/aws_rft_sdk/adapters/__init__.py b/aws-rft-sdk/src/aws_rft_sdk/adapters/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/aws-rft-sdk/src/aws_rft_sdk/adapters/strands.py b/aws-rft-sdk/src/aws_rft_sdk/adapters/strands.py new file mode 100644 index 0000000000..59a6c017bb --- /dev/null +++ b/aws-rft-sdk/src/aws_rft_sdk/adapters/strands.py @@ -0,0 +1,78 @@ +"""Strands model adapter — wraps a Strands model to inject RFT headers. + +Usage:: + + from aws_rft_sdk.adapters.strands import wrap_model + from strands.models.openai import OpenAIModel + + model = OpenAIModel( + client_args={"api_key": key, "base_url": endpoint}, + model_id="my-model", + ) + model = wrap_model(model) # Now injects X-RFT-* headers on every call + +Requires the Strands OpenAIModel to pass through ``extra_headers`` kwarg +to the underlying OpenAI client (supported since strands-agents >= X.Y.Z). +""" + +import logging +from typing import Any + +from aws_rft_sdk.context import RFTContext + +logger = logging.getLogger(__name__) + + +def wrap_model(model: Any) -> Any: + """Wrap a Strands model to automatically inject RFT training headers. + + The wrapper reads the current rollout context (populated by ``@rft_handler``) + and adds ``X-RFT-*`` headers to every inference request so the training + inference endpoint can correlate requests with rollouts. + + Args: + model: A Strands model instance (e.g., ``OpenAIModel``). + + Returns: + A wrapped model that transparently injects RFT headers. + """ + return _RFTModelWrapper(model) + + +class _RFTModelWrapper: + """Transparent proxy that injects RFT headers into Strands model calls. + + Delegates all attribute access to the inner model so it quacks like + the original. Intercepts ``stream()`` to inject ``extra_headers``. + """ + + def __init__(self, inner_model: Any): + object.__setattr__(self, "_inner", inner_model) + + def __getattr__(self, name: str) -> Any: + return getattr(self._inner, name) + + def __setattr__(self, name: str, value: Any): + if name == "_inner": + object.__setattr__(self, name, value) + else: + setattr(self._inner, name, value) + + def stream(self, *args: Any, **kwargs: Any) -> Any: + """Intercept stream() to inject RFT headers via extra_headers kwarg.""" + rft_headers = RFTContext.get_headers() + if rft_headers: + existing = kwargs.get("extra_headers") or {} + existing.update(rft_headers) + kwargs["extra_headers"] = existing + logger.debug("Injected RFT headers: %s", list(rft_headers.keys())) + return self._inner.stream(*args, **kwargs) + + def update_config(self, **model_config: Any) -> None: + return self._inner.update_config(**model_config) + + def get_config(self) -> Any: + return self._inner.get_config() + + def structured_output(self, *args: Any, **kwargs: Any) -> Any: + return self._inner.structured_output(*args, **kwargs) diff --git a/aws-rft-sdk/src/aws_rft_sdk/client.py b/aws-rft-sdk/src/aws_rft_sdk/client.py new file mode 100644 index 0000000000..67ee973098 --- /dev/null +++ b/aws-rft-sdk/src/aws_rft_sdk/client.py @@ -0,0 +1,147 @@ +"""RolloutFeedbackClient — reports rewards and trajectory completion to AgenticRFTRuntimeService.""" + +import json +import logging +from typing import List, Optional + +import boto3 +from botocore.auth import SigV4Auth +from botocore.awsrequest import AWSRequest + +logger = logging.getLogger(__name__) + +# Alpha endpoint; override via metadata["endpoint"] or AGENTIC_RFT_ENDPOINT env var. +_DEFAULT_ENDPOINT = "https://finetuning-job-runtime.alpha.sagemaker.us-west-2.api.aws" +_SIGNING_SERVICE = "sagemaker" + + +class RolloutFeedbackClient: + """Client for reporting rollout feedback to the AgenticRFTRuntimeService. + + Calls the real CompleteTrajectory and UpdateReward APIs using SigV4 auth. + + Example:: + + from aws_rft_sdk import RolloutFeedbackClient + + client = RolloutFeedbackClient(payload["metadata"]) + client.complete_trajectory() + client.update_reward([0.8, 0.9, 1.0]) + + Args: + metadata: The ``metadata`` dict from the rollout payload. Expected keys: + - ``job_arn``: the RFT job ARN + - ``trajectory_id``: trajectory to act on + - ``endpoint`` (optional): override the runtime service URL + - ``region`` (optional): AWS region (default us-west-2) + """ + + def __init__(self, metadata: dict): + self._metadata = metadata or {} + self._job_arn = self._metadata.get("job_arn") + self._trajectory_id = self._metadata.get("trajectory_id") + self._endpoint = ( + self._metadata.get("endpoint") + or _DEFAULT_ENDPOINT + ) + self._region = self._metadata.get("region", "us-west-2") + self._credentials = None + + def _get_credentials(self): + if self._credentials is None: + session = boto3.Session(region_name=self._region) + self._credentials = session.get_credentials().get_frozen_credentials() + return self._credentials + + def _signed_request(self, method: str, path: str, body: dict) -> dict: + """Send a SigV4-signed request to the runtime service.""" + import requests as http_requests + + url = f"{self._endpoint}{path}" + data = json.dumps(body) + headers = {"Content-Type": "application/json"} + + aws_request = AWSRequest(method=method, url=url, data=data, headers=headers) + SigV4Auth(self._get_credentials(), _SIGNING_SERVICE, self._region).add_auth(aws_request) + + resp = http_requests.request( + method=method, + url=url, + headers=dict(aws_request.headers), + data=data, + timeout=30, + ) + resp.raise_for_status() + return resp.json() if resp.text else {} + + def complete_trajectory(self): + """Mark the trajectory as complete (PENDING -> READY). + + Calls POST /CompleteTrajectory with the trajectory ID. + """ + if not self._trajectory_id: + logger.warning("No trajectory_id in metadata; skipping complete_trajectory") + return + + logger.info( + "CompleteTrajectory: trajectory_id=%s", + self._trajectory_id, + ) + self._signed_request("POST", "/CompleteTrajectory", { + "TrajectoryId": self._trajectory_id, + }) + + def update_reward(self, rewards: List[float]): + """Submit reward scores for the trajectory (READY -> REWARD_RECEIVED). + + Calls POST /UpdateReward with per-transition rewards. + + Args: + rewards: List of reward values, one per transition in the trajectory. + """ + if not self._trajectory_id: + logger.warning("No trajectory_id in metadata; skipping update_reward") + return + + logger.info( + "UpdateReward: trajectory_id=%s rewards=%s", + self._trajectory_id, + rewards, + ) + self._signed_request("POST", "/UpdateReward", { + "TrajectoryId": self._trajectory_id, + "Rewards": rewards, + }) + + # Convenience wrappers (backward-compatible names) + + def report_complete(self, reward: float): + """Complete the trajectory and report a single reward. + + This is a convenience method that calls complete_trajectory() + then update_reward() with a single reward value. + + Args: + reward: The computed reward for this rollout. + """ + self.complete_trajectory() + self.update_reward([reward]) + + def report_error(self, error: str, reward: Optional[float] = None): + """Log a rollout error. + + Args: + error: Error description. + reward: Optional partial reward. + """ + logger.error( + "Rollout error: trajectory_id=%s error=%s", + self._trajectory_id, + error, + ) + # Still try to complete + report zero reward so the trajectory isn't stuck + try: + self.complete_trajectory() + self.update_reward([reward or 0.0]) + except Exception: + logger.exception("Failed to report error reward") diff --git a/aws-rft-sdk/src/aws_rft_sdk/context.py b/aws-rft-sdk/src/aws_rft_sdk/context.py new file mode 100644 index 0000000000..f2ed0b21df --- /dev/null +++ b/aws-rft-sdk/src/aws_rft_sdk/context.py @@ -0,0 +1,55 @@ +"""Thread-local context for RFT rollout metadata. + +The rft_handler decorator populates this context from the payload metadata. +The Strands model wrapper reads it to inject per-request headers. +""" + +import threading +import uuid +from typing import Optional + +_context = threading.local() + + +class RFTContext: + """Access the current RFT rollout context. + + Set by @rft_handler, read by wrap_model adapters to inject headers. + + The injected headers match the AgenticRFTRuntimeService API: + - ``X-Rft-Job-Arn``: job ARN that identifies the Lego session + - ``X-Trajectory-Id``: groups turns into a single trajectory + - ``X-Span-Id``: unique ID for each turn within the trajectory + """ + + @staticmethod + def get_headers() -> dict: + """Return HTTP headers for the current rollout context. + + A new ``X-Span-Id`` is generated on every call so each inference + turn gets a unique span within the trajectory. + """ + metadata = getattr(_context, "metadata", None) + if metadata is None: + return {} + headers = {} + if metadata.get("job_arn"): + headers["X-Rft-Job-Arn"] = metadata["job_arn"] + if metadata.get("trajectory_id"): + headers["X-Trajectory-Id"] = metadata["trajectory_id"] + # Auto-generate a span ID for each inference call + headers["X-Span-Id"] = str(uuid.uuid4()) + return headers + + @staticmethod + def get_metadata() -> Optional[dict]: + """Return the raw metadata dict, or None if not in an RFT context.""" + return getattr(_context, "metadata", None) + + +def _set_metadata(metadata: dict): + _context.metadata = metadata + + +def _clear_metadata(): + _context.metadata = None diff --git a/aws-rft-sdk/src/aws_rft_sdk/handler.py b/aws-rft-sdk/src/aws_rft_sdk/handler.py new file mode 100644 index 0000000000..0c8b1be015 --- /dev/null +++ b/aws-rft-sdk/src/aws_rft_sdk/handler.py @@ -0,0 +1,70 @@ +"""@rft_handler decorator — wraps an entrypoint to manage RFT rollout context.""" + +import asyncio +import functools +import inspect +import logging + +from aws_rft_sdk.client import RolloutFeedbackClient +from aws_rft_sdk.context import _set_metadata, _clear_metadata + +logger = logging.getLogger(__name__) + + +def rft_handler(func): + """Decorator that sets up RFT rollout context around an entrypoint. + + Extracts ``metadata`` from the payload, makes it available via + ``RFTContext.get_headers()`` (used by ``wrap_model``), and auto-reports + errors if the function raises. + + Works with both sync and async functions. + + Example:: + + @app.entrypoint + @rft_handler + async def invoke_agent(payload): + user_input = payload.get("instance") + response = await agent.invoke_async(user_input) + return response.message["content"][0]["text"] + """ + + if asyncio.iscoroutinefunction(func): + + @functools.wraps(func) + async def async_wrapper(payload, *args, **kwargs): + metadata = payload.get("metadata", {}) if isinstance(payload, dict) else {} + _set_metadata(metadata) + try: + return await func(payload, *args, **kwargs) + except Exception as e: + logger.error("RFT rollout failed: %s", e) + try: + RolloutFeedbackClient(metadata).report_error(str(e)) + except Exception: + logger.exception("Failed to report rollout error") + raise + finally: + _clear_metadata() + + return async_wrapper + else: + + @functools.wraps(func) + def sync_wrapper(payload, *args, **kwargs): + metadata = payload.get("metadata", {}) if isinstance(payload, dict) else {} + _set_metadata(metadata) + try: + return func(payload, *args, **kwargs) + except Exception as e: + logger.error("RFT rollout failed: %s", e) + try: + RolloutFeedbackClient(metadata).report_error(str(e)) + except Exception: + logger.exception("Failed to report rollout error") + raise + finally: + _clear_metadata() + + return sync_wrapper