Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/art/local/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@

from .. import dev
from ..backend import AnyTrainableModel, Backend
from ..costs import build_cost_calculator, get_model_pricing
from ..metrics_taxonomy import (
TRAIN_GRADIENT_STEPS_KEY,
average_metric_samples,
Expand Down Expand Up @@ -206,6 +207,11 @@ async def register(
# (wandb initialization is now handled by the model's _get_wandb_run method)
if model.trainable and "WANDB_API_KEY" in os.environ:
_ = model._get_wandb_run()
if model.trainable:
trainable_model = cast(TrainableModel, model)
pricing = get_model_pricing(trainable_model.base_model)
if pricing is not None:
trainable_model.set_cost_calculator(build_cost_calculator(pricing))

def _model_inference_name(self, model: Model, step: int | None = None) -> str:
"""Return the inference name for a model checkpoint.
Expand Down
88 changes: 87 additions & 1 deletion src/art/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,47 @@
StateType = TypeVar("StateType", bound=dict[str, Any], default=dict[str, Any])

METRICS_BUILDER_STATE_KEY = "_metrics_builder_state"


class _OpenAIChatCompletionsProxy:
def __init__(self, completions: Any, record_costs: Any) -> None:
self._completions = completions
self._record_costs = record_costs

async def create(self, *args: Any, **kwargs: Any) -> Any:
response = await self._completions.create(*args, **kwargs)
self._record_costs(response)
return response

def __getattr__(self, name: str) -> Any:
return getattr(self._completions, name)


class _OpenAIChatProxy:
def __init__(self, chat: Any, record_costs: Any) -> None:
self._chat = chat
self.completions = _OpenAIChatCompletionsProxy(chat.completions, record_costs)

def __getattr__(self, name: str) -> Any:
return getattr(self._chat, name)


class _OpenAIClientProxy:
def __init__(self, client: Any, record_costs: Any) -> None:
self._client = client
self._record_costs = record_costs
self.chat = _OpenAIChatProxy(client.chat, record_costs)

def with_options(self, *args: Any, **kwargs: Any) -> "_OpenAIClientProxy":
return _OpenAIClientProxy(
self._client.with_options(*args, **kwargs),
self._record_costs,
)

def __getattr__(self, name: str) -> Any:
return getattr(self._client, name)


METRIC_SECTIONS = frozenset(
{
"reward",
Expand Down Expand Up @@ -233,6 +274,12 @@ async def register(self, backend: "Backend") -> None:
def openai_client(
self,
) -> AsyncOpenAI:
"""Return ART's managed inference client.

For trainable models with configured pricing, chat completion calls made
through this client automatically emit Tinker inference costs when an
ART metrics context is active.
"""
if self._openai_client is not None:
return self._openai_client

Expand All @@ -245,7 +292,7 @@ def openai_client(
raise ValueError(
"In order to create an OpenAI client you must provide an `inference_api_key` and `inference_base_url`."
)
self._openai_client = AsyncOpenAI(
raw_client = AsyncOpenAI(
base_url=self.inference_base_url,
api_key=self.inference_api_key,
http_client=DefaultAsyncHttpxClient(
Expand All @@ -255,6 +302,13 @@ def openai_client(
),
),
)
# Wrap the raw OpenAI client so ART-owned inference calls can add
# split-scoped Tinker costs without rollout code needing to do it
# manually.
self._openai_client = cast(
AsyncOpenAI,
_OpenAIClientProxy(raw_client, self._record_openai_completion_costs),
)
return self._openai_client

def litellm_completion_params(self, step: int | None = None) -> dict:
Expand Down Expand Up @@ -304,6 +358,10 @@ def get_inference_name(self, step: int | None = None) -> str:
return f"{base_name}@{step}"
return base_name

def _record_openai_completion_costs(self, _response: Any) -> None:
"""Hook for subclasses that want to auto-log managed inference costs."""
return

def _get_output_dir(self) -> str:
"""Get the output directory for this model."""
return f"{self.base_path}/{self.project}/models/{self.name}"
Expand Down Expand Up @@ -946,6 +1004,34 @@ def _noop_cost_calculator(
) -> dict[str, float]:
return {}

def _record_openai_completion_costs(self, _response: Any) -> None:
try:
builder = MetricsBuilder.get_active()
except LookupError:
return

usage = getattr(_response, "usage", None)
prompt_tokens = int(getattr(usage, "prompt_tokens", 0) or 0)
completion_tokens = int(getattr(usage, "completion_tokens", 0) or 0)
num_choices = len(getattr(_response, "choices", None) or [])
effective_prompt_tokens = prompt_tokens * max(num_choices, 1)
cost_context = builder.cost_context.strip("/")
if not cost_context:
return

cost_metrics = self._cost_calculator(
effective_prompt_tokens,
completion_tokens,
cost_context,
)
if not cost_metrics:
return

for key, value in cost_metrics.items():
if not key.startswith("costs/"):
continue
builder.add_cost(key[len("costs/") :], float(value))

@overload
def __new__(
cls,
Expand Down
154 changes: 154 additions & 0 deletions tests/unit/test_model_openai_client_costs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import importlib
from typing import Any

import pytest

from art import TrainableModel
from art.costs import build_cost_calculator, get_model_pricing


class _FakeUsage:
def __init__(self, prompt_tokens: int, completion_tokens: int) -> None:
self.prompt_tokens = prompt_tokens
self.completion_tokens = completion_tokens


class _FakeResponse:
def __init__(
self,
prompt_tokens: int,
completion_tokens: int,
*,
num_choices: int = 1,
) -> None:
self.usage = _FakeUsage(prompt_tokens, completion_tokens)
self.choices = [object() for _ in range(num_choices)]


class _FakeCompletions:
def __init__(self, response: _FakeResponse) -> None:
self._response = response

async def create(self, *args: Any, **kwargs: Any) -> _FakeResponse:
return self._response


def _patch_async_openai(
monkeypatch: pytest.MonkeyPatch, response: _FakeResponse
) -> None:
model_module = importlib.import_module("art.model")

class _FakeAsyncOpenAI:
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.chat = type(
"FakeChat",
(),
{"completions": _FakeCompletions(response)},
)()

def with_options(self, *args: Any, **kwargs: Any) -> "_FakeAsyncOpenAI":
return self

monkeypatch.setattr(model_module, "AsyncOpenAI", _FakeAsyncOpenAI)


def _build_model() -> TrainableModel:
pricing = get_model_pricing("openai/gpt-oss-20b")
assert pricing is not None

model = TrainableModel(
name="test-run",
project="test-project",
base_model="openai/gpt-oss-20b",
)
model.inference_api_key = "test-key"
model.inference_base_url = "http://example.test/v1"
model.set_cost_calculator(build_cost_calculator(pricing))
return model


class TestModelOpenAIClientCosts:
@pytest.mark.asyncio
async def test_openai_client_automatically_logs_train_tinker_costs(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_async_openai(monkeypatch, _FakeResponse(1_000, 2_000))
model = _build_model()
builder = model.metrics_builder("train")

with builder.activate_context():
await model.openai_client().chat.completions.create(
model=model.get_inference_name(),
messages=[{"role": "user", "content": "hello"}],
)

metrics = await builder.flush()
assert metrics["costs/train/tinker_prefill"] == pytest.approx(0.00012)
assert metrics["costs/train/tinker_sample"] == pytest.approx(0.0006)
assert metrics["costs/train"] == pytest.approx(0.00072)

@pytest.mark.asyncio
async def test_openai_client_automatically_logs_eval_tinker_costs(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_async_openai(monkeypatch, _FakeResponse(500, 250))
model = _build_model()
builder = model.metrics_builder("eval")

with builder.activate_context():
await model.openai_client().chat.completions.create(
model=model.get_inference_name(),
messages=[{"role": "user", "content": "hello"}],
)

metrics = await builder.flush()
assert metrics["costs/eval/tinker_prefill"] == pytest.approx(0.00006)
assert metrics["costs/eval/tinker_sample"] == pytest.approx(0.000075)
assert metrics["costs/eval"] == pytest.approx(0.000135)

@pytest.mark.asyncio
async def test_openai_client_does_not_log_costs_without_active_metrics_context(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_async_openai(monkeypatch, _FakeResponse(1_000, 2_000))
model = _build_model()
builder = model.metrics_builder("train")

await model.openai_client().chat.completions.create(
model=model.get_inference_name(),
messages=[{"role": "user", "content": "hello"}],
)

metrics = await builder.flush()
assert metrics == {}

@pytest.mark.asyncio
async def test_multiple_choices_scale_prefill_cost_once_per_sample(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_async_openai(monkeypatch, _FakeResponse(1_000, 2_000, num_choices=3))
model = _build_model()
builder = model.metrics_builder("train")

with builder.activate_context():
await model.openai_client().chat.completions.create(
model=model.get_inference_name(),
messages=[{"role": "user", "content": "hello"}],
n=3,
)

metrics = await builder.flush()
assert metrics["costs/train/tinker_prefill"] == pytest.approx(0.00036)
assert metrics["costs/train/tinker_sample"] == pytest.approx(0.0006)

def test_manual_cost_calculator_still_returns_tinker_metrics(self) -> None:
model = _build_model()

metrics = model.cost_calculator(1_000, 2_000, "train")

assert metrics["costs/train/tinker_prefill"] == pytest.approx(0.00012)
assert metrics["costs/train/tinker_sample"] == pytest.approx(0.0006)
Loading