diff --git a/src/__tests__/factory.test.ts b/src/__tests__/factory.test.ts index 0ae6a44..ceed2aa 100755 --- a/src/__tests__/factory.test.ts +++ b/src/__tests__/factory.test.ts @@ -18,17 +18,20 @@ import { defaultLatencyHistogram } from '../utils/latency-histogram'; const mockOpenAIProvider = { name: 'openai', models: ['gpt-4', 'gpt-3.5-turbo'], - supportsStreaming: true, - supportsTools: true, - supportsBatching: true, - generateResponse: vi.fn().mockResolvedValue({ + supportsStreaming: true, + supportsTools: true, + supportsBatching: true, + supportsVision: true, + generateResponse: vi.fn().mockResolvedValue({ message: 'OpenAI response', usage: { inputTokens: 10, outputTokens: 20, totalTokens: 30, cost: 0.001 }, model: 'gpt-3.5-turbo', provider: 'openai', responseTime: 1000 - } as LLMResponse), - validateConfig: vi.fn().mockReturnValue(true), + } as LLMResponse), + streamResponse: vi.fn(), + getProviderBalance: vi.fn(), + validateConfig: vi.fn().mockReturnValue(true), getModels: vi.fn().mockReturnValue(['gpt-4', 'gpt-3.5-turbo']), estimateCost: vi.fn().mockReturnValue(0.001), healthCheck: vi.fn().mockResolvedValue(true), @@ -47,17 +50,20 @@ const mockOpenAIProvider = { const mockAnthropicProvider = { name: 'anthropic', models: ['claude-3-haiku-20240307', 'claude-3-sonnet-20240229'], - supportsStreaming: true, - supportsTools: true, - supportsBatching: false, - generateResponse: vi.fn().mockResolvedValue({ + supportsStreaming: true, + supportsTools: true, + supportsBatching: false, + supportsVision: true, + generateResponse: vi.fn().mockResolvedValue({ message: 'Anthropic response', usage: { inputTokens: 10, outputTokens: 20, totalTokens: 30, cost: 0.002 }, model: 'claude-3-haiku-20240307', provider: 'anthropic', responseTime: 1200 - } as LLMResponse), - validateConfig: vi.fn().mockReturnValue(true), + } as LLMResponse), + streamResponse: vi.fn(), + getProviderBalance: vi.fn(), + validateConfig: vi.fn().mockReturnValue(true), getModels: vi.fn().mockReturnValue(['claude-3-haiku-20240307', 'claude-3-sonnet-20240229']), estimateCost: vi.fn().mockReturnValue(0.002), healthCheck: vi.fn().mockResolvedValue(true), @@ -76,17 +82,20 @@ const mockAnthropicProvider = { const mockCloudflareProvider = { name: 'cloudflare', models: ['@cf/meta/llama-3.1-8b-instruct'], - supportsStreaming: true, - supportsTools: false, - supportsBatching: true, - generateResponse: vi.fn().mockResolvedValue({ + supportsStreaming: true, + supportsTools: false, + supportsBatching: true, + supportsVision: false, + generateResponse: vi.fn().mockResolvedValue({ message: 'Cloudflare response', usage: { inputTokens: 10, outputTokens: 20, totalTokens: 30, cost: 0.0001 }, model: '@cf/meta/llama-3.1-8b-instruct', provider: 'cloudflare', responseTime: 800 - } as LLMResponse), - validateConfig: vi.fn().mockReturnValue(true), + } as LLMResponse), + streamResponse: vi.fn(), + getProviderBalance: vi.fn(), + validateConfig: vi.fn().mockReturnValue(true), getModels: vi.fn().mockReturnValue(['@cf/meta/llama-3.1-8b-instruct']), estimateCost: vi.fn().mockReturnValue(0.0001), healthCheck: vi.fn().mockResolvedValue(true), @@ -138,6 +147,17 @@ describe('LLMProviderFactory', () => { provider: 'openai', responseTime: 1000 } as LLMResponse); + mockOpenAIProvider.streamResponse.mockReset().mockResolvedValue(new ReadableStream({ + start(controller) { + controller.enqueue('OpenAI stream'); + controller.close(); + } + })); + mockOpenAIProvider.getProviderBalance.mockReset().mockResolvedValue({ + provider: 'openai', + status: 'available', + source: 'provider_api' + }); mockAnthropicProvider.generateResponse.mockReset().mockResolvedValue({ message: 'Anthropic response', usage: { inputTokens: 10, outputTokens: 20, totalTokens: 30, cost: 0.002 }, @@ -145,6 +165,17 @@ describe('LLMProviderFactory', () => { provider: 'anthropic', responseTime: 1200 } as LLMResponse); + mockAnthropicProvider.streamResponse.mockReset().mockResolvedValue(new ReadableStream({ + start(controller) { + controller.enqueue('Anthropic stream'); + controller.close(); + } + })); + mockAnthropicProvider.getProviderBalance.mockReset().mockResolvedValue({ + provider: 'anthropic', + status: 'available', + source: 'provider_api' + }); mockCloudflareProvider.generateResponse.mockReset().mockResolvedValue({ message: 'Cloudflare response', usage: { inputTokens: 10, outputTokens: 20, totalTokens: 30, cost: 0.0001 }, @@ -152,6 +183,17 @@ describe('LLMProviderFactory', () => { provider: 'cloudflare', responseTime: 800 } as LLMResponse); + mockCloudflareProvider.streamResponse.mockReset().mockResolvedValue(new ReadableStream({ + start(controller) { + controller.enqueue('Cloudflare stream'); + controller.close(); + } + })); + mockCloudflareProvider.getProviderBalance.mockReset().mockResolvedValue({ + provider: 'cloudflare', + status: 'unavailable', + source: 'not_supported' + }); factory = new LLMProviderFactory({ openai: { apiKey: 'test-openai-key' }, @@ -184,7 +226,7 @@ describe('LLMProviderFactory', () => { }); }); - describe('Response Generation', () => { + describe('Response Generation', () => { it('should generate response using available provider', async () => { const response = await factory.generateResponse(testRequest); @@ -347,6 +389,169 @@ describe('LLMProviderFactory', () => { ); }); }); + + describe('Streaming, tools, classification, and vision', () => { + async function readStream(stream: ReadableStream): Promise { + const reader = stream.getReader(); + let output = ''; + while (true) { + const { done, value } = await reader.read(); + if (done) break; + output += value; + } + return output; + } + + it('should stream through the factory and fallback before the first chunk', async () => { + const streamFactory = new LLMProviderFactory({ + openai: { apiKey: 'test-openai-key' }, + anthropic: { apiKey: 'test-anthropic-key' }, + defaultProvider: 'openai', + costOptimization: false, + fallbackRules: [{ condition: 'error', fallbackProvider: 'anthropic' }] + }); + + mockOpenAIProvider.streamResponse.mockRejectedValueOnce(new Error('stream start failed')); + + const stream = await streamFactory.generateResponseStream(testRequest); + + expect(await readStream(stream)).toBe('Anthropic stream'); + expect(mockOpenAIProvider.streamResponse).toHaveBeenCalled(); + expect(mockAnthropicProvider.streamResponse).toHaveBeenCalled(); + }); + + it('should call quota hooks before and after successful dispatch', async () => { + const quotaHook = { + check: vi.fn().mockResolvedValue({ allowed: true, remainingBudget: 1 }), + record: vi.fn().mockResolvedValue(undefined) + }; + const quotaFactory = new LLMProviderFactory({ + openai: { apiKey: 'test-openai-key' }, + defaultProvider: 'openai', + costOptimization: false, + quotaHook + }); + + await quotaFactory.generateResponse({ ...testRequest, tenantId: 'tenant-1' }); + + expect(quotaHook.check).toHaveBeenCalledWith(expect.objectContaining({ + tenantId: 'tenant-1', + provider: 'openai', + model: 'gpt-4' + })); + expect(quotaHook.record).toHaveBeenCalledWith(expect.objectContaining({ + tenantId: 'tenant-1', + provider: 'openai', + actualCost: 0.001, + inputTokens: 10, + outputTokens: 20 + })); + }); + + it('should deny dispatch when quota hook rejects the request', async () => { + const quotaFactory = new LLMProviderFactory({ + openai: { apiKey: 'test-openai-key' }, + defaultProvider: 'openai', + costOptimization: false, + quotaHook: { + check: vi.fn().mockResolvedValue({ allowed: false, reason: 'budget exhausted' }), + record: vi.fn() + } + }); + + await expect(quotaFactory.generateResponse(testRequest)).rejects.toThrow('budget exhausted'); + expect(mockOpenAIProvider.generateResponse).not.toHaveBeenCalled(); + }); + + it('should execute tool loops until the final response has no tool calls', async () => { + const toolResponse: LLMResponse = { + message: '', + usage: { inputTokens: 5, outputTokens: 5, totalTokens: 10, cost: 0.001 }, + model: 'gpt-3.5-turbo', + provider: 'openai', + responseTime: 10, + finishReason: 'tool_calls', + toolCalls: [{ + id: 'call-1', + type: 'function', + function: { name: 'lookup', arguments: '{"id":42}' } + }] + }; + const finalResponse: LLMResponse = { + message: 'done', + usage: { inputTokens: 5, outputTokens: 5, totalTokens: 10, cost: 0.001 }, + model: 'gpt-3.5-turbo', + provider: 'openai', + responseTime: 10 + }; + mockOpenAIProvider.generateResponse + .mockResolvedValueOnce(toolResponse) + .mockResolvedValueOnce(finalResponse); + + const loopFactory = new LLMProviderFactory({ + openai: { apiKey: 'test-openai-key' }, + defaultProvider: 'openai', + costOptimization: false + }); + const executor = { execute: vi.fn().mockResolvedValue({ value: 42 }) }; + + const response = await loopFactory.generateResponseWithTools(testRequest, executor); + + expect(response.message).toBe('done'); + expect(executor.execute).toHaveBeenCalledWith('lookup', { id: 42 }); + expect(mockOpenAIProvider.generateResponse).toHaveBeenLastCalledWith( + expect.objectContaining({ + messages: expect.arrayContaining([ + expect.objectContaining({ + toolResults: [{ id: 'call-1', output: '{"value":42}' }] + }) + ]) + }) + ); + }); + + it('should classify JSON responses and expose confidence', async () => { + mockOpenAIProvider.generateResponse.mockResolvedValueOnce({ + message: '{"label":"recipe","confidence":0.92}', + usage: { inputTokens: 5, outputTokens: 5, totalTokens: 10, cost: 0.001 }, + model: 'gpt-3.5-turbo', + provider: 'openai', + responseTime: 10 + } as LLMResponse); + + const classifyFactory = new LLMProviderFactory({ + openai: { apiKey: 'test-openai-key' }, + defaultProvider: 'openai', + costOptimization: false + }); + + const result = await classifyFactory.classify<{ label: string; confidence: number }>('classify this'); + + expect(result.data.label).toBe('recipe'); + expect(result.confidence).toBe(0.92); + }); + + it('should route image analysis to a vision-capable provider', async () => { + const visionFactory = new LLMProviderFactory({ + anthropic: { apiKey: 'test-anthropic-key' }, + cloudflare: { ai: {} as Ai }, + costOptimization: false + }); + + await visionFactory.analyzeImage({ + image: { data: 'abc123', mimeType: 'image/jpeg' }, + prompt: 'Extract recipe text' + }); + + expect(mockAnthropicProvider.generateResponse).toHaveBeenCalledWith( + expect.objectContaining({ + images: [{ data: 'abc123', mimeType: 'image/jpeg' }], + model: 'claude-haiku-4-5-20251001' + }) + ); + expect(mockCloudflareProvider.generateResponse).not.toHaveBeenCalled(); + }); + }); describe('Error Handling', () => { it('should handle all providers failing', async () => { @@ -449,6 +654,33 @@ describe('LLMProviderFactory', () => { expect(accumulator!.rateLimits.tpm!.used).toBe(30); expect(accumulator!.rateLimits.tpd!.used).toBe(30); }); + + it('should expose provider balance from the configured ledger', async () => { + const ledger = new CreditLedger({ + budgets: [{ + provider: 'cloudflare', + monthlyBudget: 1, + rateLimits: { rpm: 10 } + }] + }); + const balanceFactory = new LLMProviderFactory({ + cloudflare: { ai: {} as Ai }, + ledger + }); + + await balanceFactory.generateResponse(testRequest); + const balance = await balanceFactory.getProviderBalance('cloudflare'); + + expect(balance).toMatchObject({ + provider: 'cloudflare', + status: 'available', + source: 'ledger', + currentSpend: 0.0001, + monthlyBudget: 1, + requestCount: 1 + }); + expect((balance as { rateLimits: Record }).rateLimits.rpm.used).toBe(1); + }); }); }); diff --git a/src/__tests__/groq.test.ts b/src/__tests__/groq.test.ts index a225224..2d8b43c 100644 --- a/src/__tests__/groq.test.ts +++ b/src/__tests__/groq.test.ts @@ -80,6 +80,18 @@ describe('GroqProvider', () => { }); }); + describe('getProviderBalance', () => { + it('should report Groq billing API as unavailable', async () => { + const balance = await provider.getProviderBalance(); + + expect(balance).toMatchObject({ + provider: 'groq', + status: 'unavailable', + source: 'not_supported' + }); + }); + }); + describe('generateResponse', () => { it('should generate a response successfully', async () => { mockFetch.mockResolvedValueOnce({ @@ -114,6 +126,51 @@ describe('GroqProvider', () => { expect(response.finishReason).toBe('stop'); }); + it('should forward Cloudflare AI Gateway metadata headers only for gateway base URLs', async () => { + provider = new GroqProvider({ + apiKey: 'test-groq-key', + baseUrl: 'https://gateway.ai.cloudflare.com/v1/account/gateway/groq' + }); + mockFetch.mockResolvedValueOnce({ + ok: true, + json: async () => ({ + id: 'chatcmpl-123', + object: 'chat.completion', + created: 1700000000, + model: 'llama-3.3-70b-versatile', + choices: [{ + index: 0, + message: { role: 'assistant', content: 'Hello!' }, + finish_reason: 'stop' + }], + usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 } + }), + headers: new Headers({ 'content-type': 'application/json' }) + }); + + await provider.generateResponse({ + ...testRequest, + requestId: 'req-1', + tenantId: 'tenant-1', + gatewayMetadata: { + requestId: 'gw-req-1', + cacheKey: 'cache-key', + cacheTtl: 60, + customMetadata: { feature: 'demo' } + } + }); + + const headers = mockFetch.mock.calls[0][1].headers; + expect(headers['cf-aig-cache-key']).toBe('cache-key'); + expect(headers['cf-aig-cache-ttl']).toBe('60'); + expect(JSON.parse(headers['cf-aig-metadata'])).toMatchObject({ + requestId: 'gw-req-1', + llmRequestId: 'req-1', + tenantId: 'tenant-1', + feature: 'demo' + }); + }); + it('should send correct request format', async () => { mockFetch.mockResolvedValueOnce({ ok: true, diff --git a/src/errors.ts b/src/errors.ts index 119bd80..85640df 100755 --- a/src/errors.ts +++ b/src/errors.ts @@ -97,7 +97,7 @@ export class ConfigurationError extends LLMProviderError { } } -export class CircuitBreakerOpenError extends LLMProviderError { +export class CircuitBreakerOpenError extends LLMProviderError { retryAfterSec: number; consecutiveFailures: number; @@ -111,8 +111,20 @@ export class CircuitBreakerOpenError extends LLMProviderError { ); this.retryAfterSec = retryAfterSec; this.consecutiveFailures = consecutiveFailures; - } -} + } +} + +export class ToolLoopLimitError extends LLMProviderError { + constructor(provider: string, message: string = 'Tool loop limit exceeded') { + super(message, 'TOOL_LOOP_LIMIT', provider, false, 400); + } +} + +export class ToolLoopAbortedError extends LLMProviderError { + constructor(provider: string, message: string = 'Tool loop aborted') { + super(message, 'TOOL_LOOP_ABORTED', provider, false, 400); + } +} /** * Error factory for creating provider-specific errors from HTTP responses diff --git a/src/factory.ts b/src/factory.ts index fcd413a..48498d9 100755 --- a/src/factory.ts +++ b/src/factory.ts @@ -8,6 +8,9 @@ import type { LLMConfig, LLMRequest, LLMResponse, + AnalyzeImageInput, + ClassifyOptions, + ClassifyResult, ProviderConfig, OpenAIConfig, AnthropicConfig, @@ -16,7 +19,14 @@ import type { GroqConfig, FallbackRule, ProviderMetrics, - CircuitBreakerState + CircuitBreakerState, + ProviderBalance, + QuotaHook, + QuotaCheckInput, + QuotaRecordInput, + ToolExecutor, + ToolLoopOptions, + ToolLoopState } from './types'; import type { Logger } from './utils/logger'; @@ -41,6 +51,8 @@ import { AuthenticationError, RateLimitError, QuotaExceededError, + ToolLoopAbortedError, + ToolLoopLimitError, } from './errors'; export interface ProviderFactoryConfig { @@ -55,6 +67,9 @@ export interface ProviderFactoryConfig { enableCircuitBreaker?: boolean; enableRetries?: boolean; ledger?: CreditLedger; + quotaHook?: QuotaHook; + quotaFailPolicy?: 'closed' | 'open'; + defaultVisionModel?: string; logger?: Logger; hooks?: ObservabilityHooks; } @@ -197,6 +212,8 @@ export class LLMProviderFactory { const providerRequest = this.requestForProvider(request, providerName, providerModels); const model = providerRequest.model || provider.models[0] || 'unknown'; + await this.checkQuota(providerName, provider, providerRequest, model); + this.hooks.onRequestStart?.({ provider: providerName, model, @@ -224,6 +241,7 @@ export class LLMProviderFactory { if (this.config.costOptimization || this.config.ledger) { this.costTracker.trackCost(providerName, response); } + this.recordQuota(providerName, response, providerRequest); this.logger.debug(`[LLMProviderFactory] Successfully used provider: ${providerName}`); return response; @@ -280,6 +298,279 @@ export class LLMProviderFactory { ); } + async generateResponseStream(request: LLMRequest): Promise> { + const providerChain = this.buildProviderChain({ ...request, stream: true }); + const providerModels = new Map(); + let lastError: Error | null = null; + let previousProvider: string | null = null; + + for (let index = 0; index < providerChain.length; index++) { + const providerName = providerChain[index]; + try { + const provider = this.providers.get(providerName); + if (!provider || !provider.supportsStreaming || !provider.streamResponse) continue; + if (defaultExhaustionRegistry.isExhausted(providerName)) continue; + if (this.config.enableCircuitBreaker && defaultCircuitBreakerManager.getBreaker(providerName).isOpen()) continue; + if (this.config.ledger && this.isLedgerLimited(providerName)) continue; + + if (previousProvider && lastError) { + this.hooks.onFallback?.({ + fromProvider: previousProvider, + toProvider: providerName, + requestId: request.requestId, + reason: lastError.message, + errorCode: (lastError as { code?: string }).code, + timestamp: Date.now(), + }); + } + + const providerRequest = { + ...this.requestForProvider(request, providerName, providerModels), + stream: true + }; + const model = providerRequest.model || provider.models[0] || 'unknown'; + const estimatedCost = await this.checkQuota(providerName, provider, providerRequest, model); + + this.hooks.onRequestStart?.({ + provider: providerName, + model, + requestId: request.requestId, + tenantId: request.tenantId, + timestamp: Date.now(), + }); + + const startTime = Date.now(); + const opened = await this.openStreamWithFirstChunk(provider, providerRequest); + return this.buildFactoryStream( + opened.reader, + opened.firstChunk, + opened.done, + providerName, + model, + providerRequest, + startTime, + estimatedCost + ); + } catch (error) { + const err = error as Error; + lastError = err; + previousProvider = providerName; + + this.hooks.onRequestError?.({ + provider: providerName, + model: request.model || 'unknown', + requestId: request.requestId, + tenantId: request.tenantId, + error: err, + errorCode: (err as { code?: string }).code, + attempt: 1, + willRetry: this.shouldFallback(err), + timestamp: Date.now(), + }); + + const fallbackDecision = this.getFallbackDecision(err); + if (!fallbackDecision.shouldFallback) { + throw error; + } + + this.applyFallbackDecision(fallbackDecision, providerName, providerChain, index, providerModels); + } + } + + throw lastError || new LLMProviderError( + 'All streaming providers failed', + 'ALL_PROVIDERS_FAILED', + 'factory', + false + ); + } + + async generateResponseWithTools( + request: LLMRequest, + toolExecutor: ToolExecutor, + opts: ToolLoopOptions = {} + ): Promise { + const maxIterations = opts.maxIterations ?? 10; + let cumulativeCost = 0; + let messages = [...request.messages]; + + let lastResponseCost = 0; + + for (let iteration = 0; iteration <= maxIterations; iteration++) { + if (opts.abortSignal?.aborted) { + throw new ToolLoopAbortedError('factory'); + } + + // Pre-flight cost guard: use the previous iteration's cost as an + // estimate for the next one. This prevents obvious overshoots where + // a single expensive response would blow past the cap. The cap is + // still soft (±1 iteration tolerance) because the actual cost is + // only known after the response. + if (opts.maxCostUSD !== undefined && iteration > 0) { + const projectedCost = cumulativeCost + lastResponseCost; + if (projectedCost > opts.maxCostUSD) { + throw new ToolLoopLimitError( + 'factory', + `Tool loop would exceed max cost ${opts.maxCostUSD} (projected ${projectedCost.toFixed(4)})` + ); + } + } + + const response = await this.generateResponse({ ...request, messages }); + lastResponseCost = response.usage.cost; + cumulativeCost += lastResponseCost; + + if (opts.maxCostUSD !== undefined && cumulativeCost > opts.maxCostUSD) { + throw new ToolLoopLimitError( + 'factory', + `Tool loop exceeded max cost ${opts.maxCostUSD}` + ); + } + + if (!response.toolCalls || response.toolCalls.length === 0) { + return { + ...response, + metadata: { + ...response.metadata, + cumulativeCost, + toolIterations: iteration + } + }; + } + + if (iteration >= maxIterations) { + throw new ToolLoopLimitError('factory', `Tool loop exceeded ${maxIterations} iterations`); + } + + const toolResults = []; + for (const toolCall of response.toolCalls) { + if (opts.abortSignal?.aborted) { + throw new ToolLoopAbortedError('factory'); + } + + let parsedArguments: unknown; + try { + parsedArguments = JSON.parse(toolCall.function.arguments); + } catch { + parsedArguments = toolCall.function.arguments; + } + + try { + const output = await toolExecutor.execute(toolCall.function.name, parsedArguments); + toolResults.push({ + id: toolCall.id, + output: typeof output === 'string' ? output : JSON.stringify(output) + }); + } catch (error) { + toolResults.push({ + id: toolCall.id, + output: '', + error: (error as Error).message + }); + } + } + + messages = [ + ...messages, + { + role: 'assistant', + content: response.message, + toolCalls: response.toolCalls + }, + { + role: 'user', + content: '', + toolResults + } + ]; + + const state: ToolLoopState = { + iteration: iteration + 1, + cumulativeCost, + messageCount: messages.length, + lastToolCalls: response.toolCalls + }; + await opts.onIteration?.(iteration + 1, state); + } + + throw new ToolLoopLimitError('factory', `Tool loop exceeded ${maxIterations} iterations`); + } + + async classify( + input: string | LLMRequest, + options: ClassifyOptions = {} + ): Promise> { + const parser = options.schema && typeof (options.schema as { parse?: unknown }).parse === 'function' + ? (options.schema as { parse(data: unknown): T }).parse + : undefined; + const schemaDescription = options.schema && !parser + ? `\nJSON schema:\n${JSON.stringify(options.schema)}` + : ''; + const systemPrompt = options.systemPrompt || + `Classify the input and return only valid JSON.${schemaDescription}`; + const request: LLMRequest = typeof input === 'string' + ? { + messages: [{ role: 'user', content: input }], + model: options.model, + temperature: options.temperature ?? 0, + maxTokens: options.maxTokens, + response_format: { type: 'json_object' }, + systemPrompt, + seed: options.seed + } + : { + ...input, + model: options.model ?? input.model, + temperature: options.temperature ?? input.temperature ?? 0, + maxTokens: options.maxTokens ?? input.maxTokens, + response_format: { type: 'json_object' }, + systemPrompt: options.systemPrompt ?? input.systemPrompt ?? systemPrompt, + seed: options.seed ?? input.seed + }; + + const response = await this.generateResponse(request); + const parsed = this.parseJsonResponse(response.message); + const data = parser ? parser(parsed) : parsed as T; + const confidenceValue = (parsed as Record)[options.confidenceField ?? 'confidence']; + + return { + data, + confidence: typeof confidenceValue === 'number' ? confidenceValue : undefined, + response + }; + } + + async analyzeImage(input: AnalyzeImageInput): Promise { + return this.generateResponse({ + messages: [{ role: 'user', content: input.prompt }], + images: [input.image], + model: input.model ?? this.getDefaultVisionModel(), + systemPrompt: input.systemPrompt, + temperature: input.temperature, + maxTokens: input.maxTokens, + response_format: input.response_format, + tenantId: input.tenantId, + requestId: input.requestId, + metadata: input.metadata + }); + } + + async getProviderBalance(provider?: string): Promise> { + if (provider) { + const balance = await this.getSingleProviderBalance(provider); + this.hooks.onProviderBalance?.({ provider, balance, timestamp: Date.now() }); + return balance; + } + + const result: Record = {}; + for (const providerName of this.providers.keys()) { + const balance = await this.getSingleProviderBalance(providerName); + result[providerName] = balance; + this.hooks.onProviderBalance?.({ provider: providerName, balance, timestamp: Date.now() }); + } + return result; + } + /** * Build provider chain based on request and configuration */ @@ -317,14 +608,17 @@ export class LLMProviderFactory { * Get prioritized list of providers based on cost optimization and capabilities */ private getPrioritizedProviders(request: LLMRequest): string[] { + const visionOnly = (request.images?.length ?? 0) > 0; if (!this.config.costOptimization) { // Default priority: all configured providers, cheapest first return ['cloudflare', 'cerebras', 'groq', 'anthropic', 'openai'] - .filter(p => this.providers.has(p)); + .filter(p => this.providers.has(p)) + .filter(p => !visionOnly || this.providerSupportsVision(p)); } // Cost-optimized routing - const providers = Array.from(this.providers.keys()); + const providers = Array.from(this.providers.keys()) + .filter(p => !visionOnly || this.providerSupportsVision(p)); const sortedProviders = [...providers].sort((a, b) => { const providerA = this.providers.get(a)!; const providerB = this.providers.get(b)!; @@ -629,6 +923,227 @@ export class LLMProviderFactory { } } + private async openStreamWithFirstChunk( + provider: LLMProvider, + request: LLMRequest + ): Promise<{ reader: ReadableStreamDefaultReader; firstChunk?: string; done: boolean }> { + if (!provider.streamResponse) { + throw new ConfigurationError(provider.name, 'Provider does not support streaming'); + } + + const stream = await provider.streamResponse(request); + const reader = stream.getReader(); + const first = await reader.read(); + return { + reader, + firstChunk: first.value, + done: first.done + }; + } + + private buildFactoryStream( + reader: ReadableStreamDefaultReader, + firstChunk: string | undefined, + firstDone: boolean, + providerName: string, + model: string, + request: LLMRequest, + startTime: number, + estimatedCost: number + ): ReadableStream { + return new ReadableStream({ + start: async (controller) => { + try { + if (!firstDone && firstChunk !== undefined) { + controller.enqueue(firstChunk); + } + + if (!firstDone) { + while (true) { + const { done, value } = await reader.read(); + if (done) break; + if (value !== undefined) controller.enqueue(value); + } + } + + const usage = { inputTokens: 0, outputTokens: 0, totalTokens: 0, cost: estimatedCost }; + this.hooks.onRequestEnd?.({ + provider: providerName, + model, + requestId: request.requestId, + tenantId: request.tenantId, + durationMs: Date.now() - startTime, + usage, + finishReason: 'stop', + timestamp: Date.now(), + }); + this.recordQuotaInput({ + tenantId: request.tenantId, + provider: providerName, + model, + actualCost: estimatedCost, + metadata: request.metadata + }); + controller.close(); + } catch (error) { + controller.error(error); + } finally { + reader.releaseLock(); + } + } + }); + } + + private async checkQuota( + providerName: string, + provider: LLMProvider, + request: LLMRequest, + model: string + ): Promise { + const estimatedCost = provider.estimateCost(request); + if (!this.config.quotaHook) { + return estimatedCost; + } + + const input: QuotaCheckInput = { + tenantId: request.tenantId, + provider: providerName, + model, + estimatedCost, + metadata: request.metadata + }; + + try { + const result = await this.config.quotaHook.check(input); + this.hooks.onQuotaCheck?.({ input, result, timestamp: Date.now() }); + if (!result.allowed) { + this.hooks.onQuotaDenied?.({ input, reason: result.reason, timestamp: Date.now() }); + throw new QuotaExceededError(providerName, result.reason || 'Quota hook denied request'); + } + } catch (error) { + if (error instanceof QuotaExceededError) { + throw error; + } + + if ((this.config.quotaFailPolicy ?? 'closed') === 'open') { + this.logger.warn(`[LLMProviderFactory] Quota check failed open for ${providerName}:`, (error as Error).message); + return estimatedCost; + } + + const reason = (error as Error).message; + this.hooks.onQuotaDenied?.({ input, reason, timestamp: Date.now() }); + throw new QuotaExceededError(providerName, reason); + } + + return estimatedCost; + } + + private recordQuota(providerName: string, response: LLMResponse, request: LLMRequest): void { + this.recordQuotaInput({ + tenantId: request.tenantId, + provider: providerName, + model: response.model, + actualCost: response.usage.cost, + inputTokens: response.usage.inputTokens, + outputTokens: response.usage.outputTokens, + metadata: request.metadata + }); + } + + private recordQuotaInput(input: QuotaRecordInput): void { + if (!this.config.quotaHook) return; + + void this.config.quotaHook.record(input).catch(error => { + this.logger.warn(`[LLMProviderFactory] Quota record failed for ${input.provider}:`, (error as Error).message); + }); + } + + private parseJsonResponse(message: string): unknown { + try { + return JSON.parse(message); + } catch { + // Strip markdown fences (```json ... ``` or ``` ... ```) before + // falling back to brace extraction so fenced JSON parses cleanly. + const fenced = message.replace(/^```(?:json)?\s*\n?/m, '').replace(/\n?```\s*$/m, ''); + try { + return JSON.parse(fenced); + } catch { + // Last resort: extract outermost braces. + const start = fenced.indexOf('{'); + const end = fenced.lastIndexOf('}'); + if (start >= 0 && end > start) { + return JSON.parse(fenced.slice(start, end + 1)); + } + } + throw new ConfigurationError('factory', 'Classification response was not valid JSON'); + } + } + + private getDefaultVisionModel(): string | undefined { + if (this.config.defaultVisionModel) return this.config.defaultVisionModel; + if (this.providers.has('anthropic')) return 'claude-haiku-4-5-20251001'; + if (this.providers.has('openai')) return 'gpt-4o-mini'; + return undefined; + } + + private providerSupportsVision(providerName: string): boolean { + return this.providers.get(providerName)?.supportsVision === true; + } + + private async getSingleProviderBalance(providerName: string): Promise { + const ledgerBalance = this.getLedgerBalance(providerName); + if (ledgerBalance) { + return ledgerBalance; + } + + const provider = this.providers.get(providerName); + if (!provider) { + return { + provider: providerName, + status: 'error', + source: 'not_supported', + message: `Provider '${providerName}' is not configured` + }; + } + + if (provider.getProviderBalance) { + return provider.getProviderBalance(); + } + + return { + provider: providerName, + status: 'unavailable', + source: 'not_supported', + message: `Provider '${providerName}' does not expose balance reporting` + }; + } + + private getLedgerBalance(providerName: string): ProviderBalance | undefined { + const acc = this.config.ledger?.getProviderAccumulator(providerName); + if (!acc) return undefined; + + const rateLimits: ProviderBalance['rateLimits'] = {}; + for (const [dimension, window] of Object.entries(acc.rateLimits)) { + rateLimits[dimension] = { + limit: window.limit, + used: window.used, + remaining: Math.max(window.limit - window.used, 0) + }; + } + + return { + provider: providerName, + status: 'available', + source: 'ledger', + currentSpend: acc.spend, + monthlyBudget: acc.budget ?? undefined, + remainingBudget: acc.budget === null ? undefined : acc.budget - acc.spend, + usedTokens: acc.inputTokens + acc.outputTokens, + requestCount: acc.requestCount, + rateLimits + }; + } + private isLedgerLimited(providerName: string): boolean { if (!this.config.ledger) return false; diff --git a/src/index.ts b/src/index.ts index fbcb4b2..fd06fe8 100755 --- a/src/index.ts +++ b/src/index.ts @@ -5,10 +5,12 @@ */ // Core types -export type { - LLMProvider, - LLMRequest, - LLMResponse, +export type { + LLMProvider, + LLMImageInput, + GatewayMetadata, + LLMRequest, + LLMResponse, LLMMessage, LLMConfig, TokenUsage, @@ -29,10 +31,22 @@ export type { CircuitBreakerState, RetryConfig, CostConfig, - LLMError, - StreamChunk, - StreamResponse, - BatchRequest, + LLMError, + StreamChunk, + StreamResponse, + QuotaHook, + QuotaCheckInput, + QuotaCheckResult, + QuotaRecordInput, + ToolExecutor, + ToolLoopOptions, + ToolLoopState, + ClassifyOptions, + ClassifyResult, + AnalyzeImageInput, + ProviderBalance, + RateLimitBalance, + BatchRequest, BatchResponse, BatchJob } from './types'; @@ -56,7 +70,17 @@ export type { ProviderFactoryConfig, CostAnalytics, ProviderHealthEntry } from ' // Local imports for use within this file import { LLMProviderFactory } from './factory'; import type { ProviderFactoryConfig, CostAnalytics, ProviderHealthEntry } from './factory'; -import type { LLMProvider, LLMRequest, LLMResponse } from './types'; +import type { + AnalyzeImageInput, + ClassifyOptions, + ClassifyResult, + LLMProvider, + LLMRequest, + LLMResponse, + ProviderBalance, + ToolExecutor, + ToolLoopOptions +} from './types'; import { createCostOptimizedFactory } from './factory'; import { ConfigurationError } from './errors'; @@ -72,11 +96,13 @@ export { NetworkError, ServerError, ContentFilterError, - TokenLimitError, - ConfigurationError, - CircuitBreakerOpenError, - LLMErrorFactory -} from './errors'; + TokenLimitError, + ConfigurationError, + CircuitBreakerOpenError, + ToolLoopLimitError, + ToolLoopAbortedError, + LLMErrorFactory +} from './errors'; // Image generation export { ImageProvider, normalizeAiResponse } from './image/index'; @@ -97,9 +123,12 @@ export type { RetryEvent, FallbackEvent, CircuitStateChangeEvent, - QuotaExhaustedEvent, - BudgetThresholdEvent, -} from './utils/hooks'; + QuotaExhaustedEvent, + BudgetThresholdEvent, + QuotaCheckEvent, + QuotaDeniedEvent, + ProviderBalanceEvent, +} from './utils/hooks'; // Exhaustion registry export { ExhaustionRegistry, defaultExhaustionRegistry } from './utils/exhaustion'; @@ -153,11 +182,13 @@ export interface FromEnvOverrides { defaultProvider?: ProviderFactoryConfig['defaultProvider']; costOptimization?: boolean; enableCircuitBreaker?: boolean; - enableRetries?: boolean; - fallbackRules?: ProviderFactoryConfig['fallbackRules']; - ledger?: ProviderFactoryConfig['ledger']; - hooks?: ProviderFactoryConfig['hooks']; -} + enableRetries?: boolean; + fallbackRules?: ProviderFactoryConfig['fallbackRules']; + ledger?: ProviderFactoryConfig['ledger']; + quotaHook?: ProviderFactoryConfig['quotaHook']; + quotaFailPolicy?: ProviderFactoryConfig['quotaFailPolicy']; + hooks?: ProviderFactoryConfig['hooks']; +} /** * Main LLMProviders class for easy usage @@ -231,11 +262,17 @@ export class LLMProviders { if (overrides.fallbackRules !== undefined) { config.fallbackRules = overrides.fallbackRules; } - if (overrides.ledger !== undefined) { - config.ledger = overrides.ledger; - } - if (overrides.hooks !== undefined) { - config.hooks = overrides.hooks; + if (overrides.ledger !== undefined) { + config.ledger = overrides.ledger; + } + if (overrides.quotaHook !== undefined) { + config.quotaHook = overrides.quotaHook; + } + if (overrides.quotaFailPolicy !== undefined) { + config.quotaFailPolicy = overrides.quotaFailPolicy; + } + if (overrides.hooks !== undefined) { + config.hooks = overrides.hooks; } return new LLMProviders(config); @@ -244,12 +281,39 @@ export class LLMProviders { /** * Generate response with automatic provider selection and fallback */ - async generateResponse(request: LLMRequest): Promise { - return this.factory.generateResponse(request); - } - - /** - * Get specific provider instance + async generateResponse(request: LLMRequest): Promise { + return this.factory.generateResponse(request); + } + + async generateResponseStream(request: LLMRequest): Promise> { + return this.factory.generateResponseStream(request); + } + + async generateResponseWithTools( + request: LLMRequest, + toolExecutor: ToolExecutor, + opts?: ToolLoopOptions + ): Promise { + return this.factory.generateResponseWithTools(request, toolExecutor, opts); + } + + async classify( + input: string | LLMRequest, + options?: ClassifyOptions + ): Promise> { + return this.factory.classify(input, options); + } + + async analyzeImage(input: AnalyzeImageInput): Promise { + return this.factory.analyzeImage(input); + } + + async getProviderBalance(provider?: string): Promise> { + return this.factory.getProviderBalance(provider); + } + + /** + * Get specific provider instance */ getProvider(name: string): LLMProvider | undefined { return this.factory.getProvider(name); diff --git a/src/providers/anthropic.ts b/src/providers/anthropic.ts index 197e7f4..e5308a3 100755 --- a/src/providers/anthropic.ts +++ b/src/providers/anthropic.ts @@ -3,7 +3,15 @@ * Implementation for Claude models with streaming and tools support */ -import type { LLMRequest, LLMResponse, AnthropicConfig, ModelCapabilities, Tool, ToolCall } from '../types'; +import type { + LLMRequest, + LLMResponse, + AnthropicConfig, + ModelCapabilities, + ProviderBalance, + Tool, + ToolCall +} from '../types'; import { BaseProvider } from './base'; import { LLMErrorFactory, @@ -13,13 +21,18 @@ import { } from '../errors'; interface AnthropicContentBlock { - type: 'text' | 'tool_use' | 'tool_result'; + type: 'text' | 'tool_use' | 'tool_result' | 'image'; text?: string; id?: string; name?: string; input?: Record; content?: string; is_error?: boolean; + source?: { + type: 'base64'; + media_type: string; + data: string; + }; } interface AnthropicMessage { @@ -86,6 +99,7 @@ export class AnthropicProvider extends BaseProvider { supportsStreaming = true; supportsTools = true; supportsBatching = false; + supportsVision = true; private apiKey: string; private baseUrl: string; @@ -112,7 +126,7 @@ export class AnthropicProvider extends BaseProvider { try { const response = await this.executeWithResiliency(async () => { const anthropicRequest = this.formatRequest(request); - const httpResponse = await this.makeAnthropicRequest('/v1/messages', anthropicRequest); + const httpResponse = await this.makeAnthropicRequest('/v1/messages', anthropicRequest, 'POST', request); if (!httpResponse.ok) { throw await LLMErrorFactory.fromFetchResponse('anthropic', httpResponse); @@ -178,12 +192,42 @@ export class AnthropicProvider extends BaseProvider { } } + async getProviderBalance(): Promise { + try { + const response = await this.makeAnthropicRequest('/v1/organizations/cost_report', null, 'GET'); + if (!response.ok) { + return { + provider: this.name, + status: 'unavailable', + source: 'provider_api', + message: `Anthropic usage API returned HTTP ${response.status}` + }; + } + + const raw = await response.json(); + return { + provider: this.name, + status: 'available', + source: 'provider_api', + raw + }; + } catch (error) { + return { + provider: this.name, + status: 'error', + source: 'provider_api', + message: (error as Error).message + }; + } + } + protected getModelCapabilities(): Record { return { 'claude-opus-4-6-20250618': { maxContextLength: 200000, supportsStreaming: true, supportsTools: true, + supportsVision: true, supportsBatching: false, inputTokenCost: 0.015, // $15 per 1M tokens outputTokenCost: 0.075, // $75 per 1M tokens @@ -193,6 +237,7 @@ export class AnthropicProvider extends BaseProvider { maxContextLength: 200000, supportsStreaming: true, supportsTools: true, + supportsVision: true, supportsBatching: false, inputTokenCost: 0.003, // $3 per 1M tokens outputTokenCost: 0.015, // $15 per 1M tokens @@ -285,12 +330,14 @@ export class AnthropicProvider extends BaseProvider { private async makeAnthropicRequest( endpoint: string, body: AnthropicRequest | null, - method: string = 'POST' + method: string = 'POST', + request?: LLMRequest ): Promise { const headers: Record = { 'x-api-key': this.apiKey, 'anthropic-version': this.version, - 'Content-Type': 'application/json' + 'Content-Type': 'application/json', + ...this.getAIGatewayHeaders(request) }; const options: RequestInit = { @@ -315,7 +362,7 @@ export class AnthropicProvider extends BaseProvider { const anthropicMessage: AnthropicMessage = { role: message.role as 'user' | 'assistant', - content: message.content + content: this.formatMessageContent(message.content, message.role === 'user' ? request.images : undefined) }; // Handle tool calls and results @@ -385,6 +432,36 @@ export class AnthropicProvider extends BaseProvider { return anthropicRequest; } + private formatMessageContent( + text: string, + images?: LLMRequest['images'] + ): string | AnthropicContentBlock[] { + if (!images || images.length === 0) { + return text; + } + + return [ + { type: 'text', text }, + ...images.map(image => { + if (image.url) { + return { + type: 'text' as const, + text: `[Image URL: ${image.url}]` + }; + } + + return { + type: 'image' as const, + source: { + type: 'base64' as const, + media_type: image.mimeType || 'image/jpeg', + data: image.data || '' + } + }; + }) + ]; + } + private formatResponse( data: AnthropicResponse, responseTime: number @@ -461,7 +538,7 @@ export class AnthropicProvider extends BaseProvider { return new ReadableStream({ start: async (controller) => { try { - const response = await this.makeAnthropicRequest('/v1/messages', anthropicRequest); + const response = await this.makeAnthropicRequest('/v1/messages', anthropicRequest, 'POST', request); if (!response.ok) { throw await LLMErrorFactory.fromFetchResponse('anthropic', response); diff --git a/src/providers/base.ts b/src/providers/base.ts index 9ea4156..83384e7 100755 --- a/src/providers/base.ts +++ b/src/providers/base.ts @@ -3,11 +3,12 @@ * Abstract base class for all LLM providers with common functionality */ -import type { - LLMProvider, - LLMRequest, - LLMResponse, - ProviderConfig, +import type { + LLMProvider, + LLMImageInput, + LLMRequest, + LLMResponse, + ProviderConfig, ModelCapabilities, ProviderMetrics, ToolCall @@ -249,7 +250,7 @@ export abstract class BaseProvider implements LLMProvider { /** * Validate request before processing */ - protected validateRequest(request: LLMRequest): void { + protected validateRequest(request: LLMRequest): void { if (!request.messages || request.messages.length === 0) { throw new ConfigurationError(this.name, 'Request must contain at least one message'); } @@ -263,13 +264,63 @@ export abstract class BaseProvider implements LLMProvider { } // Validate model if specified - if (request.model && !this.models.includes(request.model)) { - throw new ConfigurationError( - this.name, - `Model '${request.model}' not supported. Available models: ${this.models.join(', ')}` - ); - } - } + if (request.model && !this.models.includes(request.model)) { + throw new ConfigurationError( + this.name, + `Model '${request.model}' not supported. Available models: ${this.models.join(', ')}` + ); + } + + if (request.images) { + for (const image of request.images) { + this.validateImageInput(image); + } + } + } + + protected estimateTextTokens(request: LLMRequest): number { + return ( + (request.systemPrompt ? Math.ceil(request.systemPrompt.length / 4) : 0) + + request.messages.reduce((sum, msg) => sum + Math.ceil(msg.content.length / 4), 0) + ); + } + + protected validateImageInput(image: LLMImageInput): void { + if (!image.data && !image.url) { + throw new ConfigurationError(this.name, 'Image input must include data or url'); + } + + if (image.data && !image.mimeType) { + throw new ConfigurationError(this.name, 'Base64 image input must include mimeType'); + } + } + + protected getAIGatewayHeaders(request?: LLMRequest): Record { + const baseUrl = typeof this.config.baseUrl === 'string' ? this.config.baseUrl : ''; + if (!baseUrl.includes('gateway.ai.cloudflare.com') || !request?.gatewayMetadata) { + return {}; + } + + const headers: Record = {}; + const metadata = { + ...(request.gatewayMetadata.customMetadata ?? {}), + ...(request.gatewayMetadata.requestId ? { requestId: request.gatewayMetadata.requestId } : {}), + ...(request.requestId ? { llmRequestId: request.requestId } : {}), + ...(request.tenantId ? { tenantId: request.tenantId } : {}), + }; + + if (Object.keys(metadata).length > 0) { + headers['cf-aig-metadata'] = JSON.stringify(metadata); + } + if (request.gatewayMetadata.cacheKey) { + headers['cf-aig-cache-key'] = request.gatewayMetadata.cacheKey; + } + if (typeof request.gatewayMetadata.cacheTtl === 'number') { + headers['cf-aig-cache-ttl'] = String(request.gatewayMetadata.cacheTtl); + } + + return headers; + } /** * Validate and sanitize tool calls returned by the provider. diff --git a/src/providers/cerebras.ts b/src/providers/cerebras.ts index d9598b7..cd0337c 100644 --- a/src/providers/cerebras.ts +++ b/src/providers/cerebras.ts @@ -35,6 +35,7 @@ interface CerebrasRequest { stream?: boolean; tools?: CerebrasTool[]; tool_choice?: 'auto' | 'none' | { type: 'function'; function: { name: string } }; + seed?: number; } interface CerebrasResponse { @@ -103,7 +104,7 @@ export class CerebrasProvider extends BaseProvider { try { const response = await this.executeWithResiliency(async () => { const cerebrasRequest = this.formatRequest(request); - const httpResponse = await this.makeCerebrasRequest('/chat/completions', cerebrasRequest); + const httpResponse = await this.makeCerebrasRequest('/chat/completions', cerebrasRequest, 'POST', request); if (!httpResponse.ok) { throw await LLMErrorFactory.fromFetchResponse('cerebras', httpResponse); @@ -208,7 +209,7 @@ export class CerebrasProvider extends BaseProvider { return new ReadableStream({ start: async (controller) => { try { - const response = await this.makeCerebrasRequest('/chat/completions', cerebrasRequest); + const response = await this.makeCerebrasRequest('/chat/completions', cerebrasRequest, 'POST', request); if (!response.ok) { throw await LLMErrorFactory.fromFetchResponse('cerebras', response); @@ -265,11 +266,13 @@ export class CerebrasProvider extends BaseProvider { private async makeCerebrasRequest( endpoint: string, body: CerebrasRequest | null, - method: string = 'POST' + method: string = 'POST', + request?: LLMRequest ): Promise { const headers: Record = { 'Authorization': `Bearer ${this.apiKey}`, - 'Content-Type': 'application/json' + 'Content-Type': 'application/json', + ...this.getAIGatewayHeaders(request) }; const options: RequestInit = { @@ -348,7 +351,8 @@ export class CerebrasProvider extends BaseProvider { messages, temperature: request.temperature, max_tokens: request.maxTokens, - stream: request.stream + stream: request.stream, + seed: request.seed }; // Add tools if provided. Unsupported tool models are rejected above. diff --git a/src/providers/groq.ts b/src/providers/groq.ts index e1e2674..2cae1df 100644 --- a/src/providers/groq.ts +++ b/src/providers/groq.ts @@ -3,7 +3,7 @@ * Implementation for Groq fast inference models (OpenAI-compatible API) */ -import type { LLMRequest, LLMResponse, GroqConfig, ModelCapabilities, ToolCall } from '../types'; +import type { LLMRequest, LLMResponse, GroqConfig, ModelCapabilities, ProviderBalance, ToolCall } from '../types'; import { BaseProvider } from './base'; import { LLMErrorFactory, @@ -36,6 +36,7 @@ interface GroqRequest { response_format?: { type: 'json_object' | 'text' }; tools?: GroqTool[]; tool_choice?: 'auto' | 'none' | { type: 'function'; function: { name: string } }; + seed?: number; } interface GroqResponse { @@ -103,7 +104,7 @@ export class GroqProvider extends BaseProvider { try { const response = await this.executeWithResiliency(async () => { const groqRequest = this.formatRequest(request); - const httpResponse = await this.makeGroqRequest('/chat/completions', groqRequest); + const httpResponse = await this.makeGroqRequest('/chat/completions', groqRequest, 'POST', request); if (!httpResponse.ok) { throw await LLMErrorFactory.fromFetchResponse('groq', httpResponse); @@ -156,6 +157,15 @@ export class GroqProvider extends BaseProvider { } } + async getProviderBalance(): Promise { + return { + provider: this.name, + status: 'unavailable', + source: 'not_supported', + message: 'Groq does not expose a public billing or credit-balance API; use CreditLedger reporting for local quota state.' + }; + } + protected getModelCapabilities(): Record { return { 'llama-3.3-70b-versatile': { @@ -199,7 +209,7 @@ export class GroqProvider extends BaseProvider { return new ReadableStream({ start: async (controller) => { try { - const response = await this.makeGroqRequest('/chat/completions', groqRequest); + const response = await this.makeGroqRequest('/chat/completions', groqRequest, 'POST', request); if (!response.ok) { throw await LLMErrorFactory.fromFetchResponse('groq', response); @@ -256,11 +266,13 @@ export class GroqProvider extends BaseProvider { private async makeGroqRequest( endpoint: string, body: GroqRequest | null, - method: string = 'POST' + method: string = 'POST', + request?: LLMRequest ): Promise { const headers: Record = { 'Authorization': `Bearer ${this.apiKey}`, - 'Content-Type': 'application/json' + 'Content-Type': 'application/json', + ...this.getAIGatewayHeaders(request) }; const options: RequestInit = { @@ -339,7 +351,8 @@ export class GroqProvider extends BaseProvider { messages, temperature: request.temperature, max_tokens: request.maxTokens, - stream: request.stream + stream: request.stream, + seed: request.seed }; // Pass through response_format if provided diff --git a/src/providers/openai.ts b/src/providers/openai.ts index b2f2e0a..215d5a7 100755 --- a/src/providers/openai.ts +++ b/src/providers/openai.ts @@ -18,9 +18,15 @@ interface OpenAIToolCall { function: { name: string; arguments: string }; } +interface OpenAIContentPart { + type: 'text' | 'image_url'; + text?: string; + image_url?: { url: string }; +} + interface OpenAIMessage { role: 'system' | 'user' | 'assistant' | 'tool'; - content: string | null; + content: string | null | OpenAIContentPart[]; name?: string; tool_calls?: OpenAIToolCall[]; tool_call_id?: string; @@ -46,6 +52,7 @@ interface OpenAIRequest { tools?: OpenAITool[]; tool_choice?: OpenAIToolChoice; response_format?: { type: 'json_object' | 'text' }; + seed?: number; } interface OpenAIResponse { @@ -85,6 +92,7 @@ export class OpenAIProvider extends BaseProvider { supportsStreaming = true; supportsTools = true; supportsBatching = false; + supportsVision = true; private apiKey: string; private baseUrl: string; @@ -112,7 +120,7 @@ export class OpenAIProvider extends BaseProvider { try { const response = await this.executeWithResiliency(async () => { const openaiRequest = this.formatRequest(request); - const httpResponse = await this.makeOpenAIRequest('/chat/completions', openaiRequest); + const httpResponse = await this.makeOpenAIRequest('/chat/completions', openaiRequest, 'POST', request); if (!httpResponse.ok) { throw await LLMErrorFactory.fromFetchResponse('openai', httpResponse); @@ -172,6 +180,7 @@ export class OpenAIProvider extends BaseProvider { maxContextLength: 128000, supportsStreaming: true, supportsTools: true, + supportsVision: true, supportsBatching: false, inputTokenCost: 0.005, // $5 per 1M tokens outputTokenCost: 0.015, // $15 per 1M tokens @@ -181,6 +190,7 @@ export class OpenAIProvider extends BaseProvider { maxContextLength: 128000, supportsStreaming: true, supportsTools: true, + supportsVision: true, supportsBatching: false, inputTokenCost: 0.00015, // $0.15 per 1M tokens outputTokenCost: 0.0006, // $0.60 per 1M tokens @@ -228,11 +238,13 @@ export class OpenAIProvider extends BaseProvider { private async makeOpenAIRequest( endpoint: string, body: OpenAIRequest | null, - method: string = 'POST' + method: string = 'POST', + request?: LLMRequest ): Promise { const headers: Record = { 'Authorization': `Bearer ${this.apiKey}`, - 'Content-Type': 'application/json' + 'Content-Type': 'application/json', + ...this.getAIGatewayHeaders(request) }; if (this.organization) { @@ -274,7 +286,7 @@ export class OpenAIProvider extends BaseProvider { const openaiMessage: OpenAIMessage = { role: message.role as OpenAIMessage['role'], - content: message.content + content: this.formatMessageContent(message.content, message.role === 'user' ? request.images : undefined) }; // Add tool calls if present @@ -287,6 +299,19 @@ export class OpenAIProvider extends BaseProvider { openaiMessage.content = null; // Content must be null when tool_calls present } + if (message.toolResults && message.toolResults.length > 0) { + for (const toolResult of message.toolResults) { + messages.push({ + role: 'tool', + content: toolResult.error + ? JSON.stringify({ output: toolResult.output, error: toolResult.error }) + : toolResult.output, + tool_call_id: toolResult.id + }); + } + continue; + } + messages.push(openaiMessage); } @@ -295,7 +320,8 @@ export class OpenAIProvider extends BaseProvider { messages, temperature: request.temperature, max_tokens: request.maxTokens, - stream: request.stream + stream: request.stream, + seed: request.seed }; // Add tools if provided @@ -318,6 +344,25 @@ export class OpenAIProvider extends BaseProvider { return openaiRequest; } + private formatMessageContent( + text: string, + images?: LLMRequest['images'] + ): OpenAIMessage['content'] { + if (!images || images.length === 0) { + return text; + } + + return [ + { type: 'text', text }, + ...images.map(image => ({ + type: 'image_url' as const, + image_url: { + url: image.url || `data:${image.mimeType};base64,${image.data}` + } + })) + ]; + } + private formatResponse( data: OpenAIResponse, responseTime: number @@ -378,7 +423,7 @@ export class OpenAIProvider extends BaseProvider { return new ReadableStream({ start: async (controller) => { try { - const response = await this.makeOpenAIRequest('/chat/completions', openaiRequest); + const response = await this.makeOpenAIRequest('/chat/completions', openaiRequest, 'POST', request); if (!response.ok) { throw await LLMErrorFactory.fromFetchResponse('openai', response); diff --git a/src/types.ts b/src/types.ts index c290f75..d2f083e 100755 --- a/src/types.ts +++ b/src/types.ts @@ -7,13 +7,26 @@ import type { Logger } from './utils/logger'; import type { ObservabilityHooks } from './utils/hooks'; export type { Logger, ObservabilityHooks }; -export interface LLMMessage { - role: 'user' | 'assistant' | 'system'; - content: string; - timestamp?: string; - toolCalls?: ToolCall[]; - toolResults?: ToolResult[]; -} +export interface LLMMessage { + role: 'user' | 'assistant' | 'system'; + content: string; + timestamp?: string; + toolCalls?: ToolCall[]; + toolResults?: ToolResult[]; +} + +export interface LLMImageInput { + data?: string; + url?: string; + mimeType?: string; +} + +export interface GatewayMetadata { + requestId?: string; + cacheKey?: string; + cacheTtl?: number; + customMetadata?: Record; +} export interface ToolCall { id: string; @@ -30,20 +43,23 @@ export interface ToolResult { error?: string; } -export interface LLMRequest { - messages: LLMMessage[]; - model?: string; - temperature?: number; - maxTokens?: number; - stream?: boolean; - systemPrompt?: string; - tools?: Tool[]; - toolChoice?: 'auto' | 'none' | { type: 'function'; function: { name: string } }; - response_format?: { type: 'json_object' | 'text' }; - tenantId?: string; - requestId?: string; - metadata?: Record; -} +export interface LLMRequest { + messages: LLMMessage[]; + model?: string; + temperature?: number; + maxTokens?: number; + stream?: boolean; + systemPrompt?: string; + images?: LLMImageInput[]; + tools?: Tool[]; + toolChoice?: 'auto' | 'none' | { type: 'function'; function: { name: string } }; + response_format?: { type: 'json_object' | 'text' }; + seed?: number; + gatewayMetadata?: GatewayMetadata; + tenantId?: string; + requestId?: string; + metadata?: Record; +} export interface Tool { type: 'function'; @@ -78,17 +94,20 @@ export interface TokenUsage { cost: number; // Cost in USD } -export interface LLMProvider { - name: string; - models: string[]; - supportsStreaming: boolean; - supportsTools: boolean; - supportsBatching: boolean; - - generateResponse(request: LLMRequest): Promise; - validateConfig(): boolean; - getModels(): string[]; - estimateCost(request: LLMRequest): number; +export interface LLMProvider { + name: string; + models: string[]; + supportsStreaming: boolean; + supportsTools: boolean; + supportsBatching: boolean; + supportsVision?: boolean; + + generateResponse(request: LLMRequest): Promise; + streamResponse?(request: LLMRequest): Promise>; + getProviderBalance?(): Promise; + validateConfig(): boolean; + getModels(): string[]; + estimateCost(request: LLMRequest): number; healthCheck(): Promise; getMetrics(): ProviderMetrics; resetMetrics(): void; @@ -207,13 +226,14 @@ export interface FallbackRule { fallbackModel?: string; } -export interface ModelCapabilities { - maxContextLength: number; - supportsStreaming: boolean; - supportsTools: boolean; - toolCalling?: boolean; - supportsBatching: boolean; - inputTokenCost: number; +export interface ModelCapabilities { + maxContextLength: number; + supportsStreaming: boolean; + supportsTools: boolean; + supportsVision?: boolean; + toolCalling?: boolean; + supportsBatching: boolean; + inputTokenCost: number; outputTokenCost: number; description: string; } @@ -240,10 +260,107 @@ export interface StreamChunk { usage?: Partial; } -export interface StreamResponse { - stream: ReadableStream; - controller: ReadableStreamDefaultController; -} +export interface StreamResponse { + stream: ReadableStream; + controller: ReadableStreamDefaultController; +} + +export interface QuotaCheckInput { + tenantId?: string; + provider: string; + model: string; + estimatedCost: number; + metadata?: Record; +} + +export interface QuotaCheckResult { + allowed: boolean; + reason?: string; + remainingBudget?: number; +} + +export interface QuotaRecordInput { + tenantId?: string; + provider: string; + model: string; + actualCost: number; + inputTokens?: number; + outputTokens?: number; + metadata?: Record; +} + +export interface QuotaHook { + check(input: QuotaCheckInput): Promise; + record(input: QuotaRecordInput): Promise; +} + +export interface ToolExecutor { + execute(name: string, argumentsValue: unknown): Promise; +} + +export interface ToolLoopState { + iteration: number; + cumulativeCost: number; + messageCount: number; + lastToolCalls: ToolCall[]; +} + +export interface ToolLoopOptions { + maxIterations?: number; + maxCostUSD?: number; + onIteration?: (iteration: number, state: ToolLoopState) => void | Promise; + abortSignal?: AbortSignal; +} + +export interface ClassifyOptions { + schema?: Record | { parse(data: unknown): T }; + systemPrompt?: string; + model?: string; + temperature?: number; + maxTokens?: number; + confidenceField?: string; + seed?: number; +} + +export interface ClassifyResult { + data: T; + confidence?: number; + response: LLMResponse; +} + +export interface AnalyzeImageInput { + image: LLMImageInput; + prompt: string; + model?: string; + systemPrompt?: string; + temperature?: number; + maxTokens?: number; + response_format?: LLMRequest['response_format']; + tenantId?: string; + requestId?: string; + metadata?: Record; +} + +export interface RateLimitBalance { + limit?: number; + used?: number; + remaining?: number; +} + +export interface ProviderBalance { + provider: string; + status: 'available' | 'unavailable' | 'error'; + source: 'provider_api' | 'ledger' | 'headers' | 'not_supported'; + currentSpend?: number; + monthlyBudget?: number; + remainingBudget?: number; + usedTokens?: number; + requestCount?: number; + rateLimits?: Record; + resetAt?: string; + message?: string; + raw?: unknown; +} // Batch processing export interface BatchRequest { diff --git a/src/utils/hooks.ts b/src/utils/hooks.ts index 610abb9..3eb96ac 100644 --- a/src/utils/hooks.ts +++ b/src/utils/hooks.ts @@ -9,7 +9,14 @@ * interface stays for debug output; hooks are for structured observability. */ -import type { CircuitBreakerState, ProviderMetrics, TokenUsage } from '../types'; +import type { + CircuitBreakerState, + ProviderBalance, + ProviderMetrics, + QuotaCheckInput, + QuotaCheckResult, + TokenUsage +} from '../types'; // ── Event types ────────────────────────────────────────────────────────── @@ -87,6 +94,24 @@ export interface BudgetThresholdEvent { timestamp: number; } +export interface QuotaCheckEvent { + input: QuotaCheckInput; + result: QuotaCheckResult; + timestamp: number; +} + +export interface QuotaDeniedEvent { + input: QuotaCheckInput; + reason?: string; + timestamp: number; +} + +export interface ProviderBalanceEvent { + provider: string; + balance: ProviderBalance; + timestamp: number; +} + // ── Hooks interface ────────────────────────────────────────────────────── export interface ObservabilityHooks { @@ -98,6 +123,9 @@ export interface ObservabilityHooks { onCircuitStateChange?(event: CircuitStateChangeEvent): void; onQuotaExhausted?(event: QuotaExhaustedEvent): void; onBudgetThreshold?(event: BudgetThresholdEvent): void; + onQuotaCheck?(event: QuotaCheckEvent): void; + onQuotaDenied?(event: QuotaDeniedEvent): void; + onProviderBalance?(event: ProviderBalanceEvent): void; } /** Silent hooks — default. */ @@ -111,7 +139,8 @@ export function composeHooks(...implementations: ObservabilityHooks[]): Observab const methods = [ 'onRequestStart', 'onRequestEnd', 'onRequestError', 'onRetry', 'onFallback', 'onCircuitStateChange', - 'onQuotaExhausted', 'onBudgetThreshold', + 'onQuotaExhausted', 'onBudgetThreshold', 'onQuotaCheck', + 'onQuotaDenied', 'onProviderBalance', ] as const; const composed: ObservabilityHooks = {};