From 0a423ecf181bcefc9cb26103f161115ab89d3965 Mon Sep 17 00:00:00 2001 From: Andrew Kent Date: Mon, 6 Apr 2026 14:17:17 -0600 Subject: [PATCH] support data and model eval params --- .../dev/braintrust/devserver/Devserver.java | 111 ++++++++-- .../dev/braintrust/devserver/RemoteEval.java | 100 ++------- .../braintrust/devserver/RequestContext.java | 2 +- .../main/java/dev/braintrust/eval/Eval.java | 45 +++- .../dev/braintrust/eval/ParameterDef.java | 126 +++++++++++ .../java/dev/braintrust/eval/Parameters.java | 91 ++++++++ .../braintrust/eval/ScorerBrainstoreImpl.java | 19 +- .../main/java/dev/braintrust/eval/Task.java | 31 ++- .../java/dev/braintrust/eval/TaskResult.java | 10 +- .../braintrust/devserver/DevserverTest.java | 110 +++++++++- .../java/dev/braintrust/eval/EvalTest.java | 39 ++++ .../dev/braintrust/eval/ParametersTest.java | 205 ++++++++++++++++++ .../examples/RemoteEvalWithParamsExample.java | 84 +++++++ 13 files changed, 858 insertions(+), 115 deletions(-) create mode 100644 braintrust-sdk/src/main/java/dev/braintrust/eval/ParameterDef.java create mode 100644 braintrust-sdk/src/main/java/dev/braintrust/eval/Parameters.java create mode 100644 braintrust-sdk/src/test/java/dev/braintrust/eval/ParametersTest.java create mode 100644 examples/src/main/java/dev/braintrust/examples/RemoteEvalWithParamsExample.java diff --git a/braintrust-sdk/src/main/java/dev/braintrust/devserver/Devserver.java b/braintrust-sdk/src/main/java/dev/braintrust/devserver/Devserver.java index d787d279..893359ea 100644 --- a/braintrust-sdk/src/main/java/dev/braintrust/devserver/Devserver.java +++ b/braintrust-sdk/src/main/java/dev/braintrust/devserver/Devserver.java @@ -3,6 +3,7 @@ import static dev.braintrust.json.BraintrustJsonMapper.fromJson; import static dev.braintrust.json.BraintrustJsonMapper.toJson; +import com.fasterxml.jackson.databind.node.NullNode; import com.sun.net.httpserver.HttpExchange; import com.sun.net.httpserver.HttpHandler; import com.sun.net.httpserver.HttpServer; @@ -180,32 +181,36 @@ private void handleList(HttpExchange exchange) throws IOException { Map metadata = new LinkedHashMap<>(); - Map> parametersMap = new LinkedHashMap<>(); - for (Map.Entry paramEntry : - eval.getParameters().entrySet()) { - String paramName = paramEntry.getKey(); - RemoteEval.Parameter param = paramEntry.getValue(); + // Serialize parameters in the container format + if (eval.getParameters().isEmpty()) { + metadata.put("parameters", NullNode.getInstance()); + } else { + Map> schemaMap = new LinkedHashMap<>(); + for (ParameterDef param : eval.getParameters()) { + Map paramMetadata = new LinkedHashMap<>(); + paramMetadata.put("type", param.type().toString().toLowerCase()); - Map paramMetadata = new LinkedHashMap<>(); - paramMetadata.put("type", param.getType().getValue()); + if (param.schema() != null) { + paramMetadata.put("schema", param.schema()); + } - if (param.getDescription() != null) { - paramMetadata.put("description", param.getDescription()); - } + if (param.defaultValue() != null) { + paramMetadata.put("default", param.defaultValue()); + } - if (param.getDefaultValue() != null) { - paramMetadata.put("default", param.getDefaultValue()); - } + if (param.description() != null) { + paramMetadata.put("description", param.description()); + } - // Only include schema for data type parameters - if (param.getType() == RemoteEval.ParameterType.DATA - && param.getSchema() != null) { - paramMetadata.put("schema", param.getSchema()); + schemaMap.put(param.name(), paramMetadata); } - parametersMap.put(paramName, paramMetadata); + Map parametersContainer = new LinkedHashMap<>(); + parametersContainer.put("type", "braintrust.staticParameters"); + parametersContainer.put("schema", schemaMap); + parametersContainer.put("source", NullNode.getInstance()); + metadata.put("parameters", parametersContainer); } - metadata.put("parameters", parametersMap); // Add scores (list of scorer names) List> scores = new ArrayList<>(); @@ -245,7 +250,14 @@ private void handleEval(HttpExchange exchange) throws IOException { try { InputStream requestBody = exchange.getRequestBody(); var requestBodyString = new String(requestBody.readAllBytes(), StandardCharsets.UTF_8); - EvalRequest request = fromJson(requestBodyString, EvalRequest.class); + EvalRequest request; + try { + request = fromJson(requestBodyString, EvalRequest.class); + } catch (Exception e) { + sendResponse( + exchange, 400, "text/plain", "Invalid request body: " + e.getMessage()); + return; + } // Validate evaluator exists RemoteEval eval = evals.get(request.getName()); @@ -376,6 +388,14 @@ private void handleStreamingEval( var tracer = BraintrustTracing.getTracer(); + // Merge parameters: evaluator defaults + request overrides + final Parameters mergedParameters = + new Parameters( + eval.getParameters(), + null == request.getParameters() + ? Map.of() + : request.getParameters()); + // Execute task and scorers for each case final Map> scoresByName = new ConcurrentHashMap<>(); final var parentInfo = extractParentInfo(request); @@ -414,7 +434,9 @@ private void handleStreamingEval( .makeCurrent()) { var task = eval.getTask(); try { - taskResult = task.apply(datasetCase); + taskResult = + task.apply( + datasetCase, mergedParameters); } catch (Exception e) { taskSpan.setStatus( StatusCode.ERROR, e.getMessage()); @@ -431,6 +453,21 @@ private void handleStreamingEval( "Task threw exception for input: " + datasetCase.input(), e); + // Set eval span attributes so Braintrust can + // resolve the trace + setEvalSpanAttributesForError( + evalSpan, + braintrustParent, + braintrustGeneration, + datasetCase); + // Send progress event even on error so the + // Playground can link to the trace + sendProgressEvent( + os, + evalSpan.getSpanContext().getSpanId(), + datasetCase.origin(), + eval.getName(), + null); // run scoreForTaskException on each scorer List> allScorersForError = new ArrayList<>(eval.getScorers()); @@ -578,6 +615,38 @@ private void setEvalSpanAttributes( "braintrust.output_json", toJson(Map.of("output", taskResult.result()))); } + /** + * Sets eval span attributes when the task threw an exception. Similar to {@link + * #setEvalSpanAttributes} but does not require a TaskResult. + */ + private void setEvalSpanAttributesForError( + Span evalSpan, + BraintrustUtils.Parent braintrustParent, + String braintrustGeneration, + DatasetCase datasetCase) { + var spanAttrs = new LinkedHashMap<>(); + spanAttrs.put("type", "eval"); + spanAttrs.put("name", "eval"); + if (braintrustGeneration != null) { + spanAttrs.put("generation", braintrustGeneration); + } + evalSpan.setAttribute(PARENT, braintrustParent.toParentValue()) + .setAttribute("braintrust.span_attributes", toJson(spanAttrs)) + .setAttribute("braintrust.input_json", toJson(Map.of("input", datasetCase.input()))) + .setAttribute("braintrust.expected_json", toJson(datasetCase.expected())); + + if (datasetCase.origin().isPresent()) { + evalSpan.setAttribute("braintrust.origin", toJson(datasetCase.origin().get())); + } + if (!datasetCase.tags().isEmpty()) { + evalSpan.setAttribute( + AttributeKey.stringArrayKey("braintrust.tags"), datasetCase.tags()); + } + if (!datasetCase.metadata().isEmpty()) { + evalSpan.setAttribute("braintrust.metadata", toJson(datasetCase.metadata())); + } + } + private void setTaskSpanAttributes( Span taskSpan, BraintrustUtils.Parent braintrustParent, diff --git a/braintrust-sdk/src/main/java/dev/braintrust/devserver/RemoteEval.java b/braintrust-sdk/src/main/java/dev/braintrust/devserver/RemoteEval.java index 75d569bc..9a2075c6 100644 --- a/braintrust-sdk/src/main/java/dev/braintrust/devserver/RemoteEval.java +++ b/braintrust-sdk/src/main/java/dev/braintrust/devserver/RemoteEval.java @@ -1,11 +1,14 @@ package dev.braintrust.devserver; +import dev.braintrust.eval.DatasetCase; +import dev.braintrust.eval.ParameterDef; +import dev.braintrust.eval.Parameters; import dev.braintrust.eval.Scorer; import dev.braintrust.eval.Task; +import dev.braintrust.eval.TaskResult; import java.util.*; import java.util.function.Function; import javax.annotation.Nonnull; -import javax.annotation.Nullable; import lombok.Builder; import lombok.Getter; import lombok.Singular; @@ -36,8 +39,8 @@ public class RemoteEval { */ @Singular @Nonnull private final List> scorers; - /** Optional parameters that can be configured from the UI */ - @Singular @Nonnull private final Map parameters; + /** Optional parameter definitions that can be configured from the UI */ + @Singular @Nonnull private final List> parameters; public static class Builder { /** @@ -48,84 +51,29 @@ public static class Builder { */ public Builder taskFunction(Function taskFn) { return task( - datasetCase -> { - var result = taskFn.apply(datasetCase.input()); - return new dev.braintrust.eval.TaskResult<>(result, datasetCase); + new Task<>() { + @Override + public TaskResult apply( + DatasetCase datasetCase, Parameters parameters) + throws Exception { + var result = taskFn.apply(datasetCase.input()); + return new TaskResult<>(result, datasetCase, parameters); + } }); } /** Build the RemoteEval */ public RemoteEval build() { - // can add build hooks here later if desired - return internalBuild(); - } - } - - /** Represents a configurable parameter for the evaluator */ - @Getter - @lombok.Builder(builderClassName = "Builder") - public static class Parameter { - /** Type of parameter: "prompt" or "data" */ - @Nonnull private final ParameterType type; - - /** Optional description of the parameter */ - @Nullable private final String description; - - /** Optional default value for the parameter */ - @Nullable private final Object defaultValue; - - /** - * JSON Schema for data type parameters. Only applicable when type is DATA. Should be a Map - * representing a JSON Schema object. - */ - @Nullable private final Map schema; - - public static Parameter promptParameter(String description, Object defaultValue) { - return Parameter.builder() - .type(ParameterType.PROMPT) - .description(description) - .defaultValue(defaultValue) - .build(); - } - - public static Parameter promptParameter(Object defaultValue) { - return promptParameter(null, defaultValue); - } - - public static Parameter dataParameter( - String description, Map schema, Object defaultValue) { - return Parameter.builder() - .type(ParameterType.DATA) - .description(description) - .schema(schema) - .defaultValue(defaultValue) - .build(); - } - - public static Parameter dataParameter(Map schema, Object defaultValue) { - return dataParameter(null, schema, defaultValue); - } - - public static Parameter dataParameter(Map schema) { - return dataParameter(null, schema, null); - } - } - - /** Parameter type enumeration */ - public enum ParameterType { - /** Prompt parameter (for LLM prompts) */ - PROMPT("prompt"), - /** Data parameter (for other configuration data) */ - DATA("data"); - - private final String value; - - ParameterType(String value) { - this.value = value; - } - - public String getValue() { - return value; + var result = internalBuild(); + // Validate parameter names are unique + var seen = new HashSet(); + for (var param : result.getParameters()) { + if (!seen.add(param.name())) { + throw new IllegalArgumentException( + "Duplicate parameter name: '" + param.name() + "'"); + } + } + return result; } } } diff --git a/braintrust-sdk/src/main/java/dev/braintrust/devserver/RequestContext.java b/braintrust-sdk/src/main/java/dev/braintrust/devserver/RequestContext.java index 4213c6e1..debfce32 100644 --- a/braintrust-sdk/src/main/java/dev/braintrust/devserver/RequestContext.java +++ b/braintrust-sdk/src/main/java/dev/braintrust/devserver/RequestContext.java @@ -12,7 +12,7 @@ */ @Getter @Builder -public class RequestContext { +class RequestContext { /** Validated origin from CORS */ private final String appOrigin; diff --git a/braintrust-sdk/src/main/java/dev/braintrust/eval/Eval.java b/braintrust-sdk/src/main/java/dev/braintrust/eval/Eval.java index fa7660b2..52a4de22 100644 --- a/braintrust-sdk/src/main/java/dev/braintrust/eval/Eval.java +++ b/braintrust-sdk/src/main/java/dev/braintrust/eval/Eval.java @@ -38,6 +38,7 @@ public final class Eval { private final @Nonnull List> scorers; private final @Nonnull List tags; private final @Nonnull Map metadata; + private final @Nonnull Parameters parameters; private Eval(Builder builder) { this.experimentName = builder.experimentName; @@ -59,6 +60,7 @@ private Eval(Builder builder) { this.scorers = List.copyOf(builder.scorers); this.tags = List.copyOf(builder.tags); this.metadata = Map.copyOf(builder.metadata); + this.parameters = builder.buildParameters(); } /** Runs the evaluation and returns results. */ @@ -129,7 +131,7 @@ private void evalOne(String experimentId, DatasetCase datasetCase .startSpan(); try (var unused = BraintrustContext.ofExperiment(experimentId, taskSpan).makeCurrent()) { - taskResult = task.apply(datasetCase); + taskResult = task.apply(datasetCase, parameters); rootSpan.setAttribute( "braintrust.output_json", toJson(Map.of("output", taskResult.result()))); @@ -252,6 +254,8 @@ public static final class Builder { private @Nullable Tracer tracer = null; private @Nullable Task task; private @Nonnull List> scorers = List.of(); + private @Nonnull List> parameterDefs = List.of(); + private @Nonnull Map parameterValues = Map.of(); private @Nonnull List tags = List.of(); private @Nonnull Map metadata = Map.of(); @@ -335,9 +339,10 @@ public Builder taskFunction(Function taskFn) { new Task<>() { @Override public TaskResult apply( - DatasetCase datasetCase) { + DatasetCase datasetCase, Parameters parameters) + throws Exception { var result = taskFn.apply(datasetCase.input()); - return new TaskResult<>(result, datasetCase); + return new TaskResult<>(result, datasetCase, parameters); } }); } @@ -365,5 +370,39 @@ public Builder metadata(Map metadata) { this.metadata = Map.copyOf(metadata); return this; } + + /** + * Sets parameter definitions for this eval. Default values from the definitions are used + * unless overridden via {@link #parameterValues(Map)}. + */ + @SuppressWarnings("rawtypes") + public Builder parameters(ParameterDef... parameterDefs) { + this.parameterDefs = List.of(parameterDefs); + return this; + } + + /** Sets parameter definitions for this eval. */ + public Builder parameters(List> parameterDefs) { + this.parameterDefs = List.copyOf(parameterDefs); + return this; + } + + /** + * Sets explicit parameter values, overriding any defaults from parameter definitions. Keys + * not present here fall back to the default value from the corresponding {@link + * ParameterDef}. + */ + public Builder parameterValues(Map values) { + this.parameterValues = Map.copyOf(values); + return this; + } + + /** Builds the merged Parameters from definitions and explicit values. */ + private Parameters buildParameters() { + if (parameterDefs.isEmpty() && parameterValues.isEmpty()) { + return Parameters.empty(); + } + return new Parameters(parameterDefs, parameterValues); + } } } diff --git a/braintrust-sdk/src/main/java/dev/braintrust/eval/ParameterDef.java b/braintrust-sdk/src/main/java/dev/braintrust/eval/ParameterDef.java new file mode 100644 index 00000000..8718bfb4 --- /dev/null +++ b/braintrust-sdk/src/main/java/dev/braintrust/eval/ParameterDef.java @@ -0,0 +1,126 @@ +package dev.braintrust.eval; + +import dev.braintrust.json.BraintrustJsonMapper; +import java.util.Map; +import java.util.Objects; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +/** + * Definition of a named parameter that can be configured from the Braintrust Playground UI. + * + *

Parameter definitions declare what parameters an evaluator accepts, their types, defaults, and + * descriptions. The Playground uses these to render appropriate UI controls. + * + * @param name the parameter name (used as the key in the merged parameters map) + * @param type the parameter type: {@code "data"} for generic values, {@code "model"} for a model + * picker + * @param defaultValue optional default value used when the request does not include this parameter + * @param description optional human-readable description shown in the Playground UI + * @param schema optional JSON Schema fragment describing the value shape (e.g., {@code {"type": + * "string"}}). Only applicable for {@code "data"} type parameters. + */ +public record ParameterDef( + @Nonnull String name, + @Nonnull Type type, + @Nullable T defaultValue, + @Nullable String description, + @Nullable Map schema) { + + public ParameterDef { + Objects.requireNonNull(name); + Objects.requireNonNull(type); + } + + public static ParameterDef data(@Nonnull String name, @Nonnull T defaultValue) { + return data(name, defaultValue, null); + } + + public static ParameterDef data( + @Nonnull String name, @Nonnull T defaultValue, @Nullable String description) { + return (ParameterDef) data(name, defaultValue.getClass(), defaultValue, description); + } + + /** + * Create a data parameter definition with an explicit value class (for when no default is + * provided or the type can't be inferred from the default). + * + * @param valueClass the Java class of the parameter value (e.g., {@code String.class}, {@code + * Map.class}) + */ + public static ParameterDef data( + @Nonnull String name, + @Nonnull Class valueClass, + @Nullable T defaultValue, + @Nullable String description) { + var dataType = DataType.ofClass(valueClass); + if (null == dataType) { + throw new RuntimeException("unsupported parameter value class: " + valueClass); + } + if (DataType.OBJECT.equals(dataType)) { + // fail fast if the class can't serialize + try { + BraintrustJsonMapper.get().constructType(valueClass); + // Try deserializing an empty object — catches missing default constructor, + // missing @JsonCreator, etc. + BraintrustJsonMapper.fromJson("{}", valueClass); + } catch (Exception e) { + throw new RuntimeException( + "invalid object data type. Class is not deserializable by" + + " BraintrustJsonMapper: " + + valueClass, + e); + } + } + return new ParameterDef<>( + name, + Type.DATA, + defaultValue, + description, + Map.of("type", dataType.name().toLowerCase())); + } + + /** Create a model parameter definition. The default value is a model name string. */ + public static ParameterDef model(String name, String defaultValue) { + return model(name, defaultValue, null); + } + + /** Create a model parameter definition with a description. */ + public static ParameterDef model(String name, String defaultValue, String description) { + return new ParameterDef<>(name, Type.MODEL, defaultValue, description, null); + } + + public enum Type { + DATA, + // TODO: prompts not supported yet + // PROMPT, + MODEL + } + + enum DataType { + STRING, + NUMBER, + BOOLEAN, + OBJECT, + ARRAY; + + static @Nullable DataType of(Object parameterValue) { + if (null == parameterValue) return null; + else if (parameterValue instanceof String) return DataType.STRING; + else if (parameterValue instanceof Number) return DataType.NUMBER; + else if (parameterValue instanceof Boolean) return DataType.BOOLEAN; + else if (parameterValue instanceof Iterable) return DataType.ARRAY; + else return DataType.OBJECT; + } + + static @Nullable DataType ofClass(Class clazz) { + if (String.class.isAssignableFrom(clazz)) return STRING; + if (Number.class.isAssignableFrom(clazz)) return NUMBER; + if (Boolean.class.isAssignableFrom(clazz)) return BOOLEAN; + if (Iterable.class.isAssignableFrom(clazz)) return ARRAY; + if (Map.class.isAssignableFrom(clazz)) return OBJECT; + // Assume any other class is a Jackson-serializable POJO + return OBJECT; + } + } +} diff --git a/braintrust-sdk/src/main/java/dev/braintrust/eval/Parameters.java b/braintrust-sdk/src/main/java/dev/braintrust/eval/Parameters.java new file mode 100644 index 00000000..30d8f3ae --- /dev/null +++ b/braintrust-sdk/src/main/java/dev/braintrust/eval/Parameters.java @@ -0,0 +1,91 @@ +package dev.braintrust.eval; + +import dev.braintrust.json.BraintrustJsonMapper; +import java.util.*; +import javax.annotation.Nonnull; +import lombok.Getter; +import lombok.extern.slf4j.Slf4j; + +/** + * Holds the merged parameter values for a single eval run, along with the parameter definitions. + * + *

Parameter values are the result of merging evaluator defaults with request overrides. This + * class provides typed accessors so task and scorer implementations don't need to cast. + */ +@Slf4j +public class Parameters { + private static final Parameters EMPTY = + new Parameters(Collections.emptyList(), Collections.emptyMap()); + + /** -- GETTER -- Returns the merged parameter values as an unmodifiable map. */ + @Getter private final Map merged; + + public Parameters(List> definitions, Map requestParams) { + var remaining = new LinkedHashMap<>(requestParams); + Map merged = new LinkedHashMap<>(); + for (var def : definitions) { + var paramVal = remaining.remove(def.name()); + if (null == paramVal) { + paramVal = def.defaultValue(); + } + if (null != paramVal) { + merged.put(def.name(), paramVal); + } + } + if (!remaining.isEmpty()) { + log.warn("unknown param names found in eval request: {}", remaining.keySet()); + } + this.merged = Collections.unmodifiableMap(new LinkedHashMap<>(merged)); + // NOTE: we're not holding on to definitions outside of the constructor, but we may wish to + // surface them later + } + + /** Returns an empty {@code Parameters} instance with no values or definitions. */ + public static Parameters empty() { + return EMPTY; + } + + /** Returns true if no parameter values are present. */ + public boolean isEmpty() { + return merged.isEmpty(); + } + + /** Returns true if a value exists for the given key. */ + public boolean has(String key) { + return merged.containsKey(key); + } + + /** Returns the raw value for the given key, or null if absent. */ + public T get(@Nonnull String key, @Nonnull Class paramClass) { + Objects.requireNonNull(key); + Objects.requireNonNull(paramClass); + if (!has(key)) { + throw new RuntimeException("param not found: " + key); + } + var rawParam = merged.get(key); + if (rawParam == null) { + return null; + } + // Coerce integer types to floating point if requested + if (rawParam instanceof Number number) { + if (paramClass == Double.class || paramClass == double.class) { + return (T) (Double) number.doubleValue(); + } + if (paramClass == Float.class || paramClass == float.class) { + return (T) (Float) number.floatValue(); + } + } + var actualClass = rawParam.getClass(); + if (paramClass.isAssignableFrom(actualClass)) { + return (T) rawParam; + } + // Auto-convert using Jackson (e.g., Map -> POJO when Playground sends JSON objects) + try { + return BraintrustJsonMapper.get().convertValue(rawParam, paramClass); + } catch (IllegalArgumentException e) { + throw new ClassCastException( + "cannot convert param \"%s\" (%s) to %s: %s" + .formatted(key, actualClass, paramClass, e.getMessage())); + } + } +} diff --git a/braintrust-sdk/src/main/java/dev/braintrust/eval/ScorerBrainstoreImpl.java b/braintrust-sdk/src/main/java/dev/braintrust/eval/ScorerBrainstoreImpl.java index 3baceecf..c6f1d5fb 100644 --- a/braintrust-sdk/src/main/java/dev/braintrust/eval/ScorerBrainstoreImpl.java +++ b/braintrust-sdk/src/main/java/dev/braintrust/eval/ScorerBrainstoreImpl.java @@ -60,14 +60,17 @@ public List score(TaskResult taskResult) { // Build parent span components for distributed tracing (as object, not base64 string) Object parent = buildParentSpanComponents(); - var request = - BraintrustApiClient.FunctionInvokeRequest.of( - taskResult.datasetCase().input(), - taskResult.result(), - taskResult.datasetCase().expected(), - taskResult.datasetCase().metadata(), - version, - parent); + // Build scorer args map with optional parameters + var scorerArgs = new java.util.LinkedHashMap(); + scorerArgs.put("input", taskResult.datasetCase().input()); + scorerArgs.put("output", taskResult.result()); + scorerArgs.put("expected", taskResult.datasetCase().expected()); + scorerArgs.put("metadata", taskResult.datasetCase().metadata()); + if (!taskResult.parameters().isEmpty()) { + scorerArgs.put("parameters", taskResult.parameters().getMerged()); + } + + var request = new BraintrustApiClient.FunctionInvokeRequest(scorerArgs, version, parent); Object result = apiClient.invokeFunction(getFunctionId(), request); return parseScoreResult(result); diff --git a/braintrust-sdk/src/main/java/dev/braintrust/eval/Task.java b/braintrust-sdk/src/main/java/dev/braintrust/eval/Task.java index 0e0dfa96..9072a9b1 100644 --- a/braintrust-sdk/src/main/java/dev/braintrust/eval/Task.java +++ b/braintrust-sdk/src/main/java/dev/braintrust/eval/Task.java @@ -6,17 +6,44 @@ *

If the task throws an exception, the error is recorded on the span and each scorer's {@link * Scorer#scoreForTaskException} method is invoked instead of {@link Scorer#score}. * + *

Tasks may optionally accept {@link Parameters} by overriding the two-arg {@link + * #apply(DatasetCase, Parameters)} method. Tasks that don't need parameters can override the + * single-arg {@link #apply(DatasetCase)} method instead — the two-arg version delegates to it by + * default. + * * @param type of the input data * @param type of the output data */ public interface Task { /** - * Executes this task against a single dataset case and returns the result. + * Executes this task against a single dataset case, with access to merged eval parameters. + * + *

Override this method if your task needs to read parameter values (e.g., model name, + * temperature). The default implementation delegates to {@link #apply(DatasetCase)}, ignoring + * parameters. + * + * @param datasetCase the dataset case to evaluate + * @param parameters the merged parameter values for this eval run + * @return the task result containing the output and the originating dataset case + * @throws Exception if the task fails, the error will be recorded on the span and scoring will + * fall back to {@link Scorer#scoreForTaskException} + */ + TaskResult apply(DatasetCase datasetCase, Parameters parameters) + throws Exception; + + /** + * Executes this task against a single dataset case. + * + *

Override this method for tasks that do not need parameter values. If you need parameters, + * override {@link #apply(DatasetCase, Parameters)} instead. * * @param datasetCase the dataset case to evaluate * @return the task result containing the output and the originating dataset case * @throws Exception if the task fails, the error will be recorded on the span and scoring will * fall back to {@link Scorer#scoreForTaskException} */ - TaskResult apply(DatasetCase datasetCase) throws Exception; + default TaskResult apply(DatasetCase datasetCase) + throws Exception { + return apply(datasetCase, Parameters.empty()); + } } diff --git a/braintrust-sdk/src/main/java/dev/braintrust/eval/TaskResult.java b/braintrust-sdk/src/main/java/dev/braintrust/eval/TaskResult.java index 55fe3e49..f60b89f2 100644 --- a/braintrust-sdk/src/main/java/dev/braintrust/eval/TaskResult.java +++ b/braintrust-sdk/src/main/java/dev/braintrust/eval/TaskResult.java @@ -1,8 +1,16 @@ package dev.braintrust.eval; +import javax.annotation.Nonnull; + /** Result from a single task run. */ public record TaskResult( /** task output */ OUTPUT result, /** The dataset case the task ran against to produce the result */ - DatasetCase datasetCase) {} + DatasetCase datasetCase, + /** The merged parameter values for this eval run */ + @Nonnull Parameters parameters) { + public TaskResult(OUTPUT result, DatasetCase datasetCase) { + this(result, datasetCase, Parameters.empty()); + } +} diff --git a/braintrust-sdk/src/test/java/dev/braintrust/devserver/DevserverTest.java b/braintrust-sdk/src/test/java/dev/braintrust/devserver/DevserverTest.java index 49034c33..2eae57cb 100644 --- a/braintrust-sdk/src/test/java/dev/braintrust/devserver/DevserverTest.java +++ b/braintrust-sdk/src/test/java/dev/braintrust/devserver/DevserverTest.java @@ -38,6 +38,7 @@ class DevserverTest { private static final ObjectMapper JSON_MAPPER = new ObjectMapper(); private static final String REMOTE_EVAL_NAME = "food-type-classifier"; + private static final String PARAM_EVAL_NAME = "param-eval"; private static final String TASK_ERROR_EVAL_NAME = "task-error-eval"; private static final String SCORER_ERROR_EVAL_NAME = "scorer-error-eval"; private static final BraintrustUtils.Parent PLAYGROUND_PARENT = @@ -110,10 +111,36 @@ public List score( .scorer(Scorer.of("working_scorer", (expected, result) -> 1.0)) .build(); + // Eval with parameters — task returns the model param as its output + RemoteEval paramEval = + RemoteEval.builder() + .name(PARAM_EVAL_NAME) + .parameter(dev.braintrust.eval.ParameterDef.data("model", "gpt-4")) + .parameter(dev.braintrust.eval.ParameterDef.data("temperature", 0.5)) + .task( + new dev.braintrust.eval.Task<>() { + @Override + public dev.braintrust.eval.TaskResult apply( + dev.braintrust.eval.DatasetCase + datasetCase, + dev.braintrust.eval.Parameters parameters) + throws Exception { + // Echo both params so tests can verify defaults + overrides + String model = parameters.get("model", String.class); + Double temp = parameters.get("temperature", Double.class); + String output = model + ":" + temp; + return new dev.braintrust.eval.TaskResult<>( + output, datasetCase, parameters); + } + }) + .scorer(Scorer.of("static_scorer", (expected, result) -> 1.0)) + .build(); + server = Devserver.builder() .config(testHarness.braintrust().config()) .registerEval(testEval) + .registerEval(paramEval) .registerEval(taskErrorEval) .registerEval(scorerErrorEval) .host("localhost") @@ -305,7 +332,24 @@ void testStreamingEval() throws Exception { } // Get exported spans from test harness (since devserver uses global tracer) - List exportedSpans = testHarness.awaitExportedSpans(); + // Filter to only spans belonging to this test by finding trace IDs from eval spans + // with the right generation tag + List allSpans = testHarness.awaitExportedSpans(); + var traceIds = + allSpans.stream() + .filter( + s -> { + var attrs = + s.getAttributes() + .get( + AttributeKey.stringKey( + "braintrust.span_attributes")); + return attrs != null && attrs.contains("test-gen-1"); + }) + .map(s -> s.getTraceId()) + .collect(java.util.stream.Collectors.toSet()); + List exportedSpans = + allSpans.stream().filter(s -> traceIds.contains(s.getTraceId())).toList(); assertFalse(exportedSpans.isEmpty(), "Should have exported spans"); // We should have 2 eval traces (one per dataset case), each with task, scores, and custom @@ -653,10 +697,11 @@ void testTaskErrorHandling() throws Exception { assertEquals(1, summaryEvents.size(), "Should have 1 summary event"); assertEquals(1, doneEvents.size(), "Should have 1 done event"); - // Only the good case should produce a progress event (bad case task throws before progress) + // Both cases should produce a progress event (error case sends progress with null output + // so the Playground can link to the trace) List> progressEvents = events.stream().filter(e -> "progress".equals(e.get("event"))).toList(); - assertEquals(1, progressEvents.size(), "Only the successful case should send a progress"); + assertEquals(2, progressEvents.size(), "Both cases should send a progress event"); // Verify summary includes the fallback score from scoreForTaskException (default 0.0) JsonNode summaryData = JSON_MAPPER.readTree(summaryEvents.get(0).get("data")); @@ -926,6 +971,65 @@ void testScorerErrorHandling() throws Exception { "working scorer should produce 1.0"); } + @Test + void testParameterDefaultsAndOverrides() throws Exception { + EvalRequest evalRequest = new EvalRequest(); + evalRequest.setName(PARAM_EVAL_NAME); + evalRequest.setStream(true); + + // Override temperature but let model use its default + evalRequest.setParameters(Map.of("temperature", 0.9)); + + EvalRequest.DataSpec dataSpec = new EvalRequest.DataSpec(); + EvalRequest.EvalCaseData case1 = new EvalRequest.EvalCaseData(); + case1.setInput("hello"); + case1.setExpected("world"); + dataSpec.setData(List.of(case1)); + evalRequest.setData(dataSpec); + + Map parentSpec = + Map.of( + "object_type", PLAYGROUND_PARENT.type(), + "object_id", PLAYGROUND_PARENT.id(), + "propagated_event", + Map.of("span_attributes", Map.of("generation", "test-gen-params"))); + evalRequest.setParent(parentSpec); + + String requestBody = JSON_MAPPER.writeValueAsString(evalRequest); + + HttpURLConnection conn = + (HttpURLConnection) new URI(TEST_URL + "/eval").toURL().openConnection(); + conn.setRequestMethod("POST"); + conn.setRequestProperty("Content-Type", "application/json"); + conn.setRequestProperty("x-bt-auth-token", testHarness.braintrustApiKey()); + conn.setRequestProperty("x-bt-project-id", TestHarness.defaultProjectId()); + conn.setRequestProperty("x-bt-org-name", TestHarness.defaultOrgName()); + conn.setDoOutput(true); + + conn.getOutputStream().write(requestBody.getBytes(StandardCharsets.UTF_8)); + conn.getOutputStream().flush(); + + assertEquals(200, conn.getResponseCode()); + + List> events = readSSEEvents(conn); + + List> progressEvents = + events.stream().filter(e -> "progress".equals(e.get("event"))).toList(); + assertEquals(1, progressEvents.size(), "Should have 1 progress event"); + + // Task echoes "model:temperature" — model should be default "gpt-4", + // temperature should be the overridden 0.9 + JsonNode progressData = JSON_MAPPER.readTree(progressEvents.get(0).get("data")); + String taskOutput = progressData.get("data").asText(); + assertEquals( + "\"gpt-4:0.9\"", + taskOutput, + "Task should receive default model and overridden temperature"); + + assertEquals(1, events.stream().filter(e -> "summary".equals(e.get("event"))).count()); + assertEquals(1, events.stream().filter(e -> "done".equals(e.get("event"))).count()); + } + /** Helper to read SSE events from an HttpURLConnection response. */ private List> readSSEEvents(HttpURLConnection conn) throws Exception { BufferedReader reader = diff --git a/braintrust-sdk/src/test/java/dev/braintrust/eval/EvalTest.java b/braintrust-sdk/src/test/java/dev/braintrust/eval/EvalTest.java index a74699d2..cb58ed32 100644 --- a/braintrust-sdk/src/test/java/dev/braintrust/eval/EvalTest.java +++ b/braintrust-sdk/src/test/java/dev/braintrust/eval/EvalTest.java @@ -649,4 +649,43 @@ public List score( ((Number) workingScoresJson.get("working_scorer")).doubleValue(), "working scorer should produce 1.0 (exact match)"); } + + @Test + @SneakyThrows + public void evalPassesParametersToTask() { + var receivedModel = new java.util.concurrent.atomic.AtomicReference(); + var receivedTemp = new java.util.concurrent.atomic.AtomicReference(); + + var eval = + testHarness + .braintrust() + .evalBuilder() + .name("unit-test-eval") + .cases(DatasetCase.of("hello", "world")) + .parameters( + ParameterDef.data("model", "gpt-4", "Model to use"), + ParameterDef.data("temperature", 0.5, "Sampling temperature")) + // Override temperature, keep model default + .parameterValues(Map.of("temperature", 0.9)) + .task( + new Task<>() { + @Override + public TaskResult apply( + DatasetCase datasetCase, + Parameters parameters) + throws Exception { + receivedModel.set(parameters.get("model", String.class)); + receivedTemp.set( + parameters.get("temperature", Double.class)); + return new TaskResult<>("world", datasetCase, parameters); + } + }) + .scorers(Scorer.of("exact", (expected, result) -> 1.0)) + .build(); + + eval.run(); + + assertEquals("gpt-4", receivedModel.get(), "should receive default model value"); + assertEquals(0.9, receivedTemp.get(), "should receive overridden temperature value"); + } } diff --git a/braintrust-sdk/src/test/java/dev/braintrust/eval/ParametersTest.java b/braintrust-sdk/src/test/java/dev/braintrust/eval/ParametersTest.java new file mode 100644 index 00000000..5cc9c572 --- /dev/null +++ b/braintrust-sdk/src/test/java/dev/braintrust/eval/ParametersTest.java @@ -0,0 +1,205 @@ +package dev.braintrust.eval; + +import static org.junit.jupiter.api.Assertions.*; + +import com.fasterxml.jackson.annotation.JsonProperty; +import dev.braintrust.devserver.RemoteEval; +import java.util.*; +import org.junit.jupiter.api.Test; + +public class ParametersTest { + + @Test + void emptyParametersHasNoValues() { + Parameters params = Parameters.empty(); + assertFalse(params.has("anything")); + assertTrue(params.isEmpty()); + } + + @Test + void getReturnsDefaultWhenNoRequestOverride() { + var params = new Parameters(List.of(ParameterDef.data("model", "gpt-4")), Map.of()); + assertTrue(params.has("model")); + assertEquals("gpt-4", params.get("model", String.class)); + } + + @Test + void requestOverridesDefault() { + var params = + new Parameters( + List.of(ParameterDef.data("model", "gpt-4")), Map.of("model", "gpt-5")); + assertEquals("gpt-5", params.get("model", String.class)); + } + + @Test + void throwsWhenParamNotDefined() { + var params = new Parameters(List.of(ParameterDef.data("model", "gpt-4")), Map.of()); + assertThrows(Exception.class, () -> params.get("temperature", String.class)); + } + + @Test + void numericValues() { + var params = + new Parameters( + List.of( + ParameterDef.data("temperature", 0.7), + ParameterDef.data("max_tokens", 100)), + Map.of()); + + assertEquals(0.7, params.get("temperature", Double.class)); + assertEquals(100, params.get("max_tokens", Integer.class)); + } + + @Test + void booleanValues() { + var params = new Parameters(List.of(ParameterDef.data("verbose", true)), Map.of()); + assertTrue(params.get("verbose", Boolean.class)); + } + + @Test + void intCoercesToDouble() { + // JSON deserializes 1 as Integer, but caller wants Double + var params = new Parameters(List.of(ParameterDef.data("temp", 1)), Map.of()); + assertEquals(1.0, params.get("temp", Double.class)); + } + + @Test + void intCoercesToFloat() { + var params = new Parameters(List.of(ParameterDef.data("temp", 1)), Map.of()); + assertEquals(1.0f, params.get("temp", Float.class)); + } + + @Test + void longCoercesToDouble() { + var params = new Parameters(List.of(ParameterDef.data("count", 100L)), Map.of()); + assertEquals(100.0, params.get("count", Double.class)); + } + + @Test + void classCastThrows() { + var params = new Parameters(List.of(ParameterDef.data("model", "gpt-4")), Map.of()); + assertThrows(ClassCastException.class, () -> params.get("model", Integer.class)); + } + + @Test + void unknownRequestParamsAreIgnored() { + var requestParams = new LinkedHashMap(); + requestParams.put("model", "gpt-4"); + requestParams.put("unknown_key", "should be dropped"); + var params = new Parameters(List.of(ParameterDef.data("model", "gpt-4")), requestParams); + assertTrue(params.has("model")); + assertFalse(params.has("unknown_key")); + } + + @Test + void paramWithNullDefaultAndNoRequestValueIsAbsent() { + var params = + new Parameters( + List.of(ParameterDef.data("foo", String.class, null, null)), Map.of()); + assertFalse(params.has("foo")); + } + + @Test + void listParameterDefInfersArraySchema() { + var def = ParameterDef.data("tags", List.of("foo", "bar"), "Tag list"); + assertEquals("tags", def.name()); + assertEquals(Map.of("type", "array"), def.schema()); + assertEquals(List.of("foo", "bar"), def.defaultValue()); + } + + @Test + void emptyIsSingleton() { + assertSame(Parameters.empty(), Parameters.empty()); + } + + @Test + void remoteEvalBuilderRejectsDuplicateParameterNames() { + assertThrows( + Exception.class, + () -> + RemoteEval.builder() + .name("test") + .taskFunction(input -> "out") + .parameter(ParameterDef.model("model", "gpt-4")) + .parameter(ParameterDef.data("model", "gpt-5")) + .build()); + } + + /** A complex Jackson-serializable POJO for testing OBJECT data type. */ + record ModelConfig( + @JsonProperty("model_name") String modelName, + @JsonProperty("max_tokens") int maxTokens, + @JsonProperty("stop_sequences") List stopSequences) {} + + @Test + void complexObjectParameterDef() { + var defaultConfig = new ModelConfig("gpt-4", 1024, List.of("\n", "###")); + var def = ParameterDef.data("config", defaultConfig, "Model configuration"); + + assertEquals("config", def.name()); + assertEquals(ParameterDef.Type.DATA, def.type()); + assertEquals(Map.of("type", "object"), def.schema()); + assertEquals("Model configuration", def.description()); + assertSame(defaultConfig, def.defaultValue()); + } + + @Test + void complexObjectParameterMergesFromRequest() { + // Playground would send a JSON object which Jackson deserializes as a Map + var expectedConfig = new ModelConfig("gpt-5", 2048, List.of("END")); + var requestJson = + Map.of("model_name", "gpt-5", "max_tokens", 2048, "stop_sequences", List.of("END")); + + var params = + new Parameters( + List.of(ParameterDef.data("config", ModelConfig.class, null, null)), + Map.of("config", requestJson)); + + // Request override replaces the default entirely + assertEquals(expectedConfig, params.get("config", ModelConfig.class)); + } + + @Test + void dataTypeInferenceCoversAllTypes() { + assertEquals(Map.of("type", "string"), ParameterDef.data("a", "hello").schema()); + assertEquals(Map.of("type", "number"), ParameterDef.data("b", 3.14).schema()); + assertEquals(Map.of("type", "number"), ParameterDef.data("c", 42).schema()); + assertEquals(Map.of("type", "boolean"), ParameterDef.data("d", true).schema()); + assertEquals(Map.of("type", "array"), ParameterDef.data("e", List.of(1, 2)).schema()); + assertEquals( + Map.of("type", "object"), + ParameterDef.data("f", new ModelConfig("x", 1, List.of())).schema()); + assertEquals(Map.of("type", "object"), ParameterDef.data("g", Map.of("k", "v")).schema()); + } + + @Test + void autoConvertFailsForNonDeserializableType() { + // A class with no Jackson-friendly constructor or annotations + class Opaque { + private final int x; + + Opaque(int x) { + this.x = x; + } + } + assertThrows( + Exception.class, + () -> + ParameterDef.data( + "val", + Opaque.class, + new Opaque(33), + "a custom object that doesn't serialize"), + "parameter defs must throw for objects that don't serialize"); + // Also fails fast with null default + assertThrows( + Exception.class, + () -> + ParameterDef.data( + "val", + Opaque.class, + null, + "a custom object that doesn't serialize"), + "parameter defs must throw for objects that don't serialize"); + } +} diff --git a/examples/src/main/java/dev/braintrust/examples/RemoteEvalWithParamsExample.java b/examples/src/main/java/dev/braintrust/examples/RemoteEvalWithParamsExample.java new file mode 100644 index 00000000..419fcd36 --- /dev/null +++ b/examples/src/main/java/dev/braintrust/examples/RemoteEvalWithParamsExample.java @@ -0,0 +1,84 @@ +package dev.braintrust.examples; + +import com.openai.client.okhttp.OpenAIOkHttpClient; +import com.openai.models.chat.completions.ChatCompletionCreateParams; +import dev.braintrust.Braintrust; +import dev.braintrust.devserver.Devserver; +import dev.braintrust.devserver.RemoteEval; +import dev.braintrust.eval.*; +import dev.braintrust.instrumentation.openai.BraintrustOpenAI; +import java.util.List; + +/** Simple Dev Server for Remote Evals */ +public class RemoteEvalWithParamsExample { + public static void main(String[] args) throws Exception { + var braintrust = Braintrust.get(); + var openTelemetry = braintrust.openTelemetryCreate(); + var openAIClient = BraintrustOpenAI.wrapOpenAI(openTelemetry, OpenAIOkHttpClient.fromEnv()); + + RemoteEval foodTypeEval = + RemoteEval.builder() + .name("food-type-classifier") + .task( + (datasetCase, parameters) -> { + var food = datasetCase.input(); + var request = + ChatCompletionCreateParams.builder() + .model(parameters.get("model", String.class)) + .addSystemMessage("Return a one word answer") + .addUserMessage( + "What kind of food is " + food + "?") + .temperature( + parameters.get( + "temperature", Double.class)) + .build(); + var response = + openAIClient.chat().completions().create(request); + var responseText = + response.choices() + .get(0) + .message() + .content() + .orElseThrow() + .toLowerCase(); + return new TaskResult<>(responseText, datasetCase, parameters); + }) + .scorers( + List.of( + Scorer.of("static_scorer", (expected, result) -> 0.7), + Scorer.of( + "close_enough_match", + (expected, result) -> + expected.trim() + .equalsIgnoreCase( + result.trim()) + ? 1.0 + : 0.0))) + .parameters( + List.of( + ParameterDef.model( + "model", "gpt-4o-mini", "openai model to use"), + ParameterDef.data("temperature", 0.0, "model temperature"))) + .build(); + + Devserver devserver = + Devserver.builder() + .config(braintrust.config()) + .registerEval(foodTypeEval) + .host("localhost") // set to 0.0.0.0 to bind all interfaces + .port(8301) + .build(); + + Runtime.getRuntime() + .addShutdownHook( + new Thread( + () -> { + System.out.println("Shutting down..."); + devserver.stop(); + System.out.flush(); + System.err.flush(); + })); + System.out.println("Starting Braintrust dev server"); + devserver.start(); + } +}