Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 90 additions & 21 deletions braintrust-sdk/src/main/java/dev/braintrust/devserver/Devserver.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import static dev.braintrust.json.BraintrustJsonMapper.fromJson;
import static dev.braintrust.json.BraintrustJsonMapper.toJson;

import com.fasterxml.jackson.databind.node.NullNode;
import com.sun.net.httpserver.HttpExchange;
import com.sun.net.httpserver.HttpHandler;
import com.sun.net.httpserver.HttpServer;
Expand Down Expand Up @@ -180,32 +181,36 @@ private void handleList(HttpExchange exchange) throws IOException {

Map<String, Object> metadata = new LinkedHashMap<>();

Map<String, Map<String, Object>> parametersMap = new LinkedHashMap<>();
for (Map.Entry<String, RemoteEval.Parameter> paramEntry :
eval.getParameters().entrySet()) {
String paramName = paramEntry.getKey();
RemoteEval.Parameter param = paramEntry.getValue();
// Serialize parameters in the container format
if (eval.getParameters().isEmpty()) {
metadata.put("parameters", NullNode.getInstance());
} else {
Map<String, Map<String, Object>> schemaMap = new LinkedHashMap<>();
for (ParameterDef<?> param : eval.getParameters()) {
Map<String, Object> paramMetadata = new LinkedHashMap<>();
paramMetadata.put("type", param.type().toString().toLowerCase());

Map<String, Object> paramMetadata = new LinkedHashMap<>();
paramMetadata.put("type", param.getType().getValue());
if (param.schema() != null) {
paramMetadata.put("schema", param.schema());
}

if (param.getDescription() != null) {
paramMetadata.put("description", param.getDescription());
}
if (param.defaultValue() != null) {
paramMetadata.put("default", param.defaultValue());
}

if (param.getDefaultValue() != null) {
paramMetadata.put("default", param.getDefaultValue());
}
if (param.description() != null) {
paramMetadata.put("description", param.description());
}

// Only include schema for data type parameters
if (param.getType() == RemoteEval.ParameterType.DATA
&& param.getSchema() != null) {
paramMetadata.put("schema", param.getSchema());
schemaMap.put(param.name(), paramMetadata);
}

parametersMap.put(paramName, paramMetadata);
Map<String, Object> parametersContainer = new LinkedHashMap<>();
parametersContainer.put("type", "braintrust.staticParameters");
parametersContainer.put("schema", schemaMap);
parametersContainer.put("source", NullNode.getInstance());
metadata.put("parameters", parametersContainer);
}
metadata.put("parameters", parametersMap);

// Add scores (list of scorer names)
List<Map<String, String>> scores = new ArrayList<>();
Expand Down Expand Up @@ -245,7 +250,14 @@ private void handleEval(HttpExchange exchange) throws IOException {
try {
InputStream requestBody = exchange.getRequestBody();
var requestBodyString = new String(requestBody.readAllBytes(), StandardCharsets.UTF_8);
EvalRequest request = fromJson(requestBodyString, EvalRequest.class);
EvalRequest request;
try {
request = fromJson(requestBodyString, EvalRequest.class);
} catch (Exception e) {
sendResponse(
exchange, 400, "text/plain", "Invalid request body: " + e.getMessage());
return;
}

// Validate evaluator exists
RemoteEval eval = evals.get(request.getName());
Expand Down Expand Up @@ -376,6 +388,14 @@ private <I, O> void handleStreamingEval(

var tracer = BraintrustTracing.getTracer();

// Merge parameters: evaluator defaults + request overrides
final Parameters mergedParameters =
new Parameters(
eval.getParameters(),
null == request.getParameters()
? Map.of()
: request.getParameters());

// Execute task and scorers for each case
final Map<String, List<Double>> scoresByName = new ConcurrentHashMap<>();
final var parentInfo = extractParentInfo(request);
Expand Down Expand Up @@ -414,7 +434,9 @@ private <I, O> void handleStreamingEval(
.makeCurrent()) {
var task = eval.getTask();
try {
taskResult = task.apply(datasetCase);
taskResult =
task.apply(
datasetCase, mergedParameters);
} catch (Exception e) {
taskSpan.setStatus(
StatusCode.ERROR, e.getMessage());
Expand All @@ -431,6 +453,21 @@ private <I, O> void handleStreamingEval(
"Task threw exception for input: "
+ datasetCase.input(),
e);
// Set eval span attributes so Braintrust can
// resolve the trace
setEvalSpanAttributesForError(
evalSpan,
braintrustParent,
braintrustGeneration,
datasetCase);
// Send progress event even on error so the
// Playground can link to the trace
sendProgressEvent(
os,
evalSpan.getSpanContext().getSpanId(),
datasetCase.origin(),
eval.getName(),
null);
// run scoreForTaskException on each scorer
List<Scorer<I, O>> allScorersForError =
new ArrayList<>(eval.getScorers());
Expand Down Expand Up @@ -578,6 +615,38 @@ private void setEvalSpanAttributes(
"braintrust.output_json", toJson(Map.of("output", taskResult.result())));
}

/**
* Sets eval span attributes when the task threw an exception. Similar to {@link
* #setEvalSpanAttributes} but does not require a TaskResult.
*/
private void setEvalSpanAttributesForError(
Span evalSpan,
BraintrustUtils.Parent braintrustParent,
String braintrustGeneration,
DatasetCase<?, ?> datasetCase) {
var spanAttrs = new LinkedHashMap<>();
spanAttrs.put("type", "eval");
spanAttrs.put("name", "eval");
if (braintrustGeneration != null) {
spanAttrs.put("generation", braintrustGeneration);
}
evalSpan.setAttribute(PARENT, braintrustParent.toParentValue())
.setAttribute("braintrust.span_attributes", toJson(spanAttrs))
.setAttribute("braintrust.input_json", toJson(Map.of("input", datasetCase.input())))
.setAttribute("braintrust.expected_json", toJson(datasetCase.expected()));

if (datasetCase.origin().isPresent()) {
evalSpan.setAttribute("braintrust.origin", toJson(datasetCase.origin().get()));
}
if (!datasetCase.tags().isEmpty()) {
evalSpan.setAttribute(
AttributeKey.stringArrayKey("braintrust.tags"), datasetCase.tags());
}
if (!datasetCase.metadata().isEmpty()) {
evalSpan.setAttribute("braintrust.metadata", toJson(datasetCase.metadata()));
}
}

private void setTaskSpanAttributes(
Span taskSpan,
BraintrustUtils.Parent braintrustParent,
Expand Down
100 changes: 24 additions & 76 deletions braintrust-sdk/src/main/java/dev/braintrust/devserver/RemoteEval.java
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package dev.braintrust.devserver;

import dev.braintrust.eval.DatasetCase;
import dev.braintrust.eval.ParameterDef;
import dev.braintrust.eval.Parameters;
import dev.braintrust.eval.Scorer;
import dev.braintrust.eval.Task;
import dev.braintrust.eval.TaskResult;
import java.util.*;
import java.util.function.Function;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import lombok.Builder;
import lombok.Getter;
import lombok.Singular;
Expand Down Expand Up @@ -36,8 +39,8 @@ public class RemoteEval<INPUT, OUTPUT> {
*/
@Singular @Nonnull private final List<Scorer<INPUT, OUTPUT>> scorers;

/** Optional parameters that can be configured from the UI */
@Singular @Nonnull private final Map<String, Parameter> parameters;
/** Optional parameter definitions that can be configured from the UI */
@Singular @Nonnull private final List<ParameterDef<?>> parameters;

public static class Builder<INPUT, OUTPUT> {
/**
Expand All @@ -48,84 +51,29 @@ public static class Builder<INPUT, OUTPUT> {
*/
public Builder<INPUT, OUTPUT> taskFunction(Function<INPUT, OUTPUT> taskFn) {
return task(
datasetCase -> {
var result = taskFn.apply(datasetCase.input());
return new dev.braintrust.eval.TaskResult<>(result, datasetCase);
new Task<>() {
@Override
public TaskResult<INPUT, OUTPUT> apply(
DatasetCase<INPUT, OUTPUT> datasetCase, Parameters parameters)
throws Exception {
var result = taskFn.apply(datasetCase.input());
return new TaskResult<>(result, datasetCase, parameters);
}
});
}

/** Build the RemoteEval */
public RemoteEval<INPUT, OUTPUT> build() {
// can add build hooks here later if desired
return internalBuild();
}
}

/** Represents a configurable parameter for the evaluator */
@Getter
@lombok.Builder(builderClassName = "Builder")
public static class Parameter {
/** Type of parameter: "prompt" or "data" */
@Nonnull private final ParameterType type;

/** Optional description of the parameter */
@Nullable private final String description;

/** Optional default value for the parameter */
@Nullable private final Object defaultValue;

/**
* JSON Schema for data type parameters. Only applicable when type is DATA. Should be a Map
* representing a JSON Schema object.
*/
@Nullable private final Map<String, Object> schema;

public static Parameter promptParameter(String description, Object defaultValue) {
return Parameter.builder()
.type(ParameterType.PROMPT)
.description(description)
.defaultValue(defaultValue)
.build();
}

public static Parameter promptParameter(Object defaultValue) {
return promptParameter(null, defaultValue);
}

public static Parameter dataParameter(
String description, Map<String, Object> schema, Object defaultValue) {
return Parameter.builder()
.type(ParameterType.DATA)
.description(description)
.schema(schema)
.defaultValue(defaultValue)
.build();
}

public static Parameter dataParameter(Map<String, Object> schema, Object defaultValue) {
return dataParameter(null, schema, defaultValue);
}

public static Parameter dataParameter(Map<String, Object> schema) {
return dataParameter(null, schema, null);
}
}

/** Parameter type enumeration */
public enum ParameterType {
/** Prompt parameter (for LLM prompts) */
PROMPT("prompt"),
/** Data parameter (for other configuration data) */
DATA("data");

private final String value;

ParameterType(String value) {
this.value = value;
}

public String getValue() {
return value;
var result = internalBuild();
// Validate parameter names are unique
var seen = new HashSet<String>();
for (var param : result.getParameters()) {
if (!seen.add(param.name())) {
throw new IllegalArgumentException(
"Duplicate parameter name: '" + param.name() + "'");
}
}
return result;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
*/
@Getter
@Builder
public class RequestContext {
class RequestContext {
/** Validated origin from CORS */
private final String appOrigin;

Expand Down
45 changes: 42 additions & 3 deletions braintrust-sdk/src/main/java/dev/braintrust/eval/Eval.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ public final class Eval<INPUT, OUTPUT> {
private final @Nonnull List<Scorer<INPUT, OUTPUT>> scorers;
private final @Nonnull List<String> tags;
private final @Nonnull Map<String, Object> metadata;
private final @Nonnull Parameters parameters;

private Eval(Builder<INPUT, OUTPUT> builder) {
this.experimentName = builder.experimentName;
Expand All @@ -59,6 +60,7 @@ private Eval(Builder<INPUT, OUTPUT> builder) {
this.scorers = List.copyOf(builder.scorers);
this.tags = List.copyOf(builder.tags);
this.metadata = Map.copyOf(builder.metadata);
this.parameters = builder.buildParameters();
}

/** Runs the evaluation and returns results. */
Expand Down Expand Up @@ -129,7 +131,7 @@ private void evalOne(String experimentId, DatasetCase<INPUT, OUTPUT> datasetCase
.startSpan();
try (var unused =
BraintrustContext.ofExperiment(experimentId, taskSpan).makeCurrent()) {
taskResult = task.apply(datasetCase);
taskResult = task.apply(datasetCase, parameters);
rootSpan.setAttribute(
"braintrust.output_json",
toJson(Map.of("output", taskResult.result())));
Expand Down Expand Up @@ -252,6 +254,8 @@ public static final class Builder<INPUT, OUTPUT> {
private @Nullable Tracer tracer = null;
private @Nullable Task<INPUT, OUTPUT> task;
private @Nonnull List<Scorer<INPUT, OUTPUT>> scorers = List.of();
private @Nonnull List<ParameterDef<?>> parameterDefs = List.of();
private @Nonnull Map<String, Object> parameterValues = Map.of();
private @Nonnull List<String> tags = List.of();
private @Nonnull Map<String, Object> metadata = Map.of();

Expand Down Expand Up @@ -335,9 +339,10 @@ public Builder<INPUT, OUTPUT> taskFunction(Function<INPUT, OUTPUT> taskFn) {
new Task<>() {
@Override
public TaskResult<INPUT, OUTPUT> apply(
DatasetCase<INPUT, OUTPUT> datasetCase) {
DatasetCase<INPUT, OUTPUT> datasetCase, Parameters parameters)
throws Exception {
var result = taskFn.apply(datasetCase.input());
return new TaskResult<>(result, datasetCase);
return new TaskResult<>(result, datasetCase, parameters);
}
});
}
Expand Down Expand Up @@ -365,5 +370,39 @@ public Builder<INPUT, OUTPUT> metadata(Map<String, Object> metadata) {
this.metadata = Map.copyOf(metadata);
return this;
}

/**
* Sets parameter definitions for this eval. Default values from the definitions are used
* unless overridden via {@link #parameterValues(Map)}.
*/
@SuppressWarnings("rawtypes")
public Builder<INPUT, OUTPUT> parameters(ParameterDef<?>... parameterDefs) {
this.parameterDefs = List.of(parameterDefs);
return this;
}

/** Sets parameter definitions for this eval. */
public Builder<INPUT, OUTPUT> parameters(List<ParameterDef<?>> parameterDefs) {
this.parameterDefs = List.copyOf(parameterDefs);
return this;
}

/**
* Sets explicit parameter values, overriding any defaults from parameter definitions. Keys
* not present here fall back to the default value from the corresponding {@link
* ParameterDef}.
*/
public Builder<INPUT, OUTPUT> parameterValues(Map<String, Object> values) {
this.parameterValues = Map.copyOf(values);
return this;
}

/** Builds the merged Parameters from definitions and explicit values. */
private Parameters buildParameters() {
if (parameterDefs.isEmpty() && parameterValues.isEmpty()) {
return Parameters.empty();
}
return new Parameters(parameterDefs, parameterValues);
}
}
}
Loading
Loading