diff --git a/src/__tests__/cerebras.test.ts b/src/__tests__/cerebras.test.ts index 07fe537..0c23e35 100644 --- a/src/__tests__/cerebras.test.ts +++ b/src/__tests__/cerebras.test.ts @@ -5,7 +5,7 @@ import { describe, it, expect, beforeEach, vi } from 'vitest'; import { CerebrasProvider } from '../providers/cerebras'; -import { AuthenticationError } from '../errors'; +import { AuthenticationError, ConfigurationError } from '../errors'; import { defaultCircuitBreakerManager } from '../utils/circuit-breaker'; import type { LLMRequest } from '../types'; @@ -208,6 +208,24 @@ describe('CerebrasProvider', () => { model: 'gpt-4' })).rejects.toThrow("Model 'gpt-4' not supported"); }); + + it('should reject tools for non-tool-capable models', async () => { + await expect(provider.generateResponse({ + messages: [{ role: 'user', content: 'What is the weather?' }], + model: 'llama-3.1-8b', + tools: [{ + type: 'function', + function: { + name: 'get_weather', + description: 'Get current weather', + parameters: { type: 'object', properties: { location: { type: 'string' } } } + } + }], + toolChoice: 'auto' + })).rejects.toBeInstanceOf(ConfigurationError); + + expect(mockFetch).not.toHaveBeenCalled(); + }); }); describe('estimateCost', () => { diff --git a/src/__tests__/factory.test.ts b/src/__tests__/factory.test.ts index 67a36f1..0ae6a44 100755 --- a/src/__tests__/factory.test.ts +++ b/src/__tests__/factory.test.ts @@ -3,13 +3,15 @@ * Tests for the provider factory with mocked providers */ -import { describe, it, expect, beforeEach, vi } from 'vitest'; -import { LLMProviderFactory, createCostOptimizedFactory } from '../factory'; -import { AuthenticationError } from '../errors'; -import type { LLMRequest, LLMResponse } from '../types'; -import { defaultCostTracker } from '../utils/cost-tracker'; -import { defaultCircuitBreakerManager } from '../utils/circuit-breaker'; -import { defaultExhaustionRegistry } from '../utils/exhaustion'; +import { describe, it, expect, beforeEach, vi } from 'vitest'; +import { LLMProviderFactory, createCostOptimizedFactory } from '../factory'; +import { AuthenticationError } from '../errors'; +import type { LLMRequest, LLMResponse } from '../types'; +import { OpenAIProvider } from '../providers/openai'; +import { CreditLedger } from '../utils/credit-ledger'; +import { defaultCostTracker } from '../utils/cost-tracker'; +import { defaultCircuitBreakerManager } from '../utils/circuit-breaker'; +import { defaultExhaustionRegistry } from '../utils/exhaustion'; import { defaultLatencyHistogram } from '../utils/latency-histogram'; // Mock providers @@ -122,15 +124,37 @@ describe('LLMProviderFactory', () => { temperature: 0.7 }; - beforeEach(() => { - vi.clearAllMocks(); - defaultCostTracker.reset(); - defaultCircuitBreakerManager.resetAll(); - defaultExhaustionRegistry.reset(); - defaultLatencyHistogram.reset(); - - factory = new LLMProviderFactory({ - openai: { apiKey: 'test-openai-key' }, + beforeEach(() => { + vi.clearAllMocks(); + defaultCostTracker.reset(); + defaultCircuitBreakerManager.resetAll(); + defaultExhaustionRegistry.reset(); + defaultLatencyHistogram.reset(); + + mockOpenAIProvider.generateResponse.mockReset().mockResolvedValue({ + message: 'OpenAI response', + usage: { inputTokens: 10, outputTokens: 20, totalTokens: 30, cost: 0.001 }, + model: 'gpt-3.5-turbo', + provider: 'openai', + responseTime: 1000 + } as LLMResponse); + mockAnthropicProvider.generateResponse.mockReset().mockResolvedValue({ + message: 'Anthropic response', + usage: { inputTokens: 10, outputTokens: 20, totalTokens: 30, cost: 0.002 }, + model: 'claude-3-haiku-20240307', + provider: 'anthropic', + responseTime: 1200 + } as LLMResponse); + mockCloudflareProvider.generateResponse.mockReset().mockResolvedValue({ + message: 'Cloudflare response', + usage: { inputTokens: 10, outputTokens: 20, totalTokens: 30, cost: 0.0001 }, + model: '@cf/meta/llama-3.1-8b-instruct', + provider: 'cloudflare', + responseTime: 800 + } as LLMResponse); + + factory = new LLMProviderFactory({ + openai: { apiKey: 'test-openai-key' }, anthropic: { apiKey: 'test-anthropic-key' }, cloudflare: { ai: {} as Ai }, defaultProvider: 'auto', @@ -267,18 +291,62 @@ describe('LLMProviderFactory', () => { expect(mockAnthropicProvider.generateResponse).toHaveBeenCalled(); }); - it('should build cost-optimized chain for auto mode', async () => { - const autoRequest: LLMRequest = { - ...testRequest + it('should build cost-optimized chain for auto mode', async () => { + const autoRequest: LLMRequest = { + ...testRequest // No specific model, should use auto selection }; await factory.generateResponse(autoRequest); - // Should prefer Cloudflare for cost optimization - expect(mockCloudflareProvider.generateResponse).toHaveBeenCalled(); - }); - }); + // Should prefer Cloudflare for cost optimization + expect(mockCloudflareProvider.generateResponse).toHaveBeenCalled(); + }); + + it('should honor fallbackProvider as the next route when a rule matches', async () => { + const ruleFactory = new LLMProviderFactory({ + openai: { apiKey: 'test-openai-key' }, + anthropic: { apiKey: 'test-anthropic-key' }, + cloudflare: { ai: {} as Ai }, + defaultProvider: 'openai', + costOptimization: false, + fallbackRules: [{ condition: 'error', fallbackProvider: 'anthropic' }] + }); + + mockOpenAIProvider.generateResponse.mockRejectedValueOnce(new Error('OpenAI down')); + + const response = await ruleFactory.generateResponse(testRequest); + + expect(response.provider).toBe('anthropic'); + expect(mockAnthropicProvider.generateResponse).toHaveBeenCalled(); + expect(mockCloudflareProvider.generateResponse).not.toHaveBeenCalled(); + }); + + it('should apply fallbackModel when routing through a fallback rule', async () => { + const ruleFactory = new LLMProviderFactory({ + openai: { apiKey: 'test-openai-key' }, + anthropic: { apiKey: 'test-anthropic-key' }, + defaultProvider: 'openai', + costOptimization: false, + fallbackRules: [{ + condition: 'error', + fallbackProvider: 'anthropic', + fallbackModel: 'claude-3-haiku-20240307' + }] + }); + + mockOpenAIProvider.generateResponse.mockRejectedValueOnce(new Error('OpenAI down')); + + await ruleFactory.generateResponse({ + ...testRequest, + model: 'gpt-4' + }); + + expect(mockAnthropicProvider.generateResponse).toHaveBeenCalledWith( + expect.objectContaining({ model: 'claude-3-haiku-20240307' }) + ); + }); + }); describe('Error Handling', () => { it('should handle all providers failing', async () => { @@ -302,7 +370,7 @@ describe('LLMProviderFactory', () => { }); }); - describe('Configuration Updates', () => { + describe('Configuration Updates', () => { it('should update configuration', () => { factory.updateConfig({ defaultProvider: 'anthropic', @@ -314,15 +382,75 @@ describe('LLMProviderFactory', () => { expect(() => factory.updateConfig({})).not.toThrow(); }); - it('should reset metrics and circuit breakers', () => { - factory.reset(); + it('should reset metrics and circuit breakers', () => { + factory.reset(); expect(mockOpenAIProvider.resetMetrics).toHaveBeenCalled(); expect(mockAnthropicProvider.resetMetrics).toHaveBeenCalled(); - expect(mockCloudflareProvider.resetMetrics).toHaveBeenCalled(); - }); - }); -}); + expect(mockCloudflareProvider.resetMetrics).toHaveBeenCalled(); + }); + + it('should pass maxRetries 0 to providers when factory retries are disabled', () => { + new LLMProviderFactory({ + openai: { apiKey: 'test-key' }, + enableRetries: false + }); + + expect(OpenAIProvider).toHaveBeenCalledWith( + expect.objectContaining({ + apiKey: 'test-key', + maxRetries: 0 + }) + ); + }); + + it('should preserve explicit provider maxRetries when factory retries are disabled', () => { + new LLMProviderFactory({ + openai: { apiKey: 'test-key', maxRetries: 2 }, + enableRetries: false + }); + + expect(OpenAIProvider).toHaveBeenCalledWith( + expect.objectContaining({ + apiKey: 'test-key', + maxRetries: 2 + }) + ); + }); + }); + + describe('CreditLedger integration', () => { + it('should record successful factory calls into the provided ledger even without cost optimization', async () => { + const ledger = new CreditLedger({ + budgets: [{ + provider: 'cloudflare', + monthlyBudget: 1, + rateLimits: { rpm: 10, rpd: 100, tpm: 1000, tpd: 10_000 } + }] + }); + + const ledgerFactory = new LLMProviderFactory({ + cloudflare: { ai: {} as Ai }, + costOptimization: false, + ledger + }); + + await ledgerFactory.generateResponse(testRequest); + + const accumulator = ledger.getProviderAccumulator('cloudflare'); + expect(accumulator).toMatchObject({ + spend: 0.0001, + inputTokens: 10, + outputTokens: 20, + requestCount: 1 + }); + expect(accumulator!.rateLimits.rpm!.used).toBe(1); + expect(accumulator!.rateLimits.rpd!.used).toBe(1); + expect(accumulator!.rateLimits.tpm!.used).toBe(30); + expect(accumulator!.rateLimits.tpd!.used).toBe(30); + }); + }); +}); describe('Cost Optimized Factory', () => { beforeEach(() => { diff --git a/src/__tests__/groq.test.ts b/src/__tests__/groq.test.ts index 30ce1c9..a225224 100644 --- a/src/__tests__/groq.test.ts +++ b/src/__tests__/groq.test.ts @@ -5,7 +5,7 @@ import { describe, it, expect, beforeEach, vi } from 'vitest'; import { GroqProvider } from '../providers/groq'; -import { AuthenticationError } from '../errors'; +import { AuthenticationError, ConfigurationError } from '../errors'; import { defaultCircuitBreakerManager } from '../utils/circuit-breaker'; import type { LLMRequest } from '../types'; @@ -365,32 +365,13 @@ describe('GroqProvider', () => { expect(toolMsg.content).toBe('{"temp": 15, "condition": "cloudy"}'); }); - it('should not include tools for non-tool-capable models', async () => { - mockFetch.mockResolvedValueOnce({ - ok: true, - json: async () => ({ - id: 'chatcmpl-123', - object: 'chat.completion', - created: 1700000000, - model: 'llama-3.1-8b-instant', - choices: [{ - index: 0, - message: { role: 'assistant', content: 'I cannot call tools.' }, - finish_reason: 'stop' - }], - usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 } - }), - headers: new Headers({ 'content-type': 'application/json' }) - }); - - await provider.generateResponse({ + it('should reject tools for non-tool-capable models', async () => { + await expect(provider.generateResponse({ ...toolRequest, model: 'llama-3.1-8b-instant' - }); + })).rejects.toBeInstanceOf(ConfigurationError); - const body = JSON.parse(mockFetch.mock.calls[0][1].body); - expect(body.tools).toBeUndefined(); - expect(body.tool_choice).toBeUndefined(); + expect(mockFetch).not.toHaveBeenCalled(); }); it('should support tool calling on llama-3.3-70b-versatile', async () => { diff --git a/src/factory.ts b/src/factory.ts index 9111a1c..fcd413a 100755 --- a/src/factory.ts +++ b/src/factory.ts @@ -79,6 +79,12 @@ export interface ProviderHealthEntry { error?: string; } +interface FallbackDecision { + shouldFallback: boolean; + fallbackProvider?: string; + fallbackModel?: string; +} + export class LLMProviderFactory { private providers: Map = new Map(); private config: ProviderFactoryConfig; @@ -91,7 +97,9 @@ export class LLMProviderFactory { this.config = config; this.logger = config.logger ?? noopLogger; this.hooks = config.hooks ?? noopHooks; - this.costTracker = defaultCostTracker; + this.costTracker = config.ledger + ? new CostTracker({}, config.ledger, this.logger) + : defaultCostTracker; this.fallbackRules = config.fallbackRules || this.getDefaultFallbackRules(); this.initializeProviders(); @@ -114,8 +122,13 @@ export class LLMProviderFactory { if (!providerConfig) continue; try { + const retryConfig: Partial = + this.config.enableRetries === false && providerConfig.maxRetries === undefined + ? { maxRetries: 0 } + : {}; const provider = new ProviderClass({ ...providerConfig, + ...retryConfig, logger: this.logger, hooks: this.hooks, }); @@ -138,10 +151,13 @@ export class LLMProviderFactory { */ async generateResponse(request: LLMRequest): Promise { const providerChain = this.buildProviderChain(request); + const providerModels = new Map(); let lastError: Error | null = null; let previousProvider: string | null = null; - for (const providerName of providerChain) { + for (let index = 0; index < providerChain.length; index++) { + const providerName = providerChain[index]; + try { const provider = this.providers.get(providerName); if (!provider) continue; @@ -161,18 +177,8 @@ export class LLMProviderFactory { } } - // Check rate limits if ledger is configured - if (this.config.ledger) { - const rpmCheck = this.config.ledger.checkRateLimit(providerName, 'rpm'); - if (!rpmCheck.allowed) { - this.logger.warn(`[LLMProviderFactory] Rate limit (rpm) exceeded for ${providerName} (${rpmCheck.used}/${rpmCheck.limit}), skipping`); - continue; - } - const rpdCheck = this.config.ledger.checkRateLimit(providerName, 'rpd'); - if (!rpdCheck.allowed) { - this.logger.warn(`[LLMProviderFactory] Rate limit (rpd) exceeded for ${providerName} (${rpdCheck.used}/${rpdCheck.limit}), skipping`); - continue; - } + if (this.config.ledger && this.isLedgerLimited(providerName)) { + continue; } // Emit fallback event if this isn't the first provider attempted @@ -189,7 +195,8 @@ export class LLMProviderFactory { this.logger.debug(`[LLMProviderFactory] Trying provider: ${providerName}`); - const model = request.model || provider.models[0] || 'unknown'; + const providerRequest = this.requestForProvider(request, providerName, providerModels); + const model = providerRequest.model || provider.models[0] || 'unknown'; this.hooks.onRequestStart?.({ provider: providerName, model, @@ -199,7 +206,7 @@ export class LLMProviderFactory { }); const startTime = Date.now(); - const response = await provider.generateResponse(request); + const response = await provider.generateResponse(providerRequest); const durationMs = Date.now() - startTime; this.hooks.onRequestEnd?.({ @@ -213,8 +220,8 @@ export class LLMProviderFactory { timestamp: Date.now(), }); - // Track cost if enabled - if (this.config.costOptimization) { + // Track spend whenever analytics or ledger accounting is configured. + if (this.config.costOptimization || this.config.ledger) { this.costTracker.trackCost(providerName, response); } @@ -249,10 +256,18 @@ export class LLMProviderFactory { }); } - // Check if we should continue trying other providers - if (!this.shouldFallback(error as Error)) { + const fallbackDecision = this.getFallbackDecision(error as Error); + if (!fallbackDecision.shouldFallback) { throw error; } + + this.applyFallbackDecision( + fallbackDecision, + providerName, + providerChain, + index, + providerModels + ); } } @@ -366,38 +381,49 @@ export class LLMProviderFactory { * Check if we should fallback to another provider */ private shouldFallback(error: Error): boolean { + return this.getFallbackDecision(error).shouldFallback; + } + + /** + * Get fallback routing decision for an error. + */ + private getFallbackDecision(error: Error): FallbackDecision { // Don't fallback for authentication errors if (error instanceof AuthenticationError) { - return false; + return { shouldFallback: false }; } // Don't fallback for configuration errors if (error instanceof ConfigurationError) { - return false; + return { shouldFallback: false }; + } + + // Custom fallback rules can provide explicit provider/model routing. + for (const rule of this.fallbackRules) { + if (this.evaluateFallbackRule(rule, error)) { + return { + shouldFallback: true, + fallbackProvider: rule.fallbackProvider, + fallbackModel: rule.fallbackModel + }; + } } // Fallback for circuit breaker, rate limits, and server errors if (error instanceof CircuitBreakerOpenError || error instanceof RateLimitError) { - return true; + return { shouldFallback: true }; } if (error instanceof LLMProviderError) { if (error.code === 'SERVER_ERROR' || error.code === 'NETWORK_ERROR' || error.code === 'TIMEOUT') { - return true; - } - } - - // Check custom fallback rules - for (const rule of this.fallbackRules) { - if (this.evaluateFallbackRule(rule, error)) { - return true; + return { shouldFallback: true }; } } - return false; + return { shouldFallback: false }; } /** @@ -565,7 +591,7 @@ export class LLMProviderFactory { } } - if (this.config.costOptimization) { + if (this.config.costOptimization || this.config.ledger) { this.costTracker.reset(); } @@ -579,16 +605,88 @@ export class LLMProviderFactory { updateConfig(config: Partial): void { this.config = { ...this.config, ...config }; + if ('ledger' in config) { + this.costTracker = config.ledger + ? new CostTracker({}, config.ledger, this.logger) + : defaultCostTracker; + } + if (config.fallbackRules) { this.fallbackRules = config.fallbackRules; } // Re-initialize providers if configs changed - if (config.openai || config.anthropic || config.cloudflare || config.cerebras || config.groq) { + if ( + config.openai || + config.anthropic || + config.cloudflare || + config.cerebras || + config.groq || + config.enableRetries !== undefined + ) { this.providers.clear(); this.initializeProviders(); } } + + private isLedgerLimited(providerName: string): boolean { + if (!this.config.ledger) return false; + + for (const dimension of ['rpm', 'rpd', 'tpm', 'tpd'] as const) { + const check = this.config.ledger.checkRateLimit(providerName, dimension); + if (!check.allowed) { + this.logger.warn( + `[LLMProviderFactory] Rate limit (${dimension}) exceeded for ${providerName} (${check.used}/${check.limit}), skipping` + ); + return true; + } + } + + return false; + } + + private requestForProvider( + request: LLMRequest, + providerName: string, + providerModels: Map + ): LLMRequest { + const model = providerModels.get(providerName); + if (!model) { + return request; + } + + return { ...request, model }; + } + + private applyFallbackDecision( + decision: FallbackDecision, + failedProvider: string, + providerChain: string[], + currentIndex: number, + providerModels: Map + ): void { + const targetProvider = decision.fallbackProvider; + if (!targetProvider || targetProvider === failedProvider || !this.providers.has(targetProvider)) { + return; + } + + if (decision.fallbackModel) { + providerModels.set(targetProvider, decision.fallbackModel); + } + + const nextIndex = currentIndex + 1; + const firstIndex = providerChain.indexOf(targetProvider); + if (firstIndex >= 0 && firstIndex <= currentIndex) { + return; + } + + const existingIndex = providerChain.indexOf(targetProvider, nextIndex); + if (existingIndex >= 0) { + providerChain.splice(existingIndex, 1); + } + + providerChain.splice(nextIndex, 0, targetProvider); + } } /** diff --git a/src/providers/cerebras.ts b/src/providers/cerebras.ts index d9657b2..d9598b7 100644 --- a/src/providers/cerebras.ts +++ b/src/providers/cerebras.ts @@ -7,7 +7,8 @@ import type { LLMRequest, LLMResponse, CerebrasConfig, ModelCapabilities, ToolCa import { BaseProvider } from './base'; import { LLMErrorFactory, - AuthenticationError + AuthenticationError, + ConfigurationError } from '../errors'; interface CerebrasMessage { @@ -286,9 +287,21 @@ export class CerebrasProvider extends BaseProvider { private formatRequest(request: LLMRequest): CerebrasRequest { const messages: CerebrasMessage[] = []; const model = request.model || 'llama-3.1-8b'; + const usesTools = + (request.tools?.length ?? 0) > 0 || + request.messages.some(message => + (message.toolCalls?.length ?? 0) > 0 || (message.toolResults?.length ?? 0) > 0 + ); const jsonMode = request.response_format?.type === 'json_object'; const jsonInstruction = '\n\nYou must respond with valid JSON only. No markdown fences, no commentary, no text outside the JSON.'; + if (usesTools && !TOOL_CAPABLE_MODELS.has(model)) { + throw new ConfigurationError( + this.name, + `Model '${model}' does not support tool calling on Cerebras` + ); + } + if (request.systemPrompt) { messages.push({ role: 'system', @@ -338,8 +351,8 @@ export class CerebrasProvider extends BaseProvider { stream: request.stream }; - // Add tools if the model supports them and tools are provided - if (request.tools && request.tools.length > 0 && TOOL_CAPABLE_MODELS.has(model)) { + // Add tools if provided. Unsupported tool models are rejected above. + if (request.tools && request.tools.length > 0) { result.tools = request.tools.map(t => ({ type: 'function', function: { diff --git a/src/providers/groq.ts b/src/providers/groq.ts index 3f49c81..e1e2674 100644 --- a/src/providers/groq.ts +++ b/src/providers/groq.ts @@ -7,7 +7,8 @@ import type { LLMRequest, LLMResponse, GroqConfig, ModelCapabilities, ToolCall } import { BaseProvider } from './base'; import { LLMErrorFactory, - AuthenticationError + AuthenticationError, + ConfigurationError } from '../errors'; interface GroqMessage { @@ -277,9 +278,21 @@ export class GroqProvider extends BaseProvider { private formatRequest(request: LLMRequest): GroqRequest { const messages: GroqMessage[] = []; const model = request.model || 'llama-3.3-70b-versatile'; + const usesTools = + (request.tools?.length ?? 0) > 0 || + request.messages.some(message => + (message.toolCalls?.length ?? 0) > 0 || (message.toolResults?.length ?? 0) > 0 + ); const jsonMode = request.response_format?.type === 'json_object'; const jsonInstruction = '\n\nYou must respond with valid JSON only. No markdown fences, no commentary, no text outside the JSON.'; + if (usesTools && !TOOL_CAPABLE_MODELS.has(model)) { + throw new ConfigurationError( + this.name, + `Model '${model}' does not support tool calling on Groq` + ); + } + if (request.systemPrompt) { messages.push({ role: 'system', @@ -334,8 +347,8 @@ export class GroqProvider extends BaseProvider { groqRequest.response_format = request.response_format; } - // Add tools if the model supports them and tools are provided - if (request.tools && request.tools.length > 0 && TOOL_CAPABLE_MODELS.has(model)) { + // Add tools if provided. Unsupported tool models are rejected above. + if (request.tools && request.tools.length > 0) { groqRequest.tools = request.tools.map(t => ({ type: 'function', function: {