diff --git a/core/src/main/java/com/google/adk/tools/FunctionCallingUtils.java b/core/src/main/java/com/google/adk/tools/FunctionCallingUtils.java index eec115127..3244b40e9 100644 --- a/core/src/main/java/com/google/adk/tools/FunctionCallingUtils.java +++ b/core/src/main/java/com/google/adk/tools/FunctionCallingUtils.java @@ -45,7 +45,7 @@ public final class FunctionCallingUtils { private static final Logger logger = LoggerFactory.getLogger(FunctionCallingUtils.class); - private static final ObjectMapper objectMapper = JsonBaseModel.getMapper(); + private static final ObjectMapper defaultObjectMapper = JsonBaseModel.getMapper(); /** Holds the state during a single schema generation process to handle caching and recursion. */ private static class SchemaGenerationContext { @@ -162,7 +162,20 @@ private static Schema buildSchemaFromParameter(Parameter param) { * @throws IllegalArgumentException if a type is encountered that cannot be serialized by Jackson. */ public static Schema buildSchemaFromType(Type type) { - return buildSchemaRecursive(objectMapper.constructType(type), new SchemaGenerationContext()); + return buildSchemaFromType(type, defaultObjectMapper); + } + + /** + * Builds a Schema from a Java Type, creating a new context for the generation process. + * + * @param type The Java {@link Type} to convert into a Schema. + * @param objectMapper The {@link ObjectMapper} to use for introspecting types. + * @return The generated {@link Schema}. + * @throws IllegalArgumentException if a type is encountered that cannot be serialized by Jackson. + */ + public static Schema buildSchemaFromType(Type type, ObjectMapper objectMapper) { + return buildSchemaRecursive( + objectMapper.constructType(type), new SchemaGenerationContext(), objectMapper); } /** @@ -173,7 +186,8 @@ public static Schema buildSchemaFromType(Type type) { * @return The generated {@link Schema}. * @throws IllegalArgumentException if a type is encountered that cannot be serialized by Jackson. */ - private static Schema buildSchemaRecursive(JavaType javaType, SchemaGenerationContext context) { + private static Schema buildSchemaRecursive( + JavaType javaType, SchemaGenerationContext context, ObjectMapper objectMapper) { if (context.isProcessing(javaType)) { logger.warn("Type {} is recursive. Omitting from schema.", javaType.toCanonical()); return Schema.builder() @@ -194,7 +208,9 @@ private static Schema buildSchemaRecursive(JavaType javaType, SchemaGenerationCo Class rawClass = javaType.getRawClass(); if (javaType.isCollectionLikeType() && List.class.isAssignableFrom(rawClass)) { - builder.type("ARRAY").items(buildSchemaRecursive(javaType.getContentType(), context)); + builder + .type("ARRAY") + .items(buildSchemaRecursive(javaType.getContentType(), context, objectMapper)); } else if (javaType.isMapLikeType()) { builder.type("OBJECT"); } else if (String.class.equals(rawClass)) { @@ -232,7 +248,8 @@ private static Schema buildSchemaRecursive(JavaType javaType, SchemaGenerationCo for (BeanPropertyDefinition property : beanDescription.findProperties()) { AnnotatedMember member = property.getPrimaryMember(); if (member != null) { - properties.put(property.getName(), buildSchemaRecursive(member.getType(), context)); + properties.put( + property.getName(), buildSchemaRecursive(member.getType(), context, objectMapper)); if (property.isRequired()) { required.add(property.getName()); } diff --git a/core/src/main/java/com/google/adk/tools/FunctionTool.java b/core/src/main/java/com/google/adk/tools/FunctionTool.java index a6167ee46..40046aa30 100644 --- a/core/src/main/java/com/google/adk/tools/FunctionTool.java +++ b/core/src/main/java/com/google/adk/tools/FunctionTool.java @@ -44,12 +44,13 @@ public class FunctionTool extends BaseTool { private static final Logger logger = LoggerFactory.getLogger(FunctionTool.class); - private static final ObjectMapper objectMapper = JsonBaseModel.getMapper(); + private static final ObjectMapper defaultObjectMapper = JsonBaseModel.getMapper(); private final @Nullable Object instance; private final Method func; private final FunctionDeclaration funcDeclaration; private final boolean requireConfirmation; + private final ObjectMapper objectMapper; public static FunctionTool create(Object instance, Method func) { return create(instance, func, /* requireConfirmation= */ false); @@ -166,11 +167,25 @@ private static boolean wasCompiledWithDefaultParameterNames(Method func) { } protected FunctionTool(@Nullable Object instance, Method func, boolean isLongRunning) { - this(instance, func, isLongRunning, /* requireConfirmation= */ false); + this(instance, func, isLongRunning, /* requireConfirmation= */ false, defaultObjectMapper); } protected FunctionTool( @Nullable Object instance, Method func, boolean isLongRunning, boolean requireConfirmation) { + this(instance, func, isLongRunning, requireConfirmation, defaultObjectMapper); + } + + protected FunctionTool( + @Nullable Object instance, Method func, boolean isLongRunning, ObjectMapper objectMapper) { + this(instance, func, isLongRunning, /* requireConfirmation= */ false, objectMapper); + } + + protected FunctionTool( + @Nullable Object instance, + Method func, + boolean isLongRunning, + boolean requireConfirmation, + ObjectMapper objectMapper) { super( func.isAnnotationPresent(Annotations.Schema.class) && !func.getAnnotation(Annotations.Schema.class).name().isEmpty() @@ -193,6 +208,7 @@ protected FunctionTool( FunctionCallingUtils.buildFunctionDeclaration( this.func, ImmutableList.of("toolContext", "inputStream")); this.requireConfirmation = requireConfirmation; + this.objectMapper = objectMapper; } @Override @@ -365,7 +381,7 @@ private static Class getTypeClass(Type type, String paramName) { } } - private static List createList(List values, Class type) { + private List createList(List values, Class type) { List list = new ArrayList<>(); // List of parameterized type is not supported. if (type == null) { @@ -387,7 +403,7 @@ private static List createList(List values, Class type) { return list; } - private static Object castValue(Object value, Class type) { + private Object castValue(Object value, Class type) { if (type.equals(Integer.class) || type.equals(int.class)) { if (value instanceof Integer) { return value;