Skip to content

Commit de89b24

Browse files
committed
Add BYOK support for user AI provider keys
Introduce Bring-Your-Own-Key (BYOK) support so users can save encrypted AI provider API keys and use them for requests. Adds AES-256-GCM encryption utilities and ENCRYPTION_KEY env var; a new UserProviderKey entity, DTOs, controller, service and ProviderKeysModule; and a migration to create the user_provider_keys table. Integrates BYOK into the stack: Public API will attach decrypted keys for free-tier users, ai-service accepts ephemeral api_key and ProviderRegistry.get_ephemeral creates one-off provider instances, and QuotaInterceptor bypasses platform quota for users with BYOK keys. Billing now records metered overage usage to Stripe (new subscription fields + migration) for paid tiers. Frontend settings UI updated to manage/save/validate/remove provider keys and show hints; minor UI cleanup and exports adjusted.
1 parent 2a9a3c0 commit de89b24

21 files changed

Lines changed: 1173 additions & 147 deletions

.env.example

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ GEMINI_API_KEY=
4646
HUGGINGFACE_API_KEY=
4747
DEFAULT_AI_PROVIDER=openai
4848

49+
# --- BYOK Encryption (AES-256-GCM key for user-provided API key storage) ---
50+
# Generate with: node -e "console.log(require('crypto').randomBytes(32).toString('hex'))"
51+
ENCRYPTION_KEY=
52+
4953
# --- AI Service ---
5054
AI_SERVICE_URL=http://localhost:8000
5155
AI_SERVICE_PORT=8000

packages/ai-service/app/routers/conversations.py

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""Conversation management and AI chat endpoints."""
22

3-
from fastapi import APIRouter, HTTPException
4-
from pydantic import BaseModel
5-
from typing import Optional
3+
from fastapi import APIRouter, HTTPException # type: ignore
4+
from pydantic import BaseModel # type: ignore
5+
from typing import Any, Dict, Optional
66
from datetime import datetime
77

88
from app.core.database import get_db
@@ -17,10 +17,10 @@
1717

1818
# Initialize providers on import
1919
ProviderRegistry.initialize(
20-
openai_key=settings.OPENAI_API_KEY,
21-
anthropic_key=settings.ANTHROPIC_API_KEY,
22-
gemini_key=settings.GEMINI_API_KEY,
23-
huggingface_key=settings.HUGGINGFACE_API_KEY,
20+
openai_key=settings.OPENAI_API_KEY or "",
21+
anthropic_key=settings.ANTHROPIC_API_KEY or "",
22+
gemini_key=settings.GEMINI_API_KEY or "",
23+
huggingface_key=settings.HUGGINGFACE_API_KEY or "",
2424
)
2525

2626

@@ -37,13 +37,16 @@ class SendMessageRequest(BaseModel):
3737
temperature: Optional[float] = 0.7
3838
max_tokens: Optional[int] = 2048
3939
system_prompt: Optional[str] = None
40+
# BYOK: caller-supplied key (decrypted by the core service,
41+
# transmitted over the internal network — never exposed to browsers).
42+
api_key: Optional[str] = None
4043

4144

4245
@router.post("/conversations")
4346
async def create_conversation(req: CreateConversationRequest):
4447
"""Create a new AI conversation."""
4548
db = get_db()
46-
conversation = {
49+
conversation: Dict[str, Any] = {
4750
"applicationId": req.application_id,
4851
"userId": req.user_id,
4952
"title": req.title,
@@ -61,10 +64,12 @@ async def create_conversation(req: CreateConversationRequest):
6164
@router.get("/conversations/{conversation_id}")
6265
async def get_conversation(conversation_id: str):
6366
"""Get conversation by ID with message history."""
64-
from bson import ObjectId
67+
from bson import ObjectId # type: ignore
6568

6669
db = get_db()
67-
conv = await db.ai_conversations.find_one({"_id": ObjectId(conversation_id)})
70+
conv = await db.ai_conversations.find_one(
71+
{"_id": ObjectId(conversation_id)}
72+
)
6873
if not conv:
6974
raise HTTPException(status_code=404, detail="Conversation not found")
7075

@@ -75,21 +80,42 @@ async def get_conversation(conversation_id: str):
7580
@router.post("/conversations/{conversation_id}/messages")
7681
async def send_message(conversation_id: str, req: SendMessageRequest):
7782
"""Send a message and get an AI response."""
78-
from bson import ObjectId
83+
from bson import ObjectId # type: ignore
7984

8085
db = get_db()
81-
conv = await db.ai_conversations.find_one({"_id": ObjectId(conversation_id)})
86+
conv = await db.ai_conversations.find_one(
87+
{"_id": ObjectId(conversation_id)}
88+
)
8289
if not conv:
8390
raise HTTPException(status_code=404, detail="Conversation not found")
8491

8592
# Determine provider
8693
provider_name = req.provider or settings.DEFAULT_AI_PROVIDER
87-
provider = ProviderRegistry.get(provider_name)
88-
if not provider:
89-
raise HTTPException(
90-
status_code=400,
91-
detail=f"Provider '{provider_name}' not available. Configure API key.",
92-
)
94+
95+
# If a BYOK key was supplied by the core service, create a short-lived
96+
# ephemeral provider instance (not cached — avoids key leaks
97+
# across requests).
98+
if req.api_key:
99+
provider = ProviderRegistry.get_ephemeral(provider_name, req.api_key)
100+
if not provider:
101+
raise HTTPException(
102+
status_code=400,
103+
detail=(
104+
f"Provider '{provider_name}' is not"
105+
" supported for BYOK."
106+
),
107+
)
108+
else:
109+
provider = ProviderRegistry.get(provider_name)
110+
if not provider:
111+
raise HTTPException(
112+
status_code=400,
113+
detail=(
114+
f"Provider '{provider_name}' not available."
115+
" Configure API key in Settings → AI"
116+
" Providers."
117+
),
118+
)
93119

94120
# Build message history
95121
messages = []
@@ -113,7 +139,10 @@ async def send_message(conversation_id: str, req: SendMessageRequest):
113139
try:
114140
response = await provider.chat(chat_request)
115141
except Exception as e:
116-
raise HTTPException(status_code=500, detail=f"AI provider error: {str(e)}")
142+
raise HTTPException(
143+
status_code=500,
144+
detail=f"AI provider error: {str(e)}",
145+
)
117146

118147
# Store messages
119148
user_msg = {

packages/ai-service/app/services/ai_providers.py

Lines changed: 77 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ async def chat(self, request: ChatRequest) -> ChatResponse:
4343
pass
4444

4545
@abstractmethod
46-
async def chat_stream(self, request: ChatRequest) -> AsyncGenerator[str, None]:
46+
def chat_stream(
47+
self, request: ChatRequest
48+
) -> AsyncGenerator[str, None]:
4749
pass
4850

4951

@@ -60,13 +62,19 @@ def name(self) -> str:
6062

6163
@property
6264
def available_models(self) -> list[str]:
63-
return ["gpt-4", "gpt-4-turbo", "gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo"]
65+
return [
66+
"gpt-4", "gpt-4-turbo", "gpt-4o",
67+
"gpt-4o-mini", "gpt-3.5-turbo",
68+
]
6469

6570
async def chat(self, request: ChatRequest) -> ChatResponse:
6671
model = request.model or "gpt-4"
6772
response = await self.client.chat.completions.create(
6873
model=model,
69-
messages=[{"role": m.role, "content": m.content} for m in request.messages],
74+
messages=[ # type: ignore
75+
{"role": m.role, "content": m.content} # type: ignore
76+
for m in request.messages
77+
],
7078
temperature=request.temperature,
7179
max_tokens=request.max_tokens,
7280
)
@@ -75,22 +83,35 @@ async def chat(self, request: ChatRequest) -> ChatResponse:
7583
model=model,
7684
provider=self.name,
7785
usage={
78-
"prompt_tokens": response.usage.prompt_tokens if response.usage else 0,
79-
"completion_tokens": response.usage.completion_tokens if response.usage else 0,
80-
"total_tokens": response.usage.total_tokens if response.usage else 0,
86+
"prompt_tokens": (
87+
response.usage.prompt_tokens if response.usage else 0
88+
),
89+
"completion_tokens": (
90+
response.usage.completion_tokens
91+
if response.usage
92+
else 0
93+
),
94+
"total_tokens": (
95+
response.usage.total_tokens if response.usage else 0
96+
),
8197
},
8298
)
8399

84-
async def chat_stream(self, request: ChatRequest) -> AsyncGenerator[str, None]:
100+
async def chat_stream(
101+
self, request: ChatRequest
102+
) -> AsyncGenerator[str, None]:
85103
model = request.model or "gpt-4"
86104
stream = await self.client.chat.completions.create(
87105
model=model,
88-
messages=[{"role": m.role, "content": m.content} for m in request.messages],
106+
messages=[ # type: ignore
107+
{"role": m.role, "content": m.content} # type: ignore
108+
for m in request.messages
109+
],
89110
temperature=request.temperature,
90111
max_tokens=request.max_tokens,
91112
stream=True,
92113
)
93-
async for chunk in stream:
114+
async for chunk in stream: # type: ignore[union-attr]
94115
if chunk.choices[0].delta.content:
95116
yield chunk.choices[0].delta.content
96117

@@ -126,20 +147,24 @@ async def chat(self, request: ChatRequest) -> ChatResponse:
126147
model=model,
127148
max_tokens=request.max_tokens,
128149
system=system if system else "You are a helpful assistant.",
129-
messages=messages,
150+
messages=messages, # type: ignore[arg-type]
130151
)
131152
return ChatResponse(
132-
content=response.content[0].text,
153+
content=response.content[0].text, # type: ignore[union-attr]
133154
model=model,
134155
provider=self.name,
135156
usage={
136157
"prompt_tokens": response.usage.input_tokens,
137158
"completion_tokens": response.usage.output_tokens,
138-
"total_tokens": response.usage.input_tokens + response.usage.output_tokens,
159+
"total_tokens": (
160+
response.usage.input_tokens + response.usage.output_tokens
161+
),
139162
},
140163
)
141164

142-
async def chat_stream(self, request: ChatRequest) -> AsyncGenerator[str, None]:
165+
async def chat_stream(
166+
self, request: ChatRequest
167+
) -> AsyncGenerator[str, None]:
143168
model = request.model or "claude-sonnet-4-5-20250929"
144169
system = ""
145170
messages = []
@@ -153,7 +178,7 @@ async def chat_stream(self, request: ChatRequest) -> AsyncGenerator[str, None]:
153178
model=model,
154179
max_tokens=request.max_tokens,
155180
system=system if system else "You are a helpful assistant.",
156-
messages=messages,
181+
messages=messages, # type: ignore[arg-type]
157182
) as stream:
158183
async for text in stream.text_stream:
159184
yield text
@@ -163,7 +188,7 @@ class GeminiProvider(AIProvider):
163188
"""Google Gemini provider."""
164189

165190
def __init__(self, api_key: str):
166-
import google.generativeai as genai
191+
import google.generativeai as genai # type: ignore
167192
self._genai = genai
168193
genai.configure(api_key=api_key)
169194

@@ -175,8 +200,12 @@ def name(self) -> str:
175200
def available_models(self) -> list[str]:
176201
return ["gemini-2.0-flash", "gemini-1.5-pro", "gemini-1.5-flash"]
177202

178-
def _build_contents(self, messages: list[ChatMessage]) -> tuple[list[dict], str]:
179-
"""Convert ChatMessages to Gemini format, extracting system instruction."""
203+
def _build_contents(
204+
self, messages: list[ChatMessage]
205+
) -> tuple[list[dict], str]:
206+
"""Convert ChatMessages to Gemini format, extracting system
207+
instruction.
208+
"""
180209
system_instruction = ""
181210
contents = []
182211
for m in messages:
@@ -226,7 +255,9 @@ async def chat(self, request: ChatRequest) -> ChatResponse:
226255
usage=usage,
227256
)
228257

229-
async def chat_stream(self, request: ChatRequest) -> AsyncGenerator[str, None]:
258+
async def chat_stream(
259+
self, request: ChatRequest
260+
) -> AsyncGenerator[str, None]:
230261
model_name = request.model or "gemini-2.0-flash"
231262
contents, system_instruction = self._build_contents(request.messages)
232263

@@ -322,7 +353,9 @@ async def chat(self, request: ChatRequest) -> ChatResponse:
322353
usage={},
323354
)
324355

325-
async def chat_stream(self, request: ChatRequest) -> AsyncGenerator[str, None]:
356+
async def chat_stream(
357+
self, request: ChatRequest
358+
) -> AsyncGenerator[str, None]:
326359
model = request.model or "mistralai/Mistral-7B-Instruct-v0.3"
327360
prompt = self._build_prompt(request.messages)
328361

@@ -342,7 +375,6 @@ async def chat_stream(self, request: ChatRequest) -> AsyncGenerator[str, None]:
342375
json=payload,
343376
) as response:
344377
response.raise_for_status()
345-
buffer = ""
346378
async for line in response.aiter_lines():
347379
if line.startswith("data:"):
348380
import json
@@ -356,7 +388,9 @@ async def chat_stream(self, request: ChatRequest) -> AsyncGenerator[str, None]:
356388

357389

358390
class ProviderRegistry:
359-
"""Registry of available AI providers."""
391+
"""Registry of available AI providers (platform keys, initialized at
392+
startup).
393+
"""
360394

361395
_providers: dict[str, AIProvider] = {}
362396

@@ -378,10 +412,10 @@ def list_providers(cls) -> list[dict]:
378412
@classmethod
379413
def initialize(
380414
cls,
381-
openai_key: str = None,
382-
anthropic_key: str = None,
383-
gemini_key: str = None,
384-
huggingface_key: str = None,
415+
openai_key: Optional[str] = None,
416+
anthropic_key: Optional[str] = None,
417+
gemini_key: Optional[str] = None,
418+
huggingface_key: Optional[str] = None,
385419
):
386420
if openai_key:
387421
cls.register(OpenAIProvider(openai_key))
@@ -391,3 +425,21 @@ def initialize(
391425
cls.register(GeminiProvider(gemini_key))
392426
if huggingface_key:
393427
cls.register(HuggingFaceProvider(huggingface_key))
428+
429+
@classmethod
430+
def get_ephemeral(cls, name: str, api_key: str) -> Optional[AIProvider]:
431+
"""
432+
Create a one-time, non-cached provider instance using a
433+
caller-supplied key. Used for BYOK (Bring Your Own Key) requests
434+
so user keys are never stored in the registry and do not leak
435+
across requests.
436+
"""
437+
if name == "openai":
438+
return OpenAIProvider(api_key)
439+
if name == "anthropic":
440+
return AnthropicProvider(api_key)
441+
if name == "gemini":
442+
return GeminiProvider(api_key)
443+
if name == "huggingface":
444+
return HuggingFaceProvider(api_key)
445+
return None

packages/core/src/app.module.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import { SsoModule } from "./modules/sso/sso.module";
3333
import { DataExportModule } from "./modules/data-export/data-export.module";
3434
import { SystemHealthModule } from "./modules/system-health/system-health.module";
3535
import { StripeModule } from "./modules/stripe/stripe.module";
36+
import { ProviderKeysModule } from "./modules/provider-keys/provider-keys.module";
3637
import { ScheduleModule } from "@nestjs/schedule";
3738

3839
@Module({
@@ -131,6 +132,7 @@ import { ScheduleModule } from "@nestjs/schedule";
131132
SsoModule,
132133
DataExportModule,
133134
SystemHealthModule,
135+
ProviderKeysModule,
134136
],
135137
})
136138
export class AppModule {}

0 commit comments

Comments
 (0)