-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsampling_proxy.py
More file actions
2291 lines (2005 loc) · 123 KB
/
sampling_proxy.py
File metadata and controls
2291 lines (2005 loc) · 123 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import os
import json
import httpx
from typing import Optional
from fastapi import FastAPI, Request, Response, status
from fastapi.responses import StreamingResponse
from contextlib import asynccontextmanager
import uvicorn
import asyncio # Import asyncio for potential sleep
import argparse # Import argparse for command-line arguments
import threading
# Global request counter for log correlation
_request_counter = 0
_request_counter_lock = threading.Lock()
def get_request_id():
"""Get next request ID for log correlation."""
global _request_counter
with _request_counter_lock:
_request_counter += 1
return _request_counter
def log_info(request_id: int, message: str):
"""Print info log with request ID prefix."""
print(f"[INFO][R:{request_id}] {message}")
# Import validator module for garbage detection
from validator import (
validate_response,
validate_response_partial,
save_failed_response,
save_mid_stream_failure,
create_error_message,
calculate_retry_delay,
ValidationResult,
StreamingValidator,
StreamingValidationBuffer,
count_words_in_text,
extract_text_from_sse_chunks,
build_anthropic_error_stream,
build_openai_error_stream
)
def load_config(config_path="config.json"):
"""
Load configuration from JSON file.
Returns a dictionary with configuration values.
If config file doesn't exist or is invalid, returns default values.
"""
default_config = {
"server": {
"target_base_url": "http://127.0.0.1:8000/v1",
"sampling_proxy_base_path": "",
"sampling_proxy_host": "0.0.0.0",
"sampling_proxy_port": 8001,
"connect_timeout_seconds": 5.0,
"timeout_seconds": 1200.0,
"supports_openai": True,
"supports_anthropic": False
},
"logging": {
"enable_debug_logs": False,
"enable_override_logs": False
},
"default_sampling_params": {},
"override": {
"only_anthropic": False,
"model_name": None,
"sampling_params": {}
},
"model_sampling_params": {},
"validation": {
"enabled": False,
"validator_url": "http://127.0.0.1:1234",
"validator_model": "qwen-3.5-0.8b",
"supports_openai": True,
"supports_anthropic": False,
"connect_timeout_seconds": 5.0,
"timeout_seconds": 300.0,
"max_retries": 3,
"retry_base_delay_seconds": 1.0,
"retry_multiplier": 2.0,
"mid_stream_validation_enabled": False,
"mid_stream_validation_interval_words": 300
}
}
if not os.path.exists(config_path):
print(f"WARNING: Config file '{config_path}' not found. Using default values.")
return default_config
try:
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
# Merge with defaults to ensure all required keys exist
merged_config = default_config.copy()
for key, value in config.items():
if key in merged_config:
if isinstance(merged_config[key], dict) and isinstance(value, dict):
merged_config[key].update(value)
else:
merged_config[key] = value
else:
merged_config[key] = value
# Filter out null values from sampling params (convert to empty dicts)
if merged_config.get("default_sampling_params"):
merged_config["default_sampling_params"] = {
k: v for k, v in merged_config["default_sampling_params"].items()
if v is not None
}
# Filter out null values from override.sampling_params
if merged_config.get("override"):
override_config = merged_config["override"]
if "sampling_params" in override_config:
override_config["sampling_params"] = {
k: v for k, v in override_config["sampling_params"].items()
if v is not None
}
if merged_config.get("model_sampling_params"):
filtered_model_params = {}
for model, params in merged_config["model_sampling_params"].items():
filtered_params = {
k: v for k, v in params.items()
if v is not None
}
if filtered_params: # Only include models with non-null params
filtered_model_params[model] = filtered_params
merged_config["model_sampling_params"] = filtered_model_params
print(f"Configuration loaded from '{config_path}'")
return merged_config
except json.JSONDecodeError as e:
print(f"ERROR: Invalid JSON in config file '{config_path}': {e}. Using default values.")
return default_config
except Exception as e:
print(f"ERROR: Error loading config file '{config_path}': {e}. Using default values.")
return default_config
def extract_base_path(url):
"""
Extract the base path from a URL.
For example, "http://127.0.0.1:8000/abc/v4" returns "/abc/v4"
"""
from urllib.parse import urlparse
parsed = urlparse(url)
return parsed.path
def transform_path(original_path, from_base_path, to_base_path):
"""
Transform a path from one base path to another.
For example, with from_base_path="/v1" and to_base_path="/abc/v4":
"/v1/completions" -> "/abc/v4/completions"
"/v1/chat/completions" -> "/abc/v4/chat/completions"
If original_path doesn't start with from_base_path, it's returned unchanged.
"""
# Ensure base paths start with /
if not from_base_path.startswith('/'):
from_base_path = '/' + from_base_path
if not to_base_path.startswith('/'):
to_base_path = '/' + to_base_path
# Remove trailing slashes for consistent comparison
from_base_path = from_base_path.rstrip('/')
to_base_path = to_base_path.rstrip('/')
# Check if the path starts with the from_base_path
if original_path.startswith(from_base_path):
# Replace the base path
return original_path.replace(from_base_path, to_base_path, 1)
else:
# Path doesn't start with the expected base path, return as is
return original_path
# --- Configuration ---
# These will be initialized in the main block after loading config
TARGET_BASE_URL = None
TARGET_BASE_PATH = None
SAMPLING_PROXY_HOST = None
SAMPLING_PROXY_PORT = None
SAMPLING_PROXY_BASE_PATH = None
ENABLE_DEBUG_LOGS = False
ENABLE_OVERRIDE_LOGS = False
ENABLE_VALIDATION_LOGS = False
DEFAULT_SAMPLING_PARAMS = {}
OVERRIDE_CONFIG = {}
OVERRIDE_ONLY_ANTHROPIC = False
OVERRIDE_MODEL_NAME = None
OVERRIDE_SAMPLING_PARAMS = {}
MODEL_SAMPLING_PARAMS = {}
# Server capability configuration
# Determines what formats the backend server supports for passthrough
SERVER_SUPPORTS_OPENAI = True
SERVER_SUPPORTS_ANTHROPIC = False
VALIDATION_CONFIG = {"enabled": False}
# List of API path suffixes that are considered "generation" endpoints.
# Note: We check if the path ENDS WITH these suffixes to handle various prefixes
GENERATION_ENDPOINT_SUFFIXES = [
"generate", # Common SGLang generation endpoint
"completions", # OpenAI-compatible completions endpoint
"chat/completions", # OpenAI-compatible chat completions endpoint
"v1/messages", # Anthropic-compatible messages endpoint
]
# List of Anthropic-specific endpoints that should be handled locally
ANTHROPIC_ENDPOINTS = [
"api/event_logging/batch", # Anthropic event logging endpoint
"v1/messages/count_tokens", # Anthropic token counting endpoint
]
# Global variable to store the first available model name from /models to be used for anthropic requests
FIRST_AVAILABLE_MODEL = "any" # sglang allows any model name, vllm require exact match
# Initialize an httpx AsyncClient for making requests to the OpenAI Compatible backend.
# This client is designed for efficient connection pooling.
# A higher timeout is set to accommodate potentially long LLM generation times.
# Note: This will be re-initialized after config loading in the main block
client = None
# --- FastAPI Application Lifespan Setup ---
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Handles startup and shutdown events for the FastAPI application.
Ensures the httpx client is properly closed when the application shuts down.
"""
global FIRST_AVAILABLE_MODEL, client
print("FastAPI application startup.")
# Initialize client with the correct TARGET_BASE_URL and timeout from config
connect_timeout = CONFIG["server"].get("connect_timeout_seconds", 5.0)
read_timeout = CONFIG["server"].get("timeout_seconds", 1200.0)
timeout = httpx.Timeout(connect=connect_timeout, read=read_timeout, write=read_timeout, pool=connect_timeout)
client = httpx.AsyncClient(base_url=TARGET_BASE_URL, timeout=timeout)
# Validate server capabilities - at least one format must be supported
if not SERVER_SUPPORTS_OPENAI and not SERVER_SUPPORTS_ANTHROPIC:
raise ValueError(
"Invalid configuration: server must support at least one format. "
"Set 'supports_openai: true' and/or 'supports_anthropic: true' in config."
)
# Poll /models to get the first available model (only if server supports OpenAI and no override model set)
# Skip polling if: 1) server doesn't support OpenAI (/models is OpenAI-only), or 2) override model already configured
if SERVER_SUPPORTS_OPENAI and not OVERRIDE_MODEL_NAME:
# If target base path is empty, use /v1/models for standard OpenAI/Anthropic servers
# Otherwise, the base path already includes the prefix
if not TARGET_BASE_PATH:
models_path = "/v1/models"
else:
models_path = "/models"
try:
print(f"Polling {TARGET_BASE_URL}{models_path} to get available models...")
response = await client.get(models_path)
if response.status_code == 200:
models_data = response.json()
if "data" in models_data and len(models_data["data"]) > 0:
FIRST_AVAILABLE_MODEL = models_data["data"][0]["id"]
print(f"Successfully retrieved first available model: {FIRST_AVAILABLE_MODEL}")
else:
print("WARNING: No models found in /models response")
else:
print(f"WARNING: Failed to get models from {models_path}. Status: {response.status_code}")
except Exception as e:
print(f"WARNING: Error polling {models_path}: {e}")
elif OVERRIDE_MODEL_NAME:
# Use the override model name as the first available model
FIRST_AVAILABLE_MODEL = OVERRIDE_MODEL_NAME
print(f"Using override model name: {FIRST_AVAILABLE_MODEL}")
else:
print("Skipping model polling (server doesn't support OpenAI format)")
yield # Application starts here
print("FastAPI application shutdown.")
if client:
await client.aclose()
print("HTTPX client closed.")
# --- FastAPI Application Setup ---
app = FastAPI(
title="Sampling Proxy",
description="A middleware server to override sampling parameters for generation requests, supports OpenAI Compatible target server and OpenAI Compatible and Anthropic requests.",
version="1.0.0",
lifespan=lifespan # Register the lifespan context manager
)
@app.get("/")
async def read_root():
"""
Root endpoint for a basic health check and to display middleware configuration.
"""
return {
"message": "Sampling Proxy is running.",
"target_backend": TARGET_BASE_URL,
"sampling_proxy_port": SAMPLING_PROXY_PORT,
"default_sampling_params": DEFAULT_SAMPLING_PARAMS,
"override": OVERRIDE_CONFIG,
"model_sampling_params_configured": list(MODEL_SAMPLING_PARAMS.keys()),
"generation_endpoints_monitored": GENERATION_ENDPOINT_SUFFIXES,
"anthropic_endpoints_handled_locally": ANTHROPIC_ENDPOINTS,
"debug_logs_enabled": ENABLE_DEBUG_LOGS,
}
def parse_sse_to_response(sse_text: str) -> Optional[dict]:
"""Parse SSE stream text to extract final response dict."""
content_blocks = {}
current_index = None
message_data = None
for line in sse_text.split('\n'):
line = line.strip()
if not line:
continue
if line.startswith('data: '):
data_str = line[6:]
if data_str == '[DONE]':
continue
try:
data = json.loads(data_str)
event_type = data.get('type')
if event_type == 'message_start':
message_data = data.get('message', {})
elif event_type == 'content_block_start':
index = data.get('index', 0)
content_blocks[index] = data.get('content_block', {}).copy()
current_index = index
elif event_type == 'content_block_delta':
index = data.get('index', 0)
delta = data.get('delta', {})
if index not in content_blocks:
content_blocks[index] = {}
if delta.get('type') == 'text_delta':
existing_text = content_blocks[index].get('text', '')
content_blocks[index]['text'] = existing_text + delta.get('text', '')
elif delta.get('type') == 'input_json_delta':
existing_json = content_blocks[index].get('_partial_json', '')
content_blocks[index]['_partial_json'] = existing_json + delta.get('partial_json', '')
elif event_type == 'message_stop':
# Build final response
if message_data:
content = []
for idx in sorted(content_blocks.keys()):
block = content_blocks[idx]
block_type = block.get('type', 'text')
if block_type == 'tool_use':
partial_json = block.pop('_partial_json', '')
if partial_json:
try:
block['input'] = json.loads(partial_json)
except json.JSONDecodeError:
block['input'] = {}
content.append(block)
else:
content.append(block)
message_data['content'] = content
return message_data
except json.JSONDecodeError:
continue
return None
def parse_openai_sse_to_response(sse_text: str) -> Optional[dict]:
"""Parse OpenAI SSE stream text to reconstruct the full response dict."""
content_parts = []
tool_calls = {} # index -> {id, name, arguments}
finish_reason = None
response_id = None
model = None
usage = None
for line in sse_text.split('\n'):
line = line.strip()
if not line:
continue
if line.startswith('data: '):
data_str = line[6:]
if data_str == '[DONE]':
continue
try:
data = json.loads(data_str)
# Extract metadata from first chunk
if response_id is None:
response_id = data.get('id')
model = data.get('model')
# Extract usage if present
if 'usage' in data:
usage = data['usage']
choices = data.get('choices', [])
if choices:
choice = choices[0]
delta = choice.get('delta', {})
finish_reason = choice.get('finish_reason') or finish_reason
# Handle text content
if 'content' in delta and delta['content']:
content_parts.append(delta['content'])
# Handle tool calls
if 'tool_calls' in delta:
for tc in delta['tool_calls']:
idx = tc.get('index', 0)
if idx not in tool_calls:
tool_calls[idx] = {'id': '', 'name': '', 'arguments': ''}
if 'id' in tc:
tool_calls[idx]['id'] = tc['id']
if 'function' in tc:
if 'name' in tc['function']:
tool_calls[idx]['name'] = tc['function']['name']
if 'arguments' in tc['function']:
tool_calls[idx]['arguments'] += tc['function']['arguments']
except json.JSONDecodeError:
continue
# Build final OpenAI response
if response_id is None:
return None
# Build message content
message = {'role': 'assistant'}
if content_parts:
message['content'] = ''.join(content_parts)
else:
message['content'] = ''
if tool_calls:
message['tool_calls'] = []
for idx in sorted(tool_calls.keys()):
tc = tool_calls[idx]
message['tool_calls'].append({
'id': tc['id'] or f'call_{idx}',
'type': 'function',
'function': {
'name': tc['name'],
'arguments': tc['arguments']
}
})
response = {
'id': response_id,
'object': 'chat.completion',
'created': 0,
'model': model or '',
'choices': [{
'index': 0,
'message': message,
'finish_reason': finish_reason or 'stop'
}]
}
if usage:
response['usage'] = usage
return response
def convert_openai_sse_to_anthropic_chunks(sse_text: str) -> list:
"""Convert OpenAI SSE chunks to Anthropic SSE chunks for streaming."""
anthropic_chunks = []
content_block_index = 0
has_tool_calls = False
# First, collect all chunks to determine structure
openai_chunks = []
for line in sse_text.split('\n'):
line = line.strip()
if line.startswith('data: ') and line[6:] != '[DONE]':
try:
openai_chunks.append(json.loads(line[6:]))
except json.JSONDecodeError:
pass
# Generate message_start
if openai_chunks:
first_chunk = openai_chunks[0]
anthropic_chunks.append({
'type': 'message_start',
'message': {
'id': first_chunk.get('id', 'msg_unknown'),
'type': 'message',
'role': 'assistant',
'content': [],
'model': first_chunk.get('model', ''),
'stop_reason': None,
'usage': {'input_tokens': 0, 'output_tokens': 0}
}
})
# Process each chunk
for data in openai_chunks:
choices = data.get('choices', [])
if not choices:
continue
choice = choices[0]
delta = choice.get('delta', {})
# Handle text content
if 'content' in delta and delta['content']:
if not has_tool_calls:
# Only emit text deltas if we haven't started tool calls
anthropic_chunks.append({
'type': 'content_block_delta',
'index': content_block_index,
'delta': {
'type': 'text_delta',
'text': delta['content']
}
})
# Handle tool calls
if 'tool_calls' in delta:
has_tool_calls = True
for tc in delta['tool_calls']:
idx = tc.get('index', 0)
if 'function' in tc:
func = tc['function']
if 'name' in func:
# Start new tool call block
content_block_index = idx
anthropic_chunks.append({
'type': 'content_block_start',
'index': idx,
'content_block': {
'type': 'tool_use',
'id': tc.get('id', f'toolu_{idx}'),
'name': func['name'],
'input': {}
}
})
elif 'arguments' in func:
# Arguments delta
anthropic_chunks.append({
'type': 'content_block_delta',
'index': idx,
'delta': {
'type': 'input_json_delta',
'partial_json': func['arguments']
}
})
# Handle finish_reason
if 'finish_reason' in choice and choice['finish_reason']:
finish_reason = choice['finish_reason']
stop_reason_map = {
'stop': 'end_turn',
'length': 'max_tokens',
'tool_calls': 'tool_use',
'content_filter': 'stop_sequence',
'function_call': 'tool_use'
}
stop_reason = stop_reason_map.get(finish_reason, 'end_turn')
# Add usage if present
usage_data = None
if 'usage' in data:
usage_data = {
'input_tokens': data['usage'].get('prompt_tokens', 0),
'output_tokens': data['usage'].get('completion_tokens', 0)
}
anthropic_chunks.append({
'type': 'message_delta',
'delta': {'stop_reason': stop_reason},
'usage': usage_data or {'output_tokens': 0}
})
anthropic_chunks.append({'type': 'message_stop'})
# Ensure we have message_stop if not added
if anthropic_chunks and anthropic_chunks[-1].get('type') != 'message_stop':
anthropic_chunks.append({'type': 'message_stop'})
return anthropic_chunks
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"])
async def proxy_target_requests(path: str, request: Request):
"""
Catch-all route to proxy all incoming requests to the OpenAI Compatible backend.
For POST requests to configured generation endpoints, it applies
the sampling parameter override logic.
Supports streaming responses from the OpenAI Compatible backend back to the client.
"""
# Access ENABLE_DEBUG_LOGS from the global scope
global ENABLE_DEBUG_LOGS
# Get request ID at the START for proper log correlation
request_id = get_request_id()
if ENABLE_DEBUG_LOGS:
print(f"\n--- Incoming Request: {request.method} {path} ---")
# Normalize path by removing leading/trailing slashes for consistent matching
original_path = path
path = path.strip('/')
if ENABLE_DEBUG_LOGS:
print(f"DEBUG: Normalized path for matching: '{path}' (Original: '{original_path}')")
# Handle Anthropic-specific endpoints locally
if path in ANTHROPIC_ENDPOINTS:
if ENABLE_DEBUG_LOGS:
print(f"DEBUG: Handling Anthropic endpoint '{path}' locally")
if path == "api/event_logging/batch":
# Handle event logging endpoint - return success response
if ENABLE_DEBUG_LOGS:
print(f"DEBUG: Processing event logging request")
try:
# Read the request body to acknowledge receipt
body = await request.body()
if ENABLE_DEBUG_LOGS:
print(f"DEBUG: Event logging body received: {len(body)} bytes")
# Return a success response that mimics what Anthropic expects
response_data = {
"status": "success",
"message": "Events logged successfully"
}
return Response(
content=json.dumps(response_data),
status_code=200,
media_type="application/json"
)
except Exception as e:
if ENABLE_DEBUG_LOGS:
print(f"ERROR: Error processing event logging: {e}")
return Response(
content=json.dumps({"error": "Failed to process events"}),
status_code=500,
media_type="application/json"
)
elif path == "v1/messages/count_tokens":
# Handle token counting endpoint
if ENABLE_DEBUG_LOGS:
print(f"DEBUG: Processing token counting request")
try:
# Read and parse the request body
body = await request.body()
if ENABLE_DEBUG_LOGS:
print(f"DEBUG: Token counting body received: {len(body)} bytes")
request_data = json.loads(body.decode('utf-8'))
messages = request_data.get("messages", [])
model = request_data.get("model", "claude-3-sonnet-20241022")
if ENABLE_DEBUG_LOGS:
print(f"DEBUG: Token counting request - model: {model}, messages: {messages}")
# Simple token estimation (rough approximation)
# In a real implementation, you might want to use a proper tokenizer
total_tokens = 0
for message in messages:
content = message.get("content", "")
if isinstance(content, list):
# Handle complex content format
for content_item in content:
if isinstance(content_item, dict) and content_item.get("type") == "text":
text = content_item.get("text", "")
# Rough estimation: ~4 characters per token for English text
total_tokens += len(text) // 4 + 1
elif isinstance(content_item, str):
total_tokens += len(content_item) // 4 + 1
elif isinstance(content, str):
total_tokens += len(content) // 4 + 1
else:
total_tokens += len(str(content)) // 4 + 1
# Return response in Anthropic format
response_data = {
"input_tokens": total_tokens
}
if ENABLE_DEBUG_LOGS:
print(f"DEBUG: Token counting result: {total_tokens} tokens")
return Response(
content=json.dumps(response_data),
status_code=200,
media_type="application/json"
)
except json.JSONDecodeError as e:
if ENABLE_DEBUG_LOGS:
print(f"ERROR: Invalid JSON in token counting request: {e}")
return Response(
content=json.dumps({"error": {"type": "invalid_request_error", "message": "Invalid JSON"}}),
status_code=400,
media_type="application/json"
)
except Exception as e:
if ENABLE_DEBUG_LOGS:
print(f"ERROR: Error processing token counting: {e}")
return Response(
content=json.dumps({"error": {"type": "api_error", "message": "Failed to count tokens"}}),
status_code=500,
media_type="application/json"
)
# For any other Anthropic endpoints, return a generic success
return Response(
content=json.dumps({"status": "ok"}),
status_code=200,
media_type="application/json"
)
# Prepare headers for the outgoing request to OpenAI Compatible backend.
# We copy the incoming headers and remove 'host' and 'content-length'
# as httpx will manage these for the new request.
headers = dict(request.headers)
headers.pop("host", None)
headers.pop("content-length", None) # httpx will recalculate if body changes
if ENABLE_DEBUG_LOGS:
print(f"DEBUG: Outgoing Request Headers (initial): {headers}")
request_content = None # This will hold the request body to be sent to target
is_generation_request = False
is_anthropic_request = False # Initialize Anthropic request flag
incoming_json_body = {} # Initialize in case it's not a POST/JSON request
# Determine if the current request path is a recognized generation endpoint
# Use suffix matching to handle paths with or without v1 prefix
is_generation_request = any(path.endswith(suffix) for suffix in GENERATION_ENDPOINT_SUFFIXES)
is_anthropic_request = path.endswith("v1/messages") # Check if this is an Anthropic request
if ENABLE_DEBUG_LOGS:
print(f"DEBUG: is_generation_request after check: {is_generation_request}")
print(f"DEBUG: is_anthropic_request: {is_anthropic_request}")
# Construct the target URL based on server capabilities
# Determine passthrough mode based on request format and server capabilities
should_passthrough_anthropic = is_anthropic_request and SERVER_SUPPORTS_ANTHROPIC
should_passthrough_openai = not is_anthropic_request and SERVER_SUPPORTS_OPENAI
if is_anthropic_request:
if should_passthrough_anthropic:
# Keep Anthropic path as-is, no conversion
target_path = transform_path("/" + original_path, SAMPLING_PROXY_BASE_PATH, TARGET_BASE_PATH)
if ENABLE_DEBUG_LOGS:
print(f"DEBUG: Anthropic passthrough mode - keeping path: {target_path}")
else:
# Convert /v1/messages to /chat/completions for OpenAI Compatible backend
# First apply the path transformation, then change to chat completions
transformed_path = transform_path("/" + original_path, SAMPLING_PROXY_BASE_PATH, TARGET_BASE_PATH)
target_path = transformed_path.replace("/v1/messages", "/chat/completions", 1)
if ENABLE_DEBUG_LOGS:
print(f"DEBUG: Converting Anthropic request from {original_path} to {target_path}")
else:
if not SERVER_SUPPORTS_OPENAI:
# OpenAI request but server doesn't support OpenAI - we can't convert OpenAI to Anthropic
return Response(
content=json.dumps({
"error": {
"type": "invalid_request_error",
"message": "Server does not support OpenAI format requests. Only Anthropic format is supported."
}
}),
status_code=400,
media_type="application/json"
)
# Apply base path transformation
target_path = transform_path("/" + original_path, SAMPLING_PROXY_BASE_PATH, TARGET_BASE_PATH)
if ENABLE_DEBUG_LOGS:
print(f"DEBUG: Path transformation: /{original_path} -> {target_path}")
print(f"DEBUG: Base paths - Proxy: {SAMPLING_PROXY_BASE_PATH}, Target: {TARGET_BASE_PATH}")
# Since httpx.AsyncClient is created with base_url=TARGET_BASE_URL,
# we need to provide only the path portion relative to the target base path
# Strip the TARGET_BASE_PATH from the beginning of target_path if it exists
if TARGET_BASE_PATH and target_path.startswith(TARGET_BASE_PATH):
relative_path = target_path[len(TARGET_BASE_PATH):]
# Ensure the relative path starts with / if it's not empty
if relative_path and not relative_path.startswith('/'):
relative_path = '/' + relative_path
else:
relative_path = target_path
if ENABLE_DEBUG_LOGS:
print(f"DEBUG: Relative path for httpx: {relative_path}")
# Ensure the query string is encoded to bytes as required by httpx.URL
target_url = httpx.URL(path=relative_path, query=request.url.query.encode("utf-8"))
if ENABLE_DEBUG_LOGS:
print(f"DEBUG: Target OpenAI Compatible URL: {target_url}")
# --- Sampling Parameter Override Logic ---
if is_generation_request and request.method == "POST":
if ENABLE_DEBUG_LOGS:
print("DEBUG: This is a POST generation request. Applying override logic.")
try:
# Attempt to parse the incoming request body as JSON.
# Generation requests typically send JSON payloads.
raw_body = await request.body()
if ENABLE_DEBUG_LOGS:
print(f"DEBUG: Raw incoming request body: {raw_body.decode('utf-8')}")
incoming_json_body = json.loads(raw_body) # This will be available for response processing
if ENABLE_DEBUG_LOGS:
print(f"DEBUG: Parsed incoming JSON body: {incoming_json_body}")
# Handle Anthropic request based on server capabilities
if is_anthropic_request:
if should_passthrough_anthropic:
# Passthrough mode: keep request as-is, but apply sampling params
if ENABLE_DEBUG_LOGS:
print("DEBUG: Anthropic passthrough mode - keeping request format")
# Don't modify incoming_json_body - it stays as Anthropic format
else:
# Convert Anthropic to OpenAI format
if ENABLE_DEBUG_LOGS:
print("DEBUG: Converting Anthropic request to OpenAI format.")
try:
# Extract Anthropic format data
anthropic_messages = incoming_json_body.get("messages", [])
anthropic_model = incoming_json_body.get("model")
anthropic_max_tokens = incoming_json_body.get("max_tokens")
anthropic_temperature = incoming_json_body.get("temperature")
anthropic_top_p = incoming_json_body.get("top_p")
anthropic_stream = incoming_json_body.get("stream", False)
anthropic_tools = incoming_json_body.get("tools")
anthropic_tool_choice = incoming_json_body.get("tool_choice")
# Convert Anthropic messages to OpenAI format
openai_messages = []
for msg_idx, msg in enumerate(anthropic_messages):
try:
# Map Anthropic roles to OpenAI roles
anthropic_role = msg.get("role", "user")
if anthropic_role == "user":
openai_role = "user"
elif anthropic_role == "assistant":
openai_role = "assistant"
elif anthropic_role == "system":
openai_role = "system"
else:
# Default to user for unknown roles
openai_role = "user"
if ENABLE_DEBUG_LOGS:
print(f"DEBUG: Unknown Anthropic role '{anthropic_role}' mapped to 'user'")
openai_msg = {
"role": openai_role,
"content": "" # Initialize with empty string instead of None
}
# Handle complex Anthropic content format
content = msg.get("content", [])
if isinstance(content, list):
content_parts = []
tool_calls = []
for content_item in content:
if isinstance(content_item, dict):
content_type = content_item.get("type")
if content_type == "text":
text_content = content_item.get("text", "")
if text_content:
content_parts.append(text_content)
elif content_type == "tool_use":
# Convert Anthropic tool_use to OpenAI tool_call format
tool_call = {
"id": content_item.get("id", f"call_{len(tool_calls)}"),
"type": "function",
"function": {
"name": content_item.get("name", ""),
"arguments": json.dumps(content_item.get("input", {}))
}
}
tool_calls.append(tool_call)
if ENABLE_DEBUG_LOGS:
print(f"DEBUG: Converted Anthropic tool_use to OpenAI tool_call: {tool_call}")
elif content_type == "tool_result":
# Convert Anthropic tool_result to OpenAI tool call format
tool_result_id = content_item.get("tool_use_id")
result_content = content_item.get("content", "")
is_error = content_item.get("is_error", False)
# Create a tool_call message with the result
if tool_result_id:
tool_call_msg = {
"role": "tool",
"tool_call_id": tool_result_id,
"content": str(result_content) if result_content else "No content"
}
if is_error:
tool_call_msg["content"] = f"Error: {result_content}"
openai_messages.append(tool_call_msg)
if ENABLE_DEBUG_LOGS:
print(f"DEBUG: Converted Anthropic tool_result to OpenAI tool message: {tool_call_msg}")
elif isinstance(content_item, str):
content_parts.append(content_item)
# Set content and tool_calls for the main message
if content_parts:
openai_msg["content"] = "".join(content_parts)
else:
# If no content parts but there are tool calls, set content to null
# Otherwise set to empty string
openai_msg["content"] = None if tool_calls else ""
if tool_calls:
openai_msg["tool_calls"] = tool_calls
elif isinstance(content, str):
openai_msg["content"] = content if content else ""
elif content is None:
openai_msg["content"] = ""
else:
openai_msg["content"] = str(content)
# Validate the message before adding
if openai_msg.get("role") != "tool" or "tool_call_id" in openai_msg:
# Ensure content is never None for non-tool messages
if openai_msg.get("content") is None and not openai_msg.get("tool_calls"):
openai_msg["content"] = ""
# Only add if the message has valid content or tool calls
if openai_msg.get("content") or openai_msg.get("tool_calls"):
openai_messages.append(openai_msg)
if ENABLE_DEBUG_LOGS:
print(f"DEBUG: Converted message {msg_idx}: {openai_msg}")
else:
if ENABLE_DEBUG_LOGS:
print(f"DEBUG: Skipping empty message {msg_idx}")
else:
if ENABLE_DEBUG_LOGS:
print(f"DEBUG: Skipping invalid tool message {msg_idx}")
except Exception as e:
if ENABLE_DEBUG_LOGS:
print(f"ERROR: Failed to convert message {msg_idx}: {e}")
# Continue with next message instead of failing completely
continue
# Override model for Anthropic requests
if OVERRIDE_MODEL_NAME:
overridden_model = OVERRIDE_MODEL_NAME
if ENABLE_OVERRIDE_LOGS:
print(f"OVERRIDE: Anthropic model '{anthropic_model}' OVERRIDDEN to '{OVERRIDE_MODEL_NAME}'")
else:
overridden_model = FIRST_AVAILABLE_MODEL if FIRST_AVAILABLE_MODEL else anthropic_model
if ENABLE_DEBUG_LOGS and FIRST_AVAILABLE_MODEL:
print(f"DEBUG: Using first available model '{FIRST_AVAILABLE_MODEL}' for Anthropic request")
# Convert to OpenAI chat completions format
openai_request = {
"model": overridden_model,
"messages": openai_messages,
"max_tokens": anthropic_max_tokens,
"stream": anthropic_stream
}
# Add optional parameters if present
if anthropic_temperature is not None:
openai_request["temperature"] = anthropic_temperature
if anthropic_top_p is not None:
openai_request["top_p"] = anthropic_top_p
# Convert tools if present
if anthropic_tools:
openai_tools = []
for tool in anthropic_tools:
openai_tool = {
"type": "function",
"function": {
"name": tool.get("name"),
"description": tool.get("description", ""),
"parameters": tool.get("input_schema", {})
}
}