Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions core/src/main/java/com/google/adk/tools/FunctionCallingUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
public final class FunctionCallingUtils {

private static final Logger logger = LoggerFactory.getLogger(FunctionCallingUtils.class);
private static final ObjectMapper objectMapper = JsonBaseModel.getMapper();
private static final ObjectMapper defaultObjectMapper = JsonBaseModel.getMapper();

/** Holds the state during a single schema generation process to handle caching and recursion. */
private static class SchemaGenerationContext {
Expand Down Expand Up @@ -162,7 +162,20 @@ private static Schema buildSchemaFromParameter(Parameter param) {
* @throws IllegalArgumentException if a type is encountered that cannot be serialized by Jackson.
*/
public static Schema buildSchemaFromType(Type type) {
return buildSchemaRecursive(objectMapper.constructType(type), new SchemaGenerationContext());
return buildSchemaFromType(type, defaultObjectMapper);
}

/**
* Builds a Schema from a Java Type, creating a new context for the generation process.
*
* @param type The Java {@link Type} to convert into a Schema.
* @param objectMapper The {@link ObjectMapper} to use for introspecting types.
* @return The generated {@link Schema}.
* @throws IllegalArgumentException if a type is encountered that cannot be serialized by Jackson.
*/
public static Schema buildSchemaFromType(Type type, ObjectMapper objectMapper) {
return buildSchemaRecursive(
objectMapper.constructType(type), new SchemaGenerationContext(), objectMapper);
}

/**
Expand All @@ -173,7 +186,8 @@ public static Schema buildSchemaFromType(Type type) {
* @return The generated {@link Schema}.
* @throws IllegalArgumentException if a type is encountered that cannot be serialized by Jackson.
*/
private static Schema buildSchemaRecursive(JavaType javaType, SchemaGenerationContext context) {
private static Schema buildSchemaRecursive(
JavaType javaType, SchemaGenerationContext context, ObjectMapper objectMapper) {
if (context.isProcessing(javaType)) {
logger.warn("Type {} is recursive. Omitting from schema.", javaType.toCanonical());
return Schema.builder()
Expand All @@ -194,7 +208,9 @@ private static Schema buildSchemaRecursive(JavaType javaType, SchemaGenerationCo
Class<?> rawClass = javaType.getRawClass();

if (javaType.isCollectionLikeType() && List.class.isAssignableFrom(rawClass)) {
builder.type("ARRAY").items(buildSchemaRecursive(javaType.getContentType(), context));
builder
.type("ARRAY")
.items(buildSchemaRecursive(javaType.getContentType(), context, objectMapper));
} else if (javaType.isMapLikeType()) {
builder.type("OBJECT");
} else if (String.class.equals(rawClass)) {
Expand Down Expand Up @@ -232,7 +248,8 @@ private static Schema buildSchemaRecursive(JavaType javaType, SchemaGenerationCo
for (BeanPropertyDefinition property : beanDescription.findProperties()) {
AnnotatedMember member = property.getPrimaryMember();
if (member != null) {
properties.put(property.getName(), buildSchemaRecursive(member.getType(), context));
properties.put(
property.getName(), buildSchemaRecursive(member.getType(), context, objectMapper));
if (property.isRequired()) {
required.add(property.getName());
}
Expand Down
24 changes: 20 additions & 4 deletions core/src/main/java/com/google/adk/tools/FunctionTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,13 @@
public class FunctionTool extends BaseTool {

private static final Logger logger = LoggerFactory.getLogger(FunctionTool.class);
private static final ObjectMapper objectMapper = JsonBaseModel.getMapper();
private static final ObjectMapper defaultObjectMapper = JsonBaseModel.getMapper();

private final @Nullable Object instance;
private final Method func;
private final FunctionDeclaration funcDeclaration;
private final boolean requireConfirmation;
private final ObjectMapper objectMapper;

public static FunctionTool create(Object instance, Method func) {
return create(instance, func, /* requireConfirmation= */ false);
Expand Down Expand Up @@ -166,11 +167,25 @@ private static boolean wasCompiledWithDefaultParameterNames(Method func) {
}

protected FunctionTool(@Nullable Object instance, Method func, boolean isLongRunning) {
this(instance, func, isLongRunning, /* requireConfirmation= */ false);
this(instance, func, isLongRunning, /* requireConfirmation= */ false, defaultObjectMapper);
}

protected FunctionTool(
@Nullable Object instance, Method func, boolean isLongRunning, boolean requireConfirmation) {
this(instance, func, isLongRunning, requireConfirmation, defaultObjectMapper);
}

protected FunctionTool(
@Nullable Object instance, Method func, boolean isLongRunning, ObjectMapper objectMapper) {
this(instance, func, isLongRunning, /* requireConfirmation= */ false, objectMapper);
}

protected FunctionTool(
@Nullable Object instance,
Method func,
boolean isLongRunning,
boolean requireConfirmation,
ObjectMapper objectMapper) {
super(
func.isAnnotationPresent(Annotations.Schema.class)
&& !func.getAnnotation(Annotations.Schema.class).name().isEmpty()
Expand All @@ -193,6 +208,7 @@ protected FunctionTool(
FunctionCallingUtils.buildFunctionDeclaration(
this.func, ImmutableList.of("toolContext", "inputStream"));
this.requireConfirmation = requireConfirmation;
this.objectMapper = objectMapper;
}

@Override
Expand Down Expand Up @@ -365,7 +381,7 @@ private static Class<?> getTypeClass(Type type, String paramName) {
}
}

private static List<Object> createList(List<Object> values, Class<?> type) {
private List<Object> createList(List<Object> values, Class<?> type) {
List<Object> list = new ArrayList<>();
// List of parameterized type is not supported.
if (type == null) {
Expand All @@ -387,7 +403,7 @@ private static List<Object> createList(List<Object> values, Class<?> type) {
return list;
}

private static Object castValue(Object value, Class<?> type) {
private Object castValue(Object value, Class<?> type) {
if (type.equals(Integer.class) || type.equals(int.class)) {
if (value instanceof Integer) {
return value;
Expand Down