diff --git a/sagemaker-rft/LICENSE b/sagemaker-rft/LICENSE new file mode 100644 index 0000000000..f49a4e16e6 --- /dev/null +++ b/sagemaker-rft/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/sagemaker-rft/README.md b/sagemaker-rft/README.md new file mode 100644 index 0000000000..fec01b28c6 --- /dev/null +++ b/sagemaker-rft/README.md @@ -0,0 +1,70 @@ +# SageMaker RFT SDK + +Integration SDK for the multi-turn reinforcement fine-tuning (RFT) platform on Amazon SageMaker. + +## Installation + +```bash +pip install sagemaker-rft +``` + +With framework adapters: +```bash +pip install sagemaker-rft[strands] # Strands framework support +pip install sagemaker-rft[langchain] # LangChain framework support +``` + +## Integration Patterns + +### Strands + AgentCore (simplest) + +```python +from sagemaker.rft import rft_handler, RolloutFeedbackClient +from sagemaker.rft.adapters.strands import wrap_model + +model = OpenAIModel(client_args={"base_url": "$TRAINING_INFERENCE_ENDPOINT"}, model_id="$TRAINING_MODEL_NAME") +model = wrap_model(model) +agent = Agent(model=model, tools=[...]) + +@app.entrypoint +@rft_handler +async def invoke_agent(payload): + result = await agent.invoke_async(payload["instance"]) + return result.message["content"][0]["text"] +``` + +### Strands Standalone + +```python +from sagemaker.rft import set_rollout_context, RolloutFeedbackClient +from sagemaker.rft.adapters.strands import wrap_model + +model = wrap_model(OpenAIModel(...)) +agent = Agent(model=model, tools=[...]) + +@app.post("/rollout") +def rollout(request: RolloutRequest): + set_rollout_context(request.metadata, request.inference_params) + try: + result = agent(request.instance["instruction"]) + reward = compute_reward(result) + RolloutFeedbackClient(request.metadata).report_complete(reward) + except Exception: + RolloutFeedbackClient(request.metadata).report_error() + raise + return {"status": "ok"} +``` + +### Custom Integration + +```python +from sagemaker.rft import make_inference_headers, RolloutFeedbackClient, RolloutRequest + +@app.post("/rollout") +def rollout(request: RolloutRequest): + headers = make_inference_headers(request.metadata) + client = OpenAI(base_url=endpoint, default_headers=headers) + result = my_agent.run(request.instance, client) + RolloutFeedbackClient(request.metadata).report_complete(compute_reward(result)) + return {"status": "ok"} +``` diff --git a/sagemaker-rft/VERSION b/sagemaker-rft/VERSION new file mode 100644 index 0000000000..6e8bf73aa5 --- /dev/null +++ b/sagemaker-rft/VERSION @@ -0,0 +1 @@ +0.1.0 diff --git a/sagemaker-rft/pyproject.toml b/sagemaker-rft/pyproject.toml new file mode 100644 index 0000000000..8d163cd4f6 --- /dev/null +++ b/sagemaker-rft/pyproject.toml @@ -0,0 +1,56 @@ +[build-system] +requires = ["setuptools>=64", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "sagemaker-rft" +dynamic = ["version"] +description = "Integration SDK for multi-turn reinforcement fine-tuning on Amazon SageMaker." +readme = "README.md" +requires-python = ">=3.9" +authors = [ + { name = "Amazon Web Services" }, +] +keywords = ["AI", "AWS", "Amazon", "ML", "RFT", "reinforcement-learning"] +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Natural Language :: English", + "Programming Language :: Python", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", +] +dependencies = [ + "boto3>=1.29.0", + "requests>=2.28.0", + "pydantic>=2.0.0", +] + +[project.urls] +Homepage = "https://github.com/aws/sagemaker-python-sdk" + +[project.optional-dependencies] +strands = ["strands-agents"] +langchain = ["httpx>=0.24.0"] +test = [ + "pytest", + "pytest-cov", + "mock", +] + +[tool.setuptools.packages.find] +where = ["src/"] +include = ["sagemaker*"] +namespaces = true + +[tool.setuptools.dynamic] +version = { file = "VERSION" } + +[tool.pytest.ini_options] +addopts = ["-vv"] +testpaths = ["tests"] + +[tool.black] +line-length = 100 diff --git a/sagemaker-rft/src/sagemaker/__init__.py b/sagemaker-rft/src/sagemaker/__init__.py new file mode 100644 index 0000000000..aa8d919a4a --- /dev/null +++ b/sagemaker-rft/src/sagemaker/__init__.py @@ -0,0 +1 @@ +# Namespace package — do not add code here. diff --git a/sagemaker-rft/src/sagemaker/rft/__init__.py b/sagemaker-rft/src/sagemaker/rft/__init__.py new file mode 100644 index 0000000000..92f53967da --- /dev/null +++ b/sagemaker-rft/src/sagemaker/rft/__init__.py @@ -0,0 +1,60 @@ +"""SageMaker RFT SDK - Integration library for multi-turn RL training platform. + +Strands + AgentCore (simplest):: + + from sagemaker.rft import rft_handler, RolloutFeedbackClient + from sagemaker.rft.adapters.strands import wrap_model + + model = wrap_model(OpenAIModel(...)) + + @app.entrypoint + @rft_handler + async def invoke_agent(payload): + result = await agent.invoke_async(payload["instance"]) + return result + +Strands Standalone:: + + from sagemaker.rft import set_rollout_context, RolloutFeedbackClient + from sagemaker.rft.adapters.strands import wrap_model + + model = wrap_model(model) + + @app.post("/rollout") + def rollout(request): + set_rollout_context(request.metadata, request.inference_params) + result = agent(request.instance) + RolloutFeedbackClient(request.metadata).report_complete(reward) + +Custom Integration:: + + from sagemaker.rft import make_inference_headers, RolloutFeedbackClient + + @app.post("/rollout") + def handle(request): + headers = make_inference_headers(request.metadata) + client = OpenAI(base_url=endpoint, default_headers=headers) + result = my_agent.run(request.instance, client) + RolloutFeedbackClient(request.metadata).report_complete(reward) +""" + +from sagemaker.rft.headers import make_inference_headers, get_inference_headers +from sagemaker.rft.feedback import RolloutFeedbackClient +from sagemaker.rft.models import RolloutMetadata, RolloutRequest, InferenceParams +from sagemaker.rft.decorators import rft_handler +from sagemaker.rft.context import set_rollout_context, clear_rollout_context, get_inference_params + +__all__ = [ + "make_inference_headers", + "get_inference_headers", + "RolloutFeedbackClient", + "RolloutMetadata", + "RolloutRequest", + "InferenceParams", + "rft_handler", + "set_rollout_context", + "clear_rollout_context", + "get_inference_params", +] + +__version__ = "0.1.0" diff --git a/sagemaker-rft/src/sagemaker/rft/adapters/__init__.py b/sagemaker-rft/src/sagemaker/rft/adapters/__init__.py new file mode 100644 index 0000000000..5a341d6507 --- /dev/null +++ b/sagemaker-rft/src/sagemaker/rft/adapters/__init__.py @@ -0,0 +1,6 @@ +"""Framework-specific adapters for automatic header injection. + +Available adapters: + - strands: For Strands framework (OpenAIModel, LiteLLMModel) + - langchain: For LangChain (ChatOpenAI with custom httpx clients) +""" diff --git a/sagemaker-rft/src/sagemaker/rft/adapters/langchain.py b/sagemaker-rft/src/sagemaker/rft/adapters/langchain.py new file mode 100644 index 0000000000..12578823e4 --- /dev/null +++ b/sagemaker-rft/src/sagemaker/rft/adapters/langchain.py @@ -0,0 +1,66 @@ +"""LangChain adapter for automatic header injection. + +Provides helper functions to create httpx clients that automatically inject +X-Metadata headers into inference requests using the rollout context. +""" + +from __future__ import annotations + +import httpx + +from sagemaker.rft.headers import get_inference_headers + + +def _inject_headers(request: httpx.Request) -> None: + """Event hook that injects headers from rollout context.""" + headers = get_inference_headers() + if headers: + request.headers.update(headers) + + +def create_http_client(**kwargs) -> httpx.Client: + """Create an httpx Client that auto-injects X-Metadata headers. + + Use with LangChain's ChatOpenAI http_client parameter. + + Args: + **kwargs: Additional arguments passed to httpx.Client. + + Returns: + httpx.Client configured with header injection. + """ + event_hooks = kwargs.pop("event_hooks", {}) + request_hooks = event_hooks.get("request", []) + request_hooks.append(_inject_headers) + event_hooks["request"] = request_hooks + return httpx.Client(event_hooks=event_hooks, **kwargs) + + +def create_async_http_client(**kwargs) -> httpx.AsyncClient: + """Create an httpx AsyncClient that auto-injects X-Metadata headers. + + Use with LangChain's ChatOpenAI http_async_client parameter. + + Args: + **kwargs: Additional arguments passed to httpx.AsyncClient. + + Returns: + httpx.AsyncClient configured with header injection. + """ + event_hooks = kwargs.pop("event_hooks", {}) + request_hooks = event_hooks.get("request", []) + request_hooks.append(_inject_headers) + event_hooks["request"] = request_hooks + return httpx.AsyncClient(event_hooks=event_hooks, **kwargs) + + +def create_http_clients(**kwargs) -> tuple[httpx.Client, httpx.AsyncClient]: + """Create both sync and async httpx clients with header injection. + + Args: + **kwargs: Additional arguments passed to both clients. + + Returns: + Tuple of (httpx.Client, httpx.AsyncClient). + """ + return create_http_client(**kwargs), create_async_http_client(**kwargs) diff --git a/sagemaker-rft/src/sagemaker/rft/adapters/strands.py b/sagemaker-rft/src/sagemaker/rft/adapters/strands.py new file mode 100644 index 0000000000..f3c85b2c6b --- /dev/null +++ b/sagemaker-rft/src/sagemaker/rft/adapters/strands.py @@ -0,0 +1,92 @@ +"""Strands framework adapter for automatic header and inference param injection. + +Provides wrap_model() which wraps a Strands model to automatically inject +RFT headers and inference parameters into requests using the rollout context. + +The wrapper intercepts ``stream()`` and injects headers via +``client_args["default_headers"]`` because Strands ``OpenAIModel`` creates +a new OpenAI client per request from ``client_args``. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from sagemaker.rft.headers import get_inference_headers +from sagemaker.rft.context import get_inference_params + +logger = logging.getLogger(__name__) + + +def wrap_model(model: Any) -> Any: + """Wrap a Strands model to auto-inject headers and inference params from context. + + Creates a transparent proxy that: + 1. Injects RFT headers (X-Rft-Job-Arn, X-Trajectory-Id, X-Span-Id) via + client_args["default_headers"] on every stream() call + 2. Injects inference parameters (temperature, max_tokens, top_p) + + Args: + model: A Strands model instance (e.g., ``OpenAIModel``). + + Returns: + A wrapped model that transparently injects RFT headers. + """ + return _RFTModelWrapper(model) + + +class _RFTModelWrapper: + """Transparent proxy that injects RFT headers into Strands model calls. + + Delegates all attribute access to the inner model so it quacks like + the original. Intercepts ``stream()`` to inject headers. + """ + + def __init__(self, inner_model: Any): + object.__setattr__(self, "_inner", inner_model) + + def __getattr__(self, name: str) -> Any: + return getattr(self._inner, name) + + def __setattr__(self, name: str, value: Any): + if name == "_inner": + object.__setattr__(self, name, value) + else: + setattr(self._inner, name, value) + + def stream(self, *args: Any, **kwargs: Any) -> Any: + """Intercept stream() to inject RFT headers via client_args default_headers. + + The OpenAI client supports ``default_headers`` in its constructor, + which are sent with every request. We inject the RFT headers there since + Strands OpenAIModel creates a new client per request from ``client_args``. + """ + rft_headers = get_inference_headers() + if rft_headers: + client_args = getattr(self._inner, "client_args", None) + if client_args is not None: + existing = client_args.get("default_headers") or {} + existing.update(rft_headers) + client_args["default_headers"] = existing + logger.debug("Injected RFT headers: %s", list(rft_headers.keys())) + + # Inject inference params if available + inference_params = get_inference_params() + if inference_params: + params = getattr(self._inner, "params", None) + if params is not None and isinstance(params, dict): + for key in ["temperature", "max_tokens", "top_p"]: + if inference_params.get(key) is not None: + params[key] = inference_params[key] + + return self._inner.stream(*args, **kwargs) + + def update_config(self, **model_config: Any) -> None: + return self._inner.update_config(**model_config) + + def get_config(self) -> Any: + return self._inner.get_config() + + def structured_output(self, *args: Any, **kwargs: Any) -> Any: + return self._inner.structured_output(*args, **kwargs) diff --git a/sagemaker-rft/src/sagemaker/rft/context.py b/sagemaker-rft/src/sagemaker/rft/context.py new file mode 100644 index 0000000000..a0ef4a7b8f --- /dev/null +++ b/sagemaker-rft/src/sagemaker/rft/context.py @@ -0,0 +1,60 @@ +"""Rollout context management using contextvars. + +Provides thread-safe storage for rollout metadata and inference parameters +that can be set at the top level and accessed deep in the call stack. +""" + +from __future__ import annotations + +from contextvars import ContextVar +from typing import Any + +_rollout_metadata: ContextVar[dict[str, Any] | None] = ContextVar( + "rollout_metadata", default=None +) +_inference_params: ContextVar[dict[str, Any] | None] = ContextVar( + "inference_params", default=None +) + + +def set_rollout_context( + metadata: dict[str, Any], + inference_params: dict[str, Any] | None = None, +) -> None: + """Store rollout metadata and inference params in context. + + Call this at the start of a rollout handler. Values are available + via get_rollout_context() and get_inference_params() anywhere in the + same thread/async context. + + Args: + metadata: Rollout metadata dict from the rollout request. + inference_params: Optional dict with sampling parameters + (temperature, max_tokens, top_p). + """ + _rollout_metadata.set(metadata) + _inference_params.set(inference_params) + + +def get_rollout_context() -> dict[str, Any] | None: + """Retrieve rollout metadata from context. + + Returns: + The metadata dict if set, None otherwise. + """ + return _rollout_metadata.get() + + +def get_inference_params() -> dict[str, Any] | None: + """Retrieve inference parameters from context. + + Returns: + The inference_params dict if set, None otherwise. + """ + return _inference_params.get() + + +def clear_rollout_context() -> None: + """Clear rollout metadata and inference params from context.""" + _rollout_metadata.set(None) + _inference_params.set(None) diff --git a/sagemaker-rft/src/sagemaker/rft/decorators.py b/sagemaker-rft/src/sagemaker/rft/decorators.py new file mode 100644 index 0000000000..4977218b44 --- /dev/null +++ b/sagemaker-rft/src/sagemaker/rft/decorators.py @@ -0,0 +1,73 @@ +"""Decorators for RFT integration. + +Provides rft_handler decorator for AgentCore Runtime entrypoints. +""" + +from __future__ import annotations + +import asyncio +import logging +from functools import wraps +from typing import Any, Callable + +from sagemaker.rft.context import set_rollout_context, clear_rollout_context + +logger = logging.getLogger(__name__) + + +def rft_handler(func: Callable) -> Callable: + """Decorator for AgentCore Runtime entrypoints to handle RFT rollout lifecycle. + + Automatically: + 1. Sets rollout context (metadata + inference_params) for header injection + 2. Clears context when done + 3. Logs errors if the rollout fails + + Works with both sync and async functions. + + Args: + func: A sync or async function that handles the rollout. + + Returns: + Wrapped function with automatic context management. + + Example:: + + @app.entrypoint + @rft_handler + def invoke_agent(payload): + result = agent(payload["instance"]) + return result + """ + if asyncio.iscoroutinefunction(func): + + @wraps(func) + async def async_wrapper(payload: dict) -> Any: + metadata = payload.get("metadata") + inference_params = payload.get("inference_params") + set_rollout_context(metadata, inference_params) + try: + return await func(payload) + except Exception: + logger.exception("RFT rollout failed") + raise + finally: + clear_rollout_context() + + return async_wrapper + else: + + @wraps(func) + def sync_wrapper(payload: dict) -> Any: + metadata = payload.get("metadata") + inference_params = payload.get("inference_params") + set_rollout_context(metadata, inference_params) + try: + return func(payload) + except Exception: + logger.exception("RFT rollout failed") + raise + finally: + clear_rollout_context() + + return sync_wrapper diff --git a/sagemaker-rft/src/sagemaker/rft/feedback.py b/sagemaker-rft/src/sagemaker/rft/feedback.py new file mode 100644 index 0000000000..572fab467a --- /dev/null +++ b/sagemaker-rft/src/sagemaker/rft/feedback.py @@ -0,0 +1,102 @@ +"""Rollout feedback client for reporting completion and rewards to the RFT Runtime Service.""" + +from __future__ import annotations + +import json +import logging +from typing import Any + +import boto3 +from botocore.auth import SigV4Auth +from botocore.awsrequest import AWSPreparedRequest + +from sagemaker.rft.models import RolloutMetadata + +logger = logging.getLogger(__name__) + + +class RolloutFeedbackClient: + """Client for reporting rollout completion to the RFT Runtime Service. + + Calls the runtime service's ``/CompleteTrajectory`` and ``/UpdateReward`` + APIs using SigV4-signed requests. + + Example:: + + feedback = RolloutFeedbackClient(metadata) + try: + result = run_agent(...) + feedback.complete_trajectory() + feedback.update_reward(reward=0.95) + except Exception: + logger.exception("Rollout failed") + raise + """ + + def __init__(self, metadata: dict[str, Any] | RolloutMetadata) -> None: + """Initialize the feedback client. + + Args: + metadata: Rollout metadata dict or RolloutMetadata instance. + Expected keys: ``endpoint``, ``job_arn``, ``trajectory_id``, ``region``. + """ + if isinstance(metadata, RolloutMetadata): + metadata = metadata.model_dump() + elif not isinstance(metadata, dict): + raise TypeError( + f"metadata must be a dict or RolloutMetadata, got {type(metadata).__name__}." + ) + + self._endpoint = metadata.get("endpoint", "").rstrip("/") + self._job_arn = metadata.get("job_arn", "") + self._trajectory_id = metadata.get("trajectory_id", "") + self._region = metadata.get("region", "us-west-2") + self._metadata = metadata + + def complete_trajectory(self) -> None: + """Report trajectory completion to the runtime service.""" + logger.info("CompleteTrajectory: trajectory_id=%s", self._trajectory_id) + body = json.dumps({ + "trajectoryId": self._trajectory_id, + "jobArn": self._job_arn, + }) + self._signed_post("/CompleteTrajectory", body) + + def update_reward(self, reward: float | list[float]) -> None: + """Report reward(s) to the runtime service. + + Args: + reward: A single float or list of floats for multi-step rewards. + """ + rewards = [reward] if isinstance(reward, (int, float)) else reward + body = json.dumps({ + "trajectoryId": self._trajectory_id, + "jobArn": self._job_arn, + "rewards": rewards, + }) + self._signed_post("/UpdateReward", body) + + def _signed_post(self, path: str, body: str) -> None: + """Send a SigV4-signed POST to the runtime service.""" + import requests as req_lib + + url = f"{self._endpoint}{path}" + try: + session = boto3.Session() + credentials = session.get_credentials().get_frozen_credentials() + request = AWSPreparedRequest( + method="POST", + url=url, + headers={"Content-Type": "application/json"}, + body=body, + ) + SigV4Auth(credentials, "sagemaker", self._region).add_auth(request) + response = req_lib.post( + url, + headers=dict(request.headers), + data=body, + timeout=10, + ) + response.raise_for_status() + except Exception as e: + logger.warning("Failed %s: %s", path, e) diff --git a/sagemaker-rft/src/sagemaker/rft/headers.py b/sagemaker-rft/src/sagemaker/rft/headers.py new file mode 100644 index 0000000000..bf8995d842 --- /dev/null +++ b/sagemaker-rft/src/sagemaker/rft/headers.py @@ -0,0 +1,66 @@ +"""Header utilities for inference calls. + +Provides functions to create the HTTP headers required by the +AgenticRFTRuntimeService for each inference request. + +The runtime service expects three separate headers: + - ``X-Rft-Job-Arn``: job ARN that identifies the training session + - ``X-Trajectory-Id``: groups turns into a single trajectory + - ``X-Span-Id``: unique ID for each turn within the trajectory +""" + +from __future__ import annotations + +import uuid +import warnings +from typing import Any + +from sagemaker.rft.context import get_rollout_context +from sagemaker.rft.models import RolloutMetadata + + +def make_inference_headers(metadata: dict[str, Any] | RolloutMetadata) -> dict[str, str]: + """Create headers dict for inference calls. + + Add these headers to your HTTP client when making inference calls + to the RFT Runtime Service. + + Args: + metadata: The metadata from the rollout request (dict or RolloutMetadata). + + Returns: + Headers dict to add to inference calls. + """ + if isinstance(metadata, RolloutMetadata): + metadata = metadata.model_dump() + headers: dict[str, str] = {} + if metadata.get("job_arn"): + headers["X-Rft-Job-Arn"] = metadata["job_arn"] + if metadata.get("trajectory_id"): + headers["X-Trajectory-Id"] = metadata["trajectory_id"] + headers["X-Span-Id"] = str(uuid.uuid4()) + return headers + + +def get_inference_headers() -> dict[str, str]: + """Get headers from current rollout context. + + For use with set_rollout_context() when you need to retrieve headers + deep in the call stack without passing them explicitly. + + A new ``X-Span-Id`` is generated on every call so each inference + turn gets a unique span within the trajectory. + + Returns: + Headers dict, or empty dict if no context set (with warning). + """ + metadata = get_rollout_context() + if metadata is None: + warnings.warn( + "get_inference_headers() called but no rollout context set. " + "Did you forget to call set_rollout_context()? " + "Returning empty headers.", + stacklevel=2, + ) + return {} + return make_inference_headers(metadata) diff --git a/sagemaker-rft/src/sagemaker/rft/models.py b/sagemaker-rft/src/sagemaker/rft/models.py new file mode 100644 index 0000000000..8f2eb624dd --- /dev/null +++ b/sagemaker-rft/src/sagemaker/rft/models.py @@ -0,0 +1,61 @@ +"""Contract models for the rollout server API. + +These models define the enforced contract between the platform trainer +and customer rollout servers. + +Customer server requirements: + POST /rollout - Accept RolloutRequest + GET /health - Return {"status": "healthy"} when ready +""" + +from __future__ import annotations + +from typing import Any, Dict, Optional + +from pydantic import BaseModel, Field + + +class RolloutMetadata(BaseModel): + """Metadata sent by the trainer with each rollout request. + + Pass this entire object (or its dict form) to RolloutFeedbackClient + and make_inference_headers. + """ + + job_arn: str = Field(description="Training job ARN") + trajectory_id: str = Field(description="Unique trajectory identifier") + endpoint: str = Field(description="RFT Runtime Service endpoint URL") + region: str = Field(default="us-west-2", description="AWS region") + + +class InferenceParams(BaseModel): + """Inference parameters for rollout sampling. + + All fields are optional - if not provided, model defaults are used. + """ + + temperature: float | None = Field(default=None, description="Sampling temperature") + max_tokens: int | None = Field(default=None, description="Maximum tokens to generate") + top_p: float | None = Field(default=None, description="Top-p (nucleus) sampling") + + +class RolloutRequest(BaseModel): + """Request format sent by the trainer to your /rollout endpoint. + + This is the enforced contract. Your server must accept this exact format. + """ + + instance: Dict[str, Any] = Field( + description="Problem instance from customer's data file" + ) + metadata: RolloutMetadata = Field(description="Platform-provided rollout context") + inference_params: InferenceParams | None = Field( + default=None, + description="Optional inference parameters (temperature, max_tokens, top_p)", + ) + model_name: str | None = Field( + default=None, description="Optional model name override from trainer" + ) + model_endpoint: str | None = Field( + default=None, description="Optional model endpoint override from trainer" + )