Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion src/__tests__/cerebras.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import { describe, it, expect, beforeEach, vi } from 'vitest';
import { CerebrasProvider } from '../providers/cerebras';
import { AuthenticationError } from '../errors';
import { AuthenticationError, ConfigurationError } from '../errors';
import { defaultCircuitBreakerManager } from '../utils/circuit-breaker';
import type { LLMRequest } from '../types';

Expand Down Expand Up @@ -208,6 +208,24 @@ describe('CerebrasProvider', () => {
model: 'gpt-4'
})).rejects.toThrow("Model 'gpt-4' not supported");
});

it('should reject tools for non-tool-capable models', async () => {
await expect(provider.generateResponse({
messages: [{ role: 'user', content: 'What is the weather?' }],
model: 'llama-3.1-8b',
tools: [{
type: 'function',
function: {
name: 'get_weather',
description: 'Get current weather',
parameters: { type: 'object', properties: { location: { type: 'string' } } }
}
}],
toolChoice: 'auto'
})).rejects.toBeInstanceOf(ConfigurationError);

expect(mockFetch).not.toHaveBeenCalled();
});
});

describe('estimateCost', () => {
Expand Down
188 changes: 158 additions & 30 deletions src/__tests__/factory.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
* Tests for the provider factory with mocked providers
*/

import { describe, it, expect, beforeEach, vi } from 'vitest';
import { LLMProviderFactory, createCostOptimizedFactory } from '../factory';
import { AuthenticationError } from '../errors';
import type { LLMRequest, LLMResponse } from '../types';
import { defaultCostTracker } from '../utils/cost-tracker';
import { defaultCircuitBreakerManager } from '../utils/circuit-breaker';
import { defaultExhaustionRegistry } from '../utils/exhaustion';
import { describe, it, expect, beforeEach, vi } from 'vitest';
import { LLMProviderFactory, createCostOptimizedFactory } from '../factory';
import { AuthenticationError } from '../errors';
import type { LLMRequest, LLMResponse } from '../types';
import { OpenAIProvider } from '../providers/openai';
import { CreditLedger } from '../utils/credit-ledger';
import { defaultCostTracker } from '../utils/cost-tracker';
import { defaultCircuitBreakerManager } from '../utils/circuit-breaker';
import { defaultExhaustionRegistry } from '../utils/exhaustion';
import { defaultLatencyHistogram } from '../utils/latency-histogram';

// Mock providers
Expand Down Expand Up @@ -122,15 +124,37 @@ describe('LLMProviderFactory', () => {
temperature: 0.7
};

beforeEach(() => {
vi.clearAllMocks();
defaultCostTracker.reset();
defaultCircuitBreakerManager.resetAll();
defaultExhaustionRegistry.reset();
defaultLatencyHistogram.reset();

factory = new LLMProviderFactory({
openai: { apiKey: 'test-openai-key' },
beforeEach(() => {
vi.clearAllMocks();
defaultCostTracker.reset();
defaultCircuitBreakerManager.resetAll();
defaultExhaustionRegistry.reset();
defaultLatencyHistogram.reset();

mockOpenAIProvider.generateResponse.mockReset().mockResolvedValue({
message: 'OpenAI response',
usage: { inputTokens: 10, outputTokens: 20, totalTokens: 30, cost: 0.001 },
model: 'gpt-3.5-turbo',
provider: 'openai',
responseTime: 1000
} as LLMResponse);
mockAnthropicProvider.generateResponse.mockReset().mockResolvedValue({
message: 'Anthropic response',
usage: { inputTokens: 10, outputTokens: 20, totalTokens: 30, cost: 0.002 },
model: 'claude-3-haiku-20240307',
provider: 'anthropic',
responseTime: 1200
} as LLMResponse);
mockCloudflareProvider.generateResponse.mockReset().mockResolvedValue({
message: 'Cloudflare response',
usage: { inputTokens: 10, outputTokens: 20, totalTokens: 30, cost: 0.0001 },
model: '@cf/meta/llama-3.1-8b-instruct',
provider: 'cloudflare',
responseTime: 800
} as LLMResponse);

factory = new LLMProviderFactory({
openai: { apiKey: 'test-openai-key' },
anthropic: { apiKey: 'test-anthropic-key' },
cloudflare: { ai: {} as Ai },
defaultProvider: 'auto',
Expand Down Expand Up @@ -267,18 +291,62 @@ describe('LLMProviderFactory', () => {
expect(mockAnthropicProvider.generateResponse).toHaveBeenCalled();
});

it('should build cost-optimized chain for auto mode', async () => {
const autoRequest: LLMRequest = {
...testRequest
it('should build cost-optimized chain for auto mode', async () => {
const autoRequest: LLMRequest = {
...testRequest
// No specific model, should use auto selection
};

await factory.generateResponse(autoRequest);

// Should prefer Cloudflare for cost optimization
expect(mockCloudflareProvider.generateResponse).toHaveBeenCalled();
});
});
// Should prefer Cloudflare for cost optimization
expect(mockCloudflareProvider.generateResponse).toHaveBeenCalled();
});

it('should honor fallbackProvider as the next route when a rule matches', async () => {
const ruleFactory = new LLMProviderFactory({
openai: { apiKey: 'test-openai-key' },
anthropic: { apiKey: 'test-anthropic-key' },
cloudflare: { ai: {} as Ai },
defaultProvider: 'openai',
costOptimization: false,
fallbackRules: [{ condition: 'error', fallbackProvider: 'anthropic' }]
});

mockOpenAIProvider.generateResponse.mockRejectedValueOnce(new Error('OpenAI down'));

const response = await ruleFactory.generateResponse(testRequest);

expect(response.provider).toBe('anthropic');
expect(mockAnthropicProvider.generateResponse).toHaveBeenCalled();
expect(mockCloudflareProvider.generateResponse).not.toHaveBeenCalled();
});

it('should apply fallbackModel when routing through a fallback rule', async () => {
const ruleFactory = new LLMProviderFactory({
openai: { apiKey: 'test-openai-key' },
anthropic: { apiKey: 'test-anthropic-key' },
defaultProvider: 'openai',
costOptimization: false,
fallbackRules: [{
condition: 'error',
fallbackProvider: 'anthropic',
fallbackModel: 'claude-3-haiku-20240307'
}]
});

mockOpenAIProvider.generateResponse.mockRejectedValueOnce(new Error('OpenAI down'));

await ruleFactory.generateResponse({
...testRequest,
model: 'gpt-4'
});

expect(mockAnthropicProvider.generateResponse).toHaveBeenCalledWith(
expect.objectContaining({ model: 'claude-3-haiku-20240307' })
);
});
});

describe('Error Handling', () => {
it('should handle all providers failing', async () => {
Expand All @@ -302,7 +370,7 @@ describe('LLMProviderFactory', () => {
});
});

describe('Configuration Updates', () => {
describe('Configuration Updates', () => {
it('should update configuration', () => {
factory.updateConfig({
defaultProvider: 'anthropic',
Expand All @@ -314,15 +382,75 @@ describe('LLMProviderFactory', () => {
expect(() => factory.updateConfig({})).not.toThrow();
});

it('should reset metrics and circuit breakers', () => {
factory.reset();
it('should reset metrics and circuit breakers', () => {
factory.reset();

expect(mockOpenAIProvider.resetMetrics).toHaveBeenCalled();
expect(mockAnthropicProvider.resetMetrics).toHaveBeenCalled();
expect(mockCloudflareProvider.resetMetrics).toHaveBeenCalled();
});
});
});
expect(mockCloudflareProvider.resetMetrics).toHaveBeenCalled();
});

it('should pass maxRetries 0 to providers when factory retries are disabled', () => {
new LLMProviderFactory({
openai: { apiKey: 'test-key' },
enableRetries: false
});

expect(OpenAIProvider).toHaveBeenCalledWith(
expect.objectContaining({
apiKey: 'test-key',
maxRetries: 0
})
);
});

it('should preserve explicit provider maxRetries when factory retries are disabled', () => {
new LLMProviderFactory({
openai: { apiKey: 'test-key', maxRetries: 2 },
enableRetries: false
});

expect(OpenAIProvider).toHaveBeenCalledWith(
expect.objectContaining({
apiKey: 'test-key',
maxRetries: 2
})
);
});
});

describe('CreditLedger integration', () => {
it('should record successful factory calls into the provided ledger even without cost optimization', async () => {
const ledger = new CreditLedger({
budgets: [{
provider: 'cloudflare',
monthlyBudget: 1,
rateLimits: { rpm: 10, rpd: 100, tpm: 1000, tpd: 10_000 }
}]
});

const ledgerFactory = new LLMProviderFactory({
cloudflare: { ai: {} as Ai },
costOptimization: false,
ledger
});

await ledgerFactory.generateResponse(testRequest);

const accumulator = ledger.getProviderAccumulator('cloudflare');
expect(accumulator).toMatchObject({
spend: 0.0001,
inputTokens: 10,
outputTokens: 20,
requestCount: 1
});
expect(accumulator!.rateLimits.rpm!.used).toBe(1);
expect(accumulator!.rateLimits.rpd!.used).toBe(1);
expect(accumulator!.rateLimits.tpm!.used).toBe(30);
expect(accumulator!.rateLimits.tpd!.used).toBe(30);
});
});
});

describe('Cost Optimized Factory', () => {
beforeEach(() => {
Expand Down
29 changes: 5 additions & 24 deletions src/__tests__/groq.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import { describe, it, expect, beforeEach, vi } from 'vitest';
import { GroqProvider } from '../providers/groq';
import { AuthenticationError } from '../errors';
import { AuthenticationError, ConfigurationError } from '../errors';
import { defaultCircuitBreakerManager } from '../utils/circuit-breaker';
import type { LLMRequest } from '../types';

Expand Down Expand Up @@ -365,32 +365,13 @@ describe('GroqProvider', () => {
expect(toolMsg.content).toBe('{"temp": 15, "condition": "cloudy"}');
});

it('should not include tools for non-tool-capable models', async () => {
mockFetch.mockResolvedValueOnce({
ok: true,
json: async () => ({
id: 'chatcmpl-123',
object: 'chat.completion',
created: 1700000000,
model: 'llama-3.1-8b-instant',
choices: [{
index: 0,
message: { role: 'assistant', content: 'I cannot call tools.' },
finish_reason: 'stop'
}],
usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 }
}),
headers: new Headers({ 'content-type': 'application/json' })
});

await provider.generateResponse({
it('should reject tools for non-tool-capable models', async () => {
await expect(provider.generateResponse({
...toolRequest,
model: 'llama-3.1-8b-instant'
});
})).rejects.toBeInstanceOf(ConfigurationError);

const body = JSON.parse(mockFetch.mock.calls[0][1].body);
expect(body.tools).toBeUndefined();
expect(body.tool_choice).toBeUndefined();
expect(mockFetch).not.toHaveBeenCalled();
});

it('should support tool calling on llama-3.3-70b-versatile', async () => {
Expand Down
Loading
Loading