Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions py/src/braintrust/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import traceback
import warnings
from collections import defaultdict
from collections.abc import Awaitable, Callable, Coroutine, Iterable, Iterator, Sequence
from collections.abc import Awaitable, Callable, Coroutine, Iterable, Iterator, Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from multiprocessing import cpu_count
Expand Down Expand Up @@ -1154,7 +1154,10 @@ def parse_filters(filters: list[str]) -> list[Filter]:
def evaluate_filter(object, filter: Filter):
key = object
for p in filter.path:
key = key.get(p)
if isinstance(key, Mapping):
key = key.get(p)
else:
key = getattr(key, p, None)
if key is None:
return False
return filter.pattern.match(serialize_json_with_plain_string(key)) is not None
Expand Down
78 changes: 78 additions & 0 deletions py/src/braintrust/test_framework.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import importlib.util
import re
from typing import List
from unittest.mock import MagicMock

Expand All @@ -11,6 +12,9 @@
EvalHooks,
EvalResultWithSummary,
Evaluator,
Filter,
evaluate_filter,
parse_filters,
run_evaluator,
)
from .score import Score, Scorer
Expand Down Expand Up @@ -626,3 +630,77 @@ async def test_run_evaluator_empty_dataset_warns(capsys):
captured = capsys.readouterr()
assert "Warning" in captured.err
assert "empty" in captured.err.lower()


class TestEvaluateFilter:
"""Regression tests for https://github.com/braintrustdata/braintrust-sdk-python/issues/207."""

@pytest.mark.parametrize(
"datum",
[
{"input": "hello", "metadata": {"name": "foo"}},
EvalCase(input="hello", metadata={"name": "foo"}),
],
ids=["dict", "evalcase"],
)
def test_evaluate_filter_match(self, datum):
f = Filter(path=["metadata", "name"], pattern=re.compile("foo"))
assert evaluate_filter(datum, f) is True

@pytest.mark.parametrize(
"datum",
[
{"input": "hello", "metadata": {"name": "bar"}},
EvalCase(input="hello", metadata={"name": "bar"}),
],
ids=["dict", "evalcase"],
)
def test_evaluate_filter_no_match(self, datum):
f = Filter(path=["metadata", "name"], pattern=re.compile("foo"))
assert evaluate_filter(datum, f) is False

@pytest.mark.parametrize(
"datum",
[
{"input": "hello"},
EvalCase(input="hello"),
],
ids=["dict", "evalcase"],
)
def test_evaluate_filter_missing_key(self, datum):
f = Filter(path=["metadata", "name"], pattern=re.compile("foo"))
assert evaluate_filter(datum, f) is False

def test_evaluate_filter_nested_metadata(self):
datum = EvalCase(input="hello", metadata={"priority": "P0", "owner": "alice"})
f = Filter(path=["metadata", "priority"], pattern=re.compile("^P0$"))
assert evaluate_filter(datum, f) is True

def test_evaluate_filter_input_field(self):
datum = EvalCase(input={"text": "hello world"}, metadata={"name": "foo"})
f = Filter(path=["input", "text"], pattern=re.compile("hello"))
assert evaluate_filter(datum, f) is True

@pytest.mark.asyncio
async def test_run_evaluator_with_filter_and_evalcase(self):
data = [
EvalCase(input="hello", metadata={"name": "foo"}),
EvalCase(input="world", metadata={"name": "bar"}),
]

evaluator = Evaluator(
project_name="test-project",
eval_name="test-filter-evalcase",
data=data,
task=lambda x: x,
scores=[],
experiment_name=None,
metadata=None,
)

filters = parse_filters(["metadata.name=foo"])
result = await run_evaluator(experiment=None, evaluator=evaluator, position=None, filters=filters)

# Only the "foo" case should pass the filter
assert len(result.results) == 1
assert result.results[0].input == "hello"
Loading