diff --git a/core/pom.xml b/core/pom.xml
index d31a2691b..c6592fa05 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -205,10 +205,29 @@
maven-surefire-plugin
+
basic
test
+
+
+ false
+
+
+
+
+ vertex-ai-rag-retrieval
+
+ test
+
+
+
+ true
+
+
+ VertexAiRagRetrievalTest#processLlmRequest_gemini2Model_addVertexRagStoreToConfig, VertexAiRagRetrievalTest#processLlmRequest_otherModel_doNotAddVertexRagStoreToConfig
+
apigee-llm
diff --git a/core/src/main/java/com/google/adk/tools/retrieval/VertexAiRagRetrieval.java b/core/src/main/java/com/google/adk/tools/retrieval/VertexAiRagRetrieval.java
index b36a05d10..16f11a1f8 100644
--- a/core/src/main/java/com/google/adk/tools/retrieval/VertexAiRagRetrieval.java
+++ b/core/src/main/java/com/google/adk/tools/retrieval/VertexAiRagRetrieval.java
@@ -20,6 +20,7 @@
import com.google.adk.models.LlmRequest;
import com.google.adk.tools.ToolContext;
+import com.google.adk.utils.ModelNameUtils;
import com.google.cloud.aiplatform.v1.RagContexts;
import com.google.cloud.aiplatform.v1.RagQuery;
import com.google.cloud.aiplatform.v1.RetrieveContextsRequest;
@@ -105,10 +106,9 @@ public VertexAiRagRetrieval(
public Completable processLlmRequest(
LlmRequest.Builder llmRequestBuilder, ToolContext toolContext) {
LlmRequest llmRequest = llmRequestBuilder.build();
- // Use Gemini built-in Vertex AI RAG tool for Gemini 2 models or when using Vertex AI API Model
+ // Use Gemini built-in Vertex AI RAG tool for Gemini models when using Vertex AI API Model
boolean useVertexAi = Boolean.parseBoolean(System.getenv("GOOGLE_GENAI_USE_VERTEXAI"));
- if (useVertexAi
- && (llmRequest.model().isPresent() && llmRequest.model().get().startsWith("gemini-2"))) {
+ if (useVertexAi && llmRequest.model().filter(ModelNameUtils::isGeminiModel).isPresent()) {
GenerateContentConfig config =
llmRequest.config().orElseGet(() -> GenerateContentConfig.builder().build());
ImmutableList.Builder toolsBuilder = ImmutableList.builder();
diff --git a/core/src/test/java/com/google/adk/tools/retrieval/VertexAiRagRetrievalTest.java b/core/src/test/java/com/google/adk/tools/retrieval/VertexAiRagRetrievalTest.java
index 6f04a7ef8..8246751b9 100644
--- a/core/src/test/java/com/google/adk/tools/retrieval/VertexAiRagRetrievalTest.java
+++ b/core/src/test/java/com/google/adk/tools/retrieval/VertexAiRagRetrievalTest.java
@@ -208,7 +208,7 @@ public void processLlmRequest_otherModel_doNotAddVertexRagStoreToConfig() {
"projects/test-project/locations/us-central1",
ragResources,
vectorDistanceThreshold);
- LlmRequest.Builder llmRequestBuilder = LlmRequest.builder().model("gemini-1-pro");
+ LlmRequest.Builder llmRequestBuilder = LlmRequest.builder().model("other-model");
ToolContext toolContext = buildToolContext();
GenerateContentConfig initialConfig = GenerateContentConfig.builder().build();
llmRequestBuilder.config(initialConfig);