A robust system for large-scale text inference using Vertex AI (Gemini).
This repository hosts a modular framework to orchestrate large-scale batch and live inference requests to Gemini models.
- π¦ Installation
- π Quick Start
- βοΈ Core Functions
- π‘ Example Usage
- π Package Overview
- ποΈ System Architecture
- π Key Concepts and Features
- π§© Core Components
- π Table Schema
- π License
You can install the package directly from GitHub:
pip install git+https://github.com/ericzhao28/easyinference.git- Set up your credentials for GCP and Vertex AI:
gcloud auth application-default login- Configure the necessary environment variables:
# Google Cloud Platform Configuration
export GCP_PROJECT_ID="your-project-id"
export GCP_PROJECT_NUM="123456789012"
export GCP_REGION="us-central1"
export VERTEX_BUCKET="your-gcs-bucket"
export GEMINI_API_KEY=""
# SQL Configuration
export TABLE_NAME="your-table"
export SQL_DATABASE_NAME="your-database"
export SQL_USER="db-user"
export SQL_PASSWORD="your-password"
export SQL_INSTANCE_CONNECTION_NAME="project-id:region:instance-name"
export POOL_SIZE="50"
# Local Postgres Configuration (Optional)
export DB_TYPE="local"
export LOCAL_POSTGRES_HOST="localhost"
export LOCAL_POSTGRES_PORT="5432"
# Additional Configuration
export COOLDOWN_SECONDS="1.0"
export MAX_RETRIES="8"
export BATCH_TIMEOUT_HOURS="3"
export ROUND_ROBIN_ENABLED="false"Alternatively, you can use the provided example.env file:
- Copy
example.envto.env - Update the values in
.envwith your configuration - Use
python-dotenvto load these variables in your code - Make sure you set your environment variables before importing
easyinference. Otherwise, you should runeasyinference.reload_config()after setting your environment variables.
- Initialize the database connection:
from easyinference import initialize_query_connection
# Initialize the database connection before using any inference functions
initialize_query_connection()- Import and use the package:
from dotenv import load_dotenv # pip install python-dotenv
# Load environment variables from .env file (if using this approach)
load_dotenv()
from easyinference import inference, individual_inference, run_clearing_inference, reload_config, initialize_query_connection
# Initialize the database connection
initialize_query_connection()Main async function for batch processing multiple datapoints
async def inference(
prompt_functions: List[Callable[[Any], str]], # Functions that convert datapoints to prompt text
datapoints: List[Any], # List of data items to process
tags: Optional[List[str]] = None, # Identifier tags for tracking
duplication_indices: Optional[List[int]] = None, # Indices for running datapoints multiple times
run_fast: bool = True, # If True, makes direct API calls; if False, queues for batch
allow_failure: bool = False, # If True, continues after max retries with error messages
attempts_cap: int = 8, # Maximum number of retry attempts
temperature: float = 0, # Temperature parameter for generation
max_output_tokens: int = 65535, # Maximum tokens to generate in response
thinking_budget_tokens: int = 32768, # Maximum tokens to generate in response
system_prompt: str = "", # System prompt to guide model behavior
model: str = "gemini-2.5-pro-preview-06-05", # Generative model to use
batch_size: int = 1000, # Max concurrent requests or batch job size
run_fast_timeout: float = 200, # Timeout in seconds for fast mode calls
cooldown_seconds: float = 1.0, # Base wait time between retries
batch_timeout_hours: int = 3, # Max runtime before restarting
round_robin_enabled: bool = False, # Whether to cycle through regions
round_robin_options: List[str] = ["us-central1", "us-west1", "us-east1", "us-west4", "us-east4", "us-east5", "us-south1"], # Region options for cycling
initial_histories: Optional[List[dict]] = None, # Starting conversation histories for the inference sessions
) -> tuple[List[tuple], str] # Returns ([[[response 1, response 2, ...], [query 1, query 2, ...]], ... for each datapoint], launch_timestamp_tag)For processing a single datapoint through multiple prompt functions
async def individual_inference(
prompt_functions: List[Callable[[Any], str]], # Functions that convert datapoint to prompt text
datapoint: Any, # Data to process
tags: Optional[List[str]] = None, # Identifier tags for tracking
optional_tags: Optional[List[str]] = None, # Additional tags not used for lookup
duplication_index: int = 0, # Index to distinguish duplicate runs
run_fast: bool = True, # If True, makes direct API calls; if False, queues for batch
allow_failure: bool = False, # If True, continues after max retries with error messages
attempts_cap: int = 8, # Maximum number of retry attempts
temperature: float = 0, # Temperature parameter for generation
max_output_tokens: int = 65535, # Maximum tokens to generate in response
thinking_budget_tokens: int = 32768, # Maximum tokens to generate in response
system_prompt: str = "", # System prompt to guide model behavior
model: str = "gemini-2.5-pro-preview-06-05", # Generative model to use
run_fast_timeout: float = 200, # Timeout in seconds for fast mode calls
cooldown_seconds: float = 1.0, # Base wait time between retries
round_robin_enabled: bool = False, # Whether to cycle through regions
round_robin_options: List[str] = ["us-central1", "us-west1", "us-east1", "us-west4", "us-east4", "us-east5", "us-south1"], # Region options for cycling
initial_history_json: Optional[dict] = None, # Starting conversation history for the inference session
) -> tuple[List[str], List[str]] # Returns [[response 1, response 2, ...], [query 1, query 2, ...]]For managing batch inference jobs
async def run_clearing_inference(
tag: str, # Unique identifier tag for the batch
batch_size: int, # Maximum number of requests per batch job
run_batch_jobs: bool, # Whether to launch new batch jobs
batch_timeout_hours: int = 3 # Maximum runtime hours before restarting
) -> NoneFor reloading the config after setting environment variables
def reload_config() -> Noneimport asyncio
from dotenv import load_dotenv
from easyinference import inference, reload_config, initialize_query_connection
load_dotenv()
reload_config()
# Initialize the database connection before using any inference functions
initialize_query_connection()
async def process_data():
# Define data and prompt function
datapoints = [
{"text": "What is machine learning?"},
{"text": "Explain neural networks"}
]
def create_prompt(dp):
return f"Please explain: {dp['text']}"
# Run inference
results, timestamp = await inference(
prompt_functions=[create_prompt],
datapoints=datapoints,
tags=["explanation", "v1"],
run_fast=True
)
# Process results
first_datapoint_result, second_datapoint_result = results
for i, (response, query) in enumerate(first_datapoint_result):
print(f"Query: {query}")
print(f"Response: {response}")
return results
# Run the async function
results = asyncio.run(process_data())Goal: We provide a scalable and robust pipeline to handle:
- β¨ Conversation-based inference requests to Gemini models
- β¨ Failure tracking and retry logic to ensure stable operation
- β¨ Asynchronous or synchronous methods for generating text from the model
We accomplish this by:
- Storing every inference "step" in a PostgreSQL table, which captures the query text, model parameters, conversation history, and final responses (or errors).
- Separating "fast" live calls vs. "slow" batch-based calls.
- Monitoring the status of batch inference jobs, so you can schedule or restart them if they take too long.
- Allowing different usage patterns: single datapoint or bulk processing, with multi-prompt sequences, concurrency caps, and re-tries.
βββββββββββββββββββββ
β Your Application β
βββββββββββ¬ββββββββββ
β
βββββββββββββββββββ
β β
βΌ βΌ
βββββββββββββββββββββββ βββββββββββββββββββββββ
βIndividual Inference β β Inference β
β (Fast) ββ---β β
ββββββββββββ¬βββββββββββ ββββββββββββ¬βββββββββββ
β β
β βΌ
β βββββββββββββββββββββββββββ
β β Batch Clearing β
β β (monitoring) β
β β β
β ββββββββββββ¬βββββββββββββββ
β β
βΌ βΌ
ββββββββββββββββββββββββββ ββββββββββββββββββββββββββ
β Vertex AI (Gemini API) β β Vertex AI (Gemini API) β
β (Live Calls) β β (Batch Job) β
ββββββββββββββββββββββββββ ββββββββββββββββββββββββββ
β β
ββββββββββββ βββββββββββββ
βΌ βΌ
ββββββββββββββββββββββ
β PostgreSQL β
β Master Table β
ββββββββββββββββββββββ
Individual Inferencemanages a single datapoint and a sequence of prompts.Inferenceis a bulk orchestrator that calls individual inference on multiple datapoints.Clearing Inferencetakes unprocessed/failed rows and triggers additional attempts (live or batch). It also monitors batch jobs and handles timeouts.
Stored in PostgreSQL under history_json as a JSON object:
{
"history": [
{"role": "user", "parts": {"text": "Hello, how are you?"}},
{"role": "model", "parts": {"text": "I am fine. How can I help?"}}
]
}This helps Vertex continue the same conversation context across multiple queries without duplication.
Stored under generation_params_json (JSON):
{
"temperature": 0.7,
"max_output_tokens": 65535,
"system_prompt": "You are a helpful assistant..."
}An integer marking whether a row is an exact duplicate of an earlier row (e.g., a re-run). Defaults to 0.
A list of strings (alphabetically sorted) representing categories or labels applied to a request (e.g. ["admin", "api-v1"]).
This can help in filtering or grouping by usage scenario.
Either "intentional" (explicit user request) or "backup" (an automatic fallback).
Last Statuscan be"PENDING","RUNNING","FAILED","SUCCEEDED","WAITING".Failure Counttracks how many attempts have failed so far, andAttempts Capsets the max allowed.
A hash of (Model, History, Query, GenerationParams, DuplicationIndex) for deduplicating or resuming.
- Run Fast: calls the Vertex API directly, returning the result in real-time.
- Run Slow: queues up the request for a batch job. The
run_clearing_inferencefunction handles job submission and monitoring.
Before using any inference functions, you must initialize the database connection by calling:
from easyinference import initialize_query_connection
initialize_query_connection()This sets up the necessary connections to the PostgreSQL database for tracking inference requests.
- Defines a
ConvoRowdata class that mirrors each column in the table. - Enumerations for
RequestStatusandRequestCause.
EasyInference supports three database configuration options:
- Google Cloud SQL (default)
- Local Postgres for development or when using your own database infrastructure
- No Database for simple inference without tracking or batch processing
- Choose your database type by setting the
DB_TYPEenvironment variable:
# Use Google Cloud SQL (default)
export DB_TYPE="gcp"
# Use local Postgres
export DB_TYPE="local"
export LOCAL_POSTGRES_HOST="localhost" # Or your Postgres server address
export LOCAL_POSTGRES_PORT="5432" # Or your Postgres server port
# Use no database
export DB_TYPE="none"- For Google Cloud SQL or local Postgres, set the required database parameters:
# Required for both GCP and local Postgres options
export SQL_DATABASE_NAME="your-database"
export SQL_USER="db-user"
export SQL_PASSWORD="your-password"
export TABLE_NAME="your-table"
# Only required for GCP option
export SQL_INSTANCE_CONNECTION_NAME="project-id:region:instance-name"- Initialize the database connection as usual:
from easyinference import initialize_query_connection
initialize_query_connection()You can easily switch between database types in your Python code:
import os
from easyinference import reload_config, initialize_query_connection
# Switch to local Postgres
os.environ["DB_TYPE"] = "local"
reload_config()
initialize_query_connection()
# Later, switch to Google Cloud SQL
os.environ["DB_TYPE"] = "gcp"
reload_config()
initialize_query_connection()
# Or disable database operations entirely
os.environ["DB_TYPE"] = "none"
reload_config()
initialize_query_connection()When DB_TYPE="none", EasyInference operates without any database tracking. In this mode:
- No database connection is established
- Batch inference is not available (will raise an error)
- Tagged inference is not available (will raise an error)
- Only direct, synchronous inference calls without tags are supported
- Helper functions to insert, update, or read rows from PostgreSQL.
- Includes concurrency checks so you don't overwrite a "SUCCEEDED" row with "FAILED."
- Functions for connecting to PostgreSQL, creating tables, and querying data.
- Implements both
individual_inferenceandinferencefunctions - Contains
run_chat_inference_asyncfor "fast" calls with built-in retry/backoff - Implements
run_clearing_inferencethat handles both batch submission and monitoring - Manages the logic for deduplicating (by content hash), incrementing failure counts, and handling partial successes
- Configuration settings for database connections, retry logic, and batch operations.
- Contains defaults for constants like
MAX_RETRIES,BATCH_TIMEOUT_HOURS, and various connection parameters.
Your master PostgreSQL table has the following columns:
| Column Name | Type | Description |
|---|---|---|
| row_id | INTEGER | Auto-incrementing primary key |
| content_hash | STRING | SHA-256 hash of key fields for deduplication |
| history_json | JSON | JSON storing prior conversation messages in a format with the key "history" |
| query | STRING | User's latest query that needs a response |
| model | STRING | Full path of the model (e.g. "gemini-2.5-pro-preview-06-05") |
| generation_params_json | JSON | JSON storing generation settings, e.g. {"temperature":0.7,"max_output_tokens":8192,"system_prompt":"..."} |
| duplication_index | INTEGER | Used to mark re-runs or explicit duplicates. Defaults to 0 |
| tags | ARRAY(STRING) | A sorted list of tags (e.g. ["api-v1","testing"]) |
| request_cause | STRING | "intentional" or "backup". Uses the RequestCause enum |
| request_timestamp | STRING | ISO 8601 timestamp ("2025-02-25T12:34:56Z") |
| access_timestamps | ARRAY(STRING) | List of ISO 8601 timestamps of each read/update |
| attempts_metadata_json | ARRAY(JSON) | JSON array of prior attempts, storing batch info and error messages |
| response_json | JSON | JSON containing the final successful response if available. Example: {"text":"...response..."} |
| current_batch | STRING | The ID of any currently running batch job. Can be NULL |
| last_status | STRING | "PENDING", "RUNNING", "FAILED", "SUCCEEDED", or "WAITING" |
| failure_count | INTEGER | How many times this row has failed so far |
| attempts_cap | INTEGER | The maximum number of times we will re-try |
| notes | STRING | Optional free-text notes |
| insertion_timestamp | TIMESTAMP | When the row was inserted into the database |
- SHA-256 over the combination of
(Model, History, Query, GenerationParams, DuplicationIndex). - Ensures we don't re-run the same content multiple times unless we want to.
- A query can have tags like
["api-v1","admin-request"]. The system enforces that the tag list is alphabetically sorted. - For batch mode, a timestamp tag is automatically added for tracking.
This project is provided under the MIT License.
Feel free to modify or extend the code to suit your deployment and usage requirements.