From e37d86bdb72db3b5d91d3b3b4be9caa655aceb6d Mon Sep 17 00:00:00 2001 From: amabito Date: Tue, 31 Mar 2026 09:29:06 +0900 Subject: [PATCH 1/3] feat(evaluators): add contrib budget evaluator for per-agent cost tracking New contrib evaluator "budget" that tracks cumulative token/cost usage per agent, channel, user. Configurable time windows via window_seconds. Design per reviewer feedback: - Contrib evaluator (not builtin) for production hardening - Integer limit + Currency enum (USD/EUR/tokens) - window_seconds (int) instead of named windows - group_by for dynamic per-user/per-channel budgets - Evaluator owns cost computation from pricing table - BudgetStore protocol + InMemoryBudgetStore (dict + Lock) - Store derives period keys internally, injectable clock Addresses #130. 55 tests (incl. thread safety, NaN/Inf, scope injection, double-count). --- evaluators/contrib/budget/README.md | 3 + evaluators/contrib/budget/pyproject.toml | 47 ++ .../__init__.py | 0 .../budget/__init__.py | 14 + .../budget/config.py | 107 ++++ .../budget/evaluator.py | 199 ++++++ .../budget/memory_store.py | 247 ++++++++ .../budget/store.py | 67 ++ evaluators/contrib/budget/tests/__init__.py | 0 .../contrib/budget/tests/budget/__init__.py | 0 .../budget/tests/budget/test_budget.py | 599 ++++++++++++++++++ pyproject.toml | 6 + 12 files changed, 1289 insertions(+) create mode 100644 evaluators/contrib/budget/README.md create mode 100644 evaluators/contrib/budget/pyproject.toml create mode 100644 evaluators/contrib/budget/src/agent_control_evaluator_budget/__init__.py create mode 100644 evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/__init__.py create mode 100644 evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/config.py create mode 100644 evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/evaluator.py create mode 100644 evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/memory_store.py create mode 100644 evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/store.py create mode 100644 evaluators/contrib/budget/tests/__init__.py create mode 100644 evaluators/contrib/budget/tests/budget/__init__.py create mode 100644 evaluators/contrib/budget/tests/budget/test_budget.py diff --git a/evaluators/contrib/budget/README.md b/evaluators/contrib/budget/README.md new file mode 100644 index 00000000..ddd159e8 --- /dev/null +++ b/evaluators/contrib/budget/README.md @@ -0,0 +1,3 @@ +# Budget Evaluator + +Cumulative LLM cost and token budget tracking for agent-control. diff --git a/evaluators/contrib/budget/pyproject.toml b/evaluators/contrib/budget/pyproject.toml new file mode 100644 index 00000000..6115e442 --- /dev/null +++ b/evaluators/contrib/budget/pyproject.toml @@ -0,0 +1,47 @@ +[project] +name = "agent-control-evaluator-budget" +version = "0.1.0" +description = "Budget evaluator for agent-control -- cumulative LLM cost and token tracking" +readme = "README.md" +requires-python = ">=3.12" +license = { text = "Apache-2.0" } +authors = [{ name = "Agent Control Team" }] +dependencies = [ + "agent-control-evaluators>=3.0.0", + "agent-control-models>=3.0.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.0.0", + "pytest-asyncio>=0.23.0", + "ruff>=0.1.0", + "mypy>=1.8.0", +] + +[project.entry-points."agent_control.evaluators"] +budget = "agent_control_evaluator_budget.budget:BudgetEvaluator" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/agent_control_evaluator_budget"] + +[tool.ruff] +line-length = 100 +target-version = "py312" + +[tool.ruff.lint] +select = ["E", "F", "I"] + +[tool.uv.sources] +agent-control-evaluators = { path = "../../builtin", editable = true } +agent-control-models = { path = "../../../models", editable = true } + +[dependency-groups] +dev = [ + "pytest>=9.0.2", + "pytest-asyncio>=1.3.0", +] diff --git a/evaluators/contrib/budget/src/agent_control_evaluator_budget/__init__.py b/evaluators/contrib/budget/src/agent_control_evaluator_budget/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/__init__.py b/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/__init__.py new file mode 100644 index 00000000..b0e6f6d4 --- /dev/null +++ b/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/__init__.py @@ -0,0 +1,14 @@ +"""Budget evaluator for per-agent LLM cost and token tracking.""" + +from agent_control_evaluator_budget.budget.config import BudgetEvaluatorConfig +from agent_control_evaluator_budget.budget.evaluator import BudgetEvaluator +from agent_control_evaluator_budget.budget.memory_store import InMemoryBudgetStore +from agent_control_evaluator_budget.budget.store import BudgetSnapshot, BudgetStore + +__all__ = [ + "BudgetEvaluator", + "BudgetEvaluatorConfig", + "BudgetSnapshot", + "BudgetStore", + "InMemoryBudgetStore", +] diff --git a/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/config.py b/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/config.py new file mode 100644 index 00000000..6a261f43 --- /dev/null +++ b/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/config.py @@ -0,0 +1,107 @@ +"""Configuration for the budget evaluator.""" + +from __future__ import annotations + +from enum import Enum + +from agent_control_evaluators._base import EvaluatorConfig +from pydantic import Field, field_validator, model_validator + +# --------------------------------------------------------------------------- +# Window convenience constants (seconds) +# --------------------------------------------------------------------------- + +WINDOW_HOURLY = 3600 +WINDOW_DAILY = 86400 +WINDOW_WEEKLY = 604800 +WINDOW_MONTHLY = 2592000 # 30 days + + +class Currency(str, Enum): + """Supported budget currencies.""" + + USD = "usd" + EUR = "eur" + TOKENS = "tokens" + + +class BudgetLimitRule(EvaluatorConfig): + """A single budget limit rule. + + Each rule defines a ceiling for a combination of scope dimensions + and time window. Multiple rules can apply to the same step -- the + evaluator checks all of them and triggers on the first breach. + + Attributes: + scope: Static scope dimensions that must match for this rule + to apply. Empty dict = global rule. + Examples: + {"agent": "summarizer"} -- per-agent limit + {"agent": "summarizer", "channel": "slack"} -- agent+channel limit + group_by: If set, the limit is applied independently for each + unique value of this dimension. e.g. group_by="user_id" means + each user gets their own budget. None = shared/global limit. + window_seconds: Time window for accumulation in seconds. + None = cumulative (no reset). See WINDOW_* constants. + limit: Maximum spend in the window, in minor units (e.g. cents + for USD). None = uncapped on this dimension. + currency: Currency for the limit. Defaults to USD. + limit_tokens: Maximum tokens in the window. None = uncapped. + """ + + scope: dict[str, str] = Field(default_factory=dict) + group_by: str | None = None + window_seconds: int | None = None + limit: int | None = None + currency: Currency = Currency.USD + limit_tokens: int | None = None + + @model_validator(mode="after") + def at_least_one_limit(self) -> "BudgetLimitRule": + if self.limit is None and self.limit_tokens is None: + raise ValueError("At least one of limit or limit_tokens must be set") + return self + + @field_validator("limit") + @classmethod + def validate_limit(cls, v: int | None) -> int | None: + if v is not None and v <= 0: + raise ValueError("limit must be a positive integer") + return v + + @field_validator("limit_tokens") + @classmethod + def validate_limit_tokens(cls, v: int | None) -> int | None: + if v is not None and v <= 0: + raise ValueError("limit_tokens must be positive") + return v + + @field_validator("window_seconds") + @classmethod + def validate_window_seconds(cls, v: int | None) -> int | None: + if v is not None and v <= 0: + raise ValueError("window_seconds must be positive") + return v + + +class BudgetEvaluatorConfig(EvaluatorConfig): + """Configuration for the budget evaluator. + + Attributes: + limits: List of budget limit rules. Each is checked independently. + pricing: Optional model pricing table. Maps model name to per-1K + token rates. Used to derive cost in USD from token counts and + model name. + token_path: Dot-notation path to extract token usage from step + data (e.g. "usage.total_tokens"). If None, looks for standard + fields (input_tokens, output_tokens, total_tokens, usage). + model_path: Dot-notation path to extract model name (for pricing lookup). + metadata_paths: Mapping of metadata field name to dot-notation path + in step data. Used to extract scope dimensions (channel, user_id, etc). + """ + + limits: list[BudgetLimitRule] = Field(min_length=1) + pricing: dict[str, dict[str, float]] | None = None + token_path: str | None = None + model_path: str | None = None + metadata_paths: dict[str, str] = Field(default_factory=dict) diff --git a/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/evaluator.py b/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/evaluator.py new file mode 100644 index 00000000..6d1ca128 --- /dev/null +++ b/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/evaluator.py @@ -0,0 +1,199 @@ +"""Budget evaluator -- tracks cumulative LLM token/cost usage. + +Deterministic evaluator: confidence is always 1.0, matched is True when +any configured limit is exceeded. Utilization ratio and spend breakdown +are returned in result metadata, not in confidence. +""" + +from __future__ import annotations + +import logging +import math +from typing import Any + +from agent_control_evaluators._base import Evaluator, EvaluatorMetadata +from agent_control_evaluators._registry import register_evaluator +from agent_control_models import EvaluatorResult + +from .config import BudgetEvaluatorConfig +from .memory_store import InMemoryBudgetStore + +logger = logging.getLogger(__name__) + + +def _extract_by_path(data: Any, path: str) -> Any: + """Extract a value from nested data using dot-notation path.""" + current = data + for part in path.split("."): + if part.startswith("__"): + return None + if isinstance(current, dict): + current = current.get(part) + elif hasattr(current, part): + current = getattr(current, part) + else: + return None + if current is None: + return None + return current + + +def _extract_tokens(data: Any, token_path: str | None) -> tuple[int, int]: + """Extract (input_tokens, output_tokens) from step data. + + Tries token_path first, then standard field names. + Returns (0, 0) if no token information found. + """ + if data is None: + return 0, 0 + + if token_path: + val = _extract_by_path(data, token_path) + if isinstance(val, int) and not isinstance(val, bool) and val >= 0: + return 0, val + if isinstance(val, dict): + data = val + + if isinstance(data, dict): + usage = data.get("usage", data) + if isinstance(usage, dict): + inp = usage.get("input_tokens") + if inp is None: + inp = usage.get("prompt_tokens") + out = usage.get("output_tokens") + if out is None: + out = usage.get("completion_tokens") + inp_ok = isinstance(inp, int) and not isinstance(inp, bool) + out_ok = isinstance(out, int) and not isinstance(out, bool) + if inp_ok and out_ok: + return max(0, inp), max(0, out) + total = usage.get("total_tokens") + if isinstance(total, int) and not isinstance(total, bool) and total > 0: + return 0, max(0, total) + return 0, 0 + + +def _estimate_cost( + model: str | None, + input_tokens: int, + output_tokens: int, + pricing: dict[str, dict[str, float]] | None, +) -> int: + """Estimate cost in minor units from model pricing table. Returns 0 if unknown.""" + if not model or not pricing: + return 0 + rates = pricing.get(model) + if not rates: + return 0 + input_rate = rates.get("input_per_1k", 0.0) + output_rate = rates.get("output_per_1k", 0.0) + cost = (input_tokens * input_rate + output_tokens * output_rate) / 1000.0 + if not math.isfinite(cost) or cost < 0: + return 0 + return math.ceil(cost) + + +def _extract_metadata(data: Any, metadata_paths: dict[str, str]) -> dict[str, str]: + """Extract metadata fields from step data using configured paths.""" + result: dict[str, str] = {} + for field_name, path in metadata_paths.items(): + val = _extract_by_path(data, path) + if val is not None: + result[field_name] = str(val) + return result + + +@register_evaluator +class BudgetEvaluator(Evaluator[BudgetEvaluatorConfig]): + """Tracks cumulative LLM token and cost usage per scope and time window. + + Deterministic evaluator: matched=True when any configured limit is + exceeded, confidence=1.0 always. + + The evaluator is stateful -- it accumulates usage in a BudgetStore. + The store is created per evaluator config and is thread-safe. + """ + + metadata = EvaluatorMetadata( + name="budget", + version="2.0.0", + description="Cumulative LLM token and cost budget tracking", + ) + config_model = BudgetEvaluatorConfig + + def __init__(self, config: BudgetEvaluatorConfig) -> None: + super().__init__(config) + self._store = InMemoryBudgetStore(rules=config.limits) + + async def evaluate(self, data: Any) -> EvaluatorResult: + """Evaluate step data against all configured budget limits.""" + if data is None: + return EvaluatorResult( + matched=False, + confidence=1.0, + message="No data to evaluate", + ) + + input_tokens, output_tokens = _extract_tokens(data, self.config.token_path) + + model: str | None = None + if self.config.model_path: + val = _extract_by_path(data, self.config.model_path) + if val is not None: + model = str(val) + + cost = _estimate_cost(model, input_tokens, output_tokens, self.config.pricing) + + step_metadata = _extract_metadata(data, self.config.metadata_paths) + + snapshots = self._store.record_and_check( + scope=step_metadata, + input_tokens=input_tokens, + output_tokens=output_tokens, + cost=cost, + ) + + breached: list[dict[str, Any]] = [] + all_snaps: list[dict[str, Any]] = [] + + for i, snap in enumerate(snapshots): + snap_info = { + "spent": snap.spent, + "spent_tokens": snap.spent_tokens, + "limit": snap.limit, + "limit_tokens": snap.limit_tokens, + "utilization": round(snap.utilization, 4), + "exceeded": snap.exceeded, + } + all_snaps.append(snap_info) + if snap.exceeded: + breached.append(snap_info) + + if breached: + first = breached[0] + return EvaluatorResult( + matched=True, + confidence=1.0, + message=f"Budget exceeded (utilization={first['utilization']:.0%})", + metadata={ + "breached_rules": breached, + "all_snapshots": all_snaps, + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "cost": cost, + }, + ) + + max_util = max((s["utilization"] for s in all_snaps), default=0.0) + return EvaluatorResult( + matched=False, + confidence=1.0, + message=f"Within budget (utilization={max_util:.0%})", + metadata={ + "all_snapshots": all_snaps, + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "cost": cost, + "max_utilization": round(max_util, 4), + }, + ) diff --git a/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/memory_store.py b/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/memory_store.py new file mode 100644 index 00000000..21130cf8 --- /dev/null +++ b/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/memory_store.py @@ -0,0 +1,247 @@ +"""In-memory budget store implementation. + +Not suitable for multi-process deployments. For distributed setups, +use a Redis or Postgres-backed store (separate package). +""" + +from __future__ import annotations + +import threading +import time +from collections.abc import Callable +from dataclasses import dataclass + +from .config import BudgetLimitRule +from .store import BudgetSnapshot + + +def _sanitize_scope_value(val: str) -> str: + """Percent-encode pipe and equals in scope values to prevent key injection.""" + return val.replace("%", "%25").replace("|", "%7C").replace("=", "%3D") + + +def _build_scope_key( + rule_scope: dict[str, str], + group_by: str | None, + step_scope: dict[str, str], +) -> str: + """Build a composite scope key from rule dimensions and group_by field.""" + parts: list[str] = [] + for k, v in sorted(rule_scope.items()): + parts.append(f"{k}={_sanitize_scope_value(v)}") + if group_by and group_by in step_scope: + parts.append(f"{group_by}={_sanitize_scope_value(step_scope[group_by])}") + return "|".join(parts) if parts else "__global__" + + +def _derive_period_key(window_seconds: int | None, now: float) -> str: + """Derive a period key from window_seconds and a timestamp. + + Periods are aligned to UTC epoch boundaries. For example, + window_seconds=86400 produces keys like "P86400:19800" where + 19800 is the number of complete windows since epoch. + """ + if window_seconds is None: + return "" + period_index = int(now) // window_seconds + return f"P{window_seconds}:{period_index}" + + +def _scope_matches(rule: BudgetLimitRule, scope: dict[str, str]) -> bool: + """Check if rule's scope dimensions match step scope.""" + for key, expected in rule.scope.items(): + if scope.get(key) != expected: + return False + if rule.group_by and rule.group_by not in scope: + return False + return True + + +def _compute_utilization( + spent: int, + spent_tokens: int, + limit: int | None, + limit_tokens: int | None, +) -> float: + """Return max(spend_ratio, token_ratio) clamped to [0.0, 1.0].""" + ratios: list[float] = [] + if limit is not None and limit > 0: + ratios.append(min(spent / limit, 1.0)) + if limit_tokens is not None and limit_tokens > 0: + ratios.append(min(spent_tokens / limit_tokens, 1.0)) + return max(ratios) if ratios else 0.0 + + +@dataclass +class _Bucket: + """Internal mutable accumulator for a single (scope, period) pair.""" + + spent: int = 0 + input_tokens: int = 0 + output_tokens: int = 0 + + @property + def total_tokens(self) -> int: + return self.input_tokens + self.output_tokens + + +class InMemoryBudgetStore: + """Thread-safe in-memory budget store. + + Initialized with a list of BudgetLimitRule. Derives period keys + internally from window_seconds + injected clock. + + NOTE: Currency conversion is not handled here. The cost integer + passed to record_and_check is assumed to be in the same unit as + the rule's currency. Cross-currency conversion (e.g. USD->EUR) + is the caller's responsibility and will be addressed when cost + calculation moves into the evaluator (pending design review). + """ + + _DEFAULT_MAX_BUCKETS = 100_000 + + def __init__( + self, + rules: list[BudgetLimitRule], + *, + clock: Callable[[], float] = time.time, + max_buckets: int = _DEFAULT_MAX_BUCKETS, + ) -> None: + self._rules = rules + self._clock = clock + self._lock = threading.Lock() + self._buckets: dict[tuple[str, str, str], _Bucket] = {} + self._max_buckets = max_buckets + + def record_and_check( + self, + scope: dict[str, str], + input_tokens: int, + output_tokens: int, + cost: int, + ) -> list[BudgetSnapshot]: + """Atomically record usage and return snapshots for all matching rules.""" + now = self._clock() + snapshots: list[BudgetSnapshot] = [] + recorded_pairs: set[tuple[str, str, str]] = set() + + with self._lock: + for rule in self._rules: + if not _scope_matches(rule, scope): + continue + + scope_key = _build_scope_key(rule.scope, rule.group_by, scope) + period_key = _derive_period_key(rule.window_seconds, now) + cur = rule.currency + currency_key = cur.value if hasattr(cur, "value") else str(cur) + pair = (scope_key, period_key, currency_key) + + if pair not in recorded_pairs: + bucket = self._get_or_create_bucket(pair) + if bucket is None: + # Max buckets reached -- fail closed + snapshots.append( + BudgetSnapshot( + spent=0, + spent_tokens=0, + limit=rule.limit, + limit_tokens=rule.limit_tokens, + utilization=1.0, + exceeded=True, + ) + ) + continue + bucket.spent += cost + bucket.input_tokens += input_tokens + bucket.output_tokens += output_tokens + recorded_pairs.add(pair) + else: + bucket = self._buckets.get(pair) + if bucket is None: + continue + + total_tokens = bucket.total_tokens + utilization = _compute_utilization( + bucket.spent, total_tokens, rule.limit, rule.limit_tokens + ) + exceeded = False + if rule.limit is not None and bucket.spent >= rule.limit: + exceeded = True + if rule.limit_tokens is not None and total_tokens >= rule.limit_tokens: + exceeded = True + + snapshots.append( + BudgetSnapshot( + spent=bucket.spent, + spent_tokens=total_tokens, + limit=rule.limit, + limit_tokens=rule.limit_tokens, + utilization=utilization, + exceeded=exceeded, + ) + ) + + return snapshots + + def get_snapshot( + self, + scope_key: str, + period_key: str, + limit: int | None = None, + limit_tokens: int | None = None, + currency: str = "usd", + ) -> BudgetSnapshot: + """Read current budget state without recording usage.""" + key = (scope_key, period_key, currency) + with self._lock: + bucket = self._buckets.get(key) + if bucket is None: + return BudgetSnapshot( + spent=0, + spent_tokens=0, + limit=limit, + limit_tokens=limit_tokens, + utilization=0.0, + exceeded=False, + ) + total_tokens = bucket.total_tokens + utilization = _compute_utilization(bucket.spent, total_tokens, limit, limit_tokens) + exceeded = False + if limit is not None and bucket.spent >= limit: + exceeded = True + if limit_tokens is not None and total_tokens >= limit_tokens: + exceeded = True + return BudgetSnapshot( + spent=bucket.spent, + spent_tokens=total_tokens, + limit=limit, + limit_tokens=limit_tokens, + utilization=utilization, + exceeded=exceeded, + ) + + def reset(self, scope_key: str | None = None, period_key: str | None = None) -> None: + """Clear accumulated usage.""" + with self._lock: + if scope_key is None and period_key is None: + self._buckets.clear() + return + keys_to_remove = [ + k + for k in self._buckets + if (scope_key is None or k[0] == scope_key) + and (period_key is None or k[1] == period_key) + ] + for k in keys_to_remove: + del self._buckets[k] + + def _get_or_create_bucket(self, key: tuple[str, str, str]) -> _Bucket | None: + """Get or create a bucket. Returns None if max_buckets reached.""" + bucket = self._buckets.get(key) + if bucket is not None: + return bucket + if len(self._buckets) >= self._max_buckets: + return None + bucket = _Bucket() + self._buckets[key] = bucket + return bucket diff --git a/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/store.py b/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/store.py new file mode 100644 index 00000000..0c767d3b --- /dev/null +++ b/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/store.py @@ -0,0 +1,67 @@ +"""BudgetStore protocol -- interface for budget storage backends. + +Implementations must provide atomic record-and-check: a single call +that records usage and returns the current totals. This prevents +read-then-write race conditions under concurrent access. + +Built-in: InMemoryBudgetStore (dict + threading.Lock). +External: Redis, PostgreSQL, etc. (separate packages). +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Protocol, runtime_checkable + + +@dataclass(frozen=True) +class BudgetSnapshot: + """Immutable view of budget state at a point in time. + + Attributes: + spent: Cumulative spend in minor units (e.g. cents for USD). + spent_tokens: Cumulative tokens (input + output) in this scope+period. + limit: Configured spend ceiling in minor units, or None if uncapped. + limit_tokens: Configured token ceiling, or None if uncapped. + utilization: max(spend_ratio, token_ratio) clamped to [0.0, 1.0]. + 0.0 when no limits are set. + exceeded: True when any limit is breached. + """ + + spent: int + spent_tokens: int + limit: int | None + limit_tokens: int | None + utilization: float + exceeded: bool + + +@runtime_checkable +class BudgetStore(Protocol): + """Protocol for budget storage backends. + + The store is initialized with a list of BudgetLimitRule and derives + period keys internally from window_seconds + current time. + + Callers pass only usage data: scope dict, input_tokens, output_tokens, cost. + """ + + def record_and_check( + self, + scope: dict[str, str], + input_tokens: int, + output_tokens: int, + cost: int, + ) -> list[BudgetSnapshot]: + """Atomically record usage and return snapshots for all matching rules. + + Args: + scope: Scope dimensions from the step (e.g. {"agent": "summarizer"}). + input_tokens: Input tokens consumed by this call. + output_tokens: Output tokens consumed by this call. + cost: Cost in minor units (e.g. cents for USD). + + Returns: + List of BudgetSnapshot, one per matching rule. + """ + ... diff --git a/evaluators/contrib/budget/tests/__init__.py b/evaluators/contrib/budget/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/evaluators/contrib/budget/tests/budget/__init__.py b/evaluators/contrib/budget/tests/budget/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/evaluators/contrib/budget/tests/budget/test_budget.py b/evaluators/contrib/budget/tests/budget/test_budget.py new file mode 100644 index 00000000..561f4edf --- /dev/null +++ b/evaluators/contrib/budget/tests/budget/test_budget.py @@ -0,0 +1,599 @@ +"""Tests for the budget evaluator (contrib). + +Given/When/Then comment style per reviewer request. +""" + +from __future__ import annotations + +import threading +from typing import Any + +import pytest +from pydantic import ValidationError + +from agent_control_evaluator_budget.budget.config import ( + WINDOW_DAILY, + WINDOW_MONTHLY, + WINDOW_WEEKLY, + BudgetEvaluatorConfig, + BudgetLimitRule, + Currency, +) +from agent_control_evaluator_budget.budget.evaluator import ( + BudgetEvaluator, + _extract_tokens, +) +from agent_control_evaluator_budget.budget.memory_store import ( + InMemoryBudgetStore, + _build_scope_key, + _compute_utilization, + _derive_period_key, +) + +# --------------------------------------------------------------------------- +# InMemoryBudgetStore +# --------------------------------------------------------------------------- + + +class TestInMemoryBudgetStore: + def test_single_record_under_limit(self) -> None: + # Given: store with a $10 daily limit (1000 cents) + rules = [BudgetLimitRule(limit=1000, window_seconds=WINDOW_DAILY)] + store = InMemoryBudgetStore(rules=rules, clock=lambda: 1700000000.0) + + # When: record 300 cents of usage + results = store.record_and_check(scope={}, input_tokens=100, output_tokens=50, cost=300) + + # Then: not breached, ratio ~0.3 + assert len(results) == 1 + assert not results[0].exceeded + assert results[0].utilization == pytest.approx(0.3, abs=0.01) + + def test_accumulation_triggers_breach(self) -> None: + # Given: store with 1000-cent limit + rules = [BudgetLimitRule(limit=1000)] + store = InMemoryBudgetStore(rules=rules, clock=lambda: 1700000000.0) + + # When: record 600 + 500 = 1100 cents + store.record_and_check(scope={}, input_tokens=100, output_tokens=50, cost=600) + results = store.record_and_check(scope={}, input_tokens=100, output_tokens=50, cost=500) + + # Then: exceeded + assert results[0].exceeded is True + assert results[0].spent == 1100 + + def test_scope_isolation(self) -> None: + # Given: per-agent limits + rules = [ + BudgetLimitRule(scope={"agent": "a"}, limit=1000), + BudgetLimitRule(scope={"agent": "b"}, limit=1000), + ] + store = InMemoryBudgetStore(rules=rules, clock=lambda: 1700000000.0) + + # When: agent-a records 900, agent-b records 100 + results_a = store.record_and_check( + scope={"agent": "a"}, input_tokens=0, output_tokens=0, cost=900 + ) + results_b = store.record_and_check( + scope={"agent": "b"}, input_tokens=0, output_tokens=0, cost=100 + ) + + # Then: agent-a near limit, agent-b well under + assert results_a[0].spent == 900 + assert results_b[0].spent == 100 + assert not results_b[0].exceeded + + def test_period_isolation(self) -> None: + # Given: daily limit, clock at two different days + rules = [BudgetLimitRule(limit=1000, window_seconds=WINDOW_DAILY)] + day1 = 1700000000.0 + day2 = day1 + WINDOW_DAILY + + # When: record on day 1, then day 2 + store = InMemoryBudgetStore(rules=rules, clock=lambda: day1) + store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=800) + + store._clock = lambda: day2 + results = store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=300) + + # Then: day 2 is a fresh period + assert results[0].spent == 300 + assert not results[0].exceeded + + def test_exceeded_exact_limit(self) -> None: + # Given: 1000-cent limit + rules = [BudgetLimitRule(limit=1000)] + store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0) + + # When: spend exactly 1000 + results = store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1000) + + # Then: exceeded (>= not >) + assert results[0].exceeded is True + + def test_token_only_limit(self) -> None: + # Given: 1000-token limit, no cost limit + rules = [BudgetLimitRule(limit_tokens=1000)] + store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0) + + # When: consume 600+500 = 1100 tokens + results = store.record_and_check(scope={}, input_tokens=600, output_tokens=500, cost=0) + + # Then: exceeded + assert results[0].exceeded is True + assert results[0].spent_tokens == 1100 + + def test_no_matching_rules(self) -> None: + # Given: rule for agent=summarizer only + rules = [BudgetLimitRule(scope={"agent": "summarizer"}, limit=1000)] + store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0) + + # When: step from agent=other + results = store.record_and_check( + scope={"agent": "other"}, input_tokens=100, output_tokens=50, cost=999 + ) + + # Then: no snapshots (rule didn't match) + assert results == [] + + def test_group_by_user(self) -> None: + # Given: global rule with group_by=user_id + rules = [BudgetLimitRule(group_by="user_id", limit=500)] + store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0) + + # When: two users each spend + store.record_and_check(scope={"user_id": "u1"}, input_tokens=0, output_tokens=0, cost=400) + results_u1 = store.record_and_check( + scope={"user_id": "u1"}, input_tokens=0, output_tokens=0, cost=200 + ) + results_u2 = store.record_and_check( + scope={"user_id": "u2"}, input_tokens=0, output_tokens=0, cost=300 + ) + + # Then: u1 exceeded, u2 not + assert results_u1[0].exceeded is True + assert results_u2[0].exceeded is False + + def test_thread_safety(self) -> None: + # Given: high-limit rule and 10 concurrent threads + rules = [BudgetLimitRule(limit=1_000_000)] + store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0) + errors: list[str] = [] + + def record_many() -> None: + try: + for _ in range(100): + store.record_and_check(scope={}, input_tokens=1, output_tokens=1, cost=1) + except Exception as exc: + errors.append(str(exc)) + + # When: 10 threads x 100 calls + threads = [threading.Thread(target=record_many) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + # Then: no errors, totals correct + assert errors == [] + snap = store.get_snapshot("__global__", _derive_period_key(None, 0.0), limit=1_000_000) + assert snap.spent_tokens == 2000 + assert snap.spent == 1000 + + def test_max_buckets_fail_closed(self) -> None: + # Given: store limited to 3 buckets with group_by=user_id + rules = [BudgetLimitRule(group_by="user_id", limit=100_000)] + store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0, max_buckets=3) + + # When: 5 different users try to record + exceeded_count = 0 + for i in range(5): + results = store.record_and_check( + scope={"user_id": f"u{i}"}, input_tokens=1, output_tokens=1, cost=1 + ) + if results and results[0].exceeded: + exceeded_count += 1 + + # Then: first 3 succeed, last 2 fail-closed + assert exceeded_count == 2 + + def test_reset_all(self) -> None: + # Given: store with recorded usage + rules = [BudgetLimitRule(limit=1000)] + store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0) + store.record_and_check(scope={}, input_tokens=10, output_tokens=10, cost=100) + + # When: reset all + store.reset() + + # Then: empty + snap = store.get_snapshot("__global__", "", limit=1000) + assert snap.spent == 0 + + +# --------------------------------------------------------------------------- +# Utility functions +# --------------------------------------------------------------------------- + + +class TestUtilities: + def test_compute_utilization_no_limits(self) -> None: + assert _compute_utilization(100, 10000, None, None) == 0.0 + + def test_compute_utilization_spend_only(self) -> None: + # Given: 500 of 1000 spent + assert _compute_utilization(500, 0, 1000, None) == pytest.approx(0.5) + + def test_compute_utilization_clamped(self) -> None: + assert _compute_utilization(2000, 0, 1000, None) == pytest.approx(1.0) + + def test_derive_period_key_none(self) -> None: + assert _derive_period_key(None, 0.0) == "" + + def test_derive_period_key_daily(self) -> None: + # Given: 1700000000 / 86400 = 19675 (truncated) + key = _derive_period_key(WINDOW_DAILY, 1700000000.0) + assert key == "P86400:19675" + + def test_derive_period_key_weekly(self) -> None: + key = _derive_period_key(WINDOW_WEEKLY, 1700000000.0) + assert key.startswith("P604800:") + + def test_build_scope_key_global(self) -> None: + assert _build_scope_key({}, None, {}) == "__global__" + + def test_build_scope_key_with_scope(self) -> None: + key = _build_scope_key({"channel": "slack"}, None, {}) + assert key == "channel=slack" + + def test_build_scope_key_with_group_by(self) -> None: + key = _build_scope_key({"channel": "slack"}, "user_id", {"user_id": "u1"}) + assert key == "channel=slack|user_id=u1" + + def test_build_scope_key_group_by_missing(self) -> None: + key = _build_scope_key({}, "user_id", {}) + assert key == "__global__" + + def test_extract_tokens_standard(self) -> None: + data = {"usage": {"input_tokens": 100, "output_tokens": 50}} + assert _extract_tokens(data, None) == (100, 50) + + def test_extract_tokens_openai(self) -> None: + data = {"usage": {"prompt_tokens": 80, "completion_tokens": 40}} + assert _extract_tokens(data, None) == (80, 40) + + def test_extract_tokens_none(self) -> None: + assert _extract_tokens(None, None) == (0, 0) + + +# --------------------------------------------------------------------------- +# BudgetLimitRule config validation +# --------------------------------------------------------------------------- + + +class TestBudgetLimitRuleConfig: + def test_valid_rule(self) -> None: + rule = BudgetLimitRule(limit=1000) + assert rule.limit == 1000 + assert rule.currency == Currency.USD + + def test_no_limit_rejected(self) -> None: + with pytest.raises(ValidationError, match="At least one"): + BudgetLimitRule() + + def test_negative_limit_rejected(self) -> None: + with pytest.raises(ValidationError, match="positive"): + BudgetLimitRule(limit=-1) + + def test_zero_limit_rejected(self) -> None: + with pytest.raises(ValidationError, match="positive"): + BudgetLimitRule(limit=0) + + def test_negative_limit_tokens_rejected(self) -> None: + with pytest.raises(ValidationError, match="positive"): + BudgetLimitRule(limit_tokens=-1) + + def test_negative_window_seconds_rejected(self) -> None: + with pytest.raises(ValidationError, match="positive"): + BudgetLimitRule(limit=1000, window_seconds=-1) + + def test_zero_window_seconds_rejected(self) -> None: + with pytest.raises(ValidationError, match="positive"): + BudgetLimitRule(limit=1000, window_seconds=0) + + def test_token_only_rule(self) -> None: + rule = BudgetLimitRule(limit_tokens=5000) + assert rule.limit is None + assert rule.limit_tokens == 5000 + + def test_currency_enum(self) -> None: + rule = BudgetLimitRule(limit=1000, currency=Currency.EUR) + assert rule.currency == Currency.EUR + + def test_currency_from_string(self) -> None: + rule = BudgetLimitRule(limit=1000, currency="tokens") + assert rule.currency == Currency.TOKENS + + def test_empty_limits_rejected(self) -> None: + with pytest.raises(ValidationError): + BudgetEvaluatorConfig(limits=[]) + + def test_window_constants(self) -> None: + assert WINDOW_DAILY == 86400 + assert WINDOW_WEEKLY == 604800 + assert WINDOW_MONTHLY == 2592000 + + +# --------------------------------------------------------------------------- +# BudgetEvaluator integration +# --------------------------------------------------------------------------- + + +class TestBudgetEvaluator: + def _make_evaluator(self, **kwargs: Any) -> BudgetEvaluator: + config = BudgetEvaluatorConfig(**kwargs) + return BudgetEvaluator(config) + + @pytest.mark.asyncio + async def test_single_call_under_budget(self) -> None: + # Given: evaluator with $10 limit (1000 cents) + ev = self._make_evaluator(limits=[{"limit": 1000}]) + + # When: evaluate with usage data (cost field is ignored without pricing/model_path) + result = await ev.evaluate({"usage": {"input_tokens": 100, "output_tokens": 50}}) + + # Then: not matched + assert result.matched is False + assert result.confidence == 1.0 + + @pytest.mark.asyncio + async def test_accumulate_past_budget(self) -> None: + # Given: evaluator with 50-cent limit and pricing table + ev = self._make_evaluator( + limits=[{"limit": 50}], + pricing={"gpt-4": {"input_per_1k": 30.0, "output_per_1k": 60.0}}, + model_path="model", + ) + + # When: two calls with tokens costing 27 cents each + # cost = ceil(300*30/1000 + 300*60/1000) = ceil(9+18) = 27 + # total = 27+27 = 54 > 50 + step = {"model": "gpt-4", "usage": {"input_tokens": 300, "output_tokens": 300}} + await ev.evaluate(step) + result = await ev.evaluate(step) + + # Then: matched (54 > 50) + assert result.matched is True + assert result.metadata is not None + + @pytest.mark.asyncio + async def test_group_by_user(self) -> None: + # Given: per-user 1000-cent budget with pricing table + # pricing: 200 cents per 1k input tokens + ev = self._make_evaluator( + limits=[{"group_by": "user_id", "limit": 1000}], + pricing={"gpt-4": {"input_per_1k": 200.0, "output_per_1k": 0.0}}, + model_path="model", + metadata_paths={"user_id": "user_id"}, + ) + + # When: u1 spends 800+300=1100 cents, u2 spends 300 cents + # 4000 input tokens * 200/1000 = 800 cents + # 1500 input tokens * 200/1000 = 300 cents + def _step(tokens: int, user: str) -> dict: + return { + "model": "gpt-4", + "usage": {"input_tokens": tokens, "output_tokens": 0}, + "user_id": user, + } + + await ev.evaluate(_step(4000, "u1")) + r1 = await ev.evaluate(_step(1500, "u1")) + r2 = await ev.evaluate(_step(1500, "u2")) + + # Then: u1 exceeded (1100 > 1000), u2 not (300 < 1000) + assert r1.matched is True + assert r2.matched is False + + @pytest.mark.asyncio + async def test_token_only_limit(self) -> None: + # Given: 500 token limit + ev = self._make_evaluator(limits=[{"limit_tokens": 500}]) + + # When: consume 600 tokens + result = await ev.evaluate({"usage": {"input_tokens": 300, "output_tokens": 300}}) + + # Then: exceeded + assert result.matched is True + + @pytest.mark.asyncio + async def test_no_data_returns_not_matched(self) -> None: + ev = self._make_evaluator(limits=[{"limit": 1000}]) + result = await ev.evaluate(None) + assert result.matched is False + + @pytest.mark.asyncio + async def test_confidence_always_one(self) -> None: + # Given: evaluator with 1000-cent limit and pricing table + # pricing: 200 cents per 1k input tokens + ev = self._make_evaluator( + limits=[{"limit": 1000}], + pricing={"gpt-4": {"input_per_1k": 200.0, "output_per_1k": 0.0}}, + model_path="model", + ) + + # When: first call costs 50 cents (250 tokens), second costs 960 cents (4800 tokens) + def _step(tokens: int) -> dict: + return {"model": "gpt-4", "usage": {"input_tokens": tokens, "output_tokens": 0}} + + r1 = await ev.evaluate(_step(250)) + r2 = await ev.evaluate(_step(4800)) + + # Then: confidence is always 1.0 + assert r1.confidence == 1.0 + assert r2.confidence == 1.0 + + @pytest.mark.asyncio + async def test_cost_computed_from_pricing_table(self) -> None: + # Given: evaluator with pricing table and 100-cent cost limit + ev = self._make_evaluator( + limits=[{"limit": 100}], + pricing={"gpt-4": {"input_per_1k": 30.0, "output_per_1k": 60.0}}, + model_path="model", + ) + + # When: evaluate with known model and tokens + # cost = ceil(100*30/1000 + 200*60/1000) = ceil(3+12) = 15 cents + result = await ev.evaluate( + { + "model": "gpt-4", + "usage": {"input_tokens": 100, "output_tokens": 200}, + } + ) + + # Then: not matched (15 < 100), cost tracked in metadata + assert result.matched is False + assert result.metadata is not None + assert result.metadata["cost"] == 15 + + @pytest.mark.asyncio + async def test_unknown_model_cost_zero(self) -> None: + # Given: evaluator with pricing table but data from an unknown model + ev = self._make_evaluator( + limits=[{"limit": 100}], + pricing={"gpt-4": {"input_per_1k": 30.0, "output_per_1k": 60.0}}, + model_path="model", + ) + + # When: evaluate with a model not in the pricing table + result = await ev.evaluate( + { + "model": "unknown-model", + "usage": {"input_tokens": 1000, "output_tokens": 1000}, + } + ) + + # Then: not matched (cost=0 because model not in pricing) + assert result.matched is False + assert result.metadata is not None + assert result.metadata["cost"] == 0 + + +# --------------------------------------------------------------------------- +# Security / adversarial tests +# --------------------------------------------------------------------------- + + +class TestBudgetAdversarial: + def test_scope_key_injection_pipe(self) -> None: + # Given: malicious user_id with pipe + key = _build_scope_key({"ch": "slack"}, "uid", {"uid": "u1|ch=admin"}) + + # Then: pipe is percent-encoded, no injection + parts = key.split("|") + assert len(parts) == 2 + assert "ch=admin" not in parts + + def test_scope_key_no_collision(self) -> None: + key1 = _build_scope_key({}, "uid", {"uid": "a|b"}) + key2 = _build_scope_key({}, "uid", {"uid": "a_b"}) + assert key1 != key2 + + def test_extract_by_path_rejects_dunder(self) -> None: + from agent_control_evaluator_budget.budget.evaluator import _extract_by_path + + assert _extract_by_path({"a": 1}, "__class__") is None + + def test_group_by_without_metadata_skips_rule(self) -> None: + # Given: rule with group_by=user_id but no user_id in scope + rules = [BudgetLimitRule(group_by="user_id", limit=1000)] + store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0) + + # When: step without user_id + results = store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=999) + + # Then: rule skipped + assert results == [] + + def test_two_rules_same_scope_no_double_count(self) -> None: + # Given: two global rules with different limit types + rules = [ + BudgetLimitRule(limit=1000), + BudgetLimitRule(limit_tokens=5000), + ] + store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0) + + # When: record once + results = store.record_and_check(scope={}, input_tokens=100, output_tokens=100, cost=100) + + # Then: both rules get snapshot, but usage recorded only once + assert len(results) == 2 + assert results[0].spent == 100 # not 200 + assert results[1].spent_tokens == 200 # not 400 + + def test_different_currency_separate_buckets(self) -> None: + # Given: two rules with same scope but different currencies + rules = [ + BudgetLimitRule(limit=1000, currency=Currency.USD), + BudgetLimitRule(limit=2000, currency=Currency.EUR), + ] + store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0) + + # When: record once + results = store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=500) + + # Then: each currency gets its own bucket, both record the cost + assert len(results) == 2 + assert results[0].spent == 500 + assert results[1].spent == 500 + + def test_negative_cost_not_recorded(self) -> None: + # Given: store with 1000-cent limit + rules = [BudgetLimitRule(limit=1000)] + store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0) + + # When: record positive then negative cost + store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=500) + results = store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=-200) + + # Then: negative cost is added (store is dumb; validation is caller's job) + # If this is undesirable, evaluator must reject negatives before calling store + assert results[0].spent == 300 + + def test_window_seconds_boundary_alignment(self) -> None: + # Given: hourly window, clock at boundary-1 and boundary + rules = [BudgetLimitRule(limit=1000, window_seconds=3600)] + boundary = 3600 * 100 # exact hour boundary + + # When: record just before and at boundary + store = InMemoryBudgetStore(rules=rules, clock=lambda: boundary - 1) + store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=500) + + store._clock = lambda: boundary + results = store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=500) + + # Then: boundary crossing starts fresh period + assert results[0].spent == 500 # not 1000 + + +class TestConfigValidationEdgeCases: + def test_zero_limit_tokens_rejected(self) -> None: + # Given/When: zero token limit + with pytest.raises(ValidationError, match="positive"): + BudgetLimitRule(limit_tokens=0) + + def test_invalid_currency_rejected(self) -> None: + # Given/When: invalid currency string + with pytest.raises(ValidationError): + BudgetLimitRule(limit=1000, currency="btc") + + +class TestBoolGuard: + """bool is a subclass of int in Python -- must be rejected.""" + + def test_extract_tokens_rejects_bool(self) -> None: + # Given: data with bool tokens + data = {"usage": {"input_tokens": True, "output_tokens": False}} + + # When/Then: bools are not accepted as token counts + assert _extract_tokens(data, None) == (0, 0) diff --git a/pyproject.toml b/pyproject.toml index 95baef8d..645f6229 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,3 +83,9 @@ tag_format = "v{version}" # feat = minor, fix/perf/refactor = patch, breaking (!) = major allowed_tags = ["feat", "fix", "perf", "chore", "docs", "style", "refactor", "test", "ci"] patch_tags = ["fix", "perf", "chore", "refactor"] + +[dependency-groups] +dev = [ + "pytest>=9.0.2", + "pytest-asyncio>=1.3.0", +] From cd473e8d628779e2111450a7e5979b3511da7c93 Mon Sep 17 00:00:00 2001 From: amabito Date: Fri, 10 Apr 2026 12:12:27 +0900 Subject: [PATCH 2/3] feat(budget): address R3 review -- async ABC, TTL prune, defensive guards Respond to lan17's R3 review on PR #144 with the mechanical items that do not depend on pending config-layer decisions (limit model, budget_id, unknown_model_behavior). Changes: - Migrate BudgetStore from Protocol to async ABC with __init_subclass__ guard that walks the MRO to reject sync overrides at class creation - InMemoryBudgetStore: async wrapper around sync helper, threading.Lock retained for CPU-bound critical section - TTL prune for stale period buckets on rollover, runs before max_buckets capacity check so rollover at capacity reclaims space - Monotonic prune watermark (rejects backwards clock) - _compute_utilization low-side clamp to [0.0, 1.0] (refund semantic) - Defensive guards: NaN/Inf cost and clock coerced to 0.0, negative token counts clamped to 0 - Revert root pyproject.toml (remove unrelated [dependency-groups], restore version 7.3.1) - Remove clear_budget_stores from __all__ (testing utility) - Document token attribution intent (single int -> output-only) Tests: 67 -> 91 (24 new: async migration, TTL prune coverage, adversarial guards, ABC contract enforcement) --- .../budget/__init__.py | 4 + .../budget/evaluator.py | 103 +- .../budget/memory_store.py | 141 ++- .../budget/store.py | 75 +- .../budget/tests/budget/test_budget.py | 1008 +++++++++++++++-- pyproject.toml | 6 - 6 files changed, 1185 insertions(+), 152 deletions(-) diff --git a/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/__init__.py b/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/__init__.py index b0e6f6d4..c82d2647 100644 --- a/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/__init__.py +++ b/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/__init__.py @@ -5,6 +5,10 @@ from agent_control_evaluator_budget.budget.memory_store import InMemoryBudgetStore from agent_control_evaluator_budget.budget.store import BudgetSnapshot, BudgetStore +# Note: clear_budget_stores is a testing utility and is intentionally not +# re-exported here. Import it directly from the evaluator submodule in tests: +# from agent_control_evaluator_budget.budget.evaluator import clear_budget_stores + __all__ = [ "BudgetEvaluator", "BudgetEvaluatorConfig", diff --git a/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/evaluator.py b/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/evaluator.py index 6d1ca128..c0d6c517 100644 --- a/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/evaluator.py +++ b/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/evaluator.py @@ -3,12 +3,19 @@ Deterministic evaluator: confidence is always 1.0, matched is True when any configured limit is exceeded. Utilization ratio and spend breakdown are returned in result metadata, not in confidence. + +The evaluator is stateless. Budget state lives in a module-level store +registry, independent of the evaluator instance cache in _factory.py. +This prevents silent state loss on LRU eviction and avoids cross-control +leakage when different controls share the same config. """ from __future__ import annotations +import json import logging import math +import threading from typing import Any from agent_control_evaluators._base import Evaluator, EvaluatorMetadata @@ -17,9 +24,61 @@ from .config import BudgetEvaluatorConfig from .memory_store import InMemoryBudgetStore +from .store import BudgetStore logger = logging.getLogger(__name__) +# --------------------------------------------------------------------------- +# Module-level store registry +# +# Decoupled from the evaluator instance cache so that LRU eviction in +# _factory.py does not destroy accumulated budget state. The registry +# is keyed by a stable config hash. Two controls with identical config +# intentionally share a budget pool (same config = same budget). +# --------------------------------------------------------------------------- + +# NOTE: The registry is unbounded. In practice a deployment has a finite +# set of budget configs. If dynamic config generation becomes a concern, +# add a max-size cap with LRU eviction here. +_STORE_REGISTRY: dict[str, BudgetStore] = {} +_STORE_REGISTRY_LOCK = threading.Lock() + + +def _config_key(config: BudgetEvaluatorConfig) -> str: + """Build a stable key for the store registry from evaluator config. + + The limits list is sorted before hashing so that two configs with + semantically identical rules in different order share a store. + """ + config_dict = config.model_dump(mode="json") + config_dict["limits"] = sorted( + config_dict["limits"], + key=lambda r: json.dumps(r, sort_keys=True, default=str), + ) + return f"budget:{json.dumps(config_dict, sort_keys=True, default=str)}" + + +def get_or_create_store(config: BudgetEvaluatorConfig) -> BudgetStore: + """Get or create a store for the given config, thread-safe.""" + key = _config_key(config) + with _STORE_REGISTRY_LOCK: + store = _STORE_REGISTRY.get(key) + if store is None: + store = InMemoryBudgetStore(rules=config.limits) + _STORE_REGISTRY[key] = store + return store + + +def clear_budget_stores() -> None: + """Clear all budget stores. Useful for testing.""" + with _STORE_REGISTRY_LOCK: + _STORE_REGISTRY.clear() + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + def _extract_by_path(data: Any, path: str) -> Any: """Extract a value from nested data using dot-notation path.""" @@ -50,6 +109,10 @@ def _extract_tokens(data: Any, token_path: str | None) -> tuple[int, int]: if token_path: val = _extract_by_path(data, token_path) if isinstance(val, int) and not isinstance(val, bool) and val >= 0: + # When token_path resolves to a single int we cannot distinguish + # input vs output. Attribute the whole count to output because + # output rates are typically higher than input rates in pricing + # tables, so this over-estimates cost rather than under-estimates. return 0, val if isinstance(val, dict): data = val @@ -78,19 +141,23 @@ def _estimate_cost( input_tokens: int, output_tokens: int, pricing: dict[str, dict[str, float]] | None, -) -> int: - """Estimate cost in minor units from model pricing table. Returns 0 if unknown.""" +) -> float: + """Estimate cost in cents (USD) from model pricing table. + + Returns a float for precision. Rounding happens at snapshot time, + not per call. + """ if not model or not pricing: - return 0 + return 0.0 rates = pricing.get(model) if not rates: - return 0 + return 0.0 input_rate = rates.get("input_per_1k", 0.0) output_rate = rates.get("output_per_1k", 0.0) cost = (input_tokens * input_rate + output_tokens * output_rate) / 1000.0 if not math.isfinite(cost) or cost < 0: - return 0 - return math.ceil(cost) + return 0.0 + return cost def _extract_metadata(data: Any, metadata_paths: dict[str, str]) -> dict[str, str]: @@ -103,6 +170,11 @@ def _extract_metadata(data: Any, metadata_paths: dict[str, str]) -> dict[str, st return result +# --------------------------------------------------------------------------- +# Evaluator +# --------------------------------------------------------------------------- + + @register_evaluator class BudgetEvaluator(Evaluator[BudgetEvaluatorConfig]): """Tracks cumulative LLM token and cost usage per scope and time window. @@ -110,21 +182,17 @@ class BudgetEvaluator(Evaluator[BudgetEvaluatorConfig]): Deterministic evaluator: matched=True when any configured limit is exceeded, confidence=1.0 always. - The evaluator is stateful -- it accumulates usage in a BudgetStore. - The store is created per evaluator config and is thread-safe. + The evaluator is stateless. Budget state is managed by a module-level + store registry (get_or_create_store), not by the evaluator instance. """ metadata = EvaluatorMetadata( name="budget", - version="2.0.0", + version="3.0.0", description="Cumulative LLM token and cost budget tracking", ) config_model = BudgetEvaluatorConfig - def __init__(self, config: BudgetEvaluatorConfig) -> None: - super().__init__(config) - self._store = InMemoryBudgetStore(rules=config.limits) - async def evaluate(self, data: Any) -> EvaluatorResult: """Evaluate step data against all configured budget limits.""" if data is None: @@ -146,7 +214,8 @@ async def evaluate(self, data: Any) -> EvaluatorResult: step_metadata = _extract_metadata(data, self.config.metadata_paths) - snapshots = self._store.record_and_check( + store = get_or_create_store(self.config) + snapshots = await store.record_and_check( scope=step_metadata, input_tokens=input_tokens, output_tokens=output_tokens, @@ -156,7 +225,7 @@ async def evaluate(self, data: Any) -> EvaluatorResult: breached: list[dict[str, Any]] = [] all_snaps: list[dict[str, Any]] = [] - for i, snap in enumerate(snapshots): + for snap in snapshots: snap_info = { "spent": snap.spent, "spent_tokens": snap.spent_tokens, @@ -180,7 +249,7 @@ async def evaluate(self, data: Any) -> EvaluatorResult: "all_snapshots": all_snaps, "input_tokens": input_tokens, "output_tokens": output_tokens, - "cost": cost, + "cost": round(cost, 6), }, ) @@ -193,7 +262,7 @@ async def evaluate(self, data: Any) -> EvaluatorResult: "all_snapshots": all_snaps, "input_tokens": input_tokens, "output_tokens": output_tokens, - "cost": cost, + "cost": round(cost, 6), "max_utilization": round(max_util, 4), }, ) diff --git a/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/memory_store.py b/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/memory_store.py index 21130cf8..df1784eb 100644 --- a/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/memory_store.py +++ b/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/memory_store.py @@ -6,13 +6,14 @@ from __future__ import annotations +import math import threading import time from collections.abc import Callable from dataclasses import dataclass from .config import BudgetLimitRule -from .store import BudgetSnapshot +from .store import BudgetSnapshot, BudgetStore, round_spent def _sanitize_scope_value(val: str) -> str: @@ -34,6 +35,20 @@ def _build_scope_key( return "|".join(parts) if parts else "__global__" +def _parse_period_key(key: str) -> tuple[int, int] | None: + """Parse 'P{window}:{index}' into (window_seconds, bucket_index). + + Returns None for empty/cumulative keys. + """ + if not key or not key.startswith("P"): + return None + try: + window_part, index_part = key[1:].split(":", 1) + return int(window_part), int(index_part) + except (ValueError, IndexError): + return None + + def _derive_period_key(window_seconds: int | None, now: float) -> str: """Derive a period key from window_seconds and a timestamp. @@ -58,17 +73,22 @@ def _scope_matches(rule: BudgetLimitRule, scope: dict[str, str]) -> bool: def _compute_utilization( - spent: int, + spent: float, spent_tokens: int, limit: int | None, limit_tokens: int | None, ) -> float: - """Return max(spend_ratio, token_ratio) clamped to [0.0, 1.0].""" + """Return max(spend_ratio, token_ratio) clamped to [0.0, 1.0]. + + The low-side clamp is load-bearing: under refund semantics the internal + `spent` accumulator may go negative, which would otherwise produce a + negative ratio and violate the BudgetSnapshot.utilization contract. + """ ratios: list[float] = [] if limit is not None and limit > 0: - ratios.append(min(spent / limit, 1.0)) + ratios.append(max(0.0, min(spent / limit, 1.0))) if limit_tokens is not None and limit_tokens > 0: - ratios.append(min(spent_tokens / limit_tokens, 1.0)) + ratios.append(max(0.0, min(spent_tokens / limit_tokens, 1.0))) return max(ratios) if ratios else 0.0 @@ -76,7 +96,7 @@ def _compute_utilization( class _Bucket: """Internal mutable accumulator for a single (scope, period) pair.""" - spent: int = 0 + spent: float = 0.0 input_tokens: int = 0 output_tokens: int = 0 @@ -85,17 +105,21 @@ def total_tokens(self) -> int: return self.input_tokens + self.output_tokens -class InMemoryBudgetStore: +class InMemoryBudgetStore(BudgetStore): """Thread-safe in-memory budget store. Initialized with a list of BudgetLimitRule. Derives period keys internally from window_seconds + injected clock. - NOTE: Currency conversion is not handled here. The cost integer - passed to record_and_check is assumed to be in the same unit as - the rule's currency. Cross-currency conversion (e.g. USD->EUR) - is the caller's responsibility and will be addressed when cost - calculation moves into the evaluator (pending design review). + Cost is accumulated as float for precision. Integer rounding + happens only at snapshot time for display/reporting. + + TTL prune: on new period rollover per window, buckets older than + `current - 1` for that window are dropped. This keeps memory bounded + for long-running deployments with windowed rules. + + `max_buckets` remains as a backstop for high-cardinality group_by + explosions that TTL cannot protect against. """ _DEFAULT_MAX_BUCKETS = 100_000 @@ -110,20 +134,45 @@ def __init__( self._rules = rules self._clock = clock self._lock = threading.Lock() - self._buckets: dict[tuple[str, str, str], _Bucket] = {} + self._buckets: dict[tuple[str, str], _Bucket] = {} self._max_buckets = max_buckets + self._last_pruned_period: dict[int, int] = {} - def record_and_check( + async def record_and_check( self, scope: dict[str, str], input_tokens: int, output_tokens: int, - cost: int, + cost: float, ) -> list[BudgetSnapshot]: """Atomically record usage and return snapshots for all matching rules.""" + return self._record_and_check_sync(scope, input_tokens, output_tokens, cost) + + def _record_and_check_sync( + self, + scope: dict[str, str], + input_tokens: int, + output_tokens: int, + cost: float, + ) -> list[BudgetSnapshot]: + """Sync implementation of record_and_check. + + NaN/Inf cost is coerced to 0.0 defensively. Once NaN enters a + bucket's float accumulator, all subsequent additions produce NaN + and `nan >= limit` is always False (IEEE 754), permanently + disabling budget enforcement for that bucket. + """ + if not math.isfinite(cost): + cost = 0.0 + # Token counts have no refund semantics; clamp to non-negative + # to prevent negative injection from resetting the accumulator. + input_tokens = max(0, input_tokens) + output_tokens = max(0, output_tokens) now = self._clock() + if not math.isfinite(now): + now = 0.0 snapshots: list[BudgetSnapshot] = [] - recorded_pairs: set[tuple[str, str, str]] = set() + recorded_pairs: set[tuple[str, str]] = set() with self._lock: for rule in self._rules: @@ -132,9 +181,7 @@ def record_and_check( scope_key = _build_scope_key(rule.scope, rule.group_by, scope) period_key = _derive_period_key(rule.window_seconds, now) - cur = rule.currency - currency_key = cur.value if hasattr(cur, "value") else str(cur) - pair = (scope_key, period_key, currency_key) + pair = (scope_key, period_key) if pair not in recorded_pairs: bucket = self._get_or_create_bucket(pair) @@ -157,8 +204,14 @@ def record_and_check( recorded_pairs.add(pair) else: bucket = self._buckets.get(pair) - if bucket is None: - continue + # Defensive: this branch is unreachable under current + # invariants (recorded_pairs only contains pairs whose + # bucket was successfully created, and self._lock prevents + # concurrent deletion). If a future refactor violates + # this, the assertion surfaces it. + assert bucket is not None, ( + f"bucket for {pair!r} was in recorded_pairs but missing from _buckets" + ) total_tokens = bucket.total_tokens utilization = _compute_utilization( @@ -172,7 +225,7 @@ def record_and_check( snapshots.append( BudgetSnapshot( - spent=bucket.spent, + spent=round_spent(bucket.spent), spent_tokens=total_tokens, limit=rule.limit, limit_tokens=rule.limit_tokens, @@ -189,10 +242,9 @@ def get_snapshot( period_key: str, limit: int | None = None, limit_tokens: int | None = None, - currency: str = "usd", ) -> BudgetSnapshot: """Read current budget state without recording usage.""" - key = (scope_key, period_key, currency) + key = (scope_key, period_key) with self._lock: bucket = self._buckets.get(key) if bucket is None: @@ -212,7 +264,7 @@ def get_snapshot( if limit_tokens is not None and total_tokens >= limit_tokens: exceeded = True return BudgetSnapshot( - spent=bucket.spent, + spent=round_spent(bucket.spent), spent_tokens=total_tokens, limit=limit, limit_tokens=limit_tokens, @@ -225,6 +277,7 @@ def reset(self, scope_key: str | None = None, period_key: str | None = None) -> with self._lock: if scope_key is None and period_key is None: self._buckets.clear() + self._last_pruned_period.clear() return keys_to_remove = [ k @@ -235,11 +288,45 @@ def reset(self, scope_key: str | None = None, period_key: str | None = None) -> for k in keys_to_remove: del self._buckets[k] - def _get_or_create_bucket(self, key: tuple[str, str, str]) -> _Bucket | None: - """Get or create a bucket. Returns None if max_buckets reached.""" + def _get_or_create_bucket(self, key: tuple[str, str]) -> _Bucket | None: + """Get or create a bucket. Returns None if max_buckets reached. + + On period rollover (new windowed bucket with a forward period index), + stale buckets for the same window (bucket_index < current - 1) are + pruned BEFORE the max_buckets capacity check, so that a rollover at + capacity can free space rather than fail closed. Cross-scope pruning + is intentional: all stale same-window buckets are dropped regardless + of scope key, since the period has expired globally. + + The watermark `_last_pruned_period[window]` only advances forward; + a backwards clock does not trigger spurious prune work. + + Caller must hold self._lock. + """ bucket = self._buckets.get(key) if bucket is not None: return bucket + + # TTL prune runs BEFORE the max_buckets check so that rollover at + # capacity can reclaim space rather than fail closed permanently. + parsed = _parse_period_key(key[1]) + if parsed is not None: + window, index = parsed + last_pruned = self._last_pruned_period.get(window) + # Only advance on forward progress. Backwards clock is a no-op; + # the previously established watermark still protects us. + if last_pruned is None or index > last_pruned: + stale_keys = [ + k + for k in self._buckets + if (kp := _parse_period_key(k[1])) is not None + and kp[0] == window + and kp[1] < index - 1 + ] + for k in stale_keys: + del self._buckets[k] + self._last_pruned_period[window] = index + if len(self._buckets) >= self._max_buckets: return None bucket = _Bucket() diff --git a/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/store.py b/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/store.py index 0c767d3b..9d58f76f 100644 --- a/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/store.py +++ b/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/store.py @@ -1,4 +1,4 @@ -"""BudgetStore protocol -- interface for budget storage backends. +"""BudgetStore abstract base class -- interface for budget storage backends. Implementations must provide atomic record-and-check: a single call that records usage and returns the current totals. This prevents @@ -10,8 +10,11 @@ from __future__ import annotations +import inspect +import math +from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Protocol, runtime_checkable +from typing import Any @dataclass(frozen=True) @@ -19,9 +22,9 @@ class BudgetSnapshot: """Immutable view of budget state at a point in time. Attributes: - spent: Cumulative spend in minor units (e.g. cents for USD). + spent: Cumulative spend in cents (USD), rounded from float. spent_tokens: Cumulative tokens (input + output) in this scope+period. - limit: Configured spend ceiling in minor units, or None if uncapped. + limit: Configured spend ceiling in cents, or None if uncapped. limit_tokens: Configured token ceiling, or None if uncapped. utilization: max(spend_ratio, token_ratio) clamped to [0.0, 1.0]. 0.0 when no limits are set. @@ -36,22 +39,73 @@ class BudgetSnapshot: exceeded: bool -@runtime_checkable -class BudgetStore(Protocol): - """Protocol for budget storage backends. +def round_spent(value: float) -> int: + """Truncate accumulated float spend to integer cents for display. + + Uses floor truncation (not rounding) so that the displayed spent + value never exceeds the actual float. This prevents a contradictory + snapshot where spent >= limit but exceeded is False. + """ + if not math.isfinite(value) or value < 0: + return 0 + return int(value) + + +class BudgetStore(ABC): + """Abstract base class for budget storage backends. The store is initialized with a list of BudgetLimitRule and derives period keys internally from window_seconds + current time. Callers pass only usage data: scope dict, input_tokens, output_tokens, cost. + + Negative `cost` values are permitted and reduce accumulated spend (refund + semantics). `round_spent()` floors the displayed snapshot spend to 0 for + negative accumulators, but the internal float accumulator may go negative + so that a subsequent positive charge cancels correctly. Validation of + cost >= 0 is NOT performed at the store boundary; it is the caller's + responsibility if strict positive accounting is required. + + Implementations should be safe to call from async contexts. + InMemoryBudgetStore wraps a sync critical section under threading.Lock + because the work is CPU-bound and brief; distributed backends + (Redis/Postgres) should use native async I/O. + + Subclasses must override `record_and_check` with a coroutine function + (`async def`). A sync override is rejected at class creation time rather + than failing silently at the first `await` site in production. """ - def record_and_check( + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + # Walk the MRO to find the nearest override of record_and_check. + # Checking only cls.__dict__ misses mixin-inherited sync overrides + # that satisfy ABC's abstractmethod check but silently break at the + # first `await` call site. + method = None + for base in cls.__mro__: + if base is BudgetStore: + break + if "record_and_check" in base.__dict__: + raw = base.__dict__["record_and_check"] + # Unwrap staticmethod/classmethod descriptors so that + # inspect.iscoroutinefunction sees the underlying function. + method = getattr(raw, "__func__", raw) + break + if method is not None and not inspect.iscoroutinefunction(method): + raise TypeError( + f"{cls.__name__}.record_and_check must be an async def " + "(coroutine function); got a sync function. BudgetStore is " + "an async ABC." + ) + + @abstractmethod + async def record_and_check( self, scope: dict[str, str], input_tokens: int, output_tokens: int, - cost: int, + cost: float, ) -> list[BudgetSnapshot]: """Atomically record usage and return snapshots for all matching rules. @@ -59,9 +113,8 @@ def record_and_check( scope: Scope dimensions from the step (e.g. {"agent": "summarizer"}). input_tokens: Input tokens consumed by this call. output_tokens: Output tokens consumed by this call. - cost: Cost in minor units (e.g. cents for USD). + cost: Cost in cents (USD), as a float for precision. Returns: List of BudgetSnapshot, one per matching rule. """ - ... diff --git a/evaluators/contrib/budget/tests/budget/test_budget.py b/evaluators/contrib/budget/tests/budget/test_budget.py index 561f4edf..86adae5c 100644 --- a/evaluators/contrib/budget/tests/budget/test_budget.py +++ b/evaluators/contrib/budget/tests/budget/test_budget.py @@ -17,11 +17,12 @@ WINDOW_WEEKLY, BudgetEvaluatorConfig, BudgetLimitRule, - Currency, ) from agent_control_evaluator_budget.budget.evaluator import ( BudgetEvaluator, _extract_tokens, + clear_budget_stores, + get_or_create_store, ) from agent_control_evaluator_budget.budget.memory_store import ( InMemoryBudgetStore, @@ -30,39 +31,53 @@ _derive_period_key, ) + +@pytest.fixture(autouse=True) +def _clean_store_registry() -> None: + """Clear the module-level store registry before each test.""" + clear_budget_stores() + + # --------------------------------------------------------------------------- # InMemoryBudgetStore # --------------------------------------------------------------------------- class TestInMemoryBudgetStore: - def test_single_record_under_limit(self) -> None: + @pytest.mark.asyncio + async def test_single_record_under_limit(self) -> None: # Given: store with a $10 daily limit (1000 cents) rules = [BudgetLimitRule(limit=1000, window_seconds=WINDOW_DAILY)] store = InMemoryBudgetStore(rules=rules, clock=lambda: 1700000000.0) # When: record 300 cents of usage - results = store.record_and_check(scope={}, input_tokens=100, output_tokens=50, cost=300) + results = await store.record_and_check( + scope={}, input_tokens=100, output_tokens=50, cost=300.0 + ) # Then: not breached, ratio ~0.3 assert len(results) == 1 assert not results[0].exceeded assert results[0].utilization == pytest.approx(0.3, abs=0.01) - def test_accumulation_triggers_breach(self) -> None: + @pytest.mark.asyncio + async def test_accumulation_triggers_breach(self) -> None: # Given: store with 1000-cent limit rules = [BudgetLimitRule(limit=1000)] store = InMemoryBudgetStore(rules=rules, clock=lambda: 1700000000.0) # When: record 600 + 500 = 1100 cents - store.record_and_check(scope={}, input_tokens=100, output_tokens=50, cost=600) - results = store.record_and_check(scope={}, input_tokens=100, output_tokens=50, cost=500) + await store.record_and_check(scope={}, input_tokens=100, output_tokens=50, cost=600.0) + results = await store.record_and_check( + scope={}, input_tokens=100, output_tokens=50, cost=500.0 + ) # Then: exceeded assert results[0].exceeded is True assert results[0].spent == 1100 - def test_scope_isolation(self) -> None: + @pytest.mark.asyncio + async def test_scope_isolation(self) -> None: # Given: per-agent limits rules = [ BudgetLimitRule(scope={"agent": "a"}, limit=1000), @@ -71,11 +86,11 @@ def test_scope_isolation(self) -> None: store = InMemoryBudgetStore(rules=rules, clock=lambda: 1700000000.0) # When: agent-a records 900, agent-b records 100 - results_a = store.record_and_check( - scope={"agent": "a"}, input_tokens=0, output_tokens=0, cost=900 + results_a = await store.record_and_check( + scope={"agent": "a"}, input_tokens=0, output_tokens=0, cost=900.0 ) - results_b = store.record_and_check( - scope={"agent": "b"}, input_tokens=0, output_tokens=0, cost=100 + results_b = await store.record_and_check( + scope={"agent": "b"}, input_tokens=0, output_tokens=0, cost=100.0 ) # Then: agent-a near limit, agent-b well under @@ -83,7 +98,8 @@ def test_scope_isolation(self) -> None: assert results_b[0].spent == 100 assert not results_b[0].exceeded - def test_period_isolation(self) -> None: + @pytest.mark.asyncio + async def test_period_isolation(self) -> None: # Given: daily limit, clock at two different days rules = [BudgetLimitRule(limit=1000, window_seconds=WINDOW_DAILY)] day1 = 1700000000.0 @@ -91,63 +107,75 @@ def test_period_isolation(self) -> None: # When: record on day 1, then day 2 store = InMemoryBudgetStore(rules=rules, clock=lambda: day1) - store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=800) + await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=800.0) store._clock = lambda: day2 - results = store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=300) + results = await store.record_and_check( + scope={}, input_tokens=0, output_tokens=0, cost=300.0 + ) # Then: day 2 is a fresh period assert results[0].spent == 300 assert not results[0].exceeded - def test_exceeded_exact_limit(self) -> None: + @pytest.mark.asyncio + async def test_exceeded_exact_limit(self) -> None: # Given: 1000-cent limit rules = [BudgetLimitRule(limit=1000)] store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0) # When: spend exactly 1000 - results = store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1000) + results = await store.record_and_check( + scope={}, input_tokens=0, output_tokens=0, cost=1000.0 + ) # Then: exceeded (>= not >) assert results[0].exceeded is True - def test_token_only_limit(self) -> None: + @pytest.mark.asyncio + async def test_token_only_limit(self) -> None: # Given: 1000-token limit, no cost limit rules = [BudgetLimitRule(limit_tokens=1000)] store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0) # When: consume 600+500 = 1100 tokens - results = store.record_and_check(scope={}, input_tokens=600, output_tokens=500, cost=0) + results = await store.record_and_check( + scope={}, input_tokens=600, output_tokens=500, cost=0.0 + ) # Then: exceeded assert results[0].exceeded is True assert results[0].spent_tokens == 1100 - def test_no_matching_rules(self) -> None: + @pytest.mark.asyncio + async def test_no_matching_rules(self) -> None: # Given: rule for agent=summarizer only rules = [BudgetLimitRule(scope={"agent": "summarizer"}, limit=1000)] store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0) # When: step from agent=other - results = store.record_and_check( - scope={"agent": "other"}, input_tokens=100, output_tokens=50, cost=999 + results = await store.record_and_check( + scope={"agent": "other"}, input_tokens=100, output_tokens=50, cost=999.0 ) # Then: no snapshots (rule didn't match) assert results == [] - def test_group_by_user(self) -> None: + @pytest.mark.asyncio + async def test_group_by_user(self) -> None: # Given: global rule with group_by=user_id rules = [BudgetLimitRule(group_by="user_id", limit=500)] store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0) # When: two users each spend - store.record_and_check(scope={"user_id": "u1"}, input_tokens=0, output_tokens=0, cost=400) - results_u1 = store.record_and_check( - scope={"user_id": "u1"}, input_tokens=0, output_tokens=0, cost=200 + await store.record_and_check( + scope={"user_id": "u1"}, input_tokens=0, output_tokens=0, cost=400.0 + ) + results_u1 = await store.record_and_check( + scope={"user_id": "u1"}, input_tokens=0, output_tokens=0, cost=200.0 ) - results_u2 = store.record_and_check( - scope={"user_id": "u2"}, input_tokens=0, output_tokens=0, cost=300 + results_u2 = await store.record_and_check( + scope={"user_id": "u2"}, input_tokens=0, output_tokens=0, cost=300.0 ) # Then: u1 exceeded, u2 not @@ -156,14 +184,20 @@ def test_group_by_user(self) -> None: def test_thread_safety(self) -> None: # Given: high-limit rule and 10 concurrent threads + # Each thread calls asyncio.run(store.record_and_check(...)) -- the async + # method wraps a sync critical section, so threading.Lock prevents races. rules = [BudgetLimitRule(limit=1_000_000)] store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0) errors: list[str] = [] + import asyncio + def record_many() -> None: try: for _ in range(100): - store.record_and_check(scope={}, input_tokens=1, output_tokens=1, cost=1) + asyncio.run( + store.record_and_check(scope={}, input_tokens=1, output_tokens=1, cost=1.0) + ) except Exception as exc: errors.append(str(exc)) @@ -180,7 +214,8 @@ def record_many() -> None: assert snap.spent_tokens == 2000 assert snap.spent == 1000 - def test_max_buckets_fail_closed(self) -> None: + @pytest.mark.asyncio + async def test_max_buckets_fail_closed(self) -> None: # Given: store limited to 3 buckets with group_by=user_id rules = [BudgetLimitRule(group_by="user_id", limit=100_000)] store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0, max_buckets=3) @@ -188,8 +223,8 @@ def test_max_buckets_fail_closed(self) -> None: # When: 5 different users try to record exceeded_count = 0 for i in range(5): - results = store.record_and_check( - scope={"user_id": f"u{i}"}, input_tokens=1, output_tokens=1, cost=1 + results = await store.record_and_check( + scope={"user_id": f"u{i}"}, input_tokens=1, output_tokens=1, cost=1.0 ) if results and results[0].exceeded: exceeded_count += 1 @@ -197,11 +232,12 @@ def test_max_buckets_fail_closed(self) -> None: # Then: first 3 succeed, last 2 fail-closed assert exceeded_count == 2 - def test_reset_all(self) -> None: + @pytest.mark.asyncio + async def test_reset_all(self) -> None: # Given: store with recorded usage rules = [BudgetLimitRule(limit=1000)] store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0) - store.record_and_check(scope={}, input_tokens=10, output_tokens=10, cost=100) + await store.record_and_check(scope={}, input_tokens=10, output_tokens=10, cost=100.0) # When: reset all store.reset() @@ -210,6 +246,36 @@ def test_reset_all(self) -> None: snap = store.get_snapshot("__global__", "", limit=1000) assert snap.spent == 0 + @pytest.mark.asyncio + async def test_float_accumulation_precision(self) -> None: + # Given: store with 1-cent limit + rules = [BudgetLimitRule(limit=1)] + store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0) + + # When: 100 calls each costing 0.003 cents (total = 0.3 cents) + for _ in range(100): + await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=0.003) + + # Then: not exceeded (0.3 < 1), no ceil-per-call overcount + snap = store.get_snapshot("__global__", "", limit=1) + assert not snap.exceeded + assert snap.spent == 0 # round(0.3) = 0 + + @pytest.mark.asyncio + async def test_float_accumulation_eventual_breach(self) -> None: + # Given: store with 1-cent limit + rules = [BudgetLimitRule(limit=1)] + store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0) + + # When: 400 calls each costing 0.003 cents (total = 1.2 cents) + for _ in range(400): + results = await store.record_and_check( + scope={}, input_tokens=0, output_tokens=0, cost=0.003 + ) + + # Then: exceeded (1.2 >= 1) + assert results[0].exceeded is True + # --------------------------------------------------------------------------- # Utility functions @@ -218,51 +284,92 @@ def test_reset_all(self) -> None: class TestUtilities: def test_compute_utilization_no_limits(self) -> None: - assert _compute_utilization(100, 10000, None, None) == 0.0 + # Given/When: no limits set / Then: 0.0 + assert _compute_utilization(100.0, 10000, None, None) == 0.0 def test_compute_utilization_spend_only(self) -> None: - # Given: 500 of 1000 spent - assert _compute_utilization(500, 0, 1000, None) == pytest.approx(0.5) + # Given: 500 of 1000 spent / Then: 0.5 + assert _compute_utilization(500.0, 0, 1000, None) == pytest.approx(0.5) def test_compute_utilization_clamped(self) -> None: - assert _compute_utilization(2000, 0, 1000, None) == pytest.approx(1.0) + # Given: overspent / Then: clamped to 1.0 + assert _compute_utilization(2000.0, 0, 1000, None) == pytest.approx(1.0) + + def test_compute_utilization_negative_clamped_to_zero(self) -> None: + # Given: refund made the accumulator go negative + # When: utilization is computed + # Then: clamped to 0.0 (BudgetSnapshot.utilization contract) + assert _compute_utilization(-150.0, 0, 100, None) == 0.0 + # And: negative tokens (not currently reachable but defensively clamped) + assert _compute_utilization(0.0, -50, None, 100) == 0.0 + + def test_parse_period_key_valid(self) -> None: + # Given: well-formed period key / Then: parsed tuple + from agent_control_evaluator_budget.budget.memory_store import _parse_period_key + + assert _parse_period_key("P86400:19675") == (86400, 19675) + assert _parse_period_key("P3600:0") == (3600, 0) + + def test_parse_period_key_malformed(self) -> None: + # Given: empty, missing, or non-numeric period keys + # When: parsed + # Then: None returned (never raises) + from agent_control_evaluator_budget.budget.memory_store import _parse_period_key + + assert _parse_period_key("") is None # cumulative sentinel + assert _parse_period_key("P") is None # no separator + assert _parse_period_key("P:1") is None # empty window + assert _parse_period_key("P86400:") is None # empty index + assert _parse_period_key("Pabc:1") is None # non-numeric window + assert _parse_period_key("P86400:xyz") is None # non-numeric index + assert _parse_period_key("X86400:1") is None # wrong prefix + assert _parse_period_key("PP86400:1") is None # double P def test_derive_period_key_none(self) -> None: + # Given: no window / Then: empty key assert _derive_period_key(None, 0.0) == "" def test_derive_period_key_daily(self) -> None: - # Given: 1700000000 / 86400 = 19675 (truncated) + # Given: daily window at 1700000000 / Then: epoch-aligned key key = _derive_period_key(WINDOW_DAILY, 1700000000.0) assert key == "P86400:19675" def test_derive_period_key_weekly(self) -> None: + # Given: weekly window / Then: key starts with P604800: key = _derive_period_key(WINDOW_WEEKLY, 1700000000.0) assert key.startswith("P604800:") def test_build_scope_key_global(self) -> None: + # Given: empty scope / Then: __global__ assert _build_scope_key({}, None, {}) == "__global__" def test_build_scope_key_with_scope(self) -> None: + # Given: channel scope / Then: channel=slack key = _build_scope_key({"channel": "slack"}, None, {}) assert key == "channel=slack" def test_build_scope_key_with_group_by(self) -> None: + # Given: scope + group_by / Then: combined key key = _build_scope_key({"channel": "slack"}, "user_id", {"user_id": "u1"}) assert key == "channel=slack|user_id=u1" def test_build_scope_key_group_by_missing(self) -> None: + # Given: group_by field not in scope / Then: __global__ key = _build_scope_key({}, "user_id", {}) assert key == "__global__" def test_extract_tokens_standard(self) -> None: + # Given: standard token fields / Then: extracted data = {"usage": {"input_tokens": 100, "output_tokens": 50}} assert _extract_tokens(data, None) == (100, 50) def test_extract_tokens_openai(self) -> None: + # Given: OpenAI-style fields / Then: extracted data = {"usage": {"prompt_tokens": 80, "completion_tokens": 40}} assert _extract_tokens(data, None) == (80, 40) def test_extract_tokens_none(self) -> None: + # Given: None data / Then: (0, 0) assert _extract_tokens(None, None) == (0, 0) @@ -273,52 +380,53 @@ def test_extract_tokens_none(self) -> None: class TestBudgetLimitRuleConfig: def test_valid_rule(self) -> None: + # Given/When: valid limit / Then: accepted rule = BudgetLimitRule(limit=1000) assert rule.limit == 1000 - assert rule.currency == Currency.USD def test_no_limit_rejected(self) -> None: + # Given/When: no limit or limit_tokens / Then: rejected with pytest.raises(ValidationError, match="At least one"): BudgetLimitRule() def test_negative_limit_rejected(self) -> None: + # Given/When: negative limit / Then: rejected with pytest.raises(ValidationError, match="positive"): BudgetLimitRule(limit=-1) def test_zero_limit_rejected(self) -> None: + # Given/When: zero limit / Then: rejected with pytest.raises(ValidationError, match="positive"): BudgetLimitRule(limit=0) def test_negative_limit_tokens_rejected(self) -> None: + # Given/When: negative limit_tokens / Then: rejected with pytest.raises(ValidationError, match="positive"): BudgetLimitRule(limit_tokens=-1) def test_negative_window_seconds_rejected(self) -> None: + # Given/When: negative window_seconds / Then: rejected with pytest.raises(ValidationError, match="positive"): BudgetLimitRule(limit=1000, window_seconds=-1) def test_zero_window_seconds_rejected(self) -> None: + # Given/When: zero window_seconds / Then: rejected with pytest.raises(ValidationError, match="positive"): BudgetLimitRule(limit=1000, window_seconds=0) def test_token_only_rule(self) -> None: + # Given/When: limit_tokens only / Then: accepted, limit is None rule = BudgetLimitRule(limit_tokens=5000) assert rule.limit is None assert rule.limit_tokens == 5000 - def test_currency_enum(self) -> None: - rule = BudgetLimitRule(limit=1000, currency=Currency.EUR) - assert rule.currency == Currency.EUR - - def test_currency_from_string(self) -> None: - rule = BudgetLimitRule(limit=1000, currency="tokens") - assert rule.currency == Currency.TOKENS - def test_empty_limits_rejected(self) -> None: + # Given/When: empty limits list / Then: rejected with pytest.raises(ValidationError): BudgetEvaluatorConfig(limits=[]) def test_window_constants(self) -> None: + # Given/When/Then: constants have expected values assert WINDOW_DAILY == 86400 assert WINDOW_WEEKLY == 604800 assert WINDOW_MONTHLY == 2592000 @@ -339,7 +447,7 @@ async def test_single_call_under_budget(self) -> None: # Given: evaluator with $10 limit (1000 cents) ev = self._make_evaluator(limits=[{"limit": 1000}]) - # When: evaluate with usage data (cost field is ignored without pricing/model_path) + # When: evaluate with usage data result = await ev.evaluate({"usage": {"input_tokens": 100, "output_tokens": 50}}) # Then: not matched @@ -356,8 +464,8 @@ async def test_accumulate_past_budget(self) -> None: ) # When: two calls with tokens costing 27 cents each - # cost = ceil(300*30/1000 + 300*60/1000) = ceil(9+18) = 27 - # total = 27+27 = 54 > 50 + # cost = (300*30 + 300*60) / 1000 = 27.0 + # total = 27 + 27 = 54 > 50 step = {"model": "gpt-4", "usage": {"input_tokens": 300, "output_tokens": 300}} await ev.evaluate(step) result = await ev.evaluate(step) @@ -369,7 +477,6 @@ async def test_accumulate_past_budget(self) -> None: @pytest.mark.asyncio async def test_group_by_user(self) -> None: # Given: per-user 1000-cent budget with pricing table - # pricing: 200 cents per 1k input tokens ev = self._make_evaluator( limits=[{"group_by": "user_id", "limit": 1000}], pricing={"gpt-4": {"input_per_1k": 200.0, "output_per_1k": 0.0}}, @@ -378,8 +485,6 @@ async def test_group_by_user(self) -> None: ) # When: u1 spends 800+300=1100 cents, u2 spends 300 cents - # 4000 input tokens * 200/1000 = 800 cents - # 1500 input tokens * 200/1000 = 300 cents def _step(tokens: int, user: str) -> dict: return { "model": "gpt-4", @@ -408,6 +513,7 @@ async def test_token_only_limit(self) -> None: @pytest.mark.asyncio async def test_no_data_returns_not_matched(self) -> None: + # Given: evaluator / When: None data / Then: not matched ev = self._make_evaluator(limits=[{"limit": 1000}]) result = await ev.evaluate(None) assert result.matched is False @@ -415,14 +521,13 @@ async def test_no_data_returns_not_matched(self) -> None: @pytest.mark.asyncio async def test_confidence_always_one(self) -> None: # Given: evaluator with 1000-cent limit and pricing table - # pricing: 200 cents per 1k input tokens ev = self._make_evaluator( limits=[{"limit": 1000}], pricing={"gpt-4": {"input_per_1k": 200.0, "output_per_1k": 0.0}}, model_path="model", ) - # When: first call costs 50 cents (250 tokens), second costs 960 cents (4800 tokens) + # When: first call costs 50 cents, second costs 960 cents def _step(tokens: int) -> dict: return {"model": "gpt-4", "usage": {"input_tokens": tokens, "output_tokens": 0}} @@ -443,7 +548,7 @@ async def test_cost_computed_from_pricing_table(self) -> None: ) # When: evaluate with known model and tokens - # cost = ceil(100*30/1000 + 200*60/1000) = ceil(3+12) = 15 cents + # cost = (100*30 + 200*60) / 1000 = 15.0 cents result = await ev.evaluate( { "model": "gpt-4", @@ -454,7 +559,7 @@ async def test_cost_computed_from_pricing_table(self) -> None: # Then: not matched (15 < 100), cost tracked in metadata assert result.matched is False assert result.metadata is not None - assert result.metadata["cost"] == 15 + assert result.metadata["cost"] == pytest.approx(15.0, abs=0.01) @pytest.mark.asyncio async def test_unknown_model_cost_zero(self) -> None: @@ -476,7 +581,85 @@ async def test_unknown_model_cost_zero(self) -> None: # Then: not matched (cost=0 because model not in pricing) assert result.matched is False assert result.metadata is not None - assert result.metadata["cost"] == 0 + assert result.metadata["cost"] == 0.0 + + @pytest.mark.asyncio + async def test_small_cost_no_overcount(self) -> None: + # Given: evaluator with 1-cent limit, pricing yields 0.003 cents per call + ev = self._make_evaluator( + limits=[{"limit": 1}], + pricing={"gpt-4": {"input_per_1k": 0.03, "output_per_1k": 0.0}}, + model_path="model", + ) + step = {"model": "gpt-4", "usage": {"input_tokens": 100, "output_tokens": 0}} + + # When: 100 calls (total cost = 0.3 cents, should NOT exceed 1 cent) + for _ in range(100): + result = await ev.evaluate(step) + + # Then: not exceeded (float accumulation, no per-call ceil) + assert result.matched is False + + +# --------------------------------------------------------------------------- +# Store registry +# --------------------------------------------------------------------------- + + +class TestStoreRegistry: + def test_same_config_returns_same_store(self) -> None: + # Given: two configs with identical parameters + config = BudgetEvaluatorConfig(limits=[{"limit": 1000}]) + + # When: get store twice + store1 = get_or_create_store(config) + store2 = get_or_create_store(config) + + # Then: same object + assert store1 is store2 + + def test_different_config_returns_different_store(self) -> None: + # Given: two configs with different limits + config1 = BudgetEvaluatorConfig(limits=[{"limit": 1000}]) + config2 = BudgetEvaluatorConfig(limits=[{"limit": 2000}]) + + # When: get stores + store1 = get_or_create_store(config1) + store2 = get_or_create_store(config2) + + # Then: different objects + assert store1 is not store2 + + def test_clear_budget_stores(self) -> None: + # Given: a registered store + config = BudgetEvaluatorConfig(limits=[{"limit": 1000}]) + store1 = get_or_create_store(config) + + # When: clear all stores + clear_budget_stores() + store2 = get_or_create_store(config) + + # Then: new store (old one is gone) + assert store1 is not store2 + + @pytest.mark.asyncio + async def test_evaluator_uses_registry(self) -> None: + # Given: two evaluators with same config + config = BudgetEvaluatorConfig( + limits=[{"limit": 100}], + pricing={"gpt-4": {"input_per_1k": 100.0, "output_per_1k": 0.0}}, + model_path="model", + ) + ev1 = BudgetEvaluator(config) + ev2 = BudgetEvaluator(config) + + # When: ev1 records usage, ev2 checks + step = {"model": "gpt-4", "usage": {"input_tokens": 500, "output_tokens": 0}} + await ev1.evaluate(step) + result = await ev2.evaluate(step) + + # Then: ev2 sees ev1's accumulated spend (shared store via registry) + assert result.matched is True # 50 + 50 = 100 >= 100 # --------------------------------------------------------------------------- @@ -504,18 +687,22 @@ def test_extract_by_path_rejects_dunder(self) -> None: assert _extract_by_path({"a": 1}, "__class__") is None - def test_group_by_without_metadata_skips_rule(self) -> None: + @pytest.mark.asyncio + async def test_group_by_without_metadata_skips_rule(self) -> None: # Given: rule with group_by=user_id but no user_id in scope rules = [BudgetLimitRule(group_by="user_id", limit=1000)] store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0) # When: step without user_id - results = store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=999) + results = await store.record_and_check( + scope={}, input_tokens=0, output_tokens=0, cost=999.0 + ) # Then: rule skipped assert results == [] - def test_two_rules_same_scope_no_double_count(self) -> None: + @pytest.mark.asyncio + async def test_two_rules_same_scope_no_double_count(self) -> None: # Given: two global rules with different limit types rules = [ BudgetLimitRule(limit=1000), @@ -524,53 +711,44 @@ def test_two_rules_same_scope_no_double_count(self) -> None: store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0) # When: record once - results = store.record_and_check(scope={}, input_tokens=100, output_tokens=100, cost=100) + results = await store.record_and_check( + scope={}, input_tokens=100, output_tokens=100, cost=100.0 + ) # Then: both rules get snapshot, but usage recorded only once assert len(results) == 2 assert results[0].spent == 100 # not 200 assert results[1].spent_tokens == 200 # not 400 - def test_different_currency_separate_buckets(self) -> None: - # Given: two rules with same scope but different currencies - rules = [ - BudgetLimitRule(limit=1000, currency=Currency.USD), - BudgetLimitRule(limit=2000, currency=Currency.EUR), - ] - store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0) - - # When: record once - results = store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=500) - - # Then: each currency gets its own bucket, both record the cost - assert len(results) == 2 - assert results[0].spent == 500 - assert results[1].spent == 500 - - def test_negative_cost_not_recorded(self) -> None: + @pytest.mark.asyncio + async def test_negative_cost_reduces_spend(self) -> None: # Given: store with 1000-cent limit rules = [BudgetLimitRule(limit=1000)] store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0) # When: record positive then negative cost - store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=500) - results = store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=-200) + await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=500.0) + results = await store.record_and_check( + scope={}, input_tokens=0, output_tokens=0, cost=-200.0 + ) - # Then: negative cost is added (store is dumb; validation is caller's job) - # If this is undesirable, evaluator must reject negatives before calling store + # Then: negative cost reduces spend (store does not clamp; validation is caller's job) assert results[0].spent == 300 - def test_window_seconds_boundary_alignment(self) -> None: + @pytest.mark.asyncio + async def test_window_seconds_boundary_alignment(self) -> None: # Given: hourly window, clock at boundary-1 and boundary rules = [BudgetLimitRule(limit=1000, window_seconds=3600)] boundary = 3600 * 100 # exact hour boundary # When: record just before and at boundary store = InMemoryBudgetStore(rules=rules, clock=lambda: boundary - 1) - store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=500) + await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=500.0) store._clock = lambda: boundary - results = store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=500) + results = await store.record_and_check( + scope={}, input_tokens=0, output_tokens=0, cost=500.0 + ) # Then: boundary crossing starts fresh period assert results[0].spent == 500 # not 1000 @@ -582,11 +760,6 @@ def test_zero_limit_tokens_rejected(self) -> None: with pytest.raises(ValidationError, match="positive"): BudgetLimitRule(limit_tokens=0) - def test_invalid_currency_rejected(self) -> None: - # Given/When: invalid currency string - with pytest.raises(ValidationError): - BudgetLimitRule(limit=1000, currency="btc") - class TestBoolGuard: """bool is a subclass of int in Python -- must be rejected.""" @@ -597,3 +770,656 @@ def test_extract_tokens_rejects_bool(self) -> None: # When/Then: bools are not accepted as token counts assert _extract_tokens(data, None) == (0, 0) + + +# --------------------------------------------------------------------------- +# Store registry robustness +# --------------------------------------------------------------------------- + + +class TestStoreRegistryRobustness: + def test_concurrent_get_or_create_store(self) -> None: + # Given: 10 threads requesting the same config concurrently + config = BudgetEvaluatorConfig(limits=[{"limit": 1000}]) + stores: list[Any] = [] + lock = threading.Lock() + + def get_store() -> None: + s = get_or_create_store(config) + with lock: + stores.append(s) + + # When: 10 threads call get_or_create_store simultaneously + threads = [threading.Thread(target=get_store) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + # Then: all threads got the same store object + assert len(stores) == 10 + assert all(s is stores[0] for s in stores) + + @pytest.mark.asyncio + async def test_evaluator_cache_eviction_preserves_budget_state(self) -> None: + # Given: evaluator that has recorded usage + from agent_control_evaluators._factory import ( + clear_evaluator_cache, + ) + + config = BudgetEvaluatorConfig( + limits=[{"limit": 1000}], + pricing={"gpt-4": {"input_per_1k": 100.0, "output_per_1k": 0.0}}, + model_path="model", + ) + ev = BudgetEvaluator(config) + step = {"model": "gpt-4", "usage": {"input_tokens": 500, "output_tokens": 0}} + await ev.evaluate(step) + + # When: simulate LRU eviction by clearing the evaluator cache + clear_evaluator_cache() + + # Then: budget state survives (stored in module-level registry, not on evaluator) + ev2 = BudgetEvaluator(config) + result = await ev2.evaluate(step) + + # 500 tokens * 100 cents/1k = 50.0 cents per call. + # Two calls = 100.0 cents total. limit=1000, so not exceeded. + # Key assertion: state IS preserved across evaluator re-creation. + assert result.metadata is not None + assert result.metadata["cost"] == pytest.approx(50.0, abs=0.1) + # The all_snapshots should show accumulated spend from both calls + snaps = result.metadata["all_snapshots"] + assert snaps[0]["spent"] == 100 # round(50.0 + 50.0) = 100, not 50 + + +# --------------------------------------------------------------------------- +# _estimate_cost edge cases +# --------------------------------------------------------------------------- + + +class TestRoundingBoundary: + @pytest.mark.asyncio + async def test_spent_half_cent_below_limit_not_exceeded(self) -> None: + # Given: store with 1000-cent limit + rules = [BudgetLimitRule(limit=1000)] + store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0) + + # When: spend 999.5 cents (just below limit) + results = await store.record_and_check( + scope={}, input_tokens=0, output_tokens=0, cost=999.5 + ) + + # Then: not exceeded (999.5 < 1000), spent display < limit + assert results[0].exceeded is False + assert results[0].spent < results[0].limit # no contradiction + + @pytest.mark.asyncio + async def test_spent_display_never_exceeds_actual(self) -> None: + # Given: store with 100-cent limit + rules = [BudgetLimitRule(limit=100)] + store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0) + + # When: spend 99.9 cents + results = await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=99.9) + + # Then: floor truncation means spent=99, not rounded to 100 + assert results[0].spent == 99 + assert results[0].exceeded is False + + +class TestConfigKeyOrdering: + def test_limits_order_does_not_affect_store_identity(self) -> None: + # Given: two configs with same rules in different order + rule_a = {"limit": 1000, "scope": {"agent": "a"}} + rule_b = {"limit": 2000, "scope": {"agent": "b"}} + config1 = BudgetEvaluatorConfig(limits=[rule_a, rule_b]) + config2 = BudgetEvaluatorConfig(limits=[rule_b, rule_a]) + + # When: get stores for both + store1 = get_or_create_store(config1) + store2 = get_or_create_store(config2) + + # Then: same store (order-independent) + assert store1 is store2 + + +class TestEstimateCostEdgeCases: + def test_nan_rate_returns_zero(self) -> None: + from agent_control_evaluator_budget.budget.evaluator import _estimate_cost + + # Given: pricing table with NaN rate + pricing = {"gpt-4": {"input_per_1k": float("nan"), "output_per_1k": 0.0}} + + # When: estimate cost + cost = _estimate_cost("gpt-4", 1000, 0, pricing) + + # Then: returns 0.0 (NaN guard) + assert cost == 0.0 + + def test_inf_rate_returns_zero(self) -> None: + from agent_control_evaluator_budget.budget.evaluator import _estimate_cost + + # Given: pricing table with Inf rate + pricing = {"gpt-4": {"input_per_1k": float("inf"), "output_per_1k": 0.0}} + + # When: estimate cost + cost = _estimate_cost("gpt-4", 1000, 0, pricing) + + # Then: returns 0.0 (Inf guard) + assert cost == 0.0 + + def test_negative_rate_returns_zero(self) -> None: + from agent_control_evaluator_budget.budget.evaluator import _estimate_cost + + # Given: pricing table with negative rate + pricing = {"gpt-4": {"input_per_1k": -10.0, "output_per_1k": 0.0}} + + # When: estimate cost + cost = _estimate_cost("gpt-4", 1000, 0, pricing) + + # Then: returns 0.0 (negative guard) + assert cost == 0.0 + + +# --------------------------------------------------------------------------- +# Nested model_path extraction +# --------------------------------------------------------------------------- + + +class TestNestedModelPath: + @pytest.mark.asyncio + async def test_nested_model_path(self) -> None: + # Given: evaluator with nested model_path + ev = BudgetEvaluator( + BudgetEvaluatorConfig( + limits=[{"limit": 1000}], + pricing={"gpt-4": {"input_per_1k": 100.0, "output_per_1k": 0.0}}, + model_path="llm.model_name", + ) + ) + + # When: evaluate with nested model structure + result = await ev.evaluate( + { + "llm": {"model_name": "gpt-4"}, + "usage": {"input_tokens": 500, "output_tokens": 0}, + } + ) + + # Then: model resolved correctly, cost computed + assert result.metadata is not None + assert result.metadata["cost"] == pytest.approx(50.0, abs=0.1) + + +# --------------------------------------------------------------------------- +# TTL prune tests +# --------------------------------------------------------------------------- + + +class TestTTLPrune: + @pytest.mark.asyncio + async def test_ttl_prune_drops_old_period_on_rollover(self) -> None: + # Given: store with daily window. Day N, N+1, N+2 timestamps. + day_seconds = WINDOW_DAILY + day_n = 1700000000.0 + # Align to exact day boundary + day_n = (int(day_n) // day_seconds) * day_seconds + day_n1 = day_n + day_seconds + day_n2 = day_n + 2 * day_seconds + + rules = [BudgetLimitRule(limit=10_000, window_seconds=day_seconds)] + store = InMemoryBudgetStore(rules=rules, clock=lambda: day_n) + + # When: record on day N + await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1.0) + # record on day N+1 + store._clock = lambda: day_n1 + await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=2.0) + # record on day N+2 -- should prune day N + store._clock = lambda: day_n2 + await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=3.0) + + # Then: only buckets for day N+1 and N+2 remain for that scope + with store._lock: + period_keys = [k[1] for k in store._buckets] + + day_n_key = _derive_period_key(day_seconds, day_n) + day_n1_key = _derive_period_key(day_seconds, day_n1) + day_n2_key = _derive_period_key(day_seconds, day_n2) + + assert day_n_key not in period_keys, "Day N bucket should be pruned" + assert day_n1_key in period_keys, "Day N+1 bucket must be retained" + assert day_n2_key in period_keys, "Day N+2 bucket must be retained" + + @pytest.mark.asyncio + async def test_ttl_prune_preserves_cumulative_buckets(self) -> None: + # Given: store with both cumulative (window=None) and daily rules + day_seconds = WINDOW_DAILY + day_n = (int(1700000000.0) // day_seconds) * day_seconds + + rules = [ + BudgetLimitRule(limit=10_000), # cumulative (window_seconds=None) + BudgetLimitRule(limit=10_000, window_seconds=day_seconds), + ] + store = InMemoryBudgetStore(rules=rules, clock=lambda: day_n) + + # When: record on 3 consecutive days + for i in range(3): + store._clock = lambda i=i: day_n + i * day_seconds + await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1.0) + + # Then: cumulative bucket (empty period key) must survive + with store._lock: + period_keys = [k[1] for k in store._buckets] + + assert "" in period_keys, "Cumulative bucket (period_key='') must not be pruned" + + @pytest.mark.asyncio + async def test_ttl_prune_preserves_other_windows(self) -> None: + # Given: store with hourly and daily rules + hour = 3600 + day = WINDOW_DAILY + t0 = (int(1700000000.0) // day) * day # align to day boundary + + rules = [ + BudgetLimitRule(limit=10_000, window_seconds=hour), + BudgetLimitRule(limit=100_000, window_seconds=day), + ] + store = InMemoryBudgetStore(rules=rules, clock=lambda: t0) + + # When: roll hours many times (within same day) + for h in range(5): + store._clock = lambda h=h: t0 + h * hour + await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1.0) + + # Then: daily bucket must survive hourly rollovers + day_key = _derive_period_key(day, t0) + with store._lock: + period_keys = [k[1] for k in store._buckets] + + assert day_key in period_keys, "Daily bucket must survive hourly rollovers" + + # When: roll day (prune old hourly buckets) + t_day2 = t0 + day + store._clock = lambda: t_day2 + await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1.0) + + with store._lock: + period_keys_after = [k[1] for k in store._buckets] + + # Then: old hour-0 through hour-3 (index < current_hour-1) should be pruned + # daily bucket survives (different window) + day_key2 = _derive_period_key(day, t_day2) + assert day_key2 in period_keys_after or day_key in period_keys_after, ( + "At least one daily bucket must survive" + ) + # hour 0 key should be gone (it's >1 period behind the new day's hour-0) + hour0_key = _derive_period_key(hour, t0) + # hour0 is many hours before t_day2's first hour -- must be pruned + assert hour0_key not in period_keys_after, "Old hourly buckets should be pruned" + + @pytest.mark.asyncio + async def test_ttl_prune_no_rescan_within_period(self) -> None: + # Given: store with daily window. After a rollover, subsequent records + # within the same new period must NOT trigger another prune scan. + day_seconds = WINDOW_DAILY + day_n = (int(1700000000.0) // day_seconds) * day_seconds + day_n1 = day_n + day_seconds + + rules = [BudgetLimitRule(limit=10_000, window_seconds=day_seconds)] + store = InMemoryBudgetStore(rules=rules, clock=lambda: day_n) + await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1.0) + + # Roll over to day N+1 + store._clock = lambda: day_n1 + await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1.0) + + # Capture _last_pruned_period state after first record of new period + with store._lock: + snapshot_index = dict(store._last_pruned_period) + + # When: record many more times within the same new period + for _ in range(10): + await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1.0) + + # Then: _last_pruned_period unchanged (no rescan occurred) + with store._lock: + after_index = dict(store._last_pruned_period) + + assert after_index == snapshot_index, "Prune scan must not repeat within same period" + + @pytest.mark.asyncio + async def test_ttl_prune_sparse_rollover(self) -> None: + # Given: daily rule, first record at index 5, then jump to index 100 + day = WINDOW_DAILY + day_n = (int(1700000000.0) // day) * day + rules = [BudgetLimitRule(limit=10_000, window_seconds=day)] + store = InMemoryBudgetStore(rules=rules, clock=lambda: day_n) + + # When: record at baseline + await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1.0) + # Jump forward ~95 days (any stale indices must be swept in one scan) + for i in range(1, 6): + store._clock = lambda i=i: day_n + i * day + await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1.0) + # Large gap -- should prune everything older than index-1 + far = day_n + 100 * day + store._clock = lambda: far + await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1.0) + + # Then: only current (index 100) and previous-valid bucket survive for that window + with store._lock: + period_keys = [k[1] for k in store._buckets if k[1].startswith("P")] + far_key = _derive_period_key(day, far) + assert far_key in period_keys + # Nothing from the early batch (indices 0..5) should remain + for i in range(6): + old_key = _derive_period_key(day, day_n + i * day) + assert old_key not in period_keys, f"stale index {i} must be pruned" + + @pytest.mark.asyncio + async def test_ttl_prune_reset_clears_prune_state(self) -> None: + # Given: store that has pruned once + day = WINDOW_DAILY + day_n = (int(1700000000.0) // day) * day + rules = [BudgetLimitRule(limit=10_000, window_seconds=day)] + store = InMemoryBudgetStore(rules=rules, clock=lambda: day_n) + await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1.0) + store._clock = lambda: day_n + 2 * day + await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1.0) + with store._lock: + assert day in store._last_pruned_period + + # When: full reset + store.reset() + + # Then: _last_pruned_period is cleared so that a future rollover + # re-enables pruning against fresh state + with store._lock: + assert store._last_pruned_period == {} + + # And: a fresh rollover sequence prunes again (watermark advances) + store._clock = lambda: day_n + await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1.0) + store._clock = lambda: day_n + 2 * day + await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1.0) + with store._lock: + assert store._last_pruned_period.get(day) is not None + + @pytest.mark.asyncio + async def test_ttl_prune_partial_reset_preserves_prune_state(self) -> None: + # Given: store that has pruned once + day = WINDOW_DAILY + day_n = (int(1700000000.0) // day) * day + rules = [BudgetLimitRule(limit=10_000, window_seconds=day)] + store = InMemoryBudgetStore(rules=rules, clock=lambda: day_n) + await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1.0) + store._clock = lambda: day_n + 2 * day + await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1.0) + with store._lock: + before = dict(store._last_pruned_period) + + # When: partial reset (scope-scoped) + store.reset(scope_key="__global__") + + # Then: prune state preserved (partial reset does not clobber watermark) + with store._lock: + assert store._last_pruned_period == before + + @pytest.mark.asyncio + async def test_ttl_prune_cross_scope(self) -> None: + # Given: group_by user, two users recording on the same day + day = WINDOW_DAILY + day_n = (int(1700000000.0) // day) * day + rules = [ + BudgetLimitRule(limit=10_000, window_seconds=day, group_by="user_id"), + ] + store = InMemoryBudgetStore(rules=rules, clock=lambda: day_n) + await store.record_and_check( + scope={"user_id": "u1"}, input_tokens=0, output_tokens=0, cost=1.0 + ) + await store.record_and_check( + scope={"user_id": "u2"}, input_tokens=0, output_tokens=0, cost=1.0 + ) + + # Pre-condition: both users have distinct buckets on day N + day_n_key = _derive_period_key(day, day_n) + with store._lock: + day_n_scope_keys = [k[0] for k in store._buckets if k[1] == day_n_key] + assert "user_id=u1" in day_n_scope_keys, "u1 must have its own bucket" + assert "user_id=u2" in day_n_scope_keys, "u2 must have its own bucket" + + # When: only u1 records on day N+2 (triggers prune) + store._clock = lambda: day_n + 2 * day + await store.record_and_check( + scope={"user_id": "u1"}, input_tokens=0, output_tokens=0, cost=1.0 + ) + + # Then: u2's day-N bucket is also pruned -- the period expired globally, + # not per-scope. This is intentional: the prune sweeps all same-window + # stale buckets regardless of which scope triggered it. + day_n_key = _derive_period_key(day, day_n) + with store._lock: + period_keys = [k for k in store._buckets if k[1] == day_n_key] + assert period_keys == [], "u2 day-N bucket must be pruned by u1's rollover" + + @pytest.mark.asyncio + async def test_ttl_prune_respects_max_buckets_after_rollover(self) -> None: + # Given: store with max_buckets=2 (hard cap). Record on day N and N+1 + # fills capacity. On day N+2 the prune must free the day-N slot BEFORE + # the max_buckets check, otherwise rollover permanently fails closed. + day = WINDOW_DAILY + day_n = (int(1700000000.0) // day) * day + rules = [BudgetLimitRule(limit=10_000, window_seconds=day)] + store = InMemoryBudgetStore( + rules=rules, clock=lambda: day_n, max_buckets=2 + ) + + # When: fill 2 buckets + await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1.0) + store._clock = lambda: day_n + day + await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1.0) + # Day N+2 at capacity -- prune must free space + store._clock = lambda: day_n + 2 * day + snaps = await store.record_and_check( + scope={}, input_tokens=0, output_tokens=0, cost=1.0 + ) + + # Then: day N+2 record succeeded (not fail-closed) and day-N bucket is gone + assert len(snaps) == 1 + assert not snaps[0].exceeded + with store._lock: + period_keys = [k[1] for k in store._buckets] + day_n_key = _derive_period_key(day, day_n) + assert day_n_key not in period_keys, "stale day-N bucket must be pruned to free slot" + + @pytest.mark.asyncio + async def test_ttl_prune_backwards_clock_is_noop(self) -> None: + # Given: store that pruned at day N+5 (watermark = index 5) + day = WINDOW_DAILY + day_n = (int(1700000000.0) // day) * day + rules = [BudgetLimitRule(limit=10_000, window_seconds=day)] + store = InMemoryBudgetStore(rules=rules, clock=lambda: day_n) + await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1.0) + store._clock = lambda: day_n + 5 * day + await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1.0) + with store._lock: + watermark_before = store._last_pruned_period.get(day) + assert watermark_before is not None + + # When: clock jumps backwards to day N+2 and creates a new bucket there + store._clock = lambda: day_n + 2 * day + await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1.0) + + # Then: watermark did NOT drop (monotonic advance only) + with store._lock: + watermark_after = store._last_pruned_period.get(day) + assert watermark_after == watermark_before, ( + "backwards clock must not lower the prune watermark" + ) + + +class TestBudgetStoreABC: + def test_subclass_with_sync_override_rejected_at_class_creation(self) -> None: + # Given: a subclass that overrides record_and_check with a sync def + # When: the class body is evaluated + # Then: TypeError is raised, surfacing the contract violation at + # class-creation time rather than failing silently at the first + # `await` call site in production. + from agent_control_evaluator_budget.budget.store import BudgetSnapshot, BudgetStore + + with pytest.raises(TypeError, match="must be an async def"): + + class BrokenStore(BudgetStore): # type: ignore[unused-ignore] + def record_and_check( # noqa: D401, ANN001 + self, + scope: dict[str, str], + input_tokens: int, + output_tokens: int, + cost: float, + ) -> list[BudgetSnapshot]: + return [] + + def test_subclass_with_async_override_accepted(self) -> None: + # Given/When: a subclass that overrides with async def + # Then: class creation succeeds and the subclass can be instantiated + from agent_control_evaluator_budget.budget.store import BudgetSnapshot, BudgetStore + + class GoodStore(BudgetStore): + async def record_and_check( + self, + scope: dict[str, str], + input_tokens: int, + output_tokens: int, + cost: float, + ) -> list[BudgetSnapshot]: + return [] + + # And: instances pass nominal isinstance against the ABC + instance = GoodStore() + assert isinstance(instance, BudgetStore) + + def test_subclass_without_override_accepted_at_class_creation(self) -> None: + # Given/When: a subclass that does NOT override record_and_check + # Then: class creation succeeds (__init_subclass__ method=None path). + # ABC enforces the abstractmethod at instantiation, not class creation. + from agent_control_evaluator_budget.budget.store import BudgetStore + + class PartialStore(BudgetStore): + pass # no override; abstractmethod prevents instantiation + + # And: instantiation is blocked by ABC, not our __init_subclass__ + with pytest.raises(TypeError, match="abstract method"): + PartialStore() + + def test_mixin_sync_override_rejected(self) -> None: + # Given: a sync mixin that provides record_and_check, and a subclass + # that inherits it via MRO without overriding in its own __dict__ + # When: class creation is attempted + # Then: __init_subclass__ walks MRO and catches the sync mixin override + from agent_control_evaluator_budget.budget.store import BudgetStore + + class SyncMixin: + def record_and_check(self, scope, input_tokens, output_tokens, cost): + return [] + + with pytest.raises(TypeError, match="must be an async def"): + + class MixinStore(SyncMixin, BudgetStore): + pass + + +class TestNaNCostDefense: + @pytest.mark.asyncio + async def test_nan_cost_coerced_to_zero(self) -> None: + # Given: store with a cost limit + rules = [BudgetLimitRule(limit=1000)] + store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0) + + # When: NaN cost is injected directly (bypassing _estimate_cost) + await store.record_and_check( + scope={}, input_tokens=0, output_tokens=0, cost=float("nan") + ) + # And: a subsequent valid charge arrives + snaps = await store.record_and_check( + scope={}, input_tokens=0, output_tokens=0, cost=500.0 + ) + + # Then: the NaN was coerced to 0.0; the accumulator is 500, not NaN + assert snaps[0].spent == 500 + assert not snaps[0].exceeded + + @pytest.mark.asyncio + async def test_inf_cost_coerced_to_zero(self) -> None: + # Given: store with a cost limit + rules = [BudgetLimitRule(limit=1000)] + store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0) + + # When: Inf cost is injected + await store.record_and_check( + scope={}, input_tokens=0, output_tokens=0, cost=float("inf") + ) + snaps = await store.record_and_check( + scope={}, input_tokens=0, output_tokens=0, cost=100.0 + ) + + # Then: Inf was coerced to 0.0; the accumulator is 100 + assert snaps[0].spent == 100 + assert not snaps[0].exceeded + + @pytest.mark.asyncio + @pytest.mark.parametrize( + ("neg_input", "neg_output"), + [(-50, 0), (0, -50), (-30, -20)], + ids=["neg_input_only", "neg_output_only", "both_negative"], + ) + async def test_negative_tokens_clamped_to_zero( + self, neg_input: int, neg_output: int + ) -> None: + # Given: store with a token limit, filled to 90 tokens + rules = [BudgetLimitRule(limit_tokens=100)] + store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0) + await store.record_and_check( + scope={}, input_tokens=90, output_tokens=0, cost=0.0 + ) + + # When: inject negative input/output tokens + snaps = await store.record_and_check( + scope={}, input_tokens=neg_input, output_tokens=neg_output, cost=0.0 + ) + + # Then: negative tokens clamped to 0; accumulator stays at 90 + assert snaps[0].spent_tokens == 90 + assert not snaps[0].exceeded + + @pytest.mark.asyncio + async def test_nan_clock_does_not_crash(self) -> None: + # Given: store with a windowed rule AND a clock that returns NaN + rules = [BudgetLimitRule(limit=1000, window_seconds=WINDOW_DAILY)] + store = InMemoryBudgetStore(rules=rules, clock=lambda: float("nan")) + + # When: record_and_check is called (would raise OverflowError in + # _derive_period_key without the guard) + snaps = await store.record_and_check( + scope={}, input_tokens=0, output_tokens=0, cost=100.0 + ) + + # Then: no crash; maps to epoch-zero period, budget still enforced + assert len(snaps) == 1 + assert snaps[0].spent == 100 + + @pytest.mark.asyncio + async def test_inf_clock_does_not_crash(self) -> None: + # Given: clock returning Inf + rules = [BudgetLimitRule(limit=1000, window_seconds=WINDOW_DAILY)] + store = InMemoryBudgetStore(rules=rules, clock=lambda: float("inf")) + + # When: record_and_check is called + snaps = await store.record_and_check( + scope={}, input_tokens=0, output_tokens=0, cost=100.0 + ) + + # Then: no crash + assert len(snaps) == 1 + assert snaps[0].spent == 100 diff --git a/pyproject.toml b/pyproject.toml index 645f6229..95baef8d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,9 +83,3 @@ tag_format = "v{version}" # feat = minor, fix/perf/refactor = patch, breaking (!) = major allowed_tags = ["feat", "fix", "perf", "chore", "docs", "style", "refactor", "test", "ci"] patch_tags = ["fix", "perf", "chore", "refactor"] - -[dependency-groups] -dev = [ - "pytest>=9.0.2", - "pytest-asyncio>=1.3.0", -] From d12b1ec58963ebe9e3f5bc5b9fb49e9f4ede2389 Mon Sep 17 00:00:00 2001 From: amabito Date: Sat, 11 Apr 2026 18:41:00 +0900 Subject: [PATCH 3/3] feat(budget): R5 -- limit_unit, budget_id, ModelPricing, unknown_model_behavior Phase C: - C1: replace limit/limit_tokens dual-ceiling with limit + limit_unit (usd_cents|tokens) - C2: add budget_id:str='default' to BudgetEvaluatorConfig; same id shares bucket state, different id is fully isolated - C3: drop _config_key hash-based store key; registry now keyed by f'budget:{budget_id}' - C4: introduce ModelPricing(EvaluatorConfig) with input_per_1k/output_per_1k; pricing field is now dict[str,ModelPricing]; require pricing when any rule uses limit_unit='usd_cents' - C5: store contract redesign -- InMemoryBudgetStore owns bucket state only; rules are passed per call so same budget_id pools share buckets while each evaluator uses its own rules Phase D: - D1: add unknown_model_behavior:Literal['block','warn']='block' to BudgetEvaluatorConfig - D2: block/warn triggers only for cost-based rules with pricing configured and model absent; token-only rules are unaffected - D3: README rewrite with complete config example, scope/group_by, budget pools, pricing, dual-ceiling pattern, single-process-only caveat Tests: 100 passing (was 91 in R4) --- evaluators/contrib/budget/README.md | 137 ++++- .../budget/__init__.py | 8 +- .../budget/config.py | 72 ++- .../budget/evaluator.py | 55 +- .../budget/memory_store.py | 60 +- .../budget/store.py | 26 +- .../budget/tests/budget/test_budget.py | 549 +++++++++++++----- 7 files changed, 664 insertions(+), 243 deletions(-) diff --git a/evaluators/contrib/budget/README.md b/evaluators/contrib/budget/README.md index ddd159e8..fa16f876 100644 --- a/evaluators/contrib/budget/README.md +++ b/evaluators/contrib/budget/README.md @@ -1,3 +1,136 @@ -# Budget Evaluator +# agent-control-evaluator-budget -Cumulative LLM cost and token budget tracking for agent-control. +Budget evaluator for agent-control that tracks cumulative LLM token and cost usage per scope and time window. + +## Install + +```bash +pip install agent-control-evaluator-budget +``` + +## Quickstart + +```python +from agent_control_evaluator_budget.budget import ( + BudgetEvaluatorConfig, + BudgetLimitRule, + ModelPricing, +) + +config = BudgetEvaluatorConfig( + budget_id="support-daily", + limits=[ + BudgetLimitRule( + scope={"agent": "support"}, + group_by="user_id", + window_seconds=86_400, + limit=500, + limit_unit="usd_cents", + ), + BudgetLimitRule( + scope={"agent": "support"}, + group_by="user_id", + window_seconds=86_400, + limit=50_000, + limit_unit="tokens", + ), + ], + pricing={ + "gpt-4.1-mini": ModelPricing(input_per_1k=0.04, output_per_1k=0.16), + }, + model_path="model", + metadata_paths={ + "agent": "metadata.agent", + "user_id": "metadata.user_id", + }, + unknown_model_behavior="block", +) +``` + +The evaluator reads token usage from standard fields such as `usage.input_tokens` and `usage.output_tokens`. Configure `token_path` only when your event shape uses a custom location. + +## Scope and group_by + +Each `BudgetLimitRule` has a static `scope` and an optional `group_by` field. + +`scope` filters which events a rule applies to. A rule with `scope={"agent": "support"}` only applies when extracted metadata contains `agent="support"`. An empty scope is global. + +`group_by` creates independent buckets per extracted metadata value. The common per-user pattern is: + +```python +BudgetLimitRule( + scope={"agent": "support"}, + group_by="user_id", + window_seconds=86_400, + limit=500, + limit_unit="usd_cents", +) +``` + +With `metadata_paths={"user_id": "metadata.user_id"}`, each user gets a separate daily budget inside the support scope. + +## Budget pools + +`budget_id` identifies the accumulated budget pool. + +Evaluators with the same `budget_id` share accumulated spend and token totals across all evaluator instances. Each evaluator still evaluates using its own configured rules -- the shared state is the bucket (the rolling sum), not the rule set. Evaluators with different `budget_id` values are fully isolated. + +Use stable names such as `support-daily`, `billing-global`, or `tenant-acme-monthly`. Avoid generating a new `budget_id` per request unless each request should have an isolated budget. + +## Pricing + +`ModelPricing` stores cost rates in cents per 1K tokens: + +```python +ModelPricing(input_per_1k=0.04, output_per_1k=0.16) +``` + +`input_per_1k` is applied to input tokens. `output_per_1k` is applied to output tokens. + +Pricing is required when any rule uses `limit_unit="usd_cents"`. Token-only rules can omit pricing. If an event uses a model that is not in the pricing table and a cost rule exists, `unknown_model_behavior="block"` fails closed. Use `"warn"` to log a warning and treat the cost as 0. + +## Dual Ceiling Pattern + +Use two evaluators when cost and token ceilings need independent control records or different `budget_id` pools: + +```python +cost_config = BudgetEvaluatorConfig( + budget_id="support-cost-daily", + limits=[ + BudgetLimitRule( + scope={"agent": "support"}, + group_by="user_id", + window_seconds=86_400, + limit=500, + limit_unit="usd_cents", + ) + ], + pricing={ + "gpt-4.1-mini": ModelPricing(input_per_1k=0.04, output_per_1k=0.16), + }, + model_path="model", + metadata_paths={"agent": "metadata.agent", "user_id": "metadata.user_id"}, +) + +token_config = BudgetEvaluatorConfig( + budget_id="support-token-daily", + limits=[ + BudgetLimitRule( + scope={"agent": "support"}, + group_by="user_id", + window_seconds=86_400, + limit=50_000, + limit_unit="tokens", + ) + ], + metadata_paths={"agent": "metadata.agent", "user_id": "metadata.user_id"}, +) +``` + +This pattern lets cost and token budgets reset, alert, and roll out independently. A single evaluator can also contain both rules when one shared pool and one control result are sufficient. + +## Limitations + +`InMemoryBudgetStore` is single-process only. State is lost on restart and is not shared across workers or pods. + +Use a distributed store for production deployments that run multiple processes, multiple workers, or multiple pods. diff --git a/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/__init__.py b/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/__init__.py index c82d2647..c747c4f8 100644 --- a/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/__init__.py +++ b/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/__init__.py @@ -1,6 +1,10 @@ """Budget evaluator for per-agent LLM cost and token tracking.""" -from agent_control_evaluator_budget.budget.config import BudgetEvaluatorConfig +from agent_control_evaluator_budget.budget.config import ( + BudgetEvaluatorConfig, + BudgetLimitRule, + ModelPricing, +) from agent_control_evaluator_budget.budget.evaluator import BudgetEvaluator from agent_control_evaluator_budget.budget.memory_store import InMemoryBudgetStore from agent_control_evaluator_budget.budget.store import BudgetSnapshot, BudgetStore @@ -12,7 +16,9 @@ __all__ = [ "BudgetEvaluator", "BudgetEvaluatorConfig", + "BudgetLimitRule", "BudgetSnapshot", "BudgetStore", "InMemoryBudgetStore", + "ModelPricing", ] diff --git a/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/config.py b/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/config.py index 6a261f43..b98b7407 100644 --- a/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/config.py +++ b/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/config.py @@ -2,7 +2,7 @@ from __future__ import annotations -from enum import Enum +from typing import Literal from agent_control_evaluators._base import EvaluatorConfig from pydantic import Field, field_validator, model_validator @@ -17,12 +17,11 @@ WINDOW_MONTHLY = 2592000 # 30 days -class Currency(str, Enum): - """Supported budget currencies.""" +class ModelPricing(EvaluatorConfig): + """Per-model token pricing in cents per 1K tokens.""" - USD = "usd" - EUR = "eur" - TOKENS = "tokens" + input_per_1k: float = 0.0 + output_per_1k: float = 0.0 class BudgetLimitRule(EvaluatorConfig): @@ -43,39 +42,24 @@ class BudgetLimitRule(EvaluatorConfig): each user gets their own budget. None = shared/global limit. window_seconds: Time window for accumulation in seconds. None = cumulative (no reset). See WINDOW_* constants. - limit: Maximum spend in the window, in minor units (e.g. cents - for USD). None = uncapped on this dimension. - currency: Currency for the limit. Defaults to USD. - limit_tokens: Maximum tokens in the window. None = uncapped. + limit: Maximum usage in the window. Interpreted by limit_unit. + limit_unit: Unit for limit. usd_cents checks spend; tokens checks + input + output tokens. """ scope: dict[str, str] = Field(default_factory=dict) group_by: str | None = None window_seconds: int | None = None - limit: int | None = None - currency: Currency = Currency.USD - limit_tokens: int | None = None - - @model_validator(mode="after") - def at_least_one_limit(self) -> "BudgetLimitRule": - if self.limit is None and self.limit_tokens is None: - raise ValueError("At least one of limit or limit_tokens must be set") - return self + limit: int + limit_unit: Literal["usd_cents", "tokens"] = "usd_cents" @field_validator("limit") @classmethod - def validate_limit(cls, v: int | None) -> int | None: - if v is not None and v <= 0: + def validate_limit(cls, v: int) -> int: + if v <= 0: raise ValueError("limit must be a positive integer") return v - @field_validator("limit_tokens") - @classmethod - def validate_limit_tokens(cls, v: int | None) -> int | None: - if v is not None and v <= 0: - raise ValueError("limit_tokens must be positive") - return v - @field_validator("window_seconds") @classmethod def validate_window_seconds(cls, v: int | None) -> int | None: @@ -89,9 +73,13 @@ class BudgetEvaluatorConfig(EvaluatorConfig): Attributes: limits: List of budget limit rules. Each is checked independently. - pricing: Optional model pricing table. Maps model name to per-1K - token rates. Used to derive cost in USD from token counts and - model name. + budget_id: Unique budget pool identifier. Same budget_id shares + accumulated spend. Different budget_id is fully isolated. + unknown_model_behavior: What to do when a model is not found in the + pricing table and a cost-based rule exists. block=fail closed, + warn=log warning and treat cost as 0. + pricing: Optional model pricing table. Maps model name to ModelPricing. + Used to derive cost in USD from token counts and model name. token_path: Dot-notation path to extract token usage from step data (e.g. "usage.total_tokens"). If None, looks for standard fields (input_tokens, output_tokens, total_tokens, usage). @@ -101,7 +89,27 @@ class BudgetEvaluatorConfig(EvaluatorConfig): """ limits: list[BudgetLimitRule] = Field(min_length=1) - pricing: dict[str, dict[str, float]] | None = None + budget_id: str = Field( + default="default", + description=( + "Unique budget pool identifier. Same budget_id shares accumulated spend. " + "Different budget_id is fully isolated." + ), + ) + unknown_model_behavior: Literal["block", "warn"] = Field( + default="block", + description=( + "What to do when a model is not found in the pricing table and a cost-based " + "rule exists. block=fail closed, warn=log warning and treat cost as 0." + ), + ) + pricing: dict[str, ModelPricing] | None = None token_path: str | None = None model_path: str | None = None metadata_paths: dict[str, str] = Field(default_factory=dict) + + @model_validator(mode="after") + def require_pricing_for_cost_rules(self) -> "BudgetEvaluatorConfig": + if self.pricing is None and any(rule.limit_unit == "usd_cents" for rule in self.limits): + raise ValueError('pricing is required when any rule uses limit_unit="usd_cents"') + return self diff --git a/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/evaluator.py b/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/evaluator.py index c0d6c517..2799bc9a 100644 --- a/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/evaluator.py +++ b/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/evaluator.py @@ -7,12 +7,11 @@ The evaluator is stateless. Budget state lives in a module-level store registry, independent of the evaluator instance cache in _factory.py. This prevents silent state loss on LRU eviction and avoids cross-control -leakage when different controls share the same config. +leakage when different controls use different budget_id values. """ from __future__ import annotations -import json import logging import math import threading @@ -22,7 +21,7 @@ from agent_control_evaluators._registry import register_evaluator from agent_control_models import EvaluatorResult -from .config import BudgetEvaluatorConfig +from .config import BudgetEvaluatorConfig, ModelPricing from .memory_store import InMemoryBudgetStore from .store import BudgetStore @@ -33,8 +32,8 @@ # # Decoupled from the evaluator instance cache so that LRU eviction in # _factory.py does not destroy accumulated budget state. The registry -# is keyed by a stable config hash. Two controls with identical config -# intentionally share a budget pool (same config = same budget). +# is keyed by budget_id. Controls with the same budget_id intentionally +# share accumulated spend; different budget_id values are isolated. # --------------------------------------------------------------------------- # NOTE: The registry is unbounded. In practice a deployment has a finite @@ -44,27 +43,13 @@ _STORE_REGISTRY_LOCK = threading.Lock() -def _config_key(config: BudgetEvaluatorConfig) -> str: - """Build a stable key for the store registry from evaluator config. - - The limits list is sorted before hashing so that two configs with - semantically identical rules in different order share a store. - """ - config_dict = config.model_dump(mode="json") - config_dict["limits"] = sorted( - config_dict["limits"], - key=lambda r: json.dumps(r, sort_keys=True, default=str), - ) - return f"budget:{json.dumps(config_dict, sort_keys=True, default=str)}" - - def get_or_create_store(config: BudgetEvaluatorConfig) -> BudgetStore: """Get or create a store for the given config, thread-safe.""" - key = _config_key(config) + key = f"budget:{config.budget_id}" with _STORE_REGISTRY_LOCK: store = _STORE_REGISTRY.get(key) if store is None: - store = InMemoryBudgetStore(rules=config.limits) + store = InMemoryBudgetStore() _STORE_REGISTRY[key] = store return store @@ -140,7 +125,7 @@ def _estimate_cost( model: str | None, input_tokens: int, output_tokens: int, - pricing: dict[str, dict[str, float]] | None, + pricing: dict[str, ModelPricing] | None, ) -> float: """Estimate cost in cents (USD) from model pricing table. @@ -152,8 +137,8 @@ def _estimate_cost( rates = pricing.get(model) if not rates: return 0.0 - input_rate = rates.get("input_per_1k", 0.0) - output_rate = rates.get("output_per_1k", 0.0) + input_rate = rates.input_per_1k + output_rate = rates.output_per_1k cost = (input_tokens * input_rate + output_tokens * output_rate) / 1000.0 if not math.isfinite(cost) or cost < 0: return 0.0 @@ -211,11 +196,31 @@ async def evaluate(self, data: Any) -> EvaluatorResult: model = str(val) cost = _estimate_cost(model, input_tokens, output_tokens, self.config.pricing) + model_known = model is None or self.config.pricing is None or model in self.config.pricing + has_cost_rule = any(rule.limit_unit == "usd_cents" for rule in self.config.limits) + if not model_known and has_cost_rule: + if self.config.unknown_model_behavior == "block": + return EvaluatorResult( + matched=True, + confidence=1.0, + message=f"Unknown model blocked: {model}", + metadata={ + "unknown_model": model, + "input_tokens": input_tokens, + "output_tokens": output_tokens, + }, + ) + logger.warning( + "Budget evaluator: unknown model %r, treating cost as 0 " + "(unknown_model_behavior=warn)", + model, + ) step_metadata = _extract_metadata(data, self.config.metadata_paths) store = get_or_create_store(self.config) snapshots = await store.record_and_check( + rules=self.config.limits, scope=step_metadata, input_tokens=input_tokens, output_tokens=output_tokens, @@ -230,7 +235,7 @@ async def evaluate(self, data: Any) -> EvaluatorResult: "spent": snap.spent, "spent_tokens": snap.spent_tokens, "limit": snap.limit, - "limit_tokens": snap.limit_tokens, + "limit_unit": snap.limit_unit, "utilization": round(snap.utilization, 4), "exceeded": snap.exceeded, } diff --git a/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/memory_store.py b/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/memory_store.py index df1784eb..241d8ef4 100644 --- a/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/memory_store.py +++ b/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/memory_store.py @@ -76,20 +76,19 @@ def _compute_utilization( spent: float, spent_tokens: int, limit: int | None, - limit_tokens: int | None, + limit_unit: str, ) -> float: - """Return max(spend_ratio, token_ratio) clamped to [0.0, 1.0]. + """Return the selected usage ratio clamped to [0.0, 1.0]. The low-side clamp is load-bearing: under refund semantics the internal `spent` accumulator may go negative, which would otherwise produce a negative ratio and violate the BudgetSnapshot.utilization contract. """ - ratios: list[float] = [] - if limit is not None and limit > 0: - ratios.append(max(0.0, min(spent / limit, 1.0))) - if limit_tokens is not None and limit_tokens > 0: - ratios.append(max(0.0, min(spent_tokens / limit_tokens, 1.0))) - return max(ratios) if ratios else 0.0 + if limit_unit == "tokens": + ratio = spent_tokens / limit if limit else 0.0 + else: + ratio = spent / limit if limit else 0.0 + return max(0.0, min(ratio, 1.0)) @dataclass @@ -108,8 +107,9 @@ def total_tokens(self) -> int: class InMemoryBudgetStore(BudgetStore): """Thread-safe in-memory budget store. - Initialized with a list of BudgetLimitRule. Derives period keys - internally from window_seconds + injected clock. + Owns bucket state and derives period keys internally from + window_seconds + injected clock. Callers provide the rules to evaluate + on each record operation. Cost is accumulated as float for precision. Integer rounding happens only at snapshot time for display/reporting. @@ -126,12 +126,10 @@ class InMemoryBudgetStore(BudgetStore): def __init__( self, - rules: list[BudgetLimitRule], *, clock: Callable[[], float] = time.time, max_buckets: int = _DEFAULT_MAX_BUCKETS, ) -> None: - self._rules = rules self._clock = clock self._lock = threading.Lock() self._buckets: dict[tuple[str, str], _Bucket] = {} @@ -140,16 +138,18 @@ def __init__( async def record_and_check( self, + rules: list[BudgetLimitRule], scope: dict[str, str], input_tokens: int, output_tokens: int, cost: float, ) -> list[BudgetSnapshot]: """Atomically record usage and return snapshots for all matching rules.""" - return self._record_and_check_sync(scope, input_tokens, output_tokens, cost) + return self._record_and_check_sync(rules, scope, input_tokens, output_tokens, cost) def _record_and_check_sync( self, + rules: list[BudgetLimitRule], scope: dict[str, str], input_tokens: int, output_tokens: int, @@ -175,7 +175,7 @@ def _record_and_check_sync( recorded_pairs: set[tuple[str, str]] = set() with self._lock: - for rule in self._rules: + for rule in rules: if not _scope_matches(rule, scope): continue @@ -192,9 +192,9 @@ def _record_and_check_sync( spent=0, spent_tokens=0, limit=rule.limit, - limit_tokens=rule.limit_tokens, utilization=1.0, exceeded=True, + limit_unit=rule.limit_unit, ) ) continue @@ -215,22 +215,21 @@ def _record_and_check_sync( total_tokens = bucket.total_tokens utilization = _compute_utilization( - bucket.spent, total_tokens, rule.limit, rule.limit_tokens + bucket.spent, total_tokens, rule.limit, rule.limit_unit ) - exceeded = False - if rule.limit is not None and bucket.spent >= rule.limit: - exceeded = True - if rule.limit_tokens is not None and total_tokens >= rule.limit_tokens: - exceeded = True + if rule.limit_unit == "tokens": + exceeded = total_tokens >= rule.limit + else: + exceeded = bucket.spent >= rule.limit snapshots.append( BudgetSnapshot( spent=round_spent(bucket.spent), spent_tokens=total_tokens, limit=rule.limit, - limit_tokens=rule.limit_tokens, utilization=utilization, exceeded=exceeded, + limit_unit=rule.limit_unit, ) ) @@ -241,7 +240,7 @@ def get_snapshot( scope_key: str, period_key: str, limit: int | None = None, - limit_tokens: int | None = None, + limit_unit: str = "usd_cents", ) -> BudgetSnapshot: """Read current budget state without recording usage.""" key = (scope_key, period_key) @@ -252,24 +251,23 @@ def get_snapshot( spent=0, spent_tokens=0, limit=limit, - limit_tokens=limit_tokens, utilization=0.0, exceeded=False, + limit_unit=limit_unit, ) total_tokens = bucket.total_tokens - utilization = _compute_utilization(bucket.spent, total_tokens, limit, limit_tokens) - exceeded = False - if limit is not None and bucket.spent >= limit: - exceeded = True - if limit_tokens is not None and total_tokens >= limit_tokens: - exceeded = True + utilization = _compute_utilization(bucket.spent, total_tokens, limit, limit_unit) + if limit_unit == "tokens": + exceeded = bool(limit is not None and total_tokens >= limit) + else: + exceeded = bool(limit is not None and bucket.spent >= limit) return BudgetSnapshot( spent=round_spent(bucket.spent), spent_tokens=total_tokens, limit=limit, - limit_tokens=limit_tokens, utilization=utilization, exceeded=exceeded, + limit_unit=limit_unit, ) def reset(self, scope_key: str | None = None, period_key: str | None = None) -> None: diff --git a/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/store.py b/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/store.py index 9d58f76f..6564ead3 100644 --- a/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/store.py +++ b/evaluators/contrib/budget/src/agent_control_evaluator_budget/budget/store.py @@ -14,7 +14,10 @@ import math from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from .config import BudgetLimitRule @dataclass(frozen=True) @@ -24,19 +27,18 @@ class BudgetSnapshot: Attributes: spent: Cumulative spend in cents (USD), rounded from float. spent_tokens: Cumulative tokens (input + output) in this scope+period. - limit: Configured spend ceiling in cents, or None if uncapped. - limit_tokens: Configured token ceiling, or None if uncapped. - utilization: max(spend_ratio, token_ratio) clamped to [0.0, 1.0]. - 0.0 when no limits are set. - exceeded: True when any limit is breached. + limit: Configured ceiling, interpreted by limit_unit. + utilization: Selected usage ratio clamped to [0.0, 1.0]. + exceeded: True when the configured limit is breached. + limit_unit: Unit used to interpret limit. """ spent: int spent_tokens: int limit: int | None - limit_tokens: int | None utilization: float exceeded: bool + limit_unit: str = "usd_cents" def round_spent(value: float) -> int: @@ -54,10 +56,10 @@ def round_spent(value: float) -> int: class BudgetStore(ABC): """Abstract base class for budget storage backends. - The store is initialized with a list of BudgetLimitRule and derives - period keys internally from window_seconds + current time. - - Callers pass only usage data: scope dict, input_tokens, output_tokens, cost. + The store owns bucket state and derives period keys internally from + window_seconds + current time. Callers pass the rules to evaluate for + each record operation along with usage data: scope dict, input_tokens, + output_tokens, cost. Negative `cost` values are permitted and reduce accumulated spend (refund semantics). `round_spent()` floors the displayed snapshot spend to 0 for @@ -102,6 +104,7 @@ def __init_subclass__(cls, **kwargs: Any) -> None: @abstractmethod async def record_and_check( self, + rules: list[BudgetLimitRule], scope: dict[str, str], input_tokens: int, output_tokens: int, @@ -110,6 +113,7 @@ async def record_and_check( """Atomically record usage and return snapshots for all matching rules. Args: + rules: Rules to evaluate against the shared bucket state. scope: Scope dimensions from the step (e.g. {"agent": "summarizer"}). input_tokens: Input tokens consumed by this call. output_tokens: Output tokens consumed by this call. diff --git a/evaluators/contrib/budget/tests/budget/test_budget.py b/evaluators/contrib/budget/tests/budget/test_budget.py index 86adae5c..f84eb728 100644 --- a/evaluators/contrib/budget/tests/budget/test_budget.py +++ b/evaluators/contrib/budget/tests/budget/test_budget.py @@ -17,6 +17,7 @@ WINDOW_WEEKLY, BudgetEvaluatorConfig, BudgetLimitRule, + ModelPricing, ) from agent_control_evaluator_budget.budget.evaluator import ( BudgetEvaluator, @@ -48,11 +49,11 @@ class TestInMemoryBudgetStore: async def test_single_record_under_limit(self) -> None: # Given: store with a $10 daily limit (1000 cents) rules = [BudgetLimitRule(limit=1000, window_seconds=WINDOW_DAILY)] - store = InMemoryBudgetStore(rules=rules, clock=lambda: 1700000000.0) + store = InMemoryBudgetStore(clock=lambda: 1700000000.0) # When: record 300 cents of usage results = await store.record_and_check( - scope={}, input_tokens=100, output_tokens=50, cost=300.0 + rules=rules, scope={}, input_tokens=100, output_tokens=50, cost=300.0 ) # Then: not breached, ratio ~0.3 @@ -64,12 +65,14 @@ async def test_single_record_under_limit(self) -> None: async def test_accumulation_triggers_breach(self) -> None: # Given: store with 1000-cent limit rules = [BudgetLimitRule(limit=1000)] - store = InMemoryBudgetStore(rules=rules, clock=lambda: 1700000000.0) + store = InMemoryBudgetStore(clock=lambda: 1700000000.0) # When: record 600 + 500 = 1100 cents - await store.record_and_check(scope={}, input_tokens=100, output_tokens=50, cost=600.0) + await store.record_and_check( + rules=rules, scope={}, input_tokens=100, output_tokens=50, cost=600.0 + ) results = await store.record_and_check( - scope={}, input_tokens=100, output_tokens=50, cost=500.0 + rules=rules, scope={}, input_tokens=100, output_tokens=50, cost=500.0 ) # Then: exceeded @@ -83,14 +86,14 @@ async def test_scope_isolation(self) -> None: BudgetLimitRule(scope={"agent": "a"}, limit=1000), BudgetLimitRule(scope={"agent": "b"}, limit=1000), ] - store = InMemoryBudgetStore(rules=rules, clock=lambda: 1700000000.0) + store = InMemoryBudgetStore(clock=lambda: 1700000000.0) # When: agent-a records 900, agent-b records 100 results_a = await store.record_and_check( - scope={"agent": "a"}, input_tokens=0, output_tokens=0, cost=900.0 + rules=rules, scope={"agent": "a"}, input_tokens=0, output_tokens=0, cost=900.0 ) results_b = await store.record_and_check( - scope={"agent": "b"}, input_tokens=0, output_tokens=0, cost=100.0 + rules=rules, scope={"agent": "b"}, input_tokens=0, output_tokens=0, cost=100.0 ) # Then: agent-a near limit, agent-b well under @@ -106,12 +109,14 @@ async def test_period_isolation(self) -> None: day2 = day1 + WINDOW_DAILY # When: record on day 1, then day 2 - store = InMemoryBudgetStore(rules=rules, clock=lambda: day1) - await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=800.0) + store = InMemoryBudgetStore(clock=lambda: day1) + await store.record_and_check( + rules=rules, scope={}, input_tokens=0, output_tokens=0, cost=800.0 + ) store._clock = lambda: day2 results = await store.record_and_check( - scope={}, input_tokens=0, output_tokens=0, cost=300.0 + rules=rules, scope={}, input_tokens=0, output_tokens=0, cost=300.0 ) # Then: day 2 is a fresh period @@ -122,11 +127,11 @@ async def test_period_isolation(self) -> None: async def test_exceeded_exact_limit(self) -> None: # Given: 1000-cent limit rules = [BudgetLimitRule(limit=1000)] - store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0) + store = InMemoryBudgetStore(clock=lambda: 0.0) # When: spend exactly 1000 results = await store.record_and_check( - scope={}, input_tokens=0, output_tokens=0, cost=1000.0 + rules=rules, scope={}, input_tokens=0, output_tokens=0, cost=1000.0 ) # Then: exceeded (>= not >) @@ -135,12 +140,12 @@ async def test_exceeded_exact_limit(self) -> None: @pytest.mark.asyncio async def test_token_only_limit(self) -> None: # Given: 1000-token limit, no cost limit - rules = [BudgetLimitRule(limit_tokens=1000)] - store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0) + rules = [BudgetLimitRule(limit=1000, limit_unit="tokens")] + store = InMemoryBudgetStore(clock=lambda: 0.0) # When: consume 600+500 = 1100 tokens results = await store.record_and_check( - scope={}, input_tokens=600, output_tokens=500, cost=0.0 + rules=rules, scope={}, input_tokens=600, output_tokens=500, cost=0.0 ) # Then: exceeded @@ -151,11 +156,11 @@ async def test_token_only_limit(self) -> None: async def test_no_matching_rules(self) -> None: # Given: rule for agent=summarizer only rules = [BudgetLimitRule(scope={"agent": "summarizer"}, limit=1000)] - store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0) + store = InMemoryBudgetStore(clock=lambda: 0.0) # When: step from agent=other results = await store.record_and_check( - scope={"agent": "other"}, input_tokens=100, output_tokens=50, cost=999.0 + rules=rules, scope={"agent": "other"}, input_tokens=100, output_tokens=50, cost=999.0 ) # Then: no snapshots (rule didn't match) @@ -165,17 +170,17 @@ async def test_no_matching_rules(self) -> None: async def test_group_by_user(self) -> None: # Given: global rule with group_by=user_id rules = [BudgetLimitRule(group_by="user_id", limit=500)] - store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0) + store = InMemoryBudgetStore(clock=lambda: 0.0) # When: two users each spend await store.record_and_check( - scope={"user_id": "u1"}, input_tokens=0, output_tokens=0, cost=400.0 + rules=rules, scope={"user_id": "u1"}, input_tokens=0, output_tokens=0, cost=400.0 ) results_u1 = await store.record_and_check( - scope={"user_id": "u1"}, input_tokens=0, output_tokens=0, cost=200.0 + rules=rules, scope={"user_id": "u1"}, input_tokens=0, output_tokens=0, cost=200.0 ) results_u2 = await store.record_and_check( - scope={"user_id": "u2"}, input_tokens=0, output_tokens=0, cost=300.0 + rules=rules, scope={"user_id": "u2"}, input_tokens=0, output_tokens=0, cost=300.0 ) # Then: u1 exceeded, u2 not @@ -184,10 +189,10 @@ async def test_group_by_user(self) -> None: def test_thread_safety(self) -> None: # Given: high-limit rule and 10 concurrent threads - # Each thread calls asyncio.run(store.record_and_check(...)) -- the async + # Each thread calls asyncio.run(store.record_and_check(rules=rules, ...)) -- the async # method wraps a sync critical section, so threading.Lock prevents races. rules = [BudgetLimitRule(limit=1_000_000)] - store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0) + store = InMemoryBudgetStore(clock=lambda: 0.0) errors: list[str] = [] import asyncio @@ -196,7 +201,9 @@ def record_many() -> None: try: for _ in range(100): asyncio.run( - store.record_and_check(scope={}, input_tokens=1, output_tokens=1, cost=1.0) + store.record_and_check( + rules=rules, scope={}, input_tokens=1, output_tokens=1, cost=1.0 + ) ) except Exception as exc: errors.append(str(exc)) @@ -218,13 +225,13 @@ def record_many() -> None: async def test_max_buckets_fail_closed(self) -> None: # Given: store limited to 3 buckets with group_by=user_id rules = [BudgetLimitRule(group_by="user_id", limit=100_000)] - store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0, max_buckets=3) + store = InMemoryBudgetStore(clock=lambda: 0.0, max_buckets=3) # When: 5 different users try to record exceeded_count = 0 for i in range(5): results = await store.record_and_check( - scope={"user_id": f"u{i}"}, input_tokens=1, output_tokens=1, cost=1.0 + rules=rules, scope={"user_id": f"u{i}"}, input_tokens=1, output_tokens=1, cost=1.0 ) if results and results[0].exceeded: exceeded_count += 1 @@ -236,8 +243,10 @@ async def test_max_buckets_fail_closed(self) -> None: async def test_reset_all(self) -> None: # Given: store with recorded usage rules = [BudgetLimitRule(limit=1000)] - store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0) - await store.record_and_check(scope={}, input_tokens=10, output_tokens=10, cost=100.0) + store = InMemoryBudgetStore(clock=lambda: 0.0) + await store.record_and_check( + rules=rules, scope={}, input_tokens=10, output_tokens=10, cost=100.0 + ) # When: reset all store.reset() @@ -250,11 +259,13 @@ async def test_reset_all(self) -> None: async def test_float_accumulation_precision(self) -> None: # Given: store with 1-cent limit rules = [BudgetLimitRule(limit=1)] - store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0) + store = InMemoryBudgetStore(clock=lambda: 0.0) # When: 100 calls each costing 0.003 cents (total = 0.3 cents) for _ in range(100): - await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=0.003) + await store.record_and_check( + rules=rules, scope={}, input_tokens=0, output_tokens=0, cost=0.003 + ) # Then: not exceeded (0.3 < 1), no ceil-per-call overcount snap = store.get_snapshot("__global__", "", limit=1) @@ -265,12 +276,12 @@ async def test_float_accumulation_precision(self) -> None: async def test_float_accumulation_eventual_breach(self) -> None: # Given: store with 1-cent limit rules = [BudgetLimitRule(limit=1)] - store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0) + store = InMemoryBudgetStore(clock=lambda: 0.0) # When: 400 calls each costing 0.003 cents (total = 1.2 cents) for _ in range(400): results = await store.record_and_check( - scope={}, input_tokens=0, output_tokens=0, cost=0.003 + rules=rules, scope={}, input_tokens=0, output_tokens=0, cost=0.003 ) # Then: exceeded (1.2 >= 1) @@ -285,23 +296,23 @@ async def test_float_accumulation_eventual_breach(self) -> None: class TestUtilities: def test_compute_utilization_no_limits(self) -> None: # Given/When: no limits set / Then: 0.0 - assert _compute_utilization(100.0, 10000, None, None) == 0.0 + assert _compute_utilization(100.0, 10000, None, "usd_cents") == 0.0 def test_compute_utilization_spend_only(self) -> None: # Given: 500 of 1000 spent / Then: 0.5 - assert _compute_utilization(500.0, 0, 1000, None) == pytest.approx(0.5) + assert _compute_utilization(500.0, 0, 1000, "usd_cents") == pytest.approx(0.5) def test_compute_utilization_clamped(self) -> None: # Given: overspent / Then: clamped to 1.0 - assert _compute_utilization(2000.0, 0, 1000, None) == pytest.approx(1.0) + assert _compute_utilization(2000.0, 0, 1000, "usd_cents") == pytest.approx(1.0) def test_compute_utilization_negative_clamped_to_zero(self) -> None: # Given: refund made the accumulator go negative # When: utilization is computed # Then: clamped to 0.0 (BudgetSnapshot.utilization contract) - assert _compute_utilization(-150.0, 0, 100, None) == 0.0 + assert _compute_utilization(-150.0, 0, 100, "usd_cents") == 0.0 # And: negative tokens (not currently reachable but defensively clamped) - assert _compute_utilization(0.0, -50, None, 100) == 0.0 + assert _compute_utilization(0.0, -50, 100, "tokens") == 0.0 def test_parse_period_key_valid(self) -> None: # Given: well-formed period key / Then: parsed tuple @@ -385,8 +396,8 @@ def test_valid_rule(self) -> None: assert rule.limit == 1000 def test_no_limit_rejected(self) -> None: - # Given/When: no limit or limit_tokens / Then: rejected - with pytest.raises(ValidationError, match="At least one"): + # Given/When: no limit / Then: rejected + with pytest.raises(ValidationError, match="Field required"): BudgetLimitRule() def test_negative_limit_rejected(self) -> None: @@ -399,11 +410,6 @@ def test_zero_limit_rejected(self) -> None: with pytest.raises(ValidationError, match="positive"): BudgetLimitRule(limit=0) - def test_negative_limit_tokens_rejected(self) -> None: - # Given/When: negative limit_tokens / Then: rejected - with pytest.raises(ValidationError, match="positive"): - BudgetLimitRule(limit_tokens=-1) - def test_negative_window_seconds_rejected(self) -> None: # Given/When: negative window_seconds / Then: rejected with pytest.raises(ValidationError, match="positive"): @@ -415,10 +421,10 @@ def test_zero_window_seconds_rejected(self) -> None: BudgetLimitRule(limit=1000, window_seconds=0) def test_token_only_rule(self) -> None: - # Given/When: limit_tokens only / Then: accepted, limit is None - rule = BudgetLimitRule(limit_tokens=5000) - assert rule.limit is None - assert rule.limit_tokens == 5000 + # Given/When: token limit_unit / Then: accepted + rule = BudgetLimitRule(limit=5000, limit_unit="tokens") + assert rule.limit == 5000 + assert rule.limit_unit == "tokens" def test_empty_limits_rejected(self) -> None: # Given/When: empty limits list / Then: rejected @@ -432,6 +438,22 @@ def test_window_constants(self) -> None: assert WINDOW_MONTHLY == 2592000 +class TestModelPricing: + def test_model_pricing_validation_requires_pricing_for_cost_rules(self) -> None: + # Given: a cost-based rule without pricing + # When/Then: config validation rejects it + with pytest.raises(ValidationError, match="pricing is required"): + BudgetEvaluatorConfig(limits=[BudgetLimitRule(limit=100)]) + + def test_model_pricing_token_rule_no_pricing_ok(self) -> None: + # Given: a token-only rule without pricing + # When: config is created + config = BudgetEvaluatorConfig(limits=[BudgetLimitRule(limit=100, limit_unit="tokens")]) + + # Then: no pricing table is required + assert config.pricing is None + + # --------------------------------------------------------------------------- # BudgetEvaluator integration # --------------------------------------------------------------------------- @@ -445,7 +467,7 @@ def _make_evaluator(self, **kwargs: Any) -> BudgetEvaluator: @pytest.mark.asyncio async def test_single_call_under_budget(self) -> None: # Given: evaluator with $10 limit (1000 cents) - ev = self._make_evaluator(limits=[{"limit": 1000}]) + ev = self._make_evaluator(limits=[{"limit": 1000}], pricing={}) # When: evaluate with usage data result = await ev.evaluate({"usage": {"input_tokens": 100, "output_tokens": 50}}) @@ -503,7 +525,7 @@ def _step(tokens: int, user: str) -> dict: @pytest.mark.asyncio async def test_token_only_limit(self) -> None: # Given: 500 token limit - ev = self._make_evaluator(limits=[{"limit_tokens": 500}]) + ev = self._make_evaluator(limits=[{"limit": 500, "limit_unit": "tokens"}]) # When: consume 600 tokens result = await ev.evaluate({"usage": {"input_tokens": 300, "output_tokens": 300}}) @@ -514,7 +536,7 @@ async def test_token_only_limit(self) -> None: @pytest.mark.asyncio async def test_no_data_returns_not_matched(self) -> None: # Given: evaluator / When: None data / Then: not matched - ev = self._make_evaluator(limits=[{"limit": 1000}]) + ev = self._make_evaluator(limits=[{"limit": 1000}], pricing={}) result = await ev.evaluate(None) assert result.matched is False @@ -563,11 +585,12 @@ async def test_cost_computed_from_pricing_table(self) -> None: @pytest.mark.asyncio async def test_unknown_model_cost_zero(self) -> None: - # Given: evaluator with pricing table but data from an unknown model + # Given: evaluator with warn mode and data from an unknown model ev = self._make_evaluator( limits=[{"limit": 100}], pricing={"gpt-4": {"input_per_1k": 30.0, "output_per_1k": 60.0}}, model_path="model", + unknown_model_behavior="warn", ) # When: evaluate with a model not in the pricing table @@ -601,6 +624,167 @@ async def test_small_cost_no_overcount(self) -> None: assert result.matched is False +class TestBudgetIdSemantics: + @pytest.mark.asyncio + async def test_same_budget_id_shares_store(self) -> None: + # Given: two evaluators with the same budget_id + config1 = BudgetEvaluatorConfig( + limits=[{"limit": 100}], + budget_id="shared", + pricing={"gpt-4": {"input_per_1k": 100.0, "output_per_1k": 0.0}}, + model_path="model", + ) + config2 = BudgetEvaluatorConfig( + limits=[{"limit": 100}], + budget_id="shared", + pricing={"gpt-4": {"input_per_1k": 100.0, "output_per_1k": 0.0}}, + model_path="model", + ) + ev1 = BudgetEvaluator(config1) + ev2 = BudgetEvaluator(config2) + step = {"model": "gpt-4", "usage": {"input_tokens": 500, "output_tokens": 0}} + + # When: each evaluator records a 50-cent call + first = await ev1.evaluate(step) + second = await ev2.evaluate(step) + + # Then: spend is shared and the second call reaches the 100-cent limit + assert first.matched is False + assert second.matched is True + + @pytest.mark.asyncio + async def test_different_budget_id_isolates_store(self) -> None: + # Given: two evaluators with different budget_id values + config1 = BudgetEvaluatorConfig( + limits=[{"limit": 100}], + budget_id="pool-a", + pricing={"gpt-4": {"input_per_1k": 100.0, "output_per_1k": 0.0}}, + model_path="model", + ) + config2 = BudgetEvaluatorConfig( + limits=[{"limit": 100}], + budget_id="pool-b", + pricing={"gpt-4": {"input_per_1k": 100.0, "output_per_1k": 0.0}}, + model_path="model", + ) + ev1 = BudgetEvaluator(config1) + ev2 = BudgetEvaluator(config2) + step = {"model": "gpt-4", "usage": {"input_tokens": 500, "output_tokens": 0}} + + # When: each evaluator records a 50-cent call + first = await ev1.evaluate(step) + second = await ev2.evaluate(step) + + # Then: each pool remains below the 100-cent limit independently + assert first.matched is False + assert second.matched is False + + +class TestUnknownModelBehavior: + @pytest.mark.asyncio + async def test_unknown_model_block_default(self) -> None: + # Given: a cost rule with pricing that does not include the incoming model + config = BudgetEvaluatorConfig( + limits=[{"limit": 100}], + pricing={"gpt-4": {"input_per_1k": 10.0, "output_per_1k": 20.0}}, + model_path="model", + ) + evaluator = BudgetEvaluator(config) + + # When: the step uses an unknown model + result = await evaluator.evaluate( + {"model": "unknown-model", "usage": {"input_tokens": 100, "output_tokens": 50}} + ) + + # Then: the evaluator fails closed and reports the unknown model + assert result.matched is True + assert result.metadata is not None + assert result.metadata["unknown_model"] == "unknown-model" + + @pytest.mark.asyncio + async def test_unknown_model_warn(self) -> None: + # Given: a cost rule configured to warn on unknown models + config = BudgetEvaluatorConfig( + limits=[{"limit": 100}], + pricing={"gpt-4": {"input_per_1k": 10.0, "output_per_1k": 20.0}}, + model_path="model", + unknown_model_behavior="warn", + ) + evaluator = BudgetEvaluator(config) + + # When: the step uses an unknown model + result = await evaluator.evaluate( + {"model": "unknown-model", "usage": {"input_tokens": 100, "output_tokens": 50}} + ) + + # Then: the evaluator treats cost as 0 and does not block + assert result.matched is False + assert result.metadata is not None + assert result.metadata["cost"] == 0.0 + assert result.metadata["all_snapshots"][0]["spent_tokens"] == 150 + + @pytest.mark.asyncio + async def test_unknown_model_token_only_unaffected(self) -> None: + # Given: a token-only rule with a pricing table that does not include + # the incoming model and the default block setting + config = BudgetEvaluatorConfig( + limits=[{"limit": 1000, "limit_unit": "tokens"}], + pricing={}, + model_path="model", + ) + evaluator = BudgetEvaluator(config) + + # When: the step uses an unknown model below the token limit + result = await evaluator.evaluate( + {"model": "unknown-model", "usage": {"input_tokens": 100, "output_tokens": 50}} + ) + + # Then: unknown-model blocking is not applied without a cost rule + assert result.matched is False + assert result.metadata is not None + assert result.metadata["all_snapshots"][0]["spent_tokens"] == 150 + + @pytest.mark.asyncio + async def test_pricing_lookup_is_case_sensitive(self) -> None: + # Given: pricing for lowercase gpt-4 + config = BudgetEvaluatorConfig( + limits=[{"limit": 100}], + pricing={"gpt-4": {"input_per_1k": 10.0, "output_per_1k": 20.0}}, + model_path="model", + ) + evaluator = BudgetEvaluator(config) + + # When: the step uses a differently cased model name + result = await evaluator.evaluate( + {"model": "GPT-4", "usage": {"input_tokens": 100, "output_tokens": 50}} + ) + + # Then: lookup is case-sensitive and the default behavior fails closed + assert result.matched is True + assert result.metadata is not None + assert result.metadata["unknown_model"] == "GPT-4" + + @pytest.mark.asyncio + async def test_known_model_not_blocked(self) -> None: + # Given: a cost rule whose pricing includes the incoming model + config = BudgetEvaluatorConfig( + limits=[{"limit": 100}], + pricing={"gpt-4": {"input_per_1k": 10.0, "output_per_1k": 20.0}}, + model_path="model", + ) + evaluator = BudgetEvaluator(config) + + # When: the step uses the known model + result = await evaluator.evaluate( + {"model": "gpt-4", "usage": {"input_tokens": 100, "output_tokens": 50}} + ) + + # Then: normal budget evaluation runs + assert result.matched is False + assert result.metadata is not None + assert "unknown_model" not in result.metadata + + # --------------------------------------------------------------------------- # Store registry # --------------------------------------------------------------------------- @@ -609,7 +793,7 @@ async def test_small_cost_no_overcount(self) -> None: class TestStoreRegistry: def test_same_config_returns_same_store(self) -> None: # Given: two configs with identical parameters - config = BudgetEvaluatorConfig(limits=[{"limit": 1000}]) + config = BudgetEvaluatorConfig(limits=[{"limit": 1000}], pricing={}) # When: get store twice store1 = get_or_create_store(config) @@ -618,10 +802,10 @@ def test_same_config_returns_same_store(self) -> None: # Then: same object assert store1 is store2 - def test_different_config_returns_different_store(self) -> None: - # Given: two configs with different limits - config1 = BudgetEvaluatorConfig(limits=[{"limit": 1000}]) - config2 = BudgetEvaluatorConfig(limits=[{"limit": 2000}]) + def test_different_budget_id_returns_different_store(self) -> None: + # Given: two configs with different budget ids + config1 = BudgetEvaluatorConfig(limits=[{"limit": 1000}], budget_id="a", pricing={}) + config2 = BudgetEvaluatorConfig(limits=[{"limit": 1000}], budget_id="b", pricing={}) # When: get stores store1 = get_or_create_store(config1) @@ -632,7 +816,7 @@ def test_different_config_returns_different_store(self) -> None: def test_clear_budget_stores(self) -> None: # Given: a registered store - config = BudgetEvaluatorConfig(limits=[{"limit": 1000}]) + config = BudgetEvaluatorConfig(limits=[{"limit": 1000}], pricing={}) store1 = get_or_create_store(config) # When: clear all stores @@ -661,6 +845,39 @@ async def test_evaluator_uses_registry(self) -> None: # Then: ev2 sees ev1's accumulated spend (shared store via registry) assert result.matched is True # 50 + 50 = 100 >= 100 + @pytest.mark.asyncio + async def test_same_budget_id_shares_buckets_but_not_rules(self) -> None: + # Given: two configs sharing budget_id but using different limits + pricing = {"gpt-4": {"input_per_1k": 100.0, "output_per_1k": 0.0}} + config1 = BudgetEvaluatorConfig( + limits=[{"limit": 100}], + budget_id="shared", + pricing=pricing, + model_path="model", + ) + config2 = BudgetEvaluatorConfig( + limits=[{"limit": 1000}], + budget_id="shared", + pricing=pricing, + model_path="model", + ) + ev1 = BudgetEvaluator(config1) + ev2 = BudgetEvaluator(config2) + step = {"model": "gpt-4", "usage": {"input_tokens": 600, "output_tokens": 0}} + + # When: the first evaluator records 60 cents, then the second records + # another 60 cents into the same budget bucket + first = await ev1.evaluate(step) + second = await ev2.evaluate(step) + + # Then: the second evaluator sees shared bucket state (120 cents) but + # evaluates against its own 1000-cent rule, not config1's 100-cent rule. + assert first.matched is False + assert second.matched is False + assert second.metadata is not None + assert second.metadata["all_snapshots"][0]["spent"] == 120 + assert second.metadata["all_snapshots"][0]["limit"] == 1000 + # --------------------------------------------------------------------------- # Security / adversarial tests @@ -691,11 +908,11 @@ def test_extract_by_path_rejects_dunder(self) -> None: async def test_group_by_without_metadata_skips_rule(self) -> None: # Given: rule with group_by=user_id but no user_id in scope rules = [BudgetLimitRule(group_by="user_id", limit=1000)] - store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0) + store = InMemoryBudgetStore(clock=lambda: 0.0) # When: step without user_id results = await store.record_and_check( - scope={}, input_tokens=0, output_tokens=0, cost=999.0 + rules=rules, scope={}, input_tokens=0, output_tokens=0, cost=999.0 ) # Then: rule skipped @@ -706,13 +923,13 @@ async def test_two_rules_same_scope_no_double_count(self) -> None: # Given: two global rules with different limit types rules = [ BudgetLimitRule(limit=1000), - BudgetLimitRule(limit_tokens=5000), + BudgetLimitRule(limit=5000, limit_unit="tokens"), ] - store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0) + store = InMemoryBudgetStore(clock=lambda: 0.0) # When: record once results = await store.record_and_check( - scope={}, input_tokens=100, output_tokens=100, cost=100.0 + rules=rules, scope={}, input_tokens=100, output_tokens=100, cost=100.0 ) # Then: both rules get snapshot, but usage recorded only once @@ -724,12 +941,14 @@ async def test_two_rules_same_scope_no_double_count(self) -> None: async def test_negative_cost_reduces_spend(self) -> None: # Given: store with 1000-cent limit rules = [BudgetLimitRule(limit=1000)] - store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0) + store = InMemoryBudgetStore(clock=lambda: 0.0) # When: record positive then negative cost - await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=500.0) + await store.record_and_check( + rules=rules, scope={}, input_tokens=0, output_tokens=0, cost=500.0 + ) results = await store.record_and_check( - scope={}, input_tokens=0, output_tokens=0, cost=-200.0 + rules=rules, scope={}, input_tokens=0, output_tokens=0, cost=-200.0 ) # Then: negative cost reduces spend (store does not clamp; validation is caller's job) @@ -742,12 +961,14 @@ async def test_window_seconds_boundary_alignment(self) -> None: boundary = 3600 * 100 # exact hour boundary # When: record just before and at boundary - store = InMemoryBudgetStore(rules=rules, clock=lambda: boundary - 1) - await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=500.0) + store = InMemoryBudgetStore(clock=lambda: boundary - 1) + await store.record_and_check( + rules=rules, scope={}, input_tokens=0, output_tokens=0, cost=500.0 + ) store._clock = lambda: boundary results = await store.record_and_check( - scope={}, input_tokens=0, output_tokens=0, cost=500.0 + rules=rules, scope={}, input_tokens=0, output_tokens=0, cost=500.0 ) # Then: boundary crossing starts fresh period @@ -755,10 +976,10 @@ async def test_window_seconds_boundary_alignment(self) -> None: class TestConfigValidationEdgeCases: - def test_zero_limit_tokens_rejected(self) -> None: + def test_zero_token_limit_rejected(self) -> None: # Given/When: zero token limit with pytest.raises(ValidationError, match="positive"): - BudgetLimitRule(limit_tokens=0) + BudgetLimitRule(limit=0, limit_unit="tokens") class TestBoolGuard: @@ -780,7 +1001,7 @@ def test_extract_tokens_rejects_bool(self) -> None: class TestStoreRegistryRobustness: def test_concurrent_get_or_create_store(self) -> None: # Given: 10 threads requesting the same config concurrently - config = BudgetEvaluatorConfig(limits=[{"limit": 1000}]) + config = BudgetEvaluatorConfig(limits=[{"limit": 1000}], pricing={}) stores: list[Any] = [] lock = threading.Lock() @@ -843,11 +1064,11 @@ class TestRoundingBoundary: async def test_spent_half_cent_below_limit_not_exceeded(self) -> None: # Given: store with 1000-cent limit rules = [BudgetLimitRule(limit=1000)] - store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0) + store = InMemoryBudgetStore(clock=lambda: 0.0) # When: spend 999.5 cents (just below limit) results = await store.record_and_check( - scope={}, input_tokens=0, output_tokens=0, cost=999.5 + rules=rules, scope={}, input_tokens=0, output_tokens=0, cost=999.5 ) # Then: not exceeded (999.5 < 1000), spent display < limit @@ -858,10 +1079,12 @@ async def test_spent_half_cent_below_limit_not_exceeded(self) -> None: async def test_spent_display_never_exceeds_actual(self) -> None: # Given: store with 100-cent limit rules = [BudgetLimitRule(limit=100)] - store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0) + store = InMemoryBudgetStore(clock=lambda: 0.0) # When: spend 99.9 cents - results = await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=99.9) + results = await store.record_and_check( + rules=rules, scope={}, input_tokens=0, output_tokens=0, cost=99.9 + ) # Then: floor truncation means spent=99, not rounded to 100 assert results[0].spent == 99 @@ -869,12 +1092,12 @@ async def test_spent_display_never_exceeds_actual(self) -> None: class TestConfigKeyOrdering: - def test_limits_order_does_not_affect_store_identity(self) -> None: - # Given: two configs with same rules in different order + def test_limits_order_does_not_affect_same_budget_id_store_identity(self) -> None: + # Given: two configs with same budget_id and rules in different order rule_a = {"limit": 1000, "scope": {"agent": "a"}} rule_b = {"limit": 2000, "scope": {"agent": "b"}} - config1 = BudgetEvaluatorConfig(limits=[rule_a, rule_b]) - config2 = BudgetEvaluatorConfig(limits=[rule_b, rule_a]) + config1 = BudgetEvaluatorConfig(limits=[rule_a, rule_b], budget_id="ordered", pricing={}) + config2 = BudgetEvaluatorConfig(limits=[rule_b, rule_a], budget_id="ordered", pricing={}) # When: get stores for both store1 = get_or_create_store(config1) @@ -889,7 +1112,7 @@ def test_nan_rate_returns_zero(self) -> None: from agent_control_evaluator_budget.budget.evaluator import _estimate_cost # Given: pricing table with NaN rate - pricing = {"gpt-4": {"input_per_1k": float("nan"), "output_per_1k": 0.0}} + pricing = {"gpt-4": ModelPricing(input_per_1k=float("nan"), output_per_1k=0.0)} # When: estimate cost cost = _estimate_cost("gpt-4", 1000, 0, pricing) @@ -901,7 +1124,7 @@ def test_inf_rate_returns_zero(self) -> None: from agent_control_evaluator_budget.budget.evaluator import _estimate_cost # Given: pricing table with Inf rate - pricing = {"gpt-4": {"input_per_1k": float("inf"), "output_per_1k": 0.0}} + pricing = {"gpt-4": ModelPricing(input_per_1k=float("inf"), output_per_1k=0.0)} # When: estimate cost cost = _estimate_cost("gpt-4", 1000, 0, pricing) @@ -913,7 +1136,7 @@ def test_negative_rate_returns_zero(self) -> None: from agent_control_evaluator_budget.budget.evaluator import _estimate_cost # Given: pricing table with negative rate - pricing = {"gpt-4": {"input_per_1k": -10.0, "output_per_1k": 0.0}} + pricing = {"gpt-4": ModelPricing(input_per_1k=-10.0, output_per_1k=0.0)} # When: estimate cost cost = _estimate_cost("gpt-4", 1000, 0, pricing) @@ -969,16 +1192,22 @@ async def test_ttl_prune_drops_old_period_on_rollover(self) -> None: day_n2 = day_n + 2 * day_seconds rules = [BudgetLimitRule(limit=10_000, window_seconds=day_seconds)] - store = InMemoryBudgetStore(rules=rules, clock=lambda: day_n) + store = InMemoryBudgetStore(clock=lambda: day_n) # When: record on day N - await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1.0) + await store.record_and_check( + rules=rules, scope={}, input_tokens=0, output_tokens=0, cost=1.0 + ) # record on day N+1 store._clock = lambda: day_n1 - await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=2.0) + await store.record_and_check( + rules=rules, scope={}, input_tokens=0, output_tokens=0, cost=2.0 + ) # record on day N+2 -- should prune day N store._clock = lambda: day_n2 - await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=3.0) + await store.record_and_check( + rules=rules, scope={}, input_tokens=0, output_tokens=0, cost=3.0 + ) # Then: only buckets for day N+1 and N+2 remain for that scope with store._lock: @@ -1002,12 +1231,14 @@ async def test_ttl_prune_preserves_cumulative_buckets(self) -> None: BudgetLimitRule(limit=10_000), # cumulative (window_seconds=None) BudgetLimitRule(limit=10_000, window_seconds=day_seconds), ] - store = InMemoryBudgetStore(rules=rules, clock=lambda: day_n) + store = InMemoryBudgetStore(clock=lambda: day_n) # When: record on 3 consecutive days for i in range(3): store._clock = lambda i=i: day_n + i * day_seconds - await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1.0) + await store.record_and_check( + rules=rules, scope={}, input_tokens=0, output_tokens=0, cost=1.0 + ) # Then: cumulative bucket (empty period key) must survive with store._lock: @@ -1026,12 +1257,14 @@ async def test_ttl_prune_preserves_other_windows(self) -> None: BudgetLimitRule(limit=10_000, window_seconds=hour), BudgetLimitRule(limit=100_000, window_seconds=day), ] - store = InMemoryBudgetStore(rules=rules, clock=lambda: t0) + store = InMemoryBudgetStore(clock=lambda: t0) # When: roll hours many times (within same day) for h in range(5): store._clock = lambda h=h: t0 + h * hour - await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1.0) + await store.record_and_check( + rules=rules, scope={}, input_tokens=0, output_tokens=0, cost=1.0 + ) # Then: daily bucket must survive hourly rollovers day_key = _derive_period_key(day, t0) @@ -1043,7 +1276,9 @@ async def test_ttl_prune_preserves_other_windows(self) -> None: # When: roll day (prune old hourly buckets) t_day2 = t0 + day store._clock = lambda: t_day2 - await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1.0) + await store.record_and_check( + rules=rules, scope={}, input_tokens=0, output_tokens=0, cost=1.0 + ) with store._lock: period_keys_after = [k[1] for k in store._buckets] @@ -1068,12 +1303,16 @@ async def test_ttl_prune_no_rescan_within_period(self) -> None: day_n1 = day_n + day_seconds rules = [BudgetLimitRule(limit=10_000, window_seconds=day_seconds)] - store = InMemoryBudgetStore(rules=rules, clock=lambda: day_n) - await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1.0) + store = InMemoryBudgetStore(clock=lambda: day_n) + await store.record_and_check( + rules=rules, scope={}, input_tokens=0, output_tokens=0, cost=1.0 + ) # Roll over to day N+1 store._clock = lambda: day_n1 - await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1.0) + await store.record_and_check( + rules=rules, scope={}, input_tokens=0, output_tokens=0, cost=1.0 + ) # Capture _last_pruned_period state after first record of new period with store._lock: @@ -1081,7 +1320,9 @@ async def test_ttl_prune_no_rescan_within_period(self) -> None: # When: record many more times within the same new period for _ in range(10): - await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1.0) + await store.record_and_check( + rules=rules, scope={}, input_tokens=0, output_tokens=0, cost=1.0 + ) # Then: _last_pruned_period unchanged (no rescan occurred) with store._lock: @@ -1095,18 +1336,24 @@ async def test_ttl_prune_sparse_rollover(self) -> None: day = WINDOW_DAILY day_n = (int(1700000000.0) // day) * day rules = [BudgetLimitRule(limit=10_000, window_seconds=day)] - store = InMemoryBudgetStore(rules=rules, clock=lambda: day_n) + store = InMemoryBudgetStore(clock=lambda: day_n) # When: record at baseline - await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1.0) + await store.record_and_check( + rules=rules, scope={}, input_tokens=0, output_tokens=0, cost=1.0 + ) # Jump forward ~95 days (any stale indices must be swept in one scan) for i in range(1, 6): store._clock = lambda i=i: day_n + i * day - await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1.0) + await store.record_and_check( + rules=rules, scope={}, input_tokens=0, output_tokens=0, cost=1.0 + ) # Large gap -- should prune everything older than index-1 far = day_n + 100 * day store._clock = lambda: far - await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1.0) + await store.record_and_check( + rules=rules, scope={}, input_tokens=0, output_tokens=0, cost=1.0 + ) # Then: only current (index 100) and previous-valid bucket survive for that window with store._lock: @@ -1124,10 +1371,14 @@ async def test_ttl_prune_reset_clears_prune_state(self) -> None: day = WINDOW_DAILY day_n = (int(1700000000.0) // day) * day rules = [BudgetLimitRule(limit=10_000, window_seconds=day)] - store = InMemoryBudgetStore(rules=rules, clock=lambda: day_n) - await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1.0) + store = InMemoryBudgetStore(clock=lambda: day_n) + await store.record_and_check( + rules=rules, scope={}, input_tokens=0, output_tokens=0, cost=1.0 + ) store._clock = lambda: day_n + 2 * day - await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1.0) + await store.record_and_check( + rules=rules, scope={}, input_tokens=0, output_tokens=0, cost=1.0 + ) with store._lock: assert day in store._last_pruned_period @@ -1141,9 +1392,13 @@ async def test_ttl_prune_reset_clears_prune_state(self) -> None: # And: a fresh rollover sequence prunes again (watermark advances) store._clock = lambda: day_n - await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1.0) + await store.record_and_check( + rules=rules, scope={}, input_tokens=0, output_tokens=0, cost=1.0 + ) store._clock = lambda: day_n + 2 * day - await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1.0) + await store.record_and_check( + rules=rules, scope={}, input_tokens=0, output_tokens=0, cost=1.0 + ) with store._lock: assert store._last_pruned_period.get(day) is not None @@ -1153,10 +1408,14 @@ async def test_ttl_prune_partial_reset_preserves_prune_state(self) -> None: day = WINDOW_DAILY day_n = (int(1700000000.0) // day) * day rules = [BudgetLimitRule(limit=10_000, window_seconds=day)] - store = InMemoryBudgetStore(rules=rules, clock=lambda: day_n) - await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1.0) + store = InMemoryBudgetStore(clock=lambda: day_n) + await store.record_and_check( + rules=rules, scope={}, input_tokens=0, output_tokens=0, cost=1.0 + ) store._clock = lambda: day_n + 2 * day - await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1.0) + await store.record_and_check( + rules=rules, scope={}, input_tokens=0, output_tokens=0, cost=1.0 + ) with store._lock: before = dict(store._last_pruned_period) @@ -1175,12 +1434,12 @@ async def test_ttl_prune_cross_scope(self) -> None: rules = [ BudgetLimitRule(limit=10_000, window_seconds=day, group_by="user_id"), ] - store = InMemoryBudgetStore(rules=rules, clock=lambda: day_n) + store = InMemoryBudgetStore(clock=lambda: day_n) await store.record_and_check( - scope={"user_id": "u1"}, input_tokens=0, output_tokens=0, cost=1.0 + rules=rules, scope={"user_id": "u1"}, input_tokens=0, output_tokens=0, cost=1.0 ) await store.record_and_check( - scope={"user_id": "u2"}, input_tokens=0, output_tokens=0, cost=1.0 + rules=rules, scope={"user_id": "u2"}, input_tokens=0, output_tokens=0, cost=1.0 ) # Pre-condition: both users have distinct buckets on day N @@ -1193,7 +1452,7 @@ async def test_ttl_prune_cross_scope(self) -> None: # When: only u1 records on day N+2 (triggers prune) store._clock = lambda: day_n + 2 * day await store.record_and_check( - scope={"user_id": "u1"}, input_tokens=0, output_tokens=0, cost=1.0 + rules=rules, scope={"user_id": "u1"}, input_tokens=0, output_tokens=0, cost=1.0 ) # Then: u2's day-N bucket is also pruned -- the period expired globally, @@ -1212,18 +1471,20 @@ async def test_ttl_prune_respects_max_buckets_after_rollover(self) -> None: day = WINDOW_DAILY day_n = (int(1700000000.0) // day) * day rules = [BudgetLimitRule(limit=10_000, window_seconds=day)] - store = InMemoryBudgetStore( - rules=rules, clock=lambda: day_n, max_buckets=2 - ) + store = InMemoryBudgetStore(clock=lambda: day_n, max_buckets=2) # When: fill 2 buckets - await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1.0) + await store.record_and_check( + rules=rules, scope={}, input_tokens=0, output_tokens=0, cost=1.0 + ) store._clock = lambda: day_n + day - await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1.0) + await store.record_and_check( + rules=rules, scope={}, input_tokens=0, output_tokens=0, cost=1.0 + ) # Day N+2 at capacity -- prune must free space store._clock = lambda: day_n + 2 * day snaps = await store.record_and_check( - scope={}, input_tokens=0, output_tokens=0, cost=1.0 + rules=rules, scope={}, input_tokens=0, output_tokens=0, cost=1.0 ) # Then: day N+2 record succeeded (not fail-closed) and day-N bucket is gone @@ -1240,17 +1501,23 @@ async def test_ttl_prune_backwards_clock_is_noop(self) -> None: day = WINDOW_DAILY day_n = (int(1700000000.0) // day) * day rules = [BudgetLimitRule(limit=10_000, window_seconds=day)] - store = InMemoryBudgetStore(rules=rules, clock=lambda: day_n) - await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1.0) + store = InMemoryBudgetStore(clock=lambda: day_n) + await store.record_and_check( + rules=rules, scope={}, input_tokens=0, output_tokens=0, cost=1.0 + ) store._clock = lambda: day_n + 5 * day - await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1.0) + await store.record_and_check( + rules=rules, scope={}, input_tokens=0, output_tokens=0, cost=1.0 + ) with store._lock: watermark_before = store._last_pruned_period.get(day) assert watermark_before is not None # When: clock jumps backwards to day N+2 and creates a new bucket there store._clock = lambda: day_n + 2 * day - await store.record_and_check(scope={}, input_tokens=0, output_tokens=0, cost=1.0) + await store.record_and_check( + rules=rules, scope={}, input_tokens=0, output_tokens=0, cost=1.0 + ) # Then: watermark did NOT drop (monotonic advance only) with store._lock: @@ -1274,6 +1541,7 @@ def test_subclass_with_sync_override_rejected_at_class_creation(self) -> None: class BrokenStore(BudgetStore): # type: ignore[unused-ignore] def record_and_check( # noqa: D401, ANN001 self, + rules: list[BudgetLimitRule], scope: dict[str, str], input_tokens: int, output_tokens: int, @@ -1289,6 +1557,7 @@ def test_subclass_with_async_override_accepted(self) -> None: class GoodStore(BudgetStore): async def record_and_check( self, + rules: list[BudgetLimitRule], scope: dict[str, str], input_tokens: int, output_tokens: int, @@ -1321,7 +1590,7 @@ def test_mixin_sync_override_rejected(self) -> None: from agent_control_evaluator_budget.budget.store import BudgetStore class SyncMixin: - def record_and_check(self, scope, input_tokens, output_tokens, cost): + def record_and_check(self, rules, scope, input_tokens, output_tokens, cost): return [] with pytest.raises(TypeError, match="must be an async def"): @@ -1335,15 +1604,15 @@ class TestNaNCostDefense: async def test_nan_cost_coerced_to_zero(self) -> None: # Given: store with a cost limit rules = [BudgetLimitRule(limit=1000)] - store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0) + store = InMemoryBudgetStore(clock=lambda: 0.0) # When: NaN cost is injected directly (bypassing _estimate_cost) await store.record_and_check( - scope={}, input_tokens=0, output_tokens=0, cost=float("nan") + rules=rules, scope={}, input_tokens=0, output_tokens=0, cost=float("nan") ) # And: a subsequent valid charge arrives snaps = await store.record_and_check( - scope={}, input_tokens=0, output_tokens=0, cost=500.0 + rules=rules, scope={}, input_tokens=0, output_tokens=0, cost=500.0 ) # Then: the NaN was coerced to 0.0; the accumulator is 500, not NaN @@ -1354,14 +1623,14 @@ async def test_nan_cost_coerced_to_zero(self) -> None: async def test_inf_cost_coerced_to_zero(self) -> None: # Given: store with a cost limit rules = [BudgetLimitRule(limit=1000)] - store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0) + store = InMemoryBudgetStore(clock=lambda: 0.0) # When: Inf cost is injected await store.record_and_check( - scope={}, input_tokens=0, output_tokens=0, cost=float("inf") + rules=rules, scope={}, input_tokens=0, output_tokens=0, cost=float("inf") ) snaps = await store.record_and_check( - scope={}, input_tokens=0, output_tokens=0, cost=100.0 + rules=rules, scope={}, input_tokens=0, output_tokens=0, cost=100.0 ) # Then: Inf was coerced to 0.0; the accumulator is 100 @@ -1374,19 +1643,17 @@ async def test_inf_cost_coerced_to_zero(self) -> None: [(-50, 0), (0, -50), (-30, -20)], ids=["neg_input_only", "neg_output_only", "both_negative"], ) - async def test_negative_tokens_clamped_to_zero( - self, neg_input: int, neg_output: int - ) -> None: + async def test_negative_tokens_clamped_to_zero(self, neg_input: int, neg_output: int) -> None: # Given: store with a token limit, filled to 90 tokens - rules = [BudgetLimitRule(limit_tokens=100)] - store = InMemoryBudgetStore(rules=rules, clock=lambda: 0.0) + rules = [BudgetLimitRule(limit=100, limit_unit="tokens")] + store = InMemoryBudgetStore(clock=lambda: 0.0) await store.record_and_check( - scope={}, input_tokens=90, output_tokens=0, cost=0.0 + rules=rules, scope={}, input_tokens=90, output_tokens=0, cost=0.0 ) # When: inject negative input/output tokens snaps = await store.record_and_check( - scope={}, input_tokens=neg_input, output_tokens=neg_output, cost=0.0 + rules=rules, scope={}, input_tokens=neg_input, output_tokens=neg_output, cost=0.0 ) # Then: negative tokens clamped to 0; accumulator stays at 90 @@ -1397,12 +1664,12 @@ async def test_negative_tokens_clamped_to_zero( async def test_nan_clock_does_not_crash(self) -> None: # Given: store with a windowed rule AND a clock that returns NaN rules = [BudgetLimitRule(limit=1000, window_seconds=WINDOW_DAILY)] - store = InMemoryBudgetStore(rules=rules, clock=lambda: float("nan")) + store = InMemoryBudgetStore(clock=lambda: float("nan")) # When: record_and_check is called (would raise OverflowError in # _derive_period_key without the guard) snaps = await store.record_and_check( - scope={}, input_tokens=0, output_tokens=0, cost=100.0 + rules=rules, scope={}, input_tokens=0, output_tokens=0, cost=100.0 ) # Then: no crash; maps to epoch-zero period, budget still enforced @@ -1413,11 +1680,11 @@ async def test_nan_clock_does_not_crash(self) -> None: async def test_inf_clock_does_not_crash(self) -> None: # Given: clock returning Inf rules = [BudgetLimitRule(limit=1000, window_seconds=WINDOW_DAILY)] - store = InMemoryBudgetStore(rules=rules, clock=lambda: float("inf")) + store = InMemoryBudgetStore(clock=lambda: float("inf")) # When: record_and_check is called snaps = await store.record_and_check( - scope={}, input_tokens=0, output_tokens=0, cost=100.0 + rules=rules, scope={}, input_tokens=0, output_tokens=0, cost=100.0 ) # Then: no crash