From 8a8d098b38b78316dad1c5d42a0b6ea32172c358 Mon Sep 17 00:00:00 2001 From: Vivek Kalyan Date: Tue, 17 Mar 2026 17:51:22 -0700 Subject: [PATCH] fix: auto-log inference tinker costs from managed openai client --- src/art/local/backend.py | 6 + src/art/model.py | 88 ++++++++++- tests/unit/test_model_openai_client_costs.py | 154 +++++++++++++++++++ 3 files changed, 247 insertions(+), 1 deletion(-) create mode 100644 tests/unit/test_model_openai_client_costs.py diff --git a/src/art/local/backend.py b/src/art/local/backend.py index 5baf200f..14376356 100644 --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -44,6 +44,7 @@ from .. import dev from ..backend import AnyTrainableModel, Backend +from ..costs import build_cost_calculator, get_model_pricing from ..metrics_taxonomy import ( TRAIN_GRADIENT_STEPS_KEY, average_metric_samples, @@ -206,6 +207,11 @@ async def register( # (wandb initialization is now handled by the model's _get_wandb_run method) if model.trainable and "WANDB_API_KEY" in os.environ: _ = model._get_wandb_run() + if model.trainable: + trainable_model = cast(TrainableModel, model) + pricing = get_model_pricing(trainable_model.base_model) + if pricing is not None: + trainable_model.set_cost_calculator(build_cost_calculator(pricing)) def _model_inference_name(self, model: Model, step: int | None = None) -> str: """Return the inference name for a model checkpoint. diff --git a/src/art/model.py b/src/art/model.py index ab08a88d..4e8ef03a 100644 --- a/src/art/model.py +++ b/src/art/model.py @@ -36,6 +36,47 @@ StateType = TypeVar("StateType", bound=dict[str, Any], default=dict[str, Any]) METRICS_BUILDER_STATE_KEY = "_metrics_builder_state" + + +class _OpenAIChatCompletionsProxy: + def __init__(self, completions: Any, record_costs: Any) -> None: + self._completions = completions + self._record_costs = record_costs + + async def create(self, *args: Any, **kwargs: Any) -> Any: + response = await self._completions.create(*args, **kwargs) + self._record_costs(response) + return response + + def __getattr__(self, name: str) -> Any: + return getattr(self._completions, name) + + +class _OpenAIChatProxy: + def __init__(self, chat: Any, record_costs: Any) -> None: + self._chat = chat + self.completions = _OpenAIChatCompletionsProxy(chat.completions, record_costs) + + def __getattr__(self, name: str) -> Any: + return getattr(self._chat, name) + + +class _OpenAIClientProxy: + def __init__(self, client: Any, record_costs: Any) -> None: + self._client = client + self._record_costs = record_costs + self.chat = _OpenAIChatProxy(client.chat, record_costs) + + def with_options(self, *args: Any, **kwargs: Any) -> "_OpenAIClientProxy": + return _OpenAIClientProxy( + self._client.with_options(*args, **kwargs), + self._record_costs, + ) + + def __getattr__(self, name: str) -> Any: + return getattr(self._client, name) + + METRIC_SECTIONS = frozenset( { "reward", @@ -233,6 +274,12 @@ async def register(self, backend: "Backend") -> None: def openai_client( self, ) -> AsyncOpenAI: + """Return ART's managed inference client. + + For trainable models with configured pricing, chat completion calls made + through this client automatically emit Tinker inference costs when an + ART metrics context is active. + """ if self._openai_client is not None: return self._openai_client @@ -245,7 +292,7 @@ def openai_client( raise ValueError( "In order to create an OpenAI client you must provide an `inference_api_key` and `inference_base_url`." ) - self._openai_client = AsyncOpenAI( + raw_client = AsyncOpenAI( base_url=self.inference_base_url, api_key=self.inference_api_key, http_client=DefaultAsyncHttpxClient( @@ -255,6 +302,13 @@ def openai_client( ), ), ) + # Wrap the raw OpenAI client so ART-owned inference calls can add + # split-scoped Tinker costs without rollout code needing to do it + # manually. + self._openai_client = cast( + AsyncOpenAI, + _OpenAIClientProxy(raw_client, self._record_openai_completion_costs), + ) return self._openai_client def litellm_completion_params(self, step: int | None = None) -> dict: @@ -304,6 +358,10 @@ def get_inference_name(self, step: int | None = None) -> str: return f"{base_name}@{step}" return base_name + def _record_openai_completion_costs(self, _response: Any) -> None: + """Hook for subclasses that want to auto-log managed inference costs.""" + return + def _get_output_dir(self) -> str: """Get the output directory for this model.""" return f"{self.base_path}/{self.project}/models/{self.name}" @@ -946,6 +1004,34 @@ def _noop_cost_calculator( ) -> dict[str, float]: return {} + def _record_openai_completion_costs(self, _response: Any) -> None: + try: + builder = MetricsBuilder.get_active() + except LookupError: + return + + usage = getattr(_response, "usage", None) + prompt_tokens = int(getattr(usage, "prompt_tokens", 0) or 0) + completion_tokens = int(getattr(usage, "completion_tokens", 0) or 0) + num_choices = len(getattr(_response, "choices", None) or []) + effective_prompt_tokens = prompt_tokens * max(num_choices, 1) + cost_context = builder.cost_context.strip("/") + if not cost_context: + return + + cost_metrics = self._cost_calculator( + effective_prompt_tokens, + completion_tokens, + cost_context, + ) + if not cost_metrics: + return + + for key, value in cost_metrics.items(): + if not key.startswith("costs/"): + continue + builder.add_cost(key[len("costs/") :], float(value)) + @overload def __new__( cls, diff --git a/tests/unit/test_model_openai_client_costs.py b/tests/unit/test_model_openai_client_costs.py new file mode 100644 index 00000000..b88e6bb6 --- /dev/null +++ b/tests/unit/test_model_openai_client_costs.py @@ -0,0 +1,154 @@ +import importlib +from typing import Any + +import pytest + +from art import TrainableModel +from art.costs import build_cost_calculator, get_model_pricing + + +class _FakeUsage: + def __init__(self, prompt_tokens: int, completion_tokens: int) -> None: + self.prompt_tokens = prompt_tokens + self.completion_tokens = completion_tokens + + +class _FakeResponse: + def __init__( + self, + prompt_tokens: int, + completion_tokens: int, + *, + num_choices: int = 1, + ) -> None: + self.usage = _FakeUsage(prompt_tokens, completion_tokens) + self.choices = [object() for _ in range(num_choices)] + + +class _FakeCompletions: + def __init__(self, response: _FakeResponse) -> None: + self._response = response + + async def create(self, *args: Any, **kwargs: Any) -> _FakeResponse: + return self._response + + +def _patch_async_openai( + monkeypatch: pytest.MonkeyPatch, response: _FakeResponse +) -> None: + model_module = importlib.import_module("art.model") + + class _FakeAsyncOpenAI: + def __init__(self, *args: Any, **kwargs: Any) -> None: + self.chat = type( + "FakeChat", + (), + {"completions": _FakeCompletions(response)}, + )() + + def with_options(self, *args: Any, **kwargs: Any) -> "_FakeAsyncOpenAI": + return self + + monkeypatch.setattr(model_module, "AsyncOpenAI", _FakeAsyncOpenAI) + + +def _build_model() -> TrainableModel: + pricing = get_model_pricing("openai/gpt-oss-20b") + assert pricing is not None + + model = TrainableModel( + name="test-run", + project="test-project", + base_model="openai/gpt-oss-20b", + ) + model.inference_api_key = "test-key" + model.inference_base_url = "http://example.test/v1" + model.set_cost_calculator(build_cost_calculator(pricing)) + return model + + +class TestModelOpenAIClientCosts: + @pytest.mark.asyncio + async def test_openai_client_automatically_logs_train_tinker_costs( + self, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + _patch_async_openai(monkeypatch, _FakeResponse(1_000, 2_000)) + model = _build_model() + builder = model.metrics_builder("train") + + with builder.activate_context(): + await model.openai_client().chat.completions.create( + model=model.get_inference_name(), + messages=[{"role": "user", "content": "hello"}], + ) + + metrics = await builder.flush() + assert metrics["costs/train/tinker_prefill"] == pytest.approx(0.00012) + assert metrics["costs/train/tinker_sample"] == pytest.approx(0.0006) + assert metrics["costs/train"] == pytest.approx(0.00072) + + @pytest.mark.asyncio + async def test_openai_client_automatically_logs_eval_tinker_costs( + self, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + _patch_async_openai(monkeypatch, _FakeResponse(500, 250)) + model = _build_model() + builder = model.metrics_builder("eval") + + with builder.activate_context(): + await model.openai_client().chat.completions.create( + model=model.get_inference_name(), + messages=[{"role": "user", "content": "hello"}], + ) + + metrics = await builder.flush() + assert metrics["costs/eval/tinker_prefill"] == pytest.approx(0.00006) + assert metrics["costs/eval/tinker_sample"] == pytest.approx(0.000075) + assert metrics["costs/eval"] == pytest.approx(0.000135) + + @pytest.mark.asyncio + async def test_openai_client_does_not_log_costs_without_active_metrics_context( + self, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + _patch_async_openai(monkeypatch, _FakeResponse(1_000, 2_000)) + model = _build_model() + builder = model.metrics_builder("train") + + await model.openai_client().chat.completions.create( + model=model.get_inference_name(), + messages=[{"role": "user", "content": "hello"}], + ) + + metrics = await builder.flush() + assert metrics == {} + + @pytest.mark.asyncio + async def test_multiple_choices_scale_prefill_cost_once_per_sample( + self, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + _patch_async_openai(monkeypatch, _FakeResponse(1_000, 2_000, num_choices=3)) + model = _build_model() + builder = model.metrics_builder("train") + + with builder.activate_context(): + await model.openai_client().chat.completions.create( + model=model.get_inference_name(), + messages=[{"role": "user", "content": "hello"}], + n=3, + ) + + metrics = await builder.flush() + assert metrics["costs/train/tinker_prefill"] == pytest.approx(0.00036) + assert metrics["costs/train/tinker_sample"] == pytest.approx(0.0006) + + def test_manual_cost_calculator_still_returns_tinker_metrics(self) -> None: + model = _build_model() + + metrics = model.cost_calculator(1_000, 2_000, "train") + + assert metrics["costs/train/tinker_prefill"] == pytest.approx(0.00012) + assert metrics["costs/train/tinker_sample"] == pytest.approx(0.0006)