-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathchat.py
More file actions
823 lines (709 loc) · 32 KB
/
chat.py
File metadata and controls
823 lines (709 loc) · 32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
"""
Chat API 模块 - 实现 OpenAI 兼容的 Chat API endpoint, 支持流式响应、对话持久化、自动标题生成、推理内容分离等功能
"""
import asyncio
import json
from typing import Optional, Dict, Any, List, Callable
from fastapi import HTTPException, status, Request
from fastapi.responses import StreamingResponse
from openai import AsyncOpenAI
from auth import User
from history import (
# get_history_recent,
append_history,
update_chat_title,
get_history_full,
get_combined_history,
ChatNotFoundError,
ChatAccessDeniedError,
DatabaseError,
rollback_history
)
from appinit import (
MAIN_LLM_MODEL,
LITE_LLM_MODEL,
DASHSCOPE_BASE_URL,
DASHSCOPE_API_KEY,
)
from utils.apikeytool import get_user_api_key, decrypt_api_key_record
from utils.usage import normalize_usage
from utils.load_prompts import get_prompt
from utils.load_models import get_models
# 标题生成超时时间(秒)
TITLE_GENERATION_TIMEOUT = 20
# 检查 API Key 是否配置
if not DASHSCOPE_API_KEY:
print("警告:未配置有效的 DASHSCOPE_API_KEY,请设置环境变量或在 config.yaml 中配置")
def _check_request_model(
requested_model: Optional[str],
chat_current_model: Optional[str]
) -> str:
"""检查并确定使用的模型
优先级:
1. 请求中的 model 字段(如果存在于 models.json 中)
2. chat 记录的 current_model 字段(如果存在于 models.json 中)
3. 系统默认 MAIN_LLM_MODEL
Args:
requested_model: 请求中的 model 字段
chat_current_model: chat 记录的 current_model 字段
Returns:
str: 实际使用的模型 ID
"""
# 获取有效的模型列表
models_data = get_models()
valid_model_ids = [model["id"] for model in models_data.get("data", [])]
# 1. 检查请求中的 model
if requested_model and requested_model in valid_model_ids:
return requested_model
# 2. 检查 chat 的 current_model
if chat_current_model and chat_current_model in valid_model_ids:
return chat_current_model
# 3.看看MAIN_LLM_MODEL
if MAIN_LLM_MODEL in valid_model_ids:
return MAIN_LLM_MODEL
# 最后兜底:取第一个可用模型
if valid_model_ids:
return next(iter(valid_model_ids))
# 这次是真的兜底了,返回 MAIN_LLM_MODEL(即使它不在 models.json 中,也保证有一个默认值),并炸掉
print("警告:模型列表无效且主模型配置错误,返回 MAIN_LLM_MODEL 作为默认值,请检查模型配置")
return MAIN_LLM_MODEL
def get_openai_client(api_key: Optional[str] = None, base_url: Optional[str] = None) -> AsyncOpenAI:
"""创建 OpenAI 客户端
Args:
api_key: API Key,如果提供则使用传入的 key,否则使用全局配置
base_url: API Base URL,如果提供则使用传入的 URL,否则使用全局配置
Returns:
AsyncOpenAI: OpenAI 客户端
"""
return AsyncOpenAI(
api_key=api_key if api_key else DASHSCOPE_API_KEY,
base_url=base_url if base_url else DASHSCOPE_BASE_URL
)
def extract_last_user_message(messages: List[Dict[str, Any]]) -> Optional[str]:
"""从 messages 中提取最后一条 user message"""
if not messages:
return None
# 从后往前找第一条 user message
for msg in reversed(messages):
if isinstance(msg, dict) and msg.get("role") == "user":
content = msg.get("content", "")
if isinstance(content, str):
return content
elif isinstance(content, list):
# 处理多模态内容
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
return item.get("text", "")
# 如果没有 text 类型,尝试获取第一个元素的 text
if content:
first_item = content[0]
if isinstance(first_item, dict):
result = first_item.get("text", "") or first_item.get("content", "")
return str(result) if result else None
return str(content) if content else None
return None
async def generate_chat_title(
chat_id: str,
messages: List[Dict[str, Any]],
response_content: str,
api_key: Optional[str] = None,
base_url: Optional[str] = None
):
"""异步生成对话标题
使用轻量 LLM 模型总结第一轮对话内容,生成简短标题(不超过 15 个中文字符)
Args:
chat_id: 对话 ID
messages: 对话消息列表
response_content: AI 回复内容
api_key: API Key(BYOK 用户使用),如果为 None 则使用全局配置
base_url: API Base URL(BYOK 用户使用),如果为 None 则使用全局配置
"""
try:
# 构建用于总结的 prompt
user_message = None
for msg in reversed(messages):
if msg.get("role") == "user":
user_message = msg.get("content", "")
break
if not user_message:
return
# 构建总结请求
system_prompt = get_prompt("generate_title")
summary_prompt = f"""'用户提问':{user_message}
'AI 回答':{response_content[:500]}"""
# 创建客户端用于标题生成(使用传入的 api_key 和 base_url)
title_client = get_openai_client(api_key=api_key, base_url=base_url)
# 调用 LLM API(非流式,带超时)
response = await asyncio.wait_for(
title_client.chat.completions.create(
model=LITE_LLM_MODEL,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": summary_prompt}
],
max_tokens=50,
temperature=0.1,
extra_body={
"enable_thinking": False
}
),
timeout=TITLE_GENERATION_TIMEOUT
)
# 提取生成的标题
title = None
if response.choices:
content = response.choices[0].message.content
if content:
title = content.strip()
if title:
# 确保标题不超过 24 个中文字符
if len(title) > 24:
title = title[:24]
# 更新对话标题(不传入 user_uuid,因为这是后台自动操作)
# 如果对话不存在或无权访问,异常会被静默捕获
update_chat_title(chat_id, title)
print(f"对话标题已更新:{title}")
except asyncio.TimeoutError:
print(f"标题生成超时({TITLE_GENERATION_TIMEOUT}秒),跳过标题生成")
except (ChatNotFoundError, ChatAccessDeniedError):
# 标题生成是后台异步操作,对话可能已被删除,静默处理异常
pass
except Exception as e:
print(f"标题生成失败:{e}")
def _build_chunk(
chunk_id: str,
created: int,
model: str,
delta: Dict[str, Any],
finish_reason: Optional[str] = None
) -> Dict[str, Any]:
"""构建 SSE chunk 数据
OpenAI 标准行为:
- 第一个 chunk 仅包含 role 字段用于初始化
- 后续 chunk 仅包含 content 或 reasoning_content 字段
- 最后一个 chunk 包含 finish_reason
Args:
chunk_id: chunk ID
created: 时间戳
model: 模型名称
delta: delta 内容(包含 role/content/reasoning_content)
finish_reason: 结束原因(仅最后一个 chunk 需要)
Returns:
Dict: 构建好的 chunk 数据
"""
return {
"id": chunk_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [{
"index": 0,
"delta": delta,
"finish_reason": finish_reason
}]
}
async def stream_response(
request: Request,
client: AsyncOpenAI,
messages: List[Dict[str, Any]],
enable_thinking: Optional[bool] = False,
persist: bool = False,
chat_id: Optional[str] = None,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
model: str = MAIN_LLM_MODEL,
regenerate_mode: bool = False,
rollback_callback: Optional[Callable[[], None]] = None
):
"""流式转发 LLM 响应,同时解析 usage 信息
推理内容(reasoning_content)也会流式转发到前端,同时监听 reasoning,统一使用 reasoning_content 字段转发
response_content 只记录最终回答内容,不包含推理内容
OpenAI 兼容行为:
- 第一个 chunk 发送 role: "assistant" 用于前端初始化
- 后续 chunk 仅发送 content 或 reasoning_content,不包含 role
Args:
request: FastAPI 请求对象
client: AsyncOpenAI 客户端
messages: 消息列表
enable_thinking: 是否启用推理,默认 False
persist: 是否持久化对话,默认 False
chat_id: 对话 ID(仅在 persist=True 时使用)
api_key: API Key(BYOK 用户使用),如果为 None 则使用全局配置
base_url: API Base URL(BYOK 用户使用),如果为 None 则使用全局配置
model: 使用的模型 ID
regenerate_mode: 是否为重新生成模式,默认 False
rollback_callback: 回滚回调函数(仅在 regenerate_mode=True 时使用)
"""
# 构建请求参数 - enable_thinking 在 extra_body 中,与 messages 同一层级
chat_kwargs = {
"model": model,
"messages": messages,
"stream": True,
"stream_options": {"include_usage": True},
"extra_body": {
"enable_search": True,
"enable_thinking": enable_thinking # 默认 False,如果传入 True 则启用
}
}
try:
stream = await client.chat.completions.create(**chat_kwargs)
# 用于暂存回答和 usage 信息
response_content = ""
usage_info = {}
# 标记是否已发送 role(OpenAI 标准:仅第一个 chunk 发送 role)
role_sent = False
# 用于记录上游返回的实际 model
upstream_model: Optional[str] = None
async for chunk in stream:
# 检查前端连接是否断开
if await request.is_disconnected():
print("前端连接断开,终止流式传输")
return
# 记录上游返回的 model(用于最后一个 chunk)
if hasattr(chunk, 'model') and chunk.model:
upstream_model = chunk.model
# 解析 chunk
if chunk.choices and len(chunk.choices) > 0:
delta = chunk.choices[0].delta
chunk_data = None
# 检查是否有 reasoning_content(推理内容)
reasoning_content = None
if hasattr(delta, 'reasoning_content') and delta.reasoning_content is not None:
reasoning_content = delta.reasoning_content
elif hasattr(delta, 'reasoning') and delta.reasoning is not None:
reasoning_content = delta.reasoning # 兼容新 vllm 字段
# 1. 优先处理 role chunk(独立发送)
# 无需判断上游是否有 role,只要没发过 role 就直接发送
if not role_sent:
chunk_data = _build_chunk(
chunk_id=chunk.id,
created=chunk.created,
model=chunk.model,
delta={"role": "assistant"}
)
yield f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"
role_sent = True
# 2. 处理 reasoning_content(独立于 content)
if reasoning_content is not None:
chunk_data = _build_chunk(
chunk_id=chunk.id,
created=chunk.created,
model=chunk.model,
delta={"reasoning_content": reasoning_content}
)
yield f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"
# 3. 独立处理 content(使用并列 if,不与 reasoning 互斥)
if hasattr(delta, 'content') and delta.content is not None:
# 累积记录本次回答内容(不包含推理内容)
response_content += delta.content
chunk_data = _build_chunk(
chunk_id=chunk.id,
created=chunk.created,
model=chunk.model,
delta={"content": delta.content}
)
yield f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"
# 检查 usage 信息,使用标准化函数解析
if hasattr(chunk, 'usage') and chunk.usage:
usage_info = normalize_usage(chunk.usage)
# 构建并发送 usage 信息(如有)
if usage_info:
final_model = model # 防止上游因模型路由导致返回的 model 与请求不一致,优先使用请求的 model 以确保前端和后续处理逻辑能够使用正确的模型id
final_chunk = {
"id": "final",
"object": "chat.completion.chunk",
"created": 0,
"model": final_model,
"choices": [{
"index": 0,
"delta": {},
"finish_reason": "stop"
}],
"usage": usage_info
}
yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n"
print(f"Usage 信息:{usage_info}")
# 重新生成模式:先执行回滚回调,再发送 [DONE],然后持久化
# 回滚回调会抛出异常如果失败
if regenerate_mode and rollback_callback and chat_id:
# 1. 先执行数据库回滚(如果失败会抛出异常)
print("重新生成模式:执行数据库回滚...")
rollback_callback()
print("重新生成模式:数据库回滚成功")
# 发送结束标记
yield "data: [DONE]\n\n"
# 完成流式传输后,将对话持久化到数据库(如果启用)
if persist and chat_id and response_content:
# 从 messages 中提取最后一条 user message
user_message = None
for msg in reversed(messages):
if msg.get("role") == "user":
user_message = msg.get("content", "")
break
if user_message:
# 追加到历史记录(传入 api_key、base_url 和 model 用于 BYOK 用户的摘要生成和模型记录)
append_history(
chat_id, user_message, response_content,
api_key=api_key, base_url=base_url, model=model
)
print(f"对话已保存到 {chat_id}")
# 检查是否是第一轮对话(history_full 中只有一轮对话)
history_full = get_history_full(chat_id) or []
# 每轮对话包含 2 条消息(user + assistant),所以 2 条消息表示第一轮对话
if len(history_full) == 2:
# 异步生成标题(不阻塞当前请求)
asyncio.create_task(
generate_chat_title(
chat_id, messages, response_content,
api_key=api_key, base_url=base_url
)
)
except asyncio.CancelledError:
print("请求被取消")
raise
except Exception as e:
print(f"LLM API 请求错误:{e}")
# 向前端发送错误信息 - 使用 json.dumps 确保正确的 JSON 格式
error_chunk = {
"error": {
"message": str(e),
"type": "upstream_error"
}
}
yield f"data: {json.dumps(error_chunk, ensure_ascii=False)}\n\n"
async def handle_chat_completion(
request: Request,
current_user: User,
chat_id: Optional[str] = None
) -> StreamingResponse:
"""处理 Chat 补全请求
支持 OpenAI Chat Completions API 标准格式
接受的参数包括:messages, model, temperature, max_tokens, top_p,
frequency_penalty, presence_penalty, stop, stream, user, enable_thinking 等
其中只有 messages,model 和 enable_thinking 会被使用,其他参数会被接收但忽略
BYOK 用户处理逻辑:
1. 检查用户角色,如果不是 admin 则为 BYOK 用户
2. 获取用户的 API Key 记录
3. 检查 key 状态(只有 valid 才能使用,pending/invalid 都返回 401)
4. 解密 API Key 用于请求上游 LLM
Args:
request: FastAPI 请求对象
current_user: 当前认证用户
chat_id: 对话 ID(如果提供则启用持久化)
Returns:
StreamingResponse: 流式响应
Raises:
HTTPException: 401 (API Key 无效/未配置/pending), 429 (配额不足)
"""
try:
# 解析请求体
body = await request.json()
# 必需参数:messages
messages = body.get("messages", [])
# 可选参数:model (OpenAI 标准字段)
requested_model = body.get("model")
# 可选参数:enable_thinking (自定义参数)
# 支持两种格式:顶层 enable_thinking 或嵌套在 extra_body 中
enable_thinking = body.get("enable_thinking")
if enable_thinking is None:
extra_body = body.get("extra_body")
if isinstance(extra_body, dict):
enable_thinking = extra_body.get("enable_thinking")
# 验证必需参数
if not messages or not isinstance(messages, list):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="请求必须包含 messages 数组"
)
# 提取最后一条 user message
last_user_message = extract_last_user_message(messages)
if not last_user_message:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="messages 中必须包含至少一条 user role 的消息"
)
# =========================================================================
# BYOK 用户 API Key 检查
# =========================================================================
api_key: Optional[str] = None
base_url: Optional[str] = None
# 如果用户不是 admin,则为 BYOK 用户,需要获取并检查 API Key
if current_user.role != "admin":
# 获取用户的 API Key 记录
api_key_record = get_user_api_key(current_user)
if api_key_record is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="未找到 API Key 记录,请配置您的 API Key"
)
# 检查 API Key 状态
# pending 和 invalid 都视为无效 key,返回 401
if api_key_record.status in ["pending", "invalid"]:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="您的 API Key 无效或待验证,请检查您的 API Key"
)
if api_key_record.status == "quota":
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="您的 API Key 配额不足,请检查您的 API Key"
)
# 状态为 valid,解密 API Key
# 注意:明文 API Key 仅在内存中使用,不得记录到日志
try:
api_key = decrypt_api_key_record(api_key_record)
base_url = api_key_record.base_url
except Exception as e:
print(f"API Key 解密失败:{e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="API Key 解密失败"
)
# =========================================================================
# 确定使用的模型
if chat_id:
# 获取 chat 的 current_model
from history import get_chat_by_id_and_uuid
chat = get_chat_by_id_and_uuid(chat_id, current_user.uuid)
chat_current_model = chat.current_model if chat else None
else:
chat_current_model = None
# 使用 _check_request_model 确定实际使用的模型
actual_model = _check_request_model(requested_model, chat_current_model)
# 构建发送给 LLM 的 messages
if chat_id:
# 带持久化模式:获取压缩历史和最近历史并拼接
# 格式:system prompt + history_compressed + history_recent + user prompt
compressed_messages, history_recent = get_combined_history(chat_id)
llm_messages = [
{"role": "system", "content": get_prompt("chat_system")}, # 持久化对话的 system prompt
]
llm_messages.extend(compressed_messages) # 压缩历史(已去除 round 元数据)
llm_messages.extend(history_recent) # 最近历史
llm_messages.append({"role": "user", "content": last_user_message})
else:
# 无状态模式:只发送当前 message
llm_messages = [
{"role": "system", "content": get_prompt("stateless_system")}, # system prompt
{"role": "user", "content": last_user_message}
]
# 创建 OpenAI 客户端(使用 BYOK 用户的 key 或全局配置)
client = get_openai_client(api_key=api_key, base_url=base_url)
# 请求 LLM API 并返回流式响应
return StreamingResponse(
stream_response(
request,
client,
llm_messages,
enable_thinking,
persist=(chat_id is not None),
chat_id=chat_id,
api_key=api_key,
base_url=base_url,
model=actual_model
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
except HTTPException:
raise
except json.JSONDecodeError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="请求体必须是有效的 JSON 格式"
)
except Exception as e:
print(f"处理请求时发生错误:{e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"处理请求时发生错误:{str(e)}"
)
async def handle_regenerate(
request: Request,
current_user: User,
chat_id: str,
regenerate_round_request_id: Optional[int] = None
) -> StreamingResponse:
"""处理重新生成请求
支持重新生成指定轮次的 AI 回复
流程:
1. 验证对话存在性和所有权
2. 解析 regenerate_round_request_id(None/0/1 都表示最后一轮)
3. 调用 rollback_history 执行模拟回滚,获取历史消息和被回滚的 user message
4. 在内存中拼接 prompt(system + compressed + recent + user_message)
5. 调用 stream_response(传入 regenerate_mode=True 和 rollback_callback)
6. stream_response 在流式传输成功后,调用 rollback_callback 执行数据库回滚
7. 回滚完成后发送 [DONE] 标记
8. 将新的对话历史持久化到数据库
Args:
request: FastAPI 请求对象
current_user: 当前认证用户
chat_id: 对话 ID
regenerate_round_request_id: 要重新生成的轮次(1=最后一轮,0 或 None=最后一轮,6=倒数第 6 轮)
Returns:
StreamingResponse: 流式响应
Raises:
HTTPException: 404 (对话不存在/无权访问), 400 (regenerate_round_request_id 无效), 500 (服务器错误)
"""
try:
# 解析请求体(获取 enable_thinking 等参数)
body = await request.json() if request.method == "POST" else {}
# 获取 regenerate_round_request_id 参数(支持两种命名)
regenerate_round_request_id = body.get("regenerate_round_request_id") or body.get("round_id")
# 可选参数:enable_thinking
enable_thinking = body.get("enable_thinking")
if enable_thinking is None:
extra_body = body.get("extra_body")
if isinstance(extra_body, dict):
enable_thinking = extra_body.get("enable_thinking")
# 可选参数:model
requested_model = body.get("model")
# =========================================================================
# BYOK 用户 API Key 检查
# =========================================================================
api_key: Optional[str] = None
base_url: Optional[str] = None
if current_user.role != "admin":
api_key_record = get_user_api_key(current_user)
if api_key_record is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="未找到 API Key 记录,请配置您的 API Key"
)
if api_key_record.status in ["pending", "invalid"]:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="您的 API Key 无效或待验证,请检查您的 API Key"
)
if api_key_record.status == "quota":
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="您的 API Key 配额不足,请检查您的 API Key"
)
try:
api_key = decrypt_api_key_record(api_key_record)
base_url = api_key_record.base_url
except Exception as e:
print(f"API Key 解密失败:{e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="API Key 解密失败"
)
# =========================================================================
print("接收到重新生成请求,鉴权通过,chat_id: {}, user: {}, round_id: {}".format(chat_id, current_user.email, regenerate_round_request_id))
# 处理 regenerate_round_request_id:None/0/1 都表示最后一轮
if regenerate_round_request_id is None or regenerate_round_request_id == 0:
regenerate_round_request_id = 1
# 调用 rollback_history 进行模拟回滚(不写入数据库)
# 返回:compressed_messages, recent_messages, user_message
# 该函数内部会验证对话存在性和所有权
compressed_messages, recent_messages, user_message = rollback_history(
chat_id, regenerate_round_request_id, current_user.uuid, simulate_only=True
)
# 确定使用的模型
# 获取 chat 的 current_model
from history import get_chat_by_id_and_uuid
chat = get_chat_by_id_and_uuid(chat_id, current_user.uuid)
if chat is None:
# rollback_history 已经验证过,这里理论上不会为 None
# 但为了安全起见,还是检查一下
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="对话不存在或无权访问"
)
chat_current_model = chat.current_model
actual_model = _check_request_model(requested_model, chat_current_model)
# 保存原始 updated_at,用于乐观锁检查
original_updated_at = chat.updated_at
# 构建发送给 LLM 的 messages
# 格式:system prompt + compressed_messages + recent_messages + user_message
llm_messages = [
{"role": "system", "content": get_prompt("chat_system")},
]
llm_messages.extend(compressed_messages)
llm_messages.extend(recent_messages)
llm_messages.append({"role": "user", "content": user_message})
# 创建 OpenAI 客户端
client = get_openai_client(api_key=api_key, base_url=base_url)
# 定义回滚回调函数(在流式传输成功后执行真正的数据库回滚)
def do_rollback() -> None:
"""执行真正的数据库回滚(不模拟)
Raises:
RuntimeError: 对话状态已变更(乐观锁检查失败)
DatabaseError: 数据库访问失败
"""
rollback_history(
chat_id,
regenerate_round_request_id,
current_user.uuid,
simulate_only=False,
original_updated_at=original_updated_at
)
# 请求 LLM API 并返回流式响应
# 注意:persist=True 确保新对话被持久化,regenerate_mode=True 启用回滚回调
return StreamingResponse(
stream_response(
request,
client,
llm_messages,
enable_thinking,
persist=True, # 重新生成需要持久化新对话
chat_id=chat_id,
api_key=api_key,
base_url=base_url,
model=actual_model,
regenerate_mode=True, # 启用重新生成模式
rollback_callback=do_rollback # 回滚回调
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
except HTTPException:
raise
except json.JSONDecodeError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="请求体必须是有效的 JSON 格式"
)
except ValueError as e:
# round_id 超出范围等验证错误
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
except ChatNotFoundError:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="对话不存在"
)
except ChatAccessDeniedError:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="无权访问该对话"
)
except RuntimeError as e:
# 乐观锁检查失败:对话状态已变更
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=str(e)
)
except DatabaseError as e:
# 数据库访问失败
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"数据库错误:{str(e)}"
)
except Exception as e:
print(f"处理重新生成请求时发生错误:{e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"处理重新生成请求时发生错误:{str(e)}"
)