diff --git a/core/pom.xml b/core/pom.xml index d31a2691b..c6592fa05 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -205,10 +205,29 @@ maven-surefire-plugin + basic test + + + false + + + + + vertex-ai-rag-retrieval + + test + + + + true + + + VertexAiRagRetrievalTest#processLlmRequest_gemini2Model_addVertexRagStoreToConfig, VertexAiRagRetrievalTest#processLlmRequest_otherModel_doNotAddVertexRagStoreToConfig + apigee-llm diff --git a/core/src/main/java/com/google/adk/tools/retrieval/VertexAiRagRetrieval.java b/core/src/main/java/com/google/adk/tools/retrieval/VertexAiRagRetrieval.java index b36a05d10..16f11a1f8 100644 --- a/core/src/main/java/com/google/adk/tools/retrieval/VertexAiRagRetrieval.java +++ b/core/src/main/java/com/google/adk/tools/retrieval/VertexAiRagRetrieval.java @@ -20,6 +20,7 @@ import com.google.adk.models.LlmRequest; import com.google.adk.tools.ToolContext; +import com.google.adk.utils.ModelNameUtils; import com.google.cloud.aiplatform.v1.RagContexts; import com.google.cloud.aiplatform.v1.RagQuery; import com.google.cloud.aiplatform.v1.RetrieveContextsRequest; @@ -105,10 +106,9 @@ public VertexAiRagRetrieval( public Completable processLlmRequest( LlmRequest.Builder llmRequestBuilder, ToolContext toolContext) { LlmRequest llmRequest = llmRequestBuilder.build(); - // Use Gemini built-in Vertex AI RAG tool for Gemini 2 models or when using Vertex AI API Model + // Use Gemini built-in Vertex AI RAG tool for Gemini models when using Vertex AI API Model boolean useVertexAi = Boolean.parseBoolean(System.getenv("GOOGLE_GENAI_USE_VERTEXAI")); - if (useVertexAi - && (llmRequest.model().isPresent() && llmRequest.model().get().startsWith("gemini-2"))) { + if (useVertexAi && llmRequest.model().filter(ModelNameUtils::isGeminiModel).isPresent()) { GenerateContentConfig config = llmRequest.config().orElseGet(() -> GenerateContentConfig.builder().build()); ImmutableList.Builder toolsBuilder = ImmutableList.builder(); diff --git a/core/src/test/java/com/google/adk/tools/retrieval/VertexAiRagRetrievalTest.java b/core/src/test/java/com/google/adk/tools/retrieval/VertexAiRagRetrievalTest.java index 6f04a7ef8..8246751b9 100644 --- a/core/src/test/java/com/google/adk/tools/retrieval/VertexAiRagRetrievalTest.java +++ b/core/src/test/java/com/google/adk/tools/retrieval/VertexAiRagRetrievalTest.java @@ -208,7 +208,7 @@ public void processLlmRequest_otherModel_doNotAddVertexRagStoreToConfig() { "projects/test-project/locations/us-central1", ragResources, vectorDistanceThreshold); - LlmRequest.Builder llmRequestBuilder = LlmRequest.builder().model("gemini-1-pro"); + LlmRequest.Builder llmRequestBuilder = LlmRequest.builder().model("other-model"); ToolContext toolContext = buildToolContext(); GenerateContentConfig initialConfig = GenerateContentConfig.builder().build(); llmRequestBuilder.config(initialConfig);