diff --git a/py/src/braintrust/framework.py b/py/src/braintrust/framework.py index 2eeb00de..378aa852 100644 --- a/py/src/braintrust/framework.py +++ b/py/src/braintrust/framework.py @@ -9,7 +9,7 @@ import traceback import warnings from collections import defaultdict -from collections.abc import Awaitable, Callable, Coroutine, Iterable, Iterator, Sequence +from collections.abc import Awaitable, Callable, Coroutine, Iterable, Iterator, Mapping, Sequence from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager from multiprocessing import cpu_count @@ -1154,7 +1154,10 @@ def parse_filters(filters: list[str]) -> list[Filter]: def evaluate_filter(object, filter: Filter): key = object for p in filter.path: - key = key.get(p) + if isinstance(key, Mapping): + key = key.get(p) + else: + key = getattr(key, p, None) if key is None: return False return filter.pattern.match(serialize_json_with_plain_string(key)) is not None diff --git a/py/src/braintrust/test_framework.py b/py/src/braintrust/test_framework.py index 7d33eda0..4f6175d2 100644 --- a/py/src/braintrust/test_framework.py +++ b/py/src/braintrust/test_framework.py @@ -1,4 +1,5 @@ import importlib.util +import re from typing import List from unittest.mock import MagicMock @@ -11,6 +12,9 @@ EvalHooks, EvalResultWithSummary, Evaluator, + Filter, + evaluate_filter, + parse_filters, run_evaluator, ) from .score import Score, Scorer @@ -626,3 +630,77 @@ async def test_run_evaluator_empty_dataset_warns(capsys): captured = capsys.readouterr() assert "Warning" in captured.err assert "empty" in captured.err.lower() + + +class TestEvaluateFilter: + """Regression tests for https://github.com/braintrustdata/braintrust-sdk-python/issues/207.""" + + @pytest.mark.parametrize( + "datum", + [ + {"input": "hello", "metadata": {"name": "foo"}}, + EvalCase(input="hello", metadata={"name": "foo"}), + ], + ids=["dict", "evalcase"], + ) + def test_evaluate_filter_match(self, datum): + f = Filter(path=["metadata", "name"], pattern=re.compile("foo")) + assert evaluate_filter(datum, f) is True + + @pytest.mark.parametrize( + "datum", + [ + {"input": "hello", "metadata": {"name": "bar"}}, + EvalCase(input="hello", metadata={"name": "bar"}), + ], + ids=["dict", "evalcase"], + ) + def test_evaluate_filter_no_match(self, datum): + f = Filter(path=["metadata", "name"], pattern=re.compile("foo")) + assert evaluate_filter(datum, f) is False + + @pytest.mark.parametrize( + "datum", + [ + {"input": "hello"}, + EvalCase(input="hello"), + ], + ids=["dict", "evalcase"], + ) + def test_evaluate_filter_missing_key(self, datum): + f = Filter(path=["metadata", "name"], pattern=re.compile("foo")) + assert evaluate_filter(datum, f) is False + + def test_evaluate_filter_nested_metadata(self): + datum = EvalCase(input="hello", metadata={"priority": "P0", "owner": "alice"}) + f = Filter(path=["metadata", "priority"], pattern=re.compile("^P0$")) + assert evaluate_filter(datum, f) is True + + def test_evaluate_filter_input_field(self): + datum = EvalCase(input={"text": "hello world"}, metadata={"name": "foo"}) + f = Filter(path=["input", "text"], pattern=re.compile("hello")) + assert evaluate_filter(datum, f) is True + + @pytest.mark.asyncio + async def test_run_evaluator_with_filter_and_evalcase(self): + data = [ + EvalCase(input="hello", metadata={"name": "foo"}), + EvalCase(input="world", metadata={"name": "bar"}), + ] + + evaluator = Evaluator( + project_name="test-project", + eval_name="test-filter-evalcase", + data=data, + task=lambda x: x, + scores=[], + experiment_name=None, + metadata=None, + ) + + filters = parse_filters(["metadata.name=foo"]) + result = await run_evaluator(experiment=None, evaluator=evaluator, position=None, filters=filters) + + # Only the "foo" case should pass the filter + assert len(result.results) == 1 + assert result.results[0].input == "hello"