|
8 | 8 | from braintrust.test_helpers import has_devserver_installed |
9 | 9 |
|
10 | 10 |
|
| 11 | +HAS_PYDANTIC = __import__("importlib.util").util.find_spec("pydantic") is not None |
| 12 | + |
| 13 | + |
11 | 14 | @pytest.fixture |
12 | 15 | def client(): |
13 | 16 | """Create test client using the real simple_eval.py example.""" |
@@ -205,3 +208,67 @@ def test_eval_error_handling(client, api_key, org_name): |
205 | 208 | error = response.json() |
206 | 209 | assert "error" in error |
207 | 210 | assert "not found" in error["error"].lower() |
| 211 | + |
| 212 | + |
| 213 | +@pytest.mark.skipif(not HAS_PYDANTIC, reason="pydantic not installed") |
| 214 | +def test_eval_uses_inline_request_parameters(api_key, org_name, monkeypatch): |
| 215 | + from braintrust import Evaluator |
| 216 | + from braintrust.devserver import server as devserver_module |
| 217 | + from braintrust.devserver.server import create_app |
| 218 | + from braintrust.logger import BraintrustState |
| 219 | + from pydantic import BaseModel |
| 220 | + from starlette.testclient import TestClient |
| 221 | + |
| 222 | + class RequiredInt(BaseModel): |
| 223 | + value: int |
| 224 | + |
| 225 | + def task(input: str, hooks) -> dict[str, Any]: |
| 226 | + return {"input": input, "num_samples": hooks.parameters["num_samples_without_default"]} |
| 227 | + |
| 228 | + evaluator = Evaluator( |
| 229 | + project_name="test-math-eval", |
| 230 | + eval_name="inline-parameter-eval", |
| 231 | + data=lambda: [{"input": "What is 2+2?", "expected": "4"}], |
| 232 | + task=task, |
| 233 | + scores=[], |
| 234 | + experiment_name=None, |
| 235 | + metadata=None, |
| 236 | + parameters={"num_samples_without_default": RequiredInt}, |
| 237 | + ) |
| 238 | + |
| 239 | + async def fake_cached_login(**_kwargs): |
| 240 | + return BraintrustState() |
| 241 | + |
| 242 | + class FakeSummary: |
| 243 | + def as_dict(self): |
| 244 | + return {"experiment_name": "inline-parameter-eval", "project_name": "test-math-eval", "scores": {}} |
| 245 | + |
| 246 | + class FakeResult: |
| 247 | + summary = FakeSummary() |
| 248 | + |
| 249 | + async def fake_eval_async(*, task, data, parameters, **_kwargs): |
| 250 | + assert parameters == {"num_samples_without_default": 1} |
| 251 | + datum = data[0] |
| 252 | + hooks = type("Hooks", (), {"parameters": parameters, "report_progress": lambda self, _progress: None})() |
| 253 | + await task(datum["input"], hooks) |
| 254 | + return FakeResult() |
| 255 | + |
| 256 | + monkeypatch.setattr(devserver_module, "cached_login", fake_cached_login) |
| 257 | + monkeypatch.setattr(devserver_module, "EvalAsync", fake_eval_async) |
| 258 | + |
| 259 | + response = TestClient(create_app([evaluator])).post( |
| 260 | + "/eval", |
| 261 | + headers={ |
| 262 | + "x-bt-auth-token": api_key, |
| 263 | + "x-bt-org-name": org_name, |
| 264 | + "Content-Type": "application/json", |
| 265 | + }, |
| 266 | + json={ |
| 267 | + "name": "inline-parameter-eval", |
| 268 | + "stream": False, |
| 269 | + "parameters": {"num_samples_without_default": 1}, |
| 270 | + "data": [{"input": "What is 2+2?", "expected": "4"}], |
| 271 | + }, |
| 272 | + ) |
| 273 | + |
| 274 | + assert response.status_code == 200 |
0 commit comments