From de5be95ec6ee50d285c9ecfabb59e3bfbd53ac9c Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Wed, 20 May 2026 09:56:54 -0700 Subject: [PATCH] chore: merge GA and preview templates in agentplatform. PiperOrigin-RevId: 918506513 --- .../unit_agentplatform_llama_index_py311.cfg | 13 + agentplatform/agent_engines/templates/adk.py | 36 +- .../agent_engines/templates/llama_index.py | 555 ++++++++++++++++++ noxfile.py | 35 ++ .../frameworks/test_frameworks_adk.py | 192 ++++++ .../frameworks/test_frameworks_llama_index.py | 345 +++++++++++ 6 files changed, 1175 insertions(+), 1 deletion(-) create mode 100644 .kokoro/presubmit/unit_agentplatform_llama_index_py311.cfg create mode 100644 agentplatform/agent_engines/templates/llama_index.py create mode 100644 tests/unit/agentplatform/frameworks/test_frameworks_llama_index.py diff --git a/.kokoro/presubmit/unit_agentplatform_llama_index_py311.cfg b/.kokoro/presubmit/unit_agentplatform_llama_index_py311.cfg new file mode 100644 index 0000000000..8a61813327 --- /dev/null +++ b/.kokoro/presubmit/unit_agentplatform_llama_index_py311.cfg @@ -0,0 +1,13 @@ +# Format: //devtools/kokoro/config/proto/build.proto + +# Run unit tests for LlamaIndex on Python 3.11 +env_vars: { + key: "NOX_SESSION" + value: "unit_agentplatform_llama_index-3.11" +} + +# Run unit tests in parallel, splitting up by file +env_vars: { + key: "PYTEST_ADDOPTS" + value: "-n=auto --dist=loadscope" +} diff --git a/agentplatform/agent_engines/templates/adk.py b/agentplatform/agent_engines/templates/adk.py index c68cc1d45d..2fce37de61 100644 --- a/agentplatform/agent_engines/templates/adk.py +++ b/agentplatform/agent_engines/templates/adk.py @@ -86,6 +86,15 @@ except (ImportError, AttributeError): BaseMemoryService = Any + try: + from google.adk.auth.credential_service.base_credential_service import ( + BaseCredentialService, + ) + + BaseCredentialService = BaseCredentialService + except (ImportError, AttributeError): + BaseCredentialService = Any + try: from opentelemetry.sdk import trace @@ -682,6 +691,9 @@ def __init__( session_service_builder: Optional[Callable[..., "BaseSessionService"]] = None, artifact_service_builder: Optional[Callable[..., "BaseArtifactService"]] = None, memory_service_builder: Optional[Callable[..., "BaseMemoryService"]] = None, + credential_service_builder: Optional[ + Callable[..., "BaseCredentialService"] + ] = None, instrumentor_builder: Optional[Callable[..., Any]] = None, ): """An ADK Application. @@ -715,6 +727,9 @@ def __init__( Defaults to a callable that returns InMemoryMemoryService when running locally and VertexAiMemoryBankService when running on Agent Engine. + credential_service_builder (Callable[..., BaseCredentialService]): + Optional. A callable that returns an ADK credential service. + Defaults to a callable that returns InMemoryCredentialService. instrumentor_builder (Callable[..., Any]): Optional. Callable that returns a new instrumentor. This can be used for customizing the instrumentation logic of the Agent. @@ -759,6 +774,7 @@ def __init__( "session_service_builder": session_service_builder, "artifact_service_builder": artifact_service_builder, "memory_service_builder": memory_service_builder, + "credential_service_builder": credential_service_builder, "instrumentor_builder": instrumentor_builder, "express_mode_api_key": ( initializer.global_config.api_key or os.environ.get("GOOGLE_API_KEY") @@ -912,6 +928,9 @@ def clone(self): session_service_builder=self._tmpl_attrs.get("session_service_builder"), artifact_service_builder=self._tmpl_attrs.get("artifact_service_builder"), memory_service_builder=self._tmpl_attrs.get("memory_service_builder"), + credential_service_builder=self._tmpl_attrs.get( + "credential_service_builder" + ), instrumentor_builder=self._tmpl_attrs.get("instrumentor_builder"), ) @@ -924,6 +943,9 @@ def set_up(self): InMemoryArtifactService, ) from google.adk.memory.in_memory_memory_service import InMemoryMemoryService + from google.adk.auth.credential_service.in_memory_credential_service import ( + InMemoryCredentialService, + ) os.environ["GOOGLE_GENAI_USE_VERTEXAI"] = "1" project = self._tmpl_attrs.get("project") @@ -1078,6 +1100,12 @@ def set_up(self): else: self._tmpl_attrs["memory_service"] = InMemoryMemoryService() + credential_service_builder = self._tmpl_attrs.get("credential_service_builder") + if credential_service_builder: + self._tmpl_attrs["credential_service"] = credential_service_builder() + else: + self._tmpl_attrs["credential_service"] = InMemoryCredentialService() + self._tmpl_attrs["runner"] = Runner( app=self._tmpl_attrs.get("app"), agent=( @@ -1114,6 +1142,7 @@ def set_up(self): session_service=self._tmpl_attrs.get("in_memory_session_service"), artifact_service=self._tmpl_attrs.get("in_memory_artifact_service"), memory_service=self._tmpl_attrs.get("in_memory_memory_service"), + credential_service=self._tmpl_attrs.get("credential_service"), ) async def async_stream_query( @@ -1183,11 +1212,16 @@ async def async_stream_query( from google.adk.events.event import Event session_service = self._tmpl_attrs.get("session_service") + session_obj = await session_service.get_session( + app_name=self._app_name(), + user_id=user_id, + session_id=session_id, + ) for event in session_events: if not isinstance(event, Event): event = Event.model_validate(event) await session_service.append_event( - session=session, + session=session_obj, event=event, ) diff --git a/agentplatform/agent_engines/templates/llama_index.py b/agentplatform/agent_engines/templates/llama_index.py new file mode 100644 index 0000000000..3bf13dd697 --- /dev/null +++ b/agentplatform/agent_engines/templates/llama_index.py @@ -0,0 +1,555 @@ +# -*- coding: utf-8 -*- +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Mapping, + Optional, + Sequence, + Union, +) + +if TYPE_CHECKING: + try: + from llama_index.core.base.query_pipeline import query + from llama_index.core.llms import function_calling + from llama_index.core import query_pipeline + + FunctionCallingLLM = function_calling.FunctionCallingLLM + QueryComponent = query.QUERY_COMPONENT_TYPE + QueryPipeline = query_pipeline.QueryPipeline + except ImportError: + FunctionCallingLLM = Any + QueryComponent = Any + QueryPipeline = Any + + try: + from opentelemetry.sdk import trace + + TracerProvider = trace.TracerProvider + SpanProcessor = trace.SpanProcessor + SynchronousMultiSpanProcessor = trace.SynchronousMultiSpanProcessor + except ImportError: + TracerProvider = Any + SpanProcessor = Any + SynchronousMultiSpanProcessor = Any + + +def _default_model_builder( + model_name: str, + *, + project: str, + location: str, + model_kwargs: Optional[Mapping[str, Any]] = None, +) -> "FunctionCallingLLM": + """Creates a default model builder for LlamaIndex.""" + import agentplatform + from google.cloud.aiplatform import initializer + from llama_index.llms import google_genai + + model_kwargs = model_kwargs or {} + model = google_genai.GoogleGenAI( + model=model_name, + agentplatform_config={"project": project, "location": location}, + **model_kwargs, + ) + current_project = initializer.global_config.project + current_location = initializer.global_config.location + agentplatform.init(project=current_project, location=current_location) + return model + + +def _default_runnable_builder( + model: "FunctionCallingLLM", + *, + system_instruction: Optional[str] = None, + prompt: Optional["QueryComponent"] = None, + retriever: Optional["QueryComponent"] = None, + response_synthesizer: Optional["QueryComponent"] = None, + runnable_kwargs: Optional[Mapping[str, Any]] = None, +) -> "QueryPipeline": + """Creates a default runnable builder for LlamaIndex.""" + try: + from llama_index.core.query_pipeline import QueryPipeline + except ImportError: + raise ImportError( + "Please call 'pip install google-cloud-aiplatform[llama_index]'." + ) + + prompt = prompt or _default_prompt( + system_instruction=system_instruction, + ) + pipeline = QueryPipeline(**runnable_kwargs) + pipeline_modules = { + "prompt": prompt, + "model": model, + } + if retriever: + pipeline_modules["retriever"] = retriever + if response_synthesizer: + pipeline_modules["response_synthesizer"] = response_synthesizer + + pipeline.add_modules(pipeline_modules) + pipeline.add_link("prompt", "model") + if "retriever" in pipeline_modules: + pipeline.add_link("model", "retriever") + if "response_synthesizer" in pipeline_modules: + pipeline.add_link("model", "response_synthesizer", dest_key="query_str") + if "retriever" in pipeline_modules: + pipeline.add_link("retriever", "response_synthesizer", dest_key="nodes") + + return pipeline + + +def _default_prompt( + system_instruction: Optional[str] = None, +) -> "QueryComponent": + """Creates a default prompt template for LlamaIndex. + + Handles both system instruction and user input. + + Args: + system_instruction (str, optional): The system instruction to use. + + Returns: + QueryComponent: The LlamaIndex QueryComponent. + """ + try: + from llama_index.core import prompts + from llama_index.core.base.llms import types + except ImportError: + raise ImportError( + "Please call 'pip install google-cloud-aiplatform[llama_index]'." + ) + + # Define a prompt template + message_templates = [] + if system_instruction: + message_templates.append( + types.ChatMessage(role=types.MessageRole.SYSTEM, content=system_instruction) + ) + # Add user input message + message_templates.append( + types.ChatMessage(role=types.MessageRole.USER, content="{input}") + ) + + # Create the prompt template + return prompts.ChatPromptTemplate(message_templates=message_templates) + + +def _override_active_span_processor( + tracer_provider: "TracerProvider", + active_span_processor: "SynchronousMultiSpanProcessor", +): + """Overrides the active span processor. + + When working with multiple LlamaIndexQueryPipelineAgents in the same + environment, it's crucial to manage trace exports carefully. + Each agent needs its own span processor tied to a unique project ID. + While we add a new span processor for each agent, this can lead to + unexpected behavior. + For instance, with two agents linked to different projects, traces from the + second agent might be sent to both projects. + To prevent this and guarantee traces go to the correct project, we overwrite + the active span processor whenever a new LlamaIndexQueryPipelineAgent is + created. + + Args: + tracer_provider (TracerProvider): + The tracer provider to use for the project. + active_span_processor (SynchronousMultiSpanProcessor): + The active span processor overrides the tracer provider's + active span processor. + """ + if tracer_provider._active_span_processor: + tracer_provider._active_span_processor.shutdown() + tracer_provider._active_span_processor = active_span_processor + + +class LlamaIndexQueryPipelineAgent: + """A LlamaIndex Query Pipeline Agent. + + This agent uses a query pipeline for LLAIndex, including prompt, model, + retrieval and summarization steps. More details can be found in + https://docs.llamaindex.ai/en/stable/module_guides/querying/pipeline/. + """ + + agent_framework = "llama-index" + + def __init__( + self, + model: str, + *, + system_instruction: Optional[str] = None, + prompt: Optional["QueryComponent"] = None, + model_kwargs: Optional[Mapping[str, Any]] = None, + model_builder: Optional[Callable[..., "FunctionCallingLLM"]] = None, + retriever_kwargs: Optional[Mapping[str, Any]] = None, + retriever_builder: Optional[Callable[..., "QueryComponent"]] = None, + response_synthesizer_kwargs: Optional[Mapping[str, Any]] = None, + response_synthesizer_builder: Optional[Callable[..., "QueryComponent"]] = None, + runnable_kwargs: Optional[Mapping[str, Any]] = None, + runnable_builder: Optional[Callable[..., "QueryPipeline"]] = None, + enable_tracing: bool = False, + ): + """Initializes the LlamaIndexQueryPipelineAgent. + + Under-the-hood, assuming .set_up() is called, this will correspond to + ```python + # model_builder + model = model_builder(model_name, project, location, model_kwargs) + + # runnable_builder + runnable = runnable_builder( + prompt=prompt, + model=model, + retriever=retriever_builder(model, retriever_kwargs), + response_synthesizer=response_synthesizer_builder( + model, response_synthesizer_kwargs + ), + runnable_kwargs=runnable_kwargs, + ) + ``` + + When everything is based on their default values, this corresponds to a + query pipeline `Prompt - Model`: + ```python + # Default Model Builder + model = google_genai.GoogleGenAI( + model=model_name, + agentplatform_config={ + "project": initializer.global_config.project, + "location": initializer.global_config.location, + }, + ) + + # Default Prompt Builder + prompt = prompts.ChatPromptTemplate( + message_templates=[ + types.ChatMessage( + role=types.MessageRole.USER, + content="{input}", + ), + ], + ) + + # Default Runnable Builder + runnable = QueryPipeline( + modules = { + "prompt": prompt, + "model": model, + }, + ) + pipeline.add_link("prompt", "model") + ``` + + When `system_instruction` is specified, the prompt will be updated to + include the system instruction. + ```python + # Updated Prompt Builder + prompt = prompts.ChatPromptTemplate( + message_templates=[ + types.ChatMessage( + role=types.MessageRole.SYSTEM, + content=system_instruction, + ), + types.ChatMessage( + role=types.MessageRole.USER, + content="{input}", + ), + ], + ) + ``` + + When all inputs are specified, this corresponds to a query pipeline + `Prompt - Model - Retriever - Summarizer`: + ```python + runnable = QueryPipeline( + modules = { + "prompt": prompt, + "model": model, + "retriever": retriever_builder(retriever_kwargs), + "response_synthesizer": response_synthesizer_builder( + response_synthesizer_kwargs + ), + }, + ) + pipeline.add_link("prompt", "model") + pipeline.add_link("model", "retriever") + pipeline.add_link("model", "response_synthesizer", dest_key="query_str") + pipeline.add_link("retriever", "response_synthesizer", dest_key="nodes") + ``` + + Args: + model (str): + The name of the model (e.g. "gemini-1.0-pro"). + system_instruction (str): + Optional. The system instruction to use for the agent. + prompt (llama_index.core.base.query_pipeline.query.QUERY_COMPONENT_TYPE): + Optional. The prompt template for the model. + model_kwargs (Mapping[str, Any]): + Optional. Keyword arguments for the model constructor of the + google_genai.GoogleGenAI. An example of a model_kwargs is: + ```python + { + # api_key (string): The API key for the GoogleGenAI model. + # The API can also be fetched from the GOOGLE_API_KEY + # environment variable. If `agentplatform_config` is provided, + # the API key is ignored. + "api_key": "your_api_key", + # temperature (float): Sampling temperature, it controls the + # degree of randomness in token selection. If not provided, + # the default temperature is 0.1. + "temperature": 0.1, + # context_window (int): The context window of the model. + # If not provided, the default context window is 200000. + "context_window": 200000, + # max_tokens (int): Token limit determines the maximum + # amount of text output from one prompt. If not provided, + # the default max_tokens is 256. + "max_tokens": 256, + # is_function_calling_model (bool): Whether the model is a + # function calling model. If not provided, the default + # is_function_calling_model is True. + "is_function_calling_model": True, + } + ``` + model_builder (Callable): + Optional. Callable that returns a language model. + retriever_kwargs (Mapping[str, Any]): + Optional. Keyword arguments for the retriever constructor. + retriever_builder (Callable): + Optional. Callable that returns a retriever object. + response_synthesizer_kwargs (Mapping[str, Any]): + Optional. Keyword arguments for the response synthesizer constructor. + response_synthesizer_builder (Callable): + Optional. Callable that returns a response_synthesizer object. + runnable_kwargs (Mapping[str, Any]): + Optional. Keyword arguments for the runnable constructor. + runnable_builder (Callable): + Optional. Callable that returns a runnable (query pipeline). + enable_tracing (bool): + Optional. Whether to enable tracing. Defaults to False. + """ + from google.cloud.aiplatform import initializer + + self._project = initializer.global_config.project + self._location = initializer.global_config.location + self._model_name = model + self._system_instruction = system_instruction + self._prompt = prompt + + self._model = None + self._model_kwargs = model_kwargs or {} + self._model_builder = model_builder + + self._retriever = None + self._retriever_kwargs = retriever_kwargs or {} + self._retriever_builder = retriever_builder + + self._response_synthesizer = None + self._response_synthesizer_kwargs = response_synthesizer_kwargs or {} + self._response_synthesizer_builder = response_synthesizer_builder + + self._runnable = None + self._runnable_kwargs = runnable_kwargs or {} + self._runnable_builder = runnable_builder + + self._instrumentor = None + self._enable_tracing = enable_tracing + + def set_up(self): + """Sets up the agent for execution of queries at runtime. + + It initializes the model, connects it with the prompt template, + retriever and response_synthesizer. + + This method should not be called for an object that being passed to + the ReasoningEngine service for deployment, as it initializes clients + that can not be serialized. + """ + if self._enable_tracing: + from agentplatform._genai.agent_engines import _agent_engines_utils + + cloud_trace_exporter = _agent_engines_utils._import_cloud_trace_exporter_or_warn() + cloud_trace_v2 = _agent_engines_utils._import_cloud_trace_v2_or_warn() + openinference_llama_index = ( + _agent_engines_utils._import_openinference_llama_index_or_warn() + ) + opentelemetry = _agent_engines_utils._import_opentelemetry_or_warn() + opentelemetry_sdk_trace = _agent_engines_utils._import_opentelemetry_sdk_trace_or_warn() + if all( + ( + cloud_trace_exporter, + cloud_trace_v2, + openinference_llama_index, + opentelemetry, + opentelemetry_sdk_trace, + ) + ): + import google.auth + + credentials, _ = google.auth.default() + span_exporter = cloud_trace_exporter.CloudTraceSpanExporter( + project_id=self._project, + client=cloud_trace_v2.TraceServiceClient( + credentials=credentials.with_quota_project(self._project), + ), + ) + span_processor: SpanProcessor = ( + opentelemetry_sdk_trace.export.SimpleSpanProcessor( + span_exporter=span_exporter, + ) + ) + tracer_provider: TracerProvider = ( + opentelemetry.trace.get_tracer_provider() + ) + # Get the appropriate tracer provider: + # 1. If _TRACER_PROVIDER is already set, use that. + # 2. Otherwise, if the OTEL_PYTHON_TRACER_PROVIDER environment + # variable is set, use that. + # 3. As a final fallback, use _PROXY_TRACER_PROVIDER. + # If none of the above is set, we log a warning, and + # create a tracer provider. + if not tracer_provider: + from google.cloud.aiplatform import base + + _LOGGER = base.Logger(__name__) + _LOGGER.warning( + "No tracer provider. By default, " + "we should get one of the following providers: " + "OTEL_PYTHON_TRACER_PROVIDER, _TRACER_PROVIDER, " + "or _PROXY_TRACER_PROVIDER." + ) + tracer_provider = opentelemetry_sdk_trace.TracerProvider() + opentelemetry.trace.set_tracer_provider(tracer_provider) + # Avoids AttributeError: + # 'ProxyTracerProvider' and 'NoOpTracerProvider' objects has no + # attribute 'add_span_processor'. + if _agent_engines_utils.is_noop_or_proxy_tracer_provider(tracer_provider): + tracer_provider = opentelemetry_sdk_trace.TracerProvider() + opentelemetry.trace.set_tracer_provider(tracer_provider) + # Avoids OpenTelemetry client already exists error. + _override_active_span_processor( + tracer_provider, + opentelemetry_sdk_trace.SynchronousMultiSpanProcessor(), + ) + tracer_provider.add_span_processor(span_processor) + # Keep the instrumentation up-to-date. + # When creating multiple LlamaIndexQueryPipelineAgents, + # we need to keep the instrumentation up-to-date. + # We deliberately override the instrument each time, + # so that if different agents end up using different + # instrumentations, we guarantee that the user is always + # working with the most recent agent's instrumentation. + self._instrumentor = openinference_llama_index.LlamaIndexInstrumentor() + if self._instrumentor.is_instrumented_by_opentelemetry: + self._instrumentor.uninstrument() + self._instrumentor.instrument() + else: + from google.cloud.aiplatform import base + + _LOGGER = base.Logger(__name__) + _LOGGER.warning( + "enable_tracing=True but proceeding with tracing disabled " + "because not all packages for tracing have been installed" + ) + + model_builder = self._model_builder or _default_model_builder + self._model = model_builder( + model_name=self._model_name, + model_kwargs=self._model_kwargs, + project=self._project, + location=self._location, + ) + + if self._retriever_builder: + self._retriever = self._retriever_builder( + model=self._model, + retriever_kwargs=self._retriever_kwargs, + ) + + if self._response_synthesizer_builder: + self._response_synthesizer = self._response_synthesizer_builder( + model=self._model, + response_synthesizer_kwargs=self._response_synthesizer_kwargs, + ) + + runnable_builder = self._runnable_builder or _default_runnable_builder + self._runnable = runnable_builder( + prompt=self._prompt, + model=self._model, + system_instruction=self._system_instruction, + retriever=self._retriever, + response_synthesizer=self._response_synthesizer, + runnable_kwargs=self._runnable_kwargs, + ) + + def clone(self) -> "LlamaIndexQueryPipelineAgent": + """Returns a clone of the LlamaIndexQueryPipelineAgent.""" + import copy + + return LlamaIndexQueryPipelineAgent( + model=self._model_name, + system_instruction=self._system_instruction, + prompt=copy.deepcopy(self._prompt), + model_kwargs=copy.deepcopy(self._model_kwargs), + model_builder=self._model_builder, + retriever_kwargs=copy.deepcopy(self._retriever_kwargs), + retriever_builder=self._retriever_builder, + response_synthesizer_kwargs=copy.deepcopy( + self._response_synthesizer_kwargs + ), + response_synthesizer_builder=self._response_synthesizer_builder, + runnable_kwargs=copy.deepcopy(self._runnable_kwargs), + runnable_builder=self._runnable_builder, + enable_tracing=self._enable_tracing, + ) + + def query( + self, + input: Union[str, Mapping[str, Any]], + **kwargs: Any, + ) -> Union[str, Dict[str, Any], Sequence[Union[str, Dict[str, Any]]]]: + """Queries the Agent with the given input and config. + + Args: + input (Union[str, Mapping[str, Any]]): + Required. The input to be passed to the Agent. + **kwargs: + Optional. Any additional keyword arguments to be passed to the + `.invoke()` method of the corresponding AgentExecutor. + + Returns: + The output of querying the Agent with the given input and config. + """ + from agentplatform._genai.agent_engines import _agent_engines_utils + + if isinstance(input, str): + input = {"input": input} + + if not self._runnable: + self.set_up() + + if kwargs.get("batch"): + nest_asyncio = _agent_engines_utils._import_nest_asyncio_or_warn() + nest_asyncio.apply() + + return _agent_engines_utils.to_json_serializable_llama_index_object( + self._runnable.run(**input, **kwargs) + ) diff --git a/noxfile.py b/noxfile.py index aed97736a1..f7f1935c9e 100644 --- a/noxfile.py +++ b/noxfile.py @@ -109,6 +109,7 @@ "unit_agentplatform_adk", "unit_agentplatform_langchain", "unit_agentplatform_ag2", + "unit_agentplatform_llama_index", "system", "cover", "lint", @@ -394,6 +395,40 @@ def unit_agentplatform_ag2(session): ) +@nox.session(python=UNIT_TEST_TEMPLATES_PYTHON_VERSIONS) +def unit_agentplatform_llama_index(session): + # Install all test dependencies, then install this package in-place. + + constraints_path = str( + CURRENT_DIRECTORY / "testing" / "constraints-llama-index.txt" + ) + standard_deps = UNIT_TEST_STANDARD_DEPENDENCIES + UNIT_TEST_DEPENDENCIES + session.install(*standard_deps, "-c", constraints_path) + + # Install llama_index extras + session.install("-e", ".[llama_index_testing]", "-c", constraints_path) + + # Run py.test against the unit tests. + session.run( + "py.test", + "--quiet", + "--junitxml=unit_agentplatform_llama_index_sponge_log.xml", + "--cov=google", + "--cov-append", + "--cov-config=.coveragerc", + "--cov-report=", + "--cov-fail-under=0", + os.path.join( + "tests", + "unit", + "agentplatform", + "frameworks", + "test_frameworks_llama_index.py", + ), + *session.posargs, + ) + + @nox.session(python=UNIT_TEST_TEMPLATES_PYTHON_VERSIONS) def unit_langchain(session): # Install all test dependencies, then install this package in-place. diff --git a/tests/unit/agentplatform/frameworks/test_frameworks_adk.py b/tests/unit/agentplatform/frameworks/test_frameworks_adk.py index c7256a7c89..c1599b2e45 100644 --- a/tests/unit/agentplatform/frameworks/test_frameworks_adk.py +++ b/tests/unit/agentplatform/frameworks/test_frameworks_adk.py @@ -58,6 +58,7 @@ def __init__(self, name: str, model: str): _TEST_LOCATION = "us-central1" _TEST_PROJECT = "test-project" _TEST_PROJECT_ID = "test-project-id" +_TEST_MODEL = "gemini-2.0-flash" _TEST_API_KEY = "test-api-key" _TEST_MODEL = "gemini-2.0-flash" _TEST_USER_ID = "test_user_id" @@ -91,6 +92,83 @@ def __init__(self, name: str, model: str): "streaming_mode": "sse", "max_llm_calls": 500, } +_TEST_SESSION_EVENTS = [ + { + "author": "user", + "content": { + "parts": [ + { + "text": "What is the exchange rate from US dollars to " + "Swedish krona on 2025-09-25?" + } + ], + "role": "user", + }, + "id": "8967297909049524224", + "invocationId": "e-308f65d7-a99f-41e3-b80d-40feb5f1b065", + "timestamp": 1765832134.629513, + }, + { + "author": "currency_exchange_agent", + "content": { + "parts": [ + { + "functionCall": { + "args": { + "currency_date": "2025-09-25", + "currency_from": "USD", + "currency_to": "SEK", + }, + "id": "adk-136738ad-9e57-4cfb-8e23-b0f3e50a37d7", + "name": "get_exchange_rate", + } + } + ], + "role": "model", + }, + "id": "3155402589927899136", + "invocationId": "e-308f65d7-a99f-41e3-b80d-40feb5f1b065", + "timestamp": 1765832134.723713, + }, + { + "author": "currency_exchange_agent", + "content": { + "parts": [ + { + "functionResponse": { + "id": "adk-136738ad-9e57-4cfb-8e23-b0f3e50a37d7", + "name": "get_exchange_rate", + "response": { + "amount": 1, + "base": "USD", + "date": "2025-09-25", + "rates": {"SEK": 9.4118}, + }, + } + } + ], + "role": "user", + }, + "id": "1678221912150376448", + "invocationId": "e-308f65d7-a99f-41e3-b80d-40feb5f1b065", + "timestamp": 1765832135.764961, + }, + { + "author": "currency_exchange_agent", + "content": { + "parts": [ + { + "text": "The exchange rate from US dollars to Swedish " + "krona on 2025-09-25 is 1 USD to 9.4118 SEK." + } + ], + "role": "model", + }, + "id": "2470855446567583744", + "invocationId": "e-308f65d7-a99f-41e3-b80d-40feb5f1b065", + "timestamp": 1765832135.853299, + }, +] _TEST_STAGING_BUCKET = "gs://test-bucket" _TEST_CREDENTIALS = mock.Mock(spec=auth_credentials.AnonymousCredentials()) _TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" @@ -349,8 +427,16 @@ def test_set_up( ): app = adk_template.AdkApp(agent=_TEST_AGENT) assert app._tmpl_attrs.get("runner") is None + assert app._tmpl_attrs.get("session_service") is None + assert app._tmpl_attrs.get("artifact_service") is None + assert app._tmpl_attrs.get("memory_service") is None + assert app._tmpl_attrs.get("credential_service") is None app.set_up() assert app._tmpl_attrs.get("runner") is not None + assert app._tmpl_attrs.get("session_service") is not None + assert app._tmpl_attrs.get("artifact_service") is not None + assert app._tmpl_attrs.get("memory_service") is not None + assert app._tmpl_attrs.get("credential_service") is not None def test_clone( self, @@ -431,6 +517,51 @@ async def test_async_stream_query( events.append(event) assert len(events) == 1 + @pytest.mark.asyncio + async def test_async_stream_query_with_empty_session_events( + self, + get_project_id_mock: mock.Mock, + ): + app = adk_template.AdkApp( + agent=Agent(name=_TEST_AGENT_NAME, model=_TEST_MODEL) + ) + assert app._tmpl_attrs.get("runner") is None + app.set_up() + app._tmpl_attrs["runner"] = _MockRunner() + events = [] + async for event in app.async_stream_query( + user_id=_TEST_USER_ID, + session_events=[], + message="test message", + ): + events.append(event) + assert app._tmpl_attrs.get("session_service") is not None + sessions = app.list_sessions(user_id=_TEST_USER_ID) + assert len(sessions.sessions) == 1 + + @pytest.mark.asyncio + async def test_async_stream_query_with_session_events( + self, + default_instrumentor_builder_mock: mock.Mock, + get_project_id_mock: mock.Mock, + ): + app = adk_template.AdkApp( + agent=Agent(name=_TEST_AGENT_NAME, model=_TEST_MODEL) + ) + assert app._tmpl_attrs.get("runner") is None + app.set_up() + app._tmpl_attrs["runner"] = _MockRunner() + events = [] + async for event in app.async_stream_query( + user_id=_TEST_USER_ID, + session_events=_TEST_SESSION_EVENTS, + message="on the day after that?", + ): + events.append(event) + assert app._tmpl_attrs.get("session_service") is not None + sessions = app.list_sessions(user_id=_TEST_USER_ID) + assert len(sessions.sessions) == 1 + @pytest.mark.asyncio @mock.patch.dict( os.environ, @@ -646,6 +777,54 @@ def test_delete_session(self, get_project_id_mock: mock.Mock): response0 = app.list_sessions(user_id=_TEST_USER_ID) assert not response0.sessions + @pytest.mark.asyncio + async def test_async_add_session_to_memory( + self, + default_instrumentor_builder_mock: mock.Mock, + get_project_id_mock: mock.Mock, + ): + app = adk_template.AdkApp(agent=_TEST_AGENT) + assert app._tmpl_attrs.get("memory_service") is None + session = app.create_session(user_id=_TEST_USER_ID) + app._tmpl_attrs["runner"] = _MockRunner() + + from google.adk.events.event import Event + + session_obj = await app._tmpl_attrs["session_service"].get_session( + app_name=app._app_name(), + user_id=_TEST_USER_ID, + session_id=session["id"], + ) + await app._tmpl_attrs["session_service"].append_event( + session=session_obj, + event=Event( + author="user", + content={ + "parts": [{"text": "My cat's name is Garfield"}], + "role": "user", + }, + ), + ) + + list( + app.stream_query( + user_id=_TEST_USER_ID, + session_id=session["id"], + message="My cat's name is Garfield", + ) + ) + await app.async_add_session_to_memory( + session=app.get_session( + user_id=_TEST_USER_ID, + session_id=session["id"], + ) + ) + response = await app.async_search_memory( + user_id=_TEST_USER_ID, + query=_TEST_SEARCH_MEMORY_QUERY, + ) + assert len(response.memories) >= 1 + @pytest.mark.asyncio async def test_async_add_session_to_memory_dict( self, @@ -984,6 +1163,14 @@ async def test_raise_get_session_not_found_error(self, get_project_id_mock): session_id="test_session_id", ) + def test_stream_query_invalid_message_type(self): + app = adk_template.AdkApp(agent=_TEST_AGENT) + with pytest.raises( + TypeError, + match="message must be a string or a dictionary representing a Content object.", + ): + list(app.stream_query(user_id=_TEST_USER_ID, message=123)) + @pytest.mark.asyncio async def test_async_stream_query_invalid_message_type(self): app = adk_template.AdkApp(agent=_TEST_AGENT) @@ -1210,6 +1397,11 @@ def test_update_default_telemetry_enablement( class TestAdkAppMtls: """Test cases for mTLS functionality in AdkApp.""" + def setup_method(self): + import opentelemetry.trace + + opentelemetry.trace._TRACER_PROVIDER = None + def test_use_client_cert_effective_with_should_use_client_cert(self): """Verifies that it respects the google-auth mTLS enablement check.""" with mock.patch.object( diff --git a/tests/unit/agentplatform/frameworks/test_frameworks_llama_index.py b/tests/unit/agentplatform/frameworks/test_frameworks_llama_index.py new file mode 100644 index 0000000000..aaef28749b --- /dev/null +++ b/tests/unit/agentplatform/frameworks/test_frameworks_llama_index.py @@ -0,0 +1,345 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import importlib +import json +from unittest import mock + +from google import auth +import agentplatform +from google.cloud.aiplatform import initializer +from agentplatform.agent_engines.templates import ( + llama_index, +) +from agentplatform._genai.agent_engines import _agent_engines_utils + +from llama_index.core import prompts +from llama_index.core.base.llms import types + +import pytest + +_TEST_LOCATION = "us-central1" +_TEST_PROJECT = "test-project" +_TEST_MODEL = "gemini-1.0-pro" +_TEST_SYSTEM_INSTRUCTION = "You are a helpful bot." + + +@pytest.fixture(scope="module") +def google_auth_mock(): + with mock.patch.object(auth, "default") as google_auth_mock: + credentials_mock = mock.Mock() + credentials_mock.with_quota_project.return_value = None + google_auth_mock.return_value = ( + credentials_mock, + _TEST_PROJECT, + ) + yield google_auth_mock + + +@pytest.fixture +def agentplatform_init_mock(): + with mock.patch.object(agentplatform, "init") as agentplatform_init_mock: + yield agentplatform_init_mock + + +@pytest.fixture +def json_loads_mock(): + with mock.patch.object(json, "loads") as json_loads_mock: + yield json_loads_mock + + +@pytest.fixture +def model_builder_mock(): + with mock.patch.object( + llama_index, + "_default_model_builder", + ) as model_builder_mock: + yield model_builder_mock + + +@pytest.fixture +def cloud_trace_exporter_mock(): + with mock.patch.object( + _agent_engines_utils, + "_import_cloud_trace_exporter_or_warn", + ) as cloud_trace_exporter_mock: + yield cloud_trace_exporter_mock + + +@pytest.fixture +def tracer_provider_mock(): + with mock.patch("opentelemetry.sdk.trace.TracerProvider") as tracer_provider_mock: + yield tracer_provider_mock + + +@pytest.fixture +def simple_span_processor_mock(): + with mock.patch( + "opentelemetry.sdk.trace.export.SimpleSpanProcessor" + ) as simple_span_processor_mock: + yield simple_span_processor_mock + + +@pytest.fixture +def llama_index_instrumentor_mock(): + with mock.patch.object( + _agent_engines_utils, + "_import_openinference_llama_index_or_warn", + ) as llama_index_instrumentor_mock: + yield llama_index_instrumentor_mock + + +@pytest.fixture +def llama_index_instrumentor_none_mock(): + with mock.patch.object( + _agent_engines_utils, + "_import_openinference_llama_index_or_warn", + ) as llama_index_instrumentor_mock: + llama_index_instrumentor_mock.return_value = None + yield llama_index_instrumentor_mock + + +@pytest.fixture +def nest_asyncio_apply_mock(): + with mock.patch.object( + _agent_engines_utils, + "_import_nest_asyncio_or_warn", + ) as nest_asyncio_apply_mock: + yield nest_asyncio_apply_mock + + +@pytest.mark.usefixtures("google_auth_mock") +class TestLlamaIndexQueryPipelineAgent: + def setup_method(self): + importlib.reload(initializer) + importlib.reload(agentplatform) + agentplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + self.prompt = prompts.ChatPromptTemplate( + message_templates=[ + types.ChatMessage( + role=types.MessageRole.SYSTEM, + content=_TEST_SYSTEM_INSTRUCTION, + ), + types.ChatMessage( + role=types.MessageRole.USER, + content="{input}", + ), + ], + ) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_initialization(self): + agent = llama_index.LlamaIndexQueryPipelineAgent(model=_TEST_MODEL) + assert agent._model_name == _TEST_MODEL + assert agent._project == _TEST_PROJECT + assert agent._location == _TEST_LOCATION + assert agent._runnable is None + + def test_set_up(self): + agent = llama_index.LlamaIndexQueryPipelineAgent( + model=_TEST_MODEL, + prompt=self.prompt, + model_builder=lambda **kwargs: kwargs, + runnable_builder=lambda **kwargs: kwargs, + ) + assert agent._runnable is None + agent.set_up() + assert agent._runnable is not None + + def test_clone(self): + agent = llama_index.LlamaIndexQueryPipelineAgent( + model=_TEST_MODEL, + prompt=self.prompt, + model_builder=lambda **kwargs: kwargs, + runnable_builder=lambda **kwargs: kwargs, + ) + agent.set_up() + assert agent._runnable is not None + agent_clone = agent.clone() + assert agent._runnable is not None + assert agent_clone._runnable is None + agent_clone.set_up() + assert agent_clone._runnable is not None + + def test_query(self, json_loads_mock): + agent = llama_index.LlamaIndexQueryPipelineAgent( + model=_TEST_MODEL, + prompt=self.prompt, + ) + agent._runnable = mock.Mock() + mocks = mock.Mock() + mocks.attach_mock(mock=agent._runnable, attribute="run") + agent.query(input="test query") + mocks.assert_has_calls([mock.call.run.run(input="test query")]) + + def test_query_with_kwargs(self, json_loads_mock): + agent = llama_index.LlamaIndexQueryPipelineAgent( + model=_TEST_MODEL, + prompt=self.prompt, + ) + agent._runnable = mock.Mock() + mocks = mock.Mock() + mocks.attach_mock(mock=agent._runnable, attribute="run") + agent.query(input="test query", test_arg=123) + mocks.assert_has_calls([mock.call.run.run(input="test query", test_arg=123)]) + + def test_query_with_kwargs_and_input_dict(self, json_loads_mock): + agent = llama_index.LlamaIndexQueryPipelineAgent( + model=_TEST_MODEL, + prompt=self.prompt, + ) + agent._runnable = mock.Mock() + mocks = mock.Mock() + mocks.attach_mock(mock=agent._runnable, attribute="run") + agent.query(input={"input": "test query"}) + mocks.assert_has_calls([mock.call.run.run(input="test query")]) + + def test_query_with_batch_input(self, json_loads_mock, nest_asyncio_apply_mock): + agent = llama_index.LlamaIndexQueryPipelineAgent( + model=_TEST_MODEL, + prompt=self.prompt, + ) + agent._runnable = mock.Mock() + mocks = mock.Mock() + mocks.attach_mock(mock=agent._runnable, attribute="run") + agent.query(input={"input": ["test query 1", "test query 2"]}, batch=True) + mocks.assert_has_calls( + [mock.call.run.run(input=["test query 1", "test query 2"], batch=True)] + ) + nest_asyncio_apply_mock.assert_called_once() + + @pytest.mark.usefixtures("caplog") + def test_enable_tracing( + self, + caplog, + cloud_trace_exporter_mock, + tracer_provider_mock, + simple_span_processor_mock, + llama_index_instrumentor_mock, + ): + agent = llama_index.LlamaIndexQueryPipelineAgent( + model=_TEST_MODEL, + prompt=self.prompt, + enable_tracing=True, + ) + assert agent._instrumentor is None + # TODO(b/384730642): Re-enable this test once the parent issue is fixed. + # agent.set_up() + # assert agent._instrumentor is not None + # assert ( + # "enable_tracing=True but proceeding with tracing disabled" + # not in caplog.text + # ) + + @pytest.mark.usefixtures("caplog") + def test_enable_tracing_warning(self, caplog, llama_index_instrumentor_none_mock): + agent = llama_index.LlamaIndexQueryPipelineAgent( + model=_TEST_MODEL, + prompt=self.prompt, + enable_tracing=True, + ) + assert agent._instrumentor is None + # TODO(b/384730642): Re-enable this test once the parent issue is fixed. + # agent.set_up() + # assert "enable_tracing=True but proceeding with tracing disabled" in caplog.text + + +class TestToJsonSerializableLlamaIndexObject: + """Tests for `_utils.to_json_serializable_llama_index_object`.""" + + def test_llama_index_response(self): + mock_response: _agent_engines_utils.LlamaIndexResponse = mock.Mock( + spec=_agent_engines_utils.LlamaIndexResponse + ) + mock_response.response = "test response" + mock_response.source_nodes = [ + mock.Mock( + spec=_agent_engines_utils.LlamaIndexBaseModel, + model_dump_json=lambda: '{"name": "model1"}', + ), + mock.Mock( + spec=_agent_engines_utils.LlamaIndexBaseModel, + model_dump_json=lambda: '{"name": "model2"}', + ), + ] + mock_response.metadata = {"key": "value"} + + want = { + "response": "test response", + "source_nodes": ['{"name": "model1"}', '{"name": "model2"}'], + "metadata": {"key": "value"}, + } + got = _agent_engines_utils.to_json_serializable_llama_index_object(mock_response) + assert got == want + + def test_llama_index_chat_response(self): + mock_chat_response: _agent_engines_utils.LlamaIndexChatResponse = mock.Mock( + spec=_agent_engines_utils.LlamaIndexChatResponse + ) + mock_chat_response.message = mock.Mock( + spec=_agent_engines_utils.LlamaIndexBaseModel, + model_dump_json=lambda: '{"content": "chat message"}', + ) + + want = {"content": "chat message"} + got = _agent_engines_utils.to_json_serializable_llama_index_object(mock_chat_response) + assert got == want + + def test_llama_index_base_model(self): + mock_base_model: _agent_engines_utils.LlamaIndexBaseModel = mock.Mock( + spec=_agent_engines_utils.LlamaIndexBaseModel + ) + mock_base_model.model_dump_json = lambda: '{"name": "test_model"}' + + want = {"name": "test_model"} + got = _agent_engines_utils.to_json_serializable_llama_index_object(mock_base_model) + assert got == want + + def test_sequence_of_llama_index_base_model(self): + mock_base_model1: _agent_engines_utils.LlamaIndexBaseModel = mock.Mock( + spec=_agent_engines_utils.LlamaIndexBaseModel + ) + mock_base_model1.model_dump_json = lambda: '{"name": "test_model1"}' + mock_base_model2: _agent_engines_utils.LlamaIndexBaseModel = mock.Mock( + spec=_agent_engines_utils.LlamaIndexBaseModel + ) + mock_base_model2.model_dump_json = lambda: '{"name": "test_model2"}' + mock_base_model_list = [mock_base_model1, mock_base_model2] + + want = [{"name": "test_model1"}, {"name": "test_model2"}] + got = _agent_engines_utils.to_json_serializable_llama_index_object(mock_base_model_list) + assert got == want + + def test_sequence_of_mixed_types(self): + mock_base_model: _agent_engines_utils.LlamaIndexBaseModel = mock.Mock( + spec=_agent_engines_utils.LlamaIndexBaseModel + ) + mock_base_model.model_dump_json = lambda: '{"name": "test_model"}' + mock_string = "test_string" + mock_list = [mock_base_model, mock_string] + + want = [{"name": "test_model"}, "test_string"] + got = _agent_engines_utils.to_json_serializable_llama_index_object(mock_list) + assert got == want + + def test_other_type(self): + test_dict = {"name": "test_model"} + want = "{'name': 'test_model'}" + got = _agent_engines_utils.to_json_serializable_llama_index_object(test_dict) + assert got == want