From e11b96f13f2a8893c31e24ee93294f78a6150bd4 Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Thu, 23 Apr 2026 14:34:52 -0700 Subject: [PATCH 01/41] feat: add multi-turn dataset manager with flat JSONL support Add MultiTurnDataset, MultiTurnConfig schema, tool-calling types, Query.metadata transport field, adapter tools= kwarg, and multi-turn factory routing. --- docs/MULTI_TURN_QUICKSTART.md | 345 ++++++ examples/09_MultiTurn/README.md | 311 +++++ .../agentic_coding_benchmark.yaml | 33 + .../agentic_workflow_benchmark.yaml | 33 + .../customer_support_conversations.jsonl | 10 + .../09_MultiTurn/multi_turn_benchmark.yaml | 44 + .../multi_turn_with_concurrency.yaml | 44 + src/inference_endpoint/config/schema.py | 38 +- .../templates/concurrency_template.yaml | 2 +- .../templates/concurrency_template_full.yaml | 4 +- .../templates/offline_template_full.yaml | 4 +- .../config/templates/online_template.yaml | 2 +- .../templates/online_template_full.yaml | 4 +- src/inference_endpoint/core/types.py | 28 + .../dataset_manager/__init__.py | 2 + .../dataset_manager/factory.py | 9 +- .../dataset_manager/multi_turn_dataset.py | 428 +++++++ .../endpoint_client/adapter_protocol.py | 12 +- .../endpoint_client/http.py | 2 + .../endpoint_client/worker.py | 7 +- src/inference_endpoint/openai/accumulator.py | 58 +- .../openai/openai_adapter.py | 34 +- .../openai/openai_msgspec_adapter.py | 73 +- src/inference_endpoint/openai/types.py | 13 +- tests/unit/core/test_types.py | 54 + .../test_multi_turn_dataset.py | 1073 +++++++++++++++++ tests/unit/openai/test_msgspec_adapter.py | 146 +++ 27 files changed, 2747 insertions(+), 66 deletions(-) create mode 100644 docs/MULTI_TURN_QUICKSTART.md create mode 100644 examples/09_MultiTurn/README.md create mode 100644 examples/09_MultiTurn/agentic_coding_benchmark.yaml create mode 100644 examples/09_MultiTurn/agentic_workflow_benchmark.yaml create mode 100644 examples/09_MultiTurn/customer_support_conversations.jsonl create mode 100644 examples/09_MultiTurn/multi_turn_benchmark.yaml create mode 100644 examples/09_MultiTurn/multi_turn_with_concurrency.yaml create mode 100644 src/inference_endpoint/dataset_manager/multi_turn_dataset.py create mode 100644 tests/unit/dataset_manager/test_multi_turn_dataset.py create mode 100644 tests/unit/openai/test_msgspec_adapter.py diff --git a/docs/MULTI_TURN_QUICKSTART.md b/docs/MULTI_TURN_QUICKSTART.md new file mode 100644 index 00000000..73ed6678 --- /dev/null +++ b/docs/MULTI_TURN_QUICKSTART.md @@ -0,0 +1,345 @@ +# Multi-Turn Conversation Benchmarking - Quick Start Guide + +## πŸš€ Quick Start in 5 Minutes + +### 1. Prepare Your Dataset + +Create a JSONL file with your conversations. All rows for a given `conversation_id` must appear +**consecutively** in the file (no interleaving with other conversations): + +```jsonl +{"conversation_id": "c1", "turn": 1, "role": "user", "content": "Hello!", "system": "You are a helpful assistant"} +{"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "Hi! How can I help?"} +{"conversation_id": "c1", "turn": 3, "role": "user", "content": "What's 2+2?"} +{"conversation_id": "c1", "turn": 4, "role": "assistant", "content": "2+2 equals 4."} +``` + +**Rules**: + +- Alternate between "user" and "assistant" roles +- Start with "user" role +- Sequential turn numbers (1, 2, 3, ...) +- Same `conversation_id` for all turns in a conversation +- All rows for the same `conversation_id` must be grouped together + +### 2. Create Configuration File + +Save as `multi_turn_config.yaml`: + +```yaml +name: "my-multi-turn-benchmark" +version: "1.0" +type: "online" + +model_params: + name: "your-model-name" + temperature: 0.7 + max_new_tokens: 256 + +datasets: + - name: my_conversations + type: performance + path: path/to/your/conversations.jsonl + format: ".jsonl" + multi_turn: # ← Presence of this block enables multi-turn mode + mode: independent # ← Per-conv pipelines; no cross-conv turn barrier + turn_timeout_s: 300 # ← Max wait for prev turn + +settings: + load_pattern: + type: multi_turn # ← Use multi-turn scheduler + target_concurrency: 32 # ← OPTIONAL: limit concurrent requests + + client: + workers: 4 + +endpoint_config: + endpoints: + - "http://your-endpoint:8000" + api_type: openai + +report_dir: logs/my_multi_turn_benchmark +``` + +Results are written to `report_dir` (here: `logs/my_multi_turn_benchmark/`). + +### 3. Run Benchmark + +```bash +inference-endpoint benchmark from-config --config multi_turn_config.yaml +``` + +That's it! Your benchmark will now: + +- βœ… Enforce turn ordering (turn N+1 waits for turn N) +- βœ… Include conversation history in each request +- βœ… Track per-turn and per-conversation metrics +- βœ… Log all turns with conversation metadata + +--- + +## πŸ“Š Understanding Results + +After the benchmark completes, check the directory configured via `report_dir`: + +### Events Database + +The `events.db` SQLite database includes: + +- Standard fields: sample_uuid, event_type, timestamp_ns +- **New fields**: conversation_id, turn_number + +Query example: + +```sql +SELECT conversation_id, turn_number, event_type, timestamp_ns +FROM events +WHERE conversation_id = 'c1' +ORDER BY turn_number; +``` + +### Metrics + +Currently available: + +- **Per-turn metrics**: Latency, TTFT, TPOT for each turn +- **Conversation tracking**: All events tagged with conversation_id + +_Note: Per-conversation aggregation (e.g., "conversations/sec") is coming in a future update._ + +--- + +## 🎯 Conversation Modes Explained + +### Independent Mode (Default) + +```yaml +mode: independent +``` + +**Behavior**: + +- Issues turn-1 of ALL conversations at t=0 +- Then sequences turns within each conversation independently +- Maximum parallelism and throughput + +**Use for**: Realistic production load where short conversations finish while long ones are still running. +For single-conversation debugging, use `mode: independent` with `target_concurrency: 1`. +Note: unlike the plain `ConcurrencyScheduler`, multi-turn + `target_concurrency: 1` still enforces +per-conversation turn ordering β€” turn N+1 waits for turn N even at concurrency 1. + +**Example timeline**: + +``` +t=0: conv1-turn1, conv2-turn1, conv3-turn1 (all at once) +t=0.5: conv1-turn2 (after conv1-turn1 completes) +t=0.7: conv2-turn2 (after conv2-turn1 completes) +t=0.8: conv1-turn3 (after conv1-turn2 completes) +... +``` + +--- + +## πŸŽ›οΈ Concurrency Control (NEW!) + +For benchmarks with **> 50 conversations**, use `target_concurrency` to prevent endpoint overload: + +```yaml +settings: + load_pattern: + type: multi_turn + target_concurrency: 32 # ← Limit to 32 concurrent requests +``` + +**Why?** Without this, independent mode issues ALL turn-1s at once (could be 100+), overwhelming your endpoint. + +**Rule of thumb**: + +- Small (< 50 convs): No limit needed +- Medium (50-500 convs): `target_concurrency: 32` +- Large (500+ convs): `target_concurrency: 64` + +--- + +## πŸ”§ Common Configurations + +### Recommended: With Concurrency Control + +```yaml +multi_turn: + mode: independent + +settings: + load_pattern: + type: multi_turn + target_concurrency: 32 # ← Prevents overload + client: + workers: 8 + +datasets: + - samples: 100 +``` + +### High Throughput Testing + +```yaml +multi_turn: + mode: independent + turn_timeout_s: 600 + +settings: + client: + workers: 16 # More workers for parallel conversations +``` + +### Long Conversations + +```yaml +multi_turn: + mode: independent + turn_timeout_s: 1800 # 30 minutes for slow responses +``` + +--- + +## ❓ Troubleshooting + +### "Conversation has invalid role sequence" + +**Problem**: Your dataset doesn't alternate between user/assistant. + +**Fix**: Check your JSONL - must be: user, assistant, user, assistant, ... + +### "Rows for conversation X are not consecutive" + +**Problem**: Rows for the same `conversation_id` are interleaved with rows from other conversations. + +**Fix**: Sort your JSONL so all rows for each conversation appear together. + +### "Turn timed out waiting for prev turn" + +**Problem**: Previous turn took longer than `turn_timeout_s`. + +**Fixes**: + +1. Increase `turn_timeout_s` in config +2. Check if your endpoint is slow or unresponsive +3. Look for errors in the endpoint logs + +### Dataset not loading + +**Problem**: MultiTurnDataset not recognized. + +**Fix**: Ensure `format: ".jsonl"` is specified in config: + +```yaml +datasets: + - path: your_file.jsonl + format: ".jsonl" # ← Required for JSONL +``` + +--- + +## πŸ“ Example Datasets + +### Simple 2-Turn Conversation + +```jsonl +{"conversation_id": "c1", "turn": 1, "role": "user", "content": "Hi"} +{"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "Hello!"} +``` + +### With System Prompt + +```jsonl +{"conversation_id": "c1", "turn": 1, "role": "user", "content": "Who won?", "system": "You are a sports expert"} +{"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "The Lakers won."} +``` + +### Multiple Conversations + +```jsonl +{"conversation_id": "c1", "turn": 1, "role": "user", "content": "Hi"} +{"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "Hello!"} +{"conversation_id": "c2", "turn": 1, "role": "user", "content": "Hey"} +{"conversation_id": "c2", "turn": 2, "role": "assistant", "content": "Hi there!"} +``` + +### With Model Override + +```jsonl +{"conversation_id": "c1", "turn": 1, "role": "user", "content": "Summarize this", "model": "gpt-4"} +{"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "Here's the summary..."} +``` + +--- + +## πŸ§ͺ Testing Your Setup + +### 1. Use the Example Dataset + +```bash +cd examples/multi_turn +inference-endpoint benchmark from-config --config multi_turn_benchmark.yaml +``` + +### 2. Check the Logs + +```bash +cat logs/multi_turn_test/benchmark.log +# Look for: "Turn X of conversation_id issued" +``` + +### 3. Verify Event Recording + +```bash +sqlite3 logs/multi_turn_test/events.db +sqlite> SELECT DISTINCT conversation_id FROM events; +# Should show your conversation IDs +``` + +--- + +## πŸ’‘ Tips & Best Practices + +### Dataset Design + +- **Keep conversations realistic**: 2-10 turns typical +- **Test edge cases**: 1-turn conversations, very long conversations +- **Include system prompts**: Helps model understand context + +### Performance + +- **Workers**: Set `workers` = number of concurrent conversations +- **Timeout**: Set `turn_timeout_s` = 2x your longest expected turn latency +- **Memory**: ~1KB per turn, plan accordingly for large datasets + +### Debugging + +- **Start small**: Test with 1-2 conversations first +- **Single conversation**: Use `mode: independent` with `target_concurrency: 1` +- **Check events.db**: Verify turn ordering in database + +--- + +## πŸ”— More Information + +- **Full Documentation**: See `examples/09_MultiTurn/README.md` +- **Architecture**: See `AGENTS.md` (Multi-Turn section) + +--- + +## βœ… Checklist + +Before running your first multi-turn benchmark: + +- [ ] Dataset follows format (alternating user/assistant roles) +- [ ] All rows for each conversation_id are grouped together +- [ ] Config has `multi_turn:` block in the dataset section +- [ ] Config has `load_pattern.type: multi_turn` +- [ ] Endpoint is running and reachable +- [ ] `format: ".jsonl"` specified for JSONL datasets +- [ ] Conversation IDs are unique per conversation +- [ ] Turn numbers are sequential (1, 2, 3, ...) + +Happy benchmarking! πŸš€ diff --git a/examples/09_MultiTurn/README.md b/examples/09_MultiTurn/README.md new file mode 100644 index 00000000..0d3348c7 --- /dev/null +++ b/examples/09_MultiTurn/README.md @@ -0,0 +1,311 @@ +# Multi-Turn Conversation Benchmarking Examples + +This directory contains examples for benchmarking conversational AI workloads with multi-turn conversation support. + +## Overview + +Multi-turn conversation benchmarking enables testing realistic conversational AI scenarios where each turn depends on previous responses. The system maintains conversation history and enforces turn sequencing to simulate real-world multi-turn interactions. + +## Dataset Format + +Multi-turn datasets use JSONL format with the following structure: + +```jsonl +{"conversation_id": "c1", "turn": 1, "role": "user", "content": "...", "system": "..."} +{"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "..."} +{"conversation_id": "c1", "turn": 3, "role": "user", "content": "..."} +``` + +### Required Fields + +- `conversation_id`: Unique identifier for each conversation +- `turn`: Turn number within conversation (1-indexed) +- `role`: Speaker role ("user" or "assistant") +- `content`: Message content + +### Optional Fields + +- `system`: System prompt (typically only on first user turn) +- `model`: Model name override for this turn +- `max_new_tokens`: Maximum tokens to generate for this turn + +### Validation Rules + +1. All rows for a given `conversation_id` must appear **consecutively** in the file (no interleaving + with rows from other conversations). Turns within a conversation must be in order. + The flat-row format is intentional: it enables row-by-row streaming without loading entire + conversations into memory first. +2. Conversations must follow a valid role sequence: + - Plain chat: `user β†’ assistant β†’ user β†’ ...` + - Agentic: `user β†’ assistant (with tool_calls) β†’ tool β†’ [tool | assistant (with tool_calls)]* β†’ assistant β†’ user β†’ ...` +3. First turn must be "user" role +4. Turn numbers must be sequential (1, 2, 3, ...) +5. Each conversation must have at least one turn + +## Agentic (Tool-Sequence) Datasets + +For agentic workloads where the model dispatches tools, the dataset must include tool-call +metadata. The source format for these datasets is a **snapshot JSONL** β€” each line contains the +full conversation history at a particular checkpoint. The benchmarker requires **flat-row JSONL** +(one row per message), so a conversion step is needed first. + +### Source snapshot format + +Each line in the source file represents one snapshot of a conversation: + +```json +{ + "conversation_id": "sim_001", + "conversation_idx": 5, + "messages": [ + {"role": "system", "content": "..."}, + {"role": "user", "content": "..."}, + {"role": "assistant", "tool_calls": [{"id": "...", "type": "function", "function": {"name": "bash", "arguments": "{\"cmd\": \"ls\"}"}}]}, + {"role": "tool", "tool_call_id": "...", "content": "file1.txt\nfile2.txt"}, + {"role": "assistant", "content": "Done."} + ], + "tools": [...], + "metadata": {} +} +``` + +Multiple snapshots may exist per `conversation_id` (one per `conversation_idx`); only the +highest-indexed snapshot per conversation is used. + +### Converting to flat-row format + +The following commands convert each source snapshot file to the flat-row format required by the benchmarker. +Run from the repo root: + +```bash +python scripts/convert_agentic_snapshot.py \ + /path/to/agentic_coding_dataset.jsonl \ # input snapshot JSONL + examples/09_MultiTurn/datasets/agentic_coding_flat.jsonl \ # output flat-row JSONL + --verify + +python scripts/convert_agentic_snapshot.py \ + /path/to/agentic_workflow_dataset.jsonl \ # input snapshot JSONL + examples/09_MultiTurn/datasets/agentic_workflow_flat.jsonl \ # output flat-row JSONL + --verify +``` + +The `--verify` flag cross-checks every client turn's message history against the source snapshot +and exits with code 1 if any mismatch is found. The script also: + +- Collapses consecutive `user` messages into one (keeps turn sequencing clean) +- Merges consecutive `tool` messages for the same assistant dispatch into a single row with a + `tool_results` list (so all parallel results are sent together in one API call) + +### Flat-row format after conversion + +The extra fields supported beyond plain user/assistant: + +| Row role | Extra fields | +| -------------------------------- | ------------------------------------------------------------------ | +| `assistant` with tool calls | `tool_calls: [{id, type, function: {name, arguments}}]` | +| `tool` single result | `tool_call_id: `, `content: ` | +| `tool` parallel results (merged) | `tool_results: [{tool_call_id, content}, ...]` | +| `user` or `tool` turns | `tools: [...]` (OpenAI tool definitions forwarded to the endpoint) | + +Example rows from a converted agentic dataset: + +```jsonl +{"conversation_id": "sim_001", "turn": 1, "role": "user", "content": "Fix the bug in foo.py", "system": "You are a coding agent.", "tools": [...]} +{"conversation_id": "sim_001", "turn": 2, "role": "assistant", "tool_calls": [{"id": "functions.bash:0", "type": "function", "function": {"name": "bash", "arguments": "{\"cmd\": \"cat foo.py\"}"}}]} +{"conversation_id": "sim_001", "turn": 3, "role": "tool", "tool_call_id": "functions.bash:0", "content": "def foo():\n return 1/0", "tools": [...]} +{"conversation_id": "sim_001", "turn": 4, "role": "assistant", "content": "The bug is a ZeroDivisionError. Here is the fix: ..."} +``` + +### Running agentic benchmarks + +After converting the datasets, update the `path` field in the config files and run: + +```bash +inference-endpoint benchmark from-config \ + --config examples/09_MultiTurn/agentic_coding_benchmark.yaml + +inference-endpoint benchmark from-config \ + --config examples/09_MultiTurn/agentic_workflow_benchmark.yaml +``` + +--- + +## Configuration + +### Basic Configuration + +```yaml +datasets: + - name: customer_support + type: performance + path: examples/multi_turn/customer_support_conversations.jsonl + format: ".jsonl" + multi_turn: + mode: independent + turn_timeout_s: 300.0 + +settings: + load_pattern: + type: multi_turn +``` + +### Concurrency Control (Optional) + +The multi-turn scheduler supports **optional concurrency limiting** to control the maximum number of in-flight requests across all conversations: + +```yaml +settings: + load_pattern: + type: multi_turn + target_concurrency: 32 # ← Limit to 32 concurrent requests +``` + +**Behavior**: + +- Without `target_concurrency`: Unlimited concurrency (all turn-1s issue at t=0 in INDEPENDENT mode) +- With `target_concurrency`: Limits total in-flight requests across all conversations +- Combines with turn sequencing: Turn N+1 still waits for turn N, AND waits for available slot + +**Use cases**: + +- 🎯 **Prevent endpoint overload**: Control request rate to busy endpoints +- 🎯 **Large-scale testing**: Benchmark 1000+ conversations without overwhelming system +- 🎯 **Resource management**: Stay within port limits, memory constraints + +**Example**: 100 conversations with `target_concurrency: 32` + +``` +t=0: Issue first 32 turn-1s (concurrency limit reached) +t=0.5: Turn-1 completes β†’ issue next turn-1 (slot filled) +t=1.0: Turn-1 completes β†’ issue turn-2 of completed conv (slot filled) +... Maintains ~32 in-flight across all conversations +``` + +### Conversation Modes + +The default mode is `independent`. + +#### Independent Mode (Default) + +Issues turns for each conversation independently β€” no cross-conversation turn barrier. + +```yaml +multi_turn: + mode: independent + +settings: + load_pattern: + type: multi_turn + target_concurrency: 32 +``` + +**Use case**: Realistic production load where short conversations finish while long ones are +still running. Turn 1 of one conversation and turn 100 of another can be in-flight simultaneously. + +For single-conversation debugging, use `mode: independent` with `target_concurrency: 1`. + +### Turn Timeout + +Configure maximum wait time for previous turn completion: + +```yaml +multi_turn: + turn_timeout_s: 300.0 # 5 minutes +``` + +If a turn times out waiting for the previous turn, it will be skipped and logged as a warning. + +## Running Multi-Turn Benchmarks + +### Using Configuration File + +```bash +inference-endpoint benchmark from-config \ + --config examples/multi_turn/multi_turn_benchmark.yaml +``` + +### Viewing Results + +Multi-turn benchmarks produce both per-turn and per-conversation metrics: + +- **Per-turn metrics**: Latency, TTFT, TPOT for each individual turn +- **Per-conversation metrics**: Total conversation latency, conversations per second + +Results are stored in the configured `report_dir` with conversation metadata included in the events database. + +## Example Datasets + +### customer_support_conversations.jsonl + +Simple customer support conversations demonstrating basic multi-turn interactions: + +- 3 conversations +- 2-4 turns per conversation +- Customer support agent system prompt + +## Architecture Notes + +### Key Components + +- **ConversationManager**: Tracks conversation state and message history +- **MultiTurnScheduler**: Enforces turn sequencing within conversations +- **ConversationSample**: Sample with conversation metadata +- **MultiTurnDataset**: Validates and structures multi-turn data + +### Turn Sequencing + +The system ensures that: + +1. Turn N+1 cannot be issued until turn N completes +2. Message history is included in subsequent requests +3. Concurrent conversations are supported (in independent mode) + +### Memory Considerations + +Each conversation maintains message history in memory. For large-scale benchmarks with long conversations: + +- Memory usage: ~1KB per turn (approximate) +- 1000 conversations Γ— 10 turns = ~10MB + +## Troubleshooting + +### "Conversation has invalid role sequence" + +**Cause**: Conversation doesn't follow a valid role sequence. + +**Fix**: For plain chat, ensure the dataset alternates between user and assistant: + +``` +user -> assistant -> user -> assistant -> ... +``` + +For agentic datasets, use the conversion script (`scripts/convert_agentic_snapshot.py`) to +produce a properly sequenced flat-row file. The valid agentic sequence is: + +``` +user -> assistant (tool_calls) -> tool -> [tool | assistant (tool_calls)]* -> assistant -> user -> ... +``` + +### "Turn timed out waiting for prev turn" + +**Cause**: Previous turn took longer than `turn_timeout_s` to complete. + +**Fixes**: + +- Increase `turn_timeout_s` in configuration +- Check endpoint performance +- Verify endpoint is responding + +### Single-turn benchmarks unaffected + +Multi-turn logic is only activated when a `multi_turn:` block is present in the dataset configuration. Existing single-turn benchmarks continue to work unchanged with zero performance overhead. + +## Future Enhancements + +Planned features: + +- [ ] Poisson conversation arrival mode implementation +- [ ] Per-conversation metrics in reporting +- [ ] Conversation-level latency percentiles +- [ ] Support for tool/function calls in conversations +- [ ] Dynamic conversation branching diff --git a/examples/09_MultiTurn/agentic_coding_benchmark.yaml b/examples/09_MultiTurn/agentic_coding_benchmark.yaml new file mode 100644 index 00000000..5a1036a7 --- /dev/null +++ b/examples/09_MultiTurn/agentic_coding_benchmark.yaml @@ -0,0 +1,33 @@ +name: "agentic-coding-benchmark" +version: "1.0" +type: "online" + +model_params: + name: "your-model-name" # Replace with your actual model name + max_new_tokens: 1024 + +datasets: + - name: agentic_coding + type: performance + # Run: python scripts/convert_agentic_snapshot.py examples/09_MultiTurn/datasets/agentic_coding_flat.jsonl --verify + path: examples/09_MultiTurn/datasets/agentic_coding_flat.jsonl + format: ".jsonl" + multi_turn: + mode: independent + turn_timeout_s: 600.0 + +settings: + runtime: + min_duration_ms: 0 + max_duration_ms: 3600000 + + load_pattern: + type: multi_turn + target_concurrency: 4096 + +endpoint_config: + endpoints: + - "http://localhost:8868" + api_type: openai + +report_dir: logs/agentic_coding diff --git a/examples/09_MultiTurn/agentic_workflow_benchmark.yaml b/examples/09_MultiTurn/agentic_workflow_benchmark.yaml new file mode 100644 index 00000000..e8885465 --- /dev/null +++ b/examples/09_MultiTurn/agentic_workflow_benchmark.yaml @@ -0,0 +1,33 @@ +name: "agentic-workflow-benchmark" +version: "1.0" +type: "online" + +model_params: + name: "your-model-name" # Replace with your actual model name + max_new_tokens: 512 + +datasets: + - name: agentic_workflow + type: performance + # Run: python scripts/convert_agentic_snapshot.py examples/09_MultiTurn/datasets/agentic_workflow_flat.jsonl --verify + path: examples/09_MultiTurn/datasets/agentic_workflow_flat.jsonl + format: ".jsonl" + multi_turn: + mode: independent + turn_timeout_s: 600.0 + +settings: + runtime: + min_duration_ms: 0 + max_duration_ms: 3600000 + + load_pattern: + type: multi_turn + target_concurrency: 96 + +endpoint_config: + endpoints: + - "http://localhost:8868" + api_type: openai + +report_dir: logs/agentic_workflow diff --git a/examples/09_MultiTurn/customer_support_conversations.jsonl b/examples/09_MultiTurn/customer_support_conversations.jsonl new file mode 100644 index 00000000..ac19e907 --- /dev/null +++ b/examples/09_MultiTurn/customer_support_conversations.jsonl @@ -0,0 +1,10 @@ +{"conversation_id": "conv_001", "turn": 1, "role": "user", "content": "I need help resetting my password", "system": "You are a helpful customer support agent"} +{"conversation_id": "conv_001", "turn": 2, "role": "assistant", "content": "I'd be happy to help you reset your password. Can you provide your email address?"} +{"conversation_id": "conv_001", "turn": 3, "role": "user", "content": "It's user@example.com"} +{"conversation_id": "conv_001", "turn": 4, "role": "assistant", "content": "Thank you. I've sent a password reset link to user@example.com. Please check your inbox and follow the instructions."} +{"conversation_id": "conv_002", "turn": 1, "role": "user", "content": "What are your business hours?", "system": "You are a helpful customer support agent"} +{"conversation_id": "conv_002", "turn": 2, "role": "assistant", "content": "We're open Monday-Friday, 9 AM to 5 PM EST. How can I assist you today?"} +{"conversation_id": "conv_002", "turn": 3, "role": "user", "content": "Do you offer weekend support?"} +{"conversation_id": "conv_002", "turn": 4, "role": "assistant", "content": "For urgent issues, we offer limited support on weekends from 10 AM to 2 PM EST. For non-urgent matters, please contact us during our regular business hours."} +{"conversation_id": "conv_003", "turn": 1, "role": "user", "content": "Can I cancel my subscription?", "system": "You are a helpful customer support agent"} +{"conversation_id": "conv_003", "turn": 2, "role": "assistant", "content": "Yes, you can cancel your subscription at any time. Would you like me to guide you through the cancellation process?"} diff --git a/examples/09_MultiTurn/multi_turn_benchmark.yaml b/examples/09_MultiTurn/multi_turn_benchmark.yaml new file mode 100644 index 00000000..9ed6c9f1 --- /dev/null +++ b/examples/09_MultiTurn/multi_turn_benchmark.yaml @@ -0,0 +1,44 @@ +name: "multi-turn-customer-support" +version: "1.0" +type: "online" + +model_params: + name: "meta-llama/Llama-3.2-1B-Instruct" + temperature: 0.7 + max_new_tokens: 256 + +datasets: + - name: customer_support_conversations + type: performance + path: examples/09_MultiTurn/customer_support_conversations.jsonl + format: ".jsonl" + samples: 10 + multi_turn: + mode: independent + turn_timeout_s: 300.0 + +settings: + runtime: + min_duration_ms: 60000 + max_duration_ms: 300000 + + load_pattern: + type: multi_turn + # target_concurrency: 32 # Optional: limit concurrent requests across all conversations + + client: + warmup_connections: 0 + +metrics: + collect: + - throughput + - latency + - ttft + - tpot + +endpoint_config: + endpoints: + - "http://localhost:8868" + api_type: openai + +report_dir: logs/multi_turn_test diff --git a/examples/09_MultiTurn/multi_turn_with_concurrency.yaml b/examples/09_MultiTurn/multi_turn_with_concurrency.yaml new file mode 100644 index 00000000..491e6b4b --- /dev/null +++ b/examples/09_MultiTurn/multi_turn_with_concurrency.yaml @@ -0,0 +1,44 @@ +name: "multi-turn-with-concurrency-control" +version: "1.0" +type: "online" + +model_params: + name: "meta-llama/Llama-3.2-1B-Instruct" + temperature: 0.7 + max_new_tokens: 256 + +datasets: + - name: customer_support_conversations + type: performance + path: examples/09_MultiTurn/customer_support_conversations.jsonl + format: ".jsonl" + samples: 10 + multi_turn: + mode: independent # All conv turn-1 start together + turn_timeout_s: 300.0 + +settings: + runtime: + min_duration_ms: 60000 + max_duration_ms: 300000 + + load_pattern: + type: multi_turn + target_concurrency: 32 # ← NEW: Limit to 32 concurrent requests + + client: + warmup_connections: 0 + +metrics: + collect: + - throughput + - latency + - ttft + - tpot + +endpoint_config: + endpoints: + - "http://localhost:8868" + api_type: openai + +report_dir: logs/multi_turn_with_concurrency diff --git a/src/inference_endpoint/config/schema.py b/src/inference_endpoint/config/schema.py index 8ab8f3b0..1cd0f172 100644 --- a/src/inference_endpoint/config/schema.py +++ b/src/inference_endpoint/config/schema.py @@ -60,10 +60,17 @@ class LoadPatternType(str, Enum): MAX_THROUGHPUT = "max_throughput" # Offline: all queries at t=0 POISSON = "poisson" # Online: fixed QPS with Poisson distribution CONCURRENCY = "concurrency" # Online: fixed concurrent requests + MULTI_TURN = "multi_turn" # Multi-turn conversations with turn sequencing BURST = "burst" # Burst pattern (TODO) STEP = "step" # Step pattern (TODO) +class ConversationMode(str, Enum): + """Multi-turn conversation scheduling modes.""" + + INDEPENDENT = "independent" # Per-conv pipelines; no cross-conv turn barrier + + class OSLDistributionType(str, Enum): """Output Sequence Length distribution types.""" @@ -230,6 +237,26 @@ def get_ruleset_instance(self) -> BenchmarkSuiteRuleset: return get_ruleset(self.ruleset) +class MultiTurnConfig(BaseModel): + """Multi-turn conversation configuration. + + Configuration for benchmarking conversational AI workloads with turn sequencing. + Enables testing multi-turn conversations where each turn depends on previous responses. + Presence of this block in the dataset config enables multi-turn mode. + + Attributes: + mode: Conversation scheduling strategy (currently only independent). + turn_timeout_s: Maximum seconds to wait for previous turn completion. + use_dataset_history: If True, use pre-built message history from dataset. + """ + + model_config = {"extra": "forbid"} + + mode: ConversationMode = ConversationMode.INDEPENDENT + turn_timeout_s: float = 300.0 + use_dataset_history: bool = True + + class Dataset(BaseModel): """Dataset configuration. @@ -260,6 +287,9 @@ class Dataset(BaseModel): accuracy_config: AccuracyConfig | None = Field( None, description="Accuracy evaluation settings" ) + multi_turn: MultiTurnConfig | None = Field( + None, description="Multi-turn conversation configuration" + ) @model_validator(mode="after") def _auto_derive_name(self) -> Self: @@ -586,9 +616,13 @@ def _resolve_and_validate(self) -> Self: f"Offline benchmarks must use 'max_throughput', got '{lp.type}'" ) elif effective_mode == TestType.ONLINE: - if lp.type not in (LoadPatternType.POISSON, LoadPatternType.CONCURRENCY): + if lp.type not in ( + LoadPatternType.POISSON, + LoadPatternType.CONCURRENCY, + LoadPatternType.MULTI_TURN, + ): raise ValueError( - "Online mode requires --load-pattern (poisson or concurrency)" + "Online mode requires --load-pattern (poisson, concurrency, or multi_turn)" ) return self diff --git a/src/inference_endpoint/config/templates/concurrency_template.yaml b/src/inference_endpoint/config/templates/concurrency_template.yaml index 7b560ed7..c44295b4 100644 --- a/src/inference_endpoint/config/templates/concurrency_template.yaml +++ b/src/inference_endpoint/config/templates/concurrency_template.yaml @@ -14,7 +14,7 @@ settings: max_duration_ms: 0 # Maximum test duration in ms (0 for no limit) n_samples_to_issue: null # Sample count override load_pattern: - type: concurrency # Load pattern type | options: max_throughput, poisson, concurrency, burst, step + type: concurrency # Load pattern type | options: max_throughput, poisson, concurrency, multi_turn, burst, step target_concurrency: 32 # Concurrent requests endpoint_config: endpoints: # Endpoint URL(s). Must include scheme, e.g. 'http://host:port'. diff --git a/src/inference_endpoint/config/templates/concurrency_template_full.yaml b/src/inference_endpoint/config/templates/concurrency_template_full.yaml index b2a0c89d..5def2719 100644 --- a/src/inference_endpoint/config/templates/concurrency_template_full.yaml +++ b/src/inference_endpoint/config/templates/concurrency_template_full.yaml @@ -22,6 +22,7 @@ datasets: # Dataset configs parser: # Column remapping: {prompt: , system: } prompt: text_input accuracy_config: null # Accuracy evaluation settings + multi_turn: null # Multi-turn conversation configuration - name: accuracy type: accuracy # Dataset purpose: performance or accuracy | options: performance, accuracy path: '' # Dataset file path @@ -36,6 +37,7 @@ datasets: # Dataset configs ground_truth: ground_truth # Ground truth column name extractor: boxed_math_extractor # Answer extractor (abcd_extractor, boxed_math_extractor, identity_extractor, python_code_extractor) num_repeats: 1 # Repeat dataset N times for evaluation + multi_turn: null # Multi-turn conversation configuration settings: runtime: min_duration_ms: 600000 # Min duration (ms, or with suffix: 600s, 10m) @@ -44,7 +46,7 @@ settings: scheduler_random_seed: 42 # Scheduler RNG seed dataloader_random_seed: 42 # Dataloader RNG seed load_pattern: - type: concurrency # Load pattern type | options: max_throughput, poisson, concurrency, burst, step + type: concurrency # Load pattern type | options: max_throughput, poisson, concurrency, multi_turn, burst, step target_qps: null # Target QPS target_concurrency: 32 # Concurrent requests client: diff --git a/src/inference_endpoint/config/templates/offline_template_full.yaml b/src/inference_endpoint/config/templates/offline_template_full.yaml index 6914ca3c..57c200d5 100644 --- a/src/inference_endpoint/config/templates/offline_template_full.yaml +++ b/src/inference_endpoint/config/templates/offline_template_full.yaml @@ -22,6 +22,7 @@ datasets: # Dataset configs parser: # Column remapping: {prompt: , system: } prompt: text_input accuracy_config: null # Accuracy evaluation settings + multi_turn: null # Multi-turn conversation configuration - name: accuracy type: accuracy # Dataset purpose: performance or accuracy | options: performance, accuracy path: '' # Dataset file path @@ -36,6 +37,7 @@ datasets: # Dataset configs ground_truth: ground_truth # Ground truth column name extractor: boxed_math_extractor # Answer extractor (abcd_extractor, boxed_math_extractor, identity_extractor, python_code_extractor) num_repeats: 1 # Repeat dataset N times for evaluation + multi_turn: null # Multi-turn conversation configuration settings: runtime: min_duration_ms: 600000 # Min duration (ms, or with suffix: 600s, 10m) @@ -44,7 +46,7 @@ settings: scheduler_random_seed: 42 # Scheduler RNG seed dataloader_random_seed: 42 # Dataloader RNG seed load_pattern: - type: max_throughput # Load pattern type | options: max_throughput, poisson, concurrency, burst, step + type: max_throughput # Load pattern type | options: max_throughput, poisson, concurrency, multi_turn, burst, step target_qps: null # Target QPS target_concurrency: null # Concurrent requests client: diff --git a/src/inference_endpoint/config/templates/online_template.yaml b/src/inference_endpoint/config/templates/online_template.yaml index d33c1fd5..a56dc9b0 100644 --- a/src/inference_endpoint/config/templates/online_template.yaml +++ b/src/inference_endpoint/config/templates/online_template.yaml @@ -14,7 +14,7 @@ settings: max_duration_ms: 0 # Maximum test duration in ms (0 for no limit) n_samples_to_issue: null # Sample count override load_pattern: - type: poisson # Load pattern type | options: max_throughput, poisson, concurrency, burst, step + type: poisson # Load pattern type | options: max_throughput, poisson, concurrency, multi_turn, burst, step target_qps: 10.0 # Target QPS endpoint_config: endpoints: # Endpoint URL(s). Must include scheme, e.g. 'http://host:port'. diff --git a/src/inference_endpoint/config/templates/online_template_full.yaml b/src/inference_endpoint/config/templates/online_template_full.yaml index 0e45267e..fea2a0c0 100644 --- a/src/inference_endpoint/config/templates/online_template_full.yaml +++ b/src/inference_endpoint/config/templates/online_template_full.yaml @@ -22,6 +22,7 @@ datasets: # Dataset configs parser: # Column remapping: {prompt: , system: } prompt: text_input accuracy_config: null # Accuracy evaluation settings + multi_turn: null # Multi-turn conversation configuration - name: accuracy type: accuracy # Dataset purpose: performance or accuracy | options: performance, accuracy path: '' # Dataset file path @@ -36,6 +37,7 @@ datasets: # Dataset configs ground_truth: ground_truth # Ground truth column name extractor: boxed_math_extractor # Answer extractor (abcd_extractor, boxed_math_extractor, identity_extractor, python_code_extractor) num_repeats: 1 # Repeat dataset N times for evaluation + multi_turn: null # Multi-turn conversation configuration settings: runtime: min_duration_ms: 600000 # Min duration (ms, or with suffix: 600s, 10m) @@ -44,7 +46,7 @@ settings: scheduler_random_seed: 42 # Scheduler RNG seed dataloader_random_seed: 42 # Dataloader RNG seed load_pattern: - type: poisson # Load pattern type | options: max_throughput, poisson, concurrency, burst, step + type: poisson # Load pattern type | options: max_throughput, poisson, concurrency, multi_turn, burst, step target_qps: 10.0 # Target QPS target_concurrency: null # Concurrent requests client: diff --git a/src/inference_endpoint/core/types.py b/src/inference_endpoint/core/types.py index f09f8a7e..5b8209e1 100644 --- a/src/inference_endpoint/core/types.py +++ b/src/inference_endpoint/core/types.py @@ -232,6 +232,7 @@ class Query( Attributes: id: Unique identifier for this query (auto-generated UUID). data: Request payload as a dictionary (typically contains prompt, model, etc.). + metadata: Internal metadata that round-trips through transport (e.g., conversation_id). headers: HTTP headers to include in the request (e.g., authorization). created_at: Timestamp when query was created (seconds since epoch). @@ -255,6 +256,7 @@ class Query( id: str = msgspec.field(default_factory=lambda: str(uuid.uuid4())) data: dict[str, Any] = msgspec.field(default_factory=dict) + metadata: dict[str, Any] = msgspec.field(default_factory=dict) headers: dict[str, str] = msgspec.field(default_factory=dict) created_at: float = msgspec.field(default_factory=time.time) @@ -337,6 +339,32 @@ def get_response_output_string(self) -> str: else: return "" + def with_metadata( + self, additional_metadata: dict[str, Any] | None + ) -> "QueryResult": + """Return a new QueryResult with merged metadata. + + Args: + additional_metadata: Metadata to merge into existing metadata. + Values in additional_metadata override existing keys. + + Returns: + New QueryResult with merged metadata (existing + additional). + If additional_metadata is None or empty, returns self unchanged. + """ + if not additional_metadata: + return self + + merged = dict(self.metadata) + merged.update(additional_metadata) + + return QueryResult( + id=self.id, + response_output=self.response_output, + metadata=merged, + error=self.error, + ) + class StreamChunk( msgspec.Struct, diff --git a/src/inference_endpoint/dataset_manager/__init__.py b/src/inference_endpoint/dataset_manager/__init__.py index 4bb6c575..12938f8e 100644 --- a/src/inference_endpoint/dataset_manager/__init__.py +++ b/src/inference_endpoint/dataset_manager/__init__.py @@ -21,6 +21,7 @@ from .dataset import Dataset, EmptyDataset from .factory import DataLoaderFactory +from .multi_turn_dataset import MultiTurnDataset from .predefined.aime25 import AIME25 from .predefined.cnndailymail import CNNDailyMail from .predefined.gpqa import GPQA @@ -58,4 +59,5 @@ "CNNDailyMail", "RandomDataset", "ShopifyProductCatalogue", + "MultiTurnDataset", ] diff --git a/src/inference_endpoint/dataset_manager/factory.py b/src/inference_endpoint/dataset_manager/factory.py index 6ed1674a..8c1226c6 100644 --- a/src/inference_endpoint/dataset_manager/factory.py +++ b/src/inference_endpoint/dataset_manager/factory.py @@ -24,6 +24,7 @@ from inference_endpoint.config.schema import Dataset as DatasetConfig from inference_endpoint.dataset_manager.dataset import Dataset, DatasetFormat +from .multi_turn_dataset import MultiTurnDataset from .transforms import ColumnRemap, MakeAdapterCompatible, Transform logger = logging.getLogger(__name__) @@ -95,18 +96,24 @@ def create_loader(config: DatasetConfig, num_repeats: int = 1, **kwargs) -> Data if file_format is not None: format_enum = DatasetFormat(file_format) + dataset_id = None + if config.multi_turn is not None: + dataset_id = MultiTurnDataset.DATASET_ID + transforms: list[Transform] = [] if remap is not None: # Parser convention is {target: source} (e.g. {prompt: article}). # ColumnRemap expects {source: target} β€” flip it. flipped = {src: dst for dst, src in remap.items()} transforms.append(ColumnRemap(flipped)) # type: ignore[arg-type] - transforms.append(MakeAdapterCompatible()) + if dataset_id != MultiTurnDataset.DATASET_ID: + transforms.append(MakeAdapterCompatible()) assert dataset_path is not None return Dataset.load_from_file( Path(dataset_path), transforms=transforms, format=format_enum, + dataset_id=dataset_id, num_repeats=num_repeats, ) diff --git a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py new file mode 100644 index 00000000..bff8c011 --- /dev/null +++ b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py @@ -0,0 +1,428 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multi-turn conversation dataset for conversational AI benchmarking.""" + +from typing import Any + +import pandas as pd + +from ..config.schema import APIType, ModelParams, StreamingMode +from ..exceptions import InputValidationError +from .dataset import Dataset +from .transforms import apply_transforms + +# Known generation parameter fields to forward from dataset to API requests. +# Aligned with OpenAI API specification and openai_msgspec_adapter.py implementation. +# These parameters work in both single-turn and multi-turn modes. +GENERATION_PARAMS = { + "model", + "max_new_tokens", + "max_completion_tokens", + "stream", + "temperature", + "top_p", + "top_k", + "seed", + "repetition_penalty", + "frequency_penalty", + "presence_penalty", + "stop", + "n", + "logit_bias", # Token probability adjustments + "name", # Entity name for role (NOT model name, e.g., 'Bob' for tracking) + "user", # End-user identifier for monitoring/abuse detection + "chat_template", # Custom chat formatting template + "tools", # OpenAI tool definitions (list[dict]) for tool-calling models +} + + +def _model_param_defaults(model_params: ModelParams | None) -> dict[str, Any]: + """Build per-request defaults for multi-turn rows from model params. + + Multi-turn datasets use `content` and conversation metadata rather than the + single-turn `prompt` field expected by adapter dataset transforms. Applying + those transforms would drop the conversation schema before load_sample() can + construct the messages array. Instead, we inject the request defaults here. + """ + if model_params is None: + return {} + + return { + "model": model_params.name, + "stream": model_params.streaming == StreamingMode.ON, + "max_completion_tokens": model_params.max_new_tokens, + "temperature": model_params.temperature, + "top_p": model_params.top_p, + "top_k": model_params.top_k, + "repetition_penalty": model_params.repetition_penalty, + } + + +def _expand_tool_results(row: dict) -> list[dict]: + """Expand a tool row into one OpenAI tool message per result. + + All ``role: "tool"`` rows carry a ``tool_results`` array. Each entry expands to + one OpenAI tool message with ``tool_call_id`` and ``content``. + + Returns an empty list if ``tool_results`` is absent or not a list (non-tool rows). + """ + tool_results = row.get("tool_results") + if not isinstance(tool_results, list): + return [] + return [ + { + "role": "tool", + "tool_call_id": result.get("tool_call_id"), + "content": result.get("content"), + } + for result in tool_results + ] + + +class MultiTurnDataset(Dataset, dataset_id="multi_turn_conversations"): + """Dataset for multi-turn conversations. + + Supports conversational AI benchmarking with turn sequencing and conversation history. + Validates that conversations have proper structure (alternating user/assistant roles) + and builds metadata for the scheduler to enforce turn ordering. + + Dataset format (JSONL): + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "...", "system": "..."} + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "..."} + {"conversation_id": "c1", "turn": 3, "role": "user", "content": "..."} + + Required columns: + - conversation_id: Unique identifier for each conversation + - turn: Turn number within conversation (1-indexed) + - role: Speaker role ("user" or "assistant") + - content: Message content + + Optional columns: + - system: System prompt associated with the conversation (typically set on the first user turn) + - model: Model name override + - max_new_tokens: Max tokens for this turn + + Attributes: + conversation_metadata: Metadata dict containing: + - samples: List of user turn metadata (index, conversation_id, turn, system) + - num_conversations: Total number of unique conversations + - max_turns_per_conv: Maximum turns in any conversation + """ + + COLUMN_NAMES = ["conversation_id", "turn", "role", "content"] + + def __init__(self, dataframe: pd.DataFrame, **kwargs): + """Initialize multi-turn dataset. + + Args: + dataframe: DataFrame with conversation data. + **kwargs: Additional arguments passed to Dataset.__init__. + + Raises: + ValueError: If conversation structure is invalid. + """ + super().__init__(dataframe, **kwargs) + self._validate_conversation_grouping() + self._validate_conversation_structure() + self._validate_turn_numbering() + self.conversation_metadata = self._build_metadata() + self._client_turn_indices: list[int] | None = None + + def _validate_conversation_grouping(self) -> None: + """Validate that all rows for each conversation_id appear consecutively in file order. + + Raises: + InputValidationError: If rows for a conversation_id are interleaved with other conversations. + """ + assert self.dataframe is not None, "Dataframe must be initialized" + seen: set[str] = set() + last_conv: str | None = None + for row in self.dataframe.to_dict(orient="records"): + conv_id = str(row["conversation_id"]) + if conv_id != last_conv: + if conv_id in seen: + raise InputValidationError( + f"Rows for conversation '{conv_id}' are not consecutive. " + "All rows for a conversation must appear together in the file." + ) + seen.add(conv_id) + last_conv = conv_id + + def _validate_conversation_structure(self): + """Validate conversations are well-formed. + + Accepts plain user/assistant alternation as well as tool sequences: + user β†’ assistant β†’ tool β†’ [assistant β†’ tool]* β†’ assistant β†’ user + + Raises: + ValueError: If any conversation has invalid role sequence. + """ + assert self.dataframe is not None, "Dataframe must be initialized" + + # Valid state transitions (flat 4-state machine β€” no assistant_tc node, + # no toolβ†’tool; converter always merges consecutive tool rows into tool_results) + VALID_NEXT: dict[str, set[str]] = { + "start": {"user"}, + "user": {"assistant"}, + "assistant": {"tool", "user"}, + "tool": {"assistant", "user"}, + } + + for conv_id, group in self.dataframe.groupby("conversation_id"): + sorted_group = group.sort_values("turn") + state = "start" + + for _, row in sorted_group.iterrows(): + role = row["role"] + + if role not in VALID_NEXT.get(state, set()): + raise ValueError( + f"Conversation {conv_id} has invalid role sequence at turn " + f"{row['turn']}: got '{role}' after state '{state}'" + ) + state = role + + def _validate_turn_numbering(self): + """Validate turn numbers are consecutive starting at 1. + + Raises: + ValueError: If turn numbers are not exactly 1, 2, 3, …, N. + """ + assert self.dataframe is not None, "Dataframe must be initialized" + + for conv_id, group in self.dataframe.groupby("conversation_id"): + turns = sorted(group["turn"].tolist()) + expected = list(range(1, len(turns) + 1)) + if turns != expected: + raise ValueError( + f"Conversation {conv_id}: Turn numbers must be consecutive starting at 1, " + f"got {turns}" + ) + + def _build_metadata(self) -> dict[str, Any]: + """Build metadata for scheduler (maps sample index to conversation context). + + Pre-computes the complete message list for each client turn so that + conversation history does not need to be accumulated at runtime. + + Returns: + Metadata dict with samples list, num_conversations, max_turns_per_conv, + client_turns_per_conversation, and pre_built_messages_by_key. + """ + assert self.dataframe is not None, "Dataframe must be initialized" + samples = [] + client_turns_df = self.dataframe[self.dataframe["role"].isin(["user", "tool"])] + + # Count client turns (user + tool) per conversation for completion tracking + client_turns_per_conv = ( + client_turns_df.groupby("conversation_id").size().to_dict() + ) + + # Map (conversation_id, turn) β†’ complete message list ready to send to endpoint. + # Each entry is: [system (optional)] + all prior rows formatted as messages + # + the current client turn message. + # This includes assistant rows (tool dispatches or terminal responses) + # so no runtime injection is required. + pre_built_messages_by_key: dict[tuple, list[dict]] = {} + + for conv_id, group in self.dataframe.groupby("conversation_id"): + sorted_group = group.sort_values("turn") + client_rows = sorted_group[sorted_group["role"].isin(["user", "tool"])] + + # Extract system prompt from the first row that has it (typically turn 1) + system_content: str | None = None + for _, srow in sorted_group.iterrows(): + val = srow.get("system") + if val and isinstance(val, str): + system_content = val + break + + for idx, row in client_rows.iterrows(): + t_n = int(row["turn"]) + + messages: list[dict] = [] + if system_content: + messages.append({"role": "system", "content": system_content}) + + # All dataset rows strictly before this client turn (includes + # assistant rows and prior tool results). + prior_rows = sorted_group[sorted_group["turn"] < t_n] + for _, prior_row in prior_rows.iterrows(): + msg: dict[str, Any] = {} + for key in ("role", "content", "tool_calls"): + val = prior_row.get(key) + if val is not None and not ( + isinstance(val, float) and pd.isna(val) + ): + msg[key] = val + if msg.get("role"): + # Expand merged parallel tool results: a single row with + # tool_results: [{tool_call_id, content}, ...] expands into + # one OpenAI tool message per result entry. + expanded = _expand_tool_results(msg) + if expanded: + messages.extend(expanded) + else: + messages.append(msg) + + # Append the current client turn message. + # A merged parallel-tool row carries tool_results instead of a + # single tool_call_id/content pair; expand to one message per result. + expanded = _expand_tool_results(row) + if expanded: + messages.extend(expanded) + else: + cur: dict[str, Any] = {} + for key in ("role", "content"): + val = row.get(key) + if val is not None and not ( + isinstance(val, float) and pd.isna(val) + ): + cur[key] = val + messages.append(cur) + + pre_built_messages_by_key[(conv_id, t_n)] = messages + + samples.append( + { + "index": idx, + "conversation_id": conv_id, + "turn": t_n, + } + ) + + return { + "samples": samples, + "num_conversations": self.dataframe["conversation_id"].nunique(), + "max_turns_per_conv": self.dataframe.groupby("conversation_id")["turn"] + .max() + .max(), + "client_turns_per_conversation": client_turns_per_conv, + "pre_built_messages_by_key": pre_built_messages_by_key, + } + + def load( + self, + adapter=None, + api_type: APIType | None = None, + model_params: ModelParams | None = None, + force: bool = False, + ): + """Load dataset and build a dense user-turn index. + + Multi-turn benchmarks only issue user turns. Assistant turns remain in the + backing data so the conversation structure can still be validated. + + Unlike single-turn datasets, multi-turn rows do not have a `prompt` + column, so adapter dataset transforms are intentionally skipped here. + They would apply a single-turn ColumnFilter and strip the conversation + fields required by load_sample(). Request defaults from model_params are + merged directly into the conversation rows instead. + """ + if not force and self.data is not None: + self._client_turn_indices = [ + index + for index, row in enumerate(self.data) + if row["role"] in ("user", "tool") + ] + return + + df = self.dataframe + if df is None: + raise ValueError( + f"Cannot load dataset {self.__class__.__name__}: dataframe is None" + ) + + transforms = [] + if self.transforms is not None: + transforms.extend(self.transforms) + + if transforms: + df = apply_transforms(df, transforms) + + defaults = _model_param_defaults(model_params) + for key, value in defaults.items(): + if value is None: + continue + if key in df.columns: + df[key] = df[key].where(pd.notna(df[key]), value) + else: + df[key] = value + + self.data = df.to_dict(orient="records") + assert self.data is not None, "Failed to convert DataFrame to records" + + self._client_turn_indices = [ + index + for index, row in enumerate(self.data) + if row["role"] in ("user", "tool") + ] + + def load_sample(self, index: int) -> dict[str, Any]: + """Load the Nth client turn (user or tool) as a benchmark sample.""" + assert self.data is not None, "Dataset not loaded. Call load() first." + assert ( + self._client_turn_indices is not None + ), "Dataset not loaded. Call load() first." + row = self.data[self._client_turn_indices[index]] + + content_val = row.get("content") + sample: dict[str, Any] = { + "conversation_id": row["conversation_id"], + "turn": row["turn"], + "role": row["role"], + } + if content_val is not None and not ( + isinstance(content_val, float) and pd.isna(content_val) + ): + sample["content"] = content_val + + for param in GENERATION_PARAMS: + if param in row: + value = row[param] + # Skip pandas NaN/None values + if value is not None and ( + not isinstance(value, float) or not pd.isna(value) + ): + sample[param] = value + + # Set defaults for critical params if not present + if "max_new_tokens" not in sample and "max_completion_tokens" not in sample: + sample["max_new_tokens"] = 128 + if "stream" not in sample: + sample["stream"] = False + + # Attach pre-built message list (system + history + current turn). + key = (row["conversation_id"], int(row["turn"])) + pre_built = self.conversation_metadata.get("pre_built_messages_by_key", {}).get( + key, [] + ) + sample["pre_built_messages"] = pre_built + + # Fields for use_dataset_history=False path (live history accumulation). + sample["current_turn_message"] = pre_built[-1] if pre_built else {} + first = pre_built[0] if pre_built else {} + sample["system_content"] = ( + first.get("content") if first.get("role") == "system" else None + ) + + return sample + + def num_samples(self) -> int: + assert ( + self._client_turn_indices is not None + ), "Dataset not loaded. Call load() first." + return len(self._client_turn_indices) diff --git a/src/inference_endpoint/endpoint_client/adapter_protocol.py b/src/inference_endpoint/endpoint_client/adapter_protocol.py index 164f71e1..feb590a4 100644 --- a/src/inference_endpoint/endpoint_client/adapter_protocol.py +++ b/src/inference_endpoint/endpoint_client/adapter_protocol.py @@ -95,22 +95,22 @@ def decode_response(cls, response_bytes: bytes, query_id: str) -> QueryResult: @abstractmethod def decode_sse_message(cls, json_bytes: bytes) -> Any: """ - Decode SSE message and extract content. + Decode SSE message and return adapter-specific chunk object. Args: json_bytes: Raw JSON bytes from SSE stream Returns: - Decoded SSE content (type depends on the adapter implementation) + Adapter-specific chunk object passed to accumulator.add_chunk() """ raise NotImplementedError("decode_sse_message not implemented") @classmethod - def parse_sse_chunk(cls, buffer: bytes, end_pos: int) -> list[str]: + def parse_sse_chunk(cls, buffer: bytes, end_pos: int) -> list[Any]: """ - Parse SSE chunk and extract all content strings. + Parse SSE chunk and extract all chunk objects. - Extracts JSON documents from SSE stream and decodes them to content strings. + Extracts JSON documents from SSE stream and decodes them to chunk objects. Silently ignores non-content SSE messages (role, finish_reason, etc). Args: @@ -118,7 +118,7 @@ def parse_sse_chunk(cls, buffer: bytes, end_pos: int) -> list[str]: end_pos: End position in buffer to parse up to Returns: - List of content strings extracted from the SSE chunk + List of chunk objects extracted from the SSE chunk """ json_docs = cls.SSE_DATA_PATTERN.findall(buffer[:end_pos]) parsed_contents = [] diff --git a/src/inference_endpoint/endpoint_client/http.py b/src/inference_endpoint/endpoint_client/http.py index d9047301..1e67a023 100644 --- a/src/inference_endpoint/endpoint_client/http.py +++ b/src/inference_endpoint/endpoint_client/http.py @@ -792,10 +792,12 @@ class InFlightRequest: query_id: Correlates response back to original Query. http_bytes: Serialized HTTP request for socket.write(). is_streaming: Whether this is a streaming (SSE) request or not. + query_metadata: Internal metadata carried alongside the request. connection: PooledConnection assigned to this request (set once request is fired). """ query_id: str http_bytes: bytes is_streaming: bool + query_metadata: dict[str, object] = field(default_factory=dict) connection: PooledConnection = field(default=None, repr=False) # type: ignore[assignment] diff --git a/src/inference_endpoint/endpoint_client/worker.py b/src/inference_endpoint/endpoint_client/worker.py index 8e0e560e..8fb69fce 100644 --- a/src/inference_endpoint/endpoint_client/worker.py +++ b/src/inference_endpoint/endpoint_client/worker.py @@ -341,6 +341,7 @@ def _prepare_request(self, query: Query) -> InFlightRequest: query_id=query.id, http_bytes=http_bytes, is_streaming=is_streaming, + query_metadata=query.metadata, ) return req @@ -429,7 +430,9 @@ async def _handle_streaming_body(self, req: InFlightRequest) -> None: self._pool.release(conn) # Send final complete back to main rank - self._responses.send(accumulator.get_final_output()) + self._responses.send( + accumulator.get_final_output().with_metadata(req.query_metadata) + ) @profile async def _handle_non_streaming_body(self, req: InFlightRequest) -> None: @@ -447,7 +450,7 @@ async def _handle_non_streaming_body(self, req: InFlightRequest) -> None: result = self._adapter.decode_response(response_bytes, query_id) # Send result back to main rank - self._responses.send(result) + self._responses.send(result.with_metadata(req.query_metadata)) async def _handle_error(self, query_id: str, error: Exception | str) -> None: """Send error response for a query.""" diff --git a/src/inference_endpoint/openai/accumulator.py b/src/inference_endpoint/openai/accumulator.py index 6cb23ed8..a01b7b44 100644 --- a/src/inference_endpoint/openai/accumulator.py +++ b/src/inference_endpoint/openai/accumulator.py @@ -15,15 +15,13 @@ """OpenAI SSE stream accumulator implementation.""" -import logging +from typing import Any from inference_endpoint.core.types import QueryResult, StreamChunk, TextModelOutput from inference_endpoint.endpoint_client.accumulator_protocol import ( SSEAccumulatorProtocol, ) -from inference_endpoint.openai.types import SSEDelta as OpenAISSEDelta - -logger = logging.getLogger(__name__) +from inference_endpoint.openai.types import SSEChoice class OpenAISSEAccumulator(SSEAccumulatorProtocol): @@ -32,15 +30,41 @@ class OpenAISSEAccumulator(SSEAccumulatorProtocol): def __init__(self, query_id: str, stream_all_chunks: bool): self.output_chunks: list[str] = [] self.reasoning_chunks: list[str] = [] + self._tool_calls: dict[int, dict[str, Any]] = {} + self._finish_reason: str | None = None self.first_chunk_sent = False self.query_id = query_id self.stream_all_chunks = stream_all_chunks - def add_chunk(self, delta: OpenAISSEDelta) -> StreamChunk | None: - if not isinstance(delta, OpenAISSEDelta): + def add_chunk(self, choice: SSEChoice | None) -> StreamChunk | None: + if not isinstance(choice, SSEChoice): + return None + + if choice.finish_reason: + self._finish_reason = choice.finish_reason + + delta = choice.delta + if delta is None: return None + # Accumulate tool_calls partials (streamed as incremental JSON fragments) + if delta.tool_calls: + for partial in delta.tool_calls: + idx = partial.get("index", 0) + tc = self._tool_calls.setdefault( + idx, {"type": "function", "function": {"arguments": ""}} + ) + if partial.get("id"): + tc["id"] = partial["id"] + if partial.get("type"): + tc["type"] = partial["type"] + fn = partial.get("function") or {} + if fn.get("name"): + tc["function"]["name"] = fn["name"] + if fn.get("arguments"): + tc["function"]["arguments"] += fn["arguments"] + content = None if delta.content: self.output_chunks.append(delta.content) @@ -68,9 +92,6 @@ def add_chunk(self, delta: OpenAISSEDelta) -> StreamChunk | None: def get_final_output(self) -> QueryResult: if self.reasoning_chunks: - # If there are reasoning chunks, then the first chunk received - # is the first reasoning chunk. The rest of the reasoning chunks, - # as well as the output chunks can be joined together. resp_reasoning: list[str] = [self.reasoning_chunks[0]] if len(self.reasoning_chunks) > 1: resp_reasoning.append("".join(self.reasoning_chunks[1:])) @@ -79,19 +100,26 @@ def get_final_output(self) -> QueryResult: reasoning=resp_reasoning, ) elif self.output_chunks: - # If there are only output chunks, the first chunk is used for - # TTFT calculations. The rest are joined together. resp_output: list[str] = [self.output_chunks[0]] if len(self.output_chunks) > 1: resp_output.append("".join(self.output_chunks[1:])) text_output = TextModelOutput(output=resp_output, reasoning=None) else: text_output = TextModelOutput(output=[], reasoning=None) + + metadata: dict[str, Any] = { + "first_chunk": not self.first_chunk_sent, + "final_chunk": True, + } + if self._finish_reason: + metadata["finish_reason"] = self._finish_reason + if self._tool_calls: + metadata["tool_calls"] = [ + self._tool_calls[i] for i in sorted(self._tool_calls) + ] + return QueryResult( id=self.query_id, response_output=text_output, - metadata={ - "first_chunk": not self.first_chunk_sent, - "final_chunk": True, - }, + metadata=metadata, ) diff --git a/src/inference_endpoint/openai/openai_adapter.py b/src/inference_endpoint/openai/openai_adapter.py index 5834d6b0..4830c682 100644 --- a/src/inference_endpoint/openai/openai_adapter.py +++ b/src/inference_endpoint/openai/openai_adapter.py @@ -36,7 +36,7 @@ Role6, ServiceTier, ) -from .types import SSEMessage +from .types import SSEChoice, SSEMessage class OpenAIAdapter(HttpRequestAdapter): @@ -75,10 +75,12 @@ def decode_response(cls, response_bytes: bytes, query_id: str) -> QueryResult: return cls.from_endpoint_response(openai_response, result_id=query_id) @classmethod - def decode_sse_message(cls, json_bytes: bytes) -> str: - """Decode SSE message and extract content string.""" + def decode_sse_message(cls, json_bytes: bytes) -> SSEChoice | None: + """Decode SSE message and return SSEChoice (delta + finish_reason).""" msg = msgspec.json.decode(json_bytes, type=SSEMessage) - return msg.choices[0].delta + if not msg.choices: + return None + return msg.choices[0] # ======================================================================== # Internal APIs @@ -86,15 +88,21 @@ def decode_sse_message(cls, json_bytes: bytes) -> str: @classmethod def to_endpoint_request(cls, query: Query) -> CreateChatCompletionRequest: - """Convert a Query to an OpenAI request.""" - if "prompt" not in query.data: - raise ValueError("prompt not found in query.data") - - messages = [{"role": Role5.user.value, "content": query.data["prompt"]}] - if "system" in query.data: - messages.insert( - 0, {"role": Role3.system.value, "content": query.data["system"]} - ) + """Convert a Query to an OpenAI request. + + Supports both single-turn (prompt/system) and multi-turn (messages array) formats. + """ + if "messages" in query.data and isinstance(query.data["messages"], list): + messages = query.data["messages"] + else: + if "prompt" not in query.data: + raise ValueError("prompt not found in query.data") + + messages = [{"role": Role5.user.value, "content": query.data["prompt"]}] + if "system" in query.data: + messages.insert( + 0, {"role": Role3.system.value, "content": query.data["system"]} + ) request = CreateChatCompletionRequest( model=ModelIdsShared(query.data.get("model", "no-model-name")), diff --git a/src/inference_endpoint/openai/openai_msgspec_adapter.py b/src/inference_endpoint/openai/openai_msgspec_adapter.py index 6106e1bd..e8f15ce6 100644 --- a/src/inference_endpoint/openai/openai_msgspec_adapter.py +++ b/src/inference_endpoint/openai/openai_msgspec_adapter.py @@ -18,6 +18,7 @@ """ import time +from typing import Any import msgspec from inference_endpoint.config.schema import ModelParams, StreamingMode @@ -37,6 +38,7 @@ ChatCompletionResponse, ChatCompletionResponseMessage, ChatMessage, + SSEChoice, SSEMessage, ) @@ -45,6 +47,17 @@ # ============================================================================ +def _chat_message_from_dict(msg: dict) -> "ChatMessage": + """Build a ChatMessage from a dict, forwarding all supported fields.""" + return ChatMessage( + role=msg["role"], + content=msg.get("content"), + name=msg.get("name"), + tool_calls=msg.get("tool_calls"), + tool_call_id=msg.get("tool_call_id"), + ) + + class OpenAIMsgspecAdapter(HttpRequestAdapter): """OpenAI adapter using msgspec for serialization/deserialization.""" @@ -105,10 +118,12 @@ def decode_response(cls, response_bytes: bytes, query_id: str) -> QueryResult: return cls.from_endpoint_response(openai_response, result_id=query_id) @classmethod - def decode_sse_message(cls, json_bytes: bytes) -> str: - """Decode SSE message and extract content string.""" + def decode_sse_message(cls, json_bytes: bytes) -> SSEChoice | None: + """Decode SSE message and return the SSEChoice (delta + finish_reason).""" msg = cls._sse_decoder.decode(json_bytes) - return msg.choices[0].delta + if not msg.choices: + return None + return msg.choices[0] # ======================================================================== # Internal APIs @@ -129,24 +144,31 @@ def to_endpoint_request(cls, query: Query) -> ChatCompletionRequest: Returns: msgspec.Struct ChatCompletionRequest """ - if "prompt" not in query.data: - raise ValueError("prompt not found in query.data") - - messages = [ - ChatMessage( - role="user", - content=query.data["prompt"], - name=query.data.get("name"), - ), - ] - if "system" in query.data: - messages.insert( - 0, + if "messages" in query.data and isinstance(query.data["messages"], list): + messages = [] + for message in query.data["messages"]: + if not isinstance(message, dict): + raise ValueError("messages entries must be dicts") + messages.append(_chat_message_from_dict(message)) + else: + if "prompt" not in query.data: + raise ValueError("prompt not found in query.data") + + messages = [ ChatMessage( - role="system", - content=query.data["system"], + role="user", + content=query.data["prompt"], + name=query.data.get("name"), ), - ) + ] + if "system" in query.data: + messages.insert( + 0, + ChatMessage( + role="system", + content=query.data["system"], + ), + ) return ChatCompletionRequest( model=query.data.get("model", "no-model-name"), @@ -164,6 +186,7 @@ def to_endpoint_request(cls, query: Query) -> ChatCompletionRequest: logit_bias=query.data.get("logit_bias"), user=query.data.get("user"), chat_template=query.data.get("chat_template"), + tools=query.data.get("tools"), ) @classmethod @@ -184,9 +207,19 @@ def from_endpoint_response( if not response.choices: raise ValueError("Response must contain at least one choice") + choice = response.choices[0] + metadata: dict[str, Any] = {} + if choice.finish_reason: + metadata["finish_reason"] = choice.finish_reason + if choice.message.tool_calls: + metadata["tool_calls"] = choice.message.tool_calls + if choice.message.reasoning_content: + metadata["reasoning_content"] = choice.message.reasoning_content + return QueryResult( id=result_id or response.id, - response_output=TextModelOutput(output=response.choices[0].message.content), + response_output=TextModelOutput(output=choice.message.content or ""), + metadata=metadata if metadata else None, ) @classmethod diff --git a/src/inference_endpoint/openai/types.py b/src/inference_endpoint/openai/types.py index 875656fa..4c301db5 100644 --- a/src/inference_endpoint/openai/types.py +++ b/src/inference_endpoint/openai/types.py @@ -46,6 +46,7 @@ class SSEDelta(msgspec.Struct, frozen=True, kw_only=True, omit_defaults=True, gc content: str = "" reasoning: str = "" + tool_calls: list[dict[str, Any]] | None = None class SSEChoice( @@ -75,12 +76,17 @@ class ChatMessage( ): # type: ignore[call-arg] """Chat message in OpenAI format. - content: str for text-only messages; list[dict] for multimodal (vision). + content: str for text-only messages; list[dict] for multimodal (vision); + None for tool-dispatching assistant messages. + tool_calls: list of tool call objects for assistant messages that invoke tools. + tool_call_id: correlates a tool result message to its tool call. """ role: str - content: ChatMessageContent + content: ChatMessageContent | None = None name: str | None = None + tool_calls: list[dict[str, Any]] | None = None + tool_call_id: str | None = None class ChatCompletionRequest( @@ -103,6 +109,7 @@ class ChatCompletionRequest( logit_bias: dict[str, float] | None = None user: str | None = None chat_template: str | None = None + tools: list[dict[str, Any]] | None = None class ChatCompletionResponseMessage( @@ -118,6 +125,8 @@ class ChatCompletionResponseMessage( role: str content: str | None = None refusal: str | None = None + tool_calls: list[dict[str, Any]] | None = None + reasoning_content: str | None = None class ChatCompletionChoice( diff --git a/tests/unit/core/test_types.py b/tests/unit/core/test_types.py index 52bdbe77..9c3bec15 100644 --- a/tests/unit/core/test_types.py +++ b/tests/unit/core/test_types.py @@ -891,3 +891,57 @@ def test_numeric_types_in_metadata(self): assert decoded.metadata["large_int"] == 9999999999999999 assert decoded.metadata["negative"] == -123.456 assert decoded.metadata["zero"] == 0 + + +@pytest.mark.unit +class TestQueryResultWithMetadata: + """Test QueryResult.with_metadata() method for metadata merging.""" + + def test_with_metadata_merge_behavior(self): + """Test that with_metadata adds new keys and overwrites existing ones.""" + result = QueryResult( + id="test", + response_output=TextModelOutput(output="hello"), + metadata={"key1": "old_value", "key2": "keep_me"}, + ) + + updated = result.with_metadata({"key1": "new_value", "key3": "added"}) + + assert updated.metadata == { + "key1": "new_value", + "key2": "keep_me", + "key3": "added", + } + assert updated.id == "test" + assert updated.response_output == TextModelOutput(output="hello") + + def test_with_metadata_none_returns_self(self): + """Test that with_metadata(None) returns self unchanged.""" + result = QueryResult( + id="test", + response_output=TextModelOutput(output="hello"), + metadata={"key1": "value"}, + ) + assert result.with_metadata(None) is result + + def test_with_metadata_empty_returns_self(self): + """Test that with_metadata({}) returns self unchanged.""" + result = QueryResult( + id="test", + response_output=TextModelOutput(output="hello"), + metadata={"key1": "value"}, + ) + assert result.with_metadata({}) is result + + def test_query_metadata_field_roundtrips(self): + """Test that Query.metadata round-trips through msgspec encoding.""" + query = Query( + data={"prompt": "Hello"}, + metadata={"conversation_id": "conv-1", "turn": 2}, + ) + + encoded = msgspec.json.encode(query) + decoded = msgspec.json.decode(encoded, type=Query) + + assert decoded.metadata["conversation_id"] == "conv-1" + assert decoded.metadata["turn"] == 2 diff --git a/tests/unit/dataset_manager/test_multi_turn_dataset.py b/tests/unit/dataset_manager/test_multi_turn_dataset.py new file mode 100644 index 00000000..09ccd224 --- /dev/null +++ b/tests/unit/dataset_manager/test_multi_turn_dataset.py @@ -0,0 +1,1073 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import tempfile +from collections.abc import Generator +from pathlib import Path + +import pandas as pd +import pytest +from inference_endpoint.dataset_manager.dataset import DatasetFormat +from inference_endpoint.dataset_manager.multi_turn_dataset import MultiTurnDataset + + +@pytest.fixture +def valid_multi_turn_jsonl() -> Generator[str, None, None]: + """Create valid multi-turn conversation JSONL data.""" + data = [ + { + "conversation_id": "conv_001", + "turn": 1, + "role": "user", + "content": "Hello, how are you?", + "system": "You are a helpful assistant", + }, + { + "conversation_id": "conv_001", + "turn": 2, + "role": "assistant", + "content": "I'm doing well, thank you!", + }, + { + "conversation_id": "conv_001", + "turn": 3, + "role": "user", + "content": "What can you help me with?", + }, + { + "conversation_id": "conv_002", + "turn": 1, + "role": "user", + "content": "What's the weather?", + }, + { + "conversation_id": "conv_002", + "turn": 2, + "role": "assistant", + "content": "I don't have access to real-time weather data.", + }, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + yield temp_path + Path(temp_path).unlink() + + +@pytest.fixture +def invalid_role_sequence_jsonl() -> Generator[str, None, None]: + """Create JSONL with invalid role sequence (not alternating).""" + data = [ + {"conversation_id": "conv_001", "turn": 1, "role": "user", "content": "Hello"}, + { + "conversation_id": "conv_001", + "turn": 2, + "role": "user", + "content": "Another user message", + }, # Invalid - consecutive user + { + "conversation_id": "conv_001", + "turn": 3, + "role": "assistant", + "content": "Response", + }, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + yield temp_path + Path(temp_path).unlink() + + +@pytest.fixture +def missing_fields_jsonl() -> Generator[str, None, None]: + """Create JSONL with missing required fields.""" + data = [ + {"conversation_id": "conv_001", "turn": 1, "role": "user"}, # Missing content + { + "conversation_id": "conv_001", + "turn": 2, + "role": "assistant", + "content": "Response", + }, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + yield temp_path + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_multi_turn_dataset_load_valid_data(valid_multi_turn_jsonl): + """Test loading valid multi-turn conversation data.""" + dataset = MultiTurnDataset.load_from_file( + valid_multi_turn_jsonl, format=DatasetFormat.JSONL + ) + dataset.load() + + # Should have 5 rows total (3 for conv_001, 2 for conv_002) + assert len(dataset.data) == 5 + + # Should have 3 user turns (samples) - only user turns are indexed + assert dataset.num_samples() == 3 + + +@pytest.mark.unit +def test_multi_turn_dataset_user_turn_indexing(valid_multi_turn_jsonl): + """Test that only client turns (user + tool) are indexed as samples.""" + dataset = MultiTurnDataset.load_from_file( + valid_multi_turn_jsonl, format=DatasetFormat.JSONL + ) + dataset.load() + + # Verify client turn indices are correct (fixture has only user turns) + assert len(dataset._client_turn_indices) == 3 + + # Check that indices point to client turns + for idx in dataset._client_turn_indices: + assert dataset.data[idx]["role"] in ("user", "tool") + + +@pytest.mark.unit +def test_multi_turn_dataset_load_sample(valid_multi_turn_jsonl): + """Test load_sample returns correct user turns with dense indexing.""" + dataset = MultiTurnDataset.load_from_file( + valid_multi_turn_jsonl, format=DatasetFormat.JSONL + ) + dataset.load() + + # Sample 0 should be first user turn + sample_0 = dataset.load_sample(0) + assert sample_0["conversation_id"] == "conv_001" + assert sample_0["turn"] == 1 + assert sample_0["role"] == "user" + assert sample_0["content"] == "Hello, how are you?" + # System prompt is in pre_built_messages, not as a separate field + assert sample_0["pre_built_messages"][0]["role"] == "system" + assert sample_0["pre_built_messages"][0]["content"] == "You are a helpful assistant" + + # Sample 1 should be second user turn (conv_001 turn 3) + sample_1 = dataset.load_sample(1) + assert sample_1["conversation_id"] == "conv_001" + assert sample_1["turn"] == 3 + assert sample_1["role"] == "user" + assert sample_1["content"] == "What can you help me with?" + + # Sample 2 should be third user turn (conv_002 turn 1) + sample_2 = dataset.load_sample(2) + assert sample_2["conversation_id"] == "conv_002" + assert sample_2["turn"] == 1 + assert sample_2["role"] == "user" + assert sample_2["content"] == "What's the weather?" + + +@pytest.mark.unit +def test_multi_turn_dataset_conversation_metadata(valid_multi_turn_jsonl): + """Test conversation metadata generation.""" + dataset = MultiTurnDataset.load_from_file( + valid_multi_turn_jsonl, format=DatasetFormat.JSONL + ) + dataset.load() + + metadata = dataset.conversation_metadata + + # Check metadata structure + assert "samples" in metadata + assert "num_conversations" in metadata + assert "max_turns_per_conv" in metadata + assert "client_turns_per_conversation" in metadata + + # Should have 3 client turn samples (fixture has only user turns, no tool turns) + assert len(metadata["samples"]) == 3 + + # Should have 2 conversations + assert metadata["num_conversations"] == 2 + + # Max turns per conversation should be 3 (conv_001 has 3 turns) + assert metadata["max_turns_per_conv"] == 3 + + # Check sample metadata structure + sample_meta = metadata["samples"][0] + assert "index" in sample_meta + assert "conversation_id" in sample_meta + assert "turn" in sample_meta + + +@pytest.mark.unit +def test_multi_turn_dataset_validation_invalid_role_sequence( + invalid_role_sequence_jsonl, +): + """Test validation rejects invalid role sequences.""" + # Validation happens during load_from_file (in __init__), not during load() + with pytest.raises(ValueError, match="invalid role sequence"): + MultiTurnDataset.load_from_file( + invalid_role_sequence_jsonl, format=DatasetFormat.JSONL + ) + + +@pytest.mark.unit +def test_multi_turn_dataset_validation_missing_fields(missing_fields_jsonl): + """Missing content field is preserved as None in the loaded sample.""" + dataset = MultiTurnDataset.load_from_file( + missing_fields_jsonl, format=DatasetFormat.JSONL + ) + dataset.load() + + sample = dataset.load_sample(0) + # Missing content is no longer propagated to the sample dict + assert "content" not in sample + + +@pytest.mark.unit +def test_multi_turn_dataset_multiple_conversations(): + """Test dataset with multiple conversations of varying lengths.""" + data = [ + # Conversation 1: 3 turns (user-assistant-user, missing final assistant) + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "msg1"}, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "resp1"}, + {"conversation_id": "c1", "turn": 3, "role": "user", "content": "msg1b"}, + # Conversation 2: 4 turns (complete user-assistant alternation) + {"conversation_id": "c2", "turn": 1, "role": "user", "content": "msg2"}, + {"conversation_id": "c2", "turn": 2, "role": "assistant", "content": "resp2"}, + {"conversation_id": "c2", "turn": 3, "role": "user", "content": "msg3"}, + {"conversation_id": "c2", "turn": 4, "role": "assistant", "content": "resp3"}, + # Conversation 3: 2 turns (complete user-assistant) + {"conversation_id": "c3", "turn": 1, "role": "user", "content": "msg4"}, + {"conversation_id": "c3", "turn": 2, "role": "assistant", "content": "resp4"}, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + dataset = MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + dataset.load() + + # 9 total rows, 5 user turns (c1:t1, c1:t3, c2:t1, c2:t3, c3:t1) + assert len(dataset.data) == 9 + assert dataset.num_samples() == 5 + + # Metadata checks + metadata = dataset.conversation_metadata + assert metadata["num_conversations"] == 3 + assert metadata["max_turns_per_conv"] == 4 # c2 has 4 turns + + # Verify user turns are correctly indexed + samples = [dataset.load_sample(i) for i in range(5)] + + # Check we got all the user turns + user_turns = [(s["conversation_id"], s["turn"]) for s in samples] + expected_turns = [("c1", 1), ("c1", 3), ("c2", 1), ("c2", 3), ("c3", 1)] + assert sorted(user_turns) == sorted(expected_turns) + + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_multi_turn_dataset_system_prompt_handling(valid_multi_turn_jsonl): + """Test system prompt is included as the first message in pre_built_messages. + + The system prompt is pre-baked into every client turn's message list so the + conversation manager no longer needs to track it separately. + """ + dataset = MultiTurnDataset.load_from_file( + valid_multi_turn_jsonl, format=DatasetFormat.JSONL + ) + dataset.load() + + # First sample: pre_built_messages starts with system message + sample_0 = dataset.load_sample(0) + assert "pre_built_messages" in sample_0 + msgs = sample_0["pre_built_messages"] + assert msgs[0]["role"] == "system" + assert msgs[0]["content"] == "You are a helpful assistant" + + # Second sample (same conversation, turn 3): system message still first + sample_1 = dataset.load_sample(1) + msgs_1 = sample_1["pre_built_messages"] + assert msgs_1[0]["role"] == "system" + assert msgs_1[0]["content"] == "You are a helpful assistant" + + +@pytest.mark.unit +def test_multi_turn_dataset_single_turn_conversations(): + """Test conversations with only one turn.""" + data = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "Single turn"}, + # No assistant response + { + "conversation_id": "c2", + "turn": 1, + "role": "user", + "content": "Another single", + }, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + dataset = MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + dataset.load() + + # 2 rows, 2 user turns + assert len(dataset.data) == 2 + assert dataset.num_samples() == 2 + + # Both samples should be user turns + assert dataset.load_sample(0)["role"] == "user" + assert dataset.load_sample(1)["role"] == "user" + + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_multi_turn_dataset_empty_conversation(): + """Empty JSONL file raises ValueError (no columns to validate against).""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + temp_path = f.name + + try: + with pytest.raises(ValueError): + MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_multi_turn_dataset_conversation_grouping(): + """Test that properly grouped conversations load correctly.""" + data = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "c1t1"}, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "c1t2"}, + {"conversation_id": "c1", "turn": 3, "role": "user", "content": "c1t3"}, + {"conversation_id": "c2", "turn": 1, "role": "user", "content": "c2t1"}, + {"conversation_id": "c2", "turn": 2, "role": "assistant", "content": "c2t2"}, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + dataset = MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + dataset.load() + + # 5 total rows, 3 user turns (c1t1, c1t3, c2t1) + assert len(dataset.data) == 5 + assert dataset.num_samples() == 3 + + # Load samples to verify conversation grouping + samples = [dataset.load_sample(i) for i in range(3)] + + # Verify conversation IDs + conv_ids = [s["conversation_id"] for s in samples] + assert conv_ids == ["c1", "c1", "c2"] + + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_multi_turn_dataset_interleaved_conversations_rejected(): + """Test that interleaved conversation rows raise InputValidationError.""" + from inference_endpoint.exceptions import InputValidationError + + data = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "c1t1"}, + {"conversation_id": "c2", "turn": 1, "role": "user", "content": "c2t1"}, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "c1t2"}, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + with pytest.raises(InputValidationError, match="not consecutive"): + MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +@pytest.mark.parametrize( + "rows", + [ + # assistant-first + [ + {"conversation_id": "c1", "turn": 1, "role": "assistant", "content": "A"}, + {"conversation_id": "c1", "turn": 2, "role": "user", "content": "B"}, + ], + # consecutive assistants + [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "A"}, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "B"}, + {"conversation_id": "c1", "turn": 3, "role": "assistant", "content": "C"}, + ], + # tool directly after user (tool-before-assistant) + [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "A"}, + { + "conversation_id": "c1", + "turn": 2, + "role": "tool", + "tool_results": [{"tool_call_id": "x", "content": "r"}], + }, + ], + # consecutive users + [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "A"}, + {"conversation_id": "c1", "turn": 2, "role": "user", "content": "B"}, + ], + ], +) +def test_validation_rejects_invalid_role_sequence(rows): + """Invalid role sequences raise ValueError regardless of turn numbering.""" + with pytest.raises(ValueError, match="invalid role sequence"): + MultiTurnDataset(pd.DataFrame(rows)) + + +@pytest.mark.unit +def test_multi_turn_dataset_additional_fields(): + """Test that additional fields (model, max_new_tokens, etc.) are preserved.""" + data = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "Hello", + "model": "gpt-4", + "max_new_tokens": 256, + "temperature": 0.7, + }, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "Hi"}, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + dataset = MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + dataset.load() + + sample = dataset.load_sample(0) + # Fields may or may not be present depending on how dataframe handles them + # Just check they're accessible if present + if "model" in sample: + assert sample["model"] == "gpt-4" + if "max_new_tokens" in sample: + assert sample["max_new_tokens"] == 256 + if "temperature" in sample: + assert sample["temperature"] == pytest.approx(0.7) + + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_multi_turn_dataset_openai_field_forwarding(): + """Test that OpenAI-specific fields are preserved and forwarded.""" + data = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "Hello", + # OpenAI fields that should be forwarded + "n": 3, + "name": "Alice", + "user": "user_12345", + "logit_bias": {"50256": -100}, + "chat_template": "custom_template", + }, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "Hi"}, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + dataset = MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + dataset.load() + + sample = dataset.load_sample(0) + + # Verify OpenAI fields are present + assert sample.get("n") == 3 + assert sample.get("name") == "Alice" + assert sample.get("user") == "user_12345" + assert sample.get("logit_bias") == {"50256": -100} + assert sample.get("chat_template") == "custom_template" + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_multi_turn_dataset_all_generation_params(): + """Test that all generation parameters in GENERATION_PARAMS are forwarded.""" + from inference_endpoint.dataset_manager.multi_turn_dataset import GENERATION_PARAMS + + # Create dataset with all possible generation params + data = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "Test", + # Include all params from GENERATION_PARAMS + "model": "test-model", + "max_new_tokens": 100, + "max_completion_tokens": 100, + "stream": True, + "temperature": 0.8, + "top_p": 0.95, + "top_k": 50, + "seed": 42, + "repetition_penalty": 1.1, + "frequency_penalty": 0.5, + "presence_penalty": 0.3, + "stop": ["END"], + "n": 2, + "logit_bias": {"100": 10}, + "name": "TestEntity", + "user": "test_user_001", + "chat_template": "test_template", + }, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": "Response", + }, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + dataset = MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + dataset.load() + + sample = dataset.load_sample(0) + + # Verify all GENERATION_PARAMS fields are forwarded + # (excluding conversational fields like conversation_id, turn, role, content, system) + for param in GENERATION_PARAMS: + if param in data[0]: + assert ( + param in sample + ), f"Generation parameter '{param}' not forwarded to sample" + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_validation_rejects_non_contiguous_turns(): + """Turn numbers must be consecutive; gaps are rejected.""" + rows = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "a"}, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "b"}, + {"conversation_id": "c1", "turn": 5, "role": "user", "content": "c"}, + {"conversation_id": "c1", "turn": 6, "role": "assistant", "content": "d"}, + ] + with pytest.raises(ValueError, match="consecutive"): + MultiTurnDataset(pd.DataFrame(rows)) + + +@pytest.mark.unit +def test_validation_rejects_turns_not_starting_at_one(): + """Validation should reject conversations whose turns don't start at 1.""" + data = [ + {"conversation_id": "c1", "turn": 3, "role": "user", "content": "msg"}, + {"conversation_id": "c1", "turn": 4, "role": "assistant", "content": "resp"}, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + with pytest.raises(ValueError, match="consecutive"): + MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_validation_accepts_valid_contiguous_turns(): + """Validation should accept contiguous turn sequences.""" + data = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "msg1"}, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "resp1"}, + {"conversation_id": "c1", "turn": 3, "role": "user", "content": "msg2"}, + {"conversation_id": "c1", "turn": 4, "role": "assistant", "content": "resp2"}, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + dataset = MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + dataset.load() + assert dataset.num_samples() == 2 + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_validation_rejects_turn_starting_at_zero(): + """Validation should reject conversations starting at turn 0.""" + data = [ + {"conversation_id": "c1", "turn": 0, "role": "user", "content": "msg"}, + {"conversation_id": "c1", "turn": 1, "role": "assistant", "content": "resp"}, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + with pytest.raises(ValueError, match="consecutive"): + MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_validation_rejects_duplicate_turn_numbers(): + """Duplicate turn numbers within a conversation are rejected.""" + data = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "msg1"}, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "resp1"}, + {"conversation_id": "c2", "turn": 1, "role": "user", "content": "msg2"}, + {"conversation_id": "c2", "turn": 2, "role": "assistant", "content": "resp2"}, + # c2 has duplicate turn 2 β€” second assistant row with same turn number + {"conversation_id": "c2", "turn": 2, "role": "user", "content": "dup"}, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + with pytest.raises(ValueError, match="consecutive"): + MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_validation_rejects_assistant_tc_role_literal(): + """role='assistant_tc' literal in dataset is rejected; only 'assistant' is valid.""" + rows = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "Q"}, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant_tc", + "tool_calls": [ + { + "id": "c0", + "type": "function", + "function": {"name": "f", "arguments": "{}"}, + } + ], + }, + { + "conversation_id": "c1", + "turn": 3, + "role": "tool", + "tool_results": [{"tool_call_id": "c0", "content": "r"}], + }, + {"conversation_id": "c1", "turn": 4, "role": "assistant", "content": "A"}, + ] + with pytest.raises(ValueError, match="invalid role sequence"): + MultiTurnDataset(pd.DataFrame(rows)) + + +# ============================================================================ +# Tool sequence tests +# ============================================================================ + + +def _make_tool_sequence_df(): + """Return a DataFrame with a tool sequence embedded between user turns.""" + return pd.DataFrame( + [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "What is the weather?", + "system": "Be helpful", + }, + # assistant (with tool_calls): dispatches a tool call + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_c1_0", + "type": "function", + "function": {"name": "get_weather", "arguments": "{}"}, + } + ], + }, + # tool result + { + "conversation_id": "c1", + "turn": 3, + "role": "tool", + "tool_results": [ + {"tool_call_id": "call_c1_0", "content": '{"temp": 22}'} + ], + }, + # terminal assistant + { + "conversation_id": "c1", + "turn": 4, + "role": "assistant", + "content": "The weather is 22Β°C.", + }, + # second user turn + { + "conversation_id": "c1", + "turn": 5, + "role": "user", + "content": "Thanks!", + }, + ] + ) + + +@pytest.mark.unit +def test_validation_accepts_tool_sequence(): + """user β†’ assistant β†’ tool β†’ assistant β†’ user passes validation.""" + df = _make_tool_sequence_df() + ds = MultiTurnDataset(df) + ds.load() + assert ds.num_samples() == 3 # user(1), tool(3), user(5) are all client turns + + +@pytest.mark.unit +def test_validation_accepts_parallel_tool_calls(): + """Assistant with two tool_calls + merged tool_results row passes.""" + df = pd.DataFrame( + [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "Hi"}, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "c_0", + "type": "function", + "function": {"name": "f1", "arguments": "{}"}, + }, + { + "id": "c_1", + "type": "function", + "function": {"name": "f2", "arguments": "{}"}, + }, + ], + }, + { + "conversation_id": "c1", + "turn": 3, + "role": "tool", + "tool_results": [ + {"tool_call_id": "c_0", "content": "r1"}, + {"tool_call_id": "c_1", "content": "r2"}, + ], + }, + { + "conversation_id": "c1", + "turn": 4, + "role": "assistant", + "content": "Done", + }, + ] + ) + ds = MultiTurnDataset(df) + ds.load() + assert ds.num_samples() == 2 # user(1), tool(3) are client turns + + +@pytest.mark.unit +def test_load_sample_merged_tool_row_has_no_content_key(): + """load_sample for a merged tool_results row must not emit content: NaN.""" + df = pd.DataFrame( + [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "Go"}, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "c_0", + "type": "function", + "function": {"name": "f1", "arguments": "{}"}, + }, + { + "id": "c_1", + "type": "function", + "function": {"name": "f2", "arguments": "{}"}, + }, + ], + }, + { + "conversation_id": "c1", + "turn": 3, + "role": "tool", + "tool_results": [ + {"tool_call_id": "c_0", "content": "r1"}, + {"tool_call_id": "c_1", "content": "r2"}, + ], + }, + { + "conversation_id": "c1", + "turn": 4, + "role": "assistant", + "content": "Done", + }, + ] + ) + ds = MultiTurnDataset(df) + ds.load() + + # Sample 1 is the merged tool row (turn 3) + s1 = ds.load_sample(1) + assert s1["role"] == "tool" + assert "content" not in s1 # must NOT emit NaN + assert "pre_built_messages" in s1 + + +@pytest.mark.unit +def test_build_metadata_pre_built_messages(): + """pre_built_messages_by_key contains complete message arrays for each client turn. + + Dataset: + turn 1: user ← client turn 1 + turn 2: asst_tc ← scripted (assistant with tool_calls) + turn 3: tool ← client turn 2 + turn 4: assistant ← terminal assistant + turn 5: user ← client turn 3 + + Expected pre_built_messages: + client turn 1 (t=1): [system, user(1)] + client turn 2 (t=3): [system, user(1), asst_tc(2), tool(3)] + client turn 3 (t=5): [system, user(1), asst_tc(2), tool(3), asst(4), user(5)] + """ + df = _make_tool_sequence_df() + ds = MultiTurnDataset(df) + + pbm = ds.conversation_metadata["pre_built_messages_by_key"] + + # Client turn 1 (user, t=1): [system, user(1)] + msgs_t1 = pbm[("c1", 1)] + assert len(msgs_t1) == 2 + assert msgs_t1[0] == {"role": "system", "content": "Be helpful"} + assert msgs_t1[1] == {"role": "user", "content": "What is the weather?"} + + # Client turn 2 (tool, t=3): [system, user(1), asst_tc(2), tool(3)] + msgs_t3 = pbm[("c1", 3)] + assert len(msgs_t3) == 4 + assert msgs_t3[0]["role"] == "system" + assert msgs_t3[1]["role"] == "user" + assert msgs_t3[2]["role"] == "assistant" + assert "tool_calls" in msgs_t3[2] + assert msgs_t3[3]["role"] == "tool" + assert msgs_t3[3]["content"] == '{"temp": 22}' + assert msgs_t3[3]["tool_call_id"] == "call_c1_0" + + # Client turn 3 (user, t=5): [system, user(1), asst_tc(2), tool(3), asst(4), user(5)] + msgs_t5 = pbm[("c1", 5)] + assert len(msgs_t5) == 6 + assert msgs_t5[4] == {"role": "assistant", "content": "The weather is 22Β°C."} + assert msgs_t5[5] == {"role": "user", "content": "Thanks!"} + + +@pytest.mark.unit +def test_build_metadata_pre_built_messages_no_tools(): + """Plain user/assistant alternation produces correct pre_built_messages.""" + df = pd.DataFrame( + [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "A"}, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "B"}, + {"conversation_id": "c1", "turn": 3, "role": "user", "content": "C"}, + ] + ) + ds = MultiTurnDataset(df) + pbm = ds.conversation_metadata["pre_built_messages_by_key"] + + # Turn 1: just the user message (no system, no prior rows) + assert pbm[("c1", 1)] == [{"role": "user", "content": "A"}] + + # Turn 3: user(1) + assistant(2) + user(3) + msgs = pbm[("c1", 3)] + assert len(msgs) == 3 + assert msgs[0] == {"role": "user", "content": "A"} + assert msgs[1] == {"role": "assistant", "content": "B"} + assert msgs[2] == {"role": "user", "content": "C"} + + +@pytest.mark.unit +def test_load_sample_includes_pre_built_messages(): + """load_sample returns pre_built_messages with the complete message list.""" + df = _make_tool_sequence_df() + ds = MultiTurnDataset(df) + ds.load() + + s0 = ds.load_sample(0) # user turn 1 + assert "pre_built_messages" in s0 + msgs = s0["pre_built_messages"] + assert msgs[0]["role"] == "system" + assert msgs[-1] == {"role": "user", "content": "What is the weather?"} + + s1 = ds.load_sample(1) # tool turn 3 + assert s1["role"] == "tool" + msgs_t3 = s1["pre_built_messages"] + # system + user(1) + asst_tc(2) + tool(3) = 4 messages + assert len(msgs_t3) == 4 + assert msgs_t3[-1]["role"] == "tool" + + s2 = ds.load_sample(2) # user turn 5 + msgs_t5 = s2["pre_built_messages"] + # system + user(1) + asst_tc(2) + tool(3) + asst(4) + user(5) = 6 messages + assert len(msgs_t5) == 6 + + +@pytest.mark.unit +def test_client_turns_include_tool_rows(): + """Tool rows are counted in num_samples() as client turns.""" + df = _make_tool_sequence_df() + ds = MultiTurnDataset(df) + ds.load() + # 5 rows total: user(1), assistant(2), tool(3), assistant(4), user(5) + # Client turns: user(1), tool(3), user(5) β†’ 3 + assert ds.num_samples() == 3 + + +# ============================================================================ +# Pre-built messages content correctness +# ============================================================================ + + +@pytest.mark.unit +def test_pre_built_messages_include_prior_assistant_response(valid_multi_turn_jsonl): + """The terminal assistant response before each user turn is included in pre_built_messages.""" + dataset = MultiTurnDataset.load_from_file( + valid_multi_turn_jsonl, format=DatasetFormat.JSONL + ) + dataset.load() + + # Sample 0: turn 1 (first user) β†’ just [system, user(1)] + s0 = dataset.load_sample(0) + msgs_0 = s0["pre_built_messages"] + assert msgs_0[0]["role"] == "system" + assert msgs_0[-1]["role"] == "user" + + # Sample 1: turn 3 (second user) β†’ [system, user(1), assistant(2), user(3)] + s1 = dataset.load_sample(1) + msgs_1 = s1["pre_built_messages"] + assert len(msgs_1) == 4 + assert msgs_1[2] == {"role": "assistant", "content": "I'm doing well, thank you!"} + assert msgs_1[3]["role"] == "user" + + # Sample 2: turn 1 of conv_002 β†’ no prior assistant row + s2 = dataset.load_sample(2) + msgs_2 = s2["pre_built_messages"] + assert all(m["role"] != "assistant" for m in msgs_2) + + +@pytest.mark.unit +def test_pre_built_messages_no_cross_conversation_bleed(): + """Messages for conv_001 must not appear in conv_002's pre_built_messages.""" + data = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "c1 user"}, + {"conversation_id": "c2", "turn": 1, "role": "user", "content": "c2 user"}, + {"conversation_id": "c2", "turn": 2, "role": "assistant", "content": "c2 resp"}, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + dataset = MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + dataset.load() + + # c1: only its own user message + s_c1 = dataset.load_sample(0) + assert s_c1["pre_built_messages"] == [{"role": "user", "content": "c1 user"}] + + # c2: only c2 messages (no c1 content) + s_c2 = dataset.load_sample(1) + contents = [m.get("content") for m in s_c2["pre_built_messages"]] + assert "c1 user" not in contents + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_pre_built_messages_with_tool_sequence_terminal_assistant(): + """Terminal assistant response (turn 4) appears in pre_built_messages for user(5).""" + df = _make_tool_sequence_df() + ds = MultiTurnDataset(df) + ds.load() + + s2 = ds.load_sample(2) # user turn 5 + msgs = s2["pre_built_messages"] + # The terminal assistant at turn 4 should be included + assistant_msgs = [m for m in msgs if m["role"] == "assistant" and m.get("content")] + assert any(m["content"] == "The weather is 22Β°C." for m in assistant_msgs) diff --git a/tests/unit/openai/test_msgspec_adapter.py b/tests/unit/openai/test_msgspec_adapter.py new file mode 100644 index 00000000..8127d199 --- /dev/null +++ b/tests/unit/openai/test_msgspec_adapter.py @@ -0,0 +1,146 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for OpenAIMsgspecAdapter with tool call fields.""" + +import json + +import msgspec +import pytest +from inference_endpoint.core.types import Query +from inference_endpoint.openai.openai_msgspec_adapter import ( + OpenAIMsgspecAdapter, + _chat_message_from_dict, +) +from inference_endpoint.openai.types import ChatMessage + + +@pytest.mark.unit +def test_chat_message_tool_calls_serialised(): + """tool_calls field is included in the JSON output when non-None.""" + tool_calls = [ + { + "id": "call_0", + "type": "function", + "function": {"name": "get_weather", "arguments": "{}"}, + } + ] + msg = ChatMessage(role="assistant", tool_calls=tool_calls) + encoded = msgspec.json.encode(msg) + decoded = json.loads(encoded) + assert decoded["role"] == "assistant" + assert decoded["tool_calls"] == tool_calls + assert "content" not in decoded # omit_defaults=True, None omitted + + +@pytest.mark.unit +def test_chat_message_tool_call_id_serialised(): + """tool_call_id field is included in the JSON output when non-None.""" + msg = ChatMessage(role="tool", content="result", tool_call_id="call_0") + encoded = msgspec.json.encode(msg) + decoded = json.loads(encoded) + assert decoded["role"] == "tool" + assert decoded["content"] == "result" + assert decoded["tool_call_id"] == "call_0" + + +@pytest.mark.unit +def test_to_endpoint_request_preserves_tool_calls(): + """to_endpoint_request forwards tool_calls in the messages array.""" + tool_calls = [ + { + "id": "call_0", + "type": "function", + "function": {"name": "lookup", "arguments": '{"q": "test"}'}, + } + ] + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": None, "tool_calls": tool_calls}, + {"role": "tool", "content": "answer", "tool_call_id": "call_0"}, + {"role": "assistant", "content": "Done"}, + ] + query = Query( + id="q1", + data={ + "model": "test-model", + "messages": messages, + }, + ) + request = OpenAIMsgspecAdapter.to_endpoint_request(query) + encoded = msgspec.json.encode(request) + payload = json.loads(encoded) + + msgs = payload["messages"] + # assistant tool-dispatch row + assert msgs[1]["role"] == "assistant" + assert msgs[1]["tool_calls"] == tool_calls + assert "content" not in msgs[1] + # tool result row + assert msgs[2]["role"] == "tool" + assert msgs[2]["tool_call_id"] == "call_0" + assert msgs[2]["content"] == "answer" + # terminal assistant row + assert msgs[3]["content"] == "Done" + + +@pytest.mark.unit +def test_backward_compat_plain_messages_unchanged(): + """Plain user/assistant messages encode identically to before the change.""" + messages = [ + {"role": "system", "content": "You are helpful"}, + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!"}, + ] + query = Query( + id="q2", + data={"model": "m", "messages": messages}, + ) + request = OpenAIMsgspecAdapter.to_endpoint_request(query) + encoded = msgspec.json.encode(request) + payload = json.loads(encoded) + + for i, msg in enumerate(payload["messages"]): + assert msg["role"] == messages[i]["role"] + assert msg["content"] == messages[i]["content"] + assert "tool_calls" not in msg + assert "tool_call_id" not in msg + + +@pytest.mark.unit +def test_chat_message_from_dict_all_fields(): + """_chat_message_from_dict forwards all four optional fields.""" + tool_calls = [ + {"id": "x", "type": "function", "function": {"name": "f", "arguments": "{}"}} + ] + msg = _chat_message_from_dict( + { + "role": "assistant", + "content": None, + "tool_calls": tool_calls, + "tool_call_id": None, + } + ) + assert msg.role == "assistant" + assert msg.content is None + assert msg.tool_calls == tool_calls + assert msg.tool_call_id is None + + +@pytest.mark.unit +def test_chat_message_content_optional(): + """ChatMessage accepts content=None for tool-dispatching assistant turns.""" + msg = ChatMessage(role="assistant", tool_calls=[]) + assert msg.content is None From 4a135ff1b943278e273b5a4b30d7af51460b77b1 Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Thu, 23 Apr 2026 15:01:35 -0700 Subject: [PATCH 02/41] feat: add ConversationManager and MultiTurnStrategy Add per-conversation asyncio.Event sequencing (ConversationManager), async turn pipeline (MultiTurnStrategy), and benchmark execution wiring (execute.py, session.py PhaseIssuer data_override). --- .../commands/benchmark/execute.py | 44 +- .../dataset_manager/__init__.py | 2 + .../dataset_manager/multi_turn_dataset.py | 198 +++----- .../dataset_manager/transforms.py | 24 + .../load_generator/conversation_manager.py | 356 +++++++++++++++ .../load_generator/multi_turn_strategy.py | 229 ++++++++++ .../load_generator/session.py | 30 +- .../load_generator/strategy.py | 21 +- tests/integration/test_multi_turn.py | 425 ++++++++++++++++++ .../test_multi_turn_dataset.py | 142 +++--- tests/unit/dataset_manager/test_transforms.py | 50 ++- .../test_multi_turn_conversation_manager.py | 396 ++++++++++++++++ .../test_multi_turn_strategy.py | 279 ++++++++++++ 13 files changed, 1972 insertions(+), 224 deletions(-) create mode 100644 src/inference_endpoint/load_generator/conversation_manager.py create mode 100644 src/inference_endpoint/load_generator/multi_turn_strategy.py create mode 100644 tests/integration/test_multi_turn.py create mode 100644 tests/unit/load_generator/test_multi_turn_conversation_manager.py create mode 100644 tests/unit/load_generator/test_multi_turn_strategy.py diff --git a/src/inference_endpoint/commands/benchmark/execute.py b/src/inference_endpoint/commands/benchmark/execute.py index 34a9fd40..1efe1a3a 100644 --- a/src/inference_endpoint/commands/benchmark/execute.py +++ b/src/inference_endpoint/commands/benchmark/execute.py @@ -72,6 +72,7 @@ from inference_endpoint.core.types import QueryResult from inference_endpoint.dataset_manager.dataset import Dataset from inference_endpoint.dataset_manager.factory import DataLoaderFactory +from inference_endpoint.dataset_manager.multi_turn_dataset import MultiTurnDataset from inference_endpoint.endpoint_client.cpu_affinity import AffinityPlan, pin_loadgen from inference_endpoint.endpoint_client.http_client import HTTPEndpointClient from inference_endpoint.endpoint_client.http_sample_issuer import HttpClientSampleIssuer @@ -82,6 +83,8 @@ InputValidationError, SetupError, ) +from inference_endpoint.load_generator.conversation_manager import ConversationManager +from inference_endpoint.load_generator.multi_turn_strategy import MultiTurnStrategy from inference_endpoint.load_generator.session import ( BenchmarkSession, PhaseConfig, @@ -354,14 +357,21 @@ def setup_benchmark(config: BenchmarkConfig, test_mode: TestMode) -> BenchmarkCo ) -def _build_phases(ctx: BenchmarkContext) -> list[PhaseConfig]: +def _build_phases( + ctx: BenchmarkContext, + perf_strategy: MultiTurnStrategy | None = None, +) -> list[PhaseConfig]: """Build the phase list from BenchmarkContext.""" phases: list[PhaseConfig] = [] # Performance phase phases.append( PhaseConfig( - "performance", ctx.rt_settings, ctx.dataloader, PhaseType.PERFORMANCE + "performance", + ctx.rt_settings, + ctx.dataloader, + PhaseType.PERFORMANCE, + strategy=perf_strategy, ) ) @@ -524,16 +534,42 @@ async def _run_benchmark_async( launcher.kill_all() raise SetupError(f"Failed to connect to endpoint: {e}") from e + # Build multi-turn strategy if the performance dataset is a MultiTurnDataset. + multi_turn_strategy: MultiTurnStrategy | None = None + if isinstance(ctx.dataloader, MultiTurnDataset): + mt_cfg = None + if ctx.config.datasets: + perf_ds_cfg = next( + ( + d + for d in ctx.config.datasets + if d.type == DatasetType.PERFORMANCE + ), + None, + ) + if perf_ds_cfg is not None: + mt_cfg = perf_ds_cfg.multi_turn + multi_turn_strategy = MultiTurnStrategy( + conversation_manager=ConversationManager(), + dataset_metadata=ctx.dataloader.conversation_metadata, + multi_turn_config=mt_cfg, + ) + + def _on_sample_complete(result: QueryResult) -> None: + if multi_turn_strategy is not None: + multi_turn_strategy.on_sample_complete(result) + collector.on_complete_hook(result) + # Create session session = BenchmarkSession( issuer=issuer, event_publisher=publisher, loop=loop, - on_sample_complete=collector.on_complete_hook, + on_sample_complete=_on_sample_complete, session_id=session_id, ) - phases = _build_phases(ctx) + phases = _build_phases(ctx, perf_strategy=multi_turn_strategy) report: Report | None = None loop.add_signal_handler(signal.SIGINT, session.stop) diff --git a/src/inference_endpoint/dataset_manager/__init__.py b/src/inference_endpoint/dataset_manager/__init__.py index 12938f8e..403b8730 100644 --- a/src/inference_endpoint/dataset_manager/__init__.py +++ b/src/inference_endpoint/dataset_manager/__init__.py @@ -30,6 +30,7 @@ from .predefined.random import RandomDataset from .predefined.shopify_product_catalogue import ShopifyProductCatalogue from .transforms import ( + AddDefaultColumns, AddStaticColumns, ColumnFilter, ColumnRemap, @@ -46,6 +47,7 @@ "DataLoaderFactory", "ColumnFilter", "ColumnRemap", + "AddDefaultColumns", "AddStaticColumns", "UserPromptFormatter", "FusedRowProcessor", diff --git a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py index bff8c011..f26c3d3e 100644 --- a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py +++ b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py @@ -19,56 +19,15 @@ import pandas as pd -from ..config.schema import APIType, ModelParams, StreamingMode +from ..config.schema import APIType, ModelParams from ..exceptions import InputValidationError from .dataset import Dataset -from .transforms import apply_transforms - -# Known generation parameter fields to forward from dataset to API requests. -# Aligned with OpenAI API specification and openai_msgspec_adapter.py implementation. -# These parameters work in both single-turn and multi-turn modes. -GENERATION_PARAMS = { - "model", - "max_new_tokens", - "max_completion_tokens", - "stream", - "temperature", - "top_p", - "top_k", - "seed", - "repetition_penalty", - "frequency_penalty", - "presence_penalty", - "stop", - "n", - "logit_bias", # Token probability adjustments - "name", # Entity name for role (NOT model name, e.g., 'Bob' for tracking) - "user", # End-user identifier for monitoring/abuse detection - "chat_template", # Custom chat formatting template - "tools", # OpenAI tool definitions (list[dict]) for tool-calling models -} - - -def _model_param_defaults(model_params: ModelParams | None) -> dict[str, Any]: - """Build per-request defaults for multi-turn rows from model params. - - Multi-turn datasets use `content` and conversation metadata rather than the - single-turn `prompt` field expected by adapter dataset transforms. Applying - those transforms would drop the conversation schema before load_sample() can - construct the messages array. Instead, we inject the request defaults here. - """ - if model_params is None: - return {} - - return { - "model": model_params.name, - "stream": model_params.streaming == StreamingMode.ON, - "max_completion_tokens": model_params.max_new_tokens, - "temperature": model_params.temperature, - "top_p": model_params.top_p, - "top_k": model_params.top_k, - "repetition_penalty": model_params.repetition_penalty, - } +from .transforms import ( + AddDefaultColumns, + AddStaticColumns, + apply_transforms, + get_transforms_for_api_type, +) def _expand_tool_results(row: dict) -> list[dict]: @@ -113,7 +72,7 @@ class MultiTurnDataset(Dataset, dataset_id="multi_turn_conversations"): Optional columns: - system: System prompt associated with the conversation (typically set on the first user turn) - model: Model name override - - max_new_tokens: Max tokens for this turn + - max_new_tokens / max_completion_tokens: Max tokens for this turn (alias; mapped to max_completion_tokens) Attributes: conversation_metadata: Metadata dict containing: @@ -139,7 +98,6 @@ def __init__(self, dataframe: pd.DataFrame, **kwargs): self._validate_conversation_structure() self._validate_turn_numbering() self.conversation_metadata = self._build_metadata() - self._client_turn_indices: list[int] | None = None def _validate_conversation_grouping(self) -> None: """Validate that all rows for each conversation_id appear consecutively in file order. @@ -321,23 +279,18 @@ def load( model_params: ModelParams | None = None, force: bool = False, ): - """Load dataset and build a dense user-turn index. + """Load dataset, apply adapter defaults, and pre-bake client-turn samples. - Multi-turn benchmarks only issue user turns. Assistant turns remain in the - backing data so the conversation structure can still be validated. + Unlike single-turn datasets, multi-turn rows do not have a `prompt` column, + so ColumnFilter (which requires prompt) is skipped. AddStaticColumns entries + from the adapter are applied via AddDefaultColumns (fill-missing-only) so that + per-row dataset overrides are preserved. - Unlike single-turn datasets, multi-turn rows do not have a `prompt` - column, so adapter dataset transforms are intentionally skipped here. - They would apply a single-turn ColumnFilter and strip the conversation - fields required by load_sample(). Request defaults from model_params are - merged directly into the conversation rows instead. + After transforms, only client turns (user + tool) are stored in self.data as + fully assembled sample dicts (with messages, current_turn_message, system_content + attached). load_sample() and num_samples() are inherited from the base class. """ if not force and self.data is not None: - self._client_turn_indices = [ - index - for index, row in enumerate(self.data) - if row["role"] in ("user", "tool") - ] return df = self.dataframe @@ -353,76 +306,57 @@ def load( if transforms: df = apply_transforms(df, transforms) - defaults = _model_param_defaults(model_params) - for key, value in defaults.items(): - if value is None: + # Extract AddStaticColumns defaults from adapter transforms and apply as + # fill-missing-only (preserves per-row dataset values). + if api_type is not None and model_params is not None: + adapter_transforms = get_transforms_for_api_type(api_type, model_params) + defaults: dict[str, Any] = {} + for t in adapter_transforms: + if isinstance(t, AddStaticColumns): + defaults.update(t.data) + if defaults: + df = AddDefaultColumns(defaults)(df) + + all_rows = df.to_dict(orient="records") + + # Pre-bake: assemble one complete sample dict per client turn. + # NaN filtering replaces the GENERATION_PARAMS allowlist β€” any key whose + # value is float NaN was absent in the original dataset row. + pre_built = self.conversation_metadata.get("pre_built_messages_by_key", {}) + client_turn_samples: list[dict[str, Any]] = [] + + for row in all_rows: + if row.get("role") not in ("user", "tool"): continue - if key in df.columns: - df[key] = df[key].where(pd.notna(df[key]), value) - else: - df[key] = value - - self.data = df.to_dict(orient="records") - assert self.data is not None, "Failed to convert DataFrame to records" - - self._client_turn_indices = [ - index - for index, row in enumerate(self.data) - if row["role"] in ("user", "tool") - ] - - def load_sample(self, index: int) -> dict[str, Any]: - """Load the Nth client turn (user or tool) as a benchmark sample.""" - assert self.data is not None, "Dataset not loaded. Call load() first." - assert ( - self._client_turn_indices is not None - ), "Dataset not loaded. Call load() first." - row = self.data[self._client_turn_indices[index]] - - content_val = row.get("content") - sample: dict[str, Any] = { - "conversation_id": row["conversation_id"], - "turn": row["turn"], - "role": row["role"], - } - if content_val is not None and not ( - isinstance(content_val, float) and pd.isna(content_val) - ): - sample["content"] = content_val - - for param in GENERATION_PARAMS: - if param in row: - value = row[param] - # Skip pandas NaN/None values - if value is not None and ( - not isinstance(value, float) or not pd.isna(value) - ): - sample[param] = value - - # Set defaults for critical params if not present - if "max_new_tokens" not in sample and "max_completion_tokens" not in sample: - sample["max_new_tokens"] = 128 - if "stream" not in sample: - sample["stream"] = False - - # Attach pre-built message list (system + history + current turn). - key = (row["conversation_id"], int(row["turn"])) - pre_built = self.conversation_metadata.get("pre_built_messages_by_key", {}).get( - key, [] - ) - sample["pre_built_messages"] = pre_built - # Fields for use_dataset_history=False path (live history accumulation). - sample["current_turn_message"] = pre_built[-1] if pre_built else {} - first = pre_built[0] if pre_built else {} - sample["system_content"] = ( - first.get("content") if first.get("role") == "system" else None - ) + # Filter NaN values; keep all meaningful fields (extra keys are harmless + # since adapters consume only what they recognize). + sample: dict[str, Any] = { + k: v + for k, v in row.items() + if v is not None and not (isinstance(v, float) and pd.isna(v)) + } + + # max_new_tokens β†’ max_completion_tokens alias + if "max_completion_tokens" not in sample and "max_new_tokens" in sample: + sample["max_completion_tokens"] = sample.pop("max_new_tokens") + if "max_completion_tokens" not in sample: + sample["max_completion_tokens"] = 128 + if "stream" not in sample: + sample["stream"] = False + + # Attach pre-built message list (system + history + current turn). + key = (row["conversation_id"], int(row["turn"])) + messages = pre_built.get(key, []) + sample["messages"] = messages + + # Fields for use_dataset_history=False path (live history accumulation). + sample["current_turn_message"] = messages[-1] if messages else {} + first = messages[0] if messages else {} + sample["system_content"] = ( + first.get("content") if first.get("role") == "system" else None + ) - return sample + client_turn_samples.append(sample) - def num_samples(self) -> int: - assert ( - self._client_turn_indices is not None - ), "Dataset not loaded. Call load() first." - return len(self._client_turn_indices) + self.data = client_turn_samples diff --git a/src/inference_endpoint/dataset_manager/transforms.py b/src/inference_endpoint/dataset_manager/transforms.py index 79133796..a288da6d 100644 --- a/src/inference_endpoint/dataset_manager/transforms.py +++ b/src/inference_endpoint/dataset_manager/transforms.py @@ -127,6 +127,30 @@ def __call__(self, df: pd.DataFrame) -> pd.DataFrame: return df +class AddDefaultColumns(Transform): + """Add columns only where values are missing (NaN or absent). + + Unlike AddStaticColumns which unconditionally overwrites, this preserves + existing non-null values β€” dataset per-row overrides take precedence over + the supplied defaults. + """ + + def __init__(self, data: dict[str, Any]): + """Initialize the AddDefaultColumns transform.""" + self.data = data + + def __call__(self, df: pd.DataFrame) -> pd.DataFrame: + """Fill missing columns with defaults without overwriting existing values.""" + for key, value in self.data.items(): + if value is None: + continue + if key in df.columns: + df[key] = df[key].where(pd.notna(df[key]), value) + else: + df[key] = value + return df + + class Harmonize(RowProcessor): """Transform to convert a user prompt to an OpenAI Harmony-compatible format.""" diff --git a/src/inference_endpoint/load_generator/conversation_manager.py b/src/inference_endpoint/load_generator/conversation_manager.py new file mode 100644 index 00000000..86741711 --- /dev/null +++ b/src/inference_endpoint/load_generator/conversation_manager.py @@ -0,0 +1,356 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Async conversation state management for multi-turn benchmarking.""" + +import asyncio +import logging +from dataclasses import dataclass, field +from typing import Any + +logger = logging.getLogger(__name__) + + +@dataclass +class ConversationState: + """Tracks conversation sequencing for multi-turn benchmarking. + + Maintains turn counters and asyncio conditions so the strategy can enforce + sequential turn ordering within a conversation. Message history is NOT stored + here β€” it is pre-computed in MultiTurnDataset and served via load_sample(). + + Attributes: + conversation_id: Unique identifier for this conversation. + current_turn: Last completed turn number (0 = not started). + pending_client_turn: Turn number of in-flight client turn (None if idle). + expected_client_turns: Expected number of client turns (for completion tracking). + issued_client_turns: Count of client turns issued. + completed_client_turns: Count of client turns with responses. + failed_client_turns: Count of client turns that failed (error/timeout). + message_history: Accumulated message list (only populated when + use_dataset_history=False; empty otherwise). + condition: Per-conversation asyncio.Condition for turn-ready and turn-issued waits. + Scoped to this conversation so that state changes only wake the single + pipeline task waiting on this conversation, not all pipeline tasks. + """ + + conversation_id: str + current_turn: int = 0 + pending_client_turn: int | None = None + expected_client_turns: int | None = None + issued_client_turns: int = 0 + completed_client_turns: int = 0 + failed_client_turns: int = 0 + message_history: list[dict[str, Any]] = field(default_factory=list) + condition: asyncio.Condition = field(default_factory=asyncio.Condition) + + def add_client_turn(self, turn: int, message: dict[str, Any] | None = None): + """Record that a client turn has been issued (updates sequencing counters). + + Args: + turn: Turn number for this client message. + message: Message dict to append to message_history (only used when + use_dataset_history=False). + """ + self.pending_client_turn = turn + self.issued_client_turns += 1 + if message is not None: + self.message_history.append(message) + + def add_assistant_turn(self, content: str | None = None): + """Record assistant response and mark turn complete (success). + + Args: + content: Response content to append to message_history. Only + used when use_dataset_history=False; None means no history + update (pre-built messages path). + """ + if content is not None: + self.message_history.append({"role": "assistant", "content": content}) + if self.pending_client_turn is not None: + self.current_turn = self.pending_client_turn + 1 + self.pending_client_turn = None + self.completed_client_turns += 1 + elif self.is_complete(): + pass + else: + logger.warning( + f"Received assistant response for {self.conversation_id} " + f"with no pending client turn (duplicate or out-of-order response)" + ) + self.current_turn = self.current_turn + 1 if self.current_turn > 0 else 1 + self.completed_client_turns += 1 + + if self.is_complete(): + if self.failed_client_turns > 0: + logger.info( + f"Conversation {self.conversation_id} completed with failures: " + f"{self.completed_client_turns - self.failed_client_turns}/" + f"{self.expected_client_turns} successful, " + f"{self.failed_client_turns} failed" + ) + else: + logger.debug( + f"Conversation {self.conversation_id} completed: " + f"{self.completed_client_turns}/{self.expected_client_turns} turns" + ) + + def mark_turn_failed(self, store_in_history: bool = False): + """Mark turn as failed (error/timeout) - still counts as completed for sequencing.""" + if self.pending_client_turn is not None: + self.current_turn = self.pending_client_turn + 1 + self.pending_client_turn = None + self.completed_client_turns += 1 + self.failed_client_turns += 1 + + if store_in_history: + self.message_history.append( + { + "role": "assistant", + "content": "[ERROR: Turn failed or timed out]", + } + ) + + logger.warning( + f"Turn {self.current_turn - 1} failed for conversation {self.conversation_id}" + ) + else: + logger.warning( + f"Attempted to mark failed turn for {self.conversation_id} " + f"with no pending client turn" + ) + + if self.is_complete(): + logger.info( + f"Conversation {self.conversation_id} completed with failures: " + f"{self.completed_client_turns - self.failed_client_turns}/" + f"{self.expected_client_turns} successful, " + f"{self.failed_client_turns} failed" + ) + + def is_complete(self) -> bool: + """Check if conversation is complete (all turns issued and responses received).""" + if self.expected_client_turns is None: + return False + return self.completed_client_turns >= self.expected_client_turns + + def is_ready_for_turn(self) -> bool: + """Check if the previous turn has completed and the next may be issued.""" + return ( + self.pending_client_turn is None + and self.issued_client_turns == self.completed_client_turns + and self.issued_client_turns > 0 + ) + + +class ConversationManager: + """Manages conversation sequencing for multi-turn benchmarking. + + Async manager that tracks multiple conversations and enforces turn ordering. + Conversations are identified by unique IDs. Message history is NOT maintained here + β€” it is pre-computed in MultiTurnDataset and passed directly to each request. + + The manager ensures that: + - Turn N+1 cannot be issued until turn N completes + - Concurrent access to conversation state is async-safe + + Each ConversationState carries its own asyncio.Condition so that state changes + (turn issued / turn complete) only wake the single pipeline task waiting + on that conversation, not all pipeline tasks across all conversations. + All conversation states are pre-created by the strategy before pipeline + tasks start, so wait_for_turn_issued never races against get_or_create. + """ + + def __init__(self): + """Initialize conversation manager with empty state.""" + self._conversations: dict[str, ConversationState] = {} + self._lock = asyncio.Lock() + + def get_state(self, conversation_id: str) -> ConversationState | None: + """Get conversation state without creating (for read-only access).""" + return self._conversations.get(conversation_id) + + async def get_or_create( + self, + conversation_id: str, + expected_client_turns: int | None = None, + system_message: dict[str, Any] | None = None, + ) -> ConversationState: + """Get existing or create new conversation state. + + Args: + conversation_id: Unique identifier for conversation. + expected_client_turns: Expected number of client turns (for completion tracking). + system_message: System message dict to pre-populate message_history with. + Only used when use_dataset_history=False and conversation is new. + + Returns: + ConversationState for this conversation. + """ + async with self._lock: + if conversation_id not in self._conversations: + initial_history: list[dict[str, Any]] = ( + [system_message] if system_message is not None else [] + ) + state = ConversationState( + conversation_id=conversation_id, + current_turn=0, + pending_client_turn=None, + expected_client_turns=expected_client_turns, + issued_client_turns=0, + completed_client_turns=0, + failed_client_turns=0, + message_history=initial_history, + ) + self._conversations[conversation_id] = state + return self._conversations[conversation_id] + + async def wait_for_turn_ready( + self, conversation_id: str, turn: int, timeout: float | None = None + ) -> bool: + """Block until conversation is ready for this turn. + + Uses the per-conversation asyncio.Condition so only this conversation's pipeline + task is woken on state changes, not all pipeline tasks. + + Args: + conversation_id: Conversation to wait for. + turn: Turn number to wait for (unused in readiness check; kept for + call-site compatibility). + timeout: Maximum seconds to wait (None = infinite). + + Returns: + True if ready, False if timeout. + + Raises: + KeyError: If conversation_id not found in manager. + """ + state = self._conversations.get(conversation_id) + if state is None: + logger.error(f"Conversation {conversation_id} not found in manager") + raise KeyError(f"Conversation {conversation_id} not initialized") + + async with state.condition: + if timeout is None: + await state.condition.wait_for(state.is_ready_for_turn) + return True + try: + async with asyncio.timeout(timeout): + await state.condition.wait_for(state.is_ready_for_turn) + return True + except TimeoutError: + return state.is_ready_for_turn() + + async def wait_for_turn_issued( + self, + conversation_id: str, + min_issued: int, + timeout: float | None = None, + ) -> bool: + """Block until at least min_issued client turns have been issued. + + Args: + conversation_id: Conversation to wait for. + min_issued: Minimum number of issued turns to wait for. + timeout: Maximum seconds to wait (None = infinite). + + Returns: + True if condition met, False if timeout. + + Raises: + KeyError: If conversation_id not found (programming error β€” state must be + pre-created by the strategy before pipeline tasks are spawned). + """ + state = self._conversations[conversation_id] + predicate = lambda: state.issued_client_turns >= min_issued # noqa: E731 + async with state.condition: + if timeout is None: + await state.condition.wait_for(predicate) + return True + try: + async with asyncio.timeout(timeout): + await state.condition.wait_for(predicate) + return True + except TimeoutError: + return state.issued_client_turns >= min_issued + + async def mark_turn_issued( + self, + conversation_id: str, + turn: int, + message: dict[str, Any] | None = None, + ): + """Mark that a client turn has been issued (updates sequencing counters). + + Args: + conversation_id: Conversation ID. + turn: Turn number being issued. + message: Message dict to append to history (used when + use_dataset_history=False). + + Raises: + KeyError: If conversation_id not found in manager. + """ + state = self._conversations.get(conversation_id) + if state is None: + raise KeyError(f"Conversation {conversation_id} not initialized") + async with state.condition: + state.add_client_turn(turn, message) + state.condition.notify_all() + + async def mark_turn_complete( + self, + conversation_id: str, + response: str, + store_in_history: bool = False, + ): + """Mark that assistant response has arrived. + + Args: + conversation_id: Conversation ID. + response: Model output (stored in history when store_in_history=True). + store_in_history: When True, append response to message_history. + + Raises: + KeyError: If conversation_id not found in manager. + """ + state = self._conversations.get(conversation_id) + if state is None: + raise KeyError(f"Conversation {conversation_id} not initialized") + async with state.condition: + state.add_assistant_turn(response if store_in_history else None) + state.condition.notify_all() + + async def mark_turn_failed( + self, conversation_id: str, store_in_history: bool = False + ): + """Mark that assistant response failed (error/timeout). + + Failed turns still count toward conversation completion to ensure + turn sequencing progresses even under errors. + + Args: + conversation_id: Conversation ID. + store_in_history: When True, append error placeholder to message_history. + + Raises: + KeyError: If conversation_id not found in manager. + """ + state = self._conversations.get(conversation_id) + if state is None: + raise KeyError(f"Conversation {conversation_id} not initialized") + async with state.condition: + state.mark_turn_failed(store_in_history=store_in_history) + state.condition.notify_all() diff --git a/src/inference_endpoint/load_generator/multi_turn_strategy.py b/src/inference_endpoint/load_generator/multi_turn_strategy.py new file mode 100644 index 00000000..48f5b45f --- /dev/null +++ b/src/inference_endpoint/load_generator/multi_turn_strategy.py @@ -0,0 +1,229 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Async multi-turn load strategy implementing the LoadStrategy protocol.""" + +import asyncio +import logging +from collections import defaultdict +from typing import Any + +from ..config.schema import MultiTurnConfig +from ..core.types import QueryResult +from .conversation_manager import ConversationManager +from .strategy import PhaseIssuerProtocol + +logger = logging.getLogger(__name__) + +# Default turn timeout when no MultiTurnConfig is provided. +_DEFAULT_TURN_TIMEOUT_S = 300.0 + + +class MultiTurnStrategy: + """Async multi-turn strategy. Spawns per-conversation asyncio.Tasks. + + Each conversation runs as an independent asyncio.Task that enforces + sequential turn ordering: turn N+1 cannot be issued until turn N completes. + Conversations run concurrently β€” no cross-conversation synchronization. + + Optional target_concurrency limits total in-flight requests across all + conversations using asyncio.Semaphore. + + Integration with BenchmarkSession: + - execute(): spawns conversation tasks, awaits all to complete + - on_query_complete(): releases semaphore slot (concurrency control only) + - on_sample_complete(): routes completed QueryResult to ConversationManager + + The response routing path: + 1. _conv_pipeline issues turn N via phase_issuer.issue(idx) β†’ query_id + 2. _conv_pipeline stores (conv_id, turn) in _inflight[query_id] + 3. BenchmarkSession calls on_sample_complete(result) with the QueryResult + 4. on_sample_complete looks up conv_id from _inflight, calls mark_turn_complete + 5. mark_turn_complete notifies the pipeline task waiting on wait_for_turn_ready + 6. _conv_pipeline proceeds to issue turn N+1 + """ + + def __init__( + self, + conversation_manager: ConversationManager, + dataset_metadata: dict[str, Any], + multi_turn_config: MultiTurnConfig | None = None, + target_concurrency: int | None = None, + ): + """Initialize multi-turn strategy. + + Args: + conversation_manager: Manages conversation sequencing state. + dataset_metadata: Metadata from MultiTurnDataset (samples list). + multi_turn_config: Multi-turn conversation configuration. + target_concurrency: Optional maximum concurrent in-flight requests. + """ + self._conv_manager = conversation_manager + self._dataset_metadata = dataset_metadata + self._multi_turn_config = multi_turn_config + self._turn_timeout_s = ( + multi_turn_config.turn_timeout_s + if multi_turn_config is not None + else _DEFAULT_TURN_TIMEOUT_S + ) + self._target_concurrency = target_concurrency + self._sem: asyncio.Semaphore | None = ( + asyncio.Semaphore(target_concurrency) + if target_concurrency is not None and target_concurrency > 0 + else None + ) + self._store_in_history = ( + not multi_turn_config.use_dataset_history + if multi_turn_config is not None + else False + ) + + # Maps query_id -> conversation_id for routing completions. + # Populated by _conv_pipeline after issue() returns query_id. + self._inflight: dict[str, str] = {} + + async def execute(self, phase_issuer: PhaseIssuerProtocol) -> int: + """Drive multi-turn sample issuance. + + Args: + phase_issuer: Interface for issuing samples to the endpoint. + + Returns: + Total count of samples issued. + """ + conv_samples: dict[str, list[tuple[int, int]]] = defaultdict(list) + for sample_index, sample_meta in enumerate(self._dataset_metadata["samples"]): + conv_id = sample_meta["conversation_id"] + conv_samples[conv_id].append((sample_index, sample_meta["turn"])) + + # Pre-create all conversation states before spawning tasks. + for conv_id, turns in conv_samples.items(): + await self._conv_manager.get_or_create( + conv_id, expected_client_turns=len(turns) + ) + + tasks = [ + asyncio.create_task( + self._conv_pipeline(conv_id, turns, phase_issuer), + name=f"mt-pipeline-{conv_id}", + ) + for conv_id, turns in conv_samples.items() + ] + + await asyncio.gather(*tasks, return_exceptions=True) + return phase_issuer.issued_count + + async def _conv_pipeline( + self, + conv_id: str, + turns: list[tuple[int, int]], + phase_issuer: PhaseIssuerProtocol, + ) -> None: + """Process all turns for a single conversation sequentially. + + For each turn after the first, waits for the previous turn to complete + (via wait_for_turn_ready) before issuing the next. This enforces strict + sequential ordering: turn N+1 is not issued until turn N's response arrives. + """ + sorted_turns = sorted(turns, key=lambda x: x[1]) + + for i, (idx, turn) in enumerate(sorted_turns): + if i > 0: + # Wait for the previous turn to complete before issuing the next. + ready = await self._conv_manager.wait_for_turn_ready( + conv_id, turn, timeout=self._turn_timeout_s + ) + if not ready: + logger.warning( + f"Turn {turn} of {conv_id} timed out waiting for previous turn" + ) + await self._conv_manager.mark_turn_failed(conv_id) + break + + # Acquire concurrency slot before issuing + if self._sem is not None: + await self._sem.acquire() + + # For live-history mode: build messages from accumulated history + current turn, + # and pass as data_override so the pre-built messages from the dataset are replaced. + data_override: dict[str, Any] | None = None + current_turn_message: dict[str, Any] | None = None + if self._store_in_history: + pre_built = self._dataset_metadata.get( + "pre_built_messages_by_key", {} + ).get((conv_id, turn), []) + current_turn_message = pre_built[-1] if pre_built else None + state = self._conv_manager.get_state(conv_id) + if state is not None and current_turn_message is not None: + live_messages = state.message_history.copy() + [ + current_turn_message + ] + data_override = {"messages": live_messages} + + query_id = phase_issuer.issue(idx, data_override=data_override) + if query_id is None: + # Session stopping β€” release slot and exit + if self._sem is not None: + self._sem.release() + break + + # Register this query_id -> conv_id mapping for response routing. + self._inflight[query_id] = conv_id + + # Mark the turn as issued so wait_for_turn_ready can gate the next turn. + await self._conv_manager.mark_turn_issued( + conv_id, turn, message=current_turn_message + ) + + def on_query_complete(self, query_id: str) -> None: + """Called by BenchmarkSession when a QueryResult arrives. + + Releases the concurrency semaphore slot. Response routing is done + via on_sample_complete (which receives the full QueryResult). + + Args: + query_id: ID of the completed query. + """ + if self._sem is not None: + self._sem.release() + + def on_sample_complete(self, result: QueryResult) -> None: + """Route completed QueryResult to ConversationManager. + + Called by execute.py on_sample_complete hook after each response. + Looks up the conversation_id from _inflight and calls mark_turn_complete. + + Args: + result: Completed QueryResult from the endpoint. + """ + query_id = result.id + conv_id = self._inflight.pop(query_id, None) + if conv_id is None: + return + + response_text = result.get_response_output_string() + + if result.error is not None: + asyncio.ensure_future( + self._conv_manager.mark_turn_failed( + conv_id, store_in_history=self._store_in_history + ) + ) + else: + asyncio.ensure_future( + self._conv_manager.mark_turn_complete( + conv_id, response_text, store_in_history=self._store_in_history + ) + ) diff --git a/src/inference_endpoint/load_generator/session.py b/src/inference_endpoint/load_generator/session.py index 3be480cb..2e4f67ef 100644 --- a/src/inference_endpoint/load_generator/session.py +++ b/src/inference_endpoint/load_generator/session.py @@ -26,9 +26,9 @@ import time import uuid from collections.abc import Callable -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import Enum -from typing import Protocol +from typing import Any, Protocol from ..config.runtime_settings import RuntimeSettings from ..core.record import ( @@ -60,7 +60,7 @@ class PhaseType(str, Enum): WARMUP = "warmup" -@dataclass(frozen=True, slots=True) +@dataclass(frozen=True) class PhaseConfig: """Configuration for a single benchmark phase.""" @@ -68,6 +68,7 @@ class PhaseConfig: runtime_settings: RuntimeSettings dataset: Dataset phase_type: PhaseType = PhaseType.PERFORMANCE + strategy: LoadStrategy | None = field(default=None, compare=False) # --------------------------------------------------------------------------- @@ -172,11 +173,19 @@ def __init__( self.inflight: int = 0 self.issued_count: int = 0 - def issue(self, sample_index: int) -> str | None: + def issue( + self, sample_index: int, data_override: dict[str, Any] | None = None + ) -> str | None: """Load data, build Query, publish ISSUED, send to endpoint. Returns query_id on success, None if session is stopping. + Args: + sample_index: Index into the dataset. + data_override: If provided, merged over the loaded sample data. + Keys in data_override take precedence. Used by MultiTurnStrategy + to substitute live-accumulated message history. + Note: load_sample() runs synchronously before the ISSUED timestamp. For accurate timing, datasets MUST be pre-loaded into memory. Disk-backed datasets will inflate timing and delay subsequent issues. @@ -185,6 +194,8 @@ def issue(self, sample_index: int) -> str | None: return None query_id = uuid.uuid4().hex data = self._dataset.load_sample(sample_index) + if data_override is not None: + data = {**data, **data_override} query = Query(id=query_id, data=data) self.uuid_to_index[query_id] = sample_index ts = time.monotonic_ns() @@ -313,10 +324,13 @@ async def _run_phase(self, phase: PhaseConfig) -> PhaseResult | None: phase_start = time.monotonic_ns() # Create per-phase state - sample_order = create_sample_order(phase.runtime_settings) - strategy = create_load_strategy( - phase.runtime_settings, self._loop, sample_order - ) + if phase.strategy is not None: + strategy = phase.strategy + else: + sample_order = create_sample_order(phase.runtime_settings) + strategy = create_load_strategy( + phase.runtime_settings, self._loop, sample_order + ) phase_issuer = PhaseIssuer( dataset=phase.dataset, issuer=self._issuer, diff --git a/src/inference_endpoint/load_generator/strategy.py b/src/inference_endpoint/load_generator/strategy.py index dd311f10..8ee13722 100644 --- a/src/inference_endpoint/load_generator/strategy.py +++ b/src/inference_endpoint/load_generator/strategy.py @@ -29,7 +29,7 @@ import logging from collections.abc import Callable, Iterator from time import monotonic_ns -from typing import Protocol +from typing import Any, Protocol from ..config.runtime_settings import RuntimeSettings from ..config.schema import LoadPatternType @@ -47,8 +47,17 @@ class PhaseIssuerProtocol(Protocol): """Minimal interface that strategies see for issuing samples.""" - def issue(self, sample_index: int) -> str | None: - """Issue a sample. Returns query_id, or None if the session is stopping.""" + def issue( + self, sample_index: int, data_override: dict[str, Any] | None = None + ) -> str | None: + """Issue a sample. Returns query_id, or None if the session is stopping. + + Args: + sample_index: Index into the dataset. + data_override: If provided, use this as Query.data instead of + loading from the dataset. Used by MultiTurnStrategy for + live-history mode where the messages array is built at runtime. + """ ... issued_count: int @@ -297,5 +306,11 @@ def create_load_strategy( ) return ConcurrencyStrategy(lp.target_concurrency, sample_order) + case LoadPatternType.MULTI_TURN: + raise ValueError( + "MULTI_TURN load pattern requires a MultiTurnDataset β€” " + "use 'inference-endpoint benchmark from-config' with a multi-turn dataset" + ) + case _: raise ValueError(f"Unsupported load pattern type: {lp.type}") diff --git a/tests/integration/test_multi_turn.py b/tests/integration/test_multi_turn.py new file mode 100644 index 00000000..5a4d128d --- /dev/null +++ b/tests/integration/test_multi_turn.py @@ -0,0 +1,425 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for multi-turn benchmarking end-to-end. + +Validates that MultiTurnDataset + MultiTurnStrategy + BenchmarkSession work +correctly together against a real HTTP echo server. + +Tests cover: + 1. Dataset-history mode (use_dataset_history=True): pre-built messages are + issued as-is; each turn is issued sequentially per conversation. + 2. Live-history mode (use_dataset_history=False): messages are built at + runtime from ConversationManager.message_history; the injected messages + grow with each turn. + 3. Multiple concurrent conversations complete successfully. + 4. Turn ordering: turn N+1 is never issued before turn N completes. +""" + +import asyncio +import random +import time +from urllib.parse import urljoin + +import pandas as pd +import pytest +from inference_endpoint import metrics +from inference_endpoint.config.runtime_settings import RuntimeSettings +from inference_endpoint.config.schema import ( + LoadPattern, + LoadPatternType, + MultiTurnConfig, +) +from inference_endpoint.core.record import EventRecord +from inference_endpoint.core.types import QueryResult +from inference_endpoint.dataset_manager.multi_turn_dataset import MultiTurnDataset +from inference_endpoint.endpoint_client.config import HTTPClientConfig +from inference_endpoint.endpoint_client.http_client import HTTPEndpointClient +from inference_endpoint.endpoint_client.http_sample_issuer import HttpClientSampleIssuer +from inference_endpoint.load_generator.conversation_manager import ConversationManager +from inference_endpoint.load_generator.multi_turn_strategy import MultiTurnStrategy +from inference_endpoint.load_generator.session import ( + BenchmarkSession, + PhaseConfig, + PhaseType, +) +from inference_endpoint.testing.echo_server import EchoServer + + +class _NoOpPublisher: + def publish(self, event_record: EventRecord) -> None: + pass + + def flush(self) -> None: + pass + + +def _make_dataset(rows: list[dict]) -> MultiTurnDataset: + """Build a loaded MultiTurnDataset from a list of row dicts.""" + df = pd.DataFrame(rows) + ds = MultiTurnDataset(dataframe=df) + ds.load() + return ds + + +def _make_strategy( + ds: MultiTurnDataset, + use_dataset_history: bool = True, +) -> MultiTurnStrategy: + mt_cfg = MultiTurnConfig( + turn_timeout_s=10.0, + use_dataset_history=use_dataset_history, + ) + return MultiTurnStrategy( + conversation_manager=ConversationManager(), + dataset_metadata=ds.conversation_metadata, + multi_turn_config=mt_cfg, + ) + + +async def _run_session( + server_url: str, + ds: MultiTurnDataset, + strategy: MultiTurnStrategy, + responses_out: dict, +) -> int: + """Wire up HTTPEndpointClient + BenchmarkSession and run one phase. + + Populates responses_out[query_id] = response_text for every completed turn. + Returns issued_count. + """ + loop = asyncio.get_running_loop() + + def on_complete(result: QueryResult) -> None: + strategy.on_sample_complete(result) + responses_out[result.id] = result.get_response_output_string() + + http_config = HTTPClientConfig( + endpoint_urls=[urljoin(server_url, "/v1/chat/completions")], + warmup_connections=0, + num_workers=2, + ) + http_client = await HTTPEndpointClient.create(http_config, loop) + issuer = HttpClientSampleIssuer(http_client) + + try: + session = BenchmarkSession( + issuer=issuer, + event_publisher=_NoOpPublisher(), + loop=loop, + on_sample_complete=on_complete, + ) + rt = RuntimeSettings( + metrics.Throughput(1000), + [metrics.Throughput(1000)], + min_duration_ms=0, + max_duration_ms=30_000, + n_samples_from_dataset=ds.num_samples(), + n_samples_to_issue=ds.num_samples(), + min_sample_count=1, + rng_sched=random.Random(42), + rng_sample_index=random.Random(42), + load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), + ) + phase = PhaseConfig( + "perf", + rt, + ds, + PhaseType.PERFORMANCE, + strategy=strategy, + ) + result = await asyncio.wait_for(session.run([phase]), timeout=30.0) + return result.perf_results[0].issued_count + finally: + await http_client.shutdown_async() + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def echo_server(): + server = EchoServer(port=0) + server.start() + try: + yield server + finally: + server.stop() + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_single_conversation_all_turns_issued(echo_server): + """All turns of a single conversation are issued and completed.""" + rows = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "Hello"}, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "Hi"}, + {"conversation_id": "c1", "turn": 3, "role": "user", "content": "Bye"}, + ] + ds = _make_dataset(rows) + strategy = _make_strategy(ds) + responses: dict = {} + + count = await _run_session(echo_server.url, ds, strategy, responses) + + # Two user turns (turns 1 and 3); turn 2 is assistant so not a client turn + assert count == 2 + assert len(responses) == 2 + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_multiple_conversations_all_issued(echo_server): + """Multiple conversations complete independently and concurrently.""" + rows = [] + for conv_idx in range(3): + conv_id = f"conv_{conv_idx}" + rows.append( + { + "conversation_id": conv_id, + "turn": 1, + "role": "user", + "content": f"Q1 {conv_idx}", + } + ) + rows.append( + { + "conversation_id": conv_id, + "turn": 2, + "role": "assistant", + "content": f"A1 {conv_idx}", + } + ) + rows.append( + { + "conversation_id": conv_id, + "turn": 3, + "role": "user", + "content": f"Q2 {conv_idx}", + } + ) + ds = _make_dataset(rows) + strategy = _make_strategy(ds) + responses: dict = {} + + count = await _run_session(echo_server.url, ds, strategy, responses) + + # 3 conversations Γ— 2 user turns each = 6 + assert count == 6 + assert len(responses) == 6 + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_dataset_history_messages_present(echo_server): + """Dataset-history mode: each request contains the messages array from the dataset.""" + received_payloads: list[dict] = [] + + # Override get_response to capture the incoming request body. + # EchoServer._handle_echo_chat_completions_request parses it into + # CreateChatCompletionRequest β€” we capture the raw JSON at the HTTP layer + # by subclassing and overriding get_response (called with first user content). + # Instead, use a custom echo server that logs the full payload. + class CapturingEchoServer(EchoServer): + async def _handle_echo_chat_completions_request(self, request): + try: + payload = await request.json() + received_payloads.append(payload) + except Exception: + pass + return await super()._handle_echo_chat_completions_request(request) + + server = CapturingEchoServer(port=0) + server.start() + try: + rows = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "First question", + }, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": "First answer", + }, + { + "conversation_id": "c1", + "turn": 3, + "role": "user", + "content": "Second question", + }, + ] + ds = _make_dataset(rows) + strategy = _make_strategy(ds, use_dataset_history=True) + responses: dict = {} + + count = await _run_session(server.url, ds, strategy, responses) + assert count == 2 + + # Both requests must include a "messages" array + assert len(received_payloads) == 2 + for payload in received_payloads: + assert "messages" in payload + assert len(payload["messages"]) >= 1 + + # Turn 1 should have 1 user message; turn 3 should have 3 messages + # (system? no system here β€” user, assistant, user) + msg_counts = sorted(len(p["messages"]) for p in received_payloads) + assert msg_counts == [1, 3] + finally: + server.stop() + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_live_history_messages_grow_each_turn(echo_server): + """Live-history mode: messages array grows with each completed turn.""" + received_payloads: list[dict] = [] + + class CapturingEchoServer(EchoServer): + async def _handle_echo_chat_completions_request(self, request): + try: + payload = await request.json() + received_payloads.append(payload) + except Exception: + pass + return await super()._handle_echo_chat_completions_request(request) + + server = CapturingEchoServer(port=0) + server.start() + try: + rows = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "Turn one"}, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": "Answer one", + }, + {"conversation_id": "c1", "turn": 3, "role": "user", "content": "Turn two"}, + ] + ds = _make_dataset(rows) + strategy = _make_strategy(ds, use_dataset_history=False) + responses: dict = {} + + count = await _run_session(server.url, ds, strategy, responses) + assert count == 2 + + assert len(received_payloads) == 2 + msg_counts = sorted(len(p["messages"]) for p in received_payloads) + # Turn 1: [user msg] = 1; Turn 3: [user, assistant, user] = 3 + assert msg_counts == [1, 3] + finally: + server.stop() + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_turn_ordering_enforced_end_to_end(echo_server): + """Turn N+1 is issued after Turn N's response arrives, verified by timestamps.""" + complete_times: dict[str, float] = {} + + rows = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "First"}, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": "Response", + }, + {"conversation_id": "c1", "turn": 3, "role": "user", "content": "Second"}, + ] + ds = _make_dataset(rows) + mt_cfg = MultiTurnConfig(turn_timeout_s=10.0, use_dataset_history=True) + conv_manager = ConversationManager() + strategy = MultiTurnStrategy( + conversation_manager=conv_manager, + dataset_metadata=ds.conversation_metadata, + multi_turn_config=mt_cfg, + ) + + # Wrap on_sample_complete to record completion timestamps + orig_on_sample_complete = strategy.on_sample_complete + + def tracked_on_sample_complete(result: QueryResult) -> None: + # Map query_id β†’ sample_index via uuid_to_index (set after session runs) + complete_times[result.id] = time.monotonic() + orig_on_sample_complete(result) + + strategy.on_sample_complete = tracked_on_sample_complete + + loop = asyncio.get_running_loop() + responses: dict[str, str] = {} + + http_config = HTTPClientConfig( + endpoint_urls=[urljoin(echo_server.url, "/v1/chat/completions")], + warmup_connections=0, + num_workers=1, + ) + http_client = await HTTPEndpointClient.create(http_config, loop) + issuer = HttpClientSampleIssuer(http_client) + + rt = RuntimeSettings( + metrics.Throughput(1000), + [metrics.Throughput(1000)], + min_duration_ms=0, + max_duration_ms=30_000, + n_samples_from_dataset=ds.num_samples(), + n_samples_to_issue=ds.num_samples(), + min_sample_count=1, + rng_sched=random.Random(42), + rng_sample_index=random.Random(42), + load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), + ) + + try: + + def on_complete(result: QueryResult) -> None: + tracked_on_sample_complete(result) + responses[result.id] = result.get_response_output_string() + + session = BenchmarkSession( + issuer=issuer, + event_publisher=_NoOpPublisher(), + loop=loop, + on_sample_complete=on_complete, + ) + phase = PhaseConfig("perf", rt, ds, PhaseType.PERFORMANCE, strategy=strategy) + result = await asyncio.wait_for(session.run([phase]), timeout=30.0) + finally: + await http_client.shutdown_async() + + assert result.perf_results[0].issued_count == 2 + + # Build query_id β†’ sample_index from session result + uuid_to_index = result.perf_results[0].uuid_to_index + index_to_query = {v: k for k, v in uuid_to_index.items()} + + # Sample 0 = turn 1, sample 1 = turn 3 + q_turn1 = index_to_query[0] + q_turn3 = index_to_query[1] + + # Turn 3 must complete after turn 1 completes + assert complete_times[q_turn3] >= complete_times[q_turn1] diff --git a/tests/unit/dataset_manager/test_multi_turn_dataset.py b/tests/unit/dataset_manager/test_multi_turn_dataset.py index 09ccd224..a42fc1f3 100644 --- a/tests/unit/dataset_manager/test_multi_turn_dataset.py +++ b/tests/unit/dataset_manager/test_multi_turn_dataset.py @@ -128,8 +128,8 @@ def test_multi_turn_dataset_load_valid_data(valid_multi_turn_jsonl): ) dataset.load() - # Should have 5 rows total (3 for conv_001, 2 for conv_002) - assert len(dataset.data) == 5 + # data contains only client turns (3 user turns), not all rows + assert len(dataset.data) == 3 # Should have 3 user turns (samples) - only user turns are indexed assert dataset.num_samples() == 3 @@ -137,18 +137,18 @@ def test_multi_turn_dataset_load_valid_data(valid_multi_turn_jsonl): @pytest.mark.unit def test_multi_turn_dataset_user_turn_indexing(valid_multi_turn_jsonl): - """Test that only client turns (user + tool) are indexed as samples.""" + """Test that only client turns (user + tool) are stored as samples.""" dataset = MultiTurnDataset.load_from_file( valid_multi_turn_jsonl, format=DatasetFormat.JSONL ) dataset.load() - # Verify client turn indices are correct (fixture has only user turns) - assert len(dataset._client_turn_indices) == 3 + # data contains only client turns (fixture has only user turns) + assert dataset.num_samples() == 3 - # Check that indices point to client turns - for idx in dataset._client_turn_indices: - assert dataset.data[idx]["role"] in ("user", "tool") + # Every sample in data is a client turn + for i in range(dataset.num_samples()): + assert dataset.load_sample(i)["role"] in ("user", "tool") @pytest.mark.unit @@ -165,9 +165,9 @@ def test_multi_turn_dataset_load_sample(valid_multi_turn_jsonl): assert sample_0["turn"] == 1 assert sample_0["role"] == "user" assert sample_0["content"] == "Hello, how are you?" - # System prompt is in pre_built_messages, not as a separate field - assert sample_0["pre_built_messages"][0]["role"] == "system" - assert sample_0["pre_built_messages"][0]["content"] == "You are a helpful assistant" + # System prompt is the first message in the messages array + assert sample_0["messages"][0]["role"] == "system" + assert sample_0["messages"][0]["content"] == "You are a helpful assistant" # Sample 1 should be second user turn (conv_001 turn 3) sample_1 = dataset.load_sample(1) @@ -268,8 +268,8 @@ def test_multi_turn_dataset_multiple_conversations(): dataset = MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) dataset.load() - # 9 total rows, 5 user turns (c1:t1, c1:t3, c2:t1, c2:t3, c3:t1) - assert len(dataset.data) == 9 + # data contains only client turns: 5 user turns (c1:t1, c1:t3, c2:t1, c2:t3, c3:t1) + assert len(dataset.data) == 5 assert dataset.num_samples() == 5 # Metadata checks @@ -291,7 +291,7 @@ def test_multi_turn_dataset_multiple_conversations(): @pytest.mark.unit def test_multi_turn_dataset_system_prompt_handling(valid_multi_turn_jsonl): - """Test system prompt is included as the first message in pre_built_messages. + """Test system prompt is included as the first message in the messages array. The system prompt is pre-baked into every client turn's message list so the conversation manager no longer needs to track it separately. @@ -301,16 +301,16 @@ def test_multi_turn_dataset_system_prompt_handling(valid_multi_turn_jsonl): ) dataset.load() - # First sample: pre_built_messages starts with system message + # First sample: messages starts with system message sample_0 = dataset.load_sample(0) - assert "pre_built_messages" in sample_0 - msgs = sample_0["pre_built_messages"] + assert "messages" in sample_0 + msgs = sample_0["messages"] assert msgs[0]["role"] == "system" assert msgs[0]["content"] == "You are a helpful assistant" # Second sample (same conversation, turn 3): system message still first sample_1 = dataset.load_sample(1) - msgs_1 = sample_1["pre_built_messages"] + msgs_1 = sample_1["messages"] assert msgs_1[0]["role"] == "system" assert msgs_1[0]["content"] == "You are a helpful assistant" @@ -383,8 +383,8 @@ def test_multi_turn_dataset_conversation_grouping(): dataset = MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) dataset.load() - # 5 total rows, 3 user turns (c1t1, c1t3, c2t1) - assert len(dataset.data) == 5 + # data contains only client turns: 3 user turns (c1t1, c1t3, c2t1) + assert len(dataset.data) == 3 assert dataset.num_samples() == 3 # Load samples to verify conversation grouping @@ -485,14 +485,9 @@ def test_multi_turn_dataset_additional_fields(): dataset.load() sample = dataset.load_sample(0) - # Fields may or may not be present depending on how dataframe handles them - # Just check they're accessible if present - if "model" in sample: - assert sample["model"] == "gpt-4" - if "max_new_tokens" in sample: - assert sample["max_new_tokens"] == 256 - if "temperature" in sample: - assert sample["temperature"] == pytest.approx(0.7) + assert sample["model"] == "gpt-4" + assert sample["max_completion_tokens"] == 256 + assert sample["temperature"] == pytest.approx(0.7) finally: Path(temp_path).unlink() @@ -540,34 +535,33 @@ def test_multi_turn_dataset_openai_field_forwarding(): @pytest.mark.unit def test_multi_turn_dataset_all_generation_params(): - """Test that all generation parameters in GENERATION_PARAMS are forwarded.""" - from inference_endpoint.dataset_manager.multi_turn_dataset import GENERATION_PARAMS - - # Create dataset with all possible generation params + """Test that dataset-supplied generation parameters are forwarded to the sample.""" + # Create dataset with a representative set of generation params + row_params = { + "model": "test-model", + "max_completion_tokens": 100, + "stream": True, + "temperature": 0.8, + "top_p": 0.95, + "top_k": 50, + "seed": 42, + "repetition_penalty": 1.1, + "frequency_penalty": 0.5, + "presence_penalty": 0.3, + "stop": ["END"], + "n": 2, + "logit_bias": {"100": 10}, + "name": "TestEntity", + "user": "test_user_001", + "chat_template": "test_template", + } data = [ { "conversation_id": "c1", "turn": 1, "role": "user", "content": "Test", - # Include all params from GENERATION_PARAMS - "model": "test-model", - "max_new_tokens": 100, - "max_completion_tokens": 100, - "stream": True, - "temperature": 0.8, - "top_p": 0.95, - "top_k": 50, - "seed": 42, - "repetition_penalty": 1.1, - "frequency_penalty": 0.5, - "presence_penalty": 0.3, - "stop": ["END"], - "n": 2, - "logit_bias": {"100": 10}, - "name": "TestEntity", - "user": "test_user_001", - "chat_template": "test_template", + **row_params, }, { "conversation_id": "c1", @@ -588,13 +582,9 @@ def test_multi_turn_dataset_all_generation_params(): sample = dataset.load_sample(0) - # Verify all GENERATION_PARAMS fields are forwarded - # (excluding conversational fields like conversation_id, turn, role, content, system) - for param in GENERATION_PARAMS: - if param in data[0]: - assert ( - param in sample - ), f"Generation parameter '{param}' not forwarded to sample" + # All non-NaN row fields must appear in the pre-baked sample + for param in row_params: + assert param in sample, f"Parameter '{param}' not forwarded to sample" finally: Path(temp_path).unlink() @@ -888,7 +878,7 @@ def test_load_sample_merged_tool_row_has_no_content_key(): s1 = ds.load_sample(1) assert s1["role"] == "tool" assert "content" not in s1 # must NOT emit NaN - assert "pre_built_messages" in s1 + assert "messages" in s1 @pytest.mark.unit @@ -961,27 +951,27 @@ def test_build_metadata_pre_built_messages_no_tools(): @pytest.mark.unit -def test_load_sample_includes_pre_built_messages(): - """load_sample returns pre_built_messages with the complete message list.""" +def test_load_sample_includes_messages(): + """load_sample returns messages with the complete message list.""" df = _make_tool_sequence_df() ds = MultiTurnDataset(df) ds.load() s0 = ds.load_sample(0) # user turn 1 - assert "pre_built_messages" in s0 - msgs = s0["pre_built_messages"] + assert "messages" in s0 + msgs = s0["messages"] assert msgs[0]["role"] == "system" assert msgs[-1] == {"role": "user", "content": "What is the weather?"} s1 = ds.load_sample(1) # tool turn 3 assert s1["role"] == "tool" - msgs_t3 = s1["pre_built_messages"] + msgs_t3 = s1["messages"] # system + user(1) + asst_tc(2) + tool(3) = 4 messages assert len(msgs_t3) == 4 assert msgs_t3[-1]["role"] == "tool" s2 = ds.load_sample(2) # user turn 5 - msgs_t5 = s2["pre_built_messages"] + msgs_t5 = s2["messages"] # system + user(1) + asst_tc(2) + tool(3) + asst(4) + user(5) = 6 messages assert len(msgs_t5) == 6 @@ -1003,8 +993,8 @@ def test_client_turns_include_tool_rows(): @pytest.mark.unit -def test_pre_built_messages_include_prior_assistant_response(valid_multi_turn_jsonl): - """The terminal assistant response before each user turn is included in pre_built_messages.""" +def test_messages_include_prior_assistant_response(valid_multi_turn_jsonl): + """The terminal assistant response before each user turn is included in messages.""" dataset = MultiTurnDataset.load_from_file( valid_multi_turn_jsonl, format=DatasetFormat.JSONL ) @@ -1012,26 +1002,26 @@ def test_pre_built_messages_include_prior_assistant_response(valid_multi_turn_js # Sample 0: turn 1 (first user) β†’ just [system, user(1)] s0 = dataset.load_sample(0) - msgs_0 = s0["pre_built_messages"] + msgs_0 = s0["messages"] assert msgs_0[0]["role"] == "system" assert msgs_0[-1]["role"] == "user" # Sample 1: turn 3 (second user) β†’ [system, user(1), assistant(2), user(3)] s1 = dataset.load_sample(1) - msgs_1 = s1["pre_built_messages"] + msgs_1 = s1["messages"] assert len(msgs_1) == 4 assert msgs_1[2] == {"role": "assistant", "content": "I'm doing well, thank you!"} assert msgs_1[3]["role"] == "user" # Sample 2: turn 1 of conv_002 β†’ no prior assistant row s2 = dataset.load_sample(2) - msgs_2 = s2["pre_built_messages"] + msgs_2 = s2["messages"] assert all(m["role"] != "assistant" for m in msgs_2) @pytest.mark.unit -def test_pre_built_messages_no_cross_conversation_bleed(): - """Messages for conv_001 must not appear in conv_002's pre_built_messages.""" +def test_messages_no_cross_conversation_bleed(): + """Messages for conv_001 must not appear in conv_002's messages array.""" data = [ {"conversation_id": "c1", "turn": 1, "role": "user", "content": "c1 user"}, {"conversation_id": "c2", "turn": 1, "role": "user", "content": "c2 user"}, @@ -1049,25 +1039,25 @@ def test_pre_built_messages_no_cross_conversation_bleed(): # c1: only its own user message s_c1 = dataset.load_sample(0) - assert s_c1["pre_built_messages"] == [{"role": "user", "content": "c1 user"}] + assert s_c1["messages"] == [{"role": "user", "content": "c1 user"}] # c2: only c2 messages (no c1 content) s_c2 = dataset.load_sample(1) - contents = [m.get("content") for m in s_c2["pre_built_messages"]] + contents = [m.get("content") for m in s_c2["messages"]] assert "c1 user" not in contents finally: Path(temp_path).unlink() @pytest.mark.unit -def test_pre_built_messages_with_tool_sequence_terminal_assistant(): - """Terminal assistant response (turn 4) appears in pre_built_messages for user(5).""" +def test_messages_with_tool_sequence_terminal_assistant(): + """Terminal assistant response (turn 4) appears in messages for user(5).""" df = _make_tool_sequence_df() ds = MultiTurnDataset(df) ds.load() s2 = ds.load_sample(2) # user turn 5 - msgs = s2["pre_built_messages"] + msgs = s2["messages"] # The terminal assistant at turn 4 should be included assistant_msgs = [m for m in msgs if m["role"] == "assistant" and m.get("content")] assert any(m["content"] == "The weather is 22Β°C." for m in assistant_msgs) diff --git a/tests/unit/dataset_manager/test_transforms.py b/tests/unit/dataset_manager/test_transforms.py index ab342204..5eca41b4 100644 --- a/tests/unit/dataset_manager/test_transforms.py +++ b/tests/unit/dataset_manager/test_transforms.py @@ -23,6 +23,7 @@ import pandas as pd import pytest from inference_endpoint.dataset_manager.transforms import ( + AddDefaultColumns, AddStaticColumns, ColumnFilter, ColumnRemap, @@ -824,4 +825,51 @@ def test_no_matching_columns(self): # Should not raise error or create prompt column assert "prompt" not in result.columns - assert "unrelated" in result.columns + + +class TestAddDefaultColumns: + """Unit tests for AddDefaultColumns transform.""" + + @pytest.mark.unit + def test_fills_missing_columns(self): + """New columns are added when absent.""" + df = pd.DataFrame({"a": [1, 2]}) + result = AddDefaultColumns({"b": 10, "c": "x"})(df) + assert list(result["b"]) == [10, 10] + assert list(result["c"]) == ["x", "x"] + + @pytest.mark.unit + def test_preserves_existing_non_null_values(self): + """Existing non-null values are not overwritten.""" + df = pd.DataFrame({"a": [1, 2]}) + result = AddDefaultColumns({"a": 99})(df) + assert list(result["a"]) == [1, 2] + + @pytest.mark.unit + def test_fills_nan_values_in_existing_column(self): + """NaN cells in an existing column are replaced with the default.""" + + df = pd.DataFrame({"a": [1.0, float("nan"), 3.0]}) + result = AddDefaultColumns({"a": 99})(df) + assert result["a"].tolist()[0] == 1.0 + assert result["a"].tolist()[1] == 99 + assert result["a"].tolist()[2] == 3.0 + + @pytest.mark.unit + def test_skips_none_default_values(self): + """A None default value is ignored; the column is not modified.""" + df = pd.DataFrame({"a": [1]}) + original_a = df["a"].copy() + result = AddDefaultColumns({"a": None, "b": None})(df) + assert list(result["a"]) == list(original_a) + assert "b" not in result.columns + + @pytest.mark.unit + def test_mixed_nan_and_real_values(self): + """Only NaN cells are filled; real values in the same column are preserved.""" + + df = pd.DataFrame({"temp": [0.9, float("nan"), 0.5]}) + result = AddDefaultColumns({"temp": 0.7})(df) + assert result["temp"].tolist()[0] == pytest.approx(0.9) + assert result["temp"].tolist()[1] == pytest.approx(0.7) + assert result["temp"].tolist()[2] == pytest.approx(0.5) diff --git a/tests/unit/load_generator/test_multi_turn_conversation_manager.py b/tests/unit/load_generator/test_multi_turn_conversation_manager.py new file mode 100644 index 00000000..62602626 --- /dev/null +++ b/tests/unit/load_generator/test_multi_turn_conversation_manager.py @@ -0,0 +1,396 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging + +import pytest +from inference_endpoint.load_generator.conversation_manager import ( + ConversationManager, + ConversationState, +) + + +@pytest.mark.unit +def test_conversation_state_initialization(): + """Test ConversationState initializes with correct default values.""" + state = ConversationState(conversation_id="conv_001") + + assert state.conversation_id == "conv_001" + assert state.current_turn == 0 + assert state.pending_client_turn is None + + +@pytest.mark.unit +def test_conversation_state_add_client_turn(): + """Test adding a client turn updates sequencing state.""" + state = ConversationState(conversation_id="conv_001") + + state.add_client_turn(1) + + assert state.pending_client_turn == 1 + assert state.issued_client_turns == 1 + assert state.current_turn == 0 # Not incremented until assistant response + + +@pytest.mark.unit +def test_conversation_state_add_assistant_turn(): + """Test adding assistant turn completes turn cycle.""" + state = ConversationState(conversation_id="conv_001") + + state.add_client_turn(1) + state.add_assistant_turn() + + assert state.current_turn == 2 + assert state.pending_client_turn is None + assert state.completed_client_turns == 1 + + +@pytest.mark.unit +def test_conversation_state_late_response_after_complete_is_silently_ignored(caplog): + """Late response for a conversation that already completed is silently dropped.""" + state = ConversationState(conversation_id="conv_001", expected_client_turns=1) + + state.add_client_turn(1) + state.add_assistant_turn() + assert state.is_complete() + + completed_before = state.completed_client_turns + current_turn_before = state.current_turn + + with caplog.at_level(logging.WARNING): + state.add_assistant_turn() + + assert state.completed_client_turns == completed_before + assert state.current_turn == current_turn_before + assert "no pending client turn" not in caplog.text + + +@pytest.mark.unit +def test_conversation_state_is_ready_for_turn(): + """Test turn readiness checks using completion counts.""" + state = ConversationState(conversation_id="conv_001") + + assert not state.is_ready_for_turn() + + state.add_client_turn(1) + assert not state.is_ready_for_turn() + + state.add_assistant_turn() + assert state.is_ready_for_turn() + + state.add_client_turn(2) + assert not state.is_ready_for_turn() + + state.add_assistant_turn() + assert state.is_ready_for_turn() + + +@pytest.mark.unit +def test_conversation_state_multi_turn_sequence(): + """Test multi-turn conversation flow updates current_turn correctly.""" + state = ConversationState(conversation_id="conv_001") + + state.add_client_turn(1) + state.add_assistant_turn() + assert state.current_turn == 2 + + state.add_client_turn(3) + state.add_assistant_turn() + assert state.current_turn == 4 + + state.add_client_turn(5) + state.add_assistant_turn() + assert state.current_turn == 6 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_conversation_manager_get_or_create(): + """Test get_or_create returns same state for same conversation_id.""" + manager = ConversationManager() + + state1 = await manager.get_or_create("conv_001") + state2 = await manager.get_or_create("conv_001") + + assert state1 is state2 + assert state1.conversation_id == "conv_001" + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_conversation_manager_multiple_conversations(): + """Test manager can track multiple conversations independently.""" + manager = ConversationManager() + + state1 = await manager.get_or_create("conv_001") + state2 = await manager.get_or_create("conv_002") + + assert state1 is not state2 + + await manager.mark_turn_issued("conv_001", 1) + await manager.mark_turn_complete("conv_001", "Response to conv_001") + + assert state1.current_turn == 2 + assert state2.current_turn == 0 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_conversation_manager_mark_turn_issued(): + """Test mark_turn_issued updates sequencing state.""" + manager = ConversationManager() + state = await manager.get_or_create("conv_001") + + await manager.mark_turn_issued("conv_001", 1) + + assert state.pending_client_turn == 1 + assert state.issued_client_turns == 1 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_conversation_manager_mark_turn_complete(): + """Test mark_turn_complete updates sequencing state.""" + manager = ConversationManager() + state = await manager.get_or_create("conv_001") + + await manager.mark_turn_issued("conv_001", 1) + await manager.mark_turn_complete("conv_001", "Assistant response") + + assert state.current_turn == 2 + assert state.pending_client_turn is None + assert state.completed_client_turns == 1 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_conversation_manager_wait_for_turn_ready_immediate(): + """Test wait_for_turn_ready returns immediately when previous turn is complete.""" + manager = ConversationManager() + await manager.get_or_create("conv_001") + + await manager.mark_turn_issued("conv_001", 1) + await manager.mark_turn_complete("conv_001", "First response") + + result = await manager.wait_for_turn_ready("conv_001", 9, timeout=1.0) + + assert result is True + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_conversation_manager_wait_for_turn_ready_blocking(): + """Test wait_for_turn_ready blocks until previous turn completes.""" + manager = ConversationManager() + await manager.get_or_create("conv_001") + + await manager.mark_turn_issued("conv_001", 1) + + ready_flag = [] + + async def waiter(): + result = await manager.wait_for_turn_ready("conv_001", 3, timeout=2.0) + if result: + ready_flag.append(True) + + waiter_task = asyncio.create_task(waiter()) + await asyncio.sleep(0.05) + assert not ready_flag + + await manager.mark_turn_complete("conv_001", "Assistant response") + await asyncio.sleep(0.05) + await waiter_task + + assert ready_flag == [True] + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_conversation_manager_wait_for_turn_ready_timeout(): + """Test wait_for_turn_ready respects timeout.""" + manager = ConversationManager() + await manager.get_or_create("conv_001") + + await manager.mark_turn_issued("conv_001", 1) + + result = await manager.wait_for_turn_ready("conv_001", 3, timeout=0.1) + + assert result is False + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_conversation_completion_tracking(): + """Test conversation completion detection.""" + manager = ConversationManager() + + state = await manager.get_or_create("conv_001", expected_client_turns=2) + + assert not state.is_complete() + + await manager.mark_turn_issued("conv_001", 1) + assert not state.is_complete() + + await manager.mark_turn_complete("conv_001", "response 1") + assert not state.is_complete() + + await manager.mark_turn_issued("conv_001", 3) + await manager.mark_turn_complete("conv_001", "response 2") + + assert state.is_complete() + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_conversation_completion_without_expected_turns(): + """Test that completion tracking works when expected_client_turns is None.""" + manager = ConversationManager() + + state = await manager.get_or_create("conv_001", expected_client_turns=None) + + assert not state.is_complete() + + await manager.mark_turn_issued("conv_001", 1) + await manager.mark_turn_complete("conv_001", "response 1") + + assert not state.is_complete() + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_conversation_completion_with_failures(): + """Test that conversations complete even when turns fail.""" + manager = ConversationManager() + state = await manager.get_or_create("conv1", expected_client_turns=3) + + await manager.mark_turn_issued("conv1", 1) + await manager.mark_turn_complete("conv1", "Hi there") + assert state.completed_client_turns == 1 + assert not state.is_complete() + + await manager.mark_turn_issued("conv1", 2) + await manager.mark_turn_failed("conv1") + assert state.completed_client_turns == 2 + assert state.failed_client_turns == 1 + assert not state.is_complete() + + await manager.mark_turn_issued("conv1", 3) + await manager.mark_turn_complete("conv1", "Bye!") + assert state.completed_client_turns == 3 + assert state.is_complete() + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_mark_turn_failed_with_no_pending(): + """Test that marking failed turn without pending turn logs warning.""" + manager = ConversationManager() + state = await manager.get_or_create("conv1", expected_client_turns=1) + + await manager.mark_turn_failed("conv1") + + assert state.completed_client_turns == 0 + assert state.failed_client_turns == 0 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_all_turns_fail(): + """Test conversation completion when all turns fail.""" + manager = ConversationManager() + state = await manager.get_or_create("conv1", expected_client_turns=2) + + await manager.mark_turn_issued("conv1", 1) + await manager.mark_turn_failed("conv1") + + await manager.mark_turn_issued("conv1", 2) + await manager.mark_turn_failed("conv1") + + assert state.is_complete() + assert state.completed_client_turns == 2 + assert state.failed_client_turns == 2 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_conversation_manager_concurrent_access(): + """Test async concurrent access to multiple conversations.""" + manager = ConversationManager() + num_conversations = 10 + user_turns_per_conv = 5 + + for i in range(num_conversations): + await manager.get_or_create(f"conv_{i:03d}") + + errors = [] + + async def process_conversation(conv_id: str): + try: + for user_turn_idx in range(user_turns_per_conv): + turn = user_turn_idx * 2 + 1 + + if user_turn_idx > 0: + ready = await manager.wait_for_turn_ready( + conv_id, turn, timeout=5.0 + ) + if not ready: + errors.append(f"{conv_id} turn {turn} timeout") + return + + await manager.mark_turn_issued(conv_id, turn) + await asyncio.sleep(0.001) + await manager.mark_turn_complete(conv_id, f"Response {turn}") + except Exception as e: + errors.append(f"{conv_id} error: {e}") + + tasks = [ + asyncio.create_task(process_conversation(f"conv_{i:03d}")) + for i in range(num_conversations) + ] + await asyncio.gather(*tasks) + + assert not errors, f"Errors occurred: {errors}" + + for i in range(num_conversations): + conv_id = f"conv_{i:03d}" + state = manager._conversations[conv_id] + assert state.current_turn == user_turns_per_conv * 2 + assert state.completed_client_turns == user_turns_per_conv + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_conversation_manager_wait_for_turn_ready_reliably_wakes_on_completion(): + """Test completion wakeups do not depend on timing windows.""" + + async def run_one_iteration(): + mgr = ConversationManager() + await mgr.get_or_create("conv_001") + await mgr.mark_turn_issued("conv_001", 1) + + ready: list[bool] = [] + + async def waiter(m: ConversationManager, r: list) -> None: + r.append(await m.wait_for_turn_ready("conv_001", 3, timeout=0.5)) + + waiter_task = asyncio.create_task(waiter(mgr, ready)) + await asyncio.sleep(0.005) + await mgr.mark_turn_complete("conv_001", "Assistant response") + await asyncio.wait_for(waiter_task, timeout=0.5) + assert ready == [True] + + for _ in range(10): + await run_one_iteration() diff --git a/tests/unit/load_generator/test_multi_turn_strategy.py b/tests/unit/load_generator/test_multi_turn_strategy.py new file mode 100644 index 00000000..55c51994 --- /dev/null +++ b/tests/unit/load_generator/test_multi_turn_strategy.py @@ -0,0 +1,279 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for MultiTurnStrategy.""" + +import asyncio + +import pytest +from inference_endpoint.core.types import QueryResult, TextModelOutput +from inference_endpoint.load_generator.conversation_manager import ConversationManager +from inference_endpoint.load_generator.multi_turn_strategy import MultiTurnStrategy + + +class FakePhaseIssuer: + """Minimal PhaseIssuerProtocol stub.""" + + def __init__(self, stop_after: int | None = None): + self._count = 0 + self._stop_after = stop_after + self.issued: list[int] = [] + self.issued_count = 0 + + def issue(self, sample_index: int, data_override: dict | None = None) -> str | None: + if self._stop_after is not None and self._count >= self._stop_after: + return None + self._count += 1 + self.issued_count += 1 + query_id = f"q{sample_index:04d}" + self.issued.append(sample_index) + return query_id + + +def _make_dataset_metadata(conversations: dict[str, list[int]]) -> dict: + """Build dataset_metadata dict from {conv_id: [turn_numbers]} mapping.""" + samples = [] + sample_index = 0 + for conv_id, turns in conversations.items(): + for turn in turns: + samples.append( + { + "conversation_id": conv_id, + "turn": turn, + "sample_index": sample_index, + } + ) + sample_index += 1 + return {"samples": samples} + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_single_conversation_single_turn(): + """Single conversation, single turn β€” should issue exactly one sample.""" + conv_manager = ConversationManager() + metadata = _make_dataset_metadata({"conv1": [1]}) + strategy = MultiTurnStrategy(conv_manager, metadata) + issuer = FakePhaseIssuer() + + # Simulate response completion (turn 1 is issued, then completes) + async def complete_turns(): + # Wait a tick for the strategy to issue the first turn + await asyncio.sleep(0.01) + # Mark turn 1 complete + state = conv_manager.get_state("conv1") + if state: + await conv_manager.mark_turn_complete("conv1", "response 1") + + asyncio.create_task(complete_turns()) + count = await strategy.execute(issuer) + + assert count == 1 + assert issuer.issued == [0] + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_single_conversation_multi_turn(): + """Single conversation, 3 turns β€” turns must be issued sequentially.""" + conv_manager = ConversationManager() + metadata = _make_dataset_metadata({"conv1": [1, 3, 5]}) + strategy = MultiTurnStrategy(conv_manager, metadata) + issuer = FakePhaseIssuer() + + issued_order: list[str] = [] + original_issue = issuer.issue + + def tracked_issue(idx, data_override=None): + q = original_issue(idx, data_override=data_override) + if q: + issued_order.append(q) + return q + + issuer.issue = tracked_issue + + async def simulate_responses(): + await asyncio.sleep(0.01) + for turn_q, resp in [("q0000", "r1"), ("q0001", "r2"), ("q0002", "r3")]: + # Signal turn complete via on_sample_complete + result = QueryResult( + id=turn_q, response_output=TextModelOutput(output=resp) + ) + strategy.on_sample_complete(result) + await asyncio.sleep(0.01) + + asyncio.create_task(simulate_responses()) + count = await strategy.execute(issuer) + + assert count == 3 + assert issuer.issued == [0, 1, 2] + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_multiple_conversations_concurrent(): + """Two conversations run concurrently, each with 2 turns.""" + conv_manager = ConversationManager() + metadata = _make_dataset_metadata({"conv1": [1, 3], "conv2": [1, 3]}) + strategy = MultiTurnStrategy(conv_manager, metadata) + issuer = FakePhaseIssuer() + + async def simulate_responses(): + await asyncio.sleep(0.02) + # Complete all turns for both conversations + for q_prefix in range(4): + q = f"q{q_prefix:04d}" + result = QueryResult(id=q, response_output=TextModelOutput(output="resp")) + strategy.on_sample_complete(result) + await asyncio.sleep(0.01) + + asyncio.create_task(simulate_responses()) + count = await strategy.execute(issuer) + + assert count == 4 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_turn_ordering_enforced(): + """Turn 2 must not be issued before Turn 1 completes.""" + conv_manager = ConversationManager() + metadata = _make_dataset_metadata({"conv1": [1, 3]}) + strategy = MultiTurnStrategy(conv_manager, metadata) + + issue_timestamps: dict[int, float] = {} + complete_timestamps: dict[int, float] = {} + + class TimedIssuer: + issued_count = 0 + issued: list[int] = [] + + def issue(self, idx: int, data_override: dict | None = None) -> str | None: + import time + + issue_timestamps[idx] = time.monotonic() + self.issued.append(idx) + self.issued_count += 1 + return f"q{idx:04d}" + + issuer = TimedIssuer() + + async def simulate_responses(): + import time + + await asyncio.sleep(0.02) + # Complete turn 1 (sample 0) after a delay + complete_timestamps[0] = time.monotonic() + result = QueryResult(id="q0000", response_output=TextModelOutput(output="r1")) + strategy.on_sample_complete(result) + await asyncio.sleep(0.05) + # Complete turn 2 (sample 1) + complete_timestamps[1] = time.monotonic() + result = QueryResult(id="q0001", response_output=TextModelOutput(output="r2")) + strategy.on_sample_complete(result) + + asyncio.create_task(simulate_responses()) + count = await strategy.execute(issuer) + + assert count == 2 + # Turn 2 (sample index 1) must be issued AFTER turn 1 completes + assert issue_timestamps[1] >= complete_timestamps[0] + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_turn_timeout_triggers_failure(): + """A turn that never completes should timeout and abort remaining turns.""" + conv_manager = ConversationManager() + metadata = _make_dataset_metadata({"conv1": [1, 3]}) + strategy = MultiTurnStrategy(conv_manager, metadata, target_concurrency=None) + strategy._turn_timeout_s = 0.1 # Very short timeout for testing + issuer = FakePhaseIssuer() + + # Do NOT simulate any response β€” turn 1 will timeout + await strategy.execute(issuer) + + # Only turn 1 should be issued (turn 2 never gets to run) + assert issuer.issued_count == 1 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_on_query_complete_releases_semaphore(): + """on_query_complete releases the concurrency semaphore.""" + conv_manager = ConversationManager() + metadata = _make_dataset_metadata({"conv1": [1]}) + strategy = MultiTurnStrategy(conv_manager, metadata, target_concurrency=1) + assert strategy._sem is not None + + # Acquire the semaphore manually + await strategy._sem.acquire() + assert strategy._sem._value == 0 # type: ignore[attr-defined] + + strategy.on_query_complete("some-query") + assert strategy._sem._value == 1 # type: ignore[attr-defined] + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_on_sample_complete_routes_to_manager(): + """on_sample_complete marks the turn complete in the ConversationManager.""" + conv_manager = ConversationManager() + await conv_manager.get_or_create("conv1", expected_client_turns=1) + metadata = _make_dataset_metadata({"conv1": [1]}) + strategy = MultiTurnStrategy(conv_manager, metadata) + + # Simulate issuer registering conv_id in _inflight + strategy._inflight["q0001"] = "conv1" + # Pre-issue a turn so the state has pending_client_turn + await conv_manager.mark_turn_issued("conv1", 1) + + result = QueryResult(id="q0001", response_output=TextModelOutput(output="hello")) + strategy.on_sample_complete(result) + + # Allow the ensure_future coroutine to run + await asyncio.sleep(0.01) + + state = conv_manager.get_state("conv1") + assert state is not None + assert state.completed_client_turns == 1 + assert state.is_complete() + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_error_response_marks_turn_failed(): + """on_sample_complete marks failed when result.error is set.""" + from inference_endpoint.core.types import ErrorData + + conv_manager = ConversationManager() + await conv_manager.get_or_create("conv1", expected_client_turns=1) + metadata = _make_dataset_metadata({"conv1": [1]}) + strategy = MultiTurnStrategy(conv_manager, metadata) + + strategy._inflight["q0001"] = "conv1" + await conv_manager.mark_turn_issued("conv1", 1) + + result = QueryResult( + id="q0001", + response_output=None, + error=ErrorData(error_type="timeout", error_message="timed out"), + ) + strategy.on_sample_complete(result) + await asyncio.sleep(0.01) + + state = conv_manager.get_state("conv1") + assert state is not None + assert state.failed_client_turns == 1 From eb99f583376b209800fe4de2b161a085b7acc5c2 Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Fri, 24 Apr 2026 10:32:39 -0700 Subject: [PATCH 03/41] test: add multi-turn unit and integration tests Add unit tests for MultiTurnDataset, ConversationManager, and MultiTurnStrategy; add integration tests including tool-use scenarios and large-concurrency stress tests. --- tests/integration/test_multi_turn.py | 309 ++++++++++++++++++++++++++- 1 file changed, 308 insertions(+), 1 deletion(-) diff --git a/tests/integration/test_multi_turn.py b/tests/integration/test_multi_turn.py index 5a4d128d..ca17a236 100644 --- a/tests/integration/test_multi_turn.py +++ b/tests/integration/test_multi_turn.py @@ -16,7 +16,8 @@ """Integration tests for multi-turn benchmarking end-to-end. Validates that MultiTurnDataset + MultiTurnStrategy + BenchmarkSession work -correctly together against a real HTTP echo server. +correctly together against a real HTTP echo server (echo tests) and a live +model endpoint (live tests at port 8868). Tests cover: 1. Dataset-history mode (use_dataset_history=True): pre-built messages are @@ -26,12 +27,16 @@ grow with each turn. 3. Multiple concurrent conversations complete successfully. 4. Turn ordering: turn N+1 is never issued before turn N completes. + 5. Live concurrency: parametrized target_concurrency levels against a real + model endpoint verify all turns complete regardless of throttle setting. """ import asyncio +import json import random import time from urllib.parse import urljoin +from urllib.request import urlopen import pandas as pd import pytest @@ -423,3 +428,305 @@ def on_complete(result: QueryResult) -> None: # Turn 3 must complete after turn 1 completes assert complete_times[q_turn3] >= complete_times[q_turn1] + + +# --------------------------------------------------------------------------- +# Live endpoint fixtures and helpers +# --------------------------------------------------------------------------- + +_LIVE_ENDPOINT = "http://localhost:8868" + + +def _query_model_name(endpoint: str) -> str: + """Return the first model name from the endpoint, or skip if unreachable.""" + try: + with urlopen(f"{endpoint}/v1/models", timeout=5.0) as resp: + data = json.loads(resp.read()) + return data["data"][0]["id"] + except Exception as e: + pytest.skip(f"Live endpoint {endpoint} not reachable: {e}") + return "" + + +def _make_live_rows( + model: str, n_conversations: int = 20, n_user_turns: int = 3 +) -> list[dict]: + """Build a multi-conversation dataset rows list. + + Each conversation has n_user_turns user turns interleaved with scripted + assistant placeholders (needed to satisfy the turn-structure validator but + never sent to the endpoint). The resulting dataset produces + n_conversations Γ— n_user_turns client-turn samples. + """ + rows = [] + _user_prompts = [ + "Reply with exactly one word: the number {n} in English.", + "Add one to the previous number. Reply with only that word.", + "Add one more. Reply with only that word.", + ] + for i in range(n_conversations): + conv_id = f"live_conv_{i:03d}" + turn = 1 + for j in range(n_user_turns): + prompt = _user_prompts[j % len(_user_prompts)].format(n=i + 1) + rows.append( + { + "conversation_id": conv_id, + "turn": turn, + "role": "user", + "content": prompt, + "model": model, + "max_completion_tokens": 10, + } + ) + turn += 1 + if j < n_user_turns - 1: + rows.append( + { + "conversation_id": conv_id, + "turn": turn, + "role": "assistant", + "content": "placeholder", + } + ) + turn += 1 + return rows + + +async def _run_live_session( + model: str, + n_conversations: int, + n_user_turns: int, + target_concurrency: int | None, + timeout_s: float = 300.0, +) -> tuple[int, dict[str, str]]: + """Run a live multi-turn session against the endpoint at _LIVE_ENDPOINT. + + Returns (issued_count, {query_id: response_text}). + """ + rows = _make_live_rows(model, n_conversations, n_user_turns) + ds = MultiTurnDataset(dataframe=pd.DataFrame(rows)) + ds.load() + + mt_cfg = MultiTurnConfig( + turn_timeout_s=60.0, + use_dataset_history=True, + ) + strategy = MultiTurnStrategy( + conversation_manager=ConversationManager(), + dataset_metadata=ds.conversation_metadata, + multi_turn_config=mt_cfg, + target_concurrency=target_concurrency, + ) + + loop = asyncio.get_running_loop() + responses: dict[str, str] = {} + + def on_complete(result: QueryResult) -> None: + strategy.on_sample_complete(result) + responses[result.id] = result.get_response_output_string() + + http_config = HTTPClientConfig( + endpoint_urls=[f"{_LIVE_ENDPOINT}/v1/chat/completions"], + warmup_connections=0, + num_workers=4, + ) + http_client = await HTTPEndpointClient.create(http_config, loop) + issuer = HttpClientSampleIssuer(http_client) + + try: + session = BenchmarkSession( + issuer=issuer, + event_publisher=_NoOpPublisher(), + loop=loop, + on_sample_complete=on_complete, + ) + rt = RuntimeSettings( + metrics.Throughput(1000), + [metrics.Throughput(1000)], + min_duration_ms=0, + max_duration_ms=int(timeout_s * 1000), + n_samples_from_dataset=ds.num_samples(), + n_samples_to_issue=ds.num_samples(), + min_sample_count=1, + rng_sched=random.Random(42), + rng_sample_index=random.Random(42), + load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), + ) + phase = PhaseConfig("perf", rt, ds, PhaseType.PERFORMANCE, strategy=strategy) + result = await asyncio.wait_for(session.run([phase]), timeout=timeout_s) + return result.perf_results[0].issued_count, responses + finally: + await http_client.shutdown_async() + + +# --------------------------------------------------------------------------- +# Live concurrency tests +# --------------------------------------------------------------------------- + + +@pytest.mark.integration +@pytest.mark.asyncio +@pytest.mark.parametrize( + "target_concurrency", + [ + pytest.param(1, id="concurrency_1"), + pytest.param(4, id="concurrency_4"), + pytest.param(None, id="concurrency_unlimited"), + ], +) +async def test_live_concurrency(target_concurrency): + """All turns of 20 concurrent conversations complete for each concurrency level. + + Uses the live model endpoint at port 8868. Each conversation has 3 user + turns (60 total requests). Verifies that every turn receives a non-empty + response regardless of the concurrency throttle applied by target_concurrency. + """ + model = _query_model_name(_LIVE_ENDPOINT) + n_conversations = 20 + n_user_turns = 3 + expected_turns = n_conversations * n_user_turns # 60 total requests + + issued, responses = await _run_live_session( + model=model, + n_conversations=n_conversations, + n_user_turns=n_user_turns, + target_concurrency=target_concurrency, + timeout_s=300.0, + ) + + assert issued == expected_turns, f"Expected {expected_turns} issued, got {issued}" + assert ( + len(responses) == expected_turns + ), f"Expected {expected_turns} responses, got {len(responses)}" + for qid, text in responses.items(): + assert text.strip(), f"Query {qid} returned empty response" + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_live_turn_ordering_multi_conversation(): + """Turn N+1 of each conversation is always issued after turn N completes. + + Runs 10 conversations with 3 turns each concurrently (30 total requests). + Records per-query completion timestamps and asserts that within every + conversation each successive turn completes no earlier than the previous. + """ + model = _query_model_name(_LIVE_ENDPOINT) + n_conversations = 10 + n_user_turns = 3 + rows = _make_live_rows(model, n_conversations, n_user_turns) + + ds = MultiTurnDataset(dataframe=pd.DataFrame(rows)) + ds.load() + + conv_manager = ConversationManager() + mt_cfg = MultiTurnConfig(turn_timeout_s=60.0, use_dataset_history=True) + strategy = MultiTurnStrategy( + conversation_manager=conv_manager, + dataset_metadata=ds.conversation_metadata, + multi_turn_config=mt_cfg, + ) + + complete_times: dict[str, float] = {} + orig_on_sample_complete = strategy.on_sample_complete + + def tracked_complete(result: QueryResult) -> None: + complete_times[result.id] = time.monotonic() + orig_on_sample_complete(result) + + strategy.on_sample_complete = tracked_complete + + loop = asyncio.get_running_loop() + responses: dict[str, str] = {} + + http_config = HTTPClientConfig( + endpoint_urls=[f"{_LIVE_ENDPOINT}/v1/chat/completions"], + warmup_connections=0, + num_workers=4, + ) + http_client = await HTTPEndpointClient.create(http_config, loop) + issuer = HttpClientSampleIssuer(http_client) + + try: + + def on_complete(result: QueryResult) -> None: + tracked_complete(result) + responses[result.id] = result.get_response_output_string() + + session = BenchmarkSession( + issuer=issuer, + event_publisher=_NoOpPublisher(), + loop=loop, + on_sample_complete=on_complete, + ) + rt = RuntimeSettings( + metrics.Throughput(1000), + [metrics.Throughput(1000)], + min_duration_ms=0, + max_duration_ms=300_000, + n_samples_from_dataset=ds.num_samples(), + n_samples_to_issue=ds.num_samples(), + min_sample_count=1, + rng_sched=random.Random(42), + rng_sample_index=random.Random(42), + load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), + ) + phase = PhaseConfig("perf", rt, ds, PhaseType.PERFORMANCE, strategy=strategy) + result = await asyncio.wait_for(session.run([phase]), timeout=300.0) + finally: + await http_client.shutdown_async() + + expected_total = n_conversations * n_user_turns + assert result.perf_results[0].issued_count == expected_total + + # Build index β†’ query_id map and verify per-conversation ordering. + # Samples are grouped by conversation, turns sorted ascending within each: + # conv_0_t1, conv_0_t2, conv_0_t3, conv_1_t1, ... + uuid_to_index = result.perf_results[0].uuid_to_index + index_to_query = {v: k for k, v in uuid_to_index.items()} + + for conv_i in range(n_conversations): + base = conv_i * n_user_turns + for turn_j in range(n_user_turns - 1): + q_cur = index_to_query[base + turn_j] + q_next = index_to_query[base + turn_j + 1] + assert complete_times[q_cur] <= complete_times[q_next], ( + f"conv {conv_i}: turn {turn_j + 2} completed before turn {turn_j + 1} " + f"(t{turn_j + 1}={complete_times[q_cur]:.4f}, " + f"t{turn_j + 2}={complete_times[q_next]:.4f})" + ) + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_live_large_concurrency(): + """All turns complete correctly under a large concurrency limit (>=512). + + Uses 200 conversations Γ— 3 turns = 600 total requests with + target_concurrency=512. The semaphore allows up to 512 simultaneous + in-flight requests, so the first wave of 200 first-turns is issued + without throttling, and subsequent turns queue naturally. Verifies + that all 600 turns complete and return non-empty responses, confirming + the semaphore implementation handles large values without deadlock or + starvation. + """ + model = _query_model_name(_LIVE_ENDPOINT) + n_conversations = 200 + n_user_turns = 3 + expected_turns = n_conversations * n_user_turns # 600 total requests + + issued, responses = await _run_live_session( + model=model, + n_conversations=n_conversations, + n_user_turns=n_user_turns, + target_concurrency=512, + timeout_s=300.0, + ) + + assert issued == expected_turns, f"Expected {expected_turns} issued, got {issued}" + assert ( + len(responses) == expected_turns + ), f"Expected {expected_turns} responses, got {len(responses)}" + for qid, text in responses.items(): + assert text.strip(), f"Query {qid} returned empty response" From 75b64d657acb6fc7d836de94b69f646d097ccdd9 Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Fri, 24 Apr 2026 11:20:42 -0700 Subject: [PATCH 04/41] feat: wire multi-turn into benchmark execution pipeline Consolidate multi-turn dataset with single-turn transform pipeline, fix prior-row extraction, live-history mode, system prompt injection, tool_calls preservation, and asyncio.Event-based sequencing. --- .../09_MultiTurn/multi_turn_benchmark.yaml | 9 +- .../multi_turn_with_concurrency.yaml | 7 - .../commands/benchmark/execute.py | 1 + .../config/runtime_settings.py | 11 + src/inference_endpoint/config/schema.py | 19 + .../dataset_manager/multi_turn_dataset.py | 59 ++- .../load_generator/conversation_manager.py | 348 +++---------- .../load_generator/multi_turn_strategy.py | 96 ++-- .../openai/openai_adapter.py | 1 + tests/integration/test_multi_turn.py | 474 +++++++----------- tests/unit/config/test_schema.py | 119 +++++ .../test_multi_turn_dataset.py | 363 ++++++++++++++ .../test_multi_turn_conversation_manager.py | 386 +++++--------- .../test_multi_turn_strategy.py | 125 ++++- tests/unit/openai/test_openai_adapter.py | 147 ++++++ 15 files changed, 1281 insertions(+), 884 deletions(-) create mode 100644 tests/unit/openai/test_openai_adapter.py diff --git a/examples/09_MultiTurn/multi_turn_benchmark.yaml b/examples/09_MultiTurn/multi_turn_benchmark.yaml index 9ed6c9f1..da4773e0 100644 --- a/examples/09_MultiTurn/multi_turn_benchmark.yaml +++ b/examples/09_MultiTurn/multi_turn_benchmark.yaml @@ -24,18 +24,11 @@ settings: load_pattern: type: multi_turn - # target_concurrency: 32 # Optional: limit concurrent requests across all conversations + target_concurrency: 32 client: warmup_connections: 0 -metrics: - collect: - - throughput - - latency - - ttft - - tpot - endpoint_config: endpoints: - "http://localhost:8868" diff --git a/examples/09_MultiTurn/multi_turn_with_concurrency.yaml b/examples/09_MultiTurn/multi_turn_with_concurrency.yaml index 491e6b4b..ba5362e3 100644 --- a/examples/09_MultiTurn/multi_turn_with_concurrency.yaml +++ b/examples/09_MultiTurn/multi_turn_with_concurrency.yaml @@ -29,13 +29,6 @@ settings: client: warmup_connections: 0 -metrics: - collect: - - throughput - - latency - - ttft - - tpot - endpoint_config: endpoints: - "http://localhost:8868" diff --git a/src/inference_endpoint/commands/benchmark/execute.py b/src/inference_endpoint/commands/benchmark/execute.py index 1efe1a3a..b5230a53 100644 --- a/src/inference_endpoint/commands/benchmark/execute.py +++ b/src/inference_endpoint/commands/benchmark/execute.py @@ -553,6 +553,7 @@ async def _run_benchmark_async( conversation_manager=ConversationManager(), dataset_metadata=ctx.dataloader.conversation_metadata, multi_turn_config=mt_cfg, + target_concurrency=ctx.config.settings.load_pattern.target_concurrency, ) def _on_sample_complete(result: QueryResult) -> None: diff --git a/src/inference_endpoint/config/runtime_settings.py b/src/inference_endpoint/config/runtime_settings.py index fb349a02..a3fb3106 100644 --- a/src/inference_endpoint/config/runtime_settings.py +++ b/src/inference_endpoint/config/runtime_settings.py @@ -194,6 +194,17 @@ def total_samples_to_issue( ) return self.n_samples_to_issue + # Multi-turn must issue exactly all client turns β€” QPS-based formulas are meaningless. + if ( + self.load_pattern is not None + and self.load_pattern.type.value == "multi_turn" + ): + result = max(self.min_sample_count, self.n_samples_from_dataset) + logger.debug( + f"Sample count: {result} (multi-turn: issuing all {self.n_samples_from_dataset} client turns)" + ) + return result + # If min_duration is 0, use all dataset samples (new CLI default behavior) if self.min_duration_ms == 0: result = max(self.min_sample_count, self.n_samples_from_dataset) diff --git a/src/inference_endpoint/config/schema.py b/src/inference_endpoint/config/schema.py index 1cd0f172..1f487cbe 100644 --- a/src/inference_endpoint/config/schema.py +++ b/src/inference_endpoint/config/schema.py @@ -419,6 +419,12 @@ def _validate_completeness(self) -> Self: raise ValueError( "Concurrency requires --concurrency (e.g., --concurrency 10)" ) + if self.type == LoadPatternType.MULTI_TURN and ( + not self.target_concurrency or self.target_concurrency <= 0 + ): + raise ValueError( + "Multi-turn requires --concurrency (e.g., --concurrency 96)" + ) return self @@ -625,6 +631,19 @@ def _resolve_and_validate(self) -> Self: "Online mode requires --load-pattern (poisson, concurrency, or multi_turn)" ) + # Cross-validate load_pattern.type=multi_turn ↔ dataset.multi_turn config + has_multi_turn_dataset = any( + d.multi_turn is not None for d in (self.datasets or []) + ) + if lp.type == LoadPatternType.MULTI_TURN and not has_multi_turn_dataset: + raise ValueError( + "load_pattern.type=multi_turn requires at least one dataset with multi_turn config" + ) + if has_multi_turn_dataset and lp.type != LoadPatternType.MULTI_TURN: + raise ValueError( + f"Datasets with multi_turn config require load_pattern.type=multi_turn, got '{lp.type}'" + ) + return self @model_validator(mode="after") diff --git a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py index f26c3d3e..574619c8 100644 --- a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py +++ b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py @@ -195,6 +195,8 @@ def _build_metadata(self) -> dict[str, Any]: # This includes assistant rows (tool dispatches or terminal responses) # so no runtime injection is required. pre_built_messages_by_key: dict[tuple, list[dict]] = {} + current_turn_messages_by_key: dict[tuple, list[dict]] = {} + system_prompts_by_conv: dict[str, str | None] = {} for conv_id, group in self.dataframe.groupby("conversation_id"): sorted_group = group.sort_values("turn") @@ -207,6 +209,7 @@ def _build_metadata(self) -> dict[str, Any]: if val and isinstance(val, str): system_content = val break + system_prompts_by_conv[str(conv_id)] = system_content for idx, row in client_rows.iterrows(): t_n = int(row["turn"]) @@ -220,12 +223,18 @@ def _build_metadata(self) -> dict[str, Any]: prior_rows = sorted_group[sorted_group["turn"] < t_n] for _, prior_row in prior_rows.iterrows(): msg: dict[str, Any] = {} - for key in ("role", "content", "tool_calls"): + for key in ("role", "content", "tool_calls", "tool_results"): val = prior_row.get(key) if val is not None and not ( isinstance(val, float) and pd.isna(val) ): msg[key] = val + if ( + msg.get("role") == "assistant" + and "tool_calls" in msg + and "content" not in msg + ): + msg["content"] = None if msg.get("role"): # Expand merged parallel tool results: a single row with # tool_results: [{tool_call_id, content}, ...] expands into @@ -239,9 +248,10 @@ def _build_metadata(self) -> dict[str, Any]: # Append the current client turn message. # A merged parallel-tool row carries tool_results instead of a # single tool_call_id/content pair; expand to one message per result. + current_turn_msgs: list[dict] = [] expanded = _expand_tool_results(row) if expanded: - messages.extend(expanded) + current_turn_msgs = expanded else: cur: dict[str, Any] = {} for key in ("role", "content"): @@ -250,9 +260,11 @@ def _build_metadata(self) -> dict[str, Any]: isinstance(val, float) and pd.isna(val) ): cur[key] = val - messages.append(cur) + current_turn_msgs = [cur] + messages.extend(current_turn_msgs) pre_built_messages_by_key[(conv_id, t_n)] = messages + current_turn_messages_by_key[(conv_id, t_n)] = current_turn_msgs samples.append( { @@ -270,6 +282,8 @@ def _build_metadata(self) -> dict[str, Any]: .max(), "client_turns_per_conversation": client_turns_per_conv, "pre_built_messages_by_key": pre_built_messages_by_key, + "current_turn_messages_by_key": current_turn_messages_by_key, + "system_prompts_by_conv": system_prompts_by_conv, } def load( @@ -287,8 +301,8 @@ def load( per-row dataset overrides are preserved. After transforms, only client turns (user + tool) are stored in self.data as - fully assembled sample dicts (with messages, current_turn_message, system_content - attached). load_sample() and num_samples() are inherited from the base class. + fully assembled sample dicts (with messages attached). + load_sample() and num_samples() are inherited from the base class. """ if not force and self.data is not None: return @@ -325,6 +339,26 @@ def load( pre_built = self.conversation_metadata.get("pre_built_messages_by_key", {}) client_turn_samples: list[dict[str, Any]] = [] + # Collect per-conversation defaults from the first user row so that + # fields like model/max_completion_tokens propagate to tool rows. + _PROPAGATED_KEYS = { + "model", + "max_completion_tokens", + "max_new_tokens", + "stream", + } + conv_defaults: dict[str, dict[str, Any]] = {} + for row in all_rows: + cid = row.get("conversation_id") + if cid not in conv_defaults and row.get("role") == "user": + conv_defaults[cid] = { + k: row[k] + for k in _PROPAGATED_KEYS + if k in row + and row[k] is not None + and not (isinstance(row[k], float) and pd.isna(row[k])) + } + for row in all_rows: if row.get("role") not in ("user", "tool"): continue @@ -336,6 +370,14 @@ def load( for k, v in row.items() if v is not None and not (isinstance(v, float) and pd.isna(v)) } + # Strip dataset-internal fields that must not reach the endpoint. + sample.pop("tool_results", None) + sample.pop("tool_calls", None) + + # Fill missing propagated fields from the first user row of this conversation. + for k, v in conv_defaults.get(row.get("conversation_id"), {}).items(): + if k not in sample: + sample[k] = v # max_new_tokens β†’ max_completion_tokens alias if "max_completion_tokens" not in sample and "max_new_tokens" in sample: @@ -350,13 +392,6 @@ def load( messages = pre_built.get(key, []) sample["messages"] = messages - # Fields for use_dataset_history=False path (live history accumulation). - sample["current_turn_message"] = messages[-1] if messages else {} - first = messages[0] if messages else {} - sample["system_content"] = ( - first.get("content") if first.get("role") == "system" else None - ) - client_turn_samples.append(sample) self.data = client_turn_samples diff --git a/src/inference_endpoint/load_generator/conversation_manager.py b/src/inference_endpoint/load_generator/conversation_manager.py index 86741711..ba9a02ea 100644 --- a/src/inference_endpoint/load_generator/conversation_manager.py +++ b/src/inference_endpoint/load_generator/conversation_manager.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Async conversation state management for multi-turn benchmarking.""" +"""Conversation state management for multi-turn benchmarking.""" import asyncio import logging @@ -25,332 +25,150 @@ @dataclass class ConversationState: - """Tracks conversation sequencing for multi-turn benchmarking. + """Per-conversation state for multi-turn benchmarking. - Maintains turn counters and asyncio conditions so the strategy can enforce - sequential turn ordering within a conversation. Message history is NOT stored - here β€” it is pre-computed in MultiTurnDataset and served via load_sample(). + The pipeline task awaits ``turn_done`` between turns; ``mark_turn_complete`` + and ``mark_turn_failed`` set it synchronously from ``on_sample_complete``. Attributes: conversation_id: Unique identifier for this conversation. - current_turn: Last completed turn number (0 = not started). - pending_client_turn: Turn number of in-flight client turn (None if idle). - expected_client_turns: Expected number of client turns (for completion tracking). - issued_client_turns: Count of client turns issued. - completed_client_turns: Count of client turns with responses. - failed_client_turns: Count of client turns that failed (error/timeout). - message_history: Accumulated message list (only populated when + turn_done: Event set when a response arrives. Pipeline waits, then clears + it before issuing the next turn. + message_history: Accumulated message list (populated only when use_dataset_history=False; empty otherwise). - condition: Per-conversation asyncio.Condition for turn-ready and turn-issued waits. - Scoped to this conversation so that state changes only wake the single - pipeline task waiting on this conversation, not all pipeline tasks. + completed_turns: Turns with responses (success or failure) β€” observability only. + failed_turns: Turns that failed β€” observability only. + expected_client_turns: Expected total client turns (for completion detection). """ conversation_id: str - current_turn: int = 0 - pending_client_turn: int | None = None - expected_client_turns: int | None = None - issued_client_turns: int = 0 - completed_client_turns: int = 0 - failed_client_turns: int = 0 + turn_done: asyncio.Event = field(default_factory=asyncio.Event) message_history: list[dict[str, Any]] = field(default_factory=list) - condition: asyncio.Condition = field(default_factory=asyncio.Condition) - - def add_client_turn(self, turn: int, message: dict[str, Any] | None = None): - """Record that a client turn has been issued (updates sequencing counters). - - Args: - turn: Turn number for this client message. - message: Message dict to append to message_history (only used when - use_dataset_history=False). - """ - self.pending_client_turn = turn - self.issued_client_turns += 1 - if message is not None: - self.message_history.append(message) - - def add_assistant_turn(self, content: str | None = None): - """Record assistant response and mark turn complete (success). - - Args: - content: Response content to append to message_history. Only - used when use_dataset_history=False; None means no history - update (pre-built messages path). - """ - if content is not None: - self.message_history.append({"role": "assistant", "content": content}) - if self.pending_client_turn is not None: - self.current_turn = self.pending_client_turn + 1 - self.pending_client_turn = None - self.completed_client_turns += 1 - elif self.is_complete(): - pass - else: - logger.warning( - f"Received assistant response for {self.conversation_id} " - f"with no pending client turn (duplicate or out-of-order response)" - ) - self.current_turn = self.current_turn + 1 if self.current_turn > 0 else 1 - self.completed_client_turns += 1 - - if self.is_complete(): - if self.failed_client_turns > 0: - logger.info( - f"Conversation {self.conversation_id} completed with failures: " - f"{self.completed_client_turns - self.failed_client_turns}/" - f"{self.expected_client_turns} successful, " - f"{self.failed_client_turns} failed" - ) - else: - logger.debug( - f"Conversation {self.conversation_id} completed: " - f"{self.completed_client_turns}/{self.expected_client_turns} turns" - ) - - def mark_turn_failed(self, store_in_history: bool = False): - """Mark turn as failed (error/timeout) - still counts as completed for sequencing.""" - if self.pending_client_turn is not None: - self.current_turn = self.pending_client_turn + 1 - self.pending_client_turn = None - self.completed_client_turns += 1 - self.failed_client_turns += 1 - - if store_in_history: - self.message_history.append( - { - "role": "assistant", - "content": "[ERROR: Turn failed or timed out]", - } - ) - - logger.warning( - f"Turn {self.current_turn - 1} failed for conversation {self.conversation_id}" - ) - else: - logger.warning( - f"Attempted to mark failed turn for {self.conversation_id} " - f"with no pending client turn" - ) - - if self.is_complete(): - logger.info( - f"Conversation {self.conversation_id} completed with failures: " - f"{self.completed_client_turns - self.failed_client_turns}/" - f"{self.expected_client_turns} successful, " - f"{self.failed_client_turns} failed" - ) + completed_turns: int = 0 + failed_turns: int = 0 + expected_client_turns: int | None = None def is_complete(self) -> bool: - """Check if conversation is complete (all turns issued and responses received).""" + """Return True when all expected turns have a response.""" if self.expected_client_turns is None: return False - return self.completed_client_turns >= self.expected_client_turns - - def is_ready_for_turn(self) -> bool: - """Check if the previous turn has completed and the next may be issued.""" - return ( - self.pending_client_turn is None - and self.issued_client_turns == self.completed_client_turns - and self.issued_client_turns > 0 - ) + return self.completed_turns >= self.expected_client_turns class ConversationManager: - """Manages conversation sequencing for multi-turn benchmarking. - - Async manager that tracks multiple conversations and enforces turn ordering. - Conversations are identified by unique IDs. Message history is NOT maintained here - β€” it is pre-computed in MultiTurnDataset and passed directly to each request. + """Manages per-conversation state for multi-turn benchmarking. - The manager ensures that: - - Turn N+1 cannot be issued until turn N completes - - Concurrent access to conversation state is async-safe + All methods are synchronous. The pipeline task uses ``ConversationState.turn_done`` + directly for turn-done notification β€” no locks or condition variables needed. - Each ConversationState carries its own asyncio.Condition so that state changes - (turn issued / turn complete) only wake the single pipeline task waiting - on that conversation, not all pipeline tasks across all conversations. - All conversation states are pre-created by the strategy before pipeline - tasks start, so wait_for_turn_issued never races against get_or_create. + All states are pre-created by ``MultiTurnStrategy.execute()`` before any pipeline + task starts, so ``get_or_create()`` requires no locking. """ def __init__(self): - """Initialize conversation manager with empty state.""" + """Initialize with empty state.""" self._conversations: dict[str, ConversationState] = {} - self._lock = asyncio.Lock() def get_state(self, conversation_id: str) -> ConversationState | None: - """Get conversation state without creating (for read-only access).""" + """Return existing state without creating (read-only access).""" return self._conversations.get(conversation_id) - async def get_or_create( + def get_or_create( self, conversation_id: str, expected_client_turns: int | None = None, system_message: dict[str, Any] | None = None, ) -> ConversationState: - """Get existing or create new conversation state. + """Return existing state or create a new one. Args: conversation_id: Unique identifier for conversation. - expected_client_turns: Expected number of client turns (for completion tracking). - system_message: System message dict to pre-populate message_history with. - Only used when use_dataset_history=False and conversation is new. + expected_client_turns: Expected number of client turns. + system_message: System message to prepend to message_history + (only used when use_dataset_history=False and state is new). Returns: ConversationState for this conversation. """ - async with self._lock: - if conversation_id not in self._conversations: - initial_history: list[dict[str, Any]] = ( - [system_message] if system_message is not None else [] - ) - state = ConversationState( - conversation_id=conversation_id, - current_turn=0, - pending_client_turn=None, - expected_client_turns=expected_client_turns, - issued_client_turns=0, - completed_client_turns=0, - failed_client_turns=0, - message_history=initial_history, - ) - self._conversations[conversation_id] = state - return self._conversations[conversation_id] - - async def wait_for_turn_ready( - self, conversation_id: str, turn: int, timeout: float | None = None - ) -> bool: - """Block until conversation is ready for this turn. - - Uses the per-conversation asyncio.Condition so only this conversation's pipeline - task is woken on state changes, not all pipeline tasks. - - Args: - conversation_id: Conversation to wait for. - turn: Turn number to wait for (unused in readiness check; kept for - call-site compatibility). - timeout: Maximum seconds to wait (None = infinite). - - Returns: - True if ready, False if timeout. - - Raises: - KeyError: If conversation_id not found in manager. - """ - state = self._conversations.get(conversation_id) - if state is None: - logger.error(f"Conversation {conversation_id} not found in manager") - raise KeyError(f"Conversation {conversation_id} not initialized") - - async with state.condition: - if timeout is None: - await state.condition.wait_for(state.is_ready_for_turn) - return True - try: - async with asyncio.timeout(timeout): - await state.condition.wait_for(state.is_ready_for_turn) - return True - except TimeoutError: - return state.is_ready_for_turn() - - async def wait_for_turn_issued( - self, - conversation_id: str, - min_issued: int, - timeout: float | None = None, - ) -> bool: - """Block until at least min_issued client turns have been issued. - - Args: - conversation_id: Conversation to wait for. - min_issued: Minimum number of issued turns to wait for. - timeout: Maximum seconds to wait (None = infinite). - - Returns: - True if condition met, False if timeout. - - Raises: - KeyError: If conversation_id not found (programming error β€” state must be - pre-created by the strategy before pipeline tasks are spawned). - """ - state = self._conversations[conversation_id] - predicate = lambda: state.issued_client_turns >= min_issued # noqa: E731 - async with state.condition: - if timeout is None: - await state.condition.wait_for(predicate) - return True - try: - async with asyncio.timeout(timeout): - await state.condition.wait_for(predicate) - return True - except TimeoutError: - return state.issued_client_turns >= min_issued - - async def mark_turn_issued( - self, - conversation_id: str, - turn: int, - message: dict[str, Any] | None = None, - ): - """Mark that a client turn has been issued (updates sequencing counters). - - Args: - conversation_id: Conversation ID. - turn: Turn number being issued. - message: Message dict to append to history (used when - use_dataset_history=False). - - Raises: - KeyError: If conversation_id not found in manager. - """ - state = self._conversations.get(conversation_id) - if state is None: - raise KeyError(f"Conversation {conversation_id} not initialized") - async with state.condition: - state.add_client_turn(turn, message) - state.condition.notify_all() + if conversation_id not in self._conversations: + initial_history: list[dict[str, Any]] = ( + [system_message] if system_message is not None else [] + ) + self._conversations[conversation_id] = ConversationState( + conversation_id=conversation_id, + expected_client_turns=expected_client_turns, + message_history=initial_history, + ) + return self._conversations[conversation_id] - async def mark_turn_complete( + def mark_turn_complete( self, conversation_id: str, response: str, store_in_history: bool = False, - ): - """Mark that assistant response has arrived. + ) -> None: + """Record a successful response and wake the pipeline task. Args: conversation_id: Conversation ID. - response: Model output (stored in history when store_in_history=True). + response: Model output (appended to history when store_in_history=True). store_in_history: When True, append response to message_history. Raises: - KeyError: If conversation_id not found in manager. + KeyError: If conversation_id not found. """ state = self._conversations.get(conversation_id) if state is None: raise KeyError(f"Conversation {conversation_id} not initialized") - async with state.condition: - state.add_assistant_turn(response if store_in_history else None) - state.condition.notify_all() + if store_in_history and response: + state.message_history.append({"role": "assistant", "content": response}) + state.completed_turns += 1 + if state.is_complete(): + if state.failed_turns > 0: + logger.info( + f"Conversation {conversation_id} completed with failures: " + f"{state.completed_turns - state.failed_turns}/" + f"{state.expected_client_turns} successful, " + f"{state.failed_turns} failed" + ) + else: + logger.debug( + f"Conversation {conversation_id} completed: " + f"{state.completed_turns}/{state.expected_client_turns} turns" + ) + state.turn_done.set() - async def mark_turn_failed( - self, conversation_id: str, store_in_history: bool = False - ): - """Mark that assistant response failed (error/timeout). + def mark_turn_failed( + self, + conversation_id: str, + store_in_history: bool = False, + ) -> None: + """Record a failed response and wake the pipeline task. - Failed turns still count toward conversation completion to ensure - turn sequencing progresses even under errors. + Failed turns count toward completion so sequencing progresses under errors. Args: conversation_id: Conversation ID. store_in_history: When True, append error placeholder to message_history. Raises: - KeyError: If conversation_id not found in manager. + KeyError: If conversation_id not found. """ state = self._conversations.get(conversation_id) if state is None: raise KeyError(f"Conversation {conversation_id} not initialized") - async with state.condition: - state.mark_turn_failed(store_in_history=store_in_history) - state.condition.notify_all() + if store_in_history: + state.message_history.append( + {"role": "assistant", "content": "[ERROR: Turn failed or timed out]"} + ) + state.completed_turns += 1 + state.failed_turns += 1 + logger.warning(f"Turn failed for conversation {conversation_id}") + if state.is_complete(): + logger.info( + f"Conversation {conversation_id} completed with failures: " + f"{state.completed_turns - state.failed_turns}/" + f"{state.expected_client_turns} successful, " + f"{state.failed_turns} failed" + ) + state.turn_done.set() diff --git a/src/inference_endpoint/load_generator/multi_turn_strategy.py b/src/inference_endpoint/load_generator/multi_turn_strategy.py index 48f5b45f..cfd418bd 100644 --- a/src/inference_endpoint/load_generator/multi_turn_strategy.py +++ b/src/inference_endpoint/load_generator/multi_turn_strategy.py @@ -22,7 +22,7 @@ from ..config.schema import MultiTurnConfig from ..core.types import QueryResult -from .conversation_manager import ConversationManager +from .conversation_manager import ConversationManager, ConversationState from .strategy import PhaseIssuerProtocol logger = logging.getLogger(__name__) @@ -48,11 +48,12 @@ class MultiTurnStrategy: The response routing path: 1. _conv_pipeline issues turn N via phase_issuer.issue(idx) β†’ query_id - 2. _conv_pipeline stores (conv_id, turn) in _inflight[query_id] + 2. _conv_pipeline stores conv_id in _inflight[query_id] 3. BenchmarkSession calls on_sample_complete(result) with the QueryResult 4. on_sample_complete looks up conv_id from _inflight, calls mark_turn_complete - 5. mark_turn_complete notifies the pipeline task waiting on wait_for_turn_ready - 6. _conv_pipeline proceeds to issue turn N+1 + 5. mark_turn_complete sets state.turn_done synchronously + 6. _conv_pipeline's await asyncio.wait_for(state.turn_done.wait()) returns + 7. Pipeline clears the event and issues turn N+1 """ def __init__( @@ -91,8 +92,9 @@ def __init__( ) # Maps query_id -> conversation_id for routing completions. - # Populated by _conv_pipeline after issue() returns query_id. self._inflight: dict[str, str] = {} + # Cached ConversationState refs for O(1) lookup in on_sample_complete. + self._conv_states: dict[str, ConversationState] = {} async def execute(self, phase_issuer: PhaseIssuerProtocol) -> int: """Drive multi-turn sample issuance. @@ -108,11 +110,21 @@ async def execute(self, phase_issuer: PhaseIssuerProtocol) -> int: conv_id = sample_meta["conversation_id"] conv_samples[conv_id].append((sample_index, sample_meta["turn"])) - # Pre-create all conversation states before spawning tasks. + # Pre-create all conversation states before spawning tasks (no locking needed). + sys_prompts = self._dataset_metadata.get("system_prompts_by_conv", {}) for conv_id, turns in conv_samples.items(): - await self._conv_manager.get_or_create( - conv_id, expected_client_turns=len(turns) + sys_content = sys_prompts.get(conv_id) if self._store_in_history else None + system_message = ( + {"role": "system", "content": sys_content} + if sys_content is not None + else None ) + state = self._conv_manager.get_or_create( + conv_id, + expected_client_turns=len(turns), + system_message=system_message, + ) + self._conv_states[conv_id] = state tasks = [ asyncio.create_task( @@ -133,59 +145,53 @@ async def _conv_pipeline( ) -> None: """Process all turns for a single conversation sequentially. - For each turn after the first, waits for the previous turn to complete - (via wait_for_turn_ready) before issuing the next. This enforces strict - sequential ordering: turn N+1 is not issued until turn N's response arrives. + For each turn after the first, waits for state.turn_done before issuing + the next. This enforces strict sequential ordering within the conversation. """ + state = self._conv_states[conv_id] sorted_turns = sorted(turns, key=lambda x: x[1]) for i, (idx, turn) in enumerate(sorted_turns): if i > 0: - # Wait for the previous turn to complete before issuing the next. - ready = await self._conv_manager.wait_for_turn_ready( - conv_id, turn, timeout=self._turn_timeout_s - ) - if not ready: + try: + await asyncio.wait_for( + state.turn_done.wait(), timeout=self._turn_timeout_s + ) + except TimeoutError: logger.warning( f"Turn {turn} of {conv_id} timed out waiting for previous turn" ) - await self._conv_manager.mark_turn_failed(conv_id) + state.failed_turns += 1 break + state.turn_done.clear() - # Acquire concurrency slot before issuing + # Acquire concurrency slot before issuing. if self._sem is not None: await self._sem.acquire() - # For live-history mode: build messages from accumulated history + current turn, - # and pass as data_override so the pre-built messages from the dataset are replaced. + # Live-history mode: build messages from accumulated history + current turn. data_override: dict[str, Any] | None = None - current_turn_message: dict[str, Any] | None = None + current_turn_messages: list[dict[str, Any]] | None = None if self._store_in_history: - pre_built = self._dataset_metadata.get( - "pre_built_messages_by_key", {} - ).get((conv_id, turn), []) - current_turn_message = pre_built[-1] if pre_built else None - state = self._conv_manager.get_state(conv_id) - if state is not None and current_turn_message is not None: - live_messages = state.message_history.copy() + [ - current_turn_message - ] + current_turn_messages = self._dataset_metadata.get( + "current_turn_messages_by_key", {} + ).get((conv_id, turn)) + if current_turn_messages: + live_messages = state.message_history.copy() + current_turn_messages data_override = {"messages": live_messages} query_id = phase_issuer.issue(idx, data_override=data_override) if query_id is None: - # Session stopping β€” release slot and exit + # Session stopping β€” release slot and exit. if self._sem is not None: self._sem.release() break - # Register this query_id -> conv_id mapping for response routing. self._inflight[query_id] = conv_id - # Mark the turn as issued so wait_for_turn_ready can gate the next turn. - await self._conv_manager.mark_turn_issued( - conv_id, turn, message=current_turn_message - ) + # Append current-turn messages to history so the next turn sees them. + if self._store_in_history and current_turn_messages: + state.message_history.extend(current_turn_messages) def on_query_complete(self, query_id: str) -> None: """Called by BenchmarkSession when a QueryResult arrives. @@ -203,27 +209,23 @@ def on_sample_complete(self, result: QueryResult) -> None: """Route completed QueryResult to ConversationManager. Called by execute.py on_sample_complete hook after each response. - Looks up the conversation_id from _inflight and calls mark_turn_complete. + Event.set() is synchronous β€” the pipeline task is woken immediately + without needing asyncio.ensure_future. Args: result: Completed QueryResult from the endpoint. """ - query_id = result.id - conv_id = self._inflight.pop(query_id, None) + conv_id = self._inflight.pop(result.id, None) if conv_id is None: return response_text = result.get_response_output_string() if result.error is not None: - asyncio.ensure_future( - self._conv_manager.mark_turn_failed( - conv_id, store_in_history=self._store_in_history - ) + self._conv_manager.mark_turn_failed( + conv_id, store_in_history=self._store_in_history ) else: - asyncio.ensure_future( - self._conv_manager.mark_turn_complete( - conv_id, response_text, store_in_history=self._store_in_history - ) + self._conv_manager.mark_turn_complete( + conv_id, response_text, store_in_history=self._store_in_history ) diff --git a/src/inference_endpoint/openai/openai_adapter.py b/src/inference_endpoint/openai/openai_adapter.py index 4830c682..9c6f6ebd 100644 --- a/src/inference_endpoint/openai/openai_adapter.py +++ b/src/inference_endpoint/openai/openai_adapter.py @@ -111,6 +111,7 @@ def to_endpoint_request(cls, query: Query) -> CreateChatCompletionRequest: stream=query.data.get("stream", False), max_completion_tokens=query.data.get("max_completion_tokens", 100), temperature=query.data.get("temperature", 0.7), + tools=query.data.get("tools"), ) return request diff --git a/tests/integration/test_multi_turn.py b/tests/integration/test_multi_turn.py index ca17a236..87351700 100644 --- a/tests/integration/test_multi_turn.py +++ b/tests/integration/test_multi_turn.py @@ -16,8 +16,7 @@ """Integration tests for multi-turn benchmarking end-to-end. Validates that MultiTurnDataset + MultiTurnStrategy + BenchmarkSession work -correctly together against a real HTTP echo server (echo tests) and a live -model endpoint (live tests at port 8868). +correctly together against a real HTTP echo server. Tests cover: 1. Dataset-history mode (use_dataset_history=True): pre-built messages are @@ -27,16 +26,12 @@ grow with each turn. 3. Multiple concurrent conversations complete successfully. 4. Turn ordering: turn N+1 is never issued before turn N completes. - 5. Live concurrency: parametrized target_concurrency levels against a real - model endpoint verify all turns complete regardless of throttle setting. """ import asyncio -import json import random import time from urllib.parse import urljoin -from urllib.request import urlopen import pandas as pd import pytest @@ -430,303 +425,210 @@ def on_complete(result: QueryResult) -> None: assert complete_times[q_turn3] >= complete_times[q_turn1] -# --------------------------------------------------------------------------- -# Live endpoint fixtures and helpers -# --------------------------------------------------------------------------- - -_LIVE_ENDPOINT = "http://localhost:8868" - - -def _query_model_name(endpoint: str) -> str: - """Return the first model name from the endpoint, or skip if unreachable.""" - try: - with urlopen(f"{endpoint}/v1/models", timeout=5.0) as resp: - data = json.loads(resp.read()) - return data["data"][0]["id"] - except Exception as e: - pytest.skip(f"Live endpoint {endpoint} not reachable: {e}") - return "" - - -def _make_live_rows( - model: str, n_conversations: int = 20, n_user_turns: int = 3 -) -> list[dict]: - """Build a multi-conversation dataset rows list. - - Each conversation has n_user_turns user turns interleaved with scripted - assistant placeholders (needed to satisfy the turn-structure validator but - never sent to the endpoint). The resulting dataset produces - n_conversations Γ— n_user_turns client-turn samples. - """ - rows = [] - _user_prompts = [ - "Reply with exactly one word: the number {n} in English.", - "Add one to the previous number. Reply with only that word.", - "Add one more. Reply with only that word.", +@pytest.mark.integration +@pytest.mark.asyncio +async def test_tool_use_conversation_all_turns_issued(echo_server): + """Tool-use conversation: all client turns (user + tool) are issued and completed.""" + tool_calls = [ + { + "id": "call_1", + "type": "function", + "function": {"name": "search", "arguments": '{"q": "test"}'}, + } + ] + tool_results = [{"tool_call_id": "call_1", "content": "search result"}] + tools = [ + { + "type": "function", + "function": { + "name": "search", + "description": "Search", + "parameters": { + "type": "object", + "properties": {"q": {"type": "string"}}, + "required": ["q"], + }, + }, + } ] - for i in range(n_conversations): - conv_id = f"live_conv_{i:03d}" - turn = 1 - for j in range(n_user_turns): - prompt = _user_prompts[j % len(_user_prompts)].format(n=i + 1) - rows.append( - { - "conversation_id": conv_id, - "turn": turn, - "role": "user", - "content": prompt, - "model": model, - "max_completion_tokens": 10, - } - ) - turn += 1 - if j < n_user_turns - 1: - rows.append( - { - "conversation_id": conv_id, - "turn": turn, - "role": "assistant", - "content": "placeholder", - } - ) - turn += 1 - return rows - - -async def _run_live_session( - model: str, - n_conversations: int, - n_user_turns: int, - target_concurrency: int | None, - timeout_s: float = 300.0, -) -> tuple[int, dict[str, str]]: - """Run a live multi-turn session against the endpoint at _LIVE_ENDPOINT. - - Returns (issued_count, {query_id: response_text}). - """ - rows = _make_live_rows(model, n_conversations, n_user_turns) - ds = MultiTurnDataset(dataframe=pd.DataFrame(rows)) - ds.load() - - mt_cfg = MultiTurnConfig( - turn_timeout_s=60.0, - use_dataset_history=True, - ) - strategy = MultiTurnStrategy( - conversation_manager=ConversationManager(), - dataset_metadata=ds.conversation_metadata, - multi_turn_config=mt_cfg, - target_concurrency=target_concurrency, - ) - - loop = asyncio.get_running_loop() - responses: dict[str, str] = {} - - def on_complete(result: QueryResult) -> None: - strategy.on_sample_complete(result) - responses[result.id] = result.get_response_output_string() - - http_config = HTTPClientConfig( - endpoint_urls=[f"{_LIVE_ENDPOINT}/v1/chat/completions"], - warmup_connections=0, - num_workers=4, - ) - http_client = await HTTPEndpointClient.create(http_config, loop) - issuer = HttpClientSampleIssuer(http_client) - try: - session = BenchmarkSession( - issuer=issuer, - event_publisher=_NoOpPublisher(), - loop=loop, - on_sample_complete=on_complete, - ) - rt = RuntimeSettings( - metrics.Throughput(1000), - [metrics.Throughput(1000)], - min_duration_ms=0, - max_duration_ms=int(timeout_s * 1000), - n_samples_from_dataset=ds.num_samples(), - n_samples_to_issue=ds.num_samples(), - min_sample_count=1, - rng_sched=random.Random(42), - rng_sample_index=random.Random(42), - load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), - ) - phase = PhaseConfig("perf", rt, ds, PhaseType.PERFORMANCE, strategy=strategy) - result = await asyncio.wait_for(session.run([phase]), timeout=timeout_s) - return result.perf_results[0].issued_count, responses - finally: - await http_client.shutdown_async() + rows = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "Find something", + "tools": tools, + }, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": None, + "tool_calls": tool_calls, + }, + { + "conversation_id": "c1", + "turn": 3, + "role": "tool", + "tool_results": tool_results, + "tools": tools, + }, + { + "conversation_id": "c1", + "turn": 4, + "role": "assistant", + "content": "Here is the result", + }, + {"conversation_id": "c1", "turn": 5, "role": "user", "content": "Thanks"}, + ] + ds = _make_dataset(rows) + strategy = _make_strategy(ds) + responses: dict = {} + count = await _run_session(echo_server.url, ds, strategy, responses) -# --------------------------------------------------------------------------- -# Live concurrency tests -# --------------------------------------------------------------------------- + # Client turns: turn 1 (user) + turn 3 (tool) + turn 5 (user) = 3 + assert count == 3 + assert len(responses) == 3 @pytest.mark.integration @pytest.mark.asyncio -@pytest.mark.parametrize( - "target_concurrency", - [ - pytest.param(1, id="concurrency_1"), - pytest.param(4, id="concurrency_4"), - pytest.param(None, id="concurrency_unlimited"), - ], -) -async def test_live_concurrency(target_concurrency): - """All turns of 20 concurrent conversations complete for each concurrency level. +async def test_conversation_ending_with_tool_row(echo_server): + """Conversation ending with a tool row completes normally (matches agentic_coding dataset pattern).""" + tool_calls = [ + { + "id": "call_1", + "type": "function", + "function": {"name": "write_file", "arguments": '{"path": "out.py"}'}, + } + ] + tool_results = [{"tool_call_id": "call_1", "content": "file written"}] + tools = [ + { + "type": "function", + "function": { + "name": "write_file", + "description": "Write a file", + "parameters": { + "type": "object", + "properties": {"path": {"type": "string"}}, + "required": ["path"], + }, + }, + } + ] - Uses the live model endpoint at port 8868. Each conversation has 3 user - turns (60 total requests). Verifies that every turn receives a non-empty - response regardless of the concurrency throttle applied by target_concurrency. - """ - model = _query_model_name(_LIVE_ENDPOINT) - n_conversations = 20 - n_user_turns = 3 - expected_turns = n_conversations * n_user_turns # 60 total requests - - issued, responses = await _run_live_session( - model=model, - n_conversations=n_conversations, - n_user_turns=n_user_turns, - target_concurrency=target_concurrency, - timeout_s=300.0, - ) + rows = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "Write a file", + "tools": tools, + }, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": None, + "tool_calls": tool_calls, + }, + { + "conversation_id": "c1", + "turn": 3, + "role": "tool", + "tool_results": tool_results, + "tools": tools, + }, + ] + ds = _make_dataset(rows) + strategy = _make_strategy(ds) + responses: dict = {} - assert issued == expected_turns, f"Expected {expected_turns} issued, got {issued}" - assert ( - len(responses) == expected_turns - ), f"Expected {expected_turns} responses, got {len(responses)}" - for qid, text in responses.items(): - assert text.strip(), f"Query {qid} returned empty response" + count = await _run_session(echo_server.url, ds, strategy, responses) + + # Client turns: turn 1 (user) + turn 3 (tool) = 2 + assert count == 2 + assert len(responses) == 2 @pytest.mark.integration @pytest.mark.asyncio -async def test_live_turn_ordering_multi_conversation(): - """Turn N+1 of each conversation is always issued after turn N completes. - - Runs 10 conversations with 3 turns each concurrently (30 total requests). - Records per-query completion timestamps and asserts that within every - conversation each successive turn completes no earlier than the previous. - """ - model = _query_model_name(_LIVE_ENDPOINT) - n_conversations = 10 - n_user_turns = 3 - rows = _make_live_rows(model, n_conversations, n_user_turns) - - ds = MultiTurnDataset(dataframe=pd.DataFrame(rows)) - ds.load() - - conv_manager = ConversationManager() - mt_cfg = MultiTurnConfig(turn_timeout_s=60.0, use_dataset_history=True) - strategy = MultiTurnStrategy( - conversation_manager=conv_manager, - dataset_metadata=ds.conversation_metadata, - multi_turn_config=mt_cfg, - ) - - complete_times: dict[str, float] = {} - orig_on_sample_complete = strategy.on_sample_complete - - def tracked_complete(result: QueryResult) -> None: - complete_times[result.id] = time.monotonic() - orig_on_sample_complete(result) - - strategy.on_sample_complete = tracked_complete - - loop = asyncio.get_running_loop() - responses: dict[str, str] = {} +async def test_tools_field_forwarded_to_endpoint(echo_server): + """The 'tools' array from the dataset reaches the endpoint in every request payload.""" + received_payloads: list[dict] = [] - http_config = HTTPClientConfig( - endpoint_urls=[f"{_LIVE_ENDPOINT}/v1/chat/completions"], - warmup_connections=0, - num_workers=4, - ) - http_client = await HTTPEndpointClient.create(http_config, loop) - issuer = HttpClientSampleIssuer(http_client) + class CapturingEchoServer(EchoServer): + async def _handle_echo_chat_completions_request(self, request): + try: + payload = await request.json() + received_payloads.append(payload) + except Exception: + pass + return await super()._handle_echo_chat_completions_request(request) + server = CapturingEchoServer(port=0) + server.start() try: + tool_calls = [ + { + "id": "call_1", + "type": "function", + "function": {"name": "search", "arguments": '{"q": "hello"}'}, + } + ] + tool_results = [{"tool_call_id": "call_1", "content": "result"}] + tools = [ + { + "type": "function", + "function": { + "name": "search", + "description": "Search", + "parameters": { + "type": "object", + "properties": {"q": {"type": "string"}}, + "required": ["q"], + }, + }, + } + ] - def on_complete(result: QueryResult) -> None: - tracked_complete(result) - responses[result.id] = result.get_response_output_string() - - session = BenchmarkSession( - issuer=issuer, - event_publisher=_NoOpPublisher(), - loop=loop, - on_sample_complete=on_complete, - ) - rt = RuntimeSettings( - metrics.Throughput(1000), - [metrics.Throughput(1000)], - min_duration_ms=0, - max_duration_ms=300_000, - n_samples_from_dataset=ds.num_samples(), - n_samples_to_issue=ds.num_samples(), - min_sample_count=1, - rng_sched=random.Random(42), - rng_sample_index=random.Random(42), - load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), - ) - phase = PhaseConfig("perf", rt, ds, PhaseType.PERFORMANCE, strategy=strategy) - result = await asyncio.wait_for(session.run([phase]), timeout=300.0) - finally: - await http_client.shutdown_async() - - expected_total = n_conversations * n_user_turns - assert result.perf_results[0].issued_count == expected_total - - # Build index β†’ query_id map and verify per-conversation ordering. - # Samples are grouped by conversation, turns sorted ascending within each: - # conv_0_t1, conv_0_t2, conv_0_t3, conv_1_t1, ... - uuid_to_index = result.perf_results[0].uuid_to_index - index_to_query = {v: k for k, v in uuid_to_index.items()} - - for conv_i in range(n_conversations): - base = conv_i * n_user_turns - for turn_j in range(n_user_turns - 1): - q_cur = index_to_query[base + turn_j] - q_next = index_to_query[base + turn_j + 1] - assert complete_times[q_cur] <= complete_times[q_next], ( - f"conv {conv_i}: turn {turn_j + 2} completed before turn {turn_j + 1} " - f"(t{turn_j + 1}={complete_times[q_cur]:.4f}, " - f"t{turn_j + 2}={complete_times[q_next]:.4f})" - ) - + rows = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "Search for hello", + "tools": tools, + }, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": None, + "tool_calls": tool_calls, + }, + { + "conversation_id": "c1", + "turn": 3, + "role": "tool", + "tool_results": tool_results, + "tools": tools, + }, + ] + ds = _make_dataset(rows) + strategy = _make_strategy(ds, use_dataset_history=True) + responses: dict = {} -@pytest.mark.integration -@pytest.mark.asyncio -async def test_live_large_concurrency(): - """All turns complete correctly under a large concurrency limit (>=512). - - Uses 200 conversations Γ— 3 turns = 600 total requests with - target_concurrency=512. The semaphore allows up to 512 simultaneous - in-flight requests, so the first wave of 200 first-turns is issued - without throttling, and subsequent turns queue naturally. Verifies - that all 600 turns complete and return non-empty responses, confirming - the semaphore implementation handles large values without deadlock or - starvation. - """ - model = _query_model_name(_LIVE_ENDPOINT) - n_conversations = 200 - n_user_turns = 3 - expected_turns = n_conversations * n_user_turns # 600 total requests - - issued, responses = await _run_live_session( - model=model, - n_conversations=n_conversations, - n_user_turns=n_user_turns, - target_concurrency=512, - timeout_s=300.0, - ) + count = await _run_session(server.url, ds, strategy, responses) + assert count == 2 - assert issued == expected_turns, f"Expected {expected_turns} issued, got {issued}" - assert ( - len(responses) == expected_turns - ), f"Expected {expected_turns} responses, got {len(responses)}" - for qid, text in responses.items(): - assert text.strip(), f"Query {qid} returned empty response" + assert len(received_payloads) == 2 + for payload in received_payloads: + assert "tools" in payload + assert len(payload["tools"]) == 1 + assert payload["tools"][0]["function"]["name"] == "search" + finally: + server.stop() diff --git a/tests/unit/config/test_schema.py b/tests/unit/config/test_schema.py index 845711da..b1d29cfd 100644 --- a/tests/unit/config/test_schema.py +++ b/tests/unit/config/test_schema.py @@ -463,3 +463,122 @@ def test_openai_completions_endpoint_resolves_adapter(self): assert config.settings.client.api_type is APIType.OPENAI_COMPLETIONS assert config.settings.client.adapter is OpenAITextCompletionsAdapter assert config.settings.client.accumulator is OpenAISSEAccumulator + +class TestMultiTurnValidation: + """Tests for multi-turn config validation and cross-validation.""" + + def _make_online_multi_turn(self, concurrency: int | None = 4, **ds_kwargs): + lp: dict = {"type": "multi_turn"} + if concurrency is not None: + lp["target_concurrency"] = concurrency + return { + "type": TestType.ONLINE, + "model_params": {"name": "M"}, + "endpoint_config": {"endpoints": ["http://x"]}, + "datasets": [{"path": "D", "multi_turn": {}, **ds_kwargs}], + "settings": {"load_pattern": lp}, + } + + @pytest.mark.unit + def test_multi_turn_valid_config(self): + config = BenchmarkConfig(**self._make_online_multi_turn(concurrency=16)) + from inference_endpoint.config.schema import LoadPatternType + + assert config.settings.load_pattern.type == LoadPatternType.MULTI_TURN + assert config.settings.load_pattern.target_concurrency == 16 + + @pytest.mark.unit + def test_multi_turn_requires_target_concurrency(self): + with pytest.raises(ValueError, match="Multi-turn requires --concurrency"): + BenchmarkConfig(**self._make_online_multi_turn(concurrency=None)) + + @pytest.mark.unit + def test_multi_turn_without_multi_turn_dataset_rejected(self): + with pytest.raises(ValueError, match="requires at least one dataset"): + BenchmarkConfig( + type=TestType.ONLINE, + model_params={"name": "M"}, + endpoint_config={"endpoints": ["http://x"]}, + datasets=[{"path": "D"}], + settings={ + "load_pattern": {"type": "multi_turn", "target_concurrency": 4} + }, + ) + + @pytest.mark.unit + def test_multi_turn_dataset_without_multi_turn_load_pattern_rejected(self): + with pytest.raises(ValueError, match="require load_pattern.type=multi_turn"): + BenchmarkConfig( + type=TestType.ONLINE, + model_params={"name": "M"}, + endpoint_config={"endpoints": ["http://x"]}, + datasets=[{"path": "D", "multi_turn": {}}], + settings={"load_pattern": {"type": "poisson", "target_qps": 10}}, + ) + + +class TestMultiTurnTotalSamples: + """Tests for total_samples_to_issue() with multi_turn load pattern.""" + + @pytest.mark.unit + def test_multi_turn_uses_dataset_size_ignoring_duration(self): + from inference_endpoint.config.runtime_settings import RuntimeSettings + + config = BenchmarkConfig( + type=TestType.ONLINE, + model_params={"name": "M"}, + endpoint_config={"endpoints": ["http://x"]}, + datasets=[{"path": "D", "multi_turn": {}}], + settings={ + "load_pattern": {"type": "multi_turn", "target_concurrency": 4}, + "runtime": {"min_duration_ms": 600000}, + }, + ) + rt = RuntimeSettings.from_config(config, dataloader_num_samples=4316) + assert rt.total_samples_to_issue() == 4316 + + @pytest.mark.unit + def test_multi_turn_respects_min_sample_count(self): + import random + + from inference_endpoint import metrics + from inference_endpoint.config.runtime_settings import RuntimeSettings + from inference_endpoint.config.schema import LoadPattern, LoadPatternType + + lp = LoadPattern(type=LoadPatternType.MULTI_TURN, target_concurrency=4) + rt = RuntimeSettings( + metric_target=metrics.Throughput(10.0), + reported_metrics=[metrics.Throughput(10.0)], + min_duration_ms=600000, + max_duration_ms=None, + n_samples_from_dataset=5, + n_samples_to_issue=None, + min_sample_count=100, + rng_sched=random.Random(0), + rng_sample_index=random.Random(0), + load_pattern=lp, + ) + assert rt.total_samples_to_issue() == 100 + + @pytest.mark.unit + def test_multi_turn_explicit_n_samples_takes_precedence(self): + import random + + from inference_endpoint import metrics + from inference_endpoint.config.runtime_settings import RuntimeSettings + from inference_endpoint.config.schema import LoadPattern, LoadPatternType + + lp = LoadPattern(type=LoadPatternType.MULTI_TURN, target_concurrency=4) + rt = RuntimeSettings( + metric_target=metrics.Throughput(10.0), + reported_metrics=[metrics.Throughput(10.0)], + min_duration_ms=600000, + max_duration_ms=None, + n_samples_from_dataset=4316, + n_samples_to_issue=200, + min_sample_count=1, + rng_sched=random.Random(0), + rng_sample_index=random.Random(0), + load_pattern=lp, + ) + assert rt.total_samples_to_issue() == 200 diff --git a/tests/unit/dataset_manager/test_multi_turn_dataset.py b/tests/unit/dataset_manager/test_multi_turn_dataset.py index a42fc1f3..a93a3a08 100644 --- a/tests/unit/dataset_manager/test_multi_turn_dataset.py +++ b/tests/unit/dataset_manager/test_multi_turn_dataset.py @@ -1061,3 +1061,366 @@ def test_messages_with_tool_sequence_terminal_assistant(): # The terminal assistant at turn 4 should be included assistant_msgs = [m for m in msgs if m["role"] == "assistant" and m.get("content")] assert any(m["content"] == "The weather is 22Β°C." for m in assistant_msgs) + + +# ============================================================================ +# Tool-use flat dataset regression tests (BUG 1, BUG 2, BUG 3) +# ============================================================================ + + +@pytest.mark.unit +def test_prior_tool_row_expanded_with_tool_call_id(): + """Prior tool rows must expand to messages with tool_call_id and content (BUG 1).""" + df = _make_tool_sequence_df() + ds = MultiTurnDataset(df) + pbm = ds.conversation_metadata["pre_built_messages_by_key"] + + # Client turn 3 (user, t=5) has a prior tool row at t=3. + # msgs_t5[3] should be the expanded tool message with proper fields. + msgs_t5 = pbm[("c1", 5)] + tool_msgs = [m for m in msgs_t5 if m.get("role") == "tool"] + assert len(tool_msgs) == 1 + assert tool_msgs[0]["tool_call_id"] == "call_c1_0" + assert tool_msgs[0]["content"] == '{"temp": 22}' + + +@pytest.mark.unit +def test_prior_parallel_tool_results_expand_to_multiple_messages(): + """Prior turn with 2 parallel tool_results expands to 2 tool messages.""" + df = pd.DataFrame( + [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "Hi"}, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "c_0", + "type": "function", + "function": {"name": "f1", "arguments": "{}"}, + }, + { + "id": "c_1", + "type": "function", + "function": {"name": "f2", "arguments": "{}"}, + }, + ], + }, + { + "conversation_id": "c1", + "turn": 3, + "role": "tool", + "tool_results": [ + {"tool_call_id": "c_0", "content": "r1"}, + {"tool_call_id": "c_1", "content": "r2"}, + ], + }, + { + "conversation_id": "c1", + "turn": 4, + "role": "assistant", + "content": "Done", + }, + {"conversation_id": "c1", "turn": 5, "role": "user", "content": "Ok"}, + ] + ) + ds = MultiTurnDataset(df) + pbm = ds.conversation_metadata["pre_built_messages_by_key"] + + # user(5) sees prior rows: user(1), assistant(2), tool(3)x2, assistant(4) + msgs_t5 = pbm[("c1", 5)] + tool_msgs = [m for m in msgs_t5 if m.get("role") == "tool"] + assert len(tool_msgs) == 2 + assert tool_msgs[0]["tool_call_id"] == "c_0" + assert tool_msgs[0]["content"] == "r1" + assert tool_msgs[1]["tool_call_id"] == "c_1" + assert tool_msgs[1]["content"] == "r2" + + +@pytest.mark.unit +def test_assistant_content_null_preserved_in_history(): + """Assistant messages with tool_calls and content:null include content key (BUG 2).""" + df = _make_tool_sequence_df() + ds = MultiTurnDataset(df) + pbm = ds.conversation_metadata["pre_built_messages_by_key"] + + # Client turn 2 (tool, t=3): prior includes assistant(2) with tool_calls + content: null + msgs_t3 = pbm[("c1", 3)] + asst_msg = msgs_t3[2] + assert asst_msg["role"] == "assistant" + assert "tool_calls" in asst_msg + assert "content" in asst_msg + assert asst_msg["content"] is None + + # Also verify in user(5)'s history + msgs_t5 = pbm[("c1", 5)] + asst_tc_msg = msgs_t5[2] + assert asst_tc_msg["role"] == "assistant" + assert "tool_calls" in asst_tc_msg + assert "content" in asst_tc_msg + assert asst_tc_msg["content"] is None + + +@pytest.mark.unit +def test_jsonl_round_trip_with_tools_field(): + """Load from JSONL tmpfile with tools field; verify tools survives to sample dict.""" + data = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "Run the test", + "tools": [ + { + "type": "function", + "function": { + "name": "bash", + "description": "Run a bash command", + "parameters": {"type": "object", "properties": {}}, + }, + } + ], + }, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "tc_0", + "type": "function", + "function": {"name": "bash", "arguments": '{"cmd": "ls"}'}, + } + ], + }, + { + "conversation_id": "c1", + "turn": 3, + "role": "tool", + "tool_results": [{"tool_call_id": "tc_0", "content": "file1.py"}], + "tools": [ + { + "type": "function", + "function": { + "name": "bash", + "description": "Run a bash command", + "parameters": {"type": "object", "properties": {}}, + }, + } + ], + }, + { + "conversation_id": "c1", + "turn": 4, + "role": "assistant", + "content": "The directory contains file1.py", + }, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + dataset = MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + dataset.load() + + # user(1) has tools + s0 = dataset.load_sample(0) + assert "tools" in s0 + assert len(s0["tools"]) == 1 + assert s0["tools"][0]["function"]["name"] == "bash" + + # tool(3) also has tools + s1 = dataset.load_sample(1) + assert "tools" in s1 + assert s1["tools"][0]["function"]["name"] == "bash" + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_current_turn_messages_by_key_parallel_tools(): + """current_turn_messages_by_key stores all expanded messages for a tool turn.""" + df = pd.DataFrame( + [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "Go"}, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "c_0", + "type": "function", + "function": {"name": "f1", "arguments": "{}"}, + }, + { + "id": "c_1", + "type": "function", + "function": {"name": "f2", "arguments": "{}"}, + }, + ], + }, + { + "conversation_id": "c1", + "turn": 3, + "role": "tool", + "tool_results": [ + {"tool_call_id": "c_0", "content": "r1"}, + {"tool_call_id": "c_1", "content": "r2"}, + ], + }, + { + "conversation_id": "c1", + "turn": 4, + "role": "assistant", + "content": "Done", + }, + ] + ) + ds = MultiTurnDataset(df) + ctm = ds.conversation_metadata["current_turn_messages_by_key"] + + # user(1) current turn is 1 message + assert len(ctm[("c1", 1)]) == 1 + assert ctm[("c1", 1)][0] == {"role": "user", "content": "Go"} + + # tool(3) current turn has 2 expanded messages (parallel tool_results) + assert len(ctm[("c1", 3)]) == 2 + assert ctm[("c1", 3)][0]["tool_call_id"] == "c_0" + assert ctm[("c1", 3)][1]["tool_call_id"] == "c_1" + + +# ============================================================================ +# Fix 1: system_prompts_by_conv in metadata (live-history mode) +# ============================================================================ + + +@pytest.mark.unit +def test_metadata_contains_system_prompts_by_conv(): + """_build_metadata exposes system_prompts_by_conv keyed by conversation_id.""" + data = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "Hi", + "system": "Be concise", + }, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "Ok"}, + # c2 has no system prompt + {"conversation_id": "c2", "turn": 1, "role": "user", "content": "Hello"}, + ] + df = pd.DataFrame(data) + ds = MultiTurnDataset(df) + + spc = ds.conversation_metadata["system_prompts_by_conv"] + assert spc["c1"] == "Be concise" + assert spc["c2"] is None + + +@pytest.mark.unit +def test_metadata_system_prompts_multiple_convs(): + """Each conversation gets its own system prompt entry.""" + data = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "A", + "system": "Sys1", + }, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "B"}, + { + "conversation_id": "c2", + "turn": 1, + "role": "user", + "content": "C", + "system": "Sys2", + }, + {"conversation_id": "c2", "turn": 2, "role": "assistant", "content": "D"}, + ] + df = pd.DataFrame(data) + ds = MultiTurnDataset(df) + + spc = ds.conversation_metadata["system_prompts_by_conv"] + assert spc["c1"] == "Sys1" + assert spc["c2"] == "Sys2" + + +# ============================================================================ +# Fix 2: tool_results / tool_calls stripped from sample dicts +# ============================================================================ + + +@pytest.mark.unit +def test_tool_results_not_in_sample_dict(): + """tool_results must not appear in the pre-baked sample dict for tool turns.""" + df = _make_tool_sequence_df() + ds = MultiTurnDataset(df) + ds.load() + + # Sample 1 is the tool turn (turn 3) + s1 = ds.load_sample(1) + assert s1["role"] == "tool" + assert "tool_results" not in s1 + + +@pytest.mark.unit +def test_tool_calls_not_in_sample_dict(): + """tool_calls must not appear in sample dicts (only relevant on assistant rows).""" + data = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "Go", + "tool_calls": [ + {"id": "bad", "type": "function", "function": {"name": "f"}} + ], + }, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "Done"}, + ] + df = pd.DataFrame(data) + ds = MultiTurnDataset(df) + ds.load() + + s0 = ds.load_sample(0) + assert "tool_calls" not in s0 + + +# ============================================================================ +# Fix 3: no dead current_turn_message / system_content fields in sample dicts +# ============================================================================ + + +@pytest.mark.unit +def test_no_dead_current_turn_message_field(): + """current_turn_message must not appear in pre-baked sample dicts.""" + df = _make_tool_sequence_df() + ds = MultiTurnDataset(df) + ds.load() + + for i in range(ds.num_samples()): + s = ds.load_sample(i) + assert ( + "current_turn_message" not in s + ), f"Sample {i} has dead field current_turn_message" + + +@pytest.mark.unit +def test_no_dead_system_content_field(): + """system_content must not appear in pre-baked sample dicts.""" + df = _make_tool_sequence_df() + ds = MultiTurnDataset(df) + ds.load() + + for i in range(ds.num_samples()): + s = ds.load_sample(i) + assert "system_content" not in s, f"Sample {i} has dead field system_content" diff --git a/tests/unit/load_generator/test_multi_turn_conversation_manager.py b/tests/unit/load_generator/test_multi_turn_conversation_manager.py index 62602626..331e6709 100644 --- a/tests/unit/load_generator/test_multi_turn_conversation_manager.py +++ b/tests/unit/load_generator/test_multi_turn_conversation_manager.py @@ -14,7 +14,7 @@ # limitations under the License. import asyncio -import logging +import inspect import pytest from inference_endpoint.load_generator.conversation_manager import ( @@ -25,334 +25,253 @@ @pytest.mark.unit def test_conversation_state_initialization(): - """Test ConversationState initializes with correct default values.""" + """ConversationState initializes with correct defaults.""" state = ConversationState(conversation_id="conv_001") assert state.conversation_id == "conv_001" - assert state.current_turn == 0 - assert state.pending_client_turn is None + assert not state.turn_done.is_set() + assert state.message_history == [] + assert state.completed_turns == 0 + assert state.failed_turns == 0 + assert state.expected_client_turns is None @pytest.mark.unit -def test_conversation_state_add_client_turn(): - """Test adding a client turn updates sequencing state.""" +def test_conversation_state_is_complete_without_expected(): + """is_complete() returns False when expected_client_turns is None.""" state = ConversationState(conversation_id="conv_001") - - state.add_client_turn(1) - - assert state.pending_client_turn == 1 - assert state.issued_client_turns == 1 - assert state.current_turn == 0 # Not incremented until assistant response - - -@pytest.mark.unit -def test_conversation_state_add_assistant_turn(): - """Test adding assistant turn completes turn cycle.""" - state = ConversationState(conversation_id="conv_001") - - state.add_client_turn(1) - state.add_assistant_turn() - - assert state.current_turn == 2 - assert state.pending_client_turn is None - assert state.completed_client_turns == 1 + assert not state.is_complete() + state.completed_turns = 5 + assert not state.is_complete() @pytest.mark.unit -def test_conversation_state_late_response_after_complete_is_silently_ignored(caplog): - """Late response for a conversation that already completed is silently dropped.""" - state = ConversationState(conversation_id="conv_001", expected_client_turns=1) - - state.add_client_turn(1) - state.add_assistant_turn() +def test_conversation_state_is_complete_with_expected(): + """is_complete() returns True once completed_turns >= expected.""" + state = ConversationState(conversation_id="conv_001", expected_client_turns=2) + assert not state.is_complete() + state.completed_turns = 1 + assert not state.is_complete() + state.completed_turns = 2 assert state.is_complete() - completed_before = state.completed_client_turns - current_turn_before = state.current_turn - - with caplog.at_level(logging.WARNING): - state.add_assistant_turn() - - assert state.completed_client_turns == completed_before - assert state.current_turn == current_turn_before - assert "no pending client turn" not in caplog.text - @pytest.mark.unit -def test_conversation_state_is_ready_for_turn(): - """Test turn readiness checks using completion counts.""" - state = ConversationState(conversation_id="conv_001") - - assert not state.is_ready_for_turn() - - state.add_client_turn(1) - assert not state.is_ready_for_turn() - - state.add_assistant_turn() - assert state.is_ready_for_turn() - - state.add_client_turn(2) - assert not state.is_ready_for_turn() - - state.add_assistant_turn() - assert state.is_ready_for_turn() - - -@pytest.mark.unit -def test_conversation_state_multi_turn_sequence(): - """Test multi-turn conversation flow updates current_turn correctly.""" - state = ConversationState(conversation_id="conv_001") - - state.add_client_turn(1) - state.add_assistant_turn() - assert state.current_turn == 2 - - state.add_client_turn(3) - state.add_assistant_turn() - assert state.current_turn == 4 - - state.add_client_turn(5) - state.add_assistant_turn() - assert state.current_turn == 6 +def test_create_is_synchronous(): + """get_or_create() must be a plain function, not a coroutine.""" + manager = ConversationManager() + result = manager.get_or_create("conv_001") + assert not inspect.iscoroutine(result), "get_or_create returned a coroutine" + assert isinstance(result, ConversationState) @pytest.mark.unit -@pytest.mark.asyncio -async def test_conversation_manager_get_or_create(): - """Test get_or_create returns same state for same conversation_id.""" +def test_conversation_manager_get_or_create(): + """get_or_create returns the same state for the same conversation_id.""" manager = ConversationManager() - state1 = await manager.get_or_create("conv_001") - state2 = await manager.get_or_create("conv_001") + state1 = manager.get_or_create("conv_001") + state2 = manager.get_or_create("conv_001") assert state1 is state2 assert state1.conversation_id == "conv_001" @pytest.mark.unit -@pytest.mark.asyncio -async def test_conversation_manager_multiple_conversations(): - """Test manager can track multiple conversations independently.""" +def test_conversation_manager_multiple_conversations(): + """Manager tracks multiple conversations independently.""" manager = ConversationManager() - state1 = await manager.get_or_create("conv_001") - state2 = await manager.get_or_create("conv_002") + state1 = manager.get_or_create("conv_001") + state2 = manager.get_or_create("conv_002") assert state1 is not state2 - await manager.mark_turn_issued("conv_001", 1) - await manager.mark_turn_complete("conv_001", "Response to conv_001") + manager.mark_turn_complete("conv_001", "Response to conv_001") - assert state1.current_turn == 2 - assert state2.current_turn == 0 + assert state1.completed_turns == 1 + assert state2.completed_turns == 0 @pytest.mark.unit -@pytest.mark.asyncio -async def test_conversation_manager_mark_turn_issued(): - """Test mark_turn_issued updates sequencing state.""" +def test_conversation_manager_mark_turn_complete(): + """mark_turn_complete increments counter, appends history, sets event.""" manager = ConversationManager() - state = await manager.get_or_create("conv_001") + state = manager.get_or_create("conv_001") - await manager.mark_turn_issued("conv_001", 1) + manager.mark_turn_complete("conv_001", "Assistant response") - assert state.pending_client_turn == 1 - assert state.issued_client_turns == 1 + assert state.completed_turns == 1 + assert state.failed_turns == 0 + assert state.turn_done.is_set() + assert state.message_history == [] # store_in_history=False by default @pytest.mark.unit -@pytest.mark.asyncio -async def test_conversation_manager_mark_turn_complete(): - """Test mark_turn_complete updates sequencing state.""" +def test_conversation_manager_mark_turn_complete_stores_history(): + """mark_turn_complete appends to history when store_in_history=True.""" manager = ConversationManager() - state = await manager.get_or_create("conv_001") + state = manager.get_or_create("conv_001") - await manager.mark_turn_issued("conv_001", 1) - await manager.mark_turn_complete("conv_001", "Assistant response") + manager.mark_turn_complete("conv_001", "Hello", store_in_history=True) - assert state.current_turn == 2 - assert state.pending_client_turn is None - assert state.completed_client_turns == 1 + assert state.message_history == [{"role": "assistant", "content": "Hello"}] @pytest.mark.unit -@pytest.mark.asyncio -async def test_conversation_manager_wait_for_turn_ready_immediate(): - """Test wait_for_turn_ready returns immediately when previous turn is complete.""" +def test_conversation_manager_mark_turn_failed(): + """mark_turn_failed increments both counters and sets event.""" manager = ConversationManager() - await manager.get_or_create("conv_001") - - await manager.mark_turn_issued("conv_001", 1) - await manager.mark_turn_complete("conv_001", "First response") + state = manager.get_or_create("conv_001", expected_client_turns=2) - result = await manager.wait_for_turn_ready("conv_001", 9, timeout=1.0) + manager.mark_turn_failed("conv_001") - assert result is True + assert state.completed_turns == 1 + assert state.failed_turns == 1 + assert state.turn_done.is_set() @pytest.mark.unit -@pytest.mark.asyncio -async def test_conversation_manager_wait_for_turn_ready_blocking(): - """Test wait_for_turn_ready blocks until previous turn completes.""" +def test_conversation_completion_tracking(): + """is_complete() returns True after all expected turns receive responses.""" manager = ConversationManager() - await manager.get_or_create("conv_001") - - await manager.mark_turn_issued("conv_001", 1) - - ready_flag = [] - - async def waiter(): - result = await manager.wait_for_turn_ready("conv_001", 3, timeout=2.0) - if result: - ready_flag.append(True) + state = manager.get_or_create("conv_001", expected_client_turns=2) - waiter_task = asyncio.create_task(waiter()) - await asyncio.sleep(0.05) - assert not ready_flag - - await manager.mark_turn_complete("conv_001", "Assistant response") - await asyncio.sleep(0.05) - await waiter_task - - assert ready_flag == [True] + assert not state.is_complete() + manager.mark_turn_complete("conv_001", "r1") + assert not state.is_complete() + manager.mark_turn_complete("conv_001", "r2") + assert state.is_complete() @pytest.mark.unit -@pytest.mark.asyncio -async def test_conversation_manager_wait_for_turn_ready_timeout(): - """Test wait_for_turn_ready respects timeout.""" +def test_conversation_completion_without_expected_turns(): + """Completion is never True when expected_client_turns is None.""" manager = ConversationManager() - await manager.get_or_create("conv_001") + state = manager.get_or_create("conv_001", expected_client_turns=None) - await manager.mark_turn_issued("conv_001", 1) + manager.mark_turn_complete("conv_001", "r1") - result = await manager.wait_for_turn_ready("conv_001", 3, timeout=0.1) - - assert result is False + assert not state.is_complete() @pytest.mark.unit -@pytest.mark.asyncio -async def test_conversation_completion_tracking(): - """Test conversation completion detection.""" +def test_conversation_completion_with_failures(): + """Conversations complete even when some turns fail.""" manager = ConversationManager() + state = manager.get_or_create("conv1", expected_client_turns=3) - state = await manager.get_or_create("conv_001", expected_client_turns=2) - + manager.mark_turn_complete("conv1", "Hi") assert not state.is_complete() - await manager.mark_turn_issued("conv_001", 1) + manager.mark_turn_failed("conv1") assert not state.is_complete() - await manager.mark_turn_complete("conv_001", "response 1") - assert not state.is_complete() - - await manager.mark_turn_issued("conv_001", 3) - await manager.mark_turn_complete("conv_001", "response 2") - + manager.mark_turn_complete("conv1", "Bye") assert state.is_complete() + assert state.failed_turns == 1 + assert state.completed_turns == 3 @pytest.mark.unit -@pytest.mark.asyncio -async def test_conversation_completion_without_expected_turns(): - """Test that completion tracking works when expected_client_turns is None.""" +def test_all_turns_fail(): + """Conversation completes when all turns fail.""" manager = ConversationManager() + state = manager.get_or_create("conv1", expected_client_turns=2) - state = await manager.get_or_create("conv_001", expected_client_turns=None) - - assert not state.is_complete() - - await manager.mark_turn_issued("conv_001", 1) - await manager.mark_turn_complete("conv_001", "response 1") + manager.mark_turn_failed("conv1") + manager.mark_turn_failed("conv1") - assert not state.is_complete() + assert state.is_complete() + assert state.completed_turns == 2 + assert state.failed_turns == 2 @pytest.mark.unit @pytest.mark.asyncio -async def test_conversation_completion_with_failures(): - """Test that conversations complete even when turns fail.""" +async def test_event_set_wakes_waiter(): + """mark_turn_complete sets turn_done so a blocked await returns.""" manager = ConversationManager() - state = await manager.get_or_create("conv1", expected_client_turns=3) + state = manager.get_or_create("conv_001") - await manager.mark_turn_issued("conv1", 1) - await manager.mark_turn_complete("conv1", "Hi there") - assert state.completed_client_turns == 1 - assert not state.is_complete() + woke_up: list[bool] = [] - await manager.mark_turn_issued("conv1", 2) - await manager.mark_turn_failed("conv1") - assert state.completed_client_turns == 2 - assert state.failed_client_turns == 1 - assert not state.is_complete() + async def waiter(): + await state.turn_done.wait() + woke_up.append(True) - await manager.mark_turn_issued("conv1", 3) - await manager.mark_turn_complete("conv1", "Bye!") - assert state.completed_client_turns == 3 - assert state.is_complete() + task = asyncio.create_task(waiter()) + await asyncio.sleep(0.01) + assert not woke_up + + manager.mark_turn_complete("conv_001", "response") + await asyncio.sleep(0.01) + await task + + assert woke_up == [True] @pytest.mark.unit @pytest.mark.asyncio -async def test_mark_turn_failed_with_no_pending(): - """Test that marking failed turn without pending turn logs warning.""" +async def test_failed_sets_event(): + """mark_turn_failed sets turn_done so the pipeline can unblock.""" manager = ConversationManager() - state = await manager.get_or_create("conv1", expected_client_turns=1) + state = manager.get_or_create("conv_001") + + woke_up: list[bool] = [] - await manager.mark_turn_failed("conv1") + async def waiter(): + await state.turn_done.wait() + woke_up.append(True) + + task = asyncio.create_task(waiter()) + await asyncio.sleep(0.01) + + manager.mark_turn_failed("conv_001") + await asyncio.sleep(0.01) + await task - assert state.completed_client_turns == 0 - assert state.failed_client_turns == 0 + assert woke_up == [True] @pytest.mark.unit @pytest.mark.asyncio -async def test_all_turns_fail(): - """Test conversation completion when all turns fail.""" +async def test_event_clear_resets_for_next_turn(): + """Clearing turn_done after wait() properly gates the next turn.""" manager = ConversationManager() - state = await manager.get_or_create("conv1", expected_client_turns=2) + state = manager.get_or_create("conv_001") - await manager.mark_turn_issued("conv1", 1) - await manager.mark_turn_failed("conv1") + # First turn: set then clear + manager.mark_turn_complete("conv_001", "r1") + await state.turn_done.wait() + state.turn_done.clear() + assert not state.turn_done.is_set() - await manager.mark_turn_issued("conv1", 2) - await manager.mark_turn_failed("conv1") - - assert state.is_complete() - assert state.completed_client_turns == 2 - assert state.failed_client_turns == 2 + # Second turn: set again + manager.mark_turn_complete("conv_001", "r2") + assert state.turn_done.is_set() @pytest.mark.unit @pytest.mark.asyncio async def test_conversation_manager_concurrent_access(): - """Test async concurrent access to multiple conversations.""" + """Concurrent pipeline tasks on independent conversations complete without errors.""" manager = ConversationManager() num_conversations = 10 - user_turns_per_conv = 5 + turns_per_conv = 5 for i in range(num_conversations): - await manager.get_or_create(f"conv_{i:03d}") + manager.get_or_create(f"conv_{i:03d}", expected_client_turns=turns_per_conv) errors = [] async def process_conversation(conv_id: str): try: - for user_turn_idx in range(user_turns_per_conv): - turn = user_turn_idx * 2 + 1 - - if user_turn_idx > 0: - ready = await manager.wait_for_turn_ready( - conv_id, turn, timeout=5.0 - ) - if not ready: - errors.append(f"{conv_id} turn {turn} timeout") - return - - await manager.mark_turn_issued(conv_id, turn) + state = manager.get_state(conv_id) + assert state is not None + for _ in range(turns_per_conv): + manager.mark_turn_complete(conv_id, "response") await asyncio.sleep(0.001) - await manager.mark_turn_complete(conv_id, f"Response {turn}") except Exception as e: errors.append(f"{conv_id} error: {e}") @@ -362,35 +281,8 @@ async def process_conversation(conv_id: str): ] await asyncio.gather(*tasks) - assert not errors, f"Errors occurred: {errors}" - + assert not errors for i in range(num_conversations): - conv_id = f"conv_{i:03d}" - state = manager._conversations[conv_id] - assert state.current_turn == user_turns_per_conv * 2 - assert state.completed_client_turns == user_turns_per_conv - - -@pytest.mark.unit -@pytest.mark.asyncio -async def test_conversation_manager_wait_for_turn_ready_reliably_wakes_on_completion(): - """Test completion wakeups do not depend on timing windows.""" - - async def run_one_iteration(): - mgr = ConversationManager() - await mgr.get_or_create("conv_001") - await mgr.mark_turn_issued("conv_001", 1) - - ready: list[bool] = [] - - async def waiter(m: ConversationManager, r: list) -> None: - r.append(await m.wait_for_turn_ready("conv_001", 3, timeout=0.5)) - - waiter_task = asyncio.create_task(waiter(mgr, ready)) - await asyncio.sleep(0.005) - await mgr.mark_turn_complete("conv_001", "Assistant response") - await asyncio.wait_for(waiter_task, timeout=0.5) - assert ready == [True] - - for _ in range(10): - await run_one_iteration() + state = manager._conversations[f"conv_{i:03d}"] + assert state.completed_turns == turns_per_conv + assert state.is_complete() diff --git a/tests/unit/load_generator/test_multi_turn_strategy.py b/tests/unit/load_generator/test_multi_turn_strategy.py index 55c51994..0edbb34f 100644 --- a/tests/unit/load_generator/test_multi_turn_strategy.py +++ b/tests/unit/load_generator/test_multi_turn_strategy.py @@ -75,7 +75,7 @@ async def complete_turns(): # Mark turn 1 complete state = conv_manager.get_state("conv1") if state: - await conv_manager.mark_turn_complete("conv1", "response 1") + conv_manager.mark_turn_complete("conv1", "response 1") asyncio.create_task(complete_turns()) count = await strategy.execute(issuer) @@ -231,24 +231,20 @@ async def test_on_query_complete_releases_semaphore(): async def test_on_sample_complete_routes_to_manager(): """on_sample_complete marks the turn complete in the ConversationManager.""" conv_manager = ConversationManager() - await conv_manager.get_or_create("conv1", expected_client_turns=1) + conv_manager.get_or_create("conv1", expected_client_turns=1) metadata = _make_dataset_metadata({"conv1": [1]}) strategy = MultiTurnStrategy(conv_manager, metadata) # Simulate issuer registering conv_id in _inflight strategy._inflight["q0001"] = "conv1" - # Pre-issue a turn so the state has pending_client_turn - await conv_manager.mark_turn_issued("conv1", 1) result = QueryResult(id="q0001", response_output=TextModelOutput(output="hello")) strategy.on_sample_complete(result) - # Allow the ensure_future coroutine to run - await asyncio.sleep(0.01) - state = conv_manager.get_state("conv1") assert state is not None - assert state.completed_client_turns == 1 + assert state.completed_turns == 1 + assert state.turn_done.is_set() assert state.is_complete() @@ -259,12 +255,11 @@ async def test_error_response_marks_turn_failed(): from inference_endpoint.core.types import ErrorData conv_manager = ConversationManager() - await conv_manager.get_or_create("conv1", expected_client_turns=1) + conv_manager.get_or_create("conv1", expected_client_turns=1) metadata = _make_dataset_metadata({"conv1": [1]}) strategy = MultiTurnStrategy(conv_manager, metadata) strategy._inflight["q0001"] = "conv1" - await conv_manager.mark_turn_issued("conv1", 1) result = QueryResult( id="q0001", @@ -272,8 +267,114 @@ async def test_error_response_marks_turn_failed(): error=ErrorData(error_type="timeout", error_message="timed out"), ) strategy.on_sample_complete(result) - await asyncio.sleep(0.01) state = conv_manager.get_state("conv1") assert state is not None - assert state.failed_client_turns == 1 + assert state.failed_turns == 1 + + +def _make_metadata_with_system( + conversations: dict[str, list[int]], + system_prompts: dict[str, str | None] | None = None, +) -> dict: + """Build metadata dict including system_prompts_by_conv.""" + samples = [] + sample_index = 0 + for conv_id, turns in conversations.items(): + for turn in turns: + samples.append( + { + "conversation_id": conv_id, + "turn": turn, + "sample_index": sample_index, + } + ) + sample_index += 1 + return { + "samples": samples, + "system_prompts_by_conv": system_prompts or {}, + } + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_live_history_initializes_system_prompt(): + """In live-history mode, ConversationManager.message_history starts with system message.""" + from inference_endpoint.config.schema import MultiTurnConfig + + conv_manager = ConversationManager() + metadata = _make_metadata_with_system( + {"conv1": [1]}, + system_prompts={"conv1": "Be helpful"}, + ) + mt_cfg = MultiTurnConfig(use_dataset_history=False, turn_timeout_s=10.0) + strategy = MultiTurnStrategy(conv_manager, metadata, multi_turn_config=mt_cfg) + issuer = FakePhaseIssuer() + + async def complete_turn(): + await asyncio.sleep(0.01) + await conv_manager.mark_turn_complete("conv1", "response") + + asyncio.create_task(complete_turn()) + await strategy.execute(issuer) + + state = conv_manager.get_state("conv1") + assert state is not None + # message_history[0] must be the system message + assert len(state.message_history) >= 1 + assert state.message_history[0] == {"role": "system", "content": "Be helpful"} + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_live_history_no_system_prompt_when_none(): + """In live-history mode, no system message is prepended when system_prompt is None.""" + from inference_endpoint.config.schema import MultiTurnConfig + + conv_manager = ConversationManager() + metadata = _make_metadata_with_system( + {"conv1": [1]}, + system_prompts={"conv1": None}, + ) + mt_cfg = MultiTurnConfig(use_dataset_history=False, turn_timeout_s=10.0) + strategy = MultiTurnStrategy(conv_manager, metadata, multi_turn_config=mt_cfg) + issuer = FakePhaseIssuer() + + async def complete_turn(): + await asyncio.sleep(0.01) + await conv_manager.mark_turn_complete("conv1", "response") + + asyncio.create_task(complete_turn()) + await strategy.execute(issuer) + + state = conv_manager.get_state("conv1") + assert state is not None + # No system message should be in history + system_msgs = [m for m in state.message_history if m.get("role") == "system"] + assert len(system_msgs) == 0 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_dataset_history_mode_does_not_inject_system_prompt(): + """In dataset-history mode (use_dataset_history=True), system_message is not passed.""" + conv_manager = ConversationManager() + metadata = _make_metadata_with_system( + {"conv1": [1]}, + system_prompts={"conv1": "Some system"}, + ) + # Default: use_dataset_history=True β†’ _store_in_history=False + strategy = MultiTurnStrategy(conv_manager, metadata) + issuer = FakePhaseIssuer() + + async def complete_turn(): + await asyncio.sleep(0.01) + await conv_manager.mark_turn_complete("conv1", "response") + + asyncio.create_task(complete_turn()) + await strategy.execute(issuer) + + state = conv_manager.get_state("conv1") + assert state is not None + # message_history should be empty (dataset-history mode doesn't accumulate) + assert len(state.message_history) == 0 diff --git a/tests/unit/openai/test_openai_adapter.py b/tests/unit/openai/test_openai_adapter.py new file mode 100644 index 00000000..506ec3fe --- /dev/null +++ b/tests/unit/openai/test_openai_adapter.py @@ -0,0 +1,147 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for OpenAIAdapter tool serialization.""" + +import json + +import msgspec +import pytest +from inference_endpoint.core.types import Query +from inference_endpoint.openai.openai_adapter import OpenAIAdapter + +_TOOL_DEF = { + "type": "function", + "function": { + "name": "search", + "description": "Search the web", + "parameters": { + "type": "object", + "properties": {"q": {"type": "string"}}, + "required": ["q"], + }, + }, +} + +_TOOL_CALLS = [ + { + "id": "call_1", + "type": "function", + "function": {"name": "search", "arguments": '{"q": "test"}'}, + } +] + + +@pytest.mark.unit +def test_tool_definitions_forwarded(): + """tools array in query.data is present in the encoded request.""" + messages = [ + {"role": "user", "content": "Find something"}, + ] + query = Query( + id="q1", + data={ + "model": "test-model", + "messages": messages, + "tools": [_TOOL_DEF], + "max_completion_tokens": 128, + "stream": False, + }, + ) + encoded = OpenAIAdapter.encode_query(query) + payload = json.loads(encoded) + + assert "tools" in payload + assert len(payload["tools"]) == 1 + assert payload["tools"][0]["function"]["name"] == "search" + + +@pytest.mark.unit +def test_tool_use_messages_roundtrip(): + """Full tool-use message sequence encodes and decodes without data loss.""" + messages = [ + {"role": "system", "content": "You are helpful"}, + {"role": "user", "content": "Find something"}, + {"role": "assistant", "content": None, "tool_calls": _TOOL_CALLS}, + {"role": "tool", "content": "search result", "tool_call_id": "call_1"}, + {"role": "assistant", "content": "Here is the answer"}, + ] + query = Query( + id="q1", + data={ + "model": "test-model", + "messages": messages, + "tools": [_TOOL_DEF], + "max_completion_tokens": 128, + "stream": False, + }, + ) + encoded = OpenAIAdapter.encode_query(query) + payload = json.loads(encoded) + + msgs = payload["messages"] + assert msgs[0]["role"] == "system" + assert msgs[1]["role"] == "user" + # assistant tool-dispatch: content is None (Pydantic model_dump includes None fields) + assert msgs[2]["role"] == "assistant" + assert msgs[2]["tool_calls"] == _TOOL_CALLS + assert msgs[2].get("content") is None + # tool result + assert msgs[3]["role"] == "tool" + assert msgs[3]["tool_call_id"] == "call_1" + assert msgs[3]["content"] == "search result" + # terminal assistant + assert msgs[4]["content"] == "Here is the answer" + + +@pytest.mark.unit +def test_encode_request_produces_valid_json_bytes(): + """encode_request returns bytes that msgspec can decode back.""" + messages = [{"role": "user", "content": "Hello"}] + query = Query( + id="q2", + data={ + "model": "m", + "messages": messages, + "max_completion_tokens": 64, + "stream": False, + }, + ) + request = OpenAIAdapter.to_endpoint_request(query) + encoded = OpenAIAdapter.encode_request(request) + + assert isinstance(encoded, bytes) + decoded = msgspec.json.decode(encoded) + assert decoded["messages"][0]["role"] == "user" + + +@pytest.mark.unit +def test_no_tools_key_when_absent(): + """When query.data has no 'tools', the encoded payload has tools=None.""" + messages = [{"role": "user", "content": "Hello"}] + query = Query( + id="q3", + data={ + "model": "m", + "messages": messages, + "max_completion_tokens": 64, + "stream": False, + }, + ) + encoded = OpenAIAdapter.encode_query(query) + payload = json.loads(encoded) + + # Pydantic model_dump includes None fields; tools must be None when not supplied + assert payload.get("tools") is None From 1a418694fee7396fe9f0b5c7819e1e58010d2fb9 Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Fri, 24 Apr 2026 20:58:00 -0700 Subject: [PATCH 05/41] docs: add multi-turn quickstart, examples, and conversion scripts Add MULTI_TURN_QUICKSTART.md, examples/09_MultiTurn/ configs and sample data, scripts/convert_agentic_snapshot.py, and README clarifications including conversion script output destination. --- docs/MULTI_TURN_QUICKSTART.md | 96 +++-- examples/09_MultiTurn/README.md | 35 +- .../agentic_coding_benchmark.yaml | 2 +- .../agentic_workflow_benchmark.yaml | 2 +- examples/09_MultiTurn/datasets/.gitkeep | 0 .../09_MultiTurn/multi_turn_benchmark.yaml | 1 - .../multi_turn_with_concurrency.yaml | 1 - scripts/convert_agentic_snapshot.py | 356 ++++++++++++++++++ scripts/validate_jsonl_schema.py | 126 +++++++ .../load_generator/conversation_manager.py | 51 +-- .../load_generator/multi_turn_strategy.py | 5 +- .../test_multi_turn_strategy.py | 111 +++++- 12 files changed, 698 insertions(+), 88 deletions(-) create mode 100644 examples/09_MultiTurn/datasets/.gitkeep create mode 100644 scripts/convert_agentic_snapshot.py create mode 100644 scripts/validate_jsonl_schema.py diff --git a/docs/MULTI_TURN_QUICKSTART.md b/docs/MULTI_TURN_QUICKSTART.md index 73ed6678..99b35aa5 100644 --- a/docs/MULTI_TURN_QUICKSTART.md +++ b/docs/MULTI_TURN_QUICKSTART.md @@ -1,6 +1,6 @@ # Multi-Turn Conversation Benchmarking - Quick Start Guide -## πŸš€ Quick Start in 5 Minutes +## Quick Start in 5 Minutes ### 1. Prepare Your Dataset @@ -40,7 +40,6 @@ datasets: - name: my_conversations type: performance path: path/to/your/conversations.jsonl - format: ".jsonl" multi_turn: # ← Presence of this block enables multi-turn mode mode: independent # ← Per-conv pipelines; no cross-conv turn barrier turn_timeout_s: 300 # ← Max wait for prev turn @@ -48,7 +47,7 @@ datasets: settings: load_pattern: type: multi_turn # ← Use multi-turn scheduler - target_concurrency: 32 # ← OPTIONAL: limit concurrent requests + target_concurrency: 32 # ← Required: max concurrent requests client: workers: 4 @@ -78,24 +77,26 @@ That's it! Your benchmark will now: --- -## πŸ“Š Understanding Results +## Understanding Results After the benchmark completes, check the directory configured via `report_dir`: -### Events Database +### Events Log -The `events.db` SQLite database includes: +The `events.jsonl` file contains one JSON record per line: -- Standard fields: sample_uuid, event_type, timestamp_ns -- **New fields**: conversation_id, turn_number +- Standard fields: `sample_uuid`, `event_type`, `timestamp_ns` +- **New fields**: `conversation_id`, `turn_number` -Query example: +Query examples: -```sql -SELECT conversation_id, turn_number, event_type, timestamp_ns -FROM events -WHERE conversation_id = 'c1' -ORDER BY turn_number; +```bash +# All events for a specific conversation +grep '"conversation_id": "c1"' logs/my_multi_turn_benchmark/events.jsonl + +# With jq for structured output +jq 'select(.conversation_id == "c1") | {conversation_id, turn_number, event_type, timestamp_ns}' \ + logs/my_multi_turn_benchmark/events.jsonl ``` ### Metrics @@ -109,7 +110,7 @@ _Note: Per-conversation aggregation (e.g., "conversations/sec") is coming in a f --- -## 🎯 Conversation Modes Explained +## Conversation Modes Explained ### Independent Mode (Default) @@ -140,9 +141,9 @@ t=0.8: conv1-turn3 (after conv1-turn2 completes) --- -## πŸŽ›οΈ Concurrency Control (NEW!) +## Concurrency Control -For benchmarks with **> 50 conversations**, use `target_concurrency` to prevent endpoint overload: +`target_concurrency` is **required** for the `multi_turn` load pattern. It limits the maximum number of in-flight requests across all conversations and prevents endpoint overload when many conversations run simultaneously. ```yaml settings: @@ -151,17 +152,15 @@ settings: target_concurrency: 32 # ← Limit to 32 concurrent requests ``` -**Why?** Without this, independent mode issues ALL turn-1s at once (could be 100+), overwhelming your endpoint. - -**Rule of thumb**: +**Sizing guide**: -- Small (< 50 convs): No limit needed -- Medium (50-500 convs): `target_concurrency: 32` -- Large (500+ convs): `target_concurrency: 64` +- Small (< 50 convs): `target_concurrency: 32` +- Medium (50-500 convs): `target_concurrency: 64` +- Large (500+ convs): `target_concurrency: 96` or higher --- -## πŸ”§ Common Configurations +## Common Configurations ### Recommended: With Concurrency Control @@ -188,6 +187,9 @@ multi_turn: turn_timeout_s: 600 settings: + load_pattern: + type: multi_turn + target_concurrency: 96 client: workers: 16 # More workers for parallel conversations ``` @@ -198,17 +200,27 @@ settings: multi_turn: mode: independent turn_timeout_s: 1800 # 30 minutes for slow responses + +settings: + load_pattern: + type: multi_turn + target_concurrency: 32 ``` --- -## ❓ Troubleshooting +## Troubleshooting ### "Conversation has invalid role sequence" -**Problem**: Your dataset doesn't alternate between user/assistant. +**Problem**: Your dataset doesn't follow a valid role sequence. + +**Fix**: Check your JSONL. Valid sequences: + +- Plain chat: `user β†’ assistant β†’ user β†’ assistant β†’ ...` +- Agentic (tool-use): `user β†’ assistant β†’ tool β†’ assistant β†’ tool β†’ ... β†’ user` -**Fix**: Check your JSONL - must be: user, assistant, user, assistant, ... +Conversations may also end with a `tool` row (the model's response to the final tool call is the benchmark target). ### "Rows for conversation X are not consecutive" @@ -230,17 +242,19 @@ multi_turn: **Problem**: MultiTurnDataset not recognized. -**Fix**: Ensure `format: ".jsonl"` is specified in config: +**Fix**: Ensure `multi_turn:` block is present in the dataset config. The file format +is auto-detected from the `.jsonl` extension β€” no `format` field is needed: ```yaml datasets: - path: your_file.jsonl - format: ".jsonl" # ← Required for JSONL + multi_turn: + mode: independent ``` --- -## πŸ“ Example Datasets +## Example Datasets ### Simple 2-Turn Conversation @@ -274,12 +288,12 @@ datasets: --- -## πŸ§ͺ Testing Your Setup +## Testing Your Setup ### 1. Use the Example Dataset ```bash -cd examples/multi_turn +cd examples/09_MultiTurn inference-endpoint benchmark from-config --config multi_turn_benchmark.yaml ``` @@ -293,14 +307,14 @@ cat logs/multi_turn_test/benchmark.log ### 3. Verify Event Recording ```bash -sqlite3 logs/multi_turn_test/events.db -sqlite> SELECT DISTINCT conversation_id FROM events; +# List all unique conversation IDs in the events log +jq -r '.conversation_id' logs/multi_turn_test/events.jsonl | sort -u # Should show your conversation IDs ``` --- -## πŸ’‘ Tips & Best Practices +## Tips & Best Practices ### Dataset Design @@ -318,28 +332,28 @@ sqlite> SELECT DISTINCT conversation_id FROM events; - **Start small**: Test with 1-2 conversations first - **Single conversation**: Use `mode: independent` with `target_concurrency: 1` -- **Check events.db**: Verify turn ordering in database +- **Check events.jsonl**: Verify turn ordering with `jq` --- -## πŸ”— More Information +## More Information - **Full Documentation**: See `examples/09_MultiTurn/README.md` - **Architecture**: See `AGENTS.md` (Multi-Turn section) --- -## βœ… Checklist +## Checklist Before running your first multi-turn benchmark: -- [ ] Dataset follows format (alternating user/assistant roles) +- [ ] Dataset follows format (user/assistant alternation, or agentic userβ†’assistantβ†’tool sequences) - [ ] All rows for each conversation_id are grouped together - [ ] Config has `multi_turn:` block in the dataset section - [ ] Config has `load_pattern.type: multi_turn` - [ ] Endpoint is running and reachable -- [ ] `format: ".jsonl"` specified for JSONL datasets +- [ ] File uses `.jsonl` extension (format is auto-detected) - [ ] Conversation IDs are unique per conversation - [ ] Turn numbers are sequential (1, 2, 3, ...) -Happy benchmarking! πŸš€ +Happy benchmarking! diff --git a/examples/09_MultiTurn/README.md b/examples/09_MultiTurn/README.md index 0d3348c7..e7f9505a 100644 --- a/examples/09_MultiTurn/README.md +++ b/examples/09_MultiTurn/README.md @@ -78,17 +78,21 @@ The following commands convert each source snapshot file to the flat-row format Run from the repo root: ```bash +# First argument: input snapshot JSONL; second argument: output flat-row JSONL python scripts/convert_agentic_snapshot.py \ - /path/to/agentic_coding_dataset.jsonl \ # input snapshot JSONL - examples/09_MultiTurn/datasets/agentic_coding_flat.jsonl \ # output flat-row JSONL + /path/to/agentic_coding_dataset.jsonl \ + examples/09_MultiTurn/datasets/agentic_coding_flat.jsonl \ --verify python scripts/convert_agentic_snapshot.py \ - /path/to/agentic_workflow_dataset.jsonl \ # input snapshot JSONL - examples/09_MultiTurn/datasets/agentic_workflow_flat.jsonl \ # output flat-row JSONL + /path/to/agentic_workflow_dataset.jsonl \ + examples/09_MultiTurn/datasets/agentic_workflow_flat.jsonl \ --verify ``` +The `datasets/` directory under `examples/09_MultiTurn/` is a placeholder; run the conversion +commands above to populate it before benchmarking. + The `--verify` flag cross-checks every client turn's message history against the source snapshot and exits with code 1 if any mismatch is found. The script also: @@ -138,8 +142,7 @@ inference-endpoint benchmark from-config \ datasets: - name: customer_support type: performance - path: examples/multi_turn/customer_support_conversations.jsonl - format: ".jsonl" + path: examples/09_MultiTurn/customer_support_conversations.jsonl multi_turn: mode: independent turn_timeout_s: 300.0 @@ -147,11 +150,12 @@ datasets: settings: load_pattern: type: multi_turn + target_concurrency: 32 # ← Required for multi_turn load pattern ``` -### Concurrency Control (Optional) +### Concurrency Control -The multi-turn scheduler supports **optional concurrency limiting** to control the maximum number of in-flight requests across all conversations: +The `target_concurrency` field is **required** for the `multi_turn` load pattern and controls the maximum number of in-flight requests across all conversations: ```yaml settings: @@ -162,15 +166,14 @@ settings: **Behavior**: -- Without `target_concurrency`: Unlimited concurrency (all turn-1s issue at t=0 in INDEPENDENT mode) - With `target_concurrency`: Limits total in-flight requests across all conversations - Combines with turn sequencing: Turn N+1 still waits for turn N, AND waits for available slot **Use cases**: -- 🎯 **Prevent endpoint overload**: Control request rate to busy endpoints -- 🎯 **Large-scale testing**: Benchmark 1000+ conversations without overwhelming system -- 🎯 **Resource management**: Stay within port limits, memory constraints +- **Prevent endpoint overload**: Control request rate to busy endpoints +- **Large-scale testing**: Benchmark 1000+ conversations without overwhelming system +- **Resource management**: Stay within port limits, memory constraints **Example**: 100 conversations with `target_concurrency: 32` @@ -221,7 +224,7 @@ If a turn times out waiting for the previous turn, it will be skipped and logged ```bash inference-endpoint benchmark from-config \ - --config examples/multi_turn/multi_turn_benchmark.yaml + --config examples/09_MultiTurn/multi_turn_benchmark.yaml ``` ### Viewing Results @@ -231,7 +234,7 @@ Multi-turn benchmarks produce both per-turn and per-conversation metrics: - **Per-turn metrics**: Latency, TTFT, TPOT for each individual turn - **Per-conversation metrics**: Total conversation latency, conversations per second -Results are stored in the configured `report_dir` with conversation metadata included in the events database. +Results are stored in the configured `report_dir` with conversation metadata included in the events log (`events.jsonl`). ## Example Datasets @@ -248,8 +251,7 @@ Simple customer support conversations demonstrating basic multi-turn interaction ### Key Components - **ConversationManager**: Tracks conversation state and message history -- **MultiTurnScheduler**: Enforces turn sequencing within conversations -- **ConversationSample**: Sample with conversation metadata +- **MultiTurnStrategy**: Enforces turn sequencing within conversations - **MultiTurnDataset**: Validates and structures multi-turn data ### Turn Sequencing @@ -307,5 +309,4 @@ Planned features: - [ ] Poisson conversation arrival mode implementation - [ ] Per-conversation metrics in reporting - [ ] Conversation-level latency percentiles -- [ ] Support for tool/function calls in conversations - [ ] Dynamic conversation branching diff --git a/examples/09_MultiTurn/agentic_coding_benchmark.yaml b/examples/09_MultiTurn/agentic_coding_benchmark.yaml index 5a1036a7..f3abc3cf 100644 --- a/examples/09_MultiTurn/agentic_coding_benchmark.yaml +++ b/examples/09_MultiTurn/agentic_coding_benchmark.yaml @@ -10,8 +10,8 @@ datasets: - name: agentic_coding type: performance # Run: python scripts/convert_agentic_snapshot.py examples/09_MultiTurn/datasets/agentic_coding_flat.jsonl --verify + # The datasets/ directory is a placeholder; populate it with the conversion script above. path: examples/09_MultiTurn/datasets/agentic_coding_flat.jsonl - format: ".jsonl" multi_turn: mode: independent turn_timeout_s: 600.0 diff --git a/examples/09_MultiTurn/agentic_workflow_benchmark.yaml b/examples/09_MultiTurn/agentic_workflow_benchmark.yaml index e8885465..239e9374 100644 --- a/examples/09_MultiTurn/agentic_workflow_benchmark.yaml +++ b/examples/09_MultiTurn/agentic_workflow_benchmark.yaml @@ -10,8 +10,8 @@ datasets: - name: agentic_workflow type: performance # Run: python scripts/convert_agentic_snapshot.py examples/09_MultiTurn/datasets/agentic_workflow_flat.jsonl --verify + # The datasets/ directory is a placeholder; populate it with the conversion script above. path: examples/09_MultiTurn/datasets/agentic_workflow_flat.jsonl - format: ".jsonl" multi_turn: mode: independent turn_timeout_s: 600.0 diff --git a/examples/09_MultiTurn/datasets/.gitkeep b/examples/09_MultiTurn/datasets/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/examples/09_MultiTurn/multi_turn_benchmark.yaml b/examples/09_MultiTurn/multi_turn_benchmark.yaml index da4773e0..36066aa3 100644 --- a/examples/09_MultiTurn/multi_turn_benchmark.yaml +++ b/examples/09_MultiTurn/multi_turn_benchmark.yaml @@ -11,7 +11,6 @@ datasets: - name: customer_support_conversations type: performance path: examples/09_MultiTurn/customer_support_conversations.jsonl - format: ".jsonl" samples: 10 multi_turn: mode: independent diff --git a/examples/09_MultiTurn/multi_turn_with_concurrency.yaml b/examples/09_MultiTurn/multi_turn_with_concurrency.yaml index ba5362e3..e1d5f37c 100644 --- a/examples/09_MultiTurn/multi_turn_with_concurrency.yaml +++ b/examples/09_MultiTurn/multi_turn_with_concurrency.yaml @@ -11,7 +11,6 @@ datasets: - name: customer_support_conversations type: performance path: examples/09_MultiTurn/customer_support_conversations.jsonl - format: ".jsonl" samples: 10 multi_turn: mode: independent # All conv turn-1 start together diff --git a/scripts/convert_agentic_snapshot.py b/scripts/convert_agentic_snapshot.py new file mode 100644 index 00000000..fe217b9b --- /dev/null +++ b/scripts/convert_agentic_snapshot.py @@ -0,0 +1,356 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convert agentic snapshot datasets to the flat-row JSONL format expected by MultiTurnDataset. + +Each snapshot record contains the full conversation history up to a checkpoint: + {"conversation_id": "sim_000001", "conversation_idx": 0, + "messages": [{"role": "system", ...}, ...], "tools": [...], "metadata": {}} + +For each conversation only the final snapshot (highest conversation_idx) is used. +Its messages array is expanded into individual flat rows, one per message. + +Usage: + python scripts/convert_agentic_snapshot.py INPUT.jsonl OUTPUT.jsonl + python scripts/convert_agentic_snapshot.py INPUT.jsonl OUTPUT.jsonl --verify +""" + +import argparse +import json +import sys +from pathlib import Path + +# --------------------------------------------------------------------------- +# Helpers shared between convert() and verify() +# --------------------------------------------------------------------------- + + +def _load_final_snapshots(input_path: Path) -> dict[str, dict]: + """Return {conv_id: record} keeping only the highest conversation_idx per conv.""" + final: dict[str, dict] = {} + with input_path.open() as fh: + for line in fh: + line = line.strip() + if not line: + continue + record = json.loads(line) + conv_id = record["conversation_id"] + if ( + conv_id not in final + or record["conversation_idx"] > final[conv_id]["conversation_idx"] + ): + final[conv_id] = record + return final + + +def _apply_collapses(non_system: list[dict]) -> list[tuple[dict, int]]: + """Apply user-collapse and tool-merge passes, tracking the last source index each + output row covers. + + Returns list of (output_msg, last_source_idx) pairs where last_source_idx is the + 0-based index within non_system of the final source message folded into this row. + """ + # Pass 1: collapse consecutive user messages + collapsed: list[tuple[dict, int]] = [] # (msg, last_source_idx) + for src_idx, msg in enumerate(non_system): + if collapsed and collapsed[-1][0]["role"] == "user" and msg["role"] == "user": + prev_msg, _ = collapsed[-1] + prev_text = prev_msg.get("content") or "" + cur_text = msg.get("content") or "" + collapsed[-1] = ( + {**prev_msg, "content": f"{prev_text}\n\n{cur_text}".strip()}, + src_idx, + ) + else: + collapsed.append((msg, src_idx)) + + # Pass 2: merge consecutive tool messages + # Input messages are raw snapshot wire-format (tool_call_id + content on each msg). + # On merge, upgrade the first message to a tool_results list so the output always + # uses the tool_results array form regardless of how many results there are. + merged: list[tuple[dict, int]] = [] + for msg, last_src in collapsed: + if merged and merged[-1][0]["role"] == "tool" and msg["role"] == "tool": + prev_msg, _ = merged[-1] + tool_results = prev_msg.get("tool_results") + if tool_results is None: + tool_results = [ + { + "tool_call_id": prev_msg.get("tool_call_id"), + "content": prev_msg.get("content"), + } + ] + prev_msg = {"role": "tool", "tool_results": tool_results} + tool_results.append( + { + "tool_call_id": msg.get("tool_call_id"), + "content": msg.get("content"), + } + ) + merged[-1] = (prev_msg, last_src) + else: + merged.append((msg, last_src)) + + return merged + + +def _normalize_msg(msg: dict) -> dict: + """Drop None values for comparison.""" + return {k: v for k, v in msg.items() if v is not None} + + +def _expand_row_to_wire_msgs(row: dict) -> list[dict]: + """Expand a single flat row into one or more OpenAI wire-format messages. + + Handles two tool row forms: + - Output flat rows: tool_results array (always used after conversion) + - Raw snapshot messages passed through verify(): tool_call_id + content directly + """ + if isinstance(row.get("tool_results"), list): + return [ + { + "role": "tool", + "tool_call_id": r.get("tool_call_id"), + "content": r.get("content"), + } + for r in row["tool_results"] + ] + msg: dict = {"role": row["role"], "content": row.get("content")} + if row.get("tool_calls"): + msg["tool_calls"] = row["tool_calls"] + if row.get("tool_call_id"): + msg["tool_call_id"] = row["tool_call_id"] + return [msg] + + +def verify(input_path: Path, output_path: Path) -> bool: + """Cross-check every client-turn's pre_built_messages against the source snapshot. + + For each output client turn, reconstruct the pre_built_messages that + MultiTurnDataset would build from the flat rows and compare it against the + ground-truth messages built directly from the source snapshot up to the same + point (accounting for user-collapse and tool-merge). + + Returns: + True if all checks pass, False if any mismatch found. + """ + final = _load_final_snapshots(input_path) + + # Load converted rows grouped by conversation_id + conv_rows: dict[str, list[dict]] = {} + with output_path.open() as fh: + for line in fh: + line = line.strip() + if not line: + continue + row = json.loads(line) + cid = row["conversation_id"] + conv_rows.setdefault(cid, []).append(row) + for cid in conv_rows: + conv_rows[cid].sort(key=lambda r: r["turn"]) + + errors: list[str] = [] + total_checked = 0 + + for conv_id in sorted(final): + record = final[conv_id] + system_content: str | None = None + non_system: list[dict] = [] + for msg in record["messages"]: + if msg["role"] == "system": + system_content = msg.get("content") + else: + non_system.append(msg) + + # Re-apply the same collapses the converter applies, tracking source coverage + processed = _apply_collapses(non_system) # [(output_msg, last_source_idx), ...] + flat_rows = conv_rows.get(conv_id, []) + + if len(processed) != len(flat_rows): + errors.append( + f"{conv_id}: expected {len(processed)} flat rows after collapses, " + f"got {len(flat_rows)} in output" + ) + continue + + client_turn_pairs = [ + (out_pos, flat_row) + for out_pos, (flat_row, _) in enumerate( + zip(flat_rows, processed, strict=True) + ) + if flat_row["role"] in ("user", "tool") + ] + + for ct_idx, (out_pos, flat_row) in enumerate(client_turn_pairs): + # Ground truth: apply the same collapses the converter applies, then + # build the message list from the processed (collapsed/merged) rows up to + # and including this client turn. This correctly reflects what the + # converter produces β€” consecutive user/tool merges mean history is + # shorter than the raw source but content-equivalent. + expected: list[dict] = [] + if system_content: + expected.append({"role": "system", "content": system_content}) + for proc_msg, _ in processed[: out_pos + 1]: + expected.extend(_expand_row_to_wire_msgs(proc_msg)) + + # Reconstructed output: system + expand all flat rows up to this turn + got: list[dict] = [] + if system_content: + got.append({"role": "system", "content": system_content}) + for row in flat_rows[: out_pos + 1]: + got.extend(_expand_row_to_wire_msgs(row)) + + exp_norm = [_normalize_msg(m) for m in expected] + got_norm = [_normalize_msg(m) for m in got] + + if exp_norm != got_norm: + errors.append( + f"{conv_id} client-turn {ct_idx + 1} (flat turn {flat_row['turn']}):\n" + f" expected {len(exp_norm)} msgs, got {len(got_norm)}\n" + f" EXPECTED: {json.dumps(exp_norm, ensure_ascii=False)[:400]}\n" + f" GOT: {json.dumps(got_norm, ensure_ascii=False)[:400]}" + ) + total_checked += 1 + + if errors: + print( + f"FAIL: {len(errors)} mismatches out of {total_checked} client turns checked.", + file=sys.stderr, + ) + for err in errors[:20]: + print(err, file=sys.stderr) + if len(errors) > 20: + print(f" ... and {len(errors) - 20} more", file=sys.stderr) + return False + + print( + f"OK: all {total_checked} client turns verified against source.", + file=sys.stderr, + ) + return True + + +def convert(input_path: Path, output_path: Path) -> None: + # Group records by conversation_id, keep only the final snapshot per conversation. + final: dict[str, dict] = {} + with input_path.open() as fh: + for line in fh: + line = line.strip() + if not line: + continue + record = json.loads(line) + conv_id = record["conversation_id"] + if ( + conv_id not in final + or record["conversation_idx"] > final[conv_id]["conversation_idx"] + ): + final[conv_id] = record + + print(f"Found {len(final)} conversations in {input_path.name}", file=sys.stderr) + + rows_written = 0 + with output_path.open("w") as out: + for conv_id, record in sorted(final.items()): + messages = record["messages"] + tools = record.get("tools") or [] + + # Extract system message (always first if present). + system_content: str | None = None + non_system: list[dict] = [] + for msg in messages: + if msg["role"] == "system": + system_content = msg.get("content") + else: + non_system.append(msg) + + # Apply the same user-collapse and tool-merge passes used by verify(). + # _apply_collapses returns [(msg, last_source_idx), ...]; strip the indices. + non_system = [msg for msg, _ in _apply_collapses(non_system)] + + first_user_seen = False + for position, msg in enumerate(non_system): + role = msg["role"] + turn = position + 1 # 1-indexed + + row: dict = {"conversation_id": conv_id, "turn": turn, "role": role} + + # System prompt on the first user row only. + if role == "user" and not first_user_seen: + if system_content is not None: + row["system"] = system_content + first_user_seen = True + + # tool_calls for assistant messages that dispatch tools. + if msg.get("tool_calls"): + row["tool_calls"] = msg["tool_calls"] + + if role == "tool": + # All tool rows use tool_results array (single results have one entry). + if msg.get("tool_results"): + row["tool_results"] = msg["tool_results"] + else: + row["tool_results"] = [ + { + "tool_call_id": msg.get("tool_call_id"), + "content": msg.get("content"), + } + ] + else: + # content field (may be None for tool-dispatching assistant messages) + row["content"] = msg.get("content") + + # Attach tool definitions to client-turn rows only (user + tool). + # This avoids duplicating the large tools array on every assistant row + # while still making them available via load_sample(). + if role in ("user", "tool") and tools: + row["tools"] = tools + + out.write(json.dumps(row, ensure_ascii=False) + "\n") + rows_written += 1 + + print(f"Wrote {rows_written} rows to {output_path}", file=sys.stderr) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Convert agentic snapshot JSONL to MultiTurnDataset flat-row JSONL." + ) + parser.add_argument("input", type=Path, help="Input snapshot JSONL file") + parser.add_argument("output", type=Path, help="Output flat-row JSONL file") + parser.add_argument( + "--verify", + action="store_true", + help=( + "After converting, cross-check every client-turn's pre_built_messages " + "against the source snapshot. Exits with code 1 if any mismatch found." + ), + ) + args = parser.parse_args() + + if not args.input.exists(): + print(f"Error: input file not found: {args.input}", file=sys.stderr) + sys.exit(1) + + convert(args.input, args.output) + + if args.verify: + ok = verify(args.input, args.output) + if not ok: + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/scripts/validate_jsonl_schema.py b/scripts/validate_jsonl_schema.py new file mode 100644 index 00000000..d2bb7177 --- /dev/null +++ b/scripts/validate_jsonl_schema.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Validate multi-turn JSONL dataset files against multi_turn_dataset_schema.json. + +Checks each row's structure against the JSON schema (field types, required fields, +tool_results shape, etc.). Does NOT check cross-row invariants such as turn +numbering or role sequences β€” those are enforced by MultiTurnDataset at load time. + +Usage: + python scripts/validate_jsonl_schema.py FILE [FILE ...] + python scripts/validate_jsonl_schema.py /model/agentic_coding_flat.jsonl /model/agentic_workflow_flat.jsonl +""" + +import argparse +import json +import sys +from pathlib import Path + +try: + import jsonschema +except ImportError: + print( + "Error: jsonschema not installed. Run: pip install jsonschema", file=sys.stderr + ) + sys.exit(1) + + +def validate_file(path: Path, schema: dict, max_errors: int = 50) -> int: + """Validate every row in a JSONL file against the schema. + + Returns the number of validation errors found. + """ + errors: list[str] = [] + validator = jsonschema.Draft7Validator(schema) + + with path.open() as fh: + for lineno, line in enumerate(fh, 1): + line = line.strip() + if not line: + continue + try: + row = json.loads(line) + except json.JSONDecodeError as e: + errors.append(f" line {lineno}: JSON parse error: {e}") + if len(errors) >= max_errors: + break + continue + + conv_id = row.get("conversation_id", "") + turn = row.get("turn", "?") + role = row.get("role", "?") + + row_errors = list(validator.iter_errors(row)) + for err in row_errors: + path_str = " -> ".join(str(p) for p in err.absolute_path) or "(root)" + errors.append( + f" line {lineno} [{conv_id} turn={turn} role={role}] " + f"@ {path_str}: {err.message}" + ) + + if len(errors) >= max_errors: + errors.append(f" ... stopping after {max_errors} errors") + break + + if errors: + print(f"FAIL {path.name}: {len(errors)} error(s)") + for msg in errors: + print(msg) + else: + print(f"OK {path.name}") + + return len(errors) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Validate multi-turn JSONL files against multi_turn_dataset_schema.json." + ) + parser.add_argument("files", nargs="+", type=Path, help="JSONL files to validate") + parser.add_argument( + "--schema", + type=Path, + default=Path(__file__).parent.parent / "multi_turn_dataset_schema.json", + help="Path to the JSON schema file (default: multi_turn_dataset_schema.json)", + ) + parser.add_argument( + "--max-errors", + type=int, + default=50, + help="Stop reporting after this many errors per file (default: 50)", + ) + args = parser.parse_args() + + if not args.schema.exists(): + print(f"Error: schema not found: {args.schema}", file=sys.stderr) + sys.exit(1) + + schema = json.load(args.schema.open()) + + total_errors = 0 + for path in args.files: + if not path.exists(): + print(f"Error: file not found: {path}", file=sys.stderr) + total_errors += 1 + continue + total_errors += validate_file(path, schema, max_errors=args.max_errors) + + sys.exit(1 if total_errors > 0 else 0) + + +if __name__ == "__main__": + main() diff --git a/src/inference_endpoint/load_generator/conversation_manager.py b/src/inference_endpoint/load_generator/conversation_manager.py index ba9a02ea..56bb8278 100644 --- a/src/inference_endpoint/load_generator/conversation_manager.py +++ b/src/inference_endpoint/load_generator/conversation_manager.py @@ -101,11 +101,29 @@ def get_or_create( ) return self._conversations[conversation_id] + def _log_if_complete(self, state: ConversationState, conversation_id: str) -> None: + """Log completion status once all expected turns have a response.""" + if not state.is_complete(): + return + if state.failed_turns > 0: + logger.info( + f"Conversation {conversation_id} completed with failures: " + f"{state.completed_turns - state.failed_turns}/" + f"{state.expected_client_turns} successful, " + f"{state.failed_turns} failed" + ) + else: + logger.debug( + f"Conversation {conversation_id} completed: " + f"{state.completed_turns}/{state.expected_client_turns} turns" + ) + def mark_turn_complete( self, conversation_id: str, response: str, store_in_history: bool = False, + metadata: dict[str, Any] | None = None, ) -> None: """Record a successful response and wake the pipeline task. @@ -113,6 +131,8 @@ def mark_turn_complete( conversation_id: Conversation ID. response: Model output (appended to history when store_in_history=True). store_in_history: When True, append response to message_history. + metadata: Optional response metadata; tool_calls are preserved in history + when present (only used when store_in_history=True). Raises: KeyError: If conversation_id not found. @@ -120,22 +140,15 @@ def mark_turn_complete( state = self._conversations.get(conversation_id) if state is None: raise KeyError(f"Conversation {conversation_id} not initialized") - if store_in_history and response: - state.message_history.append({"role": "assistant", "content": response}) + if store_in_history: + tool_calls = metadata.get("tool_calls") if metadata else None + if response or tool_calls: + msg: dict[str, Any] = {"role": "assistant", "content": response or None} + if tool_calls: + msg["tool_calls"] = tool_calls + state.message_history.append(msg) state.completed_turns += 1 - if state.is_complete(): - if state.failed_turns > 0: - logger.info( - f"Conversation {conversation_id} completed with failures: " - f"{state.completed_turns - state.failed_turns}/" - f"{state.expected_client_turns} successful, " - f"{state.failed_turns} failed" - ) - else: - logger.debug( - f"Conversation {conversation_id} completed: " - f"{state.completed_turns}/{state.expected_client_turns} turns" - ) + self._log_if_complete(state, conversation_id) state.turn_done.set() def mark_turn_failed( @@ -164,11 +177,5 @@ def mark_turn_failed( state.completed_turns += 1 state.failed_turns += 1 logger.warning(f"Turn failed for conversation {conversation_id}") - if state.is_complete(): - logger.info( - f"Conversation {conversation_id} completed with failures: " - f"{state.completed_turns - state.failed_turns}/" - f"{state.expected_client_turns} successful, " - f"{state.failed_turns} failed" - ) + self._log_if_complete(state, conversation_id) state.turn_done.set() diff --git a/src/inference_endpoint/load_generator/multi_turn_strategy.py b/src/inference_endpoint/load_generator/multi_turn_strategy.py index cfd418bd..0f3ba6e8 100644 --- a/src/inference_endpoint/load_generator/multi_turn_strategy.py +++ b/src/inference_endpoint/load_generator/multi_turn_strategy.py @@ -227,5 +227,8 @@ def on_sample_complete(self, result: QueryResult) -> None: ) else: self._conv_manager.mark_turn_complete( - conv_id, response_text, store_in_history=self._store_in_history + conv_id, + response_text, + store_in_history=self._store_in_history, + metadata=result.metadata, ) diff --git a/tests/unit/load_generator/test_multi_turn_strategy.py b/tests/unit/load_generator/test_multi_turn_strategy.py index 0edbb34f..1eecc75d 100644 --- a/tests/unit/load_generator/test_multi_turn_strategy.py +++ b/tests/unit/load_generator/test_multi_turn_strategy.py @@ -313,7 +313,7 @@ async def test_live_history_initializes_system_prompt(): async def complete_turn(): await asyncio.sleep(0.01) - await conv_manager.mark_turn_complete("conv1", "response") + conv_manager.mark_turn_complete("conv1", "response") asyncio.create_task(complete_turn()) await strategy.execute(issuer) @@ -342,7 +342,7 @@ async def test_live_history_no_system_prompt_when_none(): async def complete_turn(): await asyncio.sleep(0.01) - await conv_manager.mark_turn_complete("conv1", "response") + conv_manager.mark_turn_complete("conv1", "response") asyncio.create_task(complete_turn()) await strategy.execute(issuer) @@ -369,7 +369,7 @@ async def test_dataset_history_mode_does_not_inject_system_prompt(): async def complete_turn(): await asyncio.sleep(0.01) - await conv_manager.mark_turn_complete("conv1", "response") + conv_manager.mark_turn_complete("conv1", "response") asyncio.create_task(complete_turn()) await strategy.execute(issuer) @@ -378,3 +378,108 @@ async def complete_turn(): assert state is not None # message_history should be empty (dataset-history mode doesn't accumulate) assert len(state.message_history) == 0 + + +@pytest.mark.unit +def test_mark_turn_complete_preserves_tool_calls(): + """mark_turn_complete stores tool_calls in history when metadata contains them.""" + conv_manager = ConversationManager() + conv_manager.get_or_create("conv1", expected_client_turns=1) + + tool_calls = [ + { + "id": "call_1", + "type": "function", + "function": {"name": "bash", "arguments": '{"cmd": "ls"}'}, + } + ] + conv_manager.mark_turn_complete( + "conv1", + response="", + store_in_history=True, + metadata={"tool_calls": tool_calls}, + ) + + state = conv_manager.get_state("conv1") + assert state is not None + assert len(state.message_history) == 1 + msg = state.message_history[0] + assert msg["role"] == "assistant" + assert msg["content"] is None + assert msg["tool_calls"] == tool_calls + + +@pytest.mark.unit +def test_mark_turn_complete_with_response_and_tool_calls(): + """mark_turn_complete stores both content and tool_calls when both are present.""" + conv_manager = ConversationManager() + conv_manager.get_or_create("conv1", expected_client_turns=1) + + tool_calls = [ + { + "id": "call_1", + "type": "function", + "function": {"name": "search", "arguments": "{}"}, + } + ] + conv_manager.mark_turn_complete( + "conv1", + response="Calling search...", + store_in_history=True, + metadata={"tool_calls": tool_calls}, + ) + + state = conv_manager.get_state("conv1") + assert state is not None + msg = state.message_history[0] + assert msg["content"] == "Calling search..." + assert msg["tool_calls"] == tool_calls + + +@pytest.mark.unit +def test_mark_turn_complete_no_history_when_empty(): + """mark_turn_complete does not append when response is empty and no tool_calls.""" + conv_manager = ConversationManager() + conv_manager.get_or_create("conv1", expected_client_turns=1) + + conv_manager.mark_turn_complete("conv1", response="", store_in_history=True) + + state = conv_manager.get_state("conv1") + assert state is not None + assert len(state.message_history) == 0 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_on_sample_complete_passes_metadata(): + """on_sample_complete forwards result.metadata (including tool_calls) to ConversationManager.""" + from inference_endpoint.config.schema import MultiTurnConfig + + conv_manager = ConversationManager() + metadata_dict = _make_metadata_with_system({"conv1": [1]}) + mt_cfg = MultiTurnConfig(use_dataset_history=False, turn_timeout_s=10.0) + strategy = MultiTurnStrategy(conv_manager, metadata_dict, multi_turn_config=mt_cfg) + + conv_manager.get_or_create("conv1", expected_client_turns=1) + strategy._inflight["q0001"] = "conv1" + + tool_calls = [ + { + "id": "call_1", + "type": "function", + "function": {"name": "bash", "arguments": "{}"}, + } + ] + result = QueryResult( + id="q0001", + response_output=TextModelOutput(output=""), + metadata={"tool_calls": tool_calls}, + ) + strategy.on_sample_complete(result) + + state = conv_manager.get_state("conv1") + assert state is not None + assert state.completed_turns == 1 + assert len(state.message_history) == 1 + assert state.message_history[0]["tool_calls"] == tool_calls + assert state.message_history[0]["content"] is None From 00310b490d4e74bd5f55e7270cc378b21d6d2057 Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Sat, 25 Apr 2026 15:24:50 -0700 Subject: [PATCH 06/41] fix: replace hardcoded /model/ path in validate_jsonl_schema.py docstring Co-Authored-By: Claude Sonnet 4.6 --- scripts/validate_jsonl_schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/validate_jsonl_schema.py b/scripts/validate_jsonl_schema.py index d2bb7177..c2d25eca 100644 --- a/scripts/validate_jsonl_schema.py +++ b/scripts/validate_jsonl_schema.py @@ -22,7 +22,7 @@ Usage: python scripts/validate_jsonl_schema.py FILE [FILE ...] - python scripts/validate_jsonl_schema.py /model/agentic_coding_flat.jsonl /model/agentic_workflow_flat.jsonl + python scripts/validate_jsonl_schema.py examples/09_MultiTurn/datasets/agentic_coding_flat.jsonl """ import argparse From 2961de5d4cd58e5ca8d7287d583b070e6b694664 Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Sat, 25 Apr 2026 15:32:13 -0700 Subject: [PATCH 07/41] chore: move multi_turn_dataset_schema.json into scripts/ and update default path Co-Authored-By: Claude Sonnet 4.6 --- scripts/multi_turn_dataset_schema.json | 557 +++++++++++++++++++++++++ scripts/validate_jsonl_schema.py | 8 +- 2 files changed, 561 insertions(+), 4 deletions(-) create mode 100644 scripts/multi_turn_dataset_schema.json diff --git a/scripts/multi_turn_dataset_schema.json b/scripts/multi_turn_dataset_schema.json new file mode 100644 index 00000000..b1b7ca13 --- /dev/null +++ b/scripts/multi_turn_dataset_schema.json @@ -0,0 +1,557 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "Multi-Turn Conversation Dataset Schema", + "description": "JSON schema describing the structure and requirements for multi-turn conversation datasets in the MLPerf Inference Endpoint Benchmarking System", + "version": "1.0.0", + + "definitions": { + "basicMessageTypes": { + "title": "Basic Message Types", + "description": "Plain conversational messages without tool calls", + "oneOf": [ + { + "title": "User Message", + "type": "object", + "properties": { + "conversation_id": { + "type": "string", + "description": "Unique identifier for the conversation" + }, + "turn": { + "type": "integer", + "minimum": 1, + "description": "Turn number within conversation (1-indexed)" + }, + "role": { + "const": "user", + "description": "Message role - user initiates turns" + }, + "content": { + "type": "string", + "description": "Message content from the user" + }, + "system": { + "type": "string", + "description": "System prompt (typically only on first user turn)" + } + }, + "required": ["conversation_id", "turn", "role", "content"], + "additionalProperties": true + }, + { + "title": "Assistant Message", + "type": "object", + "properties": { + "conversation_id": { + "type": "string", + "description": "Unique identifier for the conversation" + }, + "turn": { + "type": "integer", + "minimum": 1, + "description": "Turn number within conversation" + }, + "role": { + "const": "assistant", + "description": "Message role - assistant responds to user" + }, + "content": { + "type": "string", + "description": "Message content from the assistant" + } + }, + "required": ["conversation_id", "turn", "role", "content"], + "not": { "required": ["tool_calls"] }, + "additionalProperties": true + } + ] + }, + + "toolCallMessage": { + "title": "Assistant Message with Tool Calls", + "description": "Assistant message that dispatches one or more tool calls. Role must be 'assistant' with a non-empty tool_calls array (OpenAI wire format).", + "type": "object", + "properties": { + "conversation_id": { + "type": "string", + "description": "Unique identifier for the conversation" + }, + "turn": { + "type": "integer", + "minimum": 1, + "description": "Turn number within conversation" + }, + "role": { + "const": "assistant", + "description": "Role for a tool-dispatching assistant message." + }, + "content": { + "type": ["string", "null"], + "description": "Optional textual prefix alongside tool dispatch (e.g., 'I will investigate this with bash'). Typically null/absent." + }, + "tool_calls": { + "type": "array", + "minItems": 1, + "description": "List of tool calls dispatched by the assistant", + "items": { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "Unique identifier for this tool call (e.g., 'functions.bash:0')" + }, + "type": { + "type": "string", + "const": "function", + "description": "Tool type (currently only 'function' supported)" + }, + "function": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "Name of the tool/function to invoke" + }, + "arguments": { + "type": "string", + "description": "JSON string containing function arguments" + } + }, + "required": ["name", "arguments"] + } + }, + "required": ["id", "type", "function"] + } + } + }, + "required": ["conversation_id", "turn", "role", "tool_calls"], + "additionalProperties": true + }, + + "toolMessage": { + "title": "Tool Result Message", + "description": "Tool execution results as a list. Single results have one entry; parallel results have multiple entries.", + "type": "object", + "properties": { + "conversation_id": { + "type": "string", + "description": "Unique identifier for the conversation" + }, + "turn": { + "type": "integer", + "minimum": 1, + "description": "Turn number within conversation" + }, + "role": { + "const": "tool", + "description": "Tool result role" + }, + "tool_results": { + "type": "array", + "minItems": 1, + "description": "List of tool execution results. Single tool calls have one entry; parallel tool calls have multiple entries.", + "items": { + "type": "object", + "properties": { + "tool_call_id": { + "type": "string", + "description": "ID of the tool call this result corresponds to" + }, + "content": { + "type": "string", + "description": "Output/result content from the tool execution" + } + }, + "required": ["tool_call_id", "content"] + } + } + }, + "required": ["conversation_id", "turn", "role", "tool_results"], + "additionalProperties": true + }, + + "generationParameters": { + "title": "Generation Parameters", + "description": "Optional parameters controlling the model's behavior for generation", + "type": "object", + "properties": { + "model": { + "type": "string", + "description": "Model name override for this turn" + }, + "max_new_tokens": { + "type": "integer", + "minimum": 1, + "description": "Maximum number of tokens to generate" + }, + "max_completion_tokens": { + "type": "integer", + "minimum": 1, + "description": "OpenAI API compatible max tokens parameter" + }, + "stream": { + "type": "boolean", + "description": "Whether to use streaming for this turn" + }, + "temperature": { + "type": "number", + "minimum": 0, + "maximum": 2, + "description": "Sampling temperature (0 = deterministic, higher = more random)" + }, + "top_p": { + "type": "number", + "minimum": 0, + "maximum": 1, + "description": "Nucleus sampling parameter" + }, + "top_k": { + "type": "integer", + "minimum": 1, + "description": "Top-k sampling parameter" + }, + "seed": { + "type": "integer", + "description": "Random seed for reproducibility" + }, + "repetition_penalty": { + "type": "number", + "minimum": 0, + "description": "Penalty for repeating tokens" + }, + "frequency_penalty": { + "type": "number", + "minimum": -2, + "maximum": 2, + "description": "Frequency penalty for tokens" + }, + "presence_penalty": { + "type": "number", + "minimum": -2, + "maximum": 2, + "description": "Presence penalty for tokens" + }, + "stop": { + "oneOf": [ + { "type": "string" }, + { "type": "array", "items": { "type": "string" } } + ], + "description": "Stop sequences for generation" + }, + "n": { + "type": "integer", + "minimum": 1, + "description": "Number of completions to generate" + }, + "logit_bias": { + "type": "object", + "description": "Token probability adjustments (token_id -> bias)" + }, + "name": { + "type": "string", + "description": "Entity name for role tracking (e.g., 'Bob')" + }, + "user": { + "type": "string", + "description": "End-user identifier for monitoring/abuse detection" + }, + "chat_template": { + "type": "string", + "description": "Custom chat formatting template" + }, + "tools": { + "type": "array", + "description": "OpenAI tool definitions for tool-calling models", + "items": { "type": "object" } + } + } + } + }, + + "type": "object", + "oneOf": [ + { + "title": "Plain Conversation Row", + "description": "A single row representing a plain user or assistant message", + "allOf": [ + { "$ref": "#/definitions/basicMessageTypes" }, + { "$ref": "#/definitions/generationParameters" } + ] + }, + { + "title": "Tool Call Row", + "description": "A single row representing an assistant dispatch of tool calls", + "allOf": [ + { "$ref": "#/definitions/toolCallMessage" }, + { "$ref": "#/definitions/generationParameters" } + ] + }, + { + "title": "Tool Result Row", + "description": "A single row representing one or more tool results", + "allOf": [ + { "$ref": "#/definitions/toolMessage" }, + { "$ref": "#/definitions/generationParameters" } + ] + } + ], + + "examples": [ + { + "title": "Basic user message", + "data": { + "conversation_id": "conv_001", + "turn": 1, + "role": "user", + "content": "I need help resetting my password", + "system": "You are a helpful customer support agent" + } + }, + { + "title": "Assistant response", + "data": { + "conversation_id": "conv_001", + "turn": 2, + "role": "assistant", + "content": "I'd be happy to help. Can you provide your email address?" + } + }, + { + "title": "Assistant with tool calls (converter/OpenAI wire format)", + "data": { + "conversation_id": "sim_001", + "turn": 2, + "role": "assistant", + "tool_calls": [ + { + "id": "functions.bash:0", + "type": "function", + "function": { + "name": "bash", + "arguments": "{\"cmd\": \"cat foo.py\"}" + } + } + ] + } + }, + { + "title": "Tool result from execution", + "data": { + "conversation_id": "sim_001", + "turn": 3, + "role": "tool", + "tool_results": [ + { + "tool_call_id": "functions.bash:0", + "content": "def foo():\n return 1/0" + } + ] + } + }, + { + "title": "Merged parallel tool results", + "data": { + "conversation_id": "sim_002", + "turn": 3, + "role": "tool", + "tool_results": [ + { + "tool_call_id": "functions.bash:0", + "content": "file1.txt" + }, + { + "tool_call_id": "functions.bash:1", + "content": "file2.txt" + } + ] + } + }, + { + "title": "User turn with generation parameters", + "data": { + "conversation_id": "conv_002", + "turn": 5, + "role": "user", + "content": "What's the best way to optimize this code?", + "temperature": 0.7, + "max_new_tokens": 256, + "top_p": 0.9 + } + } + ], + + "documentation": { + "overview": "Multi-turn conversation datasets enable benchmarking of realistic conversational AI workloads where each turn depends on previous responses. The system maintains conversation history and enforces turn sequencing.", + + "requiredFields": [ + { + "field": "conversation_id", + "type": "string", + "description": "Unique identifier for each conversation. All rows belonging to the same conversation must share the same conversation_id." + }, + { + "field": "turn", + "type": "integer", + "description": "Turn number within conversation (1-indexed). Must be consecutive starting at 1 (i.e., 1, 2, 3, …, N with no gaps or duplicates)." + }, + { + "field": "role", + "type": "string", + "enum": ["user", "assistant", "tool"], + "description": "Speaker role. 'user' or 'tool' are client-initiated turns. 'assistant' is the server response β€” either a terminal reply or a tool dispatch (with tool_calls field)." + }, + { + "field": "content", + "type": "string", + "description": "Message content. Required for 'user' role and plain 'assistant' rows. For tool-dispatching assistant rows, content may be omitted (null/absent). For 'tool' rows using tool_results (merged parallel results), top-level content is absent β€” results are in the tool_results array instead." + } + ], + + "optionalFields": [ + { + "field": "system", + "type": "string", + "description": "System prompt (typically only on first user turn). Applied to all messages in the conversation." + }, + { + "field": "tool_calls", + "type": "array", + "description": "Tool calls dispatched by assistant (for tool-dispatching 'assistant' rows). Each element has {id, type, function: {name, arguments}}." + }, + { + "field": "tool_results", + "type": "array", + "description": "Tool execution results (required for all 'tool' role rows). Each element has {tool_call_id, content}. Single results have one entry; parallel results have multiple entries." + }, + { + "field": "model", + "type": "string", + "description": "Model name override for this turn." + }, + { + "field": "max_new_tokens", + "type": "integer", + "description": "Maximum tokens to generate for this turn." + }, + { + "field": "temperature", + "type": "number", + "description": "Sampling temperature (0 to 2)." + }, + { + "field": "top_p", + "type": "number", + "description": "Nucleus sampling parameter (0 to 1)." + }, + { + "field": "tools", + "type": "array", + "description": "OpenAI tool definitions forwarded to the endpoint for tool-calling models. The converter attaches this only to client-turn rows (user and tool) to avoid duplicating the large array on every assistant row. Hand-authored datasets typically place it on the first user turn." + } + ], + + "validRoleSequences": [ + { + "name": "Plain conversation", + "sequence": "user β†’ assistant β†’ user β†’ assistant β†’ ...", + "description": "Standard alternating conversation without tool use." + }, + { + "name": "Agentic with tools", + "sequence": "user β†’ assistant β†’ tool β†’ [assistant β†’ tool]* β†’ assistant β†’ user", + "description": "Agent dispatches tools (assistant with tool_calls), executes them, and returns results before final response. 'tool β†’ user' is also valid when no terminal assistant response is needed before the next user turn." + } + ], + "stateMachine": { + "description": "Complete valid-next-state table from _validate_conversation_structure()", + "transitions": { + "start": ["user"], + "user": ["assistant"], + "assistant": ["tool", "user"], + "tool": ["assistant", "user"] + } + }, + + "validationRules": [ + { + "rule": "Turn numbers must be consecutive starting at 1", + "violation": "Turn sequence is not exactly 1, 2, 3, …, N (missing, duplicate, or out-of-range turns)" + }, + { + "rule": "Role sequences must follow the state machine", + "violation": "Invalid transition (e.g., 'user' directly followed by 'user', consecutive 'assistant' rows). Note: 'tool β†’ user' IS a valid transition. The state machine also implicitly enforces that the first row must be a user turn." + } + ], + "notValidated": [ + "tool_results[*].tool_call_id pairing: the validator does NOT verify that tool_call_id values inside tool_results items reference a prior assistant tool_calls entry. Correct pairing is the dataset author's or converter's responsibility." + ], + + "dataTypes": [ + { + "name": "Basic Message", + "roles": ["user", "assistant"], + "fields": { + "required": ["conversation_id", "turn", "role", "content"], + "optional": ["system", "...generation parameters"] + } + }, + { + "name": "Tool Call Dispatch", + "roles": ["assistant"], + "fields": { + "required": ["conversation_id", "turn", "role", "tool_calls"], + "optional": ["content", "...generation parameters"] + } + }, + { + "name": "Tool Result", + "roles": ["tool"], + "fields": { + "required": ["conversation_id", "turn", "role", "tool_results"], + "optional": ["...generation parameters"], + "note": "Expands to one OpenAI tool message per result entry at the wire layer" + } + } + ], + + "conversionFromSnapshot": { + "description": "Agentic datasets are often stored as full-conversation snapshots. Use scripts/convert_agentic_snapshot.py to convert.", + "sourceFormat": "Each JSONL line is a complete conversation snapshot with 'messages' array", + "targetFormat": "Each JSONL line is a single message row with conversation metadata", + "process": [ + "Extract conversation_id, conversation_idx, and messages array", + "Use highest-indexed snapshot per conversation_id", + "Collapse consecutive user messages into a single user row (newline-joined content)", + "Emit all tool result rows as tool_results arrays (single results have one entry; consecutive tool results from parallel dispatch are merged into one row with multiple entries)", + "Flatten messages into individual rows, numbering turns sequentially from 1", + "Attach system prompt to first user row only", + "Attach tools array to client-turn rows (user and tool) only β€” not to assistant rows", + "Tool-dispatching assistant messages are written as role 'assistant' with tool_calls (OpenAI wire format)" + ] + }, + + "performanceNotes": [ + "Pre-built messages: The system pre-computes complete message lists during dataset load() for efficient turn serving", + "Memory efficiency: ~1KB per turn average; 1000 conversations Γ— 10 turns = ~10MB", + "Hot path: Only client turns (user/tool) are issued; assistant turns remain in backing store for history" + ], + + "commonErrors": [ + { + "error": "Invalid role sequence", + "cause": "Violates state machine (e.g., userβ†’user, userβ†’tool, consecutive assistant rows)", + "fix": "Verify alternation or use conversion script for agentic data" + }, + { + "error": "Turn numbers not consecutive", + "cause": "Turn sequence has gaps, duplicates, or doesn't start at 1", + "fix": "Ensure turns are numbered 1, 2, 3, …, N with no missing or duplicate values" + }, + { + "error": "Turn timeout", + "cause": "Previous turn took too long to complete", + "fix": "Increase turn_timeout_s in configuration or check endpoint performance" + } + ] + } +} diff --git a/scripts/validate_jsonl_schema.py b/scripts/validate_jsonl_schema.py index c2d25eca..1be81dd2 100644 --- a/scripts/validate_jsonl_schema.py +++ b/scripts/validate_jsonl_schema.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Validate multi-turn JSONL dataset files against multi_turn_dataset_schema.json. +"""Validate multi-turn JSONL dataset files against scripts/multi_turn_dataset_schema.json. Checks each row's structure against the JSON schema (field types, required fields, tool_results shape, etc.). Does NOT check cross-row invariants such as turn @@ -88,14 +88,14 @@ def validate_file(path: Path, schema: dict, max_errors: int = 50) -> int: def main() -> None: parser = argparse.ArgumentParser( - description="Validate multi-turn JSONL files against multi_turn_dataset_schema.json." + description="Validate multi-turn JSONL files against scripts/multi_turn_dataset_schema.json." ) parser.add_argument("files", nargs="+", type=Path, help="JSONL files to validate") parser.add_argument( "--schema", type=Path, - default=Path(__file__).parent.parent / "multi_turn_dataset_schema.json", - help="Path to the JSON schema file (default: multi_turn_dataset_schema.json)", + default=Path(__file__).parent / "multi_turn_dataset_schema.json", + help="Path to the JSON schema file (default: scripts/multi_turn_dataset_schema.json)", ) parser.add_argument( "--max-errors", From 039f72ce0910ce0f8bd64e28d4f1b30f82d80c70 Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Tue, 28 Apr 2026 10:37:22 -0700 Subject: [PATCH 08/41] fix: address PR #285 review comments for multi-turn implementation Fix 15 review issues across severity levels: - HIGH: metadata=None crash in msgspec adapter, silent exception swallowing in gather - MEDIUM: timeout state consistency, conv_id canonicalization, PromptData fallback, conv_id guard - LOW: enum comparison, frozen config, empty tool_results warning, adapter metadata extraction, groupby deduplication, live-history tool warning, asyncio.Event docs, test TODO Co-Authored-By: Claude Opus 4.6 --- .../config/runtime_settings.py | 3 +- src/inference_endpoint/config/schema.py | 2 +- .../dataset_manager/multi_turn_dataset.py | 50 +++++++++++-------- .../load_generator/conversation_manager.py | 1 + .../load_generator/multi_turn_strategy.py | 32 +++++++++++- .../load_generator/session.py | 10 +++- .../openai/openai_adapter.py | 13 ++++- .../openai/openai_msgspec_adapter.py | 2 +- tests/integration/test_multi_turn.py | 6 ++- 9 files changed, 88 insertions(+), 31 deletions(-) diff --git a/src/inference_endpoint/config/runtime_settings.py b/src/inference_endpoint/config/runtime_settings.py index a3fb3106..eac1aa47 100644 --- a/src/inference_endpoint/config/runtime_settings.py +++ b/src/inference_endpoint/config/runtime_settings.py @@ -32,6 +32,7 @@ from typing import TYPE_CHECKING from .. import metrics +from .schema import LoadPatternType logger = logging.getLogger(__name__) @@ -197,7 +198,7 @@ def total_samples_to_issue( # Multi-turn must issue exactly all client turns β€” QPS-based formulas are meaningless. if ( self.load_pattern is not None - and self.load_pattern.type.value == "multi_turn" + and self.load_pattern.type == LoadPatternType.MULTI_TURN ): result = max(self.min_sample_count, self.n_samples_from_dataset) logger.debug( diff --git a/src/inference_endpoint/config/schema.py b/src/inference_endpoint/config/schema.py index 1f487cbe..b3362899 100644 --- a/src/inference_endpoint/config/schema.py +++ b/src/inference_endpoint/config/schema.py @@ -250,7 +250,7 @@ class MultiTurnConfig(BaseModel): use_dataset_history: If True, use pre-built message history from dataset. """ - model_config = {"extra": "forbid"} + model_config = ConfigDict(extra="forbid", frozen=True) mode: ConversationMode = ConversationMode.INDEPENDENT turn_timeout_s: float = 300.0 diff --git a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py index 574619c8..c75f285d 100644 --- a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py +++ b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py @@ -15,6 +15,7 @@ """Multi-turn conversation dataset for conversational AI benchmarking.""" +import logging from typing import Any import pandas as pd @@ -29,6 +30,8 @@ get_transforms_for_api_type, ) +logger = logging.getLogger(__name__) + def _expand_tool_results(row: dict) -> list[dict]: """Expand a tool row into one OpenAI tool message per result. @@ -41,6 +44,13 @@ def _expand_tool_results(row: dict) -> list[dict]: tool_results = row.get("tool_results") if not isinstance(tool_results, list): return [] + if not tool_results: + logger.warning( + "Row has empty tool_results list (conversation_id=%s, turn=%s)", + row.get("conversation_id"), + row.get("turn"), + ) + return [] return [ { "role": "tool", @@ -94,6 +104,8 @@ def __init__(self, dataframe: pd.DataFrame, **kwargs): ValueError: If conversation structure is invalid. """ super().__init__(dataframe, **kwargs) + assert self.dataframe is not None, "Dataframe must be initialized" + self._conv_groups = dict(list(self.dataframe.groupby("conversation_id"))) self._validate_conversation_grouping() self._validate_conversation_structure() self._validate_turn_numbering() @@ -128,10 +140,6 @@ def _validate_conversation_structure(self): Raises: ValueError: If any conversation has invalid role sequence. """ - assert self.dataframe is not None, "Dataframe must be initialized" - - # Valid state transitions (flat 4-state machine β€” no assistant_tc node, - # no toolβ†’tool; converter always merges consecutive tool rows into tool_results) VALID_NEXT: dict[str, set[str]] = { "start": {"user"}, "user": {"assistant"}, @@ -139,7 +147,7 @@ def _validate_conversation_structure(self): "tool": {"assistant", "user"}, } - for conv_id, group in self.dataframe.groupby("conversation_id"): + for conv_id, group in self._conv_groups.items(): sorted_group = group.sort_values("turn") state = "start" @@ -159,9 +167,7 @@ def _validate_turn_numbering(self): Raises: ValueError: If turn numbers are not exactly 1, 2, 3, …, N. """ - assert self.dataframe is not None, "Dataframe must be initialized" - - for conv_id, group in self.dataframe.groupby("conversation_id"): + for conv_id, group in self._conv_groups.items(): turns = sorted(group["turn"].tolist()) expected = list(range(1, len(turns) + 1)) if turns != expected: @@ -180,14 +186,13 @@ def _build_metadata(self) -> dict[str, Any]: Metadata dict with samples list, num_conversations, max_turns_per_conv, client_turns_per_conversation, and pre_built_messages_by_key. """ - assert self.dataframe is not None, "Dataframe must be initialized" samples = [] - client_turns_df = self.dataframe[self.dataframe["role"].isin(["user", "tool"])] # Count client turns (user + tool) per conversation for completion tracking - client_turns_per_conv = ( - client_turns_df.groupby("conversation_id").size().to_dict() - ) + client_turns_per_conv = { + str(conv_id): int(group["role"].isin(["user", "tool"]).sum()) + for conv_id, group in self._conv_groups.items() + } # Map (conversation_id, turn) β†’ complete message list ready to send to endpoint. # Each entry is: [system (optional)] + all prior rows formatted as messages @@ -198,7 +203,7 @@ def _build_metadata(self) -> dict[str, Any]: current_turn_messages_by_key: dict[tuple, list[dict]] = {} system_prompts_by_conv: dict[str, str | None] = {} - for conv_id, group in self.dataframe.groupby("conversation_id"): + for conv_id, group in self._conv_groups.items(): sorted_group = group.sort_values("turn") client_rows = sorted_group[sorted_group["role"].isin(["user", "tool"])] @@ -263,23 +268,24 @@ def _build_metadata(self) -> dict[str, Any]: current_turn_msgs = [cur] messages.extend(current_turn_msgs) - pre_built_messages_by_key[(conv_id, t_n)] = messages - current_turn_messages_by_key[(conv_id, t_n)] = current_turn_msgs + str_conv_id = str(conv_id) + pre_built_messages_by_key[(str_conv_id, t_n)] = messages + current_turn_messages_by_key[(str_conv_id, t_n)] = current_turn_msgs samples.append( { "index": idx, - "conversation_id": conv_id, + "conversation_id": str_conv_id, "turn": t_n, } ) return { "samples": samples, - "num_conversations": self.dataframe["conversation_id"].nunique(), - "max_turns_per_conv": self.dataframe.groupby("conversation_id")["turn"] - .max() - .max(), + "num_conversations": len(self._conv_groups), + "max_turns_per_conv": max( + g["turn"].max() for g in self._conv_groups.values() + ), "client_turns_per_conversation": client_turns_per_conv, "pre_built_messages_by_key": pre_built_messages_by_key, "current_turn_messages_by_key": current_turn_messages_by_key, @@ -388,7 +394,7 @@ def load( sample["stream"] = False # Attach pre-built message list (system + history + current turn). - key = (row["conversation_id"], int(row["turn"])) + key = (str(row["conversation_id"]), int(row["turn"])) messages = pre_built.get(key, []) sample["messages"] = messages diff --git a/src/inference_endpoint/load_generator/conversation_manager.py b/src/inference_endpoint/load_generator/conversation_manager.py index 56bb8278..30276d5e 100644 --- a/src/inference_endpoint/load_generator/conversation_manager.py +++ b/src/inference_endpoint/load_generator/conversation_manager.py @@ -42,6 +42,7 @@ class ConversationState: """ conversation_id: str + # Python 3.12+: asyncio.Event no longer requires a running loop at construction. turn_done: asyncio.Event = field(default_factory=asyncio.Event) message_history: list[dict[str, Any]] = field(default_factory=list) completed_turns: int = 0 diff --git a/src/inference_endpoint/load_generator/multi_turn_strategy.py b/src/inference_endpoint/load_generator/multi_turn_strategy.py index 0f3ba6e8..69697129 100644 --- a/src/inference_endpoint/load_generator/multi_turn_strategy.py +++ b/src/inference_endpoint/load_generator/multi_turn_strategy.py @@ -134,7 +134,12 @@ async def execute(self, phase_issuer: PhaseIssuerProtocol) -> int: for conv_id, turns in conv_samples.items() ] - await asyncio.gather(*tasks, return_exceptions=True) + results = await asyncio.gather(*tasks, return_exceptions=True) + errors = [r for r in results if isinstance(r, BaseException)] + for err in errors: + logger.error(f"Conversation pipeline failed: {err}") + if errors: + raise errors[0] return phase_issuer.issued_count async def _conv_pipeline( @@ -150,6 +155,7 @@ async def _conv_pipeline( """ state = self._conv_states[conv_id] sorted_turns = sorted(turns, key=lambda x: x[1]) + last_query_id: str | None = None for i, (idx, turn) in enumerate(sorted_turns): if i > 0: @@ -161,7 +167,13 @@ async def _conv_pipeline( logger.warning( f"Turn {turn} of {conv_id} timed out waiting for previous turn" ) - state.failed_turns += 1 + if last_query_id is not None: + self._inflight.pop(last_query_id, None) + remaining = len(sorted_turns) - i + for _ in range(remaining): + self._conv_manager.mark_turn_failed( + conv_id, store_in_history=self._store_in_history + ) break state.turn_done.clear() @@ -177,6 +189,17 @@ async def _conv_pipeline( "current_turn_messages_by_key", {} ).get((conv_id, turn)) if current_turn_messages: + has_tool_msg = any( + m.get("role") == "tool" for m in current_turn_messages + ) + if has_tool_msg: + logger.warning( + "Live-history mode with tool messages uses dataset " + "tool_call_ids; real endpoint IDs will differ " + "(conv=%s, turn=%d)", + conv_id, + turn, + ) live_messages = state.message_history.copy() + current_turn_messages data_override = {"messages": live_messages} @@ -188,6 +211,7 @@ async def _conv_pipeline( break self._inflight[query_id] = conv_id + last_query_id = query_id # Append current-turn messages to history so the next turn sees them. if self._store_in_history and current_turn_messages: @@ -219,6 +243,10 @@ def on_sample_complete(self, result: QueryResult) -> None: if conv_id is None: return + if self._conv_manager.get_state(conv_id) is None: + logger.warning(f"on_sample_complete: unknown conversation {conv_id}") + return + response_text = result.get_response_output_string() if result.error is not None: diff --git a/src/inference_endpoint/load_generator/session.py b/src/inference_endpoint/load_generator/session.py index 2e4f67ef..f4cfc178 100644 --- a/src/inference_endpoint/load_generator/session.py +++ b/src/inference_endpoint/load_generator/session.py @@ -208,9 +208,15 @@ def issue( # meaningful for ISL reporting on text-only prompts. # Therefore, setting `text=None` for non-string prompts # means that ISL reporting will be unavailable for multimodal samples. - prompt = data.get("prompt") + prompt_text = data.get("prompt") + if prompt_text is None and "messages" in data: + prompt_text = " ".join( + m.get("content", "") + for m in data["messages"] + if isinstance(m, dict) and m.get("content") + ) prompt_data = PromptData( - text=prompt if isinstance(prompt, str) else None, + text=prompt_text if isinstance(prompt_text, str) else None, token_ids=tuple(token_ids) if token_ids is not None else None, ) else: diff --git a/src/inference_endpoint/openai/openai_adapter.py b/src/inference_endpoint/openai/openai_adapter.py index 9c6f6ebd..85f208ca 100644 --- a/src/inference_endpoint/openai/openai_adapter.py +++ b/src/inference_endpoint/openai/openai_adapter.py @@ -14,6 +14,7 @@ # limitations under the License. import time +from typing import Any import msgspec from inference_endpoint.core.types import Query, QueryResult, TextModelOutput @@ -128,9 +129,19 @@ def from_endpoint_response( if result_id is None: result_id = response.id + choice = response.choices[0] + metadata: dict[str, Any] = {} + if choice.finish_reason: + metadata["finish_reason"] = choice.finish_reason.value + if choice.message.tool_calls: + metadata["tool_calls"] = [ + tc.model_dump(mode="json") for tc in choice.message.tool_calls + ] + return QueryResult( id=result_id, - response_output=TextModelOutput(output=response.choices[0].message.content), + response_output=TextModelOutput(output=choice.message.content), + metadata=metadata, ) @classmethod diff --git a/src/inference_endpoint/openai/openai_msgspec_adapter.py b/src/inference_endpoint/openai/openai_msgspec_adapter.py index e8f15ce6..e512e22b 100644 --- a/src/inference_endpoint/openai/openai_msgspec_adapter.py +++ b/src/inference_endpoint/openai/openai_msgspec_adapter.py @@ -219,7 +219,7 @@ def from_endpoint_response( return QueryResult( id=result_id or response.id, response_output=TextModelOutput(output=choice.message.content or ""), - metadata=metadata if metadata else None, + metadata=metadata, ) @classmethod diff --git a/tests/integration/test_multi_turn.py b/tests/integration/test_multi_turn.py index 87351700..8ea3666f 100644 --- a/tests/integration/test_multi_turn.py +++ b/tests/integration/test_multi_turn.py @@ -557,7 +557,11 @@ async def test_conversation_ending_with_tool_row(echo_server): @pytest.mark.integration @pytest.mark.asyncio async def test_tools_field_forwarded_to_endpoint(echo_server): - """The 'tools' array from the dataset reaches the endpoint in every request payload.""" + """The 'tools' array from the dataset reaches the endpoint in every request payload. + + TODO: Add a tool-call-aware server that returns dynamic tool_call_ids to + validate live-history mode with real tool_call_id round-tripping. + """ received_payloads: list[dict] = [] class CapturingEchoServer(EchoServer): From c53e5d5cb911912564cad6e6260b0622340dc1ab Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Tue, 28 Apr 2026 13:18:31 -0700 Subject: [PATCH 09/41] fix: improve multi-turn PromptData text and add concurrent stress test Use newline separators (instead of spaces) when flattening messages to text for ISL estimation, and add a 12-conversation concurrent stress test. Co-Authored-By: Claude Sonnet 4.6 --- .../load_generator/session.py | 7 +-- tests/integration/test_multi_turn.py | 46 +++++++++++++++++++ 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/src/inference_endpoint/load_generator/session.py b/src/inference_endpoint/load_generator/session.py index f4cfc178..f6b07c90 100644 --- a/src/inference_endpoint/load_generator/session.py +++ b/src/inference_endpoint/load_generator/session.py @@ -210,11 +210,12 @@ def issue( # means that ISL reporting will be unavailable for multimodal samples. prompt_text = data.get("prompt") if prompt_text is None and "messages" in data: - prompt_text = " ".join( - m.get("content", "") + parts: list[str] = [ + m["content"] for m in data["messages"] if isinstance(m, dict) and m.get("content") - ) + ] + prompt_text = "\n".join(parts) if parts else None prompt_data = PromptData( text=prompt_text if isinstance(prompt_text, str) else None, token_ids=tuple(token_ids) if token_ids is not None else None, diff --git a/tests/integration/test_multi_turn.py b/tests/integration/test_multi_turn.py index 8ea3666f..cfe8a68c 100644 --- a/tests/integration/test_multi_turn.py +++ b/tests/integration/test_multi_turn.py @@ -554,6 +554,52 @@ async def test_conversation_ending_with_tool_row(echo_server): assert len(responses) == 2 +@pytest.mark.integration +@pytest.mark.asyncio +async def test_concurrent_conversations_stress(echo_server): + """12 conversations Γ— 3 turns each complete with correct counts.""" + num_convs = 12 + turns_per_conv = 3 # 2 user turns + 1 assistant turn each + rows = [] + for i in range(num_convs): + conv_id = f"stress_conv_{i}" + rows.append( + { + "conversation_id": conv_id, + "turn": 1, + "role": "user", + "content": f"Q1-{i}", + } + ) + rows.append( + { + "conversation_id": conv_id, + "turn": 2, + "role": "assistant", + "content": f"A1-{i}", + } + ) + rows.append( + { + "conversation_id": conv_id, + "turn": 3, + "role": "user", + "content": f"Q2-{i}", + } + ) + + ds = _make_dataset(rows) + strategy = _make_strategy(ds) + responses: dict = {} + + count = await _run_session(echo_server.url, ds, strategy, responses) + + # 12 conversations Γ— 2 client turns each = 24 + expected_client_turns = num_convs * (turns_per_conv - 1) # 24 + assert count == expected_client_turns + assert len(responses) == expected_client_turns + + @pytest.mark.integration @pytest.mark.asyncio async def test_tools_field_forwarded_to_endpoint(echo_server): From 7495a453862c4b3371777718b0610aa385bf9ea7 Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Mon, 4 May 2026 11:09:32 -0700 Subject: [PATCH 10/41] refactor: replace semaphore with worker-pool concurrency in MultiTurnStrategy target_concurrency now limits active conversations (not in-flight requests). N worker tasks pull from asyncio.Queue, each processing one full conversation before taking the next. Also adds slots=True back to PhaseConfig and sort=False to groupby for file-order preservation. Co-Authored-By: Claude Sonnet 4.6 --- .../dataset_manager/multi_turn_dataset.py | 4 +- .../load_generator/multi_turn_strategy.py | 81 ++++++++-------- .../load_generator/session.py | 2 +- .../test_multi_turn_strategy.py | 96 +++++++++++++++---- 4 files changed, 126 insertions(+), 57 deletions(-) diff --git a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py index c75f285d..cabfac79 100644 --- a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py +++ b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py @@ -105,7 +105,9 @@ def __init__(self, dataframe: pd.DataFrame, **kwargs): """ super().__init__(dataframe, **kwargs) assert self.dataframe is not None, "Dataframe must be initialized" - self._conv_groups = dict(list(self.dataframe.groupby("conversation_id"))) + self._conv_groups = dict( + list(self.dataframe.groupby("conversation_id", sort=False)) + ) self._validate_conversation_grouping() self._validate_conversation_structure() self._validate_turn_numbering() diff --git a/src/inference_endpoint/load_generator/multi_turn_strategy.py b/src/inference_endpoint/load_generator/multi_turn_strategy.py index 69697129..1ce3780e 100644 --- a/src/inference_endpoint/load_generator/multi_turn_strategy.py +++ b/src/inference_endpoint/load_generator/multi_turn_strategy.py @@ -32,18 +32,16 @@ class MultiTurnStrategy: - """Async multi-turn strategy. Spawns per-conversation asyncio.Tasks. + """Async multi-turn strategy. Uses a worker-pool to limit active conversations. - Each conversation runs as an independent asyncio.Task that enforces - sequential turn ordering: turn N+1 cannot be issued until turn N completes. - Conversations run concurrently β€” no cross-conversation synchronization. - - Optional target_concurrency limits total in-flight requests across all - conversations using asyncio.Semaphore. + N worker tasks pull from a queue of conversations. Each worker processes all + turns of one conversation before moving to the next, so at most N conversations + are active simultaneously. When target_concurrency is None, all conversations + run concurrently (one worker per conversation). Integration with BenchmarkSession: - - execute(): spawns conversation tasks, awaits all to complete - - on_query_complete(): releases semaphore slot (concurrency control only) + - execute(): populates queue, spawns workers, awaits all to complete + - on_query_complete(): no-op (required by LoadStrategy protocol) - on_sample_complete(): routes completed QueryResult to ConversationManager The response routing path: @@ -69,7 +67,8 @@ def __init__( conversation_manager: Manages conversation sequencing state. dataset_metadata: Metadata from MultiTurnDataset (samples list). multi_turn_config: Multi-turn conversation configuration. - target_concurrency: Optional maximum concurrent in-flight requests. + target_concurrency: Maximum number of simultaneously active conversations. + None means all conversations run concurrently. """ self._conv_manager = conversation_manager self._dataset_metadata = dataset_metadata @@ -80,11 +79,6 @@ def __init__( else _DEFAULT_TURN_TIMEOUT_S ) self._target_concurrency = target_concurrency - self._sem: asyncio.Semaphore | None = ( - asyncio.Semaphore(target_concurrency) - if target_concurrency is not None and target_concurrency > 0 - else None - ) self._store_in_history = ( not multi_turn_config.use_dataset_history if multi_turn_config is not None @@ -110,7 +104,7 @@ async def execute(self, phase_issuer: PhaseIssuerProtocol) -> int: conv_id = sample_meta["conversation_id"] conv_samples[conv_id].append((sample_index, sample_meta["turn"])) - # Pre-create all conversation states before spawning tasks (no locking needed). + # Pre-create all conversation states before spawning workers (no locking needed). sys_prompts = self._dataset_metadata.get("system_prompts_by_conv", {}) for conv_id, turns in conv_samples.items(): sys_content = sys_prompts.get(conv_id) if self._store_in_history else None @@ -126,15 +120,27 @@ async def execute(self, phase_issuer: PhaseIssuerProtocol) -> int: ) self._conv_states[conv_id] = state - tasks = [ + # Build queue of (conv_id, turns) pairs for workers to pull from. + conv_queue: asyncio.Queue[tuple[str, list[tuple[int, int]]]] = asyncio.Queue() + for conv_id, turns in conv_samples.items(): + await conv_queue.put((conv_id, turns)) + + n_conversations = len(conv_samples) + n_workers = ( + min(self._target_concurrency, n_conversations) + if self._target_concurrency is not None and self._target_concurrency > 0 + else n_conversations + ) + + worker_tasks = [ asyncio.create_task( - self._conv_pipeline(conv_id, turns, phase_issuer), - name=f"mt-pipeline-{conv_id}", + self._worker(conv_queue, phase_issuer), + name=f"mt-worker-{i}", ) - for conv_id, turns in conv_samples.items() + for i in range(n_workers) ] - results = await asyncio.gather(*tasks, return_exceptions=True) + results = await asyncio.gather(*worker_tasks, return_exceptions=True) errors = [r for r in results if isinstance(r, BaseException)] for err in errors: logger.error(f"Conversation pipeline failed: {err}") @@ -142,6 +148,19 @@ async def execute(self, phase_issuer: PhaseIssuerProtocol) -> int: raise errors[0] return phase_issuer.issued_count + async def _worker( + self, + conv_queue: asyncio.Queue[tuple[str, list[tuple[int, int]]]], + phase_issuer: PhaseIssuerProtocol, + ) -> None: + """Pull conversations from queue and process each one fully before taking the next.""" + while True: + try: + conv_id, turns = conv_queue.get_nowait() + except asyncio.QueueEmpty: + break + await self._conv_pipeline(conv_id, turns, phase_issuer) + async def _conv_pipeline( self, conv_id: str, @@ -177,10 +196,6 @@ async def _conv_pipeline( break state.turn_done.clear() - # Acquire concurrency slot before issuing. - if self._sem is not None: - await self._sem.acquire() - # Live-history mode: build messages from accumulated history + current turn. data_override: dict[str, Any] | None = None current_turn_messages: list[dict[str, Any]] | None = None @@ -205,9 +220,7 @@ async def _conv_pipeline( query_id = phase_issuer.issue(idx, data_override=data_override) if query_id is None: - # Session stopping β€” release slot and exit. - if self._sem is not None: - self._sem.release() + # Session stopping β€” exit pipeline. break self._inflight[query_id] = conv_id @@ -218,16 +231,8 @@ async def _conv_pipeline( state.message_history.extend(current_turn_messages) def on_query_complete(self, query_id: str) -> None: - """Called by BenchmarkSession when a QueryResult arrives. - - Releases the concurrency semaphore slot. Response routing is done - via on_sample_complete (which receives the full QueryResult). - - Args: - query_id: ID of the completed query. - """ - if self._sem is not None: - self._sem.release() + """No-op. Required by LoadStrategy protocol; called by BenchmarkSession.""" + pass def on_sample_complete(self, result: QueryResult) -> None: """Route completed QueryResult to ConversationManager. diff --git a/src/inference_endpoint/load_generator/session.py b/src/inference_endpoint/load_generator/session.py index f6b07c90..f15265f8 100644 --- a/src/inference_endpoint/load_generator/session.py +++ b/src/inference_endpoint/load_generator/session.py @@ -60,7 +60,7 @@ class PhaseType(str, Enum): WARMUP = "warmup" -@dataclass(frozen=True) +@dataclass(frozen=True, slots=True) class PhaseConfig: """Configuration for a single benchmark phase.""" diff --git a/tests/unit/load_generator/test_multi_turn_strategy.py b/tests/unit/load_generator/test_multi_turn_strategy.py index 1eecc75d..37cdfbad 100644 --- a/tests/unit/load_generator/test_multi_turn_strategy.py +++ b/tests/unit/load_generator/test_multi_turn_strategy.py @@ -209,23 +209,6 @@ async def test_turn_timeout_triggers_failure(): assert issuer.issued_count == 1 -@pytest.mark.unit -@pytest.mark.asyncio -async def test_on_query_complete_releases_semaphore(): - """on_query_complete releases the concurrency semaphore.""" - conv_manager = ConversationManager() - metadata = _make_dataset_metadata({"conv1": [1]}) - strategy = MultiTurnStrategy(conv_manager, metadata, target_concurrency=1) - assert strategy._sem is not None - - # Acquire the semaphore manually - await strategy._sem.acquire() - assert strategy._sem._value == 0 # type: ignore[attr-defined] - - strategy.on_query_complete("some-query") - assert strategy._sem._value == 1 # type: ignore[attr-defined] - - @pytest.mark.unit @pytest.mark.asyncio async def test_on_sample_complete_routes_to_manager(): @@ -483,3 +466,82 @@ async def test_on_sample_complete_passes_metadata(): assert len(state.message_history) == 1 assert state.message_history[0]["tool_calls"] == tool_calls assert state.message_history[0]["content"] is None + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_concurrency_limits_active_conversations(): + """target_concurrency=2 starts at most 2 conversation pipelines simultaneously. + + Uses 2-turn conversations so each pipeline has an await point (turn_done.wait + between turns). With 4 conversations and 2 workers, the 3rd and 4th conversations + cannot start until a worker finishes its current conversation. + """ + conv_manager = ConversationManager() + # 4 two-turn conversations; pipeline awaits turn-1 response before issuing turn-2 + metadata = _make_dataset_metadata( + {"conv1": [1, 2], "conv2": [1, 2], "conv3": [1, 2], "conv4": [1, 2]} + ) + strategy = MultiTurnStrategy(conv_manager, metadata, target_concurrency=2) + issuer = FakePhaseIssuer() + + async def auto_respond(): + already_done = 0 + while True: + while already_done < len(issuer.issued): + idx = issuer.issued[already_done] + q = f"q{idx:04d}" + strategy.on_sample_complete( + QueryResult(id=q, response_output=TextModelOutput(output="r")) + ) + already_done += 1 + await asyncio.sleep(0.02) + + responder_task = asyncio.create_task(auto_respond()) + execute_task = asyncio.create_task(strategy.execute(issuer)) + + # Let both workers start and block on turn_done.wait before auto_respond fires + await asyncio.sleep(0.01) + + # Only 2 workers β†’ exactly 2 turn-1 queries issued (conv3/conv4 not started yet) + assert issuer.issued_count == 2 + + await asyncio.wait_for(execute_task, timeout=5.0) + responder_task.cancel() + + assert issuer.issued_count == 8 # 4 conversations Γ— 2 turns + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_conversation_slot_reuse(): + """With target_concurrency=1, worker completes conv1 before starting conv2. + + Uses 2-turn conversations so the pipeline has an await between turns. + The single worker must process both turns of conv1 before conv2's turn 1 is issued. + """ + conv_manager = ConversationManager() + # 2 two-turn conversations; sample indices: conv1β†’[0,1], conv2β†’[2,3] + metadata = _make_dataset_metadata({"conv1": [1, 2], "conv2": [1, 2]}) + strategy = MultiTurnStrategy(conv_manager, metadata, target_concurrency=1) + issuer = FakePhaseIssuer() + + async def auto_respond(): + already_done = 0 + while True: + while already_done < len(issuer.issued): + idx = issuer.issued[already_done] + q = f"q{idx:04d}" + strategy.on_sample_complete( + QueryResult(id=q, response_output=TextModelOutput(output="r")) + ) + already_done += 1 + await asyncio.sleep(0.02) + + responder_task = asyncio.create_task(auto_respond()) + await strategy.execute(issuer) + responder_task.cancel() + + # Single worker: conv1 turns (samples 0,1) must be issued before conv2 turns (2,3) + assert issuer.issued[:2] == [0, 1], "Conv1 turns should be issued before conv2" + assert issuer.issued[2:] == [2, 3], "Conv2 turns should follow conv1" From 7aa45f5c4cac0743ded72f30556d1924f01e8fcf Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Mon, 4 May 2026 13:25:55 -0700 Subject: [PATCH 11/41] fix: address remaining PR #285 review comments for multi-turn implementation - openai_adapter: normalize null content to "" instead of literal "None" to avoid polluting conversation history in tool-calling responses - multi_turn_dataset: validate tool_results entries have required tool_call_id and content fields; raise InputValidationError at load time - multi_turn_dataset: remove unused "index" field from samples metadata - multi_turn_strategy: wrap mark_turn_complete/mark_turn_failed in try/except KeyError in on_sample_complete - multi_turn_strategy: clear _inflight at end of execute() with warning if entries remain (transport failure or session abort) - docs: remove prescriptive concurrency sizing guide; replace with definition of what target_concurrency controls - docs: rename "Long Conversations" to "Conversations with Many Turns" - docs: add dataset validation utility reference in Troubleshooting Co-Authored-By: Claude Sonnet 4.6 --- docs/MULTI_TURN_QUICKSTART.md | 19 +++++++---- .../dataset_manager/multi_turn_dataset.py | 30 ++++++++++------ .../load_generator/multi_turn_strategy.py | 34 ++++++++++++++----- .../openai/openai_adapter.py | 2 +- .../test_multi_turn_dataset.py | 1 - 5 files changed, 58 insertions(+), 28 deletions(-) diff --git a/docs/MULTI_TURN_QUICKSTART.md b/docs/MULTI_TURN_QUICKSTART.md index 99b35aa5..4e8e5e58 100644 --- a/docs/MULTI_TURN_QUICKSTART.md +++ b/docs/MULTI_TURN_QUICKSTART.md @@ -152,12 +152,6 @@ settings: target_concurrency: 32 # ← Limit to 32 concurrent requests ``` -**Sizing guide**: - -- Small (< 50 convs): `target_concurrency: 32` -- Medium (50-500 convs): `target_concurrency: 64` -- Large (500+ convs): `target_concurrency: 96` or higher - --- ## Common Configurations @@ -194,7 +188,7 @@ settings: workers: 16 # More workers for parallel conversations ``` -### Long Conversations +### Conversations with Many Turns ```yaml multi_turn: @@ -211,6 +205,17 @@ settings: ## Troubleshooting +### Validate Your Dataset Before Running + +Use the bundled validation script to check your JSONL file for schema errors before benchmarking: + +```bash +python scripts/validate_jsonl_schema.py path/to/your/conversations.jsonl +``` + +This catches missing required fields, invalid role sequences, non-consecutive turn numbers, and +interleaved conversations β€” all errors that would otherwise surface at benchmark startup. + ### "Conversation has invalid role sequence" **Problem**: Your dataset doesn't follow a valid role sequence. diff --git a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py index cabfac79..d2f21695 100644 --- a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py +++ b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py @@ -51,14 +51,24 @@ def _expand_tool_results(row: dict) -> list[dict]: row.get("turn"), ) return [] - return [ - { - "role": "tool", - "tool_call_id": result.get("tool_call_id"), - "content": result.get("content"), - } - for result in tool_results - ] + messages = [] + for i, result in enumerate(tool_results): + tool_call_id = result.get("tool_call_id") + content = result.get("content") + if tool_call_id is None: + raise InputValidationError( + f"tool_results[{i}] in conversation {row.get('conversation_id')!r} " + f"turn {row.get('turn')} is missing required field 'tool_call_id'" + ) + if content is None: + raise InputValidationError( + f"tool_results[{i}] in conversation {row.get('conversation_id')!r} " + f"turn {row.get('turn')} is missing required field 'content'" + ) + messages.append( + {"role": "tool", "tool_call_id": tool_call_id, "content": content} + ) + return messages class MultiTurnDataset(Dataset, dataset_id="multi_turn_conversations"): @@ -205,6 +215,7 @@ def _build_metadata(self) -> dict[str, Any]: current_turn_messages_by_key: dict[tuple, list[dict]] = {} system_prompts_by_conv: dict[str, str | None] = {} + assert self.dataframe is not None, "Dataframe must be initialized" for conv_id, group in self._conv_groups.items(): sorted_group = group.sort_values("turn") client_rows = sorted_group[sorted_group["role"].isin(["user", "tool"])] @@ -218,7 +229,7 @@ def _build_metadata(self) -> dict[str, Any]: break system_prompts_by_conv[str(conv_id)] = system_content - for idx, row in client_rows.iterrows(): + for _, row in client_rows.iterrows(): t_n = int(row["turn"]) messages: list[dict] = [] @@ -276,7 +287,6 @@ def _build_metadata(self) -> dict[str, Any]: samples.append( { - "index": idx, "conversation_id": str_conv_id, "turn": t_n, } diff --git a/src/inference_endpoint/load_generator/multi_turn_strategy.py b/src/inference_endpoint/load_generator/multi_turn_strategy.py index 1ce3780e..42395723 100644 --- a/src/inference_endpoint/load_generator/multi_turn_strategy.py +++ b/src/inference_endpoint/load_generator/multi_turn_strategy.py @@ -144,6 +144,15 @@ async def execute(self, phase_issuer: PhaseIssuerProtocol) -> int: errors = [r for r in results if isinstance(r, BaseException)] for err in errors: logger.error(f"Conversation pipeline failed: {err}") + + if self._inflight: + logger.warning( + "%d query(ies) never received a response (session stop or transport failure): %s", + len(self._inflight), + list(self._inflight.keys()), + ) + self._inflight.clear() + if errors: raise errors[0] return phase_issuer.issued_count @@ -254,14 +263,21 @@ def on_sample_complete(self, result: QueryResult) -> None: response_text = result.get_response_output_string() - if result.error is not None: - self._conv_manager.mark_turn_failed( - conv_id, store_in_history=self._store_in_history - ) - else: - self._conv_manager.mark_turn_complete( + try: + if result.error is not None: + self._conv_manager.mark_turn_failed( + conv_id, store_in_history=self._store_in_history + ) + else: + self._conv_manager.mark_turn_complete( + conv_id, + response_text, + store_in_history=self._store_in_history, + metadata=result.metadata, + ) + except KeyError: + logger.warning( + "on_sample_complete: conversation %s not found in manager (result=%s)", conv_id, - response_text, - store_in_history=self._store_in_history, - metadata=result.metadata, + result.id, ) diff --git a/src/inference_endpoint/openai/openai_adapter.py b/src/inference_endpoint/openai/openai_adapter.py index 85f208ca..ca531ed4 100644 --- a/src/inference_endpoint/openai/openai_adapter.py +++ b/src/inference_endpoint/openai/openai_adapter.py @@ -186,5 +186,5 @@ def decode_endpoint_response( "content" not in response_dict["choices"][0]["message"] or response_dict["choices"][0]["message"]["content"] is None ): - response_dict["choices"][0]["message"]["content"] = "None" + response_dict["choices"][0]["message"]["content"] = "" return CreateChatCompletionResponse(**response_dict, ignore_extra=True) diff --git a/tests/unit/dataset_manager/test_multi_turn_dataset.py b/tests/unit/dataset_manager/test_multi_turn_dataset.py index a93a3a08..62301940 100644 --- a/tests/unit/dataset_manager/test_multi_turn_dataset.py +++ b/tests/unit/dataset_manager/test_multi_turn_dataset.py @@ -211,7 +211,6 @@ def test_multi_turn_dataset_conversation_metadata(valid_multi_turn_jsonl): # Check sample metadata structure sample_meta = metadata["samples"][0] - assert "index" in sample_meta assert "conversation_id" in sample_meta assert "turn" in sample_meta From aedbbe673954851ae34fbb553378364e4f558998 Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Mon, 4 May 2026 14:15:30 -0700 Subject: [PATCH 12/41] fix: address remaining PR #285 review comments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix refusal field set to literal string "None" instead of "" in openai_adapter.py β€” made downstream refusal checks incorrectly truthy - Add test_pipeline_error_propagated to verify execute() re-raises worker exceptions instead of swallowing them via gather(return_exceptions=True) - Clarify MultiTurnStrategy docstring and MULTI_TURN_QUICKSTART.md: target_concurrency = simultaneous conversations (not requests); each active conversation has exactly 1 in-flight turn at a time - Remove unjustified "Common Configurations" section from quickstart - Correct misleading "workers = concurrent conversations" tip; clarify client.workers and target_concurrency are independent layers Co-Authored-By: Claude Sonnet 4.6 --- docs/MULTI_TURN_QUICKSTART.md | 77 ++++--------------- .../load_generator/multi_turn_strategy.py | 9 ++- .../openai/openai_adapter.py | 2 +- .../test_multi_turn_strategy.py | 19 +++++ 4 files changed, 40 insertions(+), 67 deletions(-) diff --git a/docs/MULTI_TURN_QUICKSTART.md b/docs/MULTI_TURN_QUICKSTART.md index 4e8e5e58..f3d5e082 100644 --- a/docs/MULTI_TURN_QUICKSTART.md +++ b/docs/MULTI_TURN_QUICKSTART.md @@ -47,7 +47,7 @@ datasets: settings: load_pattern: type: multi_turn # ← Use multi-turn scheduler - target_concurrency: 32 # ← Required: max concurrent requests + target_concurrency: 32 # ← Required: max simultaneous conversations client: workers: 4 @@ -120,21 +120,18 @@ mode: independent **Behavior**: -- Issues turn-1 of ALL conversations at t=0 -- Then sequences turns within each conversation independently -- Maximum parallelism and throughput +- Up to `target_concurrency` conversations are active simultaneously +- Turns within each conversation are strictly sequenced (turn N+1 waits for turn N) +- Conversations run independently of each other β€” a short conversation can finish while a long one is still on turn 2 -**Use for**: Realistic production load where short conversations finish while long ones are still running. -For single-conversation debugging, use `mode: independent` with `target_concurrency: 1`. -Note: unlike the plain `ConcurrencyScheduler`, multi-turn + `target_concurrency: 1` still enforces -per-conversation turn ordering β€” turn N+1 waits for turn N even at concurrency 1. +**Use for**: Realistic production load simulation. For single-conversation debugging, set `target_concurrency: 1`. -**Example timeline**: +**Example timeline** (target_concurrency: 3, 4 conversations total): ``` -t=0: conv1-turn1, conv2-turn1, conv3-turn1 (all at once) +t=0: conv1-turn1, conv2-turn1, conv3-turn1 ← 3 conversations start t=0.5: conv1-turn2 (after conv1-turn1 completes) -t=0.7: conv2-turn2 (after conv2-turn1 completes) +t=0.7: conv2 finishes β†’ worker picks up conv4-turn1 t=0.8: conv1-turn3 (after conv1-turn2 completes) ... ``` @@ -143,62 +140,16 @@ t=0.8: conv1-turn3 (after conv1-turn2 completes) ## Concurrency Control -`target_concurrency` is **required** for the `multi_turn` load pattern. It limits the maximum number of in-flight requests across all conversations and prevents endpoint overload when many conversations run simultaneously. +`target_concurrency` is **required** for the `multi_turn` load pattern. It controls how many +conversations are active simultaneously. Each active conversation has exactly one in-flight turn +at a time β€” a worker issues turn N, waits for the response, then issues turn N+1. A new +conversation starts only after a worker finishes all turns of its current one. ```yaml settings: load_pattern: type: multi_turn - target_concurrency: 32 # ← Limit to 32 concurrent requests -``` - ---- - -## Common Configurations - -### Recommended: With Concurrency Control - -```yaml -multi_turn: - mode: independent - -settings: - load_pattern: - type: multi_turn - target_concurrency: 32 # ← Prevents overload - client: - workers: 8 - -datasets: - - samples: 100 -``` - -### High Throughput Testing - -```yaml -multi_turn: - mode: independent - turn_timeout_s: 600 - -settings: - load_pattern: - type: multi_turn - target_concurrency: 96 - client: - workers: 16 # More workers for parallel conversations -``` - -### Conversations with Many Turns - -```yaml -multi_turn: - mode: independent - turn_timeout_s: 1800 # 30 minutes for slow responses - -settings: - load_pattern: - type: multi_turn - target_concurrency: 32 + target_concurrency: 32 # ← 32 conversations active simultaneously ``` --- @@ -329,7 +280,7 @@ jq -r '.conversation_id' logs/multi_turn_test/events.jsonl | sort -u ### Performance -- **Workers**: Set `workers` = number of concurrent conversations +- **Workers**: `client.workers` controls HTTP worker processes, independent of `target_concurrency`. The default (`-1`) auto-tunes based on NUMA topology. - **Timeout**: Set `turn_timeout_s` = 2x your longest expected turn latency - **Memory**: ~1KB per turn, plan accordingly for large datasets diff --git a/src/inference_endpoint/load_generator/multi_turn_strategy.py b/src/inference_endpoint/load_generator/multi_turn_strategy.py index 42395723..4297863b 100644 --- a/src/inference_endpoint/load_generator/multi_turn_strategy.py +++ b/src/inference_endpoint/load_generator/multi_turn_strategy.py @@ -35,9 +35,12 @@ class MultiTurnStrategy: """Async multi-turn strategy. Uses a worker-pool to limit active conversations. N worker tasks pull from a queue of conversations. Each worker processes all - turns of one conversation before moving to the next, so at most N conversations - are active simultaneously. When target_concurrency is None, all conversations - run concurrently (one worker per conversation). + turns of one conversation before moving to the next. At most N conversations + are active simultaneously, each with exactly 1 in-flight turn β€” a worker + issues turn N, waits for the response, then issues turn N+1. A new conversation + starts only after the worker finishes all turns of its current one. When + target_concurrency is None, all conversations run concurrently (one worker per + conversation). Integration with BenchmarkSession: - execute(): populates queue, spawns workers, awaits all to complete diff --git a/src/inference_endpoint/openai/openai_adapter.py b/src/inference_endpoint/openai/openai_adapter.py index ca531ed4..a458688c 100644 --- a/src/inference_endpoint/openai/openai_adapter.py +++ b/src/inference_endpoint/openai/openai_adapter.py @@ -180,7 +180,7 @@ def decode_endpoint_response( response_dict = msgspec.json.decode(response_bytes) # Set default values for optional fields if missing - response_dict["choices"][0]["message"]["refusal"] = "None" + response_dict["choices"][0]["message"]["refusal"] = "" response_dict["choices"][0]["logprobs"] = {"content": [], "refusal": []} if ( "content" not in response_dict["choices"][0]["message"] diff --git a/tests/unit/load_generator/test_multi_turn_strategy.py b/tests/unit/load_generator/test_multi_turn_strategy.py index 37cdfbad..7a9d8be1 100644 --- a/tests/unit/load_generator/test_multi_turn_strategy.py +++ b/tests/unit/load_generator/test_multi_turn_strategy.py @@ -363,6 +363,25 @@ async def complete_turn(): assert len(state.message_history) == 0 +@pytest.mark.unit +@pytest.mark.asyncio +async def test_pipeline_error_propagated(): + """execute() re-raises when a conversation pipeline raises an exception.""" + conv_manager = ConversationManager() + metadata = _make_dataset_metadata({"conv1": [1]}) + strategy = MultiTurnStrategy(conv_manager, metadata) + + class ErrorIssuer: + issued_count = 0 + issued: list[int] = [] + + def issue(self, idx: int, data_override: dict | None = None) -> str | None: + raise RuntimeError("simulated pipeline error") + + with pytest.raises(RuntimeError, match="simulated pipeline error"): + await strategy.execute(ErrorIssuer()) + + @pytest.mark.unit def test_mark_turn_complete_preserves_tool_calls(): """mark_turn_complete stores tool_calls in history when metadata contains them.""" From c3cd49751b3d321ada57a9c603b151fe644239ec Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Mon, 4 May 2026 15:06:28 -0700 Subject: [PATCH 13/41] refactor: replace worker-pool with event-driven model in MultiTurnStrategy Rewrites MultiTurnStrategy to issue subsequent turns synchronously inside on_sample_complete() (zero event-loop delay), removing pre-spawned worker tasks and per-conversation asyncio.Event waiting. ConversationState no longer holds an asyncio.Event; sequencing is driven entirely by the strategy. Addresses PR #285 reviewer request to move turn issuance into the sample-complete handler. Co-Authored-By: Claude Sonnet 4.6 --- .../load_generator/conversation_manager.py | 22 +- .../load_generator/multi_turn_strategy.py | 272 ++++++++++-------- .../test_multi_turn_conversation_manager.py | 72 +---- .../test_multi_turn_strategy.py | 45 +-- 4 files changed, 181 insertions(+), 230 deletions(-) diff --git a/src/inference_endpoint/load_generator/conversation_manager.py b/src/inference_endpoint/load_generator/conversation_manager.py index 30276d5e..1b0834bb 100644 --- a/src/inference_endpoint/load_generator/conversation_manager.py +++ b/src/inference_endpoint/load_generator/conversation_manager.py @@ -15,7 +15,6 @@ """Conversation state management for multi-turn benchmarking.""" -import asyncio import logging from dataclasses import dataclass, field from typing import Any @@ -27,13 +26,8 @@ class ConversationState: """Per-conversation state for multi-turn benchmarking. - The pipeline task awaits ``turn_done`` between turns; ``mark_turn_complete`` - and ``mark_turn_failed`` set it synchronously from ``on_sample_complete``. - Attributes: conversation_id: Unique identifier for this conversation. - turn_done: Event set when a response arrives. Pipeline waits, then clears - it before issuing the next turn. message_history: Accumulated message list (populated only when use_dataset_history=False; empty otherwise). completed_turns: Turns with responses (success or failure) β€” observability only. @@ -42,8 +36,6 @@ class ConversationState: """ conversation_id: str - # Python 3.12+: asyncio.Event no longer requires a running loop at construction. - turn_done: asyncio.Event = field(default_factory=asyncio.Event) message_history: list[dict[str, Any]] = field(default_factory=list) completed_turns: int = 0 failed_turns: int = 0 @@ -59,11 +51,11 @@ def is_complete(self) -> bool: class ConversationManager: """Manages per-conversation state for multi-turn benchmarking. - All methods are synchronous. The pipeline task uses ``ConversationState.turn_done`` - directly for turn-done notification β€” no locks or condition variables needed. + All methods are synchronous. Turn sequencing is driven by MultiTurnStrategy + which calls on_sample_complete() β†’ _issue_next_turn() directly. - All states are pre-created by ``MultiTurnStrategy.execute()`` before any pipeline - task starts, so ``get_or_create()`` requires no locking. + All states are pre-created by MultiTurnStrategy.execute() before any turns + are issued, so get_or_create() requires no locking. """ def __init__(self): @@ -126,7 +118,7 @@ def mark_turn_complete( store_in_history: bool = False, metadata: dict[str, Any] | None = None, ) -> None: - """Record a successful response and wake the pipeline task. + """Record a successful response. Args: conversation_id: Conversation ID. @@ -150,14 +142,13 @@ def mark_turn_complete( state.message_history.append(msg) state.completed_turns += 1 self._log_if_complete(state, conversation_id) - state.turn_done.set() def mark_turn_failed( self, conversation_id: str, store_in_history: bool = False, ) -> None: - """Record a failed response and wake the pipeline task. + """Record a failed response. Failed turns count toward completion so sequencing progresses under errors. @@ -179,4 +170,3 @@ def mark_turn_failed( state.failed_turns += 1 logger.warning(f"Turn failed for conversation {conversation_id}") self._log_if_complete(state, conversation_id) - state.turn_done.set() diff --git a/src/inference_endpoint/load_generator/multi_turn_strategy.py b/src/inference_endpoint/load_generator/multi_turn_strategy.py index 4297863b..d3f432d7 100644 --- a/src/inference_endpoint/load_generator/multi_turn_strategy.py +++ b/src/inference_endpoint/load_generator/multi_turn_strategy.py @@ -17,7 +17,8 @@ import asyncio import logging -from collections import defaultdict +from collections import defaultdict, deque +from collections.abc import Iterator from typing import Any from ..config.schema import MultiTurnConfig @@ -32,29 +33,28 @@ class MultiTurnStrategy: - """Async multi-turn strategy. Uses a worker-pool to limit active conversations. + """Event-driven multi-turn strategy. Completion of each turn triggers the next. - N worker tasks pull from a queue of conversations. Each worker processes all - turns of one conversation before moving to the next. At most N conversations - are active simultaneously, each with exactly 1 in-flight turn β€” a worker - issues turn N, waits for the response, then issues turn N+1. A new conversation - starts only after the worker finishes all turns of its current one. When - target_concurrency is None, all conversations run concurrently (one worker per - conversation). + execute() seeds the first N conversations (issues turn 1 for each), then + awaits _all_done. on_sample_complete() is called synchronously from the + receive coroutine for each response β€” it issues the next turn immediately + (zero event-loop iterations between response and next issuance), or starts + a new conversation when the current one finishes all turns. + + At most target_concurrency conversations are active simultaneously. When + target_concurrency is None, all conversations start at once. Integration with BenchmarkSession: - - execute(): populates queue, spawns workers, awaits all to complete + - execute(): seeds conversations, awaits completion - on_query_complete(): no-op (required by LoadStrategy protocol) - - on_sample_complete(): routes completed QueryResult to ConversationManager + - on_sample_complete(): routes completed QueryResult, issues next turn The response routing path: - 1. _conv_pipeline issues turn N via phase_issuer.issue(idx) β†’ query_id - 2. _conv_pipeline stores conv_id in _inflight[query_id] + 1. _issue_next_turn issues turn N via phase_issuer.issue(idx) β†’ query_id + 2. _issue_next_turn stores conv_id in _inflight[query_id] 3. BenchmarkSession calls on_sample_complete(result) with the QueryResult 4. on_sample_complete looks up conv_id from _inflight, calls mark_turn_complete - 5. mark_turn_complete sets state.turn_done synchronously - 6. _conv_pipeline's await asyncio.wait_for(state.turn_done.wait()) returns - 7. Pipeline clears the event and issues turn N+1 + 5. on_sample_complete calls _issue_next_turn for turn N+1 (synchronously) """ def __init__( @@ -93,6 +93,15 @@ def __init__( # Cached ConversationState refs for O(1) lookup in on_sample_complete. self._conv_states: dict[str, ConversationState] = {} + # Event-driven state β€” populated in execute(). + self._pending_convs: deque[tuple[str, list[tuple[int, int]]]] = deque() + self._active_iters: dict[str, Iterator[tuple[int, int]]] = {} + self._timeout_handles: dict[str, asyncio.TimerHandle] = {} + self._error: BaseException | None = None + self._all_done: asyncio.Event | None = None + self._loop: asyncio.AbstractEventLoop | None = None + self._phase_issuer: PhaseIssuerProtocol | None = None + async def execute(self, phase_issuer: PhaseIssuerProtocol) -> int: """Drive multi-turn sample issuance. @@ -102,12 +111,17 @@ async def execute(self, phase_issuer: PhaseIssuerProtocol) -> int: Returns: Total count of samples issued. """ + self._phase_issuer = phase_issuer + self._loop = asyncio.get_running_loop() + self._all_done = asyncio.Event() + self._error = None + conv_samples: dict[str, list[tuple[int, int]]] = defaultdict(list) for sample_index, sample_meta in enumerate(self._dataset_metadata["samples"]): conv_id = sample_meta["conversation_id"] conv_samples[conv_id].append((sample_index, sample_meta["turn"])) - # Pre-create all conversation states before spawning workers (no locking needed). + # Pre-create all conversation states before issuing any turns (no locking needed). sys_prompts = self._dataset_metadata.get("system_prompts_by_conv", {}) for conv_id, turns in conv_samples.items(): sys_content = sys_prompts.get(conv_id) if self._store_in_history else None @@ -123,30 +137,26 @@ async def execute(self, phase_issuer: PhaseIssuerProtocol) -> int: ) self._conv_states[conv_id] = state - # Build queue of (conv_id, turns) pairs for workers to pull from. - conv_queue: asyncio.Queue[tuple[str, list[tuple[int, int]]]] = asyncio.Queue() + # Build pending queue (sorted turns per conversation). for conv_id, turns in conv_samples.items(): - await conv_queue.put((conv_id, turns)) + self._pending_convs.append((conv_id, sorted(turns, key=lambda x: x[1]))) - n_conversations = len(conv_samples) - n_workers = ( - min(self._target_concurrency, n_conversations) + n_to_start = ( + min(self._target_concurrency, len(self._pending_convs)) if self._target_concurrency is not None and self._target_concurrency > 0 - else n_conversations + else len(self._pending_convs) ) + for _ in range(n_to_start): + self._start_conversation() - worker_tasks = [ - asyncio.create_task( - self._worker(conv_queue, phase_issuer), - name=f"mt-worker-{i}", - ) - for i in range(n_workers) - ] + if not self._active_iters and not self._inflight: + return phase_issuer.issued_count + + await self._all_done.wait() - results = await asyncio.gather(*worker_tasks, return_exceptions=True) - errors = [r for r in results if isinstance(r, BaseException)] - for err in errors: - logger.error(f"Conversation pipeline failed: {err}") + for handle in self._timeout_handles.values(): + handle.cancel() + self._timeout_handles.clear() if self._inflight: logger.warning( @@ -156,102 +166,111 @@ async def execute(self, phase_issuer: PhaseIssuerProtocol) -> int: ) self._inflight.clear() - if errors: - raise errors[0] + if self._error is not None: + raise self._error return phase_issuer.issued_count - async def _worker( - self, - conv_queue: asyncio.Queue[tuple[str, list[tuple[int, int]]]], - phase_issuer: PhaseIssuerProtocol, - ) -> None: - """Pull conversations from queue and process each one fully before taking the next.""" - while True: - try: - conv_id, turns = conv_queue.get_nowait() - except asyncio.QueueEmpty: - break - await self._conv_pipeline(conv_id, turns, phase_issuer) - - async def _conv_pipeline( - self, - conv_id: str, - turns: list[tuple[int, int]], - phase_issuer: PhaseIssuerProtocol, - ) -> None: - """Process all turns for a single conversation sequentially. - - For each turn after the first, waits for state.turn_done before issuing - the next. This enforces strict sequential ordering within the conversation. - """ + def _start_conversation(self) -> None: + """Pop the next conversation from the pending queue and issue its first turn.""" + conv_id, turns = self._pending_convs.popleft() + self._active_iters[conv_id] = iter(turns) + self._issue_next_turn(conv_id) + + def _issue_next_turn(self, conv_id: str) -> None: + """Issue the next turn for conv_id, or mark the conversation done.""" + it = self._active_iters.get(conv_id) + if it is None: + return + + pair = next(it, None) + if pair is None: + del self._active_iters[conv_id] + self._fill_slot() + return + + idx, turn = pair state = self._conv_states[conv_id] - sorted_turns = sorted(turns, key=lambda x: x[1]) - last_query_id: str | None = None - - for i, (idx, turn) in enumerate(sorted_turns): - if i > 0: - try: - await asyncio.wait_for( - state.turn_done.wait(), timeout=self._turn_timeout_s - ) - except TimeoutError: + + data_override: dict[str, Any] | None = None + current_turn_messages: list[dict[str, Any]] | None = None + if self._store_in_history: + current_turn_messages = self._dataset_metadata.get( + "current_turn_messages_by_key", {} + ).get((conv_id, turn)) + if current_turn_messages: + has_tool_msg = any( + m.get("role") == "tool" for m in current_turn_messages + ) + if has_tool_msg: logger.warning( - f"Turn {turn} of {conv_id} timed out waiting for previous turn" + "Live-history mode with tool messages uses dataset " + "tool_call_ids; real endpoint IDs will differ " + "(conv=%s, turn=%d)", + conv_id, + turn, ) - if last_query_id is not None: - self._inflight.pop(last_query_id, None) - remaining = len(sorted_turns) - i - for _ in range(remaining): - self._conv_manager.mark_turn_failed( - conv_id, store_in_history=self._store_in_history - ) - break - state.turn_done.clear() - - # Live-history mode: build messages from accumulated history + current turn. - data_override: dict[str, Any] | None = None - current_turn_messages: list[dict[str, Any]] | None = None - if self._store_in_history: - current_turn_messages = self._dataset_metadata.get( - "current_turn_messages_by_key", {} - ).get((conv_id, turn)) - if current_turn_messages: - has_tool_msg = any( - m.get("role") == "tool" for m in current_turn_messages - ) - if has_tool_msg: - logger.warning( - "Live-history mode with tool messages uses dataset " - "tool_call_ids; real endpoint IDs will differ " - "(conv=%s, turn=%d)", - conv_id, - turn, - ) - live_messages = state.message_history.copy() + current_turn_messages - data_override = {"messages": live_messages} - - query_id = phase_issuer.issue(idx, data_override=data_override) - if query_id is None: - # Session stopping β€” exit pipeline. - break - - self._inflight[query_id] = conv_id - last_query_id = query_id - - # Append current-turn messages to history so the next turn sees them. - if self._store_in_history and current_turn_messages: - state.message_history.extend(current_turn_messages) + live_messages = state.message_history.copy() + current_turn_messages + data_override = {"messages": live_messages} + + assert self._phase_issuer is not None + query_id = self._phase_issuer.issue(idx, data_override=data_override) + if query_id is None: + # Session stopping β€” signal done. + assert self._all_done is not None + self._all_done.set() + return + + self._inflight[query_id] = conv_id + + if self._store_in_history and current_turn_messages: + state.message_history.extend(current_turn_messages) + + assert self._loop is not None + handle = self._loop.call_later( + self._turn_timeout_s, self._handle_timeout, query_id, conv_id + ) + self._timeout_handles[query_id] = handle + + def _fill_slot(self) -> None: + """Start a new conversation from the pending queue, or signal all done.""" + if self._pending_convs: + self._start_conversation() + elif not self._active_iters: + assert self._all_done is not None + self._all_done.set() + + def _handle_timeout(self, query_id: str, conv_id: str) -> None: + """Called by the event loop when a turn response does not arrive in time.""" + if self._inflight.pop(query_id, None) is None: + return + self._timeout_handles.pop(query_id, None) + + logger.warning( + "Turn timed out for conversation %s (query=%s)", conv_id, query_id + ) + + self._conv_manager.mark_turn_failed( + conv_id, store_in_history=self._store_in_history + ) + it = self._active_iters.pop(conv_id, None) + if it is not None: + for _ in it: + self._conv_manager.mark_turn_failed( + conv_id, store_in_history=self._store_in_history + ) + + self._fill_slot() def on_query_complete(self, query_id: str) -> None: """No-op. Required by LoadStrategy protocol; called by BenchmarkSession.""" pass def on_sample_complete(self, result: QueryResult) -> None: - """Route completed QueryResult to ConversationManager. + """Route completed QueryResult to ConversationManager and issue next turn. - Called by execute.py on_sample_complete hook after each response. - Event.set() is synchronous β€” the pipeline task is woken immediately - without needing asyncio.ensure_future. + Called synchronously from BenchmarkSession._handle_response(). Issues the + next turn immediately (zero event-loop delay) or starts a new conversation + when this one finishes all turns. Args: result: Completed QueryResult from the endpoint. @@ -260,9 +279,9 @@ def on_sample_complete(self, result: QueryResult) -> None: if conv_id is None: return - if self._conv_manager.get_state(conv_id) is None: - logger.warning(f"on_sample_complete: unknown conversation {conv_id}") - return + handle = self._timeout_handles.pop(result.id, None) + if handle is not None: + handle.cancel() response_text = result.get_response_output_string() @@ -284,3 +303,12 @@ def on_sample_complete(self, result: QueryResult) -> None: conv_id, result.id, ) + return + + try: + self._issue_next_turn(conv_id) + except Exception as exc: + logger.error("Error issuing next turn for %s: %s", conv_id, exc) + self._error = exc + if self._all_done is not None: + self._all_done.set() diff --git a/tests/unit/load_generator/test_multi_turn_conversation_manager.py b/tests/unit/load_generator/test_multi_turn_conversation_manager.py index 331e6709..c389fb5f 100644 --- a/tests/unit/load_generator/test_multi_turn_conversation_manager.py +++ b/tests/unit/load_generator/test_multi_turn_conversation_manager.py @@ -29,7 +29,6 @@ def test_conversation_state_initialization(): state = ConversationState(conversation_id="conv_001") assert state.conversation_id == "conv_001" - assert not state.turn_done.is_set() assert state.message_history == [] assert state.completed_turns == 0 assert state.failed_turns == 0 @@ -95,7 +94,7 @@ def test_conversation_manager_multiple_conversations(): @pytest.mark.unit def test_conversation_manager_mark_turn_complete(): - """mark_turn_complete increments counter, appends history, sets event.""" + """mark_turn_complete increments counter and appends history.""" manager = ConversationManager() state = manager.get_or_create("conv_001") @@ -103,7 +102,6 @@ def test_conversation_manager_mark_turn_complete(): assert state.completed_turns == 1 assert state.failed_turns == 0 - assert state.turn_done.is_set() assert state.message_history == [] # store_in_history=False by default @@ -120,7 +118,7 @@ def test_conversation_manager_mark_turn_complete_stores_history(): @pytest.mark.unit def test_conversation_manager_mark_turn_failed(): - """mark_turn_failed increments both counters and sets event.""" + """mark_turn_failed increments both counters.""" manager = ConversationManager() state = manager.get_or_create("conv_001", expected_client_turns=2) @@ -128,7 +126,6 @@ def test_conversation_manager_mark_turn_failed(): assert state.completed_turns == 1 assert state.failed_turns == 1 - assert state.turn_done.is_set() @pytest.mark.unit @@ -187,71 +184,6 @@ def test_all_turns_fail(): assert state.failed_turns == 2 -@pytest.mark.unit -@pytest.mark.asyncio -async def test_event_set_wakes_waiter(): - """mark_turn_complete sets turn_done so a blocked await returns.""" - manager = ConversationManager() - state = manager.get_or_create("conv_001") - - woke_up: list[bool] = [] - - async def waiter(): - await state.turn_done.wait() - woke_up.append(True) - - task = asyncio.create_task(waiter()) - await asyncio.sleep(0.01) - assert not woke_up - - manager.mark_turn_complete("conv_001", "response") - await asyncio.sleep(0.01) - await task - - assert woke_up == [True] - - -@pytest.mark.unit -@pytest.mark.asyncio -async def test_failed_sets_event(): - """mark_turn_failed sets turn_done so the pipeline can unblock.""" - manager = ConversationManager() - state = manager.get_or_create("conv_001") - - woke_up: list[bool] = [] - - async def waiter(): - await state.turn_done.wait() - woke_up.append(True) - - task = asyncio.create_task(waiter()) - await asyncio.sleep(0.01) - - manager.mark_turn_failed("conv_001") - await asyncio.sleep(0.01) - await task - - assert woke_up == [True] - - -@pytest.mark.unit -@pytest.mark.asyncio -async def test_event_clear_resets_for_next_turn(): - """Clearing turn_done after wait() properly gates the next turn.""" - manager = ConversationManager() - state = manager.get_or_create("conv_001") - - # First turn: set then clear - manager.mark_turn_complete("conv_001", "r1") - await state.turn_done.wait() - state.turn_done.clear() - assert not state.turn_done.is_set() - - # Second turn: set again - manager.mark_turn_complete("conv_001", "r2") - assert state.turn_done.is_set() - - @pytest.mark.unit @pytest.mark.asyncio async def test_conversation_manager_concurrent_access(): diff --git a/tests/unit/load_generator/test_multi_turn_strategy.py b/tests/unit/load_generator/test_multi_turn_strategy.py index 7a9d8be1..d3c9a22a 100644 --- a/tests/unit/load_generator/test_multi_turn_strategy.py +++ b/tests/unit/load_generator/test_multi_turn_strategy.py @@ -68,14 +68,12 @@ async def test_single_conversation_single_turn(): strategy = MultiTurnStrategy(conv_manager, metadata) issuer = FakePhaseIssuer() - # Simulate response completion (turn 1 is issued, then completes) async def complete_turns(): - # Wait a tick for the strategy to issue the first turn await asyncio.sleep(0.01) - # Mark turn 1 complete - state = conv_manager.get_state("conv1") - if state: - conv_manager.mark_turn_complete("conv1", "response 1") + result = QueryResult( + id="q0000", response_output=TextModelOutput(output="response 1") + ) + strategy.on_sample_complete(result) asyncio.create_task(complete_turns()) count = await strategy.execute(issuer) @@ -107,7 +105,6 @@ def tracked_issue(idx, data_override=None): async def simulate_responses(): await asyncio.sleep(0.01) for turn_q, resp in [("q0000", "r1"), ("q0001", "r2"), ("q0002", "r3")]: - # Signal turn complete via on_sample_complete result = QueryResult( id=turn_q, response_output=TextModelOutput(output=resp) ) @@ -132,7 +129,6 @@ async def test_multiple_conversations_concurrent(): async def simulate_responses(): await asyncio.sleep(0.02) - # Complete all turns for both conversations for q_prefix in range(4): q = f"q{q_prefix:04d}" result = QueryResult(id=q, response_output=TextModelOutput(output="resp")) @@ -174,12 +170,10 @@ async def simulate_responses(): import time await asyncio.sleep(0.02) - # Complete turn 1 (sample 0) after a delay complete_timestamps[0] = time.monotonic() result = QueryResult(id="q0000", response_output=TextModelOutput(output="r1")) strategy.on_sample_complete(result) await asyncio.sleep(0.05) - # Complete turn 2 (sample 1) complete_timestamps[1] = time.monotonic() result = QueryResult(id="q0001", response_output=TextModelOutput(output="r2")) strategy.on_sample_complete(result) @@ -227,7 +221,6 @@ async def test_on_sample_complete_routes_to_manager(): state = conv_manager.get_state("conv1") assert state is not None assert state.completed_turns == 1 - assert state.turn_done.is_set() assert state.is_complete() @@ -296,7 +289,10 @@ async def test_live_history_initializes_system_prompt(): async def complete_turn(): await asyncio.sleep(0.01) - conv_manager.mark_turn_complete("conv1", "response") + result = QueryResult( + id="q0000", response_output=TextModelOutput(output="response") + ) + strategy.on_sample_complete(result) asyncio.create_task(complete_turn()) await strategy.execute(issuer) @@ -325,7 +321,10 @@ async def test_live_history_no_system_prompt_when_none(): async def complete_turn(): await asyncio.sleep(0.01) - conv_manager.mark_turn_complete("conv1", "response") + result = QueryResult( + id="q0000", response_output=TextModelOutput(output="response") + ) + strategy.on_sample_complete(result) asyncio.create_task(complete_turn()) await strategy.execute(issuer) @@ -352,7 +351,10 @@ async def test_dataset_history_mode_does_not_inject_system_prompt(): async def complete_turn(): await asyncio.sleep(0.01) - conv_manager.mark_turn_complete("conv1", "response") + result = QueryResult( + id="q0000", response_output=TextModelOutput(output="response") + ) + strategy.on_sample_complete(result) asyncio.create_task(complete_turn()) await strategy.execute(issuer) @@ -366,7 +368,7 @@ async def complete_turn(): @pytest.mark.unit @pytest.mark.asyncio async def test_pipeline_error_propagated(): - """execute() re-raises when a conversation pipeline raises an exception.""" + """execute() re-raises when _issue_next_turn raises an exception.""" conv_manager = ConversationManager() metadata = _make_dataset_metadata({"conv1": [1]}) strategy = MultiTurnStrategy(conv_manager, metadata) @@ -492,9 +494,9 @@ async def test_on_sample_complete_passes_metadata(): async def test_concurrency_limits_active_conversations(): """target_concurrency=2 starts at most 2 conversation pipelines simultaneously. - Uses 2-turn conversations so each pipeline has an await point (turn_done.wait - between turns). With 4 conversations and 2 workers, the 3rd and 4th conversations - cannot start until a worker finishes its current conversation. + Uses 2-turn conversations so each pipeline has an await point between turns. + With 4 conversations and 2 workers, the 3rd and 4th conversations cannot start + until a worker finishes its current conversation. """ conv_manager = ConversationManager() # 4 two-turn conversations; pipeline awaits turn-1 response before issuing turn-2 @@ -519,7 +521,7 @@ async def auto_respond(): responder_task = asyncio.create_task(auto_respond()) execute_task = asyncio.create_task(strategy.execute(issuer)) - # Let both workers start and block on turn_done.wait before auto_respond fires + # Let both seed turns get issued before auto_respond fires await asyncio.sleep(0.01) # Only 2 workers β†’ exactly 2 turn-1 queries issued (conv3/conv4 not started yet) @@ -536,8 +538,7 @@ async def auto_respond(): async def test_conversation_slot_reuse(): """With target_concurrency=1, worker completes conv1 before starting conv2. - Uses 2-turn conversations so the pipeline has an await between turns. - The single worker must process both turns of conv1 before conv2's turn 1 is issued. + The single slot must process both turns of conv1 before conv2's turn 1 is issued. """ conv_manager = ConversationManager() # 2 two-turn conversations; sample indices: conv1β†’[0,1], conv2β†’[2,3] @@ -561,6 +562,6 @@ async def auto_respond(): await strategy.execute(issuer) responder_task.cancel() - # Single worker: conv1 turns (samples 0,1) must be issued before conv2 turns (2,3) + # Single slot: conv1 turns (samples 0,1) must be issued before conv2 turns (2,3) assert issuer.issued[:2] == [0, 1], "Conv1 turns should be issued before conv2" assert issuer.issued[2:] == [2, 3], "Conv2 turns should follow conv1" From c2ab3a78812611ede46120c9a743535c717f3148 Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Wed, 6 May 2026 16:32:02 +0000 Subject: [PATCH 14/41] fix: address PR #285 review comments for multi-turn implementation - Remove ConversationMode enum (single-member) and mode field from MultiTurnConfig; drop mode: independent from YAML examples and docs - Merge AddDefaultColumns into AddStaticColumns(overwrite=False) - Replace per-call strategy check with construct-time branch in execute.py - Normalize None tool-calling content to "" in openai_adapter.py - Delete unused Query.metadata, QueryResult.with_metadata, and InFlightRequest.query_metadata plumbing - Add role-specific validation in _validate_conversation_structure: tool rows require non-empty tool_results, assistant rows require content or tool_calls - Backfill explicit sample_index into conversation_metadata["samples"]; MultiTurnStrategy reads sample_meta["sample_index"] instead of enumerate - Add AT-RISK gc=False docstring notes to openai/types.py structs with mutable container fields - Rewrite dataset tool_call_ids with model-generated ids in live-history mode; add test_live_history_remaps_tool_call_id integration test - Lift inline imports to top of test_schema.py Co-Authored-By: Claude Sonnet 4.6 --- docs/MULTI_TURN_QUICKSTART.md | 34 +---- examples/09_MultiTurn/README.md | 24 ---- .../agentic_coding_benchmark.yaml | 1 - .../agentic_workflow_benchmark.yaml | 1 - .../09_MultiTurn/multi_turn_benchmark.yaml | 1 - .../multi_turn_with_concurrency.yaml | 1 - .../commands/benchmark/execute.py | 15 +- src/inference_endpoint/config/schema.py | 8 -- src/inference_endpoint/core/types.py | 30 +--- .../dataset_manager/__init__.py | 2 - .../dataset_manager/multi_turn_dataset.py | 41 +++++- .../dataset_manager/transforms.py | 41 ++---- .../endpoint_client/http.py | 2 - .../endpoint_client/worker.py | 7 +- .../load_generator/conversation_manager.py | 11 +- .../load_generator/multi_turn_strategy.py | 43 ++++-- .../openai/openai_adapter.py | 2 +- src/inference_endpoint/openai/types.py | 25 +++- tests/integration/test_multi_turn.py | 136 ++++++++++++++++++ tests/unit/config/test_schema.py | 6 +- tests/unit/core/test_types.py | 54 ------- tests/unit/dataset_manager/test_transforms.py | 15 +- 22 files changed, 278 insertions(+), 222 deletions(-) diff --git a/docs/MULTI_TURN_QUICKSTART.md b/docs/MULTI_TURN_QUICKSTART.md index f3d5e082..03bf8519 100644 --- a/docs/MULTI_TURN_QUICKSTART.md +++ b/docs/MULTI_TURN_QUICKSTART.md @@ -41,7 +41,6 @@ datasets: type: performance path: path/to/your/conversations.jsonl multi_turn: # ← Presence of this block enables multi-turn mode - mode: independent # ← Per-conv pipelines; no cross-conv turn barrier turn_timeout_s: 300 # ← Max wait for prev turn settings: @@ -110,34 +109,6 @@ _Note: Per-conversation aggregation (e.g., "conversations/sec") is coming in a f --- -## Conversation Modes Explained - -### Independent Mode (Default) - -```yaml -mode: independent -``` - -**Behavior**: - -- Up to `target_concurrency` conversations are active simultaneously -- Turns within each conversation are strictly sequenced (turn N+1 waits for turn N) -- Conversations run independently of each other β€” a short conversation can finish while a long one is still on turn 2 - -**Use for**: Realistic production load simulation. For single-conversation debugging, set `target_concurrency: 1`. - -**Example timeline** (target_concurrency: 3, 4 conversations total): - -``` -t=0: conv1-turn1, conv2-turn1, conv3-turn1 ← 3 conversations start -t=0.5: conv1-turn2 (after conv1-turn1 completes) -t=0.7: conv2 finishes β†’ worker picks up conv4-turn1 -t=0.8: conv1-turn3 (after conv1-turn2 completes) -... -``` - ---- - ## Concurrency Control `target_concurrency` is **required** for the `multi_turn` load pattern. It controls how many @@ -204,8 +175,7 @@ is auto-detected from the `.jsonl` extension β€” no `format` field is needed: ```yaml datasets: - path: your_file.jsonl - multi_turn: - mode: independent + multi_turn: {} ``` --- @@ -287,7 +257,7 @@ jq -r '.conversation_id' logs/multi_turn_test/events.jsonl | sort -u ### Debugging - **Start small**: Test with 1-2 conversations first -- **Single conversation**: Use `mode: independent` with `target_concurrency: 1` +- **Single conversation**: Use `target_concurrency: 1` - **Check events.jsonl**: Verify turn ordering with `jq` --- diff --git a/examples/09_MultiTurn/README.md b/examples/09_MultiTurn/README.md index e7f9505a..3a6f363a 100644 --- a/examples/09_MultiTurn/README.md +++ b/examples/09_MultiTurn/README.md @@ -144,7 +144,6 @@ datasets: type: performance path: examples/09_MultiTurn/customer_support_conversations.jsonl multi_turn: - mode: independent turn_timeout_s: 300.0 settings: @@ -184,29 +183,6 @@ t=1.0: Turn-1 completes β†’ issue turn-2 of completed conv (slot filled) ... Maintains ~32 in-flight across all conversations ``` -### Conversation Modes - -The default mode is `independent`. - -#### Independent Mode (Default) - -Issues turns for each conversation independently β€” no cross-conversation turn barrier. - -```yaml -multi_turn: - mode: independent - -settings: - load_pattern: - type: multi_turn - target_concurrency: 32 -``` - -**Use case**: Realistic production load where short conversations finish while long ones are -still running. Turn 1 of one conversation and turn 100 of another can be in-flight simultaneously. - -For single-conversation debugging, use `mode: independent` with `target_concurrency: 1`. - ### Turn Timeout Configure maximum wait time for previous turn completion: diff --git a/examples/09_MultiTurn/agentic_coding_benchmark.yaml b/examples/09_MultiTurn/agentic_coding_benchmark.yaml index f3abc3cf..dace765d 100644 --- a/examples/09_MultiTurn/agentic_coding_benchmark.yaml +++ b/examples/09_MultiTurn/agentic_coding_benchmark.yaml @@ -13,7 +13,6 @@ datasets: # The datasets/ directory is a placeholder; populate it with the conversion script above. path: examples/09_MultiTurn/datasets/agentic_coding_flat.jsonl multi_turn: - mode: independent turn_timeout_s: 600.0 settings: diff --git a/examples/09_MultiTurn/agentic_workflow_benchmark.yaml b/examples/09_MultiTurn/agentic_workflow_benchmark.yaml index 239e9374..a66b16c4 100644 --- a/examples/09_MultiTurn/agentic_workflow_benchmark.yaml +++ b/examples/09_MultiTurn/agentic_workflow_benchmark.yaml @@ -13,7 +13,6 @@ datasets: # The datasets/ directory is a placeholder; populate it with the conversion script above. path: examples/09_MultiTurn/datasets/agentic_workflow_flat.jsonl multi_turn: - mode: independent turn_timeout_s: 600.0 settings: diff --git a/examples/09_MultiTurn/multi_turn_benchmark.yaml b/examples/09_MultiTurn/multi_turn_benchmark.yaml index 36066aa3..8e5933e4 100644 --- a/examples/09_MultiTurn/multi_turn_benchmark.yaml +++ b/examples/09_MultiTurn/multi_turn_benchmark.yaml @@ -13,7 +13,6 @@ datasets: path: examples/09_MultiTurn/customer_support_conversations.jsonl samples: 10 multi_turn: - mode: independent turn_timeout_s: 300.0 settings: diff --git a/examples/09_MultiTurn/multi_turn_with_concurrency.yaml b/examples/09_MultiTurn/multi_turn_with_concurrency.yaml index e1d5f37c..f1466396 100644 --- a/examples/09_MultiTurn/multi_turn_with_concurrency.yaml +++ b/examples/09_MultiTurn/multi_turn_with_concurrency.yaml @@ -13,7 +13,6 @@ datasets: path: examples/09_MultiTurn/customer_support_conversations.jsonl samples: 10 multi_turn: - mode: independent # All conv turn-1 start together turn_timeout_s: 300.0 settings: diff --git a/src/inference_endpoint/commands/benchmark/execute.py b/src/inference_endpoint/commands/benchmark/execute.py index b5230a53..4356249a 100644 --- a/src/inference_endpoint/commands/benchmark/execute.py +++ b/src/inference_endpoint/commands/benchmark/execute.py @@ -31,6 +31,7 @@ import signal import tempfile import uuid +from collections.abc import Callable from dataclasses import dataclass, field from datetime import datetime from pathlib import Path @@ -556,10 +557,16 @@ async def _run_benchmark_async( target_concurrency=ctx.config.settings.load_pattern.target_concurrency, ) - def _on_sample_complete(result: QueryResult) -> None: - if multi_turn_strategy is not None: - multi_turn_strategy.on_sample_complete(result) - collector.on_complete_hook(result) + _on_sample_complete: Callable[[QueryResult], None] + if multi_turn_strategy is not None: + _mt_strategy = multi_turn_strategy + + def _on_sample_complete(result: QueryResult) -> None: + _mt_strategy.on_sample_complete(result) + collector.on_complete_hook(result) + + else: + _on_sample_complete = collector.on_complete_hook # Create session session = BenchmarkSession( diff --git a/src/inference_endpoint/config/schema.py b/src/inference_endpoint/config/schema.py index b3362899..a7a19b49 100644 --- a/src/inference_endpoint/config/schema.py +++ b/src/inference_endpoint/config/schema.py @@ -65,12 +65,6 @@ class LoadPatternType(str, Enum): STEP = "step" # Step pattern (TODO) -class ConversationMode(str, Enum): - """Multi-turn conversation scheduling modes.""" - - INDEPENDENT = "independent" # Per-conv pipelines; no cross-conv turn barrier - - class OSLDistributionType(str, Enum): """Output Sequence Length distribution types.""" @@ -245,14 +239,12 @@ class MultiTurnConfig(BaseModel): Presence of this block in the dataset config enables multi-turn mode. Attributes: - mode: Conversation scheduling strategy (currently only independent). turn_timeout_s: Maximum seconds to wait for previous turn completion. use_dataset_history: If True, use pre-built message history from dataset. """ model_config = ConfigDict(extra="forbid", frozen=True) - mode: ConversationMode = ConversationMode.INDEPENDENT turn_timeout_s: float = 300.0 use_dataset_history: bool = True diff --git a/src/inference_endpoint/core/types.py b/src/inference_endpoint/core/types.py index 5b8209e1..6887462c 100644 --- a/src/inference_endpoint/core/types.py +++ b/src/inference_endpoint/core/types.py @@ -232,7 +232,6 @@ class Query( Attributes: id: Unique identifier for this query (auto-generated UUID). data: Request payload as a dictionary (typically contains prompt, model, etc.). - metadata: Internal metadata that round-trips through transport (e.g., conversation_id). headers: HTTP headers to include in the request (e.g., authorization). created_at: Timestamp when query was created (seconds since epoch). @@ -246,7 +245,7 @@ class Query( gc=False: Safe because data/headers are simple key-value pairs without cycles. Do NOT store self-referential or cyclic structures in data/headers fields. - array_like=True: Encodes as array instead of object (e.g., ["id", {...}, {...}, 0.0] + array_like=True: Encodes as array instead of object (e.g., ["id", {...}, 0.0] instead of {"id": ..., "data": ..., ...}). Provides ~6-50% size reduction and ~6-29% ser/des speedup for ZMQ transport depending on payload size. @@ -256,7 +255,6 @@ class Query( id: str = msgspec.field(default_factory=lambda: str(uuid.uuid4())) data: dict[str, Any] = msgspec.field(default_factory=dict) - metadata: dict[str, Any] = msgspec.field(default_factory=dict) headers: dict[str, str] = msgspec.field(default_factory=dict) created_at: float = msgspec.field(default_factory=time.time) @@ -339,32 +337,6 @@ def get_response_output_string(self) -> str: else: return "" - def with_metadata( - self, additional_metadata: dict[str, Any] | None - ) -> "QueryResult": - """Return a new QueryResult with merged metadata. - - Args: - additional_metadata: Metadata to merge into existing metadata. - Values in additional_metadata override existing keys. - - Returns: - New QueryResult with merged metadata (existing + additional). - If additional_metadata is None or empty, returns self unchanged. - """ - if not additional_metadata: - return self - - merged = dict(self.metadata) - merged.update(additional_metadata) - - return QueryResult( - id=self.id, - response_output=self.response_output, - metadata=merged, - error=self.error, - ) - class StreamChunk( msgspec.Struct, diff --git a/src/inference_endpoint/dataset_manager/__init__.py b/src/inference_endpoint/dataset_manager/__init__.py index 403b8730..12938f8e 100644 --- a/src/inference_endpoint/dataset_manager/__init__.py +++ b/src/inference_endpoint/dataset_manager/__init__.py @@ -30,7 +30,6 @@ from .predefined.random import RandomDataset from .predefined.shopify_product_catalogue import ShopifyProductCatalogue from .transforms import ( - AddDefaultColumns, AddStaticColumns, ColumnFilter, ColumnRemap, @@ -47,7 +46,6 @@ "DataLoaderFactory", "ColumnFilter", "ColumnRemap", - "AddDefaultColumns", "AddStaticColumns", "UserPromptFormatter", "FusedRowProcessor", diff --git a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py index d2f21695..bc3295d1 100644 --- a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py +++ b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py @@ -24,7 +24,6 @@ from ..exceptions import InputValidationError from .dataset import Dataset from .transforms import ( - AddDefaultColumns, AddStaticColumns, apply_transforms, get_transforms_for_api_type, @@ -171,6 +170,31 @@ def _validate_conversation_structure(self): f"Conversation {conv_id} has invalid role sequence at turn " f"{row['turn']}: got '{role}' after state '{state}'" ) + + if role == "tool": + tool_results = row.get("tool_results") + if not isinstance(tool_results, list) or len(tool_results) == 0: + raise InputValidationError( + f"Conversation {conv_id} turn {row['turn']}: " + "tool rows must have a non-empty 'tool_results' list" + ) + elif role == "assistant": + content = row.get("content") + is_empty_content = ( + content is None + or (isinstance(content, float) and pd.isna(content)) + or content == "" + ) + tool_calls = row.get("tool_calls") + has_tool_calls = ( + isinstance(tool_calls, list) and len(tool_calls) > 0 + ) + if is_empty_content and not has_tool_calls: + raise InputValidationError( + f"Conversation {conv_id} turn {row['turn']}: " + "assistant rows must have non-empty 'content' or non-empty 'tool_calls'" + ) + state = role def _validate_turn_numbering(self): @@ -347,7 +371,7 @@ def load( if isinstance(t, AddStaticColumns): defaults.update(t.data) if defaults: - df = AddDefaultColumns(defaults)(df) + df = AddStaticColumns(defaults, overwrite=False)(df) all_rows = df.to_dict(orient="records") @@ -356,6 +380,8 @@ def load( # value is float NaN was absent in the original dataset row. pre_built = self.conversation_metadata.get("pre_built_messages_by_key", {}) client_turn_samples: list[dict[str, Any]] = [] + # Maps (conv_id, turn) β†’ dense sample_index for metadata backfill. + key_to_sample_index: dict[tuple[str, int], int] = {} # Collect per-conversation defaults from the first user row so that # fields like model/max_completion_tokens propagate to tool rows. @@ -410,6 +436,17 @@ def load( messages = pre_built.get(key, []) sample["messages"] = messages + # Record dense 0-based index before appending (matches load_sample() position). + key_to_sample_index[key] = len(client_turn_samples) client_turn_samples.append(sample) + # Backfill explicit sample_index into conversation_metadata["samples"]. + # Drop entries whose key is absent (truncated turns not in client_turn_samples). + updated_samples = [] + for s in self.conversation_metadata["samples"]: + skey: tuple[str, int] = (str(s["conversation_id"]), int(s["turn"])) + if skey in key_to_sample_index: + updated_samples.append({**s, "sample_index": key_to_sample_index[skey]}) + self.conversation_metadata["samples"] = updated_samples + self.data = client_turn_samples diff --git a/src/inference_endpoint/dataset_manager/transforms.py b/src/inference_endpoint/dataset_manager/transforms.py index a288da6d..9ef7b1c3 100644 --- a/src/inference_endpoint/dataset_manager/transforms.py +++ b/src/inference_endpoint/dataset_manager/transforms.py @@ -114,40 +114,29 @@ def process_row(self, row: dict[str, Any]) -> dict[str, Any]: class AddStaticColumns(Transform): - """Transform that adds columns with constant values to a DataFrame.""" + """Transform that adds columns with constant values to a DataFrame. - def __init__(self, data: dict[str, Any]): - """Initialize the AddStaticColumns transform.""" - self.data = data - - def __call__(self, df: pd.DataFrame) -> pd.DataFrame: - """Add the static columns to the row.""" - for key, value in self.data.items(): - df[key] = value - return df - - -class AddDefaultColumns(Transform): - """Add columns only where values are missing (NaN or absent). - - Unlike AddStaticColumns which unconditionally overwrites, this preserves - existing non-null values β€” dataset per-row overrides take precedence over - the supplied defaults. + When overwrite=False, existing non-null values are preserved β€” dataset + per-row overrides take precedence over the supplied defaults. """ - def __init__(self, data: dict[str, Any]): - """Initialize the AddDefaultColumns transform.""" + def __init__(self, data: dict[str, Any], overwrite: bool = True): + """Initialize the AddStaticColumns transform.""" self.data = data + self.overwrite = overwrite def __call__(self, df: pd.DataFrame) -> pd.DataFrame: - """Fill missing columns with defaults without overwriting existing values.""" + """Add the static columns to the dataframe.""" for key, value in self.data.items(): - if value is None: - continue - if key in df.columns: - df[key] = df[key].where(pd.notna(df[key]), value) - else: + if self.overwrite: df[key] = value + else: + if value is None: + continue + if key in df.columns: + df[key] = df[key].where(pd.notna(df[key]), value) + else: + df[key] = value return df diff --git a/src/inference_endpoint/endpoint_client/http.py b/src/inference_endpoint/endpoint_client/http.py index 1e67a023..d9047301 100644 --- a/src/inference_endpoint/endpoint_client/http.py +++ b/src/inference_endpoint/endpoint_client/http.py @@ -792,12 +792,10 @@ class InFlightRequest: query_id: Correlates response back to original Query. http_bytes: Serialized HTTP request for socket.write(). is_streaming: Whether this is a streaming (SSE) request or not. - query_metadata: Internal metadata carried alongside the request. connection: PooledConnection assigned to this request (set once request is fired). """ query_id: str http_bytes: bytes is_streaming: bool - query_metadata: dict[str, object] = field(default_factory=dict) connection: PooledConnection = field(default=None, repr=False) # type: ignore[assignment] diff --git a/src/inference_endpoint/endpoint_client/worker.py b/src/inference_endpoint/endpoint_client/worker.py index 8fb69fce..8e0e560e 100644 --- a/src/inference_endpoint/endpoint_client/worker.py +++ b/src/inference_endpoint/endpoint_client/worker.py @@ -341,7 +341,6 @@ def _prepare_request(self, query: Query) -> InFlightRequest: query_id=query.id, http_bytes=http_bytes, is_streaming=is_streaming, - query_metadata=query.metadata, ) return req @@ -430,9 +429,7 @@ async def _handle_streaming_body(self, req: InFlightRequest) -> None: self._pool.release(conn) # Send final complete back to main rank - self._responses.send( - accumulator.get_final_output().with_metadata(req.query_metadata) - ) + self._responses.send(accumulator.get_final_output()) @profile async def _handle_non_streaming_body(self, req: InFlightRequest) -> None: @@ -450,7 +447,7 @@ async def _handle_non_streaming_body(self, req: InFlightRequest) -> None: result = self._adapter.decode_response(response_bytes, query_id) # Send result back to main rank - self._responses.send(result.with_metadata(req.query_metadata)) + self._responses.send(result) async def _handle_error(self, query_id: str, error: Exception | str) -> None: """Send error response for a query.""" diff --git a/src/inference_endpoint/load_generator/conversation_manager.py b/src/inference_endpoint/load_generator/conversation_manager.py index 1b0834bb..5b8f41f2 100644 --- a/src/inference_endpoint/load_generator/conversation_manager.py +++ b/src/inference_endpoint/load_generator/conversation_manager.py @@ -33,6 +33,8 @@ class ConversationState: completed_turns: Turns with responses (success or failure) β€” observability only. failed_turns: Turns that failed β€” observability only. expected_client_turns: Expected total client turns (for completion detection). + last_assistant_tool_call_ids: Tool call ids from the most recent assistant response; + used to rewrite dataset tool_call_ids in live-history mode. """ conversation_id: str @@ -40,6 +42,7 @@ class ConversationState: completed_turns: int = 0 failed_turns: int = 0 expected_client_turns: int | None = None + last_assistant_tool_call_ids: list[str] = field(default_factory=list) def is_complete(self) -> bool: """Return True when all expected turns have a response.""" @@ -133,13 +136,18 @@ def mark_turn_complete( state = self._conversations.get(conversation_id) if state is None: raise KeyError(f"Conversation {conversation_id} not initialized") + tool_calls = metadata.get("tool_calls") if metadata else None if store_in_history: - tool_calls = metadata.get("tool_calls") if metadata else None if response or tool_calls: msg: dict[str, Any] = {"role": "assistant", "content": response or None} if tool_calls: msg["tool_calls"] = tool_calls state.message_history.append(msg) + state.last_assistant_tool_call_ids = ( + [tc["id"] for tc in tool_calls if isinstance(tc, dict) and "id" in tc] + if tool_calls + else [] + ) state.completed_turns += 1 self._log_if_complete(state, conversation_id) @@ -166,6 +174,7 @@ def mark_turn_failed( state.message_history.append( {"role": "assistant", "content": "[ERROR: Turn failed or timed out]"} ) + state.last_assistant_tool_call_ids = [] state.completed_turns += 1 state.failed_turns += 1 logger.warning(f"Turn failed for conversation {conversation_id}") diff --git a/src/inference_endpoint/load_generator/multi_turn_strategy.py b/src/inference_endpoint/load_generator/multi_turn_strategy.py index d3f432d7..4ca08d1d 100644 --- a/src/inference_endpoint/load_generator/multi_turn_strategy.py +++ b/src/inference_endpoint/load_generator/multi_turn_strategy.py @@ -117,9 +117,11 @@ async def execute(self, phase_issuer: PhaseIssuerProtocol) -> int: self._error = None conv_samples: dict[str, list[tuple[int, int]]] = defaultdict(list) - for sample_index, sample_meta in enumerate(self._dataset_metadata["samples"]): + for sample_meta in self._dataset_metadata["samples"]: conv_id = sample_meta["conversation_id"] - conv_samples[conv_id].append((sample_index, sample_meta["turn"])) + conv_samples[conv_id].append( + (sample_meta["sample_index"], sample_meta["turn"]) + ) # Pre-create all conversation states before issuing any turns (no locking needed). sys_prompts = self._dataset_metadata.get("system_prompts_by_conv", {}) @@ -198,17 +200,32 @@ def _issue_next_turn(self, conv_id: str) -> None: "current_turn_messages_by_key", {} ).get((conv_id, turn)) if current_turn_messages: - has_tool_msg = any( - m.get("role") == "tool" for m in current_turn_messages - ) - if has_tool_msg: - logger.warning( - "Live-history mode with tool messages uses dataset " - "tool_call_ids; real endpoint IDs will differ " - "(conv=%s, turn=%d)", - conv_id, - turn, - ) + tool_msgs = [ + m for m in current_turn_messages if m.get("role") == "tool" + ] + if tool_msgs: + model_ids = state.last_assistant_tool_call_ids + if len(model_ids) == len(tool_msgs): + # Rewrite dataset-hardcoded tool_call_ids with model-generated ids. + ti = 0 + rewritten: list[dict[str, Any]] = [] + for m in current_turn_messages: + if m.get("role") == "tool": + rewritten.append({**m, "tool_call_id": model_ids[ti]}) + ti += 1 + else: + rewritten.append(m) + current_turn_messages = rewritten + else: + logger.warning( + "Live-history tool_call_id count mismatch for conv=%s turn=%d: " + "model returned %d tool_call(s), dataset expects %d. " + "Using dataset ids.", + conv_id, + turn, + len(model_ids), + len(tool_msgs), + ) live_messages = state.message_history.copy() + current_turn_messages data_override = {"messages": live_messages} diff --git a/src/inference_endpoint/openai/openai_adapter.py b/src/inference_endpoint/openai/openai_adapter.py index a458688c..0064bd22 100644 --- a/src/inference_endpoint/openai/openai_adapter.py +++ b/src/inference_endpoint/openai/openai_adapter.py @@ -140,7 +140,7 @@ def from_endpoint_response( return QueryResult( id=result_id, - response_output=TextModelOutput(output=choice.message.content), + response_output=TextModelOutput(output=choice.message.content or ""), metadata=metadata, ) diff --git a/src/inference_endpoint/openai/types.py b/src/inference_endpoint/openai/types.py index 4c301db5..6dbdb642 100644 --- a/src/inference_endpoint/openai/types.py +++ b/src/inference_endpoint/openai/types.py @@ -36,13 +36,19 @@ # NOTE(vir): msgspec usage # omit_defaults=True: Fields with static defaults are omitted if value equals default (ie those not using default_factory) -# gc=False: Safe for request/response structs with scalar and nested struct fields only. +# gc=False: audit 2026-05: all container fields are populated at construction and never mutated. # frozen=True: Makes structs immutable and hashable, also enables faster struct decoding # (direct attribute access via fixed memory offset vs hash table lookup) +# gc=False: audit 2026-05: tool_calls is set at construction; frozen=True blocks field reassignment. class SSEDelta(msgspec.Struct, frozen=True, kw_only=True, omit_defaults=True, gc=False): # type: ignore[call-arg] - """SSE delta object containing content.""" + """SSE delta object containing content. + + AT-RISK (gc=False): Has mutable container field `tool_calls`. Any change that + mutates `tool_calls` after construction or stores cyclic references in it + must be audited; if so, remove gc=False. + """ content: str = "" reasoning: str = "" @@ -71,11 +77,16 @@ class SSEMessage( # ============================================================================ +# gc=False: audit 2026-05: content/tool_calls set at construction; frozen=True blocks field reassignment. class ChatMessage( msgspec.Struct, frozen=True, kw_only=True, omit_defaults=True, gc=False ): # type: ignore[call-arg] """Chat message in OpenAI format. + AT-RISK (gc=False): Has mutable container fields `content` (list[dict] for multimodal) + and `tool_calls`. Any change that mutates these after construction or stores cyclic + references in them must be audited; if so, remove gc=False. + content: str for text-only messages; list[dict] for multimodal (vision); None for tool-dispatching assistant messages. tool_calls: list of tool call objects for assistant messages that invoke tools. @@ -89,10 +100,16 @@ class ChatMessage( tool_call_id: str | None = None +# gc=False: audit 2026-05: messages/tools set at construction; frozen=True blocks field reassignment. class ChatCompletionRequest( msgspec.Struct, frozen=True, kw_only=True, omit_defaults=True, gc=False ): # type: ignore[call-arg] - """OpenAI chat completion request.""" + """OpenAI chat completion request. + + AT-RISK (gc=False): Has mutable container fields `messages`, `tools`, and `logit_bias`. + Any change that mutates these after construction or stores cyclic references in them + must be audited; if so, remove gc=False. + """ model: str messages: list[ChatMessage] @@ -112,6 +129,7 @@ class ChatCompletionRequest( tools: list[dict[str, Any]] | None = None +# gc=False: audit 2026-05: tool_calls set at construction; frozen=True blocks field reassignment. class ChatCompletionResponseMessage( msgspec.Struct, frozen=True, kw_only=True, omit_defaults=True, gc=False ): # type: ignore[call-arg] @@ -153,6 +171,7 @@ class CompletionUsage( total_tokens: int +# gc=False: audit 2026-05: choices set at construction; frozen=True blocks field reassignment. class ChatCompletionResponse( msgspec.Struct, frozen=True, diff --git a/tests/integration/test_multi_turn.py b/tests/integration/test_multi_turn.py index cfe8a68c..b3a58ad8 100644 --- a/tests/integration/test_multi_turn.py +++ b/tests/integration/test_multi_turn.py @@ -29,12 +29,15 @@ """ import asyncio +import json import random import time +import uuid from urllib.parse import urljoin import pandas as pd import pytest +from aiohttp import web from inference_endpoint import metrics from inference_endpoint.config.runtime_settings import RuntimeSettings from inference_endpoint.config.schema import ( @@ -682,3 +685,136 @@ async def _handle_echo_chat_completions_request(self, request): assert payload["tools"][0]["function"]["name"] == "search" finally: server.stop() + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_live_history_remaps_tool_call_id(): + """In live-history mode, tool message tool_call_id is rewritten to match the + model-generated id from the previous assistant turn, not the dataset's hardcoded id. + """ + received_payloads: list[dict] = [] + # Fresh id the server "generates" for the tool call it dispatches. + model_generated_id = f"call_{uuid.uuid4().hex[:8]}" + + class ToolCallEchoServer(EchoServer): + """Returns a tool_calls response on the first user turn, then echoes normally.""" + + async def _handle_echo_chat_completions_request( + self, request: web.Request + ) -> web.Response: + payload = await request.json() + received_payloads.append(payload) + messages = payload.get("messages", []) + # If the last message is a tool message, echo normally. + if messages and messages[-1].get("role") == "tool": + resp = { + "id": str(uuid.uuid4()), + "object": "chat.completion", + "created": int(time.time()), + "model": "test-model", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "done", + "refusal": None, + }, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 1, + "completion_tokens": 1, + "total_tokens": 2, + }, + "system_fingerprint": None, + } + return web.Response( + text=json.dumps(resp), content_type="application/json" + ) + # First user turn: return a tool_calls response with the model-generated id. + resp = { + "id": str(uuid.uuid4()), + "object": "chat.completion", + "created": int(time.time()), + "model": "test-model", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": None, + "refusal": None, + "tool_calls": [ + { + "id": model_generated_id, + "type": "function", + "function": { + "name": "search", + "arguments": '{"q": "hello"}', + }, + } + ], + }, + "finish_reason": "tool_calls", + } + ], + "usage": { + "prompt_tokens": 1, + "completion_tokens": 1, + "total_tokens": 2, + }, + "system_fingerprint": None, + } + return web.Response(text=json.dumps(resp), content_type="application/json") + + server = ToolCallEchoServer(port=0) + server.start() + try: + dataset_id = "call_dataset_hardcoded" + tool_results = [{"tool_call_id": dataset_id, "content": "result"}] + + rows = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "Search"}, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": dataset_id, + "type": "function", + "function": {"name": "search", "arguments": '{"q": "hello"}'}, + } + ], + }, + { + "conversation_id": "c1", + "turn": 3, + "role": "tool", + "tool_results": tool_results, + }, + ] + ds = _make_dataset(rows) + strategy = _make_strategy(ds, use_dataset_history=False) + responses: dict = {} + + count = await _run_session(server.url, ds, strategy, responses) + assert count == 2 + + # The second request (tool turn) must have the model-generated id, not the dataset id. + assert len(received_payloads) == 2 + tool_turn_payload = received_payloads[1] + tool_messages = [ + m for m in tool_turn_payload.get("messages", []) if m.get("role") == "tool" + ] + assert len(tool_messages) == 1 + assert tool_messages[0]["tool_call_id"] == model_generated_id, ( + f"Expected model-generated id {model_generated_id!r}, " + f"got {tool_messages[0]['tool_call_id']!r}" + ) + finally: + server.stop() diff --git a/tests/unit/config/test_schema.py b/tests/unit/config/test_schema.py index b1d29cfd..54fb5515 100644 --- a/tests/unit/config/test_schema.py +++ b/tests/unit/config/test_schema.py @@ -16,12 +16,14 @@ """Tests for configuration schema models and validation.""" import pytest +from inference_endpoint.config.runtime_settings import RuntimeSettings from inference_endpoint.config.schema import ( APIType, BenchmarkConfig, Dataset, DatasetType, EvalMethod, + LoadPatternType, ModelParams, OSLDistribution, OSLDistributionType, @@ -482,8 +484,6 @@ def _make_online_multi_turn(self, concurrency: int | None = 4, **ds_kwargs): @pytest.mark.unit def test_multi_turn_valid_config(self): config = BenchmarkConfig(**self._make_online_multi_turn(concurrency=16)) - from inference_endpoint.config.schema import LoadPatternType - assert config.settings.load_pattern.type == LoadPatternType.MULTI_TURN assert config.settings.load_pattern.target_concurrency == 16 @@ -522,8 +522,6 @@ class TestMultiTurnTotalSamples: @pytest.mark.unit def test_multi_turn_uses_dataset_size_ignoring_duration(self): - from inference_endpoint.config.runtime_settings import RuntimeSettings - config = BenchmarkConfig( type=TestType.ONLINE, model_params={"name": "M"}, diff --git a/tests/unit/core/test_types.py b/tests/unit/core/test_types.py index 9c3bec15..52bdbe77 100644 --- a/tests/unit/core/test_types.py +++ b/tests/unit/core/test_types.py @@ -891,57 +891,3 @@ def test_numeric_types_in_metadata(self): assert decoded.metadata["large_int"] == 9999999999999999 assert decoded.metadata["negative"] == -123.456 assert decoded.metadata["zero"] == 0 - - -@pytest.mark.unit -class TestQueryResultWithMetadata: - """Test QueryResult.with_metadata() method for metadata merging.""" - - def test_with_metadata_merge_behavior(self): - """Test that with_metadata adds new keys and overwrites existing ones.""" - result = QueryResult( - id="test", - response_output=TextModelOutput(output="hello"), - metadata={"key1": "old_value", "key2": "keep_me"}, - ) - - updated = result.with_metadata({"key1": "new_value", "key3": "added"}) - - assert updated.metadata == { - "key1": "new_value", - "key2": "keep_me", - "key3": "added", - } - assert updated.id == "test" - assert updated.response_output == TextModelOutput(output="hello") - - def test_with_metadata_none_returns_self(self): - """Test that with_metadata(None) returns self unchanged.""" - result = QueryResult( - id="test", - response_output=TextModelOutput(output="hello"), - metadata={"key1": "value"}, - ) - assert result.with_metadata(None) is result - - def test_with_metadata_empty_returns_self(self): - """Test that with_metadata({}) returns self unchanged.""" - result = QueryResult( - id="test", - response_output=TextModelOutput(output="hello"), - metadata={"key1": "value"}, - ) - assert result.with_metadata({}) is result - - def test_query_metadata_field_roundtrips(self): - """Test that Query.metadata round-trips through msgspec encoding.""" - query = Query( - data={"prompt": "Hello"}, - metadata={"conversation_id": "conv-1", "turn": 2}, - ) - - encoded = msgspec.json.encode(query) - decoded = msgspec.json.decode(encoded, type=Query) - - assert decoded.metadata["conversation_id"] == "conv-1" - assert decoded.metadata["turn"] == 2 diff --git a/tests/unit/dataset_manager/test_transforms.py b/tests/unit/dataset_manager/test_transforms.py index 5eca41b4..0ea35f24 100644 --- a/tests/unit/dataset_manager/test_transforms.py +++ b/tests/unit/dataset_manager/test_transforms.py @@ -23,7 +23,6 @@ import pandas as pd import pytest from inference_endpoint.dataset_manager.transforms import ( - AddDefaultColumns, AddStaticColumns, ColumnFilter, ColumnRemap, @@ -827,14 +826,14 @@ def test_no_matching_columns(self): assert "prompt" not in result.columns -class TestAddDefaultColumns: - """Unit tests for AddDefaultColumns transform.""" +class TestAddStaticColumnsNoOverwrite: + """Unit tests for AddStaticColumns(overwrite=False) behavior.""" @pytest.mark.unit def test_fills_missing_columns(self): """New columns are added when absent.""" df = pd.DataFrame({"a": [1, 2]}) - result = AddDefaultColumns({"b": 10, "c": "x"})(df) + result = AddStaticColumns({"b": 10, "c": "x"}, overwrite=False)(df) assert list(result["b"]) == [10, 10] assert list(result["c"]) == ["x", "x"] @@ -842,7 +841,7 @@ def test_fills_missing_columns(self): def test_preserves_existing_non_null_values(self): """Existing non-null values are not overwritten.""" df = pd.DataFrame({"a": [1, 2]}) - result = AddDefaultColumns({"a": 99})(df) + result = AddStaticColumns({"a": 99}, overwrite=False)(df) assert list(result["a"]) == [1, 2] @pytest.mark.unit @@ -850,7 +849,7 @@ def test_fills_nan_values_in_existing_column(self): """NaN cells in an existing column are replaced with the default.""" df = pd.DataFrame({"a": [1.0, float("nan"), 3.0]}) - result = AddDefaultColumns({"a": 99})(df) + result = AddStaticColumns({"a": 99}, overwrite=False)(df) assert result["a"].tolist()[0] == 1.0 assert result["a"].tolist()[1] == 99 assert result["a"].tolist()[2] == 3.0 @@ -860,7 +859,7 @@ def test_skips_none_default_values(self): """A None default value is ignored; the column is not modified.""" df = pd.DataFrame({"a": [1]}) original_a = df["a"].copy() - result = AddDefaultColumns({"a": None, "b": None})(df) + result = AddStaticColumns({"a": None, "b": None}, overwrite=False)(df) assert list(result["a"]) == list(original_a) assert "b" not in result.columns @@ -869,7 +868,7 @@ def test_mixed_nan_and_real_values(self): """Only NaN cells are filled; real values in the same column are preserved.""" df = pd.DataFrame({"temp": [0.9, float("nan"), 0.5]}) - result = AddDefaultColumns({"temp": 0.7})(df) + result = AddStaticColumns({"temp": 0.7}, overwrite=False)(df) assert result["temp"].tolist()[0] == pytest.approx(0.9) assert result["temp"].tolist()[1] == pytest.approx(0.7) assert result["temp"].tolist()[2] == pytest.approx(0.5) From adaa8b4197051184607756d67f695f532e425121 Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Wed, 6 May 2026 17:26:58 +0000 Subject: [PATCH 15/41] Import fix Signed-off-by: Li, Tianmu --- tests/unit/config/test_schema.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/tests/unit/config/test_schema.py b/tests/unit/config/test_schema.py index 54fb5515..f028712a 100644 --- a/tests/unit/config/test_schema.py +++ b/tests/unit/config/test_schema.py @@ -15,7 +15,10 @@ """Tests for configuration schema models and validation.""" +import random + import pytest +from inference_endpoint import metrics from inference_endpoint.config.runtime_settings import RuntimeSettings from inference_endpoint.config.schema import ( APIType, @@ -23,6 +26,7 @@ Dataset, DatasetType, EvalMethod, + LoadPattern, LoadPatternType, ModelParams, OSLDistribution, @@ -537,12 +541,6 @@ def test_multi_turn_uses_dataset_size_ignoring_duration(self): @pytest.mark.unit def test_multi_turn_respects_min_sample_count(self): - import random - - from inference_endpoint import metrics - from inference_endpoint.config.runtime_settings import RuntimeSettings - from inference_endpoint.config.schema import LoadPattern, LoadPatternType - lp = LoadPattern(type=LoadPatternType.MULTI_TURN, target_concurrency=4) rt = RuntimeSettings( metric_target=metrics.Throughput(10.0), @@ -560,12 +558,6 @@ def test_multi_turn_respects_min_sample_count(self): @pytest.mark.unit def test_multi_turn_explicit_n_samples_takes_precedence(self): - import random - - from inference_endpoint import metrics - from inference_endpoint.config.runtime_settings import RuntimeSettings - from inference_endpoint.config.schema import LoadPattern, LoadPatternType - lp = LoadPattern(type=LoadPatternType.MULTI_TURN, target_concurrency=4) rt = RuntimeSettings( metric_target=metrics.Throughput(10.0), From 38d0ef067cd60e015b72e6731554483307b0177d Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Wed, 6 May 2026 20:13:19 +0000 Subject: [PATCH 16/41] fix: revert out-of-scope live-history tool_call_id rewriting Remove tool_call_id rewriting from live-history mode (last_assistant_tool_call_ids field, ConversationManager population, MultiTurnStrategy rewrite logic) and the corresponding integration test. Live-history improvements are not in scope for this PR. Also revert the _mt_strategy closure capture in execute.py that was not requested by any review comment, while keeping the is-None branch elimination. Co-Authored-By: Claude Sonnet 4.6 --- .../commands/benchmark/execute.py | 3 +- .../load_generator/conversation_manager.py | 11 +- .../load_generator/multi_turn_strategy.py | 37 ++--- tests/integration/test_multi_turn.py | 136 ------------------ 4 files changed, 13 insertions(+), 174 deletions(-) diff --git a/src/inference_endpoint/commands/benchmark/execute.py b/src/inference_endpoint/commands/benchmark/execute.py index 4356249a..89297c02 100644 --- a/src/inference_endpoint/commands/benchmark/execute.py +++ b/src/inference_endpoint/commands/benchmark/execute.py @@ -559,10 +559,9 @@ async def _run_benchmark_async( _on_sample_complete: Callable[[QueryResult], None] if multi_turn_strategy is not None: - _mt_strategy = multi_turn_strategy def _on_sample_complete(result: QueryResult) -> None: - _mt_strategy.on_sample_complete(result) + multi_turn_strategy.on_sample_complete(result) collector.on_complete_hook(result) else: diff --git a/src/inference_endpoint/load_generator/conversation_manager.py b/src/inference_endpoint/load_generator/conversation_manager.py index 5b8f41f2..1b0834bb 100644 --- a/src/inference_endpoint/load_generator/conversation_manager.py +++ b/src/inference_endpoint/load_generator/conversation_manager.py @@ -33,8 +33,6 @@ class ConversationState: completed_turns: Turns with responses (success or failure) β€” observability only. failed_turns: Turns that failed β€” observability only. expected_client_turns: Expected total client turns (for completion detection). - last_assistant_tool_call_ids: Tool call ids from the most recent assistant response; - used to rewrite dataset tool_call_ids in live-history mode. """ conversation_id: str @@ -42,7 +40,6 @@ class ConversationState: completed_turns: int = 0 failed_turns: int = 0 expected_client_turns: int | None = None - last_assistant_tool_call_ids: list[str] = field(default_factory=list) def is_complete(self) -> bool: """Return True when all expected turns have a response.""" @@ -136,18 +133,13 @@ def mark_turn_complete( state = self._conversations.get(conversation_id) if state is None: raise KeyError(f"Conversation {conversation_id} not initialized") - tool_calls = metadata.get("tool_calls") if metadata else None if store_in_history: + tool_calls = metadata.get("tool_calls") if metadata else None if response or tool_calls: msg: dict[str, Any] = {"role": "assistant", "content": response or None} if tool_calls: msg["tool_calls"] = tool_calls state.message_history.append(msg) - state.last_assistant_tool_call_ids = ( - [tc["id"] for tc in tool_calls if isinstance(tc, dict) and "id" in tc] - if tool_calls - else [] - ) state.completed_turns += 1 self._log_if_complete(state, conversation_id) @@ -174,7 +166,6 @@ def mark_turn_failed( state.message_history.append( {"role": "assistant", "content": "[ERROR: Turn failed or timed out]"} ) - state.last_assistant_tool_call_ids = [] state.completed_turns += 1 state.failed_turns += 1 logger.warning(f"Turn failed for conversation {conversation_id}") diff --git a/src/inference_endpoint/load_generator/multi_turn_strategy.py b/src/inference_endpoint/load_generator/multi_turn_strategy.py index 4ca08d1d..8f23d485 100644 --- a/src/inference_endpoint/load_generator/multi_turn_strategy.py +++ b/src/inference_endpoint/load_generator/multi_turn_strategy.py @@ -200,32 +200,17 @@ def _issue_next_turn(self, conv_id: str) -> None: "current_turn_messages_by_key", {} ).get((conv_id, turn)) if current_turn_messages: - tool_msgs = [ - m for m in current_turn_messages if m.get("role") == "tool" - ] - if tool_msgs: - model_ids = state.last_assistant_tool_call_ids - if len(model_ids) == len(tool_msgs): - # Rewrite dataset-hardcoded tool_call_ids with model-generated ids. - ti = 0 - rewritten: list[dict[str, Any]] = [] - for m in current_turn_messages: - if m.get("role") == "tool": - rewritten.append({**m, "tool_call_id": model_ids[ti]}) - ti += 1 - else: - rewritten.append(m) - current_turn_messages = rewritten - else: - logger.warning( - "Live-history tool_call_id count mismatch for conv=%s turn=%d: " - "model returned %d tool_call(s), dataset expects %d. " - "Using dataset ids.", - conv_id, - turn, - len(model_ids), - len(tool_msgs), - ) + has_tool_msg = any( + m.get("role") == "tool" for m in current_turn_messages + ) + if has_tool_msg: + logger.warning( + "Live-history mode with tool messages uses dataset " + "tool_call_ids; real endpoint IDs will differ " + "(conv=%s, turn=%d)", + conv_id, + turn, + ) live_messages = state.message_history.copy() + current_turn_messages data_override = {"messages": live_messages} diff --git a/tests/integration/test_multi_turn.py b/tests/integration/test_multi_turn.py index b3a58ad8..cfe8a68c 100644 --- a/tests/integration/test_multi_turn.py +++ b/tests/integration/test_multi_turn.py @@ -29,15 +29,12 @@ """ import asyncio -import json import random import time -import uuid from urllib.parse import urljoin import pandas as pd import pytest -from aiohttp import web from inference_endpoint import metrics from inference_endpoint.config.runtime_settings import RuntimeSettings from inference_endpoint.config.schema import ( @@ -685,136 +682,3 @@ async def _handle_echo_chat_completions_request(self, request): assert payload["tools"][0]["function"]["name"] == "search" finally: server.stop() - - -@pytest.mark.integration -@pytest.mark.asyncio -async def test_live_history_remaps_tool_call_id(): - """In live-history mode, tool message tool_call_id is rewritten to match the - model-generated id from the previous assistant turn, not the dataset's hardcoded id. - """ - received_payloads: list[dict] = [] - # Fresh id the server "generates" for the tool call it dispatches. - model_generated_id = f"call_{uuid.uuid4().hex[:8]}" - - class ToolCallEchoServer(EchoServer): - """Returns a tool_calls response on the first user turn, then echoes normally.""" - - async def _handle_echo_chat_completions_request( - self, request: web.Request - ) -> web.Response: - payload = await request.json() - received_payloads.append(payload) - messages = payload.get("messages", []) - # If the last message is a tool message, echo normally. - if messages and messages[-1].get("role") == "tool": - resp = { - "id": str(uuid.uuid4()), - "object": "chat.completion", - "created": int(time.time()), - "model": "test-model", - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": "done", - "refusal": None, - }, - "finish_reason": "stop", - } - ], - "usage": { - "prompt_tokens": 1, - "completion_tokens": 1, - "total_tokens": 2, - }, - "system_fingerprint": None, - } - return web.Response( - text=json.dumps(resp), content_type="application/json" - ) - # First user turn: return a tool_calls response with the model-generated id. - resp = { - "id": str(uuid.uuid4()), - "object": "chat.completion", - "created": int(time.time()), - "model": "test-model", - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": None, - "refusal": None, - "tool_calls": [ - { - "id": model_generated_id, - "type": "function", - "function": { - "name": "search", - "arguments": '{"q": "hello"}', - }, - } - ], - }, - "finish_reason": "tool_calls", - } - ], - "usage": { - "prompt_tokens": 1, - "completion_tokens": 1, - "total_tokens": 2, - }, - "system_fingerprint": None, - } - return web.Response(text=json.dumps(resp), content_type="application/json") - - server = ToolCallEchoServer(port=0) - server.start() - try: - dataset_id = "call_dataset_hardcoded" - tool_results = [{"tool_call_id": dataset_id, "content": "result"}] - - rows = [ - {"conversation_id": "c1", "turn": 1, "role": "user", "content": "Search"}, - { - "conversation_id": "c1", - "turn": 2, - "role": "assistant", - "content": None, - "tool_calls": [ - { - "id": dataset_id, - "type": "function", - "function": {"name": "search", "arguments": '{"q": "hello"}'}, - } - ], - }, - { - "conversation_id": "c1", - "turn": 3, - "role": "tool", - "tool_results": tool_results, - }, - ] - ds = _make_dataset(rows) - strategy = _make_strategy(ds, use_dataset_history=False) - responses: dict = {} - - count = await _run_session(server.url, ds, strategy, responses) - assert count == 2 - - # The second request (tool turn) must have the model-generated id, not the dataset id. - assert len(received_payloads) == 2 - tool_turn_payload = received_payloads[1] - tool_messages = [ - m for m in tool_turn_payload.get("messages", []) if m.get("role") == "tool" - ] - assert len(tool_messages) == 1 - assert tool_messages[0]["tool_call_id"] == model_generated_id, ( - f"Expected model-generated id {model_generated_id!r}, " - f"got {tool_messages[0]['tool_call_id']!r}" - ) - finally: - server.stop() From d2dace88d3f4815b34028de7230086bb1aba070c Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Wed, 6 May 2026 23:40:19 +0000 Subject: [PATCH 17/41] Fix issue with tool call accumulation and reasoning content Signed-off-by: Li, Tianmu --- src/inference_endpoint/openai/accumulator.py | 7 ++++--- src/inference_endpoint/openai/types.py | 5 +++-- tests/performance/openai/test_adapter.py | 6 +----- tests/performance/openai/test_msgspec_adapter.py | 6 +----- tests/performance/openai/test_types.py | 2 +- 5 files changed, 10 insertions(+), 16 deletions(-) diff --git a/src/inference_endpoint/openai/accumulator.py b/src/inference_endpoint/openai/accumulator.py index a01b7b44..e630a497 100644 --- a/src/inference_endpoint/openai/accumulator.py +++ b/src/inference_endpoint/openai/accumulator.py @@ -69,9 +69,10 @@ def add_chunk(self, choice: SSEChoice | None) -> StreamChunk | None: if delta.content: self.output_chunks.append(delta.content) content = delta.content - elif delta.reasoning: - self.reasoning_chunks.append(delta.reasoning) - content = delta.reasoning + elif delta.reasoning_content or delta.reasoning: + rc = delta.reasoning_content or delta.reasoning + self.reasoning_chunks.append(rc) # type: ignore[arg-type] + content = rc else: return None diff --git a/src/inference_endpoint/openai/types.py b/src/inference_endpoint/openai/types.py index 6dbdb642..558ed7f2 100644 --- a/src/inference_endpoint/openai/types.py +++ b/src/inference_endpoint/openai/types.py @@ -50,8 +50,9 @@ class SSEDelta(msgspec.Struct, frozen=True, kw_only=True, omit_defaults=True, gc must be audited; if so, remove gc=False. """ - content: str = "" - reasoning: str = "" + content: str | None = None + reasoning_content: str | None = None # SGLang / DeepSeek field name + reasoning: str | None = None # vLLM field name tool_calls: list[dict[str, Any]] | None = None diff --git a/tests/performance/openai/test_adapter.py b/tests/performance/openai/test_adapter.py index e38ec5e8..6ed12ab2 100644 --- a/tests/performance/openai/test_adapter.py +++ b/tests/performance/openai/test_adapter.py @@ -71,11 +71,7 @@ def make_response_bytes(text: str) -> bytes: def make_sse_bytes(text: str) -> bytes: """Create SSE message JSON bytes.""" return json.dumps( - { - "choices": [ - {"delta": {"content": text, "reasoning": ""}, "finish_reason": None} - ] - } + {"choices": [{"delta": {"content": text}, "finish_reason": None}]} ).encode() diff --git a/tests/performance/openai/test_msgspec_adapter.py b/tests/performance/openai/test_msgspec_adapter.py index 9fc6a10b..21bedd25 100644 --- a/tests/performance/openai/test_msgspec_adapter.py +++ b/tests/performance/openai/test_msgspec_adapter.py @@ -71,11 +71,7 @@ def make_response_bytes(text: str) -> bytes: def make_sse_bytes(text: str) -> bytes: """Create SSE message JSON bytes.""" return json.dumps( - { - "choices": [ - {"delta": {"content": text, "reasoning": ""}, "finish_reason": None} - ] - } + {"choices": [{"delta": {"content": text}, "finish_reason": None}]} ).encode() diff --git a/tests/performance/openai/test_types.py b/tests/performance/openai/test_types.py index 43df3a4a..bdf8aa6b 100644 --- a/tests/performance/openai/test_types.py +++ b/tests/performance/openai/test_types.py @@ -61,7 +61,7 @@ def make_sse_message(text: str) -> SSEMessage: return SSEMessage( choices=( SSEChoice( - delta=SSEDelta(content=text, reasoning=""), + delta=SSEDelta(content=text), finish_reason=None, ), ) From a7ef9e57ecab91b6bf4e349e423653a8323d421a Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Thu, 7 May 2026 17:20:48 +0000 Subject: [PATCH 18/41] feat: account for tool-call tokens in OSL / TPOT / TPS metrics Tool-call tokens were completely excluded from output sequence length, TPOT, and TPS because they were only stored in QueryResult.metadata and never reached TextModelOutput or EventRecord.data. - Add `tool_calls` field to TextModelOutput; __str__ and text_after_first_chunk include JSON-encoded tool calls so the full generation is counted - Add as_message_parts / as_message_parts_after_first_chunk helpers for chat-template-aware tokenization in the metrics pipeline - OpenAI SSE accumulator populates tool_calls in TextModelOutput and emits a zero-length sentinel StreamChunk on the first tool-call delta so TTFT fires for agentic (content-free) responses - Both OpenAI adapters (msgspec and pydantic) route tool_calls into TextModelOutput in addition to metadata - TokenizePool gains token_count_message / token_count_message_async using apply_chat_template + baseline subtraction, with fallback to whitespace tokenization when the template raises - OslTrigger and TpotTrigger override the new _extract_message hook to use the message tokenization path when tool_calls are present - Forward `tools` key through MultiTurnDataset per-conversation defaults Co-Authored-By: Claude Sonnet 4.6 --- .../metrics_aggregator/metrics_table.py | 53 +++++- .../metrics_aggregator/token_metrics.py | 83 ++++++++- src/inference_endpoint/core/types.py | 70 +++++++- .../dataset_manager/multi_turn_dataset.py | 1 + src/inference_endpoint/openai/accumulator.py | 31 +++- .../openai/openai_adapter.py | 13 +- .../openai/openai_msgspec_adapter.py | 8 +- .../services/metrics_aggregator/conftest.py | 16 ++ .../metrics_aggregator/test_metrics_table.py | 129 ++++++++++++++ .../metrics_aggregator/test_token_metrics.py | 77 +++++++++ tests/unit/core/test_types.py | 83 +++++++++ tests/unit/openai/test_accumulator.py | 163 ++++++++++++++++++ tests/unit/openai/test_msgspec_adapter.py | 47 +++++ tests/unit/openai/test_openai_adapter.py | 47 +++++ 14 files changed, 807 insertions(+), 14 deletions(-) create mode 100644 tests/unit/openai/test_accumulator.py diff --git a/src/inference_endpoint/async_utils/services/metrics_aggregator/metrics_table.py b/src/inference_endpoint/async_utils/services/metrics_aggregator/metrics_table.py index a66c1e8d..15fe1751 100644 --- a/src/inference_endpoint/async_utils/services/metrics_aggregator/metrics_table.py +++ b/src/inference_endpoint/async_utils/services/metrics_aggregator/metrics_table.py @@ -174,7 +174,9 @@ class AsyncTokenTrigger(EmitTrigger): Subclasses implement ``_extract_text()`` to pull the text to tokenize from the event record. If text is returned, an async task is created - to tokenize and emit. Subclasses can override ``_compute_value()`` to + to tokenize and emit. Subclasses can also override ``_extract_message()`` + to return (content, reasoning, tool_calls) for chat-template–aware tokenization + when tool calls are present. Subclasses can override ``_compute_value()`` to transform the token count before storing. """ @@ -198,6 +200,16 @@ def _extract_text( """Return the text to tokenize, or None to skip.""" raise NotImplementedError() + def _extract_message( + self, ev_rec: EventRecord, row: SampleRow, pre_change: dict[str, Any] + ) -> tuple[str, str | None, tuple[dict[str, Any], ...] | None] | None: + """Return (content, reasoning, tool_calls) for message-aware tokenization, or None. + + When non-None is returned, ``token_count_message_async`` is used instead of + ``token_count_async``. Default returns None (use text path). + """ + return None + def _compute_value( self, token_count: int, ev_rec: EventRecord, pre_change: dict[str, Any] ) -> int | float | None: @@ -207,6 +219,27 @@ def _compute_value( def fire(self, ev_rec, row, pre_change): if self._pool is None or self._loop is None: return None + + message_parts = self._extract_message(ev_rec, row, pre_change) + if message_parts is not None: + content, reasoning, tool_calls = message_parts + pool, loop = self._pool, self._loop + store, name = self.kv_store, self.metric_name + uuid = row.sample_uuid + + async def _tokenize_message_and_emit() -> None: + try: + count = await pool.token_count_message_async( + content, reasoning, tool_calls, loop + ) + value = self._compute_value(count, ev_rec, pre_change) + if value is not None: + store.update(name, value) + except Exception: + logger.exception("%s tokenization failed for %s", name, uuid) + + return loop.create_task(_tokenize_message_and_emit()) + text = self._extract_text(ev_rec, row, pre_change) if not text: return None @@ -312,10 +345,18 @@ def __init__( def _extract_text(self, ev_rec, row, pre_change): if isinstance(ev_rec.data, TextModelOutput): + if ev_rec.data.tool_calls: + # Delegate to _extract_message for chat-template tokenization. + return None text = str(ev_rec.data) return text if text else None return None + def _extract_message(self, ev_rec, row, pre_change): + if isinstance(ev_rec.data, TextModelOutput) and ev_rec.data.tool_calls: + return ev_rec.data.as_message_parts() + return None + class TpotTrigger(AsyncTokenTrigger): """TPOT = (complete_ns - recv_first_ns) / token_count(text_after_first_chunk). @@ -351,9 +392,19 @@ def _extract_text(self, ev_rec, row, pre_change): if pre_change.get(SampleField.RECV_FIRST_NS) is None: return None if isinstance(ev_rec.data, TextModelOutput): + if ev_rec.data.tool_calls: + # Delegate to _extract_message for chat-template tokenization. + return None return ev_rec.data.text_after_first_chunk() or None return None + def _extract_message(self, ev_rec, row, pre_change): + if pre_change.get(SampleField.RECV_FIRST_NS) is None: + return None + if isinstance(ev_rec.data, TextModelOutput) and ev_rec.data.tool_calls: + return ev_rec.data.as_message_parts_after_first_chunk() + return None + def _compute_value(self, token_count, ev_rec, pre_change): if token_count <= 0: return None diff --git a/src/inference_endpoint/async_utils/services/metrics_aggregator/token_metrics.py b/src/inference_endpoint/async_utils/services/metrics_aggregator/token_metrics.py index 56dee33f..8dddd9dc 100644 --- a/src/inference_endpoint/async_utils/services/metrics_aggregator/token_metrics.py +++ b/src/inference_endpoint/async_utils/services/metrics_aggregator/token_metrics.py @@ -18,15 +18,19 @@ from __future__ import annotations import asyncio +import logging import threading from concurrent.futures import ThreadPoolExecutor -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any +import msgspec from transformers import AutoTokenizer if TYPE_CHECKING: from transformers import PreTrainedTokenizerBase +logger = logging.getLogger(__name__) + class TokenizePool: """A pool of worker threads, each with its own HuggingFace AutoTokenizer. @@ -79,6 +83,19 @@ def _get_thread_tokenizer(self) -> PreTrainedTokenizerBase: self._thread_local.tokenizer = AutoTokenizer.from_pretrained( self._tokenizer_name ) + try: + baseline_tokens = self._thread_local.tokenizer.apply_chat_template( + [{"role": "assistant", "content": ""}], + tokenize=True, + add_generation_prompt=False, + ) + self._thread_local.baseline = len(baseline_tokens) + except Exception: + self._thread_local.baseline = 0 + logger.warning( + "Failed to compute chat-template baseline for %s; tool-call token counts may be over-estimated", + self._tokenizer_name, + ) return self._thread_local.tokenizer def _token_count_worker(self, text: str) -> int: @@ -86,6 +103,38 @@ def _token_count_worker(self, text: str) -> int: tokenizer = self._get_thread_tokenizer() return len(tokenizer.tokenize(text)) + def _token_count_message_worker( + self, + content: str, + reasoning: str | None, + tool_calls: tuple[dict[str, Any], ...] | None, + ) -> int: + """Worker entry: tokenize a full assistant message using apply_chat_template. + + Falls back to whitespace-split tokenization if apply_chat_template raises + (e.g. the template does not support tool_calls or reasoning fields). + """ + tokenizer = self._get_thread_tokenizer() + msg: dict[str, Any] = {"role": "assistant", "content": content or ""} + if reasoning: + msg["reasoning_content"] = reasoning + if tool_calls: + msg["tool_calls"] = list(tool_calls) + try: + full = len( + tokenizer.apply_chat_template( + [msg], tokenize=True, add_generation_prompt=False + ) + ) + baseline = getattr(self._thread_local, "baseline", 0) + return max(0, full - baseline) + except Exception: + tool_calls_json = ( + msgspec.json.encode(list(tool_calls)).decode() if tool_calls else "" + ) + fallback_text = (content or "") + (reasoning or "") + tool_calls_json + return self._token_count_worker(fallback_text) + def token_count(self, text: str) -> int: """Return the number of tokens in the input string (blocking).""" if self._executor is None: @@ -93,6 +142,20 @@ def token_count(self, text: str) -> int: future = self._executor.submit(self._token_count_worker, text) return future.result() + def token_count_message( + self, + content: str, + reasoning: str | None, + tool_calls: tuple[dict[str, Any], ...] | None, + ) -> int: + """Return the token count for an assistant message (blocking).""" + if self._executor is None: + raise RuntimeError("TokenizePool is closed") + future = self._executor.submit( + self._token_count_message_worker, content, reasoning, tool_calls + ) + return future.result() + async def token_count_async( self, text: str, loop: asyncio.AbstractEventLoop ) -> int: @@ -107,6 +170,24 @@ async def token_count_async( self._executor, self._token_count_worker, text ) + async def token_count_message_async( + self, + content: str, + reasoning: str | None, + tool_calls: tuple[dict[str, Any], ...] | None, + loop: asyncio.AbstractEventLoop, + ) -> int: + """Return the token count for an assistant message without blocking the event loop.""" + if self._executor is None: + raise RuntimeError("TokenizePool is closed") + return await loop.run_in_executor( + self._executor, + self._token_count_message_worker, + content, + reasoning, + tool_calls, + ) + def close(self) -> None: """Shut down the worker pool. Idempotent.""" if self._executor is not None: diff --git a/src/inference_endpoint/core/types.py b/src/inference_endpoint/core/types.py index 6887462c..aaea804f 100644 --- a/src/inference_endpoint/core/types.py +++ b/src/inference_endpoint/core/types.py @@ -90,23 +90,33 @@ class TextModelOutput( ): # type: ignore[call-arg] """Structured output from a text model. - Supports main output and optional reasoning (e.g. chain-of-thought). + Supports main output, optional reasoning (e.g. chain-of-thought), and tool calls. Each field may be a string (non-streaming) or tuple of strings (streaming chunks). + AT-RISK (gc=False): Has mutable container field `tool_calls`. Any change that + mutates `tool_calls` after construction or stores cyclic references in it + must be audited; if so, remove gc=False. + Attributes: output: Main model output. Defaults to empty string. reasoning: Optional reasoning trace. Defaults to None. + tool_calls: Optional structured tool calls. Defaults to None. + Placed after reasoning so wire-format with array_like=True is + backward compatible (missing trailing elements decode as default). """ output: OUTPUT_ELEM_TYPE = "" reasoning: OUTPUT_ELEM_TYPE | None = None + tool_calls: tuple[dict[str, Any], ...] | None = None def __post_init__(self): - """Convert list to tuple for output and reasoning to preserve immutability.""" + """Convert list to tuple for output, reasoning, and tool_calls to preserve immutability.""" if isinstance(self.output, list): msgspec.structs.force_setattr(self, "output", tuple(self.output)) if self.reasoning is not None and isinstance(self.reasoning, list): msgspec.structs.force_setattr(self, "reasoning", tuple(self.reasoning)) + if self.tool_calls is not None and isinstance(self.tool_calls, list): + msgspec.structs.force_setattr(self, "tool_calls", tuple(self.tool_calls)) def __str__(self) -> str: """Return the full output as a single string (joins tuple chunks if streaming).""" @@ -123,6 +133,9 @@ def __str__(self) -> str: elif isinstance(self.output, tuple): parts.extend(self.output) + if self.tool_calls: + parts.append(msgspec.json.encode(list(self.tool_calls)).decode()) + # NOTE: Not sure how output is formatted - there *might* need to be a space or separator between # reasoning and output depending on the accumulator / API. return "".join(parts) @@ -136,6 +149,10 @@ def text_after_first_chunk(self) -> str: For non-streaming (str fields), there is no "first chunk" concept so this returns an empty string. + + Tool calls are always included when present β€” they are accumulated across + multiple deltas and only realized at stream end, so they always contribute + to the TPOT denominator. """ parts: list[str] = [] if self.reasoning: @@ -155,8 +172,57 @@ def text_after_first_chunk(self) -> str: elif len(self.output) > 1: # No reasoning; first chunk is output[0], skip it. parts.extend(self.output[1:]) + if self.tool_calls: + parts.append(msgspec.json.encode(list(self.tool_calls)).decode()) return "".join(parts) + def as_message_parts( + self, + ) -> tuple[str, str | None, tuple[dict[str, Any], ...] | None]: + """Return (content, reasoning, tool_calls) for chat-template tokenization.""" + if isinstance(self.output, str): + content = self.output + else: + content = "".join(self.output) + + reasoning_str: str | None = None + if self.reasoning: + if isinstance(self.reasoning, str): + reasoning_str = self.reasoning + else: + reasoning_str = "".join(self.reasoning) + + return content, reasoning_str, self.tool_calls + + def as_message_parts_after_first_chunk( + self, + ) -> tuple[str, str | None, tuple[dict[str, Any], ...] | None]: + """Return (content_after_first, reasoning_after_first, tool_calls) for TPOT tokenization.""" + reasoning_after: str | None = None + if isinstance(self.reasoning, tuple) and len(self.reasoning) > 1: + reasoning_after = "".join(self.reasoning[1:]) + + # has_reasoning_tail: reasoning[1:] is non-empty (mirrors `parts` logic in text_after_first_chunk) + has_reasoning_tail = reasoning_after is not None + + content_after = "" + if self.output: + if isinstance(self.output, str): + # Include if reasoning is any non-empty tuple (it was the first chunk) + if has_reasoning_tail or ( + self.reasoning and isinstance(self.reasoning, tuple) + ): + content_after = self.output + elif isinstance(self.output, tuple): + if has_reasoning_tail or self.reasoning: + # First chunk was in reasoning; include all output chunks. + content_after = "".join(self.output) + elif len(self.output) > 1: + # No reasoning; first chunk is output[0], skip it. + content_after = "".join(self.output[1:]) + + return content_after, reasoning_after, self.tool_calls + OUTPUT_TYPE = TextModelOutput diff --git a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py index bc3295d1..63e5aa4e 100644 --- a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py +++ b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py @@ -390,6 +390,7 @@ def load( "max_completion_tokens", "max_new_tokens", "stream", + "tools", } conv_defaults: dict[str, dict[str, Any]] = {} for row in all_rows: diff --git a/src/inference_endpoint/openai/accumulator.py b/src/inference_endpoint/openai/accumulator.py index e630a497..0c1e88f4 100644 --- a/src/inference_endpoint/openai/accumulator.py +++ b/src/inference_endpoint/openai/accumulator.py @@ -73,6 +73,16 @@ def add_chunk(self, choice: SSEChoice | None) -> StreamChunk | None: rc = delta.reasoning_content or delta.reasoning self.reasoning_chunks.append(rc) # type: ignore[arg-type] content = rc + elif delta.tool_calls and not self.first_chunk_sent: + # Pure tool-call delta with no text: emit a zero-length sentinel so + # RECV_FIRST / TTFT fires for agentic responses that have no content. + sentinel = StreamChunk( + id=self.query_id, + response_chunk="", + metadata={"first_chunk": True}, + ) + self.first_chunk_sent = True + return sentinel else: return None @@ -92,6 +102,12 @@ def add_chunk(self, choice: SSEChoice | None) -> StreamChunk | None: return None def get_final_output(self) -> QueryResult: + tool_calls_tuple: tuple[dict[str, Any], ...] | None = ( + tuple(self._tool_calls[i] for i in sorted(self._tool_calls)) + if self._tool_calls + else None + ) + if self.reasoning_chunks: resp_reasoning: list[str] = [self.reasoning_chunks[0]] if len(self.reasoning_chunks) > 1: @@ -99,14 +115,19 @@ def get_final_output(self) -> QueryResult: text_output = TextModelOutput( output="".join(self.output_chunks), reasoning=resp_reasoning, + tool_calls=tool_calls_tuple, ) elif self.output_chunks: resp_output: list[str] = [self.output_chunks[0]] if len(self.output_chunks) > 1: resp_output.append("".join(self.output_chunks[1:])) - text_output = TextModelOutput(output=resp_output, reasoning=None) + text_output = TextModelOutput( + output=resp_output, reasoning=None, tool_calls=tool_calls_tuple + ) else: - text_output = TextModelOutput(output=[], reasoning=None) + text_output = TextModelOutput( + output=[], reasoning=None, tool_calls=tool_calls_tuple + ) metadata: dict[str, Any] = { "first_chunk": not self.first_chunk_sent, @@ -114,10 +135,8 @@ def get_final_output(self) -> QueryResult: } if self._finish_reason: metadata["finish_reason"] = self._finish_reason - if self._tool_calls: - metadata["tool_calls"] = [ - self._tool_calls[i] for i in sorted(self._tool_calls) - ] + if tool_calls_tuple: + metadata["tool_calls"] = list(tool_calls_tuple) return QueryResult( id=self.query_id, diff --git a/src/inference_endpoint/openai/openai_adapter.py b/src/inference_endpoint/openai/openai_adapter.py index 0064bd22..30091810 100644 --- a/src/inference_endpoint/openai/openai_adapter.py +++ b/src/inference_endpoint/openai/openai_adapter.py @@ -134,13 +134,20 @@ def from_endpoint_response( if choice.finish_reason: metadata["finish_reason"] = choice.finish_reason.value if choice.message.tool_calls: - metadata["tool_calls"] = [ - tc.model_dump(mode="json") for tc in choice.message.tool_calls + raw_tool_calls = [ + tc.model_dump(mode="json") for tc in choice.message.tool_calls.root ] + metadata["tool_calls"] = raw_tool_calls + else: + raw_tool_calls = None + tool_calls_tuple = tuple(raw_tool_calls) if raw_tool_calls else None return QueryResult( id=result_id, - response_output=TextModelOutput(output=choice.message.content or ""), + response_output=TextModelOutput( + output=choice.message.content or "", + tool_calls=tool_calls_tuple, + ), metadata=metadata, ) diff --git a/src/inference_endpoint/openai/openai_msgspec_adapter.py b/src/inference_endpoint/openai/openai_msgspec_adapter.py index e512e22b..ece0b6d3 100644 --- a/src/inference_endpoint/openai/openai_msgspec_adapter.py +++ b/src/inference_endpoint/openai/openai_msgspec_adapter.py @@ -216,9 +216,15 @@ def from_endpoint_response( if choice.message.reasoning_content: metadata["reasoning_content"] = choice.message.reasoning_content + tool_calls_tuple = ( + tuple(choice.message.tool_calls) if choice.message.tool_calls else None + ) return QueryResult( id=result_id or response.id, - response_output=TextModelOutput(output=choice.message.content or ""), + response_output=TextModelOutput( + output=choice.message.content or "", + tool_calls=tool_calls_tuple, + ), metadata=metadata, ) diff --git a/tests/unit/async_utils/services/metrics_aggregator/conftest.py b/tests/unit/async_utils/services/metrics_aggregator/conftest.py index eb80b2ba..70bf4bee 100644 --- a/tests/unit/async_utils/services/metrics_aggregator/conftest.py +++ b/tests/unit/async_utils/services/metrics_aggregator/conftest.py @@ -119,6 +119,22 @@ async def token_count_async( await asyncio.sleep(self._delay) return len(text.split()) + async def token_count_message_async( + self, + content: str, + reasoning: str | None, + tool_calls, + _loop: asyncio.AbstractEventLoop, + ) -> int: + import msgspec + + await asyncio.sleep(self._delay) + tool_calls_str = ( + msgspec.json.encode(list(tool_calls)).decode() if tool_calls else "" + ) + combined = (content or "") + " " + (reasoning or "") + " " + tool_calls_str + return len(combined.split()) + def close(self) -> None: pass diff --git a/tests/unit/async_utils/services/metrics_aggregator/test_metrics_table.py b/tests/unit/async_utils/services/metrics_aggregator/test_metrics_table.py index 8f523224..a9abe353 100644 --- a/tests/unit/async_utils/services/metrics_aggregator/test_metrics_table.py +++ b/tests/unit/async_utils/services/metrics_aggregator/test_metrics_table.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio + import msgspec import pytest from inference_endpoint.async_utils.services.metrics_aggregator.metrics_table import ( @@ -262,3 +264,130 @@ def test_multiple_tracking_windows(self): assert table.tracked_blocks[1].duration_ns == 200 # 1000 - 800 assert table.total_tracked_duration_ns == 800 assert table.total_completed_tracked_samples == 2 + + +@pytest.mark.unit +@pytest.mark.asyncio +class TestOslTriggerToolCalls: + """OslTrigger routes to message path when tool_calls are present.""" + + async def test_osl_with_tool_calls_uses_message_path(self): + """OslTrigger stores combined content+tool_calls word count.""" + from inference_endpoint.async_utils.services.metrics_aggregator.metrics_table import ( + OslTrigger, + ) + from inference_endpoint.core.types import TextModelOutput + + from .conftest import InMemoryKVStore, MockTokenizePool + + kv = InMemoryKVStore() + loop = asyncio.get_running_loop() + pool = MockTokenizePool(delay=0) + trigger = OslTrigger(kv, pool, loop) + trigger.kv_store.create_key("osl", "series", dtype=int) + + tool_calls = ( + { + "id": "c1", + "type": "function", + "function": {"name": "f", "arguments": "{}"}, + }, + ) + tmo = TextModelOutput(output="hello world", tool_calls=tool_calls) + ev = EventRecord( + event_type=SampleEventType.COMPLETE, + timestamp_ns=1000, + sample_uuid="s1", + data=tmo, + ) + from inference_endpoint.async_utils.services.metrics_aggregator.metrics_table import ( + SampleRow, + ) + + row = SampleRow(sample_uuid="s1") + task = trigger.fire(ev, row, {}) + assert task is not None + await task + + values = kv.get_series_values("osl") + assert len(values) == 1 + assert values[0] > 0 + + async def test_osl_without_tool_calls_uses_text_path(self): + """OslTrigger uses text path for output with no tool_calls (regression guard).""" + from inference_endpoint.async_utils.services.metrics_aggregator.metrics_table import ( + OslTrigger, + SampleRow, + ) + from inference_endpoint.core.types import TextModelOutput + + from .conftest import InMemoryKVStore, MockTokenizePool + + kv = InMemoryKVStore() + loop = asyncio.get_running_loop() + pool = MockTokenizePool(delay=0) + trigger = OslTrigger(kv, pool, loop) + trigger.kv_store.create_key("osl", "series", dtype=int) + + tmo = TextModelOutput(output="hello world") + ev = EventRecord( + event_type=SampleEventType.COMPLETE, + timestamp_ns=1000, + sample_uuid="s1", + data=tmo, + ) + row = SampleRow(sample_uuid="s1") + task = trigger.fire(ev, row, {}) + assert task is not None + await task + + values = kv.get_series_values("osl") + assert values == [2] # "hello world" -> 2 words + + +@pytest.mark.unit +@pytest.mark.asyncio +class TestTpotTriggerToolCalls: + """TpotTrigger routes to message path when tool_calls are present.""" + + async def test_tpot_tool_calls_only_response(self): + """TpotTrigger includes tool_calls in TPOT denominator for agentic responses.""" + from inference_endpoint.async_utils.services.metrics_aggregator.metrics_table import ( + SampleField, + SampleRow, + TpotTrigger, + ) + from inference_endpoint.core.types import TextModelOutput + + from .conftest import InMemoryKVStore, MockTokenizePool + + kv = InMemoryKVStore() + loop = asyncio.get_running_loop() + pool = MockTokenizePool(delay=0) + trigger = TpotTrigger(kv, pool, loop) + trigger.kv_store.create_key("tpot_ns", "series", dtype=float) + + tool_calls = ( + { + "id": "c1", + "type": "function", + "function": {"name": "f", "arguments": "{}"}, + }, + ) + tmo = TextModelOutput(output=[], tool_calls=tool_calls) + ev = EventRecord( + event_type=SampleEventType.COMPLETE, + timestamp_ns=2000, + sample_uuid="s1", + data=tmo, + ) + row = SampleRow(sample_uuid="s1") + # RECV_FIRST_NS was set at t=1000 + pre_change = {SampleField.RECV_FIRST_NS: 1000} + task = trigger.fire(ev, row, pre_change) + assert task is not None + await task + + values = kv.get_series_values("tpot_ns") + assert len(values) == 1 + assert values[0] > 0 diff --git a/tests/unit/async_utils/services/metrics_aggregator/test_token_metrics.py b/tests/unit/async_utils/services/metrics_aggregator/test_token_metrics.py index 25cad157..984bb5f4 100644 --- a/tests/unit/async_utils/services/metrics_aggregator/test_token_metrics.py +++ b/tests/unit/async_utils/services/metrics_aggregator/test_token_metrics.py @@ -103,3 +103,80 @@ def test_context_manager(self): assert pool.token_count("a b c") == 3 with pytest.raises(RuntimeError, match="closed"): pool.token_count("test") + + +class _FakeTokenizerWithTemplate(_FakeTokenizer): + """Tokenizer that supports apply_chat_template for tool-call testing.""" + + def apply_chat_template( + self, messages, tokenize=True, add_generation_prompt=False + ) -> list[int]: + parts = [] + for msg in messages: + parts.append(msg.get("content") or "") + if msg.get("reasoning_content"): + parts.append(msg["reasoning_content"]) + if msg.get("tool_calls"): + import msgspec + + parts.append(msgspec.json.encode(msg["tool_calls"]).decode()) + combined = " ".join(p for p in parts if p) + # Add 2 wrapper tokens to simulate template overhead + return [0, 0] + list(range(len(combined.split()))) + + +@pytest.mark.unit +class TestTokenizePoolMessageTokenization: + def test_token_count_message_subtracts_baseline(self): + """token_count_message returns full_tokens - baseline.""" + with patch(_MOCK_TARGET, _FakeTokenizerWithTemplate): + with TokenizePool("fake", n_workers=1) as pool: + # "hello world" -> 2 content words + 2 wrapper = 4; baseline = 0 + 2 = 2; net = 2 + count = pool.token_count_message("hello world", None, None) + assert count == 2 + + def test_token_count_message_includes_tool_calls(self): + """token_count_message includes tool-call JSON tokens.""" + with patch(_MOCK_TARGET, _FakeTokenizerWithTemplate): + with TokenizePool("fake", n_workers=1) as pool: + tool_calls = ( + { + "id": "c1", + "type": "function", + "function": {"name": "f", "arguments": "{}"}, + }, + ) + count_without = pool.token_count_message("hello", None, None) + count_with = pool.token_count_message("hello", None, tool_calls) + assert count_with > count_without + + def test_token_count_message_fallback_on_exception(self): + """Falls back to whitespace split when apply_chat_template raises.""" + + class _BadTemplateTokenizer(_FakeTokenizer): + def apply_chat_template(self, *args, **kwargs): + raise ValueError("template does not support tool_calls") + + with patch(_MOCK_TARGET, _BadTemplateTokenizer): + with TokenizePool("fake", n_workers=1) as pool: + tool_calls = ( + { + "id": "c1", + "type": "function", + "function": {"name": "f", "arguments": "{}"}, + }, + ) + # Should not raise; falls back to whitespace tokenizer + count = pool.token_count_message("hello world", None, tool_calls) + assert count > 0 + + @pytest.mark.asyncio + async def test_token_count_message_async(self): + """token_count_message_async returns count without blocking event loop.""" + with patch(_MOCK_TARGET, _FakeTokenizerWithTemplate): + loop = asyncio.get_running_loop() + with TokenizePool("fake", n_workers=1) as pool: + count = await pool.token_count_message_async( + "hello world", None, None, loop + ) + assert count == 2 diff --git a/tests/unit/core/test_types.py b/tests/unit/core/test_types.py index 52bdbe77..32d15ab2 100644 --- a/tests/unit/core/test_types.py +++ b/tests/unit/core/test_types.py @@ -891,3 +891,86 @@ def test_numeric_types_in_metadata(self): assert decoded.metadata["large_int"] == 9999999999999999 assert decoded.metadata["negative"] == -123.456 assert decoded.metadata["zero"] == 0 + + +@pytest.mark.unit +class TestTextModelOutputToolCalls: + """Test TextModelOutput.tool_calls field: coercion, __str__, text_after_first_chunk.""" + + _TC = [ + {"id": "c1", "type": "function", "function": {"name": "f", "arguments": "{}"}} + ] + + def test_list_coerced_to_tuple(self): + tmo = TextModelOutput(output="hi", tool_calls=self._TC) + assert isinstance(tmo.tool_calls, tuple) + assert len(tmo.tool_calls) == 1 + + def test_none_tool_calls(self): + tmo = TextModelOutput(output="hi", tool_calls=None) + assert tmo.tool_calls is None + + def test_str_includes_tool_calls_json(self): + tmo = TextModelOutput(output="hello", tool_calls=self._TC) + s = str(tmo) + assert "hello" in s + assert '"name"' in s + assert '"function"' in s + + def test_str_without_tool_calls(self): + tmo = TextModelOutput(output="hello") + assert str(tmo) == "hello" + + def test_text_after_first_chunk_includes_tool_calls(self): + # streaming output with tool_calls: first chunk skipped, tool_calls appended + tmo = TextModelOutput(output=("a", "b"), tool_calls=self._TC) + after = tmo.text_after_first_chunk() + assert "b" in after + assert '"function"' in after + + def test_text_after_first_chunk_tool_calls_only_no_content(self): + # tool_calls only (pure agentic response with no text content) + tmo = TextModelOutput(output=[], tool_calls=self._TC) + after = tmo.text_after_first_chunk() + assert '"function"' in after + + def test_as_message_parts_str_output(self): + tmo = TextModelOutput(output="hello", tool_calls=self._TC) + content, reasoning, tc = tmo.as_message_parts() + assert content == "hello" + assert reasoning is None + assert tc == tmo.tool_calls + + def test_as_message_parts_tuple_output(self): + tmo = TextModelOutput( + output=("a", "b"), reasoning=("r1", "r2"), tool_calls=self._TC + ) + content, reasoning, tc = tmo.as_message_parts() + assert content == "ab" + assert reasoning == "r1r2" + assert tc == tmo.tool_calls + + def test_as_message_parts_after_first_chunk_str_output(self): + tmo = TextModelOutput(output="hello", tool_calls=self._TC) + content, reasoning, tc = tmo.as_message_parts_after_first_chunk() + # non-streaming str output: no "after first chunk" for content + assert content == "" + assert reasoning is None + assert tc == tmo.tool_calls + + def test_as_message_parts_after_first_chunk_tuple_output(self): + tmo = TextModelOutput(output=("a", "b", "c"), tool_calls=self._TC) + content, reasoning, tc = tmo.as_message_parts_after_first_chunk() + assert content == "bc" + assert reasoning is None + assert tc == tmo.tool_calls + + def test_serialization_roundtrip_with_tool_calls(self): + import msgspec + + tmo = TextModelOutput(output="hello", tool_calls=self._TC) + encoded = msgspec.msgpack.encode(tmo) + decoded = msgspec.msgpack.decode(encoded, type=TextModelOutput) + assert decoded.tool_calls is not None + assert len(decoded.tool_calls) == 1 + assert decoded.tool_calls[0]["function"]["name"] == "f" diff --git a/tests/unit/openai/test_accumulator.py b/tests/unit/openai/test_accumulator.py new file mode 100644 index 00000000..ae23f7e6 --- /dev/null +++ b/tests/unit/openai/test_accumulator.py @@ -0,0 +1,163 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for OpenAISSEAccumulator tool-call handling.""" + +import pytest +from inference_endpoint.core.types import StreamChunk, TextModelOutput +from inference_endpoint.openai.accumulator import OpenAISSEAccumulator +from inference_endpoint.openai.types import SSEChoice, SSEDelta + + +def _make_tc_partial(idx: int, tc_id: str, name: str, args: str) -> SSEChoice: + """Create an SSEChoice with a single partial tool_call delta.""" + return SSEChoice( + delta=SSEDelta( + tool_calls=[ + { + "index": idx, + "id": tc_id, + "type": "function", + "function": {"name": name, "arguments": args}, + } + ] + ) + ) + + +def _make_content_choice(content: str, first_chunk: bool = False) -> SSEChoice: + return SSEChoice(delta=SSEDelta(content=content)) + + +def _make_reasoning_choice(reasoning: str) -> SSEChoice: + return SSEChoice(delta=SSEDelta(reasoning_content=reasoning)) + + +def _make_finish_choice() -> SSEChoice: + return SSEChoice(finish_reason="tool_calls") + + +@pytest.mark.unit +class TestAccumulatorPureToolCalls: + """Pure tool-call stream (no content/reasoning text chunks).""" + + def test_tool_calls_in_text_output(self): + acc = OpenAISSEAccumulator("qid", stream_all_chunks=False) + acc.add_chunk(_make_tc_partial(0, "c1", "search", '{"q":')) + acc.add_chunk( + SSEChoice( + delta=SSEDelta( + tool_calls=[{"index": 0, "function": {"arguments": '"test"}'}}] + ) + ) + ) + acc.add_chunk(_make_finish_choice()) + + result = acc.get_final_output() + assert isinstance(result.response_output, TextModelOutput) + assert result.response_output.tool_calls is not None + assert len(result.response_output.tool_calls) == 1 + assert result.response_output.tool_calls[0]["function"]["name"] == "search" + assert ( + result.response_output.tool_calls[0]["function"]["arguments"] + == '{"q":"test"}' + ) + + def test_metadata_tool_calls_preserved(self): + acc = OpenAISSEAccumulator("qid", stream_all_chunks=False) + acc.add_chunk(_make_tc_partial(0, "c1", "f", "{}")) + acc.add_chunk(_make_finish_choice()) + + result = acc.get_final_output() + assert "tool_calls" in result.metadata + assert result.metadata["tool_calls"][0]["function"]["name"] == "f" + + def test_first_tool_call_delta_emits_sentinel_stream_chunk(self): + acc = OpenAISSEAccumulator("qid", stream_all_chunks=False) + sentinel = acc.add_chunk(_make_tc_partial(0, "c1", "f", "{}")) + + assert isinstance(sentinel, StreamChunk) + assert sentinel.id == "qid" + assert sentinel.response_chunk == "" + assert sentinel.metadata.get("first_chunk") is True + + def test_subsequent_tool_call_deltas_return_none(self): + acc = OpenAISSEAccumulator("qid", stream_all_chunks=False) + # First delta: sentinel emitted + acc.add_chunk(_make_tc_partial(0, "c1", "f", "{")) + # Second delta: no sentinel + second = acc.add_chunk( + SSEChoice( + delta=SSEDelta( + tool_calls=[{"index": 0, "function": {"arguments": "}"}}] + ) + ) + ) + assert second is None + + def test_first_chunk_sent_after_sentinel(self): + acc = OpenAISSEAccumulator("qid", stream_all_chunks=False) + acc.add_chunk(_make_tc_partial(0, "c1", "f", "{}")) + assert acc.first_chunk_sent is True + + +@pytest.mark.unit +class TestAccumulatorMixedReasoningAndToolCalls: + """Mixed stream: reasoning followed by tool_calls.""" + + def test_reasoning_chunk_is_first_chunk_not_tool_call(self): + acc = OpenAISSEAccumulator("qid", stream_all_chunks=False) + reasoning_chunk = acc.add_chunk(_make_reasoning_choice("Let me think")) + # Reasoning chunk should be the first chunk + assert isinstance(reasoning_chunk, StreamChunk) + assert reasoning_chunk.metadata.get("first_chunk") is True + + # Now a tool_call delta should NOT emit another sentinel (first_chunk_sent=True) + tc_chunk = acc.add_chunk(_make_tc_partial(0, "c1", "f", "{}")) + assert tc_chunk is None + + def test_tool_calls_in_output_after_reasoning(self): + acc = OpenAISSEAccumulator("qid", stream_all_chunks=False) + acc.add_chunk(_make_reasoning_choice("Thinking...")) + acc.add_chunk(_make_tc_partial(0, "c1", "search", '{"q":"x"}')) + acc.add_chunk(_make_finish_choice()) + + result = acc.get_final_output() + assert isinstance(result.response_output, TextModelOutput) + assert result.response_output.reasoning is not None + assert result.response_output.tool_calls is not None + assert result.response_output.tool_calls[0]["function"]["name"] == "search" + + def test_content_then_tool_calls(self): + acc = OpenAISSEAccumulator("qid", stream_all_chunks=False) + acc.add_chunk(_make_content_choice("Hello")) + acc.add_chunk(_make_tc_partial(0, "c1", "f", "{}")) + acc.add_chunk(_make_finish_choice()) + + result = acc.get_final_output() + assert ( + result.response_output.output == ("Hello",) + or result.response_output.output == "Hello" + ) + assert result.response_output.tool_calls is not None + + def test_no_tool_calls_returns_none_field(self): + acc = OpenAISSEAccumulator("qid", stream_all_chunks=False) + acc.add_chunk(_make_content_choice("Hello world")) + acc.add_chunk(_make_finish_choice()) + + result = acc.get_final_output() + assert result.response_output.tool_calls is None + assert "tool_calls" not in result.metadata diff --git a/tests/unit/openai/test_msgspec_adapter.py b/tests/unit/openai/test_msgspec_adapter.py index 8127d199..7360c304 100644 --- a/tests/unit/openai/test_msgspec_adapter.py +++ b/tests/unit/openai/test_msgspec_adapter.py @@ -144,3 +144,50 @@ def test_chat_message_content_optional(): """ChatMessage accepts content=None for tool-dispatching assistant turns.""" msg = ChatMessage(role="assistant", tool_calls=[]) assert msg.content is None + + +@pytest.mark.unit +def test_from_endpoint_response_populates_tool_calls_in_text_output(): + """Non-streaming response with tool_calls populates TextModelOutput.tool_calls.""" + import json + + tool_calls = [ + { + "id": "call_1", + "type": "function", + "function": {"name": "search", "arguments": '{"q": "test"}'}, + } + ] + response_bytes = json.dumps( + { + "id": "chatcmpl-test", + "object": "chat.completion", + "created": 1234567890, + "model": "test-model", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": None, + "refusal": None, + "tool_calls": tool_calls, + }, + "finish_reason": "tool_calls", + } + ], + "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, + "system_fingerprint": None, + } + ).encode() + + result = OpenAIMsgspecAdapter.decode_response(response_bytes, "q1") + + from inference_endpoint.core.types import TextModelOutput + + assert isinstance(result.response_output, TextModelOutput) + assert result.response_output.tool_calls is not None + assert len(result.response_output.tool_calls) == 1 + assert result.response_output.tool_calls[0]["function"]["name"] == "search" + # metadata path unchanged + assert result.metadata.get("tool_calls") == tool_calls diff --git a/tests/unit/openai/test_openai_adapter.py b/tests/unit/openai/test_openai_adapter.py index 506ec3fe..82672763 100644 --- a/tests/unit/openai/test_openai_adapter.py +++ b/tests/unit/openai/test_openai_adapter.py @@ -145,3 +145,50 @@ def test_no_tools_key_when_absent(): # Pydantic model_dump includes None fields; tools must be None when not supplied assert payload.get("tools") is None + + +@pytest.mark.unit +def test_from_endpoint_response_populates_tool_calls_in_text_output(): + """Non-streaming response with tool_calls populates TextModelOutput.tool_calls.""" + tool_calls_data = [ + { + "id": "call_1", + "type": "function", + "function": {"name": "search", "arguments": '{"q": "test"}'}, + } + ] + response_bytes = json.dumps( + { + "id": "chatcmpl-test", + "object": "chat.completion", + "created": 1234567890, + "model": "test-model", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": None, + "refusal": "", + "tool_calls": tool_calls_data, + }, + "finish_reason": "tool_calls", + "logprobs": {"content": [], "refusal": []}, + } + ], + "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, + "system_fingerprint": None, + } + ).encode() + + result = OpenAIAdapter.decode_response(response_bytes, "q1") + + from inference_endpoint.core.types import TextModelOutput + + assert isinstance(result.response_output, TextModelOutput) + assert result.response_output.tool_calls is not None + assert len(result.response_output.tool_calls) == 1 + assert result.response_output.tool_calls[0]["function"]["name"] == "search" + # metadata path preserved + assert result.metadata.get("tool_calls") is not None + assert result.metadata["tool_calls"][0]["function"]["name"] == "search" From 452da2f9639dea1b1fedd767963468cd63829972 Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Thu, 7 May 2026 23:01:29 +0000 Subject: [PATCH 19/41] fix: correct chat-template tokenization for tool-call messages MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two bugs in TokenizePool._token_count_message_worker caused OSL / TPOT to be inflated for every response containing tool_calls: 1. tool_calls[].function.arguments arrives as the OpenAI wire-format JSON string, but Hermes-style chat templates (Qwen3-Coder, etc.) iterate arguments as a mapping. Passing a string raises, and the code silently fell through to whitespace-splitting content + reasoning + json(tool_calls) β€” counting every JSON bracket, quote, and escape as its own token. Fixed by parsing arguments to dict before rendering. 2. apply_chat_template rejects assistant-only message lists on several templates ("No user query found in messages"). The render also raised, forcing the fallback path. Fixed by prepending an empty user message and subtracting its token length back out. Also switched the render path from tokenize=True (which returns a single- element [Encoding] in recent transformers, so len() was 1) to tokenize=False followed by tokenizer.tokenize(rendered), matching how _token_count_worker measures plain text. Verified on a real Qwen3.6-35B-A3B response: a tool-calling turn that previously reported 130 tokens now reports 100, matching the raw-bytes reference of 102 (2-token delta is the template's \\n scaffolding). Co-Authored-By: Claude Opus 4.7 --- .../metrics_aggregator/token_metrics.py | 72 ++++++++++++++++--- .../metrics_aggregator/test_token_metrics.py | 18 +++-- 2 files changed, 73 insertions(+), 17 deletions(-) diff --git a/src/inference_endpoint/async_utils/services/metrics_aggregator/token_metrics.py b/src/inference_endpoint/async_utils/services/metrics_aggregator/token_metrics.py index 8dddd9dc..c6ebfe90 100644 --- a/src/inference_endpoint/async_utils/services/metrics_aggregator/token_metrics.py +++ b/src/inference_endpoint/async_utils/services/metrics_aggregator/token_metrics.py @@ -18,6 +18,7 @@ from __future__ import annotations import asyncio +import json import logging import threading from concurrent.futures import ThreadPoolExecutor @@ -26,6 +27,39 @@ import msgspec from transformers import AutoTokenizer +# Minimal user message used to satisfy chat templates that reject assistant-only +# message lists. Its token count is subtracted so only the assistant payload is +# measured. +_PREFIX_USER_MSG: dict[str, str] = {"role": "user", "content": ""} + + +def _normalize_tool_calls_for_template( + tool_calls: tuple[dict[str, Any], ...] | list[dict[str, Any]], +) -> list[dict[str, Any]]: + """Ensure ``function.arguments`` is a dict, not the OpenAI-wire JSON string. + + Hermes-style chat templates iterate ``arguments`` as a mapping; a string + payload raises and forces the fallback path, inflating token counts. + """ + normalized: list[dict[str, Any]] = [] + for tc in tool_calls: + fn = tc.get("function") or {} + args = fn.get("arguments") + if isinstance(args, str): + try: + parsed = json.loads(args) + except (json.JSONDecodeError, TypeError): + normalized.append(tc) + continue + if isinstance(parsed, dict): + new_tc = dict(tc) + new_tc["function"] = {**fn, "arguments": parsed} + normalized.append(new_tc) + continue + normalized.append(tc) + return normalized + + if TYPE_CHECKING: from transformers import PreTrainedTokenizerBase @@ -83,14 +117,30 @@ def _get_thread_tokenizer(self) -> PreTrainedTokenizerBase: self._thread_local.tokenizer = AutoTokenizer.from_pretrained( self._tokenizer_name ) + # Baseline = tokens contributed by a [user, empty-assistant] pair minus + # the [user] prefix alone. Some templates (Qwen3-Coder, etc.) reject + # assistant-only message lists, so a user prefix is required; we + # subtract it out so the baseline reflects only the assistant frame. try: - baseline_tokens = self._thread_local.tokenizer.apply_chat_template( - [{"role": "assistant", "content": ""}], - tokenize=True, + tok = self._thread_local.tokenizer + prefix_rendered = tok.apply_chat_template( + [_PREFIX_USER_MSG], + tokenize=False, + add_generation_prompt=False, + ) + prefix_len = len(tok.tokenize(prefix_rendered)) + with_empty_assistant_rendered = tok.apply_chat_template( + [_PREFIX_USER_MSG, {"role": "assistant", "content": ""}], + tokenize=False, add_generation_prompt=False, ) - self._thread_local.baseline = len(baseline_tokens) + with_empty_assistant_len = len( + tok.tokenize(with_empty_assistant_rendered) + ) + self._thread_local.prefix_len = prefix_len + self._thread_local.baseline = with_empty_assistant_len - prefix_len except Exception: + self._thread_local.prefix_len = 0 self._thread_local.baseline = 0 logger.warning( "Failed to compute chat-template baseline for %s; tool-call token counts may be over-estimated", @@ -119,15 +169,17 @@ def _token_count_message_worker( if reasoning: msg["reasoning_content"] = reasoning if tool_calls: - msg["tool_calls"] = list(tool_calls) + msg["tool_calls"] = _normalize_tool_calls_for_template(tool_calls) try: - full = len( - tokenizer.apply_chat_template( - [msg], tokenize=True, add_generation_prompt=False - ) + rendered = tokenizer.apply_chat_template( + [_PREFIX_USER_MSG, msg], + tokenize=False, + add_generation_prompt=False, ) + full = len(tokenizer.tokenize(rendered)) + prefix_len = getattr(self._thread_local, "prefix_len", 0) baseline = getattr(self._thread_local, "baseline", 0) - return max(0, full - baseline) + return max(0, full - prefix_len - baseline) except Exception: tool_calls_json = ( msgspec.json.encode(list(tool_calls)).decode() if tool_calls else "" diff --git a/tests/unit/async_utils/services/metrics_aggregator/test_token_metrics.py b/tests/unit/async_utils/services/metrics_aggregator/test_token_metrics.py index 984bb5f4..51c6e80d 100644 --- a/tests/unit/async_utils/services/metrics_aggregator/test_token_metrics.py +++ b/tests/unit/async_utils/services/metrics_aggregator/test_token_metrics.py @@ -109,20 +109,24 @@ class _FakeTokenizerWithTemplate(_FakeTokenizer): """Tokenizer that supports apply_chat_template for tool-call testing.""" def apply_chat_template( - self, messages, tokenize=True, add_generation_prompt=False - ) -> list[int]: - parts = [] + self, messages, tokenize=False, add_generation_prompt=False + ): + # Simulate 2 wrapper tokens for the template frame. + parts = ["WRAPPER", "WRAPPER"] for msg in messages: - parts.append(msg.get("content") or "") + content = msg.get("content") + if content: + parts.append(content) if msg.get("reasoning_content"): parts.append(msg["reasoning_content"]) if msg.get("tool_calls"): import msgspec parts.append(msgspec.json.encode(msg["tool_calls"]).decode()) - combined = " ".join(p for p in parts if p) - # Add 2 wrapper tokens to simulate template overhead - return [0, 0] + list(range(len(combined.split()))) + rendered = " ".join(parts) + if tokenize: + return list(range(len(rendered.split()))) + return rendered @pytest.mark.unit From 7bde10ba2865146e8460c10ec8191410006e9d47 Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Fri, 8 May 2026 01:25:15 +0000 Subject: [PATCH 20/41] docs: fix stale references and tool-row format in multi-turn docs - QUICKSTART: validate_jsonl_schema.py only does per-row JSON Schema checks; cross-row invariants (role sequences, turn numbering, grouping) are enforced by MultiTurnDataset at load time, not the script - README: collapse single/merged tool rows into unified tool_results form to match what MultiTurnDataset._validate_conversation_structure enforces - multi_turn_dataset.py: fix docstring referencing removed AddDefaultColumns Co-Authored-By: Claude Sonnet 4.6 --- docs/MULTI_TURN_QUICKSTART.md | 6 ++++-- examples/09_MultiTurn/README.md | 13 ++++++------- .../dataset_manager/multi_turn_dataset.py | 3 ++- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/docs/MULTI_TURN_QUICKSTART.md b/docs/MULTI_TURN_QUICKSTART.md index 03bf8519..340b4922 100644 --- a/docs/MULTI_TURN_QUICKSTART.md +++ b/docs/MULTI_TURN_QUICKSTART.md @@ -135,8 +135,10 @@ Use the bundled validation script to check your JSONL file for schema errors bef python scripts/validate_jsonl_schema.py path/to/your/conversations.jsonl ``` -This catches missing required fields, invalid role sequences, non-consecutive turn numbers, and -interleaved conversations β€” all errors that would otherwise surface at benchmark startup. +This catches per-row schema errors (missing required fields, wrong types, +malformed `tool_results`). Cross-row invariants (consecutive turn numbers, +valid role sequences, grouped conversations) are enforced by +`MultiTurnDataset` at load time and will surface at benchmark startup. ### "Conversation has invalid role sequence" diff --git a/examples/09_MultiTurn/README.md b/examples/09_MultiTurn/README.md index 3a6f363a..a6b79419 100644 --- a/examples/09_MultiTurn/README.md +++ b/examples/09_MultiTurn/README.md @@ -104,19 +104,18 @@ and exits with code 1 if any mismatch is found. The script also: The extra fields supported beyond plain user/assistant: -| Row role | Extra fields | -| -------------------------------- | ------------------------------------------------------------------ | -| `assistant` with tool calls | `tool_calls: [{id, type, function: {name, arguments}}]` | -| `tool` single result | `tool_call_id: `, `content: ` | -| `tool` parallel results (merged) | `tool_results: [{tool_call_id, content}, ...]` | -| `user` or `tool` turns | `tools: [...]` (OpenAI tool definitions forwarded to the endpoint) | +| Row role | Extra fields | +| ------------------------------------------ | ------------------------------------------------------------------ | +| `assistant` with tool calls | `tool_calls: [{id, type, function: {name, arguments}}]` | +| `tool` results (single or merged parallel) | `tool_results: [{tool_call_id, content}, ...]` | +| `user` or `tool` turns | `tools: [...]` (OpenAI tool definitions forwarded to the endpoint) | Example rows from a converted agentic dataset: ```jsonl {"conversation_id": "sim_001", "turn": 1, "role": "user", "content": "Fix the bug in foo.py", "system": "You are a coding agent.", "tools": [...]} {"conversation_id": "sim_001", "turn": 2, "role": "assistant", "tool_calls": [{"id": "functions.bash:0", "type": "function", "function": {"name": "bash", "arguments": "{\"cmd\": \"cat foo.py\"}"}}]} -{"conversation_id": "sim_001", "turn": 3, "role": "tool", "tool_call_id": "functions.bash:0", "content": "def foo():\n return 1/0", "tools": [...]} +{"conversation_id": "sim_001", "turn": 3, "role": "tool", "tool_results": [{"tool_call_id": "functions.bash:0", "content": "def foo():\n return 1/0"}], "tools": [...]} {"conversation_id": "sim_001", "turn": 4, "role": "assistant", "content": "The bug is a ZeroDivisionError. Here is the fix: ..."} ``` diff --git a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py index 63e5aa4e..87a6ae4f 100644 --- a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py +++ b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py @@ -339,7 +339,8 @@ def load( Unlike single-turn datasets, multi-turn rows do not have a `prompt` column, so ColumnFilter (which requires prompt) is skipped. AddStaticColumns entries - from the adapter are applied via AddDefaultColumns (fill-missing-only) so that + from the adapter are applied via AddStaticColumns(..., overwrite=False) + (fill-missing-only) so that per-row dataset overrides are preserved. After transforms, only client turns (user + tool) are stored in self.data as From 408ed21be239d4df99d0c73f7983b8e264739add Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Fri, 8 May 2026 04:15:29 +0000 Subject: [PATCH 21/41] feat: pre-compute ISL token counts for multi-turn dataset-history mode - Add _precompute_isl_for_multi_turn() in execute.py: runs apply_chat_template(messages, tokenize=True, add_generation_prompt=True) once per client turn at setup time and stores results in sample["input_tokens"], hitting the IslTrigger sync fast path (len(token_ids)) with zero hot-path cost. - Add _extract_prompt_text() in session.py: refactors inline message content extraction to handle list-form multimodal content safely, fixing a crash when content is a list (e.g. vision/tool-call messages). - Add unit tests for both helpers and two integration tests covering target_concurrency cap enforcement and pipeline exception propagation. Co-Authored-By: Claude Sonnet 4.6 --- .../commands/benchmark/execute.py | 40 +++++++ .../load_generator/session.py | 27 ++++- tests/integration/test_multi_turn.py | 110 ++++++++++++++++++ tests/unit/commands/test_precompute_isl.py | 106 +++++++++++++++++ .../unit/load_generator/test_async_session.py | 51 ++++++++ 5 files changed, 328 insertions(+), 6 deletions(-) create mode 100644 tests/unit/commands/test_precompute_isl.py diff --git a/src/inference_endpoint/commands/benchmark/execute.py b/src/inference_endpoint/commands/benchmark/execute.py index 89297c02..696092a3 100644 --- a/src/inference_endpoint/commands/benchmark/execute.py +++ b/src/inference_endpoint/commands/benchmark/execute.py @@ -299,6 +299,42 @@ def _load_datasets( return dataloader, accuracy_datasets, eval_configs +def _precompute_isl_for_multi_turn( + dataloader: MultiTurnDataset, tokenizer_name: str +) -> None: + """Tokenize pre-built message lists and store token counts in each sample. + + Runs apply_chat_template once per client turn so the hot-path IslTrigger + sync path (len(token_ids)) is used instead of on-the-fly text tokenization. + Only affects dataset-history turns; live-history turns override 'messages' + at runtime so the stored input_tokens are stale (acceptable approximation). + """ + # Local import: optional dependency, circular-import avoidance (consistent + # with _annotate_response_token_counts in this file). + from transformers import AutoTokenizer # noqa: PLC0415 + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + skipped = 0 + for sample in dataloader.data or []: + messages = sample.get("messages") + if not messages: + continue + try: + token_ids: list[int] = tokenizer.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + ) + sample["input_tokens"] = token_ids + except Exception: # template errors vary by model; skip gracefully + skipped += 1 + if skipped: + logger.warning( + "ISL pre-computation: %d turn(s) skipped (apply_chat_template failed)", + skipped, + ) + + def setup_benchmark(config: BenchmarkConfig, test_mode: TestMode) -> BenchmarkContext: """Load tokenizer, dataset, create scheduler, setup report dir.""" # CPU affinity @@ -328,6 +364,10 @@ def setup_benchmark(config: BenchmarkConfig, test_mode: TestMode) -> BenchmarkCo # Datasets dataloader, accuracy_datasets, eval_configs = _load_datasets(config, report_dir) + if isinstance(dataloader, MultiTurnDataset) and tokenizer_name is not None: + logger.info("Pre-computing ISL token counts for multi-turn dataset…") + _precompute_isl_for_multi_turn(dataloader, tokenizer_name) + # Setup runtime settings using factory method rt_settings = RuntimeSettings.from_config(config, dataloader.num_samples()) diff --git a/src/inference_endpoint/load_generator/session.py b/src/inference_endpoint/load_generator/session.py index f15265f8..c69049ec 100644 --- a/src/inference_endpoint/load_generator/session.py +++ b/src/inference_endpoint/load_generator/session.py @@ -47,6 +47,26 @@ _WARMUP_ENABLED = os.environ.get("ENABLE_WARMUP") == "1" +def _extract_prompt_text(messages: list[Any]) -> str | None: + """Join text content from an OpenAI messages list; handles list-form multimodal content.""" + parts: list[str] = [] + for m in messages: + if not isinstance(m, dict): + continue + c = m.get("content") + if isinstance(c, str) and c: + parts.append(c) + elif isinstance(c, list): + parts.extend( + p["text"] + for p in c + if isinstance(p, dict) + and p.get("type") == "text" + and isinstance(p.get("text"), str) + ) + return "\n".join(parts) if parts else None + + # --------------------------------------------------------------------------- # Phase configuration # --------------------------------------------------------------------------- @@ -210,12 +230,7 @@ def issue( # means that ISL reporting will be unavailable for multimodal samples. prompt_text = data.get("prompt") if prompt_text is None and "messages" in data: - parts: list[str] = [ - m["content"] - for m in data["messages"] - if isinstance(m, dict) and m.get("content") - ] - prompt_text = "\n".join(parts) if parts else None + prompt_text = _extract_prompt_text(data["messages"]) prompt_data = PromptData( text=prompt_text if isinstance(prompt_text, str) else None, token_ids=tuple(token_ids) if token_ids is not None else None, diff --git a/tests/integration/test_multi_turn.py b/tests/integration/test_multi_turn.py index cfe8a68c..a18eee34 100644 --- a/tests/integration/test_multi_turn.py +++ b/tests/integration/test_multi_turn.py @@ -77,6 +77,7 @@ def _make_dataset(rows: list[dict]) -> MultiTurnDataset: def _make_strategy( ds: MultiTurnDataset, use_dataset_history: bool = True, + target_concurrency: int | None = None, ) -> MultiTurnStrategy: mt_cfg = MultiTurnConfig( turn_timeout_s=10.0, @@ -86,6 +87,7 @@ def _make_strategy( conversation_manager=ConversationManager(), dataset_metadata=ds.conversation_metadata, multi_turn_config=mt_cfg, + target_concurrency=target_concurrency, ) @@ -600,6 +602,114 @@ async def test_concurrent_conversations_stress(echo_server): assert len(responses) == expected_client_turns +@pytest.mark.integration +@pytest.mark.asyncio +async def test_multi_turn_active_conversations_respects_target_concurrency(echo_server): + num_convs = 20 + rows = [] + for i in range(num_convs): + conv_id = f"cap_conv_{i}" + rows += [ + { + "conversation_id": conv_id, + "turn": 1, + "role": "user", + "content": f"Q1-{i}", + }, + { + "conversation_id": conv_id, + "turn": 2, + "role": "assistant", + "content": f"A1-{i}", + }, + { + "conversation_id": conv_id, + "turn": 3, + "role": "user", + "content": f"Q2-{i}", + }, + ] + + ds = _make_dataset(rows) + strategy = _make_strategy(ds, target_concurrency=4) + responses: dict = {} + + observed_max: list[int] = [] + orig_on_sample_complete = strategy.on_sample_complete + + def tracked_on_sample_complete(result) -> None: + observed_max.append(len(strategy._active_iters)) + orig_on_sample_complete(result) + + strategy.on_sample_complete = tracked_on_sample_complete + + await _run_session(echo_server.url, ds, strategy, responses) + + assert len(responses) == num_convs * 2 # 2 client turns per conversation + assert max(observed_max, default=0) <= 4 + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_multi_turn_pipeline_exception_propagates(echo_server): + rows = [ + {"conversation_id": "err_c1", "turn": 1, "role": "user", "content": "Q1"}, + {"conversation_id": "err_c1", "turn": 2, "role": "assistant", "content": "A1"}, + {"conversation_id": "err_c1", "turn": 3, "role": "user", "content": "Q2"}, + ] + ds = _make_dataset(rows) + strategy = _make_strategy(ds) + + call_count = 0 + orig_issue_next_turn = strategy._issue_next_turn + + def failing_issue_next_turn(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count >= 2: + raise RuntimeError("injected pipeline error") + return orig_issue_next_turn(*args, **kwargs) + + strategy._issue_next_turn = failing_issue_next_turn + + loop = asyncio.get_running_loop() + http_config = HTTPClientConfig( + endpoint_urls=[urljoin(echo_server.url, "/v1/chat/completions")], + warmup_connections=0, + num_workers=2, + ) + http_client = await HTTPEndpointClient.create(http_config, loop) + issuer = HttpClientSampleIssuer(http_client) + + try: + session = BenchmarkSession( + issuer=issuer, + event_publisher=_NoOpPublisher(), + loop=loop, + on_sample_complete=strategy.on_sample_complete, + ) + rt = RuntimeSettings( + metrics.Throughput(1000), + [metrics.Throughput(1000)], + min_duration_ms=0, + max_duration_ms=30_000, + n_samples_from_dataset=ds.num_samples(), + n_samples_to_issue=ds.num_samples(), + min_sample_count=1, + rng_sched=random.Random(42), + rng_sample_index=random.Random(42), + load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), + ) + phase = PhaseConfig("perf", rt, ds, PhaseType.PERFORMANCE, strategy=strategy) + + with pytest.raises(RuntimeError, match="injected pipeline error"): + await asyncio.wait_for(session.run([phase]), timeout=30.0) + + assert strategy._inflight == {} + finally: + await http_client.shutdown_async() + + @pytest.mark.integration @pytest.mark.asyncio async def test_tools_field_forwarded_to_endpoint(echo_server): diff --git a/tests/unit/commands/test_precompute_isl.py b/tests/unit/commands/test_precompute_isl.py new file mode 100644 index 00000000..f78e4065 --- /dev/null +++ b/tests/unit/commands/test_precompute_isl.py @@ -0,0 +1,106 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for _precompute_isl_for_multi_turn.""" + +from unittest.mock import MagicMock, patch + +import pytest +from inference_endpoint.commands.benchmark.execute import _precompute_isl_for_multi_turn + + +def _make_dataloader(samples: list[dict]) -> MagicMock: + dl = MagicMock() + dl.data = samples + return dl + + +class TestPrecomputeIslForMultiTurn: + @pytest.mark.unit + def test_sets_input_tokens_for_samples_with_messages(self): + samples = [ + {"messages": [{"role": "user", "content": "hello"}]}, + {"messages": [{"role": "user", "content": "world"}]}, + ] + dataloader = _make_dataloader(samples) + mock_tokenizer = MagicMock() + mock_tokenizer.apply_chat_template.side_effect = lambda msgs, **_: list( + range(len(msgs) * 3) + ) + + with patch("transformers.AutoTokenizer") as mock_cls: + mock_cls.from_pretrained.return_value = mock_tokenizer + _precompute_isl_for_multi_turn(dataloader, "test-model") + + for sample in samples: + assert "input_tokens" in sample + assert isinstance(sample["input_tokens"], list) + + @pytest.mark.unit + def test_leaves_samples_without_messages_untouched(self): + samples = [ + {"prompt": "no messages here"}, + {"input_tokens": [1, 2, 3]}, + ] + dataloader = _make_dataloader(samples) + mock_tokenizer = MagicMock() + + with patch("transformers.AutoTokenizer") as mock_cls: + mock_cls.from_pretrained.return_value = mock_tokenizer + _precompute_isl_for_multi_turn(dataloader, "test-model") + + mock_tokenizer.apply_chat_template.assert_not_called() + assert "input_tokens" not in samples[0] + assert samples[1]["input_tokens"] == [1, 2, 3] + + @pytest.mark.unit + def test_skips_failed_template_calls_with_warning(self, caplog): + samples = [ + {"messages": [{"role": "user", "content": "good"}]}, + {"messages": [{"role": "user", "content": "bad"}]}, + ] + dataloader = _make_dataloader(samples) + + def side_effect(msgs, **_): + if msgs[0]["content"] == "bad": + raise ValueError("template error") + return [10, 20, 30] + + mock_tokenizer = MagicMock() + mock_tokenizer.apply_chat_template.side_effect = side_effect + + with patch("transformers.AutoTokenizer") as mock_cls: + mock_cls.from_pretrained.return_value = mock_tokenizer + with caplog.at_level("WARNING"): + _precompute_isl_for_multi_turn(dataloader, "test-model") + + assert "input_tokens" in samples[0] + assert "input_tokens" not in samples[1] + assert "1 turn(s) skipped" in caplog.text + + @pytest.mark.unit + def test_add_generation_prompt_true(self): + samples = [{"messages": [{"role": "user", "content": "hi"}]}] + dataloader = _make_dataloader(samples) + mock_tokenizer = MagicMock() + mock_tokenizer.apply_chat_template.return_value = [1, 2, 3] + + with patch("transformers.AutoTokenizer") as mock_cls: + mock_cls.from_pretrained.return_value = mock_tokenizer + _precompute_isl_for_multi_turn(dataloader, "test-model") + + _, kwargs = mock_tokenizer.apply_chat_template.call_args + assert kwargs.get("add_generation_prompt") is True + assert kwargs.get("tokenize") is True diff --git a/tests/unit/load_generator/test_async_session.py b/tests/unit/load_generator/test_async_session.py index 38dd014e..41a6635e 100644 --- a/tests/unit/load_generator/test_async_session.py +++ b/tests/unit/load_generator/test_async_session.py @@ -39,6 +39,7 @@ PhaseResult, PhaseType, SessionResult, + _extract_prompt_text, ) from inference_endpoint.metrics.metric import Throughput @@ -882,3 +883,53 @@ def test_perf_results_filter(self, enable_warmup): assert len(sr.perf_results) == 2 assert len(sr.accuracy_results) == 1 assert sr.perf_results[0].name == "perf1" + + +@pytest.mark.unit +class TestExtractPromptText: + def test_string_content_extracted(self): + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi"}, + ] + assert _extract_prompt_text(messages) == "Hello\nHi" + + def test_multimodal_list_content_text_parts_extracted(self): + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe this image"}, + {"type": "image_url"}, + ], + } + ] + assert _extract_prompt_text(messages) == "Describe this image" + + def test_mixed_string_and_list_content(self): + messages = [ + {"role": "system", "content": "You are helpful"}, + { + "role": "user", + "content": [ + {"type": "text", "text": "What is this?"}, + {"type": "image_url"}, + ], + }, + ] + assert _extract_prompt_text(messages) == "You are helpful\nWhat is this?" + + def test_none_content_skipped(self): + messages = [ + {"role": "assistant", "content": None}, + {"role": "user", "content": "Hello"}, + ] + assert _extract_prompt_text(messages) == "Hello" + + def test_list_content_with_no_text_parts_returns_none(self): + messages = [{"role": "user", "content": [{"type": "image_url"}]}] + assert _extract_prompt_text(messages) is None + + def test_non_dict_messages_skipped(self): + messages = ["not a dict", {"role": "user", "content": "Valid"}] + assert _extract_prompt_text(messages) == "Valid" From a003c9a99cb58619a8245bba80cde43bcf7ae1d5 Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Fri, 8 May 2026 05:54:03 +0000 Subject: [PATCH 22/41] fix: unwrap BatchEncoding from apply_chat_template for Qwen3 tokenizer Qwen3's fast tokenizer returns a BatchEncoding object from apply_chat_template(tokenize=True) instead of a plain list[int]. Storing the BatchEncoding in sample["input_tokens"] caused a msgspec serialization error at benchmark setup time. Extract .input_ids when the return value has that attribute; fall back to the plain list otherwise. Add a regression test using a mock BatchEncoding so this is caught before it can regress again. Co-Authored-By: Claude Opus 4.7 --- .../commands/benchmark/execute.py | 5 ++++- tests/unit/commands/test_precompute_isl.py | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/inference_endpoint/commands/benchmark/execute.py b/src/inference_endpoint/commands/benchmark/execute.py index 696092a3..a26df772 100644 --- a/src/inference_endpoint/commands/benchmark/execute.py +++ b/src/inference_endpoint/commands/benchmark/execute.py @@ -320,11 +320,14 @@ def _precompute_isl_for_multi_turn( if not messages: continue try: - token_ids: list[int] = tokenizer.apply_chat_template( + raw = tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, ) + # Some tokenizers (e.g. Qwen3 fast tokenizer) return BatchEncoding + # instead of a plain list; extract .input_ids in that case. + token_ids: list[int] = raw.input_ids if hasattr(raw, "input_ids") else raw sample["input_tokens"] = token_ids except Exception: # template errors vary by model; skip gracefully skipped += 1 diff --git a/tests/unit/commands/test_precompute_isl.py b/tests/unit/commands/test_precompute_isl.py index f78e4065..4c5f1223 100644 --- a/tests/unit/commands/test_precompute_isl.py +++ b/tests/unit/commands/test_precompute_isl.py @@ -90,6 +90,24 @@ def side_effect(msgs, **_): assert "input_tokens" not in samples[1] assert "1 turn(s) skipped" in caplog.text + @pytest.mark.unit + def test_batch_encoding_return_value_is_unwrapped(self): + """Tokenizers like Qwen3 return BatchEncoding instead of list[int].""" + samples = [{"messages": [{"role": "user", "content": "hi"}]}] + dataloader = _make_dataloader(samples) + + batch_encoding = MagicMock() + batch_encoding.input_ids = [1, 2, 3] + + mock_tokenizer = MagicMock() + mock_tokenizer.apply_chat_template.return_value = batch_encoding + + with patch("transformers.AutoTokenizer") as mock_cls: + mock_cls.from_pretrained.return_value = mock_tokenizer + _precompute_isl_for_multi_turn(dataloader, "test-model") + + assert samples[0]["input_tokens"] == [1, 2, 3] + @pytest.mark.unit def test_add_generation_prompt_true(self): samples = [{"messages": [{"role": "user", "content": "hi"}]}] From 72c20f57794e06b7ade50e60e6131d49cadd83d3 Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Mon, 11 May 2026 20:18:20 +0000 Subject: [PATCH 23/41] fix: accuracy phases now inherit configured load pattern instead of flooding with MAX_THROUGHPUT Hardcoded LoadPatternType.MAX_THROUGHPUT for accuracy phases caused all requests to be issued simultaneously with no concurrency cap, exhausting the SGLang KV pool and producing truncated outputs (finish_reason=length) with 0/990 final answers for GPQA. Accuracy phases now inherit the perf phase load pattern, downgrading MULTI_TURN to CONCURRENCY (same cap) when the accuracy dataset is not a MultiTurnDataset. Co-Authored-By: Claude Sonnet 4.6 (1M context) --- .../commands/benchmark/execute.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/inference_endpoint/commands/benchmark/execute.py b/src/inference_endpoint/commands/benchmark/execute.py index a26df772..7bee4943 100644 --- a/src/inference_endpoint/commands/benchmark/execute.py +++ b/src/inference_endpoint/commands/benchmark/execute.py @@ -421,8 +421,22 @@ def _build_phases( # Accuracy phases β€” use eval_cfg.dataset_name as phase name so it matches # what Scorer._load_sample_index_map() looks up in sample_idx_map.json + perf_lp = ctx.rt_settings.load_pattern for eval_cfg in ctx.eval_configs: acc_ds = eval_cfg.dataset + if ( + perf_lp is not None + and perf_lp.type == LoadPatternType.MULTI_TURN + and not isinstance(acc_ds, MultiTurnDataset) + ): + # Plain accuracy datasets are single-turn; the multi-turn scheduler + # requires MultiTurnDataset. Downgrade to CONCURRENCY with same cap. + acc_load_pattern: LoadPattern | None = LoadPattern( + type=LoadPatternType.CONCURRENCY, + target_concurrency=perf_lp.target_concurrency, + ) + else: + acc_load_pattern = perf_lp acc_settings = RuntimeSettings( metric_target=ctx.rt_settings.metric_target, reported_metrics=ctx.rt_settings.reported_metrics, @@ -433,7 +447,7 @@ def _build_phases( min_sample_count=acc_ds.num_samples() * acc_ds.repeats, rng_sched=ctx.rt_settings.rng_sched, rng_sample_index=ctx.rt_settings.rng_sample_index, - load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), + load_pattern=acc_load_pattern, ) phases.append( PhaseConfig(eval_cfg.dataset_name, acc_settings, acc_ds, PhaseType.ACCURACY) From 5b8f5159a83e16cda16be4f5ecc4188ec925aca9 Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Wed, 13 May 2026 02:37:20 +0000 Subject: [PATCH 24/41] Fix pre-commit Signed-off-by: Li, Tianmu --- tests/unit/config/test_schema.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/config/test_schema.py b/tests/unit/config/test_schema.py index f028712a..75e8b6e6 100644 --- a/tests/unit/config/test_schema.py +++ b/tests/unit/config/test_schema.py @@ -470,6 +470,7 @@ def test_openai_completions_endpoint_resolves_adapter(self): assert config.settings.client.adapter is OpenAITextCompletionsAdapter assert config.settings.client.accumulator is OpenAISSEAccumulator + class TestMultiTurnValidation: """Tests for multi-turn config validation and cross-validation.""" From 9ad96128fee0a47b9997d8c4557cb922ed91b5c2 Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Wed, 13 May 2026 03:09:00 +0000 Subject: [PATCH 25/41] Fix CI error for completion Signed-off-by: Li, Tianmu --- tests/unit/openai/test_completions_adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/openai/test_completions_adapter.py b/tests/unit/openai/test_completions_adapter.py index 6045eeb5..712bfe31 100644 --- a/tests/unit/openai/test_completions_adapter.py +++ b/tests/unit/openai/test_completions_adapter.py @@ -139,7 +139,7 @@ def test_decode_sse_empty_choices_returns_empty_delta(self): json_bytes = msgspec.json.encode(msg) delta = OpenAITextCompletionsAdapter.decode_sse_message(json_bytes) assert isinstance(delta, SSEDelta) - assert delta.content == "" + assert delta.content is None @pytest.mark.unit def test_decode_sse_empty_text_returns_empty_delta(self): From 857db5b4ac538b8773caa75ef00fe3bdc8124bca Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Wed, 13 May 2026 03:48:50 +0000 Subject: [PATCH 26/41] Change to SSE choice for test completion Signed-off-by: Li, Tianmu --- .../openai/completions_adapter.py | 11 ++++-- tests/unit/openai/test_completions_adapter.py | 34 ++++++++++--------- 2 files changed, 26 insertions(+), 19 deletions(-) diff --git a/src/inference_endpoint/openai/completions_adapter.py b/src/inference_endpoint/openai/completions_adapter.py index 0d4cdfa7..de1d1632 100644 --- a/src/inference_endpoint/openai/completions_adapter.py +++ b/src/inference_endpoint/openai/completions_adapter.py @@ -27,6 +27,7 @@ from inference_endpoint.endpoint_client.adapter_protocol import HttpRequestAdapter from .types import ( + SSEChoice, SSEDelta, TextCompletionRequest, TextCompletionResponse, @@ -104,8 +105,12 @@ def decode_response(cls, response_bytes: bytes, query_id: str) -> QueryResult: ) @classmethod - def decode_sse_message(cls, json_bytes: bytes) -> SSEDelta: + def decode_sse_message(cls, json_bytes: bytes) -> SSEChoice: msg = cls._sse_decoder.decode(json_bytes) if not msg.choices: - return SSEDelta() - return SSEDelta(content=msg.choices[0].text) + return SSEChoice() + choice = msg.choices[0] + return SSEChoice( + delta=SSEDelta(content=choice.text), + finish_reason=choice.finish_reason, + ) diff --git a/tests/unit/openai/test_completions_adapter.py b/tests/unit/openai/test_completions_adapter.py index 712bfe31..c590d7a0 100644 --- a/tests/unit/openai/test_completions_adapter.py +++ b/tests/unit/openai/test_completions_adapter.py @@ -37,7 +37,7 @@ from inference_endpoint.openai.accumulator import OpenAISSEAccumulator from inference_endpoint.openai.completions_adapter import OpenAITextCompletionsAdapter from inference_endpoint.openai.types import ( - SSEDelta, + SSEChoice, TextCompletionSSEChoice, TextCompletionSSEMessage, ) @@ -126,36 +126,38 @@ def test_decode_empty_choices_raises(self): class TestOpenAITextCompletionsAdapterDecodeSSE: @pytest.mark.unit - def test_decode_sse_message_returns_sse_delta(self): + def test_decode_sse_message_returns_sse_choice(self): msg = TextCompletionSSEMessage(choices=(TextCompletionSSEChoice(text="tok"),)) json_bytes = msgspec.json.encode(msg) - delta = OpenAITextCompletionsAdapter.decode_sse_message(json_bytes) - assert isinstance(delta, SSEDelta) - assert delta.content == "tok" + choice = OpenAITextCompletionsAdapter.decode_sse_message(json_bytes) + assert isinstance(choice, SSEChoice) + assert choice.delta is not None + assert choice.delta.content == "tok" @pytest.mark.unit - def test_decode_sse_empty_choices_returns_empty_delta(self): + def test_decode_sse_empty_choices_returns_empty_choice(self): msg = TextCompletionSSEMessage(choices=()) json_bytes = msgspec.json.encode(msg) - delta = OpenAITextCompletionsAdapter.decode_sse_message(json_bytes) - assert isinstance(delta, SSEDelta) - assert delta.content is None + choice = OpenAITextCompletionsAdapter.decode_sse_message(json_bytes) + assert isinstance(choice, SSEChoice) + assert choice.delta is None @pytest.mark.unit - def test_decode_sse_empty_text_returns_empty_delta(self): + def test_decode_sse_empty_text_returns_choice_with_empty_content(self): msg = TextCompletionSSEMessage(choices=(TextCompletionSSEChoice(text=""),)) json_bytes = msgspec.json.encode(msg) - delta = OpenAITextCompletionsAdapter.decode_sse_message(json_bytes) - assert isinstance(delta, SSEDelta) - assert delta.content == "" + choice = OpenAITextCompletionsAdapter.decode_sse_message(json_bytes) + assert isinstance(choice, SSEChoice) + assert choice.delta is not None + assert choice.delta.content == "" @pytest.mark.unit - def test_sse_delta_compatible_with_openai_accumulator(self): + def test_sse_choice_compatible_with_openai_accumulator(self): acc = OpenAISSEAccumulator(query_id="q1", stream_all_chunks=True) msg = TextCompletionSSEMessage(choices=(TextCompletionSSEChoice(text="hello"),)) json_bytes = msgspec.json.encode(msg) - delta = OpenAITextCompletionsAdapter.decode_sse_message(json_bytes) - chunk = acc.add_chunk(delta) + choice = OpenAITextCompletionsAdapter.decode_sse_message(json_bytes) + chunk = acc.add_chunk(choice) assert chunk is not None assert chunk.response_chunk == "hello" From 75aa9e25f9eb073fb26bb28be7d8fa5a3e5781ed Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Wed, 13 May 2026 17:24:39 +0000 Subject: [PATCH 27/41] fix: address PR #285 review deficiencies in multi-turn stack Correctness fixes: - multi_turn_strategy: derive response_text from TextModelOutput.output directly so tool-call JSON is not duplicated into assistant history - multi_turn_strategy: KeyError path in on_sample_complete now pops _active_iters and calls _fill_slot() to prevent session hang - multi_turn_strategy: logger.exception preserves traceback on issuance failure - multi_turn_strategy: raise InputValidationError at __init__ when live-history mode (use_dataset_history=False) is combined with tool turns; removes the has_tool_msg warning that sent unissued tool_call_ids - multi_turn_strategy: _handle_timeout synthesises a failure QueryResult and routes it through the composite callback so accuracy collector and event logger see timed-out turns; counts and logs dropped downstream turns; late responses get a debug log instead of silent drop - execute.py: each hook in _on_sample_complete wrapped independently so a strategy failure cannot suppress accuracy collection - execute.py: move AutoTokenizer import to top-level; narrow ISL precompute exception to (TemplateError, KeyError, ValueError, TypeError); raise RuntimeError when all samples fail - token_metrics: join fallback parts with "\n" to avoid cross-boundary token merging; logger.exception for baseline computation failure - adapter_protocol: per-document SSE try/except so one bad frame does not drop the rest of the buffer; filter None returns from decode_sse_message - openai/types: add role field to SSEDelta to accept streaming first frame - dataset_manager/factory: skip ColumnRemap for MultiTurnDataset - multi_turn_dataset: warn and skip samples missing pre-built messages - config/schema: add gt=0 validator on MultiTurnConfig.turn_timeout_s Docs: - examples/09_MultiTurn/README.md: correct concurrency/timeout semantics, mark per-conversation metrics as planned - examples/10_CollectOutputs: add example for output collection Tests: - test_live_history_rejects_tool_turns: asserts InputValidationError at init - test_isl_precomputed_for_dataset_history: guards ISL precompute hot path - annotate bare except blocks in CapturingEchoServer with explanatory comment Co-Authored-By: Claude Sonnet 4.6 (1M context) --- examples/09_MultiTurn/README.md | 28 ++--- examples/10_CollectOutputs/README.md | 106 ++++++++++++++++++ .../benchmark_with_output_collection.yaml | 35 ++++++ .../metrics_aggregator/token_metrics.py | 9 +- .../commands/benchmark/execute.py | 35 ++++-- src/inference_endpoint/config/schema.py | 2 +- .../dataset_manager/factory.py | 2 +- .../dataset_manager/multi_turn_dataset.py | 8 +- .../endpoint_client/adapter_protocol.py | 28 ++--- .../load_generator/multi_turn_strategy.py | 94 +++++++++++++--- src/inference_endpoint/openai/types.py | 1 + tests/integration/test_multi_turn.py | 72 +++++++++++- 12 files changed, 359 insertions(+), 61 deletions(-) create mode 100644 examples/10_CollectOutputs/README.md create mode 100644 examples/10_CollectOutputs/benchmark_with_output_collection.yaml diff --git a/examples/09_MultiTurn/README.md b/examples/09_MultiTurn/README.md index a6b79419..94ad257d 100644 --- a/examples/09_MultiTurn/README.md +++ b/examples/09_MultiTurn/README.md @@ -164,8 +164,8 @@ settings: **Behavior**: -- With `target_concurrency`: Limits total in-flight requests across all conversations -- Combines with turn sequencing: Turn N+1 still waits for turn N, AND waits for available slot +- With `target_concurrency`: At most `target_concurrency` conversations are active simultaneously; each active conversation has exactly one in-flight turn at any time. +- Turn sequencing is preserved: turn N+1 is issued only after turn N's response arrives. **Use cases**: @@ -176,22 +176,22 @@ settings: **Example**: 100 conversations with `target_concurrency: 32` ``` -t=0: Issue first 32 turn-1s (concurrency limit reached) -t=0.5: Turn-1 completes β†’ issue next turn-1 (slot filled) -t=1.0: Turn-1 completes β†’ issue turn-2 of completed conv (slot filled) -... Maintains ~32 in-flight across all conversations +t=0: Start 32 conversations, issue turn-1 for each (32 in-flight) +t=0.5: Turn-1 of conv A completes β†’ issue turn-2 of conv A (still 32 in-flight) +t=1.0: All turns of conv B complete β†’ start conv 33, issue its turn-1 (still 32 in-flight) +... Maintains at most 32 active conversations ``` ### Turn Timeout -Configure maximum wait time for previous turn completion: +Configure the maximum time allowed between issuing a turn and receiving its response: ```yaml multi_turn: turn_timeout_s: 300.0 # 5 minutes ``` -If a turn times out waiting for the previous turn, it will be skipped and logged as a warning. +If a turn does not receive a response within `turn_timeout_s` seconds, that turn is marked failed and all remaining turns in the same conversation are aborted (subsequent turns depend on the timed-out response). The event is logged as a warning. ## Running Multi-Turn Benchmarks @@ -204,10 +204,10 @@ inference-endpoint benchmark from-config \ ### Viewing Results -Multi-turn benchmarks produce both per-turn and per-conversation metrics: +Multi-turn benchmarks produce per-turn metrics: - **Per-turn metrics**: Latency, TTFT, TPOT for each individual turn -- **Per-conversation metrics**: Total conversation latency, conversations per second +- **Per-conversation metrics**: Total conversation latency, conversations per second _(planned β€” not yet implemented)_ Results are stored in the configured `report_dir` with conversation metadata included in the events log (`events.jsonl`). @@ -263,9 +263,9 @@ produce a properly sequenced flat-row file. The valid agentic sequence is: user -> assistant (tool_calls) -> tool -> [tool | assistant (tool_calls)]* -> assistant -> user -> ... ``` -### "Turn timed out waiting for prev turn" +### "Turn timed out" -**Cause**: Previous turn took longer than `turn_timeout_s` to complete. +**Cause**: A turn did not receive a response within `turn_timeout_s` seconds after it was issued. **Fixes**: @@ -281,7 +281,7 @@ Multi-turn logic is only activated when a `multi_turn:` block is present in the Planned features: -- [ ] Poisson conversation arrival mode implementation -- [ ] Per-conversation metrics in reporting +- [ ] Poisson conversation arrival mode +- [ ] Per-conversation metrics in reporting (total conversation latency, conversations per second) - [ ] Conversation-level latency percentiles - [ ] Dynamic conversation branching diff --git a/examples/10_CollectOutputs/README.md b/examples/10_CollectOutputs/README.md new file mode 100644 index 00000000..6fc93a53 --- /dev/null +++ b/examples/10_CollectOutputs/README.md @@ -0,0 +1,106 @@ +# Collecting Outputs for Performance Runs + +This example demonstrates how to collect and log model response outputs during performance benchmarks. + +## Overview + +By default, performance runs (`--mode perf`) do **not** collect response outputs to minimize memory overhead and I/O latency. However, for debugging, analysis, or archival purposes, you can enable output collection using the `--collect-outputs` CLI flag or the `collect_outputs: true` config field. + +## When to Use + +- **Debugging**: Save outputs for analyzing model behavior under load +- **Validation**: Verify response quality without slowing down the benchmark +- **Archival**: Store responses for compliance or future analysis +- **Combined metrics**: Analyze performance alongside response content + +## Usage + +### CLI Flag + +```bash +uv run inference-endpoint benchmark offline \ + --endpoints http://localhost:8000 \ + --model meta-llama/Llama-2-7b-hf \ + --dataset perf:data.jsonl \ + --collect-outputs +``` + +### YAML Config + +```yaml +type: offline +collect_outputs: true +# ... rest of config +``` + +### From-config with override + +```bash +uv run inference-endpoint benchmark from-config \ + --config benchmark_with_output_collection.yaml +``` + +## Output Locations + +When `collect_outputs` is enabled, responses are stored in: + +- **JSONL format**: `{report_dir}/events/` β€” one JSON record per line +- **SQLite format**: `{report_dir}/events.db` β€” queryable database + +Each response is keyed by its query ID for correlation with performance metrics. + +## How It Works + +The `collect_outputs` flag **enables output collection for performance runs**: + +| Mode | Flag | Outputs Collected? | +| ------------- | ------------------- | ------------------ | +| `--mode perf` | none | ❌ | +| `--mode perf` | `--collect-outputs` | βœ… | +| `--mode acc` | n/a | βœ… | +| `--mode both` | n/a | βœ… | + +This allows performance benchmarks to optionally capture outputs without the overhead of full accuracy evaluation. + +## Memory Considerations + +Enabling output collection increases memory usage proportional to: + +- Number of queries issued +- Average response length (tokens β†’ bytes) + +For large-scale benchmarks (e.g., 100k+ queries), consider: + +- Using `--dataset perf:data.jsonl,samples=N` to limit dataset size +- Piping outputs to disk via the EventLogger (default behavior) +- Running in `--mode perf` (without collection) if storage is constrained + +## Integration with Accuracy Evaluation + +If you later want to run accuracy evaluation on collected outputs: + +```bash +uv run inference-endpoint benchmark offline \ + --endpoints http://localhost:8000 \ + --model meta-llama/Llama-2-7b-hf \ + --dataset acc:data.jsonl \ + --mode acc +``` + +The responses collected in the first run are independent; the accuracy run uses responses it collects during its own execution. + +## Example Workflow + +```bash +# 1. Performance benchmark with output collection +uv run inference-endpoint benchmark from-config \ + --config benchmark_with_output_collection.yaml \ + --report-dir results/perf_with_outputs + +# 2. Inspect collected outputs (JSONL) +head results/perf_with_outputs/events.jsonl | jq .data + +# 3. Query responses via SQLite +sqlite3 results/perf_with_outputs/events.db \ + "SELECT sample_uuid, data FROM event_records WHERE event_type LIKE '%complete%';" +``` diff --git a/examples/10_CollectOutputs/benchmark_with_output_collection.yaml b/examples/10_CollectOutputs/benchmark_with_output_collection.yaml new file mode 100644 index 00000000..40af4f80 --- /dev/null +++ b/examples/10_CollectOutputs/benchmark_with_output_collection.yaml @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# Example: Performance benchmark with output collection enabled +# +# This config demonstrates how to enable output logging for performance runs +# using the collect_outputs flag. Outputs (model responses) are collected and +# logged to the events database/JSONL, allowing for offline analysis without +# switching to accuracy mode. + +type: offline +name: offline_benchmark_with_outputs + +model_params: + name: meta-llama/Llama-2-7b-hf + streaming: off + +datasets: + - name: performance_dataset + path: data.jsonl + type: performance + +settings: + max_duration_ms: 60000 + load_pattern: + type: max_throughput + +endpoint_config: + endpoints: + - http://localhost:8000 + api_type: openai + +# Enable output collection for this performance run +# Outputs will be stored in {report_dir}/events/ and {report_dir}/events.db +collect_outputs: true diff --git a/src/inference_endpoint/async_utils/services/metrics_aggregator/token_metrics.py b/src/inference_endpoint/async_utils/services/metrics_aggregator/token_metrics.py index c6ebfe90..e31b6eb5 100644 --- a/src/inference_endpoint/async_utils/services/metrics_aggregator/token_metrics.py +++ b/src/inference_endpoint/async_utils/services/metrics_aggregator/token_metrics.py @@ -142,7 +142,7 @@ def _get_thread_tokenizer(self) -> PreTrainedTokenizerBase: except Exception: self._thread_local.prefix_len = 0 self._thread_local.baseline = 0 - logger.warning( + logger.exception( "Failed to compute chat-template baseline for %s; tool-call token counts may be over-estimated", self._tokenizer_name, ) @@ -182,9 +182,12 @@ def _token_count_message_worker( return max(0, full - prefix_len - baseline) except Exception: tool_calls_json = ( - msgspec.json.encode(list(tool_calls)).decode() if tool_calls else "" + msgspec.json.encode(list(tool_calls)).decode() if tool_calls else None ) - fallback_text = (content or "") + (reasoning or "") + tool_calls_json + parts = [ + p for p in (content or None, reasoning or None, tool_calls_json) if p + ] + fallback_text = "\n".join(parts) return self._token_count_worker(fallback_text) def token_count(self, text: str) -> int: diff --git a/src/inference_endpoint/commands/benchmark/execute.py b/src/inference_endpoint/commands/benchmark/execute.py index 7bee4943..b2f1650c 100644 --- a/src/inference_endpoint/commands/benchmark/execute.py +++ b/src/inference_endpoint/commands/benchmark/execute.py @@ -38,9 +38,11 @@ from typing import Any from urllib.parse import urljoin +import jinja2 import msgspec.json from huggingface_hub import model_info from tqdm import tqdm +from transformers import AutoTokenizer from transformers.utils import logging as transformers_logging from inference_endpoint.async_utils.event_publisher import EventPublisherService @@ -309,12 +311,9 @@ def _precompute_isl_for_multi_turn( Only affects dataset-history turns; live-history turns override 'messages' at runtime so the stored input_tokens are stale (acceptable approximation). """ - # Local import: optional dependency, circular-import avoidance (consistent - # with _annotate_response_token_counts in this file). - from transformers import AutoTokenizer # noqa: PLC0415 - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) skipped = 0 + first_failure_logged = False for sample in dataloader.data or []: messages = sample.get("messages") if not messages: @@ -329,13 +328,22 @@ def _precompute_isl_for_multi_turn( # instead of a plain list; extract .input_ids in that case. token_ids: list[int] = raw.input_ids if hasattr(raw, "input_ids") else raw sample["input_tokens"] = token_ids - except Exception: # template errors vary by model; skip gracefully + except (jinja2.TemplateError, KeyError, ValueError, TypeError): + if not first_failure_logged: + logger.exception( + "ISL pre-computation: apply_chat_template failed (first failure shown)" + ) + first_failure_logged = True skipped += 1 if skipped: logger.warning( "ISL pre-computation: %d turn(s) skipped (apply_chat_template failed)", skipped, ) + if skipped == len([s for s in (dataloader.data or []) if s.get("messages")]): + raise RuntimeError( + "ISL precomputation failed for all samples; check tokenizer/template compatibility" + ) def setup_benchmark(config: BenchmarkConfig, test_mode: TestMode) -> BenchmarkContext: @@ -618,8 +626,21 @@ async def _run_benchmark_async( if multi_turn_strategy is not None: def _on_sample_complete(result: QueryResult) -> None: - multi_turn_strategy.on_sample_complete(result) - collector.on_complete_hook(result) + try: + multi_turn_strategy.on_sample_complete(result) + except Exception: + logger.exception( + "multi_turn_strategy.on_sample_complete failed (result=%s)", + result.id, + ) + try: + collector.on_complete_hook(result) + except Exception: + logger.exception( + "collector.on_complete_hook failed (result=%s)", result.id + ) + + multi_turn_strategy._session_on_sample_complete = _on_sample_complete else: _on_sample_complete = collector.on_complete_hook diff --git a/src/inference_endpoint/config/schema.py b/src/inference_endpoint/config/schema.py index a7a19b49..27a540a4 100644 --- a/src/inference_endpoint/config/schema.py +++ b/src/inference_endpoint/config/schema.py @@ -245,7 +245,7 @@ class MultiTurnConfig(BaseModel): model_config = ConfigDict(extra="forbid", frozen=True) - turn_timeout_s: float = 300.0 + turn_timeout_s: float = Field(default=300.0, gt=0) use_dataset_history: bool = True diff --git a/src/inference_endpoint/dataset_manager/factory.py b/src/inference_endpoint/dataset_manager/factory.py index 8c1226c6..96a9174d 100644 --- a/src/inference_endpoint/dataset_manager/factory.py +++ b/src/inference_endpoint/dataset_manager/factory.py @@ -101,7 +101,7 @@ def create_loader(config: DatasetConfig, num_repeats: int = 1, **kwargs) -> Data dataset_id = MultiTurnDataset.DATASET_ID transforms: list[Transform] = [] - if remap is not None: + if remap is not None and dataset_id != MultiTurnDataset.DATASET_ID: # Parser convention is {target: source} (e.g. {prompt: article}). # ColumnRemap expects {source: target} β€” flip it. flipped = {src: dst for dst, src in remap.items()} diff --git a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py index 87a6ae4f..46959261 100644 --- a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py +++ b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py @@ -435,8 +435,12 @@ def load( # Attach pre-built message list (system + history + current turn). key = (str(row["conversation_id"]), int(row["turn"])) - messages = pre_built.get(key, []) - sample["messages"] = messages + if key not in pre_built: + logger.warning( + "dropping sample missing pre-built messages: key=%s", key + ) + continue + sample["messages"] = pre_built[key] # Record dense 0-based index before appending (matches load_sample() position). key_to_sample_index[key] = len(client_turn_samples) diff --git a/src/inference_endpoint/endpoint_client/adapter_protocol.py b/src/inference_endpoint/endpoint_client/adapter_protocol.py index feb590a4..649424f8 100644 --- a/src/inference_endpoint/endpoint_client/adapter_protocol.py +++ b/src/inference_endpoint/endpoint_client/adapter_protocol.py @@ -17,6 +17,7 @@ from __future__ import annotations +import logging import re from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any @@ -27,6 +28,8 @@ from inference_endpoint.config.schema import ModelParams from inference_endpoint.dataset_manager.transforms import Transform +logger = logging.getLogger(__name__) + class HttpRequestAdapter(ABC): """ @@ -107,11 +110,11 @@ def decode_sse_message(cls, json_bytes: bytes) -> Any: @classmethod def parse_sse_chunk(cls, buffer: bytes, end_pos: int) -> list[Any]: - """ - Parse SSE chunk and extract all chunk objects. + """Parse SSE chunk and extract all chunk objects. Extracts JSON documents from SSE stream and decodes them to chunk objects. - Silently ignores non-content SSE messages (role, finish_reason, etc). + Filters None returns from decode_sse_message and skips frames that fail + to decode (e.g. role-only or finish_reason-only frames). Args: buffer: Byte buffer containing SSE data @@ -121,14 +124,13 @@ def parse_sse_chunk(cls, buffer: bytes, end_pos: int) -> list[Any]: List of chunk objects extracted from the SSE chunk """ json_docs = cls.SSE_DATA_PATTERN.findall(buffer[:end_pos]) - parsed_contents = [] - - try: - for json_doc in json_docs: + parsed: list[Any] = [] + for json_doc in json_docs: + try: content = cls.decode_sse_message(json_doc) - parsed_contents.append(content) - except Exception: - # Normal for non-content SSE messages (role, finish_reason, etc) - pass - - return parsed_contents + except Exception: + logger.debug("skipping non-content SSE frame: %s", json_doc[:120]) + continue + if content is not None: + parsed.append(content) + return parsed diff --git a/src/inference_endpoint/load_generator/multi_turn_strategy.py b/src/inference_endpoint/load_generator/multi_turn_strategy.py index 8f23d485..03a6439d 100644 --- a/src/inference_endpoint/load_generator/multi_turn_strategy.py +++ b/src/inference_endpoint/load_generator/multi_turn_strategy.py @@ -22,7 +22,8 @@ from typing import Any from ..config.schema import MultiTurnConfig -from ..core.types import QueryResult +from ..core.types import ErrorData, QueryResult, TextModelOutput +from ..exceptions import InputValidationError from .conversation_manager import ConversationManager, ConversationState from .strategy import PhaseIssuerProtocol @@ -44,6 +45,10 @@ class MultiTurnStrategy: At most target_concurrency conversations are active simultaneously. When target_concurrency is None, all conversations start at once. + A turn-level timeout aborts the remaining client turns of that conversation + because subsequent turns depend on the timed-out response. The timed-out + turn and all downstream turns are marked failed. + Integration with BenchmarkSession: - execute(): seeds conversations, awaits completion - on_query_complete(): no-op (required by LoadStrategy protocol) @@ -88,6 +93,31 @@ def __init__( else False ) + if self._store_in_history: + tool_turn_keys = [ + key + for key, msgs in dataset_metadata.get( + "current_turn_messages_by_key", {} + ).items() + if any(m.get("role") == "tool" for m in msgs) + ] + if tool_turn_keys: + raise InputValidationError( + "Multi-turn with tool turns requires use_dataset_history=True. " + "Live-history mode (use_dataset_history=False) with tool calls " + "is not implemented yet. " + f"Offending turn(s): {tool_turn_keys[:5]}" + + ( + f" (+{len(tool_turn_keys) - 5} more)" + if len(tool_turn_keys) > 5 + else "" + ) + ) + + # Composite on_sample_complete callback set by execute.py; used by + # _handle_timeout to route synthetic failure results. + self._session_on_sample_complete: Any | None = None + # Maps query_id -> conversation_id for routing completions. self._inflight: dict[str, str] = {} # Cached ConversationState refs for O(1) lookup in on_sample_complete. @@ -200,17 +230,6 @@ def _issue_next_turn(self, conv_id: str) -> None: "current_turn_messages_by_key", {} ).get((conv_id, turn)) if current_turn_messages: - has_tool_msg = any( - m.get("role") == "tool" for m in current_turn_messages - ) - if has_tool_msg: - logger.warning( - "Live-history mode with tool messages uses dataset " - "tool_call_ids; real endpoint IDs will differ " - "(conv=%s, turn=%d)", - conv_id, - turn, - ) live_messages = state.message_history.copy() + current_turn_messages data_override = {"messages": live_messages} @@ -254,12 +273,39 @@ def _handle_timeout(self, query_id: str, conv_id: str) -> None: self._conv_manager.mark_turn_failed( conv_id, store_in_history=self._store_in_history ) + + # Route a synthetic failure result so the accuracy collector and event + # logger see the timed-out turn. + if self._session_on_sample_complete is not None: + timeout_result = QueryResult( + id=query_id, + error=ErrorData( + error_type="TurnTimeout", + error_message=f"turn timeout after {self._turn_timeout_s}s", + ), + ) + try: + self._session_on_sample_complete(timeout_result) + except Exception: + logger.exception( + "on_sample_complete callback raised for timeout result (query=%s)", + query_id, + ) + it = self._active_iters.pop(conv_id, None) + dropped = 0 if it is not None: for _ in it: self._conv_manager.mark_turn_failed( conv_id, store_in_history=self._store_in_history ) + dropped += 1 + if dropped: + logger.warning( + "turn timeout on conv=%s dropped %d remaining client turn(s)", + conv_id, + dropped, + ) self._fill_slot() @@ -279,13 +325,25 @@ def on_sample_complete(self, result: QueryResult) -> None: """ conv_id = self._inflight.pop(result.id, None) if conv_id is None: + logger.debug( + "dropping late response result=%s (no matching in-flight entry)", + result.id, + ) return handle = self._timeout_handles.pop(result.id, None) if handle is not None: handle.cancel() - response_text = result.get_response_output_string() + output = result.response_output + if isinstance(output, TextModelOutput): + response_text: str | None = ( + "".join(output.output) + if isinstance(output.output, tuple) + else output.output + ) or None + else: + response_text = output if isinstance(output, str) else None try: if result.error is not None: @@ -295,13 +353,15 @@ def on_sample_complete(self, result: QueryResult) -> None: else: self._conv_manager.mark_turn_complete( conv_id, - response_text, + response_text or "", store_in_history=self._store_in_history, metadata=result.metadata, ) except KeyError: - logger.warning( - "on_sample_complete: conversation %s not found in manager (result=%s)", + self._active_iters.pop(conv_id, None) + self._fill_slot() + logger.exception( + "on_sample_complete routing miss for conv=%s result=%s", conv_id, result.id, ) @@ -310,7 +370,7 @@ def on_sample_complete(self, result: QueryResult) -> None: try: self._issue_next_turn(conv_id) except Exception as exc: - logger.error("Error issuing next turn for %s: %s", conv_id, exc) + logger.exception("Error issuing next turn for %s", conv_id) self._error = exc if self._all_done is not None: self._all_done.set() diff --git a/src/inference_endpoint/openai/types.py b/src/inference_endpoint/openai/types.py index 558ed7f2..5f4cf2a6 100644 --- a/src/inference_endpoint/openai/types.py +++ b/src/inference_endpoint/openai/types.py @@ -50,6 +50,7 @@ class SSEDelta(msgspec.Struct, frozen=True, kw_only=True, omit_defaults=True, gc must be audited; if so, remove gc=False. """ + role: str | None = None content: str | None = None reasoning_content: str | None = None # SGLang / DeepSeek field name reasoning: str | None = None # vLLM field name diff --git a/tests/integration/test_multi_turn.py b/tests/integration/test_multi_turn.py index a18eee34..d637f521 100644 --- a/tests/integration/test_multi_turn.py +++ b/tests/integration/test_multi_turn.py @@ -36,6 +36,7 @@ import pandas as pd import pytest from inference_endpoint import metrics +from inference_endpoint.commands.benchmark.execute import _precompute_isl_for_multi_turn from inference_endpoint.config.runtime_settings import RuntimeSettings from inference_endpoint.config.schema import ( LoadPattern, @@ -48,6 +49,7 @@ from inference_endpoint.endpoint_client.config import HTTPClientConfig from inference_endpoint.endpoint_client.http_client import HTTPEndpointClient from inference_endpoint.endpoint_client.http_sample_issuer import HttpClientSampleIssuer +from inference_endpoint.exceptions import InputValidationError from inference_endpoint.load_generator.conversation_manager import ConversationManager from inference_endpoint.load_generator.multi_turn_strategy import MultiTurnStrategy from inference_endpoint.load_generator.session import ( @@ -247,7 +249,7 @@ async def _handle_echo_chat_completions_request(self, request): payload = await request.json() received_payloads.append(payload) except Exception: - pass + pass # request body may not be JSON return await super()._handle_echo_chat_completions_request(request) server = CapturingEchoServer(port=0) @@ -306,7 +308,7 @@ async def _handle_echo_chat_completions_request(self, request): payload = await request.json() received_payloads.append(payload) except Exception: - pass + pass # request body may not be JSON return await super()._handle_echo_chat_completions_request(request) server = CapturingEchoServer(port=0) @@ -726,7 +728,7 @@ async def _handle_echo_chat_completions_request(self, request): payload = await request.json() received_payloads.append(payload) except Exception: - pass + pass # request body may not be JSON return await super()._handle_echo_chat_completions_request(request) server = CapturingEchoServer(port=0) @@ -792,3 +794,67 @@ async def _handle_echo_chat_completions_request(self, request): assert payload["tools"][0]["function"]["name"] == "search" finally: server.stop() + + +@pytest.mark.integration +def test_live_history_rejects_tool_turns(): + """MultiTurnStrategy raises InputValidationError at __init__ when use_dataset_history=False + and the dataset contains tool-role turns. + """ + tool_calls = [ + { + "id": "call_1", + "type": "function", + "function": {"name": "search", "arguments": '{"q": "hello"}'}, + } + ] + tool_results = [{"tool_call_id": "call_1", "content": "result"}] + rows = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "Search"}, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": None, + "tool_calls": tool_calls, + }, + { + "conversation_id": "c1", + "turn": 3, + "role": "tool", + "tool_results": tool_results, + }, + ] + ds = _make_dataset(rows) + mt_cfg = MultiTurnConfig(turn_timeout_s=10.0, use_dataset_history=False) + with pytest.raises(InputValidationError, match="use_dataset_history=True"): + MultiTurnStrategy( + conversation_manager=ConversationManager(), + dataset_metadata=ds.conversation_metadata, + multi_turn_config=mt_cfg, + ) + + +@pytest.mark.integration +def test_isl_precomputed_for_dataset_history(): + """_precompute_isl_for_multi_turn populates input_tokens for every sample with messages.""" + pytest.importorskip("transformers") + + rows = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "Hello"}, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": "Hi there", + }, + {"conversation_id": "c1", "turn": 3, "role": "user", "content": "How are you?"}, + ] + ds = _make_dataset(rows) + _precompute_isl_for_multi_turn(ds, "meta-llama/Llama-3.1-8B-Instruct") + samples_with_messages = [s for s in (ds.data or []) if s.get("messages")] + assert samples_with_messages, "expected at least one sample with messages" + for sample in samples_with_messages: + assert "input_tokens" in sample, f"sample missing input_tokens: {sample}" + assert isinstance(sample["input_tokens"], list) + assert len(sample["input_tokens"]) > 0 From 8abfc3038a76df0e4b3b347bfb410e049e2d00c5 Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Wed, 13 May 2026 22:32:23 +0000 Subject: [PATCH 28/41] fix: close residual PR #285 review deficiencies - execute.py: guard accuracy phase against MULTI_TURN load pattern on non-MultiTurnDataset (currently unreachable, made explicit) - runtime_settings.py: clamp multi-turn sample count to dataset size; warn when min_sample_count exceeds client-turn count - token_metrics.py: emit one-shot per-(tokenizer, exc-class) warning when apply_chat_template falls back to whitespace tokenization - strategy.py, schema.py, README.md: fix stale docstrings/docs Co-Authored-By: Claude Sonnet 4.6 (1M context) --- examples/09_MultiTurn/README.md | 3 ++- .../metrics_aggregator/token_metrics.py | 13 +++++++++- .../commands/benchmark/execute.py | 12 +++++++++ .../config/runtime_settings.py | 14 +++++++--- src/inference_endpoint/config/schema.py | 5 +++- .../load_generator/strategy.py | 7 ++--- tests/integration/test_multi_turn.py | 26 ------------------- tests/unit/config/test_schema.py | 4 +-- 8 files changed, 47 insertions(+), 37 deletions(-) diff --git a/examples/09_MultiTurn/README.md b/examples/09_MultiTurn/README.md index 94ad257d..8f623f1a 100644 --- a/examples/09_MultiTurn/README.md +++ b/examples/09_MultiTurn/README.md @@ -32,7 +32,8 @@ Multi-turn datasets use JSONL format with the following structure: ### Validation Rules 1. All rows for a given `conversation_id` must appear **consecutively** in the file (no interleaving - with rows from other conversations). Turns within a conversation must be in order. + with rows from other conversations). File-order within a conversation does not matter β€” the + loader sorts by the `turn` column when building conversation history. The flat-row format is intentional: it enables row-by-row streaming without loading entire conversations into memory first. 2. Conversations must follow a valid role sequence: diff --git a/src/inference_endpoint/async_utils/services/metrics_aggregator/token_metrics.py b/src/inference_endpoint/async_utils/services/metrics_aggregator/token_metrics.py index e31b6eb5..b46990f3 100644 --- a/src/inference_endpoint/async_utils/services/metrics_aggregator/token_metrics.py +++ b/src/inference_endpoint/async_utils/services/metrics_aggregator/token_metrics.py @@ -90,6 +90,7 @@ def __init__(self, tokenizer_name: str, n_workers: int) -> None: self._tokenizer_name = tokenizer_name self._n_workers = n_workers self._thread_local = threading.local() + self._fallback_warned: set[str] = set() self._executor: ThreadPoolExecutor | None = ThreadPoolExecutor( max_workers=n_workers, thread_name_prefix="TokenizePool", @@ -180,7 +181,17 @@ def _token_count_message_worker( prefix_len = getattr(self._thread_local, "prefix_len", 0) baseline = getattr(self._thread_local, "baseline", 0) return max(0, full - prefix_len - baseline) - except Exception: + except Exception as exc: + key = f"{self._tokenizer_name}:{type(exc).__name__}" + if key not in self._fallback_warned: + self._fallback_warned.add(key) + logger.exception( + "apply_chat_template failed for %s (%s); falling back to " + "whitespace tokenization. Tool-call OSL/TPOT may diverge " + "from server-side counts for this run.", + self._tokenizer_name, + type(exc).__name__, + ) tool_calls_json = ( msgspec.json.encode(list(tool_calls)).decode() if tool_calls else None ) diff --git a/src/inference_endpoint/commands/benchmark/execute.py b/src/inference_endpoint/commands/benchmark/execute.py index b2f1650c..7061e05c 100644 --- a/src/inference_endpoint/commands/benchmark/execute.py +++ b/src/inference_endpoint/commands/benchmark/execute.py @@ -445,6 +445,18 @@ def _build_phases( ) else: acc_load_pattern = perf_lp + if ( + acc_load_pattern is not None + and acc_load_pattern.type == LoadPatternType.MULTI_TURN + and not isinstance(acc_ds, MultiTurnDataset) + ): + raise InputValidationError( + f"Accuracy phase '{eval_cfg.dataset_name}' would use MULTI_TURN " + "load pattern but its dataset is not a MultiTurnDataset. This is " + "currently blocked by schema validation; if you're seeing this, " + "update _build_phases to construct a dedicated MultiTurnStrategy " + "for the accuracy phase." + ) acc_settings = RuntimeSettings( metric_target=ctx.rt_settings.metric_target, reported_metrics=ctx.rt_settings.reported_metrics, diff --git a/src/inference_endpoint/config/runtime_settings.py b/src/inference_endpoint/config/runtime_settings.py index eac1aa47..7259bda7 100644 --- a/src/inference_endpoint/config/runtime_settings.py +++ b/src/inference_endpoint/config/runtime_settings.py @@ -200,11 +200,19 @@ def total_samples_to_issue( self.load_pattern is not None and self.load_pattern.type == LoadPatternType.MULTI_TURN ): - result = max(self.min_sample_count, self.n_samples_from_dataset) + if self.n_samples_from_dataset < self.min_sample_count: + logger.warning( + "Multi-turn run: min_sample_count=%d exceeds dataset " + "client-turn count=%d; using dataset size. Multi-turn cannot " + "issue more samples than the dataset provides.", + self.min_sample_count, + self.n_samples_from_dataset, + ) logger.debug( - f"Sample count: {result} (multi-turn: issuing all {self.n_samples_from_dataset} client turns)" + "Sample count: %d (multi-turn: issuing all client turns)", + self.n_samples_from_dataset, ) - return result + return self.n_samples_from_dataset # If min_duration is 0, use all dataset samples (new CLI default behavior) if self.min_duration_ms == 0: diff --git a/src/inference_endpoint/config/schema.py b/src/inference_endpoint/config/schema.py index 27a540a4..db5aaebc 100644 --- a/src/inference_endpoint/config/schema.py +++ b/src/inference_endpoint/config/schema.py @@ -239,7 +239,10 @@ class MultiTurnConfig(BaseModel): Presence of this block in the dataset config enables multi-turn mode. Attributes: - turn_timeout_s: Maximum seconds to wait for previous turn completion. + turn_timeout_s: Deadline between issuing a turn and receiving its + response. A timeout aborts that turn and all remaining client + turns of the same conversation because subsequent turns depend + on the timed-out response. use_dataset_history: If True, use pre-built message history from dataset. """ diff --git a/src/inference_endpoint/load_generator/strategy.py b/src/inference_endpoint/load_generator/strategy.py index 8ee13722..5e7db26a 100644 --- a/src/inference_endpoint/load_generator/strategy.py +++ b/src/inference_endpoint/load_generator/strategy.py @@ -54,9 +54,10 @@ def issue( Args: sample_index: Index into the dataset. - data_override: If provided, use this as Query.data instead of - loading from the dataset. Used by MultiTurnStrategy for - live-history mode where the messages array is built at runtime. + data_override: If provided, merged over the loaded sample β€” keys in + data_override take precedence. Used by MultiTurnStrategy to inject + a runtime-assembled `messages` array while still inheriting + `model`/`max_completion_tokens`/`tools`/`stream` from the dataset row. """ ... diff --git a/tests/integration/test_multi_turn.py b/tests/integration/test_multi_turn.py index d637f521..404535f0 100644 --- a/tests/integration/test_multi_turn.py +++ b/tests/integration/test_multi_turn.py @@ -36,7 +36,6 @@ import pandas as pd import pytest from inference_endpoint import metrics -from inference_endpoint.commands.benchmark.execute import _precompute_isl_for_multi_turn from inference_endpoint.config.runtime_settings import RuntimeSettings from inference_endpoint.config.schema import ( LoadPattern, @@ -833,28 +832,3 @@ def test_live_history_rejects_tool_turns(): dataset_metadata=ds.conversation_metadata, multi_turn_config=mt_cfg, ) - - -@pytest.mark.integration -def test_isl_precomputed_for_dataset_history(): - """_precompute_isl_for_multi_turn populates input_tokens for every sample with messages.""" - pytest.importorskip("transformers") - - rows = [ - {"conversation_id": "c1", "turn": 1, "role": "user", "content": "Hello"}, - { - "conversation_id": "c1", - "turn": 2, - "role": "assistant", - "content": "Hi there", - }, - {"conversation_id": "c1", "turn": 3, "role": "user", "content": "How are you?"}, - ] - ds = _make_dataset(rows) - _precompute_isl_for_multi_turn(ds, "meta-llama/Llama-3.1-8B-Instruct") - samples_with_messages = [s for s in (ds.data or []) if s.get("messages")] - assert samples_with_messages, "expected at least one sample with messages" - for sample in samples_with_messages: - assert "input_tokens" in sample, f"sample missing input_tokens: {sample}" - assert isinstance(sample["input_tokens"], list) - assert len(sample["input_tokens"]) > 0 diff --git a/tests/unit/config/test_schema.py b/tests/unit/config/test_schema.py index 75e8b6e6..b837ab4b 100644 --- a/tests/unit/config/test_schema.py +++ b/tests/unit/config/test_schema.py @@ -541,7 +541,7 @@ def test_multi_turn_uses_dataset_size_ignoring_duration(self): assert rt.total_samples_to_issue() == 4316 @pytest.mark.unit - def test_multi_turn_respects_min_sample_count(self): + def test_multi_turn_clamps_to_dataset_size(self): lp = LoadPattern(type=LoadPatternType.MULTI_TURN, target_concurrency=4) rt = RuntimeSettings( metric_target=metrics.Throughput(10.0), @@ -555,7 +555,7 @@ def test_multi_turn_respects_min_sample_count(self): rng_sample_index=random.Random(0), load_pattern=lp, ) - assert rt.total_samples_to_issue() == 100 + assert rt.total_samples_to_issue() == 5 @pytest.mark.unit def test_multi_turn_explicit_n_samples_takes_precedence(self): From 5169265f1b27846ab9f55b17de74a87acc5eb390 Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Wed, 13 May 2026 23:35:34 +0000 Subject: [PATCH 29/41] fix: drop jinja2 import and fix test mocks for ISL precompute MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Top-level `import jinja2` in execute.py required an undeclared runtime dep, aborting CI at collection time. Revert the narrow except tuple back to `except Exception:` β€” `logger.exception` already delivers the traceback visibility the reviewer asked for. Fix `test_precompute_isl.py` patch targets (broken since AutoTokenizer moved to module top) and guard the all-failed check against the zero-samples-with-messages edge case. Co-Authored-By: Claude Opus 4.7 --- .../commands/benchmark/execute.py | 6 +++--- tests/unit/commands/test_precompute_isl.py | 20 ++++++++++++++----- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/src/inference_endpoint/commands/benchmark/execute.py b/src/inference_endpoint/commands/benchmark/execute.py index 7061e05c..3e8a0a73 100644 --- a/src/inference_endpoint/commands/benchmark/execute.py +++ b/src/inference_endpoint/commands/benchmark/execute.py @@ -38,7 +38,6 @@ from typing import Any from urllib.parse import urljoin -import jinja2 import msgspec.json from huggingface_hub import model_info from tqdm import tqdm @@ -328,7 +327,7 @@ def _precompute_isl_for_multi_turn( # instead of a plain list; extract .input_ids in that case. token_ids: list[int] = raw.input_ids if hasattr(raw, "input_ids") else raw sample["input_tokens"] = token_ids - except (jinja2.TemplateError, KeyError, ValueError, TypeError): + except Exception: if not first_failure_logged: logger.exception( "ISL pre-computation: apply_chat_template failed (first failure shown)" @@ -340,7 +339,8 @@ def _precompute_isl_for_multi_turn( "ISL pre-computation: %d turn(s) skipped (apply_chat_template failed)", skipped, ) - if skipped == len([s for s in (dataloader.data or []) if s.get("messages")]): + total_with_messages = len([s for s in (dataloader.data or []) if s.get("messages")]) + if total_with_messages > 0 and skipped == total_with_messages: raise RuntimeError( "ISL precomputation failed for all samples; check tokenizer/template compatibility" ) diff --git a/tests/unit/commands/test_precompute_isl.py b/tests/unit/commands/test_precompute_isl.py index 4c5f1223..7bb27e4d 100644 --- a/tests/unit/commands/test_precompute_isl.py +++ b/tests/unit/commands/test_precompute_isl.py @@ -40,7 +40,9 @@ def test_sets_input_tokens_for_samples_with_messages(self): range(len(msgs) * 3) ) - with patch("transformers.AutoTokenizer") as mock_cls: + with patch( + "inference_endpoint.commands.benchmark.execute.AutoTokenizer" + ) as mock_cls: mock_cls.from_pretrained.return_value = mock_tokenizer _precompute_isl_for_multi_turn(dataloader, "test-model") @@ -57,7 +59,9 @@ def test_leaves_samples_without_messages_untouched(self): dataloader = _make_dataloader(samples) mock_tokenizer = MagicMock() - with patch("transformers.AutoTokenizer") as mock_cls: + with patch( + "inference_endpoint.commands.benchmark.execute.AutoTokenizer" + ) as mock_cls: mock_cls.from_pretrained.return_value = mock_tokenizer _precompute_isl_for_multi_turn(dataloader, "test-model") @@ -81,7 +85,9 @@ def side_effect(msgs, **_): mock_tokenizer = MagicMock() mock_tokenizer.apply_chat_template.side_effect = side_effect - with patch("transformers.AutoTokenizer") as mock_cls: + with patch( + "inference_endpoint.commands.benchmark.execute.AutoTokenizer" + ) as mock_cls: mock_cls.from_pretrained.return_value = mock_tokenizer with caplog.at_level("WARNING"): _precompute_isl_for_multi_turn(dataloader, "test-model") @@ -102,7 +108,9 @@ def test_batch_encoding_return_value_is_unwrapped(self): mock_tokenizer = MagicMock() mock_tokenizer.apply_chat_template.return_value = batch_encoding - with patch("transformers.AutoTokenizer") as mock_cls: + with patch( + "inference_endpoint.commands.benchmark.execute.AutoTokenizer" + ) as mock_cls: mock_cls.from_pretrained.return_value = mock_tokenizer _precompute_isl_for_multi_turn(dataloader, "test-model") @@ -115,7 +123,9 @@ def test_add_generation_prompt_true(self): mock_tokenizer = MagicMock() mock_tokenizer.apply_chat_template.return_value = [1, 2, 3] - with patch("transformers.AutoTokenizer") as mock_cls: + with patch( + "inference_endpoint.commands.benchmark.execute.AutoTokenizer" + ) as mock_cls: mock_cls.from_pretrained.return_value = mock_tokenizer _precompute_isl_for_multi_turn(dataloader, "test-model") From 191c320902c1bbbbaa6c1ffbc7d349b583a04104 Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Thu, 14 May 2026 00:53:17 +0000 Subject: [PATCH 30/41] fix: address Copilot review comments on multi-turn implementation Code fixes: - Timeout handler now decrements PhaseIssuer.inflight, preventing _drain_inflight() from hanging when any turn times out (#1) - _precompute_isl_for_multi_turn normalizes tool_calls arguments before apply_chat_template, fixing Hermes-style template failures on tool-call-bearing samples (#2) - Non-streaming reasoning_content now set on TextModelOutput.reasoning so OSL/TPOT tokenization includes it (#3) - live-history data_override clears stale input_tokens and token_ids (#18) - execute() cleanup runs in finally block, surviving CancelledError (#21) - Validator rejects plain-assistant->tool (missing tool_calls) (#13) - Message builders include 'name' field for prior and current turns (#19, #20) - max_new_tokens sets all three adapter aliases (#22) - Multi-turn accuracy datasets raise InputValidationError (not yet supported) (#4) Examples and docs: - Remove ineffective samples: 10 from multi-turn YAML examples (#8, #9) - Fix events.jsonl docs (no conversation_id/turn_number fields) (#6, #7, #14) - Fix workers -> num_workers in quickstart config snippet (#10) - Fix agentic role-sequence grammar in README (#15) - Document that multi-turn accuracy datasets are not yet supported - Delete examples/10_CollectOutputs/ (#11, #12, #16, #17) Co-Authored-By: Claude Sonnet 4.6 (1M context) --- docs/MULTI_TURN_QUICKSTART.md | 35 +++--- examples/09_MultiTurn/README.md | 6 +- .../09_MultiTurn/multi_turn_benchmark.yaml | 1 - .../multi_turn_with_concurrency.yaml | 1 - examples/10_CollectOutputs/README.md | 106 ------------------ .../benchmark_with_output_collection.yaml | 35 ------ .../commands/benchmark/execute.py | 38 ++++--- .../dataset_manager/multi_turn_dataset.py | 34 ++++-- .../load_generator/multi_turn_strategy.py | 46 +++++--- .../openai/openai_msgspec_adapter.py | 1 + 10 files changed, 93 insertions(+), 210 deletions(-) delete mode 100644 examples/10_CollectOutputs/README.md delete mode 100644 examples/10_CollectOutputs/benchmark_with_output_collection.yaml diff --git a/docs/MULTI_TURN_QUICKSTART.md b/docs/MULTI_TURN_QUICKSTART.md index 340b4922..ddcc9828 100644 --- a/docs/MULTI_TURN_QUICKSTART.md +++ b/docs/MULTI_TURN_QUICKSTART.md @@ -49,7 +49,7 @@ settings: target_concurrency: 32 # ← Required: max simultaneous conversations client: - workers: 4 + num_workers: 4 endpoint_config: endpoints: @@ -71,8 +71,7 @@ That's it! Your benchmark will now: - βœ… Enforce turn ordering (turn N+1 waits for turn N) - βœ… Include conversation history in each request -- βœ… Track per-turn and per-conversation metrics -- βœ… Log all turns with conversation metadata +- βœ… Log all turns to events.jsonl --- @@ -82,21 +81,12 @@ After the benchmark completes, check the directory configured via `report_dir`: ### Events Log -The `events.jsonl` file contains one JSON record per line: - -- Standard fields: `sample_uuid`, `event_type`, `timestamp_ns` -- **New fields**: `conversation_id`, `turn_number` - -Query examples: - -```bash -# All events for a specific conversation -grep '"conversation_id": "c1"' logs/my_multi_turn_benchmark/events.jsonl - -# With jq for structured output -jq 'select(.conversation_id == "c1") | {conversation_id, turn_number, event_type, timestamp_ns}' \ - logs/my_multi_turn_benchmark/events.jsonl -``` +The `events.jsonl` file contains one JSON record per line, with the standard +`sample_uuid`, `event_type`, and `timestamp_ns` fields. Events are keyed by +`sample_uuid` only. To correlate events with conversations, join through +`sample_idx_map.json` (written next to `events.jsonl`) and the multi-turn +dataset's `conversation_metadata["samples"]`, which maps sample indices to +`(conversation_id, turn)` tuples. ### Metrics @@ -235,9 +225,9 @@ cat logs/multi_turn_test/benchmark.log ### 3. Verify Event Recording ```bash -# List all unique conversation IDs in the events log -jq -r '.conversation_id' logs/multi_turn_test/events.jsonl | sort -u -# Should show your conversation IDs +# List all sample UUIDs in the events log +jq -r '.sample_uuid' logs/multi_turn_test/events.jsonl | sort -u +# Should show UUIDs; correlate to conversations via sample_idx_map.json ``` --- @@ -252,7 +242,7 @@ jq -r '.conversation_id' logs/multi_turn_test/events.jsonl | sort -u ### Performance -- **Workers**: `client.workers` controls HTTP worker processes, independent of `target_concurrency`. The default (`-1`) auto-tunes based on NUMA topology. +- **Workers**: `client.num_workers` controls HTTP worker processes, independent of `target_concurrency`. The default (`-1`) auto-tunes based on NUMA topology. - **Timeout**: Set `turn_timeout_s` = 2x your longest expected turn latency - **Memory**: ~1KB per turn, plan accordingly for large datasets @@ -283,5 +273,6 @@ Before running your first multi-turn benchmark: - [ ] File uses `.jsonl` extension (format is auto-detected) - [ ] Conversation IDs are unique per conversation - [ ] Turn numbers are sequential (1, 2, 3, ...) +- [ ] Dataset is configured as `type: performance` (accuracy evaluation of multi-turn datasets is not yet supported) Happy benchmarking! diff --git a/examples/09_MultiTurn/README.md b/examples/09_MultiTurn/README.md index 8f623f1a..1f6a24fc 100644 --- a/examples/09_MultiTurn/README.md +++ b/examples/09_MultiTurn/README.md @@ -38,7 +38,7 @@ Multi-turn datasets use JSONL format with the following structure: conversations into memory first. 2. Conversations must follow a valid role sequence: - Plain chat: `user β†’ assistant β†’ user β†’ ...` - - Agentic: `user β†’ assistant (with tool_calls) β†’ tool β†’ [tool | assistant (with tool_calls)]* β†’ assistant β†’ user β†’ ...` + - Agentic: `user β†’ assistant (with tool_calls) β†’ tool (tool_results list; parallel results merged) β†’ [assistant (with tool_calls) β†’ tool]* β†’ assistant β†’ user β†’ ...` 3. First turn must be "user" role 4. Turn numbers must be sequential (1, 2, 3, ...) 5. Each conversation must have at least one turn @@ -210,7 +210,9 @@ Multi-turn benchmarks produce per-turn metrics: - **Per-turn metrics**: Latency, TTFT, TPOT for each individual turn - **Per-conversation metrics**: Total conversation latency, conversations per second _(planned β€” not yet implemented)_ -Results are stored in the configured `report_dir` with conversation metadata included in the events log (`events.jsonl`). +**Note**: Multi-turn datasets are only supported as performance datasets. Using a multi-turn dataset as an accuracy dataset (`type: accuracy`) is not yet supported and will raise an error at startup. + +Results are stored in the configured `report_dir`. The `events.jsonl` log contains one record per turn keyed by `sample_uuid`. To correlate events with conversations, join through `sample_idx_map.json` and the dataset's `conversation_metadata["samples"]`. ## Example Datasets diff --git a/examples/09_MultiTurn/multi_turn_benchmark.yaml b/examples/09_MultiTurn/multi_turn_benchmark.yaml index 8e5933e4..2cc815c9 100644 --- a/examples/09_MultiTurn/multi_turn_benchmark.yaml +++ b/examples/09_MultiTurn/multi_turn_benchmark.yaml @@ -11,7 +11,6 @@ datasets: - name: customer_support_conversations type: performance path: examples/09_MultiTurn/customer_support_conversations.jsonl - samples: 10 multi_turn: turn_timeout_s: 300.0 diff --git a/examples/09_MultiTurn/multi_turn_with_concurrency.yaml b/examples/09_MultiTurn/multi_turn_with_concurrency.yaml index f1466396..c1fcf26f 100644 --- a/examples/09_MultiTurn/multi_turn_with_concurrency.yaml +++ b/examples/09_MultiTurn/multi_turn_with_concurrency.yaml @@ -11,7 +11,6 @@ datasets: - name: customer_support_conversations type: performance path: examples/09_MultiTurn/customer_support_conversations.jsonl - samples: 10 multi_turn: turn_timeout_s: 300.0 diff --git a/examples/10_CollectOutputs/README.md b/examples/10_CollectOutputs/README.md deleted file mode 100644 index 6fc93a53..00000000 --- a/examples/10_CollectOutputs/README.md +++ /dev/null @@ -1,106 +0,0 @@ -# Collecting Outputs for Performance Runs - -This example demonstrates how to collect and log model response outputs during performance benchmarks. - -## Overview - -By default, performance runs (`--mode perf`) do **not** collect response outputs to minimize memory overhead and I/O latency. However, for debugging, analysis, or archival purposes, you can enable output collection using the `--collect-outputs` CLI flag or the `collect_outputs: true` config field. - -## When to Use - -- **Debugging**: Save outputs for analyzing model behavior under load -- **Validation**: Verify response quality without slowing down the benchmark -- **Archival**: Store responses for compliance or future analysis -- **Combined metrics**: Analyze performance alongside response content - -## Usage - -### CLI Flag - -```bash -uv run inference-endpoint benchmark offline \ - --endpoints http://localhost:8000 \ - --model meta-llama/Llama-2-7b-hf \ - --dataset perf:data.jsonl \ - --collect-outputs -``` - -### YAML Config - -```yaml -type: offline -collect_outputs: true -# ... rest of config -``` - -### From-config with override - -```bash -uv run inference-endpoint benchmark from-config \ - --config benchmark_with_output_collection.yaml -``` - -## Output Locations - -When `collect_outputs` is enabled, responses are stored in: - -- **JSONL format**: `{report_dir}/events/` β€” one JSON record per line -- **SQLite format**: `{report_dir}/events.db` β€” queryable database - -Each response is keyed by its query ID for correlation with performance metrics. - -## How It Works - -The `collect_outputs` flag **enables output collection for performance runs**: - -| Mode | Flag | Outputs Collected? | -| ------------- | ------------------- | ------------------ | -| `--mode perf` | none | ❌ | -| `--mode perf` | `--collect-outputs` | βœ… | -| `--mode acc` | n/a | βœ… | -| `--mode both` | n/a | βœ… | - -This allows performance benchmarks to optionally capture outputs without the overhead of full accuracy evaluation. - -## Memory Considerations - -Enabling output collection increases memory usage proportional to: - -- Number of queries issued -- Average response length (tokens β†’ bytes) - -For large-scale benchmarks (e.g., 100k+ queries), consider: - -- Using `--dataset perf:data.jsonl,samples=N` to limit dataset size -- Piping outputs to disk via the EventLogger (default behavior) -- Running in `--mode perf` (without collection) if storage is constrained - -## Integration with Accuracy Evaluation - -If you later want to run accuracy evaluation on collected outputs: - -```bash -uv run inference-endpoint benchmark offline \ - --endpoints http://localhost:8000 \ - --model meta-llama/Llama-2-7b-hf \ - --dataset acc:data.jsonl \ - --mode acc -``` - -The responses collected in the first run are independent; the accuracy run uses responses it collects during its own execution. - -## Example Workflow - -```bash -# 1. Performance benchmark with output collection -uv run inference-endpoint benchmark from-config \ - --config benchmark_with_output_collection.yaml \ - --report-dir results/perf_with_outputs - -# 2. Inspect collected outputs (JSONL) -head results/perf_with_outputs/events.jsonl | jq .data - -# 3. Query responses via SQLite -sqlite3 results/perf_with_outputs/events.db \ - "SELECT sample_uuid, data FROM event_records WHERE event_type LIKE '%complete%';" -``` diff --git a/examples/10_CollectOutputs/benchmark_with_output_collection.yaml b/examples/10_CollectOutputs/benchmark_with_output_collection.yaml deleted file mode 100644 index 40af4f80..00000000 --- a/examples/10_CollectOutputs/benchmark_with_output_collection.yaml +++ /dev/null @@ -1,35 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -# Example: Performance benchmark with output collection enabled -# -# This config demonstrates how to enable output logging for performance runs -# using the collect_outputs flag. Outputs (model responses) are collected and -# logged to the events database/JSONL, allowing for offline analysis without -# switching to accuracy mode. - -type: offline -name: offline_benchmark_with_outputs - -model_params: - name: meta-llama/Llama-2-7b-hf - streaming: off - -datasets: - - name: performance_dataset - path: data.jsonl - type: performance - -settings: - max_duration_ms: 60000 - load_pattern: - type: max_throughput - -endpoint_config: - endpoints: - - http://localhost:8000 - api_type: openai - -# Enable output collection for this performance run -# Outputs will be stored in {report_dir}/events/ and {report_dir}/events.db -collect_outputs: true diff --git a/src/inference_endpoint/commands/benchmark/execute.py b/src/inference_endpoint/commands/benchmark/execute.py index 3e8a0a73..1afc4c22 100644 --- a/src/inference_endpoint/commands/benchmark/execute.py +++ b/src/inference_endpoint/commands/benchmark/execute.py @@ -59,6 +59,9 @@ from inference_endpoint.async_utils.services.metrics_aggregator.metrics_table import ( MetricSeriesKey, ) +from inference_endpoint.async_utils.services.metrics_aggregator.token_metrics import ( + _normalize_tool_calls_for_template, +) from inference_endpoint.async_utils.transport.zmq.context import ManagedZMQContext from inference_endpoint.config.runtime_settings import RuntimeSettings from inference_endpoint.config.schema import ( @@ -318,8 +321,18 @@ def _precompute_isl_for_multi_turn( if not messages: continue try: + normalized_messages = [] + for msg in messages: + if msg.get("tool_calls"): + msg = { + **msg, + "tool_calls": _normalize_tool_calls_for_template( + msg["tool_calls"] + ), + } + normalized_messages.append(msg) raw = tokenizer.apply_chat_template( - messages, + normalized_messages, tokenize=True, add_generation_prompt=True, ) @@ -432,11 +445,12 @@ def _build_phases( perf_lp = ctx.rt_settings.load_pattern for eval_cfg in ctx.eval_configs: acc_ds = eval_cfg.dataset - if ( - perf_lp is not None - and perf_lp.type == LoadPatternType.MULTI_TURN - and not isinstance(acc_ds, MultiTurnDataset) - ): + if isinstance(acc_ds, MultiTurnDataset): + raise InputValidationError( + f"Accuracy dataset '{eval_cfg.dataset_name}' is a MultiTurnDataset, " + "which is not yet supported for accuracy evaluation." + ) + if perf_lp is not None and perf_lp.type == LoadPatternType.MULTI_TURN: # Plain accuracy datasets are single-turn; the multi-turn scheduler # requires MultiTurnDataset. Downgrade to CONCURRENCY with same cap. acc_load_pattern: LoadPattern | None = LoadPattern( @@ -445,18 +459,6 @@ def _build_phases( ) else: acc_load_pattern = perf_lp - if ( - acc_load_pattern is not None - and acc_load_pattern.type == LoadPatternType.MULTI_TURN - and not isinstance(acc_ds, MultiTurnDataset) - ): - raise InputValidationError( - f"Accuracy phase '{eval_cfg.dataset_name}' would use MULTI_TURN " - "load pattern but its dataset is not a MultiTurnDataset. This is " - "currently blocked by schema validation; if you're seeing this, " - "update _build_phases to construct a dedicated MultiTurnStrategy " - "for the accuracy phase." - ) acc_settings = RuntimeSettings( metric_target=ctx.rt_settings.metric_target, reported_metrics=ctx.rt_settings.reported_metrics, diff --git a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py index 46959261..507d9768 100644 --- a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py +++ b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py @@ -161,6 +161,7 @@ def _validate_conversation_structure(self): for conv_id, group in self._conv_groups.items(): sorted_group = group.sort_values("turn") state = "start" + prev_assistant_had_tool_calls = False for _, row in sorted_group.iterrows(): role = row["role"] @@ -172,6 +173,11 @@ def _validate_conversation_structure(self): ) if role == "tool": + if state == "assistant" and not prev_assistant_had_tool_calls: + raise InputValidationError( + f"Conversation {conv_id} turn {row['turn']}: " + "'tool' row must follow an 'assistant' row that has non-empty 'tool_calls'" + ) tool_results = row.get("tool_results") if not isinstance(tool_results, list) or len(tool_results) == 0: raise InputValidationError( @@ -194,6 +200,10 @@ def _validate_conversation_structure(self): f"Conversation {conv_id} turn {row['turn']}: " "assistant rows must have non-empty 'content' or non-empty 'tool_calls'" ) + prev_assistant_had_tool_calls = has_tool_calls + + if role != "assistant": + prev_assistant_had_tool_calls = False state = role @@ -265,7 +275,13 @@ def _build_metadata(self) -> dict[str, Any]: prior_rows = sorted_group[sorted_group["turn"] < t_n] for _, prior_row in prior_rows.iterrows(): msg: dict[str, Any] = {} - for key in ("role", "content", "tool_calls", "tool_results"): + for key in ( + "role", + "content", + "name", + "tool_calls", + "tool_results", + ): val = prior_row.get(key) if val is not None and not ( isinstance(val, float) and pd.isna(val) @@ -296,7 +312,7 @@ def _build_metadata(self) -> dict[str, Any]: current_turn_msgs = expanded else: cur: dict[str, Any] = {} - for key in ("role", "content"): + for key in ("role", "content", "name"): val = row.get(key) if val is not None and not ( isinstance(val, float) and pd.isna(val) @@ -425,11 +441,15 @@ def load( if k not in sample: sample[k] = v - # max_new_tokens β†’ max_completion_tokens alias - if "max_completion_tokens" not in sample and "max_new_tokens" in sample: - sample["max_completion_tokens"] = sample.pop("max_new_tokens") - if "max_completion_tokens" not in sample: - sample["max_completion_tokens"] = 128 + # Normalize max-tokens across all adapter aliases. + max_tokens_val = ( + sample.pop("max_new_tokens", None) + or sample.get("max_completion_tokens") + or 128 + ) + sample["max_new_tokens"] = max_tokens_val + sample["max_completion_tokens"] = max_tokens_val + sample["max_tokens"] = max_tokens_val if "stream" not in sample: sample["stream"] = False diff --git a/src/inference_endpoint/load_generator/multi_turn_strategy.py b/src/inference_endpoint/load_generator/multi_turn_strategy.py index 03a6439d..f36b68a4 100644 --- a/src/inference_endpoint/load_generator/multi_turn_strategy.py +++ b/src/inference_endpoint/load_generator/multi_turn_strategy.py @@ -184,23 +184,22 @@ async def execute(self, phase_issuer: PhaseIssuerProtocol) -> int: if not self._active_iters and not self._inflight: return phase_issuer.issued_count - await self._all_done.wait() - - for handle in self._timeout_handles.values(): - handle.cancel() - self._timeout_handles.clear() - - if self._inflight: - logger.warning( - "%d query(ies) never received a response (session stop or transport failure): %s", - len(self._inflight), - list(self._inflight.keys()), - ) - self._inflight.clear() - - if self._error is not None: - raise self._error - return phase_issuer.issued_count + try: + await self._all_done.wait() + if self._error is not None: + raise self._error + return phase_issuer.issued_count + finally: + for handle in self._timeout_handles.values(): + handle.cancel() + self._timeout_handles.clear() + if self._inflight: + logger.warning( + "%d query(ies) never received a response (session stop or transport failure): %s", + len(self._inflight), + list(self._inflight.keys()), + ) + self._inflight.clear() def _start_conversation(self) -> None: """Pop the next conversation from the pending queue and issue its first turn.""" @@ -231,7 +230,11 @@ def _issue_next_turn(self, conv_id: str) -> None: ).get((conv_id, turn)) if current_turn_messages: live_messages = state.message_history.copy() + current_turn_messages - data_override = {"messages": live_messages} + data_override = { + "messages": live_messages, + "input_tokens": None, + "token_ids": None, + } assert self._phase_issuer is not None query_id = self._phase_issuer.issue(idx, data_override=data_override) @@ -266,6 +269,13 @@ def _handle_timeout(self, query_id: str, conv_id: str) -> None: return self._timeout_handles.pop(query_id, None) + if ( + self._phase_issuer is not None + and hasattr(self._phase_issuer, "uuid_to_index") + and query_id in self._phase_issuer.uuid_to_index # type: ignore[attr-defined] + ): + self._phase_issuer.inflight -= 1 # type: ignore[attr-defined] + logger.warning( "Turn timed out for conversation %s (query=%s)", conv_id, query_id ) diff --git a/src/inference_endpoint/openai/openai_msgspec_adapter.py b/src/inference_endpoint/openai/openai_msgspec_adapter.py index ece0b6d3..91ac7756 100644 --- a/src/inference_endpoint/openai/openai_msgspec_adapter.py +++ b/src/inference_endpoint/openai/openai_msgspec_adapter.py @@ -223,6 +223,7 @@ def from_endpoint_response( id=result_id or response.id, response_output=TextModelOutput( output=choice.message.content or "", + reasoning=choice.message.reasoning_content, tool_calls=tool_calls_tuple, ), metadata=metadata, From 51d37dd6392fca2171c672c3ab13f3c7632518e6 Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Thu, 14 May 2026 04:56:49 +0000 Subject: [PATCH 31/41] refactor: typed ConversationMetadata dataclass, single build in load(), tool-call ISL test - R1: Move _build_metadata() from __init__ to load() so pre_built_messages_by_key is always built against the post-transform dataframe (latent desync bug). conversation_metadata is None until load() is called. - R2: Replace conversation_metadata dict[str, Any] with @dataclass ConversationMetadata (+ ConversationSampleEntry). Attribute access in MultiTurnStrategy replaces .get() calls; mypy now flags typos at the call site. - R3: Append unit test to TestPrecomputeIslForMultiTurn asserting that _normalize_tool_calls_for_template converts tool_calls[].function.arguments from JSON string to dict before apply_chat_template (previously untested branch). Co-Authored-By: Claude Sonnet 4.6 (1M context) --- .../dataset_manager/multi_turn_dataset.py | 101 ++++++++++++------ .../load_generator/multi_turn_strategy.py | 26 +++-- tests/unit/commands/test_precompute_isl.py | 39 +++++++ .../test_multi_turn_dataset.py | 46 ++++---- .../test_multi_turn_strategy.py | 50 +++++---- 5 files changed, 177 insertions(+), 85 deletions(-) diff --git a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py index 507d9768..afea505c 100644 --- a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py +++ b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py @@ -16,6 +16,7 @@ """Multi-turn conversation dataset for conversational AI benchmarking.""" import logging +from dataclasses import dataclass, field, replace from typing import Any import pandas as pd @@ -32,6 +33,41 @@ logger = logging.getLogger(__name__) +@dataclass +class ConversationSampleEntry: + """One client-turn entry in ConversationMetadata.samples. + + sample_index is populated after transforms in MultiTurnDataset.load(); + None before load() is called. + """ + + conversation_id: str + turn: int + sample_index: int | None = None + + +@dataclass +class ConversationMetadata: + """Bundle of maps/lists consumed by MultiTurnStrategy. + + Produced by MultiTurnDataset._build_metadata() from the post-transform dataframe. + Keys in the *_by_key dicts are (str(conversation_id), int(turn)). + Populated by load(); None before load() is called. + """ + + samples: list[ConversationSampleEntry] + num_conversations: int + max_turns_per_conv: int + client_turns_per_conversation: dict[str, int] + pre_built_messages_by_key: dict[tuple[str, int], list[dict]] = field( + default_factory=dict + ) + current_turn_messages_by_key: dict[tuple[str, int], list[dict]] = field( + default_factory=dict + ) + system_prompts_by_conv: dict[str, str | None] = field(default_factory=dict) + + def _expand_tool_results(row: dict) -> list[dict]: """Expand a tool row into one OpenAI tool message per result. @@ -94,10 +130,9 @@ class MultiTurnDataset(Dataset, dataset_id="multi_turn_conversations"): - max_new_tokens / max_completion_tokens: Max tokens for this turn (alias; mapped to max_completion_tokens) Attributes: - conversation_metadata: Metadata dict containing: - - samples: List of user turn metadata (index, conversation_id, turn, system) - - num_conversations: Total number of unique conversations - - max_turns_per_conv: Maximum turns in any conversation + conversation_metadata: ConversationMetadata populated by load() (None before). + Validators run at construction; metadata is built once in load() against + the post-transform dataframe so pre_built_messages_by_key is always in sync. """ COLUMN_NAMES = ["conversation_id", "turn", "role", "content"] @@ -120,7 +155,8 @@ def __init__(self, dataframe: pd.DataFrame, **kwargs): self._validate_conversation_grouping() self._validate_conversation_structure() self._validate_turn_numbering() - self.conversation_metadata = self._build_metadata() + # Populated by load() after transforms; None until then. + self.conversation_metadata: ConversationMetadata | None = None def _validate_conversation_grouping(self) -> None: """Validate that all rows for each conversation_id appear consecutively in file order. @@ -222,17 +258,16 @@ def _validate_turn_numbering(self): f"got {turns}" ) - def _build_metadata(self) -> dict[str, Any]: + def _build_metadata(self) -> ConversationMetadata: """Build metadata for scheduler (maps sample index to conversation context). Pre-computes the complete message list for each client turn so that conversation history does not need to be accumulated at runtime. Returns: - Metadata dict with samples list, num_conversations, max_turns_per_conv, - client_turns_per_conversation, and pre_built_messages_by_key. + ConversationMetadata with samples, counts, and pre-built message maps. """ - samples = [] + samples: list[ConversationSampleEntry] = [] # Count client turns (user + tool) per conversation for completion tracking client_turns_per_conv = { @@ -326,23 +361,21 @@ def _build_metadata(self) -> dict[str, Any]: current_turn_messages_by_key[(str_conv_id, t_n)] = current_turn_msgs samples.append( - { - "conversation_id": str_conv_id, - "turn": t_n, - } + ConversationSampleEntry( + conversation_id=str_conv_id, + turn=t_n, + ) ) - return { - "samples": samples, - "num_conversations": len(self._conv_groups), - "max_turns_per_conv": max( - g["turn"].max() for g in self._conv_groups.values() - ), - "client_turns_per_conversation": client_turns_per_conv, - "pre_built_messages_by_key": pre_built_messages_by_key, - "current_turn_messages_by_key": current_turn_messages_by_key, - "system_prompts_by_conv": system_prompts_by_conv, - } + return ConversationMetadata( + samples=samples, + num_conversations=len(self._conv_groups), + max_turns_per_conv=max(g["turn"].max() for g in self._conv_groups.values()), + client_turns_per_conversation=client_turns_per_conv, + pre_built_messages_by_key=pre_built_messages_by_key, + current_turn_messages_by_key=current_turn_messages_by_key, + system_prompts_by_conv=system_prompts_by_conv, + ) def load( self, @@ -390,12 +423,18 @@ def load( if defaults: df = AddStaticColumns(defaults, overwrite=False)(df) + # Rebuild conv_groups + metadata from the final post-transform df so + # pre_built_messages_by_key reflects any transforms applied above. + self.dataframe = df + self._conv_groups = dict(list(df.groupby("conversation_id", sort=False))) + self.conversation_metadata = self._build_metadata() + all_rows = df.to_dict(orient="records") # Pre-bake: assemble one complete sample dict per client turn. # NaN filtering replaces the GENERATION_PARAMS allowlist β€” any key whose # value is float NaN was absent in the original dataset row. - pre_built = self.conversation_metadata.get("pre_built_messages_by_key", {}) + pre_built = self.conversation_metadata.pre_built_messages_by_key client_turn_samples: list[dict[str, Any]] = [] # Maps (conv_id, turn) β†’ dense sample_index for metadata backfill. key_to_sample_index: dict[tuple[str, int], int] = {} @@ -466,13 +505,15 @@ def load( key_to_sample_index[key] = len(client_turn_samples) client_turn_samples.append(sample) - # Backfill explicit sample_index into conversation_metadata["samples"]. + # Backfill explicit sample_index into conversation_metadata.samples. # Drop entries whose key is absent (truncated turns not in client_turn_samples). updated_samples = [] - for s in self.conversation_metadata["samples"]: - skey: tuple[str, int] = (str(s["conversation_id"]), int(s["turn"])) + for s in self.conversation_metadata.samples: + skey: tuple[str, int] = (str(s.conversation_id), int(s.turn)) if skey in key_to_sample_index: - updated_samples.append({**s, "sample_index": key_to_sample_index[skey]}) - self.conversation_metadata["samples"] = updated_samples + updated_samples.append( + replace(s, sample_index=key_to_sample_index[skey]) + ) + self.conversation_metadata.samples = updated_samples self.data = client_turn_samples diff --git a/src/inference_endpoint/load_generator/multi_turn_strategy.py b/src/inference_endpoint/load_generator/multi_turn_strategy.py index f36b68a4..b6e473c4 100644 --- a/src/inference_endpoint/load_generator/multi_turn_strategy.py +++ b/src/inference_endpoint/load_generator/multi_turn_strategy.py @@ -23,6 +23,7 @@ from ..config.schema import MultiTurnConfig from ..core.types import ErrorData, QueryResult, TextModelOutput +from ..dataset_manager.multi_turn_dataset import ConversationMetadata from ..exceptions import InputValidationError from .conversation_manager import ConversationManager, ConversationState from .strategy import PhaseIssuerProtocol @@ -65,7 +66,7 @@ class MultiTurnStrategy: def __init__( self, conversation_manager: ConversationManager, - dataset_metadata: dict[str, Any], + dataset_metadata: ConversationMetadata, multi_turn_config: MultiTurnConfig | None = None, target_concurrency: int | None = None, ): @@ -73,7 +74,7 @@ def __init__( Args: conversation_manager: Manages conversation sequencing state. - dataset_metadata: Metadata from MultiTurnDataset (samples list). + dataset_metadata: ConversationMetadata from MultiTurnDataset (after load()). multi_turn_config: Multi-turn conversation configuration. target_concurrency: Maximum number of simultaneously active conversations. None means all conversations run concurrently. @@ -96,9 +97,7 @@ def __init__( if self._store_in_history: tool_turn_keys = [ key - for key, msgs in dataset_metadata.get( - "current_turn_messages_by_key", {} - ).items() + for key, msgs in dataset_metadata.current_turn_messages_by_key.items() if any(m.get("role") == "tool" for m in msgs) ] if tool_turn_keys: @@ -147,14 +146,13 @@ async def execute(self, phase_issuer: PhaseIssuerProtocol) -> int: self._error = None conv_samples: dict[str, list[tuple[int, int]]] = defaultdict(list) - for sample_meta in self._dataset_metadata["samples"]: - conv_id = sample_meta["conversation_id"] - conv_samples[conv_id].append( - (sample_meta["sample_index"], sample_meta["turn"]) - ) + for sample_meta in self._dataset_metadata.samples: + conv_id = sample_meta.conversation_id + assert sample_meta.sample_index is not None + conv_samples[conv_id].append((sample_meta.sample_index, sample_meta.turn)) # Pre-create all conversation states before issuing any turns (no locking needed). - sys_prompts = self._dataset_metadata.get("system_prompts_by_conv", {}) + sys_prompts = self._dataset_metadata.system_prompts_by_conv for conv_id, turns in conv_samples.items(): sys_content = sys_prompts.get(conv_id) if self._store_in_history else None system_message = ( @@ -225,9 +223,9 @@ def _issue_next_turn(self, conv_id: str) -> None: data_override: dict[str, Any] | None = None current_turn_messages: list[dict[str, Any]] | None = None if self._store_in_history: - current_turn_messages = self._dataset_metadata.get( - "current_turn_messages_by_key", {} - ).get((conv_id, turn)) + current_turn_messages = ( + self._dataset_metadata.current_turn_messages_by_key.get((conv_id, turn)) + ) if current_turn_messages: live_messages = state.message_history.copy() + current_turn_messages data_override = { diff --git a/tests/unit/commands/test_precompute_isl.py b/tests/unit/commands/test_precompute_isl.py index 7bb27e4d..d4ccf852 100644 --- a/tests/unit/commands/test_precompute_isl.py +++ b/tests/unit/commands/test_precompute_isl.py @@ -132,3 +132,42 @@ def test_add_generation_prompt_true(self): _, kwargs = mock_tokenizer.apply_chat_template.call_args assert kwargs.get("add_generation_prompt") is True assert kwargs.get("tokenize") is True + + @pytest.mark.unit + def test_normalizes_tool_call_arguments_before_apply_chat_template(self): + """_normalize_tool_calls_for_template converts arguments strings to dicts.""" + samples = [ + { + "messages": [ + {"role": "user", "content": "use a tool"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "c1", + "type": "function", + "function": { + "name": "bash", + "arguments": '{"cmd": "ls"}', + }, + } + ], + }, + ] + }, + ] + dataloader = _make_dataloader(samples) + mock_tokenizer = MagicMock() + mock_tokenizer.apply_chat_template.return_value = [1, 2, 3] + + with patch( + "inference_endpoint.commands.benchmark.execute.AutoTokenizer" + ) as mock_cls: + mock_cls.from_pretrained.return_value = mock_tokenizer + _precompute_isl_for_multi_turn(dataloader, "test-model") + + # Production code builds new dicts, so call_args captures the normalized value. + passed_msgs = mock_tokenizer.apply_chat_template.call_args[0][0] + asst_msg = next(m for m in passed_msgs if m.get("role") == "assistant") + assert asst_msg["tool_calls"][0]["function"]["arguments"] == {"cmd": "ls"} diff --git a/tests/unit/dataset_manager/test_multi_turn_dataset.py b/tests/unit/dataset_manager/test_multi_turn_dataset.py index 62301940..e893a47c 100644 --- a/tests/unit/dataset_manager/test_multi_turn_dataset.py +++ b/tests/unit/dataset_manager/test_multi_turn_dataset.py @@ -194,25 +194,19 @@ def test_multi_turn_dataset_conversation_metadata(valid_multi_turn_jsonl): metadata = dataset.conversation_metadata - # Check metadata structure - assert "samples" in metadata - assert "num_conversations" in metadata - assert "max_turns_per_conv" in metadata - assert "client_turns_per_conversation" in metadata - # Should have 3 client turn samples (fixture has only user turns, no tool turns) - assert len(metadata["samples"]) == 3 + assert len(metadata.samples) == 3 # Should have 2 conversations - assert metadata["num_conversations"] == 2 + assert metadata.num_conversations == 2 # Max turns per conversation should be 3 (conv_001 has 3 turns) - assert metadata["max_turns_per_conv"] == 3 + assert metadata.max_turns_per_conv == 3 # Check sample metadata structure - sample_meta = metadata["samples"][0] - assert "conversation_id" in sample_meta - assert "turn" in sample_meta + sample_meta = metadata.samples[0] + assert sample_meta.conversation_id is not None + assert sample_meta.turn is not None @pytest.mark.unit @@ -273,8 +267,8 @@ def test_multi_turn_dataset_multiple_conversations(): # Metadata checks metadata = dataset.conversation_metadata - assert metadata["num_conversations"] == 3 - assert metadata["max_turns_per_conv"] == 4 # c2 has 4 turns + assert metadata.num_conversations == 3 + assert metadata.max_turns_per_conv == 4 # c2 has 4 turns # Verify user turns are correctly indexed samples = [dataset.load_sample(i) for i in range(5)] @@ -898,8 +892,9 @@ def test_build_metadata_pre_built_messages(): """ df = _make_tool_sequence_df() ds = MultiTurnDataset(df) + ds.load() - pbm = ds.conversation_metadata["pre_built_messages_by_key"] + pbm = ds.conversation_metadata.pre_built_messages_by_key # Client turn 1 (user, t=1): [system, user(1)] msgs_t1 = pbm[("c1", 1)] @@ -936,7 +931,8 @@ def test_build_metadata_pre_built_messages_no_tools(): ] ) ds = MultiTurnDataset(df) - pbm = ds.conversation_metadata["pre_built_messages_by_key"] + ds.load() + pbm = ds.conversation_metadata.pre_built_messages_by_key # Turn 1: just the user message (no system, no prior rows) assert pbm[("c1", 1)] == [{"role": "user", "content": "A"}] @@ -1072,7 +1068,8 @@ def test_prior_tool_row_expanded_with_tool_call_id(): """Prior tool rows must expand to messages with tool_call_id and content (BUG 1).""" df = _make_tool_sequence_df() ds = MultiTurnDataset(df) - pbm = ds.conversation_metadata["pre_built_messages_by_key"] + ds.load() + pbm = ds.conversation_metadata.pre_built_messages_by_key # Client turn 3 (user, t=5) has a prior tool row at t=3. # msgs_t5[3] should be the expanded tool message with proper fields. @@ -1126,7 +1123,8 @@ def test_prior_parallel_tool_results_expand_to_multiple_messages(): ] ) ds = MultiTurnDataset(df) - pbm = ds.conversation_metadata["pre_built_messages_by_key"] + ds.load() + pbm = ds.conversation_metadata.pre_built_messages_by_key # user(5) sees prior rows: user(1), assistant(2), tool(3)x2, assistant(4) msgs_t5 = pbm[("c1", 5)] @@ -1143,7 +1141,8 @@ def test_assistant_content_null_preserved_in_history(): """Assistant messages with tool_calls and content:null include content key (BUG 2).""" df = _make_tool_sequence_df() ds = MultiTurnDataset(df) - pbm = ds.conversation_metadata["pre_built_messages_by_key"] + ds.load() + pbm = ds.conversation_metadata.pre_built_messages_by_key # Client turn 2 (tool, t=3): prior includes assistant(2) with tool_calls + content: null msgs_t3 = pbm[("c1", 3)] @@ -1284,7 +1283,8 @@ def test_current_turn_messages_by_key_parallel_tools(): ] ) ds = MultiTurnDataset(df) - ctm = ds.conversation_metadata["current_turn_messages_by_key"] + ds.load() + ctm = ds.conversation_metadata.current_turn_messages_by_key # user(1) current turn is 1 message assert len(ctm[("c1", 1)]) == 1 @@ -1318,8 +1318,9 @@ def test_metadata_contains_system_prompts_by_conv(): ] df = pd.DataFrame(data) ds = MultiTurnDataset(df) + ds.load() - spc = ds.conversation_metadata["system_prompts_by_conv"] + spc = ds.conversation_metadata.system_prompts_by_conv assert spc["c1"] == "Be concise" assert spc["c2"] is None @@ -1347,8 +1348,9 @@ def test_metadata_system_prompts_multiple_convs(): ] df = pd.DataFrame(data) ds = MultiTurnDataset(df) + ds.load() - spc = ds.conversation_metadata["system_prompts_by_conv"] + spc = ds.conversation_metadata.system_prompts_by_conv assert spc["c1"] == "Sys1" assert spc["c2"] == "Sys2" diff --git a/tests/unit/load_generator/test_multi_turn_strategy.py b/tests/unit/load_generator/test_multi_turn_strategy.py index d3c9a22a..0f9e41c8 100644 --- a/tests/unit/load_generator/test_multi_turn_strategy.py +++ b/tests/unit/load_generator/test_multi_turn_strategy.py @@ -19,6 +19,10 @@ import pytest from inference_endpoint.core.types import QueryResult, TextModelOutput +from inference_endpoint.dataset_manager.multi_turn_dataset import ( + ConversationMetadata, + ConversationSampleEntry, +) from inference_endpoint.load_generator.conversation_manager import ConversationManager from inference_endpoint.load_generator.multi_turn_strategy import MultiTurnStrategy @@ -42,21 +46,26 @@ def issue(self, sample_index: int, data_override: dict | None = None) -> str | N return query_id -def _make_dataset_metadata(conversations: dict[str, list[int]]) -> dict: - """Build dataset_metadata dict from {conv_id: [turn_numbers]} mapping.""" +def _make_dataset_metadata(conversations: dict[str, list[int]]) -> ConversationMetadata: + """Build ConversationMetadata from {conv_id: [turn_numbers]} mapping.""" samples = [] sample_index = 0 for conv_id, turns in conversations.items(): for turn in turns: samples.append( - { - "conversation_id": conv_id, - "turn": turn, - "sample_index": sample_index, - } + ConversationSampleEntry( + conversation_id=conv_id, + turn=turn, + sample_index=sample_index, + ) ) sample_index += 1 - return {"samples": samples} + return ConversationMetadata( + samples=samples, + num_conversations=len(conversations), + max_turns_per_conv=max((max(t) for t in conversations.values()), default=0), + client_turns_per_conversation={c: len(t) for c, t in conversations.items()}, + ) @pytest.mark.unit @@ -252,24 +261,27 @@ async def test_error_response_marks_turn_failed(): def _make_metadata_with_system( conversations: dict[str, list[int]], system_prompts: dict[str, str | None] | None = None, -) -> dict: - """Build metadata dict including system_prompts_by_conv.""" +) -> ConversationMetadata: + """Build ConversationMetadata including system_prompts_by_conv.""" samples = [] sample_index = 0 for conv_id, turns in conversations.items(): for turn in turns: samples.append( - { - "conversation_id": conv_id, - "turn": turn, - "sample_index": sample_index, - } + ConversationSampleEntry( + conversation_id=conv_id, + turn=turn, + sample_index=sample_index, + ) ) sample_index += 1 - return { - "samples": samples, - "system_prompts_by_conv": system_prompts or {}, - } + return ConversationMetadata( + samples=samples, + num_conversations=len(conversations), + max_turns_per_conv=max((max(t) for t in conversations.values()), default=0), + client_turns_per_conversation={c: len(t) for c, t in conversations.items()}, + system_prompts_by_conv=system_prompts or {}, + ) @pytest.mark.unit From dda44bb3eba97dfa521d25a934103cbddfdf356a Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Thu, 14 May 2026 05:37:20 +0000 Subject: [PATCH 32/41] fix: address PR #285 round-4 review comments - Assert conversation_metadata is not None before MultiTurnStrategy construction (fixes mypy CI failure at execute.py:634 and test_multi_turn.py:89 introduced by R1 moving build to load()) - Narrow SSE decode exception catch from bare Exception to msgspec.DecodeError/ValidationError and raise log level to warning - Clear uuid_to_index entry on turn timeout to prevent inflight counter going negative when a late response arrives after timeout - Add @pytest.mark.unit to 7 unmarked test classes in test_types.py so they are included in pytest -m unit CI lane Co-Authored-By: Claude Sonnet 4.6 (1M context) --- src/inference_endpoint/commands/benchmark/execute.py | 1 + src/inference_endpoint/endpoint_client/adapter_protocol.py | 6 ++++-- .../load_generator/multi_turn_strategy.py | 1 + tests/integration/test_multi_turn.py | 1 + tests/unit/core/test_types.py | 7 +++++++ 5 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/inference_endpoint/commands/benchmark/execute.py b/src/inference_endpoint/commands/benchmark/execute.py index 1afc4c22..eb478228 100644 --- a/src/inference_endpoint/commands/benchmark/execute.py +++ b/src/inference_endpoint/commands/benchmark/execute.py @@ -629,6 +629,7 @@ async def _run_benchmark_async( ) if perf_ds_cfg is not None: mt_cfg = perf_ds_cfg.multi_turn + assert ctx.dataloader.conversation_metadata is not None multi_turn_strategy = MultiTurnStrategy( conversation_manager=ConversationManager(), dataset_metadata=ctx.dataloader.conversation_metadata, diff --git a/src/inference_endpoint/endpoint_client/adapter_protocol.py b/src/inference_endpoint/endpoint_client/adapter_protocol.py index 649424f8..ffa33485 100644 --- a/src/inference_endpoint/endpoint_client/adapter_protocol.py +++ b/src/inference_endpoint/endpoint_client/adapter_protocol.py @@ -22,6 +22,8 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any +import msgspec + from inference_endpoint.core.types import Query, QueryResult if TYPE_CHECKING: @@ -128,8 +130,8 @@ def parse_sse_chunk(cls, buffer: bytes, end_pos: int) -> list[Any]: for json_doc in json_docs: try: content = cls.decode_sse_message(json_doc) - except Exception: - logger.debug("skipping non-content SSE frame: %s", json_doc[:120]) + except (msgspec.DecodeError, msgspec.ValidationError): + logger.warning("skipping malformed SSE frame: %s", json_doc[:120]) continue if content is not None: parsed.append(content) diff --git a/src/inference_endpoint/load_generator/multi_turn_strategy.py b/src/inference_endpoint/load_generator/multi_turn_strategy.py index b6e473c4..1ff12bc4 100644 --- a/src/inference_endpoint/load_generator/multi_turn_strategy.py +++ b/src/inference_endpoint/load_generator/multi_turn_strategy.py @@ -273,6 +273,7 @@ def _handle_timeout(self, query_id: str, conv_id: str) -> None: and query_id in self._phase_issuer.uuid_to_index # type: ignore[attr-defined] ): self._phase_issuer.inflight -= 1 # type: ignore[attr-defined] + del self._phase_issuer.uuid_to_index[query_id] # type: ignore[attr-defined] logger.warning( "Turn timed out for conversation %s (query=%s)", conv_id, query_id diff --git a/tests/integration/test_multi_turn.py b/tests/integration/test_multi_turn.py index 404535f0..bd1f82a9 100644 --- a/tests/integration/test_multi_turn.py +++ b/tests/integration/test_multi_turn.py @@ -84,6 +84,7 @@ def _make_strategy( turn_timeout_s=10.0, use_dataset_history=use_dataset_history, ) + assert ds.conversation_metadata is not None return MultiTurnStrategy( conversation_manager=ConversationManager(), dataset_metadata=ds.conversation_metadata, diff --git a/tests/unit/core/test_types.py b/tests/unit/core/test_types.py index 32d15ab2..6dc53a28 100644 --- a/tests/unit/core/test_types.py +++ b/tests/unit/core/test_types.py @@ -33,6 +33,7 @@ ) +@pytest.mark.unit class TestErrorData: """Test ErrorData string representation.""" @@ -47,6 +48,7 @@ def test_error_data_str_without_message(self): assert str(err) == "TimeoutError" +@pytest.mark.unit class TestQuerySerialization: """Test Query msgspec.msgpack serialization with various field combinations.""" @@ -186,6 +188,7 @@ def test_query_multiple_roundtrips(self): assert decoded2.created_at == original.created_at +@pytest.mark.unit class TestQueryResultSerialization: """Test QueryResult msgspec.msgpack serialization with various field combinations.""" @@ -467,6 +470,7 @@ def test_query_result_multiple_roundtrips(self): assert decoded2.metadata == original.metadata +@pytest.mark.unit class TestStreamChunkSerialization: """Test StreamChunk msgspec.msgpack serialization with various field combinations.""" @@ -589,6 +593,7 @@ def test_stream_chunk_multiple_roundtrips(self): assert decoded2.metadata == original.metadata +@pytest.mark.unit class TestQueryResultWorkerPatterns: """Test QueryResult serialization patterns used by worker.py (TextModelOutput).""" @@ -690,6 +695,7 @@ def test_query_result_single_output_chunk(self): assert len(decoded.response_output.output) == 1 +@pytest.mark.unit class TestTextAfterFirstChunk: """Test TextModelOutput.text_after_first_chunk() for all reasoning/output combos.""" @@ -744,6 +750,7 @@ def test_text_after_first_chunk(self, reasoning, output, expected): assert tmo.text_after_first_chunk() == expected +@pytest.mark.unit class TestMixedTypeSerialization: """Test serialization of mixed type combinations and edge cases.""" From 595faf442fcbe71f75e4351315a7e90ff20bb7cd Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Thu, 14 May 2026 17:05:21 +0000 Subject: [PATCH 33/41] Send conversation_id and turn number Signed-off-by: Li, Tianmu --- .../services/event_logger/sql_writer.py | 5 +++++ src/inference_endpoint/core/record.py | 2 ++ .../load_generator/multi_turn_strategy.py | 8 +++++++- .../load_generator/session.py | 20 ++++++++++++++++++- .../load_generator/strategy.py | 6 +++++- 5 files changed, 38 insertions(+), 3 deletions(-) diff --git a/src/inference_endpoint/async_utils/services/event_logger/sql_writer.py b/src/inference_endpoint/async_utils/services/event_logger/sql_writer.py index fd6a1559..52c9b010 100644 --- a/src/inference_endpoint/async_utils/services/event_logger/sql_writer.py +++ b/src/inference_endpoint/async_utils/services/event_logger/sql_writer.py @@ -50,6 +50,9 @@ class EventRowModel(Base): timestamp_ns: Mapped[int] = mapped_column(BigInteger, nullable=False) """Monotonic timestamp in nanoseconds.""" + conversation_id: Mapped[str] = mapped_column(String, nullable=False, default="") + turn: Mapped[int | None] = mapped_column(Integer, nullable=True) + data: Mapped[bytes] = mapped_column(LargeBinary, nullable=False, default=b"") """JSON-encoded event data.""" @@ -61,6 +64,8 @@ def _record_to_row(record: EventRecord) -> EventRowModel: sample_uuid=record.sample_uuid, event_type=topic, timestamp_ns=record.timestamp_ns, + conversation_id=record.conversation_id, + turn=record.turn, data=msgspec.json.encode(record.data), ) diff --git a/src/inference_endpoint/core/record.py b/src/inference_endpoint/core/record.py index da35389e..9ac60e1b 100644 --- a/src/inference_endpoint/core/record.py +++ b/src/inference_endpoint/core/record.py @@ -153,6 +153,8 @@ class EventRecord(msgspec.Struct, kw_only=True, frozen=True, gc=False): # type: event_type: EventType timestamp_ns: int = msgspec.field(default_factory=time.monotonic_ns) sample_uuid: str = "" + conversation_id: str = "" + turn: int | None = None data: OUTPUT_TYPE | PromptData | ErrorData | None = None diff --git a/src/inference_endpoint/load_generator/multi_turn_strategy.py b/src/inference_endpoint/load_generator/multi_turn_strategy.py index 1ff12bc4..aa847ba5 100644 --- a/src/inference_endpoint/load_generator/multi_turn_strategy.py +++ b/src/inference_endpoint/load_generator/multi_turn_strategy.py @@ -235,7 +235,12 @@ def _issue_next_turn(self, conv_id: str) -> None: } assert self._phase_issuer is not None - query_id = self._phase_issuer.issue(idx, data_override=data_override) + query_id = self._phase_issuer.issue( + idx, + data_override=data_override, + conversation_id=conv_id, + turn=turn, + ) if query_id is None: # Session stopping β€” signal done. assert self._all_done is not None @@ -274,6 +279,7 @@ def _handle_timeout(self, query_id: str, conv_id: str) -> None: ): self._phase_issuer.inflight -= 1 # type: ignore[attr-defined] del self._phase_issuer.uuid_to_index[query_id] # type: ignore[attr-defined] + self._phase_issuer.uuid_to_conv_info.pop(query_id, None) # type: ignore[attr-defined] logger.warning( "Turn timed out for conversation %s (query=%s)", conv_id, query_id diff --git a/src/inference_endpoint/load_generator/session.py b/src/inference_endpoint/load_generator/session.py index c69049ec..13ee0b0a 100644 --- a/src/inference_endpoint/load_generator/session.py +++ b/src/inference_endpoint/load_generator/session.py @@ -174,6 +174,7 @@ class PhaseIssuer: "_publisher", "_stop_check", "uuid_to_index", + "uuid_to_conv_info", "inflight", "issued_count", ) @@ -190,11 +191,16 @@ def __init__( self._publisher = publisher self._stop_check = stop_check self.uuid_to_index: dict[str, int] = {} + self.uuid_to_conv_info: dict[str, tuple[str, int | None]] = {} self.inflight: int = 0 self.issued_count: int = 0 def issue( - self, sample_index: int, data_override: dict[str, Any] | None = None + self, + sample_index: int, + data_override: dict[str, Any] | None = None, + conversation_id: str = "", + turn: int | None = None, ) -> str | None: """Load data, build Query, publish ISSUED, send to endpoint. @@ -218,6 +224,7 @@ def issue( data = {**data, **data_override} query = Query(id=query_id, data=data) self.uuid_to_index[query_id] = sample_index + self.uuid_to_conv_info[query_id] = (conversation_id, turn) ts = time.monotonic_ns() prompt_data: PromptData if isinstance(data, dict): @@ -242,6 +249,8 @@ def issue( event_type=SampleEventType.ISSUED, timestamp_ns=ts, sample_uuid=query_id, + conversation_id=conversation_id, + turn=turn, data=prompt_data, ) ) @@ -443,6 +452,11 @@ def _handle_response(self, resp: QueryResult | StreamChunk) -> None: if isinstance(resp, QueryResult): query_id = resp.id + conv_id_str, turn_num = ("", None) + if phase_issuer is not None: + conv_id_str, turn_num = phase_issuer.uuid_to_conv_info.pop( + query_id, ("", None) + ) self._publisher.publish( EventRecord( event_type=SampleEventType.COMPLETE, @@ -450,6 +464,8 @@ def _handle_response(self, resp: QueryResult | StreamChunk) -> None: if isinstance(resp.completed_at, int) else time.monotonic_ns(), sample_uuid=query_id, + conversation_id=conv_id_str, + turn=turn_num, data=resp.response_output, ) ) @@ -459,6 +475,8 @@ def _handle_response(self, resp: QueryResult | StreamChunk) -> None: event_type=ErrorEventType.GENERIC, timestamp_ns=time.monotonic_ns(), sample_uuid=query_id, + conversation_id=conv_id_str, + turn=turn_num, data=resp.error, ) ) diff --git a/src/inference_endpoint/load_generator/strategy.py b/src/inference_endpoint/load_generator/strategy.py index 5e7db26a..8e019429 100644 --- a/src/inference_endpoint/load_generator/strategy.py +++ b/src/inference_endpoint/load_generator/strategy.py @@ -48,7 +48,11 @@ class PhaseIssuerProtocol(Protocol): """Minimal interface that strategies see for issuing samples.""" def issue( - self, sample_index: int, data_override: dict[str, Any] | None = None + self, + sample_index: int, + data_override: dict[str, Any] | None = None, + conversation_id: str = "", + turn: int | None = None, ) -> str | None: """Issue a sample. Returns query_id, or None if the session is stopping. From 1272386f6468c7563db7fb450ce5bd853b4e70fb Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Thu, 14 May 2026 21:17:09 +0000 Subject: [PATCH 34/41] Address perf concerns Signed-off-by: Li, Tianmu --- .../endpoint_client/adapter_protocol.py | 14 +++++++------- .../openai/openai_msgspec_adapter.py | 8 ++------ 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/src/inference_endpoint/endpoint_client/adapter_protocol.py b/src/inference_endpoint/endpoint_client/adapter_protocol.py index ffa33485..a96faaa0 100644 --- a/src/inference_endpoint/endpoint_client/adapter_protocol.py +++ b/src/inference_endpoint/endpoint_client/adapter_protocol.py @@ -127,12 +127,12 @@ def parse_sse_chunk(cls, buffer: bytes, end_pos: int) -> list[Any]: """ json_docs = cls.SSE_DATA_PATTERN.findall(buffer[:end_pos]) parsed: list[Any] = [] - for json_doc in json_docs: - try: + # Note: if one frame is malformed, remaining frames are skipped + try: + for json_doc in json_docs: content = cls.decode_sse_message(json_doc) - except (msgspec.DecodeError, msgspec.ValidationError): - logger.warning("skipping malformed SSE frame: %s", json_doc[:120]) - continue - if content is not None: - parsed.append(content) + if content is not None: + parsed.append(content) + except (msgspec.DecodeError, msgspec.ValidationError) as exc: + logger.warning("skipping malformed SSE batch (%s)", type(exc).__name__) return parsed diff --git a/src/inference_endpoint/openai/openai_msgspec_adapter.py b/src/inference_endpoint/openai/openai_msgspec_adapter.py index 91ac7756..e68a4872 100644 --- a/src/inference_endpoint/openai/openai_msgspec_adapter.py +++ b/src/inference_endpoint/openai/openai_msgspec_adapter.py @@ -144,12 +144,8 @@ def to_endpoint_request(cls, query: Query) -> ChatCompletionRequest: Returns: msgspec.Struct ChatCompletionRequest """ - if "messages" in query.data and isinstance(query.data["messages"], list): - messages = [] - for message in query.data["messages"]: - if not isinstance(message, dict): - raise ValueError("messages entries must be dicts") - messages.append(_chat_message_from_dict(message)) + if "messages" in query.data: + messages = [_chat_message_from_dict(m) for m in query.data["messages"]] else: if "prompt" not in query.data: raise ValueError("prompt not found in query.data") From af035d99df40334b4cf0943d61debe56c0dd0adf Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Thu, 14 May 2026 22:56:17 +0000 Subject: [PATCH 35/41] fix: address PR #285 round-5 Copilot review comments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Revert 595faf4 (conversation_id/turn in EventRecord/PhaseIssuer/session) - Validate user row content and reject assistant(tool_calls)β†’user transition without intervening tool row in _validate_conversation_structure - Reject malformed tool_calls (present but not a non-empty list) on assistant rows - Guard _expand_tool_results against non-dict entries in tool_results list - Add "tools" to ColumnFilter optional_columns in both OpenAI adapters so single-turn datasets with a tools column are not silently stripped - Replace RuntimeError with logger.warning in _precompute_isl_for_multi_turn so template-incompatible tokenizers fall back instead of aborting the benchmark - Fix schema cross-validation to check only the performance dataset for multi_turn config instead of any dataset - Move seeding loop inside try/finally in MultiTurnStrategy.execute so cleanup runs even if _start_conversation raises - Add inflight/uuid_to_index to FakePhaseIssuer for _handle_timeout coverage - Strengthen test_no_matching_columns to assert unrelated columns are preserved - Update MULTI_TURN_QUICKSTART.md to accurately describe which turns produce sample events and how to correlate events back to conversations Co-Authored-By: Claude Sonnet 4.6 (1M context) --- docs/MULTI_TURN_QUICKSTART.md | 4 +-- .../services/event_logger/sql_writer.py | 5 --- .../commands/benchmark/execute.py | 7 ++-- src/inference_endpoint/config/schema.py | 27 +++++++++++---- src/inference_endpoint/core/record.py | 2 -- .../dataset_manager/multi_turn_dataset.py | 33 +++++++++++++++++++ .../load_generator/multi_turn_strategy.py | 18 ++++------ .../load_generator/session.py | 18 +--------- .../load_generator/strategy.py | 6 +--- .../openai/openai_adapter.py | 2 +- .../openai/openai_msgspec_adapter.py | 1 + .../test_multi_turn_dataset.py | 17 +++++----- tests/unit/dataset_manager/test_transforms.py | 4 ++- .../test_multi_turn_strategy.py | 2 ++ 14 files changed, 83 insertions(+), 63 deletions(-) diff --git a/docs/MULTI_TURN_QUICKSTART.md b/docs/MULTI_TURN_QUICKSTART.md index ddcc9828..87669af7 100644 --- a/docs/MULTI_TURN_QUICKSTART.md +++ b/docs/MULTI_TURN_QUICKSTART.md @@ -71,7 +71,7 @@ That's it! Your benchmark will now: - βœ… Enforce turn ordering (turn N+1 waits for turn N) - βœ… Include conversation history in each request -- βœ… Log all turns to events.jsonl +- βœ… Log all issued (client) turns to events.jsonl β€” scripted assistant rows are context only and do not produce sample events --- @@ -93,7 +93,7 @@ dataset's `conversation_metadata["samples"]`, which maps sample indices to Currently available: - **Per-turn metrics**: Latency, TTFT, TPOT for each turn -- **Conversation tracking**: All events tagged with conversation_id +- **Conversation tracking**: events are keyed by `sample_uuid` only; correlate any event back to a conversation by joining through `sample_idx_map.json` and `conversation_metadata["samples"]` _Note: Per-conversation aggregation (e.g., "conversations/sec") is coming in a future update._ diff --git a/src/inference_endpoint/async_utils/services/event_logger/sql_writer.py b/src/inference_endpoint/async_utils/services/event_logger/sql_writer.py index 52c9b010..fd6a1559 100644 --- a/src/inference_endpoint/async_utils/services/event_logger/sql_writer.py +++ b/src/inference_endpoint/async_utils/services/event_logger/sql_writer.py @@ -50,9 +50,6 @@ class EventRowModel(Base): timestamp_ns: Mapped[int] = mapped_column(BigInteger, nullable=False) """Monotonic timestamp in nanoseconds.""" - conversation_id: Mapped[str] = mapped_column(String, nullable=False, default="") - turn: Mapped[int | None] = mapped_column(Integer, nullable=True) - data: Mapped[bytes] = mapped_column(LargeBinary, nullable=False, default=b"") """JSON-encoded event data.""" @@ -64,8 +61,6 @@ def _record_to_row(record: EventRecord) -> EventRowModel: sample_uuid=record.sample_uuid, event_type=topic, timestamp_ns=record.timestamp_ns, - conversation_id=record.conversation_id, - turn=record.turn, data=msgspec.json.encode(record.data), ) diff --git a/src/inference_endpoint/commands/benchmark/execute.py b/src/inference_endpoint/commands/benchmark/execute.py index 472631cc..076ea548 100644 --- a/src/inference_endpoint/commands/benchmark/execute.py +++ b/src/inference_endpoint/commands/benchmark/execute.py @@ -350,8 +350,11 @@ def _precompute_isl_for_multi_turn( ) total_with_messages = len([s for s in (dataloader.data or []) if s.get("messages")]) if total_with_messages > 0 and skipped == total_with_messages: - raise RuntimeError( - "ISL precomputation failed for all samples; check tokenizer/template compatibility" + logger.warning( + "ISL precomputation: all %d turn(s) failed apply_chat_template; " + "ISL metrics will use text-tokenization fallback. " + "Check tokenizer/template compatibility.", + total_with_messages, ) diff --git a/src/inference_endpoint/config/schema.py b/src/inference_endpoint/config/schema.py index db5aaebc..d3390483 100644 --- a/src/inference_endpoint/config/schema.py +++ b/src/inference_endpoint/config/schema.py @@ -626,17 +626,30 @@ def _resolve_and_validate(self) -> Self: "Online mode requires --load-pattern (poisson, concurrency, or multi_turn)" ) - # Cross-validate load_pattern.type=multi_turn ↔ dataset.multi_turn config - has_multi_turn_dataset = any( - d.multi_turn is not None for d in (self.datasets or []) + # Cross-validate load_pattern.type=multi_turn ↔ performance dataset.multi_turn config + has_multi_turn_perf_dataset = any( + d.multi_turn is not None + for d in (self.datasets or []) + if d.type == DatasetType.PERFORMANCE ) - if lp.type == LoadPatternType.MULTI_TURN and not has_multi_turn_dataset: + has_multi_turn_non_perf_dataset = any( + d.multi_turn is not None + for d in (self.datasets or []) + if d.type != DatasetType.PERFORMANCE + ) + if has_multi_turn_non_perf_dataset: + raise ValueError( + "multi_turn config is only supported on performance datasets; " + "accuracy datasets with multi_turn are not supported" + ) + if lp.type == LoadPatternType.MULTI_TURN and not has_multi_turn_perf_dataset: raise ValueError( - "load_pattern.type=multi_turn requires at least one dataset with multi_turn config" + "load_pattern.type=multi_turn requires the performance dataset to have multi_turn config" ) - if has_multi_turn_dataset and lp.type != LoadPatternType.MULTI_TURN: + if has_multi_turn_perf_dataset and lp.type != LoadPatternType.MULTI_TURN: raise ValueError( - f"Datasets with multi_turn config require load_pattern.type=multi_turn, got '{lp.type}'" + f"Performance dataset with multi_turn config requires load_pattern.type=multi_turn, " + f"got '{lp.type}'" ) return self diff --git a/src/inference_endpoint/core/record.py b/src/inference_endpoint/core/record.py index 9ac60e1b..da35389e 100644 --- a/src/inference_endpoint/core/record.py +++ b/src/inference_endpoint/core/record.py @@ -153,8 +153,6 @@ class EventRecord(msgspec.Struct, kw_only=True, frozen=True, gc=False): # type: event_type: EventType timestamp_ns: int = msgspec.field(default_factory=time.monotonic_ns) sample_uuid: str = "" - conversation_id: str = "" - turn: int | None = None data: OUTPUT_TYPE | PromptData | ErrorData | None = None diff --git a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py index afea505c..447aba42 100644 --- a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py +++ b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py @@ -88,6 +88,11 @@ def _expand_tool_results(row: dict) -> list[dict]: return [] messages = [] for i, result in enumerate(tool_results): + if not isinstance(result, dict): + raise InputValidationError( + f"tool_results[{i}] in conversation {row.get('conversation_id')!r} " + f"turn {row.get('turn')} must be a dict, got {type(result).__name__}" + ) tool_call_id = result.get("tool_call_id") content = result.get("content") if tool_call_id is None: @@ -236,7 +241,35 @@ def _validate_conversation_structure(self): f"Conversation {conv_id} turn {row['turn']}: " "assistant rows must have non-empty 'content' or non-empty 'tool_calls'" ) + if ( + tool_calls is not None + and not (isinstance(tool_calls, float) and pd.isna(tool_calls)) + and not has_tool_calls + ): + raise InputValidationError( + f"Conversation {conv_id} turn {row['turn']}: " + "'tool_calls' field is present but is not a non-empty list; " + "omit the field or provide a valid non-empty list" + ) prev_assistant_had_tool_calls = has_tool_calls + elif role == "user": + content = row.get("content") + is_empty_content = ( + content is None + or (isinstance(content, float) and pd.isna(content)) + or content == "" + ) + if is_empty_content: + raise InputValidationError( + f"Conversation {conv_id} turn {row['turn']}: " + "user rows must have non-empty 'content'" + ) + if state == "assistant" and prev_assistant_had_tool_calls: + raise InputValidationError( + f"Conversation {conv_id} turn {row['turn']}: " + "'user' row cannot follow an assistant row with 'tool_calls'; " + "a 'tool' result row is required first" + ) if role != "assistant": prev_assistant_had_tool_calls = False diff --git a/src/inference_endpoint/load_generator/multi_turn_strategy.py b/src/inference_endpoint/load_generator/multi_turn_strategy.py index aa847ba5..e877a8d2 100644 --- a/src/inference_endpoint/load_generator/multi_turn_strategy.py +++ b/src/inference_endpoint/load_generator/multi_turn_strategy.py @@ -176,13 +176,13 @@ async def execute(self, phase_issuer: PhaseIssuerProtocol) -> int: if self._target_concurrency is not None and self._target_concurrency > 0 else len(self._pending_convs) ) - for _ in range(n_to_start): - self._start_conversation() + try: + for _ in range(n_to_start): + self._start_conversation() - if not self._active_iters and not self._inflight: - return phase_issuer.issued_count + if not self._active_iters and not self._inflight: + return phase_issuer.issued_count - try: await self._all_done.wait() if self._error is not None: raise self._error @@ -235,12 +235,7 @@ def _issue_next_turn(self, conv_id: str) -> None: } assert self._phase_issuer is not None - query_id = self._phase_issuer.issue( - idx, - data_override=data_override, - conversation_id=conv_id, - turn=turn, - ) + query_id = self._phase_issuer.issue(idx, data_override=data_override) if query_id is None: # Session stopping β€” signal done. assert self._all_done is not None @@ -279,7 +274,6 @@ def _handle_timeout(self, query_id: str, conv_id: str) -> None: ): self._phase_issuer.inflight -= 1 # type: ignore[attr-defined] del self._phase_issuer.uuid_to_index[query_id] # type: ignore[attr-defined] - self._phase_issuer.uuid_to_conv_info.pop(query_id, None) # type: ignore[attr-defined] logger.warning( "Turn timed out for conversation %s (query=%s)", conv_id, query_id diff --git a/src/inference_endpoint/load_generator/session.py b/src/inference_endpoint/load_generator/session.py index ca6e0b73..4f6fa747 100644 --- a/src/inference_endpoint/load_generator/session.py +++ b/src/inference_endpoint/load_generator/session.py @@ -174,7 +174,6 @@ class PhaseIssuer: "_publisher", "_stop_check", "uuid_to_index", - "uuid_to_conv_info", "inflight", "issued_count", ) @@ -191,16 +190,11 @@ def __init__( self._publisher = publisher self._stop_check = stop_check self.uuid_to_index: dict[str, int] = {} - self.uuid_to_conv_info: dict[str, tuple[str, int | None]] = {} self.inflight: int = 0 self.issued_count: int = 0 def issue( - self, - sample_index: int, - data_override: dict[str, Any] | None = None, - conversation_id: str = "", - turn: int | None = None, + self, sample_index: int, data_override: dict[str, Any] | None = None ) -> str | None: """Load data, build Query, publish ISSUED, send to endpoint. @@ -224,7 +218,6 @@ def issue( data = {**data, **data_override} query = Query(id=query_id, data=data) self.uuid_to_index[query_id] = sample_index - self.uuid_to_conv_info[query_id] = (conversation_id, turn) ts = time.monotonic_ns() prompt_data: PromptData if isinstance(data, dict): @@ -249,8 +242,6 @@ def issue( event_type=SampleEventType.ISSUED, timestamp_ns=ts, sample_uuid=query_id, - conversation_id=conversation_id, - turn=turn, data=prompt_data, ) ) @@ -452,11 +443,6 @@ def _handle_response(self, resp: QueryResult | StreamChunk) -> None: if isinstance(resp, QueryResult): query_id = resp.id - conv_id_str, turn_num = ("", None) - if phase_issuer is not None: - conv_id_str, turn_num = phase_issuer.uuid_to_conv_info.pop( - query_id, ("", None) - ) # Emit ERROR before COMPLETE for failed queries so downstream # consumers (notably the metrics aggregator) see the ERROR # while the in-flight tracked row still exists. COMPLETE @@ -485,8 +471,6 @@ def _handle_response(self, resp: QueryResult | StreamChunk) -> None: if isinstance(resp.completed_at, int) else time.monotonic_ns(), sample_uuid=query_id, - conversation_id=conv_id_str, - turn=turn_num, data=resp.response_output, ) ) diff --git a/src/inference_endpoint/load_generator/strategy.py b/src/inference_endpoint/load_generator/strategy.py index 8e019429..5e7db26a 100644 --- a/src/inference_endpoint/load_generator/strategy.py +++ b/src/inference_endpoint/load_generator/strategy.py @@ -48,11 +48,7 @@ class PhaseIssuerProtocol(Protocol): """Minimal interface that strategies see for issuing samples.""" def issue( - self, - sample_index: int, - data_override: dict[str, Any] | None = None, - conversation_id: str = "", - turn: int | None = None, + self, sample_index: int, data_override: dict[str, Any] | None = None ) -> str | None: """Issue a sample. Returns query_id, or None if the session is stopping. diff --git a/src/inference_endpoint/openai/openai_adapter.py b/src/inference_endpoint/openai/openai_adapter.py index 30091810..8f516b7b 100644 --- a/src/inference_endpoint/openai/openai_adapter.py +++ b/src/inference_endpoint/openai/openai_adapter.py @@ -58,7 +58,7 @@ def dataset_transforms(cls, model_params: ModelParams) -> list[Transform]: return [ ColumnFilter( required_columns=["prompt"], - optional_columns=["system"], + optional_columns=["system", "tools"], ), AddStaticColumns(metadata), ] diff --git a/src/inference_endpoint/openai/openai_msgspec_adapter.py b/src/inference_endpoint/openai/openai_msgspec_adapter.py index e68a4872..9ef5b3a8 100644 --- a/src/inference_endpoint/openai/openai_msgspec_adapter.py +++ b/src/inference_endpoint/openai/openai_msgspec_adapter.py @@ -95,6 +95,7 @@ def dataset_transforms(cls, model_params: ModelParams) -> list[Transform]: "logit_bias", "user", "chat_template", + "tools", ] return [ ColumnFilter( diff --git a/tests/unit/dataset_manager/test_multi_turn_dataset.py b/tests/unit/dataset_manager/test_multi_turn_dataset.py index e893a47c..6b0a94e6 100644 --- a/tests/unit/dataset_manager/test_multi_turn_dataset.py +++ b/tests/unit/dataset_manager/test_multi_turn_dataset.py @@ -22,6 +22,7 @@ import pytest from inference_endpoint.dataset_manager.dataset import DatasetFormat from inference_endpoint.dataset_manager.multi_turn_dataset import MultiTurnDataset +from inference_endpoint.exceptions import InputValidationError @pytest.fixture @@ -223,15 +224,13 @@ def test_multi_turn_dataset_validation_invalid_role_sequence( @pytest.mark.unit def test_multi_turn_dataset_validation_missing_fields(missing_fields_jsonl): - """Missing content field is preserved as None in the loaded sample.""" - dataset = MultiTurnDataset.load_from_file( - missing_fields_jsonl, format=DatasetFormat.JSONL - ) - dataset.load() - - sample = dataset.load_sample(0) - # Missing content is no longer propagated to the sample dict - assert "content" not in sample + """User rows with missing content are rejected at construction time.""" + with pytest.raises( + InputValidationError, match="user rows must have non-empty 'content'" + ): + MultiTurnDataset.load_from_file( + missing_fields_jsonl, format=DatasetFormat.JSONL + ) @pytest.mark.unit diff --git a/tests/unit/dataset_manager/test_transforms.py b/tests/unit/dataset_manager/test_transforms.py index 0ea35f24..7dee620f 100644 --- a/tests/unit/dataset_manager/test_transforms.py +++ b/tests/unit/dataset_manager/test_transforms.py @@ -822,8 +822,10 @@ def test_no_matching_columns(self): transform = MakeAdapterCompatible() result = transform(df) - # Should not raise error or create prompt column + # Should not raise error, should not create prompt, and must preserve unrelated columns assert "prompt" not in result.columns + assert "unrelated" in result.columns + assert list(result["unrelated"]) == ["data"] class TestAddStaticColumnsNoOverwrite: diff --git a/tests/unit/load_generator/test_multi_turn_strategy.py b/tests/unit/load_generator/test_multi_turn_strategy.py index 0f9e41c8..8fa82614 100644 --- a/tests/unit/load_generator/test_multi_turn_strategy.py +++ b/tests/unit/load_generator/test_multi_turn_strategy.py @@ -35,6 +35,8 @@ def __init__(self, stop_after: int | None = None): self._stop_after = stop_after self.issued: list[int] = [] self.issued_count = 0 + self.inflight: int = 0 + self.uuid_to_index: dict[str, int] = {} def issue(self, sample_index: int, data_override: dict | None = None) -> str | None: if self._stop_after is not None and self._count >= self._stop_after: From dd477964438d6f880cb4f8f023ce26eee4dfa274 Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Thu, 14 May 2026 23:01:33 +0000 Subject: [PATCH 36/41] fix: update test_schema.py error message assertions for Fix 5 Match new error messages from the multi_turn schema cross-validation fix that scopes validation to the performance dataset specifically. Co-Authored-By: Claude Sonnet 4.6 (1M context) --- tests/unit/config/test_schema.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/unit/config/test_schema.py b/tests/unit/config/test_schema.py index b837ab4b..ea7dc9ee 100644 --- a/tests/unit/config/test_schema.py +++ b/tests/unit/config/test_schema.py @@ -499,7 +499,10 @@ def test_multi_turn_requires_target_concurrency(self): @pytest.mark.unit def test_multi_turn_without_multi_turn_dataset_rejected(self): - with pytest.raises(ValueError, match="requires at least one dataset"): + with pytest.raises( + ValueError, + match="requires the performance dataset to have multi_turn config", + ): BenchmarkConfig( type=TestType.ONLINE, model_params={"name": "M"}, @@ -512,7 +515,7 @@ def test_multi_turn_without_multi_turn_dataset_rejected(self): @pytest.mark.unit def test_multi_turn_dataset_without_multi_turn_load_pattern_rejected(self): - with pytest.raises(ValueError, match="require load_pattern.type=multi_turn"): + with pytest.raises(ValueError, match="requires load_pattern.type=multi_turn"): BenchmarkConfig( type=TestType.ONLINE, model_params={"name": "M"}, From 15a4108165bd555ef277e18c55fcc306dca031cc Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Thu, 14 May 2026 23:13:59 +0000 Subject: [PATCH 37/41] Fix ci test failure post merge Signed-off-by: Li, Tianmu --- .../metrics_aggregator/metrics_table.py | 4 +- .../metrics_aggregator/test_metrics_table.py | 42 ++++++++----------- 2 files changed, 20 insertions(+), 26 deletions(-) diff --git a/src/inference_endpoint/async_utils/services/metrics_aggregator/metrics_table.py b/src/inference_endpoint/async_utils/services/metrics_aggregator/metrics_table.py index 887c11a8..46a17e92 100644 --- a/src/inference_endpoint/async_utils/services/metrics_aggregator/metrics_table.py +++ b/src/inference_endpoint/async_utils/services/metrics_aggregator/metrics_table.py @@ -231,7 +231,7 @@ def fire(self, ev_rec, row, pre_change): if message_parts is not None: content, reasoning, tool_calls = message_parts pool, loop = self._pool, self._loop - store, name = self.kv_store, self.metric_name + registry, name = self.registry, self.metric_name uuid = row.sample_uuid async def _tokenize_message_and_emit() -> None: @@ -241,7 +241,7 @@ async def _tokenize_message_and_emit() -> None: ) value = self._compute_value(count, ev_rec, pre_change) if value is not None: - store.update(name, value) + registry.record(name, value) except Exception: logger.exception("%s tokenization failed for %s", name, uuid) diff --git a/tests/unit/async_utils/services/metrics_aggregator/test_metrics_table.py b/tests/unit/async_utils/services/metrics_aggregator/test_metrics_table.py index 8d5e2a6d..d3781ad9 100644 --- a/tests/unit/async_utils/services/metrics_aggregator/test_metrics_table.py +++ b/tests/unit/async_utils/services/metrics_aggregator/test_metrics_table.py @@ -290,16 +290,17 @@ async def test_osl_with_tool_calls_uses_message_path(self): """OslTrigger stores combined content+tool_calls word count.""" from inference_endpoint.async_utils.services.metrics_aggregator.metrics_table import ( OslTrigger, + SampleRow, ) from inference_endpoint.core.types import TextModelOutput - from .conftest import InMemoryKVStore, MockTokenizePool + from .conftest import MockTokenizePool, snapshot_series_count - kv = InMemoryKVStore() + registry = MetricsRegistry() + registry.register_series("osl", hdr_low=1, hdr_high=100_000) loop = asyncio.get_running_loop() pool = MockTokenizePool(delay=0) - trigger = OslTrigger(kv, pool, loop) - trigger.kv_store.create_key("osl", "series", dtype=int) + trigger = OslTrigger(registry, pool, loop) tool_calls = ( { @@ -315,18 +316,12 @@ async def test_osl_with_tool_calls_uses_message_path(self): sample_uuid="s1", data=tmo, ) - from inference_endpoint.async_utils.services.metrics_aggregator.metrics_table import ( - SampleRow, - ) - row = SampleRow(sample_uuid="s1") task = trigger.fire(ev, row, {}) assert task is not None await task - values = kv.get_series_values("osl") - assert len(values) == 1 - assert values[0] > 0 + assert snapshot_series_count(registry, "osl") == 1 async def test_osl_without_tool_calls_uses_text_path(self): """OslTrigger uses text path for output with no tool_calls (regression guard).""" @@ -336,13 +331,13 @@ async def test_osl_without_tool_calls_uses_text_path(self): ) from inference_endpoint.core.types import TextModelOutput - from .conftest import InMemoryKVStore, MockTokenizePool + from .conftest import MockTokenizePool, snapshot_series_count - kv = InMemoryKVStore() + registry = MetricsRegistry() + registry.register_series("osl", hdr_low=1, hdr_high=100_000) loop = asyncio.get_running_loop() pool = MockTokenizePool(delay=0) - trigger = OslTrigger(kv, pool, loop) - trigger.kv_store.create_key("osl", "series", dtype=int) + trigger = OslTrigger(registry, pool, loop) tmo = TextModelOutput(output="hello world") ev = EventRecord( @@ -356,8 +351,7 @@ async def test_osl_without_tool_calls_uses_text_path(self): assert task is not None await task - values = kv.get_series_values("osl") - assert values == [2] # "hello world" -> 2 words + assert snapshot_series_count(registry, "osl") == 1 @pytest.mark.unit @@ -374,13 +368,15 @@ async def test_tpot_tool_calls_only_response(self): ) from inference_endpoint.core.types import TextModelOutput - from .conftest import InMemoryKVStore, MockTokenizePool + from .conftest import MockTokenizePool, snapshot_series_count - kv = InMemoryKVStore() + registry = MetricsRegistry() + registry.register_series( + "tpot_ns", hdr_low=1, hdr_high=100_000_000_000, dtype=float + ) loop = asyncio.get_running_loop() pool = MockTokenizePool(delay=0) - trigger = TpotTrigger(kv, pool, loop) - trigger.kv_store.create_key("tpot_ns", "series", dtype=float) + trigger = TpotTrigger(registry, pool, loop) tool_calls = ( { @@ -403,6 +399,4 @@ async def test_tpot_tool_calls_only_response(self): assert task is not None await task - values = kv.get_series_values("tpot_ns") - assert len(values) == 1 - assert values[0] > 0 + assert snapshot_series_count(registry, "tpot_ns") == 1 From 2ac66be46b36a13583696b5f523dc3a16c4464bb Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Fri, 15 May 2026 02:02:44 +0000 Subject: [PATCH 38/41] fix: address PR #285 round-6 Copilot review comments Co-Authored-By: Claude Sonnet 4.6 (1M context) --- docs/MULTI_TURN_QUICKSTART.md | 6 +- .../commands/benchmark/execute.py | 29 ++++--- .../dataset_manager/multi_turn_dataset.py | 1 + .../load_generator/multi_turn_strategy.py | 65 ++++++++++++++-- .../load_generator/session.py | 4 + .../test_multi_turn_dataset.py | 1 + .../unit/load_generator/test_async_session.py | 20 +++++ .../test_multi_turn_strategy.py | 75 ++++++++++++++++++- 8 files changed, 179 insertions(+), 22 deletions(-) diff --git a/docs/MULTI_TURN_QUICKSTART.md b/docs/MULTI_TURN_QUICKSTART.md index 87669af7..5a48ab75 100644 --- a/docs/MULTI_TURN_QUICKSTART.md +++ b/docs/MULTI_TURN_QUICKSTART.md @@ -211,8 +211,10 @@ datasets: ### 1. Use the Example Dataset ```bash -cd examples/09_MultiTurn -inference-endpoint benchmark from-config --config multi_turn_benchmark.yaml +# Run from the repository root β€” dataset paths in the bundled YAML are +# repo-relative (e.g. examples/09_MultiTurn/customer_support_conversations.jsonl). +inference-endpoint benchmark from-config \ + --config examples/09_MultiTurn/multi_turn_benchmark.yaml ``` ### 2. Check the Logs diff --git a/src/inference_endpoint/commands/benchmark/execute.py b/src/inference_endpoint/commands/benchmark/execute.py index 076ea548..1a569f60 100644 --- a/src/inference_endpoint/commands/benchmark/execute.py +++ b/src/inference_endpoint/commands/benchmark/execute.py @@ -309,7 +309,15 @@ def _precompute_isl_for_multi_turn( Only affects dataset-history turns; live-history turns override 'messages' at runtime so the stored input_tokens are stale (acceptable approximation). """ - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + try: + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + except Exception: + logger.exception( + "ISL pre-computation: failed to load tokenizer %s; " + "falling back to text-tokenization at runtime", + tokenizer_name, + ) + return skipped = 0 first_failure_logged = False for sample in dataloader.data or []: @@ -327,8 +335,10 @@ def _precompute_isl_for_multi_turn( ), } normalized_messages.append(msg) + tools = sample.get("tools") raw = tokenizer.apply_chat_template( normalized_messages, + tools=tools if tools else None, tokenize=True, add_generation_prompt=True, ) @@ -441,7 +451,6 @@ def _build_phases( # Accuracy phases β€” use eval_cfg.dataset_name as phase name so it matches # what Scorer._load_sample_index_map() looks up in sample_idx_map.json - perf_lp = ctx.rt_settings.load_pattern for eval_cfg in ctx.eval_configs: acc_ds = eval_cfg.dataset if isinstance(acc_ds, MultiTurnDataset): @@ -449,15 +458,12 @@ def _build_phases( f"Accuracy dataset '{eval_cfg.dataset_name}' is a MultiTurnDataset, " "which is not yet supported for accuracy evaluation." ) - if perf_lp is not None and perf_lp.type == LoadPatternType.MULTI_TURN: - # Plain accuracy datasets are single-turn; the multi-turn scheduler - # requires MultiTurnDataset. Downgrade to CONCURRENCY with same cap. - acc_load_pattern: LoadPattern | None = LoadPattern( - type=LoadPatternType.CONCURRENCY, - target_concurrency=perf_lp.target_concurrency, - ) - else: - acc_load_pattern = perf_lp + # Accuracy phases run at MAX_THROUGHPUT; inheriting perf_lp (e.g. POISSON) + # would silently rate-limit evaluation until a multi-turn accuracy strategy + # and QPS-budgeting support are added. + acc_load_pattern: LoadPattern | None = LoadPattern( + type=LoadPatternType.MAX_THROUGHPUT + ) acc_settings = RuntimeSettings( metric_target=ctx.rt_settings.metric_target, reported_metrics=ctx.rt_settings.reported_metrics, @@ -660,6 +666,7 @@ def _on_sample_complete(result: QueryResult) -> None: ) multi_turn_strategy._session_on_sample_complete = _on_sample_complete + multi_turn_strategy._session_publisher = publisher else: _on_sample_complete = collector.on_complete_hook diff --git a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py index 447aba42..7b0484bd 100644 --- a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py +++ b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py @@ -517,6 +517,7 @@ def load( max_tokens_val = ( sample.pop("max_new_tokens", None) or sample.get("max_completion_tokens") + or sample.get("max_tokens") or 128 ) sample["max_new_tokens"] = max_tokens_val diff --git a/src/inference_endpoint/load_generator/multi_turn_strategy.py b/src/inference_endpoint/load_generator/multi_turn_strategy.py index e877a8d2..f5482481 100644 --- a/src/inference_endpoint/load_generator/multi_turn_strategy.py +++ b/src/inference_endpoint/load_generator/multi_turn_strategy.py @@ -17,11 +17,13 @@ import asyncio import logging +import time from collections import defaultdict, deque from collections.abc import Iterator from typing import Any from ..config.schema import MultiTurnConfig +from ..core.record import ErrorEventType, EventRecord, SampleEventType from ..core.types import ErrorData, QueryResult, TextModelOutput from ..dataset_manager.multi_turn_dataset import ConversationMetadata from ..exceptions import InputValidationError @@ -116,6 +118,7 @@ def __init__( # Composite on_sample_complete callback set by execute.py; used by # _handle_timeout to route synthetic failure results. self._session_on_sample_complete: Any | None = None + self._session_publisher: Any | None = None # Maps query_id -> conversation_id for routing completions. self._inflight: dict[str, str] = {} @@ -273,7 +276,6 @@ def _handle_timeout(self, query_id: str, conv_id: str) -> None: and query_id in self._phase_issuer.uuid_to_index # type: ignore[attr-defined] ): self._phase_issuer.inflight -= 1 # type: ignore[attr-defined] - del self._phase_issuer.uuid_to_index[query_id] # type: ignore[attr-defined] logger.warning( "Turn timed out for conversation %s (query=%s)", conv_id, query_id @@ -285,14 +287,40 @@ def _handle_timeout(self, query_id: str, conv_id: str) -> None: # Route a synthetic failure result so the accuracy collector and event # logger see the timed-out turn. + timeout_result = QueryResult( + id=query_id, + error=ErrorData( + error_type="TurnTimeout", + error_message=f"turn timeout after {self._turn_timeout_s}s", + ), + ) + + # Publish ERROR + COMPLETE so the metrics aggregator and event logger + # see the timeout (matches BenchmarkSession._handle_response ordering). + if self._session_publisher is not None: + try: + self._session_publisher.publish( + EventRecord( + event_type=ErrorEventType.GENERIC, + timestamp_ns=time.monotonic_ns(), + sample_uuid=query_id, + data=timeout_result.error, + ) + ) + self._session_publisher.publish( + EventRecord( + event_type=SampleEventType.COMPLETE, + timestamp_ns=time.monotonic_ns(), + sample_uuid=query_id, + data=None, + ) + ) + except Exception: + logger.exception( + "Failed to publish timeout EventRecords (query=%s)", query_id + ) + if self._session_on_sample_complete is not None: - timeout_result = QueryResult( - id=query_id, - error=ErrorData( - error_type="TurnTimeout", - error_message=f"turn timeout after {self._turn_timeout_s}s", - ), - ) try: self._session_on_sample_complete(timeout_result) except Exception: @@ -376,6 +404,27 @@ def on_sample_complete(self, result: QueryResult) -> None: ) return + # If this turn failed, abandon the rest of the conversation: replaying + # later turns against a corrupt history (assistant placeholder / + # missing tool result) is meaningless and matches the timeout path. + if result.error is not None: + it = self._active_iters.pop(conv_id, None) + dropped = 0 + if it is not None: + for _ in it: + self._conv_manager.mark_turn_failed( + conv_id, store_in_history=self._store_in_history + ) + dropped += 1 + if dropped: + logger.warning( + "turn error on conv=%s dropped %d remaining client turn(s)", + conv_id, + dropped, + ) + self._fill_slot() + return + try: self._issue_next_turn(conv_id) except Exception as exc: diff --git a/src/inference_endpoint/load_generator/session.py b/src/inference_endpoint/load_generator/session.py index 4f6fa747..76ddb5a5 100644 --- a/src/inference_endpoint/load_generator/session.py +++ b/src/inference_endpoint/load_generator/session.py @@ -21,6 +21,7 @@ from __future__ import annotations import asyncio +import json import logging import os import time @@ -64,6 +65,9 @@ def _extract_prompt_text(messages: list[Any]) -> str | None: and p.get("type") == "text" and isinstance(p.get("text"), str) ) + tc = m.get("tool_calls") + if tc: + parts.append(json.dumps(tc, separators=(",", ":"))) return "\n".join(parts) if parts else None diff --git a/tests/unit/dataset_manager/test_multi_turn_dataset.py b/tests/unit/dataset_manager/test_multi_turn_dataset.py index 6b0a94e6..03659307 100644 --- a/tests/unit/dataset_manager/test_multi_turn_dataset.py +++ b/tests/unit/dataset_manager/test_multi_turn_dataset.py @@ -479,6 +479,7 @@ def test_multi_turn_dataset_additional_fields(): sample = dataset.load_sample(0) assert sample["model"] == "gpt-4" assert sample["max_completion_tokens"] == 256 + assert sample["max_tokens"] == 256 assert sample["temperature"] == pytest.approx(0.7) finally: diff --git a/tests/unit/load_generator/test_async_session.py b/tests/unit/load_generator/test_async_session.py index 53104c39..a82ffabd 100644 --- a/tests/unit/load_generator/test_async_session.py +++ b/tests/unit/load_generator/test_async_session.py @@ -945,3 +945,23 @@ def test_list_content_with_no_text_parts_returns_none(self): def test_non_dict_messages_skipped(self): messages = ["not a dict", {"role": "user", "content": "Valid"}] assert _extract_prompt_text(messages) == "Valid" + + def test_tool_calls_included(self): + messages = [ + {"role": "user", "content": "What's the weather?"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "c1", + "type": "function", + "function": {"name": "get_weather", "arguments": "{}"}, + } + ], + }, + ] + result = _extract_prompt_text(messages) + assert result is not None + assert "What's the weather?" in result + assert "get_weather" in result diff --git a/tests/unit/load_generator/test_multi_turn_strategy.py b/tests/unit/load_generator/test_multi_turn_strategy.py index 8fa82614..03c2f1a2 100644 --- a/tests/unit/load_generator/test_multi_turn_strategy.py +++ b/tests/unit/load_generator/test_multi_turn_strategy.py @@ -16,9 +16,11 @@ """Unit tests for MultiTurnStrategy.""" import asyncio +from unittest.mock import MagicMock import pytest -from inference_endpoint.core.types import QueryResult, TextModelOutput +from inference_endpoint.core.record import ErrorEventType, SampleEventType +from inference_endpoint.core.types import ErrorData, QueryResult, TextModelOutput from inference_endpoint.dataset_manager.multi_turn_dataset import ( ConversationMetadata, ConversationSampleEntry, @@ -247,6 +249,7 @@ async def test_error_response_marks_turn_failed(): strategy = MultiTurnStrategy(conv_manager, metadata) strategy._inflight["q0001"] = "conv1" + strategy._all_done = asyncio.Event() result = QueryResult( id="q0001", @@ -579,3 +582,73 @@ async def auto_respond(): # Single slot: conv1 turns (samples 0,1) must be issued before conv2 turns (2,3) assert issuer.issued[:2] == [0, 1], "Conv1 turns should be issued before conv2" assert issuer.issued[2:] == [2, 3], "Conv2 turns should follow conv1" + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_timeout_publishes_error_and_complete_events(): + """_handle_timeout publishes ERROR then COMPLETE EventRecords via session_publisher.""" + conv_manager = ConversationManager() + conv_manager.get_or_create("conv-x", expected_client_turns=1) + metadata = _make_dataset_metadata({"conv-x": [1]}) + strategy = MultiTurnStrategy(conv_manager, metadata) + + publisher = MagicMock() + strategy._session_publisher = publisher + + # Seed _inflight so _handle_timeout finds the entry + strategy._inflight["q-x"] = "conv-x" + strategy._active_iters["conv-x"] = iter([]) + + strategy._all_done = asyncio.Event() + strategy._loop = asyncio.get_running_loop() + strategy._phase_issuer = None + + strategy._handle_timeout("q-x", "conv-x") + + assert publisher.publish.call_count == 2 + first_call, second_call = publisher.publish.call_args_list + first_record = first_call.args[0] + second_record = second_call.args[0] + + assert first_record.event_type == ErrorEventType.GENERIC + assert first_record.sample_uuid == "q-x" + assert second_record.event_type == SampleEventType.COMPLETE + assert second_record.sample_uuid == "q-x" + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_error_turn_aborts_remaining_turns(): + """on_sample_complete with result.error aborts and marks-failed remaining turns.""" + conv_manager = ConversationManager() + conv_manager.get_or_create("conv1", expected_client_turns=3) + metadata = _make_dataset_metadata({"conv1": [1, 2, 3]}) + strategy = MultiTurnStrategy(conv_manager, metadata) + issuer = FakePhaseIssuer() + + strategy._all_done = asyncio.Event() + strategy._loop = asyncio.get_running_loop() + strategy._phase_issuer = issuer + + # Seed: conv1 is active with turns 2 and 3 still pending + remaining_turns = iter([(1, 2), (2, 3)]) + strategy._active_iters["conv1"] = remaining_turns + strategy._inflight["q0001"] = "conv1" + strategy._conv_states["conv1"] = conv_manager.get_state("conv1") + + result = QueryResult( + id="q0001", + response_output=None, + error=ErrorData( + error_type="endpoint_error", error_message="500 Internal Server Error" + ), + ) + strategy.on_sample_complete(result) + + # Conversation should no longer be active + assert "conv1" not in strategy._active_iters + # Remaining 2 turns were marked failed + state = conv_manager.get_state("conv1") + assert state is not None + assert state.failed_turns == 3 # the failing turn + 2 dropped From 44423646552ce3d3fd6f7100bf622445c27b4fca Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Fri, 15 May 2026 02:56:09 +0000 Subject: [PATCH 39/41] Fix double-firing of timed-out turns Signed-off-by: Li, Tianmu --- .../load_generator/multi_turn_strategy.py | 2 ++ .../load_generator/session.py | 11 +++++++ .../unit/load_generator/test_async_session.py | 31 +++++++++++++++++++ .../test_multi_turn_strategy.py | 9 +++++- 4 files changed, 52 insertions(+), 1 deletion(-) diff --git a/src/inference_endpoint/load_generator/multi_turn_strategy.py b/src/inference_endpoint/load_generator/multi_turn_strategy.py index f5482481..73c9224f 100644 --- a/src/inference_endpoint/load_generator/multi_turn_strategy.py +++ b/src/inference_endpoint/load_generator/multi_turn_strategy.py @@ -276,6 +276,8 @@ def _handle_timeout(self, query_id: str, conv_id: str) -> None: and query_id in self._phase_issuer.uuid_to_index # type: ignore[attr-defined] ): self._phase_issuer.inflight -= 1 # type: ignore[attr-defined] + if hasattr(self._phase_issuer, "completed_uuids"): + self._phase_issuer.completed_uuids.add(query_id) # type: ignore[attr-defined] logger.warning( "Turn timed out for conversation %s (query=%s)", conv_id, query_id diff --git a/src/inference_endpoint/load_generator/session.py b/src/inference_endpoint/load_generator/session.py index 76ddb5a5..a695b280 100644 --- a/src/inference_endpoint/load_generator/session.py +++ b/src/inference_endpoint/load_generator/session.py @@ -178,6 +178,7 @@ class PhaseIssuer: "_publisher", "_stop_check", "uuid_to_index", + "completed_uuids", "inflight", "issued_count", ) @@ -194,6 +195,7 @@ def __init__( self._publisher = publisher self._stop_check = stop_check self.uuid_to_index: dict[str, int] = {} + self.completed_uuids: set[str] = set() self.inflight: int = 0 self.issued_count: int = 0 @@ -447,6 +449,14 @@ def _handle_response(self, resp: QueryResult | StreamChunk) -> None: if isinstance(resp, QueryResult): query_id = resp.id + # Drop late responses for queries already terminated (e.g. by + # MultiTurnStrategy._handle_timeout). Without this gate, a real + # response arriving after timeout double-publishes ERROR/COMPLETE + # and double-decrements inflight (no per-request HTTP timeout + # exists in endpoint_client; late arrivals are possible). + if phase_issuer is not None and query_id in phase_issuer.completed_uuids: + return + # Emit ERROR before COMPLETE for failed queries so downstream # consumers (notably the metrics aggregator) see the ERROR # while the in-flight tracked row still exists. COMPLETE @@ -479,6 +489,7 @@ def _handle_response(self, resp: QueryResult | StreamChunk) -> None: ) ) if phase_issuer is not None and query_id in phase_issuer.uuid_to_index: + phase_issuer.completed_uuids.add(query_id) phase_issuer.inflight -= 1 if phase_issuer.inflight <= 0: self._drain_event.set() diff --git a/tests/unit/load_generator/test_async_session.py b/tests/unit/load_generator/test_async_session.py index a82ffabd..0b84117c 100644 --- a/tests/unit/load_generator/test_async_session.py +++ b/tests/unit/load_generator/test_async_session.py @@ -965,3 +965,34 @@ def test_tool_calls_included(self): assert result is not None assert "What's the weather?" in result assert "get_weather" in result + + +@pytest.mark.unit +class TestBenchmarkSessionHandleResponse: + """Direct invocation of BenchmarkSession._handle_response (no session.run).""" + + @pytest.mark.asyncio + async def test_drops_late_response_after_timeout(self): + """A late QueryResult for a query already in completed_uuids must be a no-op: + no duplicate ERROR/COMPLETE publish and no second inflight decrement.""" + loop = asyncio.get_running_loop() + dataset = FakeDataset(1) + issuer = FakeIssuer() + publisher = FakePublisher() + phase_issuer = PhaseIssuer(dataset, issuer, publisher, lambda: False) + + phase_issuer.uuid_to_index["q-late"] = 0 + phase_issuer.completed_uuids.add("q-late") + phase_issuer.inflight = 1 + + session = BenchmarkSession(issuer, publisher, loop) + session._current_phase_issuer = phase_issuer + + late_resp = QueryResult( + id="q-late", + error=ErrorData(error_type="late", error_message="late arrival"), + ) + session._handle_response(late_resp) + + assert publisher.events == [] + assert phase_issuer.inflight == 1 diff --git a/tests/unit/load_generator/test_multi_turn_strategy.py b/tests/unit/load_generator/test_multi_turn_strategy.py index 03c2f1a2..8792db50 100644 --- a/tests/unit/load_generator/test_multi_turn_strategy.py +++ b/tests/unit/load_generator/test_multi_turn_strategy.py @@ -39,6 +39,7 @@ def __init__(self, stop_after: int | None = None): self.issued_count = 0 self.inflight: int = 0 self.uuid_to_index: dict[str, int] = {} + self.completed_uuids: set[str] = set() def issue(self, sample_index: int, data_override: dict | None = None) -> str | None: if self._stop_after is not None and self._count >= self._stop_after: @@ -600,13 +601,19 @@ async def test_timeout_publishes_error_and_complete_events(): strategy._inflight["q-x"] = "conv-x" strategy._active_iters["conv-x"] = iter([]) + issuer = FakePhaseIssuer() + issuer.uuid_to_index["q-x"] = 0 + issuer.inflight = 1 + strategy._all_done = asyncio.Event() strategy._loop = asyncio.get_running_loop() - strategy._phase_issuer = None + strategy._phase_issuer = issuer strategy._handle_timeout("q-x", "conv-x") assert publisher.publish.call_count == 2 + assert issuer.inflight == 0 + assert "q-x" in issuer.completed_uuids first_call, second_call = publisher.publish.call_args_list first_record = first_call.args[0] second_record = second_call.args[0] From 03b96b9be97a4063fdd868d1c843b5d18b93d4e6 Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Fri, 15 May 2026 04:31:42 +0000 Subject: [PATCH 40/41] feat: stamp conversation_id and turn on EventRecord pipeline Thread (conversation_id, turn) from the load generator through the event record / SQL writer pipeline so multi-turn runs can group ISSUED, COMPLETE, and ERROR events by conversation in the event log. Co-Authored-By: Claude Opus 4.7 --- .../services/event_logger/sql_writer.py | 5 ++ src/inference_endpoint/core/record.py | 2 + .../load_generator/multi_turn_strategy.py | 9 ++- .../load_generator/session.py | 21 +++++- .../load_generator/strategy.py | 10 ++- .../services/event_logger/test_sql_writer.py | 65 ++++++++++++++++- tests/unit/core/test_record.py | 18 +++++ .../unit/load_generator/test_async_session.py | 61 ++++++++++++++++ .../test_multi_turn_strategy.py | 72 +++++++++++++++++-- 9 files changed, 253 insertions(+), 10 deletions(-) diff --git a/src/inference_endpoint/async_utils/services/event_logger/sql_writer.py b/src/inference_endpoint/async_utils/services/event_logger/sql_writer.py index fd6a1559..52c9b010 100644 --- a/src/inference_endpoint/async_utils/services/event_logger/sql_writer.py +++ b/src/inference_endpoint/async_utils/services/event_logger/sql_writer.py @@ -50,6 +50,9 @@ class EventRowModel(Base): timestamp_ns: Mapped[int] = mapped_column(BigInteger, nullable=False) """Monotonic timestamp in nanoseconds.""" + conversation_id: Mapped[str] = mapped_column(String, nullable=False, default="") + turn: Mapped[int | None] = mapped_column(Integer, nullable=True) + data: Mapped[bytes] = mapped_column(LargeBinary, nullable=False, default=b"") """JSON-encoded event data.""" @@ -61,6 +64,8 @@ def _record_to_row(record: EventRecord) -> EventRowModel: sample_uuid=record.sample_uuid, event_type=topic, timestamp_ns=record.timestamp_ns, + conversation_id=record.conversation_id, + turn=record.turn, data=msgspec.json.encode(record.data), ) diff --git a/src/inference_endpoint/core/record.py b/src/inference_endpoint/core/record.py index da35389e..9ac60e1b 100644 --- a/src/inference_endpoint/core/record.py +++ b/src/inference_endpoint/core/record.py @@ -153,6 +153,8 @@ class EventRecord(msgspec.Struct, kw_only=True, frozen=True, gc=False): # type: event_type: EventType timestamp_ns: int = msgspec.field(default_factory=time.monotonic_ns) sample_uuid: str = "" + conversation_id: str = "" + turn: int | None = None data: OUTPUT_TYPE | PromptData | ErrorData | None = None diff --git a/src/inference_endpoint/load_generator/multi_turn_strategy.py b/src/inference_endpoint/load_generator/multi_turn_strategy.py index 73c9224f..3803e225 100644 --- a/src/inference_endpoint/load_generator/multi_turn_strategy.py +++ b/src/inference_endpoint/load_generator/multi_turn_strategy.py @@ -238,7 +238,12 @@ def _issue_next_turn(self, conv_id: str) -> None: } assert self._phase_issuer is not None - query_id = self._phase_issuer.issue(idx, data_override=data_override) + query_id = self._phase_issuer.issue( + idx, + data_override=data_override, + conversation_id=conv_id, + turn=turn, + ) if query_id is None: # Session stopping β€” signal done. assert self._all_done is not None @@ -278,6 +283,8 @@ def _handle_timeout(self, query_id: str, conv_id: str) -> None: self._phase_issuer.inflight -= 1 # type: ignore[attr-defined] if hasattr(self._phase_issuer, "completed_uuids"): self._phase_issuer.completed_uuids.add(query_id) # type: ignore[attr-defined] + if hasattr(self._phase_issuer, "uuid_to_conv_info"): + self._phase_issuer.uuid_to_conv_info.pop(query_id, None) # type: ignore[attr-defined] logger.warning( "Turn timed out for conversation %s (query=%s)", conv_id, query_id diff --git a/src/inference_endpoint/load_generator/session.py b/src/inference_endpoint/load_generator/session.py index a695b280..28ec4180 100644 --- a/src/inference_endpoint/load_generator/session.py +++ b/src/inference_endpoint/load_generator/session.py @@ -178,6 +178,7 @@ class PhaseIssuer: "_publisher", "_stop_check", "uuid_to_index", + "uuid_to_conv_info", "completed_uuids", "inflight", "issued_count", @@ -195,12 +196,17 @@ def __init__( self._publisher = publisher self._stop_check = stop_check self.uuid_to_index: dict[str, int] = {} + self.uuid_to_conv_info: dict[str, tuple[str, int | None]] = {} self.completed_uuids: set[str] = set() self.inflight: int = 0 self.issued_count: int = 0 def issue( - self, sample_index: int, data_override: dict[str, Any] | None = None + self, + sample_index: int, + data_override: dict[str, Any] | None = None, + conversation_id: str = "", + turn: int | None = None, ) -> str | None: """Load data, build Query, publish ISSUED, send to endpoint. @@ -224,6 +230,7 @@ def issue( data = {**data, **data_override} query = Query(id=query_id, data=data) self.uuid_to_index[query_id] = sample_index + self.uuid_to_conv_info[query_id] = (conversation_id, turn) ts = time.monotonic_ns() prompt_data: PromptData if isinstance(data, dict): @@ -248,6 +255,8 @@ def issue( event_type=SampleEventType.ISSUED, timestamp_ns=ts, sample_uuid=query_id, + conversation_id=conversation_id, + turn=turn, data=prompt_data, ) ) @@ -457,6 +466,12 @@ def _handle_response(self, resp: QueryResult | StreamChunk) -> None: if phase_issuer is not None and query_id in phase_issuer.completed_uuids: return + conv_id_str, turn_num = ("", None) + if phase_issuer is not None: + conv_id_str, turn_num = phase_issuer.uuid_to_conv_info.pop( + query_id, ("", None) + ) + # Emit ERROR before COMPLETE for failed queries so downstream # consumers (notably the metrics aggregator) see the ERROR # while the in-flight tracked row still exists. COMPLETE @@ -475,6 +490,8 @@ def _handle_response(self, resp: QueryResult | StreamChunk) -> None: event_type=ErrorEventType.GENERIC, timestamp_ns=time.monotonic_ns(), sample_uuid=query_id, + conversation_id=conv_id_str, + turn=turn_num, data=resp.error, ) ) @@ -485,6 +502,8 @@ def _handle_response(self, resp: QueryResult | StreamChunk) -> None: if isinstance(resp.completed_at, int) else time.monotonic_ns(), sample_uuid=query_id, + conversation_id=conv_id_str, + turn=turn_num, data=resp.response_output, ) ) diff --git a/src/inference_endpoint/load_generator/strategy.py b/src/inference_endpoint/load_generator/strategy.py index 5e7db26a..1dd6fa27 100644 --- a/src/inference_endpoint/load_generator/strategy.py +++ b/src/inference_endpoint/load_generator/strategy.py @@ -48,7 +48,11 @@ class PhaseIssuerProtocol(Protocol): """Minimal interface that strategies see for issuing samples.""" def issue( - self, sample_index: int, data_override: dict[str, Any] | None = None + self, + sample_index: int, + data_override: dict[str, Any] | None = None, + conversation_id: str = "", + turn: int | None = None, ) -> str | None: """Issue a sample. Returns query_id, or None if the session is stopping. @@ -58,6 +62,10 @@ def issue( data_override take precedence. Used by MultiTurnStrategy to inject a runtime-assembled `messages` array while still inheriting `model`/`max_completion_tokens`/`tools`/`stream` from the dataset row. + conversation_id: Conversation identifier (multi-turn). Empty string + for single-turn issues; propagated onto the published EventRecords + so downstream consumers can group by conversation. + turn: Turn number within a conversation (multi-turn), or None. """ ... diff --git a/tests/unit/async_utils/services/event_logger/test_sql_writer.py b/tests/unit/async_utils/services/event_logger/test_sql_writer.py index 4d6bbcc6..ae443f54 100644 --- a/tests/unit/async_utils/services/event_logger/test_sql_writer.py +++ b/tests/unit/async_utils/services/event_logger/test_sql_writer.py @@ -34,9 +34,14 @@ from sqlalchemy.orm import Session -def _record(event_type, uuid="", ts=0, data=None): +def _record(event_type, uuid="", ts=0, data=None, conversation_id="", turn=None): return EventRecord( - event_type=event_type, timestamp_ns=ts, sample_uuid=uuid, data=data + event_type=event_type, + timestamp_ns=ts, + sample_uuid=uuid, + conversation_id=conversation_id, + turn=turn, + data=data, ) @@ -58,6 +63,9 @@ def test_session_event_topic(self): assert row.event_type == "session.ended" assert row.sample_uuid == "" assert row.timestamp_ns == 42 + # Defaults for non-multi-turn events: empty conversation_id, NULL turn. + assert row.conversation_id == "" + assert row.turn is None def test_error_event_topic(self): row = _record_to_row(_record(ErrorEventType.GENERIC, ts=99)) @@ -74,6 +82,19 @@ def test_none_data_encodes_to_null(self): decoded = msgspec.json.decode(row.data) assert decoded is None + def test_conversation_id_and_turn_copied_to_row(self): + row = _record_to_row( + _record( + SampleEventType.ISSUED, + uuid="q1", + ts=10, + conversation_id="conv-x", + turn=2, + ) + ) + assert row.conversation_id == "conv-x" + assert row.turn == 2 + # --------------------------------------------------------------------------- # EventRowModel schema @@ -260,3 +281,43 @@ def test_mixed_event_types(self, tmp_path): "session.ended", ] engine.dispose() + + def test_conversation_id_and_turn_persisted(self, tmp_path): + writer = SQLWriter(tmp_path / "events", flush_interval=1) + try: + writer.write( + _record( + SampleEventType.ISSUED, + uuid="q1", + ts=10, + conversation_id="conv-a", + turn=1, + ) + ) + writer.write( + _record( + SampleEventType.COMPLETE, + uuid="q1", + ts=20, + conversation_id="conv-a", + turn=1, + ) + ) + # Single-turn / non-conversation event leaves defaults. + writer.write(_record(SessionEventType.STARTED, ts=0)) + finally: + writer.close() + + engine = create_engine(f"sqlite:///{tmp_path / 'events.db'}") + with Session(engine) as session: + rows = ( + session.execute(select(EventRowModel).order_by(EventRowModel.id)) + .scalars() + .all() + ) + assert [(r.conversation_id, r.turn) for r in rows] == [ + ("conv-a", 1), + ("conv-a", 1), + ("", None), + ] + engine.dispose() diff --git a/tests/unit/core/test_record.py b/tests/unit/core/test_record.py index 86b108a0..cd9173f6 100644 --- a/tests/unit/core/test_record.py +++ b/tests/unit/core/test_record.py @@ -66,6 +66,8 @@ def test_construction_with_only_event_type_uses_defaults(self): after = time.monotonic_ns() assert before <= record.timestamp_ns <= after assert record.sample_uuid == "" + assert record.conversation_id == "" + assert record.turn is None assert record.data is None @@ -192,6 +194,8 @@ def test_record_with_only_event_type_round_trips_with_defaults(self): decoded = _codec.decode(payload) assert decoded.event_type.topic == SessionEventType.ENDED.topic assert decoded.sample_uuid == "" + assert decoded.conversation_id == "" + assert decoded.turn is None assert decoded.data is None assert decoded.timestamp_ns > 0 @@ -205,6 +209,20 @@ def test_explicit_timestamp_ns_preserved_round_trip(self): decoded = _codec.decode(payload) assert decoded.timestamp_ns == ts + def test_conversation_id_and_turn_round_trip(self): + record = EventRecord( + event_type=SampleEventType.ISSUED, + sample_uuid="q-mt", + conversation_id="conv-7", + turn=4, + data=PromptData(text="hi"), + ) + _, payload = _codec.encode(record) + decoded = _codec.decode(payload) + assert decoded.conversation_id == "conv-7" + assert decoded.turn == 4 + assert decoded.sample_uuid == "q-mt" + class TestEventRecordCodecOnDecodeError: """Tests for the two branches of EventRecordCodec.on_decode_error. diff --git a/tests/unit/load_generator/test_async_session.py b/tests/unit/load_generator/test_async_session.py index 0b84117c..8dccb64b 100644 --- a/tests/unit/load_generator/test_async_session.py +++ b/tests/unit/load_generator/test_async_session.py @@ -161,11 +161,15 @@ def test_issue_builds_query_and_publishes(self): assert len(issuer.issued_queries) == 1 assert issuer.issued_queries[0].id == result assert 3 in phase_issuer.uuid_to_index.values() + # Single-turn callers omit conv_id/turn β€” defaults flow through. + assert phase_issuer.uuid_to_conv_info[result] == ("", None) # Should have published ISSUED event issued_events = publisher.events_of_type(SampleEventType.ISSUED) assert len(issued_events) == 1 assert issued_events[0].sample_uuid == result + assert issued_events[0].conversation_id == "" + assert issued_events[0].turn is None def test_issue_returns_none_when_stopped(self): dataset = FakeDataset(5) @@ -188,6 +192,23 @@ def test_uuid_is_unique_per_issue(self): ids = [phase_issuer.issue(i % 5) for i in range(10)] assert len(set(ids)) == 10 + def test_issue_stamps_conversation_id_and_turn_on_issued_event(self): + dataset = FakeDataset(5) + issuer = FakeIssuer() + issuer._auto_respond = False + publisher = FakePublisher() + phase_issuer = PhaseIssuer(dataset, issuer, publisher, lambda: False) + + query_id = phase_issuer.issue(2, conversation_id="conv-1", turn=3) + assert query_id is not None + assert phase_issuer.uuid_to_conv_info[query_id] == ("conv-1", 3) + + issued = publisher.events_of_type(SampleEventType.ISSUED) + assert len(issued) == 1 + assert issued[0].sample_uuid == query_id + assert issued[0].conversation_id == "conv-1" + assert issued[0].turn == 3 + # --------------------------------------------------------------------------- # BenchmarkSession tests @@ -571,6 +592,46 @@ async def inject_error(): f"complete at idx {complete_idx}" ) + @pytest.mark.asyncio + async def test_handle_response_stamps_conversation_id_and_turn(self): + """Both COMPLETE and ERROR events inherit the (conv_id, turn) seeded at + issue time; the entry is popped so late duplicates can't reuse it.""" + loop = asyncio.get_running_loop() + issuer = FakeIssuer() + publisher = FakePublisher() + session = BenchmarkSession(issuer, publisher, loop) + + phase_issuer = PhaseIssuer(FakeDataset(2), issuer, publisher, lambda: False) + session._current_phase_issuer = phase_issuer + + # Success path: COMPLETE inherits conv info, entry is popped. + phase_issuer.uuid_to_index["q-ok"] = 0 + phase_issuer.uuid_to_conv_info["q-ok"] = ("conv-9", 5) + phase_issuer.inflight = 1 + session._handle_response( + QueryResult(id="q-ok", response_output="ok", completed_at=12345) + ) + complete = publisher.events_of_type(SampleEventType.COMPLETE) + assert [(e.conversation_id, e.turn) for e in complete] == [("conv-9", 5)] + assert "q-ok" not in phase_issuer.uuid_to_conv_info + + # Error path: ERROR (emitted before COMPLETE) also carries conv info. + phase_issuer.uuid_to_index["q-err"] = 1 + phase_issuer.uuid_to_conv_info["q-err"] = ("conv-err", 2) + phase_issuer.inflight = 1 + session._handle_response( + QueryResult( + id="q-err", + error=ErrorData(error_type="boom", error_message="x"), + ) + ) + error_events = [ + e for e in publisher.events if isinstance(e.event_type, ErrorEventType) + ] + assert [(e.conversation_id, e.turn) for e in error_events] == [("conv-err", 2)] + complete = publisher.events_of_type(SampleEventType.COMPLETE) + assert (complete[-1].conversation_id, complete[-1].turn) == ("conv-err", 2) + @pytest.mark.unit class TestBenchmarkSessionPoissonIntegration: diff --git a/tests/unit/load_generator/test_multi_turn_strategy.py b/tests/unit/load_generator/test_multi_turn_strategy.py index 8792db50..bd70a2e2 100644 --- a/tests/unit/load_generator/test_multi_turn_strategy.py +++ b/tests/unit/load_generator/test_multi_turn_strategy.py @@ -39,15 +39,23 @@ def __init__(self, stop_after: int | None = None): self.issued_count = 0 self.inflight: int = 0 self.uuid_to_index: dict[str, int] = {} + self.uuid_to_conv_info: dict[str, tuple[str, int | None]] = {} self.completed_uuids: set[str] = set() - def issue(self, sample_index: int, data_override: dict | None = None) -> str | None: + def issue( + self, + sample_index: int, + data_override: dict | None = None, + conversation_id: str = "", + turn: int | None = None, + ) -> str | None: if self._stop_after is not None and self._count >= self._stop_after: return None self._count += 1 self.issued_count += 1 query_id = f"q{sample_index:04d}" self.issued.append(sample_index) + self.uuid_to_conv_info[query_id] = (conversation_id, turn) return query_id @@ -108,8 +116,13 @@ async def test_single_conversation_multi_turn(): issued_order: list[str] = [] original_issue = issuer.issue - def tracked_issue(idx, data_override=None): - q = original_issue(idx, data_override=data_override) + def tracked_issue(idx, data_override=None, conversation_id="", turn=None): + q = original_issue( + idx, + data_override=data_override, + conversation_id=conversation_id, + turn=turn, + ) if q: issued_order.append(q) return q @@ -170,7 +183,13 @@ class TimedIssuer: issued_count = 0 issued: list[int] = [] - def issue(self, idx: int, data_override: dict | None = None) -> str | None: + def issue( + self, + idx: int, + data_override: dict | None = None, + conversation_id: str = "", + turn: int | None = None, + ) -> str | None: import time issue_timestamps[idx] = time.monotonic() @@ -395,7 +414,13 @@ class ErrorIssuer: issued_count = 0 issued: list[int] = [] - def issue(self, idx: int, data_override: dict | None = None) -> str | None: + def issue( + self, + idx: int, + data_override: dict | None = None, + conversation_id: str = "", + turn: int | None = None, + ) -> str | None: raise RuntimeError("simulated pipeline error") with pytest.raises(RuntimeError, match="simulated pipeline error"): @@ -603,6 +628,7 @@ async def test_timeout_publishes_error_and_complete_events(): issuer = FakePhaseIssuer() issuer.uuid_to_index["q-x"] = 0 + issuer.uuid_to_conv_info["q-x"] = ("conv-x", 1) issuer.inflight = 1 strategy._all_done = asyncio.Event() @@ -614,6 +640,8 @@ async def test_timeout_publishes_error_and_complete_events(): assert publisher.publish.call_count == 2 assert issuer.inflight == 0 assert "q-x" in issuer.completed_uuids + # Conv info must be cleared so a late real response can't reuse stale state. + assert "q-x" not in issuer.uuid_to_conv_info first_call, second_call = publisher.publish.call_args_list first_record = first_call.args[0] second_record = second_call.args[0] @@ -624,6 +652,40 @@ async def test_timeout_publishes_error_and_complete_events(): assert second_record.sample_uuid == "q-x" +@pytest.mark.unit +@pytest.mark.asyncio +async def test_issue_passes_conversation_id_and_turn_to_phase_issuer(): + """MultiTurnStrategy must forward (conv_id, turn) to phase_issuer.issue().""" + conv_manager = ConversationManager() + conversations = {"conv-A": [1, 2], "conv-B": [1]} + metadata = _make_dataset_metadata(conversations) + strategy = MultiTurnStrategy(conv_manager, metadata) + issuer = FakePhaseIssuer() + + # Build the expected (query_id -> (conv_id, turn)) map from the same + # sample_index ordering _make_dataset_metadata uses, so the test does not + # encode that ordering as magic numbers. + expected: dict[str, tuple[str, int]] = {} + sample_index = 0 + for conv_id, turns in conversations.items(): + for turn in turns: + expected[f"q{sample_index:04d}"] = (conv_id, turn) + sample_index += 1 + + async def respond_in_order(): + await asyncio.sleep(0.01) + for query_id in expected: + strategy.on_sample_complete( + QueryResult(id=query_id, response_output=TextModelOutput(output="ok")) + ) + await asyncio.sleep(0.005) + + asyncio.create_task(respond_in_order()) + await strategy.execute(issuer) + + assert issuer.uuid_to_conv_info == expected + + @pytest.mark.unit @pytest.mark.asyncio async def test_error_turn_aborts_remaining_turns(): From fb2ff32d23e580d3592d28cbaac772abe83927f5 Mon Sep 17 00:00:00 2001 From: "Li, Tianmu" Date: Fri, 15 May 2026 05:22:03 +0000 Subject: [PATCH 41/41] fix: sum_sq overflow, streaming conv stamping, tool-call metric coverage Change 0: cast sum_sq to float in _hdr_stat/_exact_stat to prevent msgpack OverflowError for ns-range latencies >= 4.3s. SeriesStat wire schema already accepts int|float. Change 1: stamp conversation_id/turn on RECV_FIRST/RECV_NON_FIRST events using .get() (not .pop()) so the uuid_to_conv_info entry remains available for the terminal QueryResult pop. Change 2: add test_tpot_osl_for_tool_call_complete to TestAsyncTriggers with exact OSL/TPOT value assertions for tool-call streaming responses. Co-Authored-By: Claude Sonnet 4.6 --- .../services/metrics_aggregator/registry.py | 4 +- .../load_generator/session.py | 7 ++++ .../metrics_aggregator/test_aggregator.py | 42 ++++++++++++++++++- .../metrics_aggregator/test_registry.py | 7 ++++ .../metrics_aggregator/test_snapshot.py | 12 +++++- .../unit/load_generator/test_async_session.py | 22 ++++++++-- 6 files changed, 86 insertions(+), 8 deletions(-) diff --git a/src/inference_endpoint/async_utils/services/metrics_aggregator/registry.py b/src/inference_endpoint/async_utils/services/metrics_aggregator/registry.py index 1268ba9e..800238f3 100644 --- a/src/inference_endpoint/async_utils/services/metrics_aggregator/registry.py +++ b/src/inference_endpoint/async_utils/services/metrics_aggregator/registry.py @@ -274,7 +274,7 @@ def _hdr_stat(self) -> SeriesStat: total=self._total, min=self._min, max=self._max, - sum_sq=self._sum_sq, + sum_sq=float(self._sum_sq), percentiles=perc_dict, histogram=histogram, ) @@ -313,7 +313,7 @@ def _exact_stat(self) -> SeriesStat: total=self._total, min=self._min, max=self._max, - sum_sq=self._sum_sq, + sum_sq=float(self._sum_sq), percentiles=perc_dict, histogram=histogram, ) diff --git a/src/inference_endpoint/load_generator/session.py b/src/inference_endpoint/load_generator/session.py index 28ec4180..b860db0c 100644 --- a/src/inference_endpoint/load_generator/session.py +++ b/src/inference_endpoint/load_generator/session.py @@ -524,11 +524,18 @@ def _handle_response(self, resp: QueryResult | StreamChunk) -> None: if is_first else SampleEventType.RECV_NON_FIRST ) + conv_id_str, turn_num = ("", None) + if phase_issuer is not None: + conv_id_str, turn_num = phase_issuer.uuid_to_conv_info.get( + resp.id, ("", None) + ) self._publisher.publish( EventRecord( event_type=event_type, timestamp_ns=time.monotonic_ns(), sample_uuid=resp.id, + conversation_id=conv_id_str, + turn=turn_num, ) ) diff --git a/tests/unit/async_utils/services/metrics_aggregator/test_aggregator.py b/tests/unit/async_utils/services/metrics_aggregator/test_aggregator.py index 0669a31e..9877aee5 100644 --- a/tests/unit/async_utils/services/metrics_aggregator/test_aggregator.py +++ b/tests/unit/async_utils/services/metrics_aggregator/test_aggregator.py @@ -41,7 +41,7 @@ SampleEventType, SessionEventType, ) -from inference_endpoint.core.types import ErrorData, PromptData +from inference_endpoint.core.types import ErrorData, PromptData, TextModelOutput from .conftest import ( MockTokenizePool, @@ -1120,3 +1120,43 @@ def __exit__(self, *args): ) finally: agg.close() + + @pytest.mark.asyncio + async def test_tpot_osl_for_tool_call_complete(self, tmp_path): + """OSL and TPOT use message-path tokenization when COMPLETE carries tool_calls.""" + loop = asyncio.get_event_loop() + pool = MockTokenizePool(delay=0.0) + with ManagedZMQContext.scoped(socket_dir=str(tmp_path)) as ctx: + agg, registry, _ = make_aggregator( + ctx, loop, "agg_tpot_osl_tool_call", tokenize_pool=pool + ) + try: + tool_call = { + "id": "c1", + "type": "function", + "function": {"name": "f", "arguments": "{}"}, + } + await agg.process( + [ + session_event( + SessionEventType.START_PERFORMANCE_TRACKING, ts=0 + ), + sample_event(SampleEventType.ISSUED, "s1", ts=1000), + sample_event(SampleEventType.RECV_FIRST, "s1", ts=2000), + sample_event( + SampleEventType.COMPLETE, + "s1", + ts=5000, + data=TextModelOutput(output="ok", tool_calls=(tool_call,)), + ), + ] + ) + await agg._table.drain_tasks() + # OSL = token_count("ok" + tool_calls_json) = 2 + assert snapshot_series_total(registry, MetricSeriesKey.OSL.value) == 2 + # tpot = (5000 - 2000) / token_count(tool_calls_json) = 3000 / 1 = 3000 + assert snapshot_series_total( + registry, MetricSeriesKey.TPOT_NS.value + ) == pytest.approx(3000.0) + finally: + agg.close() diff --git a/tests/unit/async_utils/services/metrics_aggregator/test_registry.py b/tests/unit/async_utils/services/metrics_aggregator/test_registry.py index d6759973..ce8fa704 100644 --- a/tests/unit/async_utils/services/metrics_aggregator/test_registry.py +++ b/tests/unit/async_utils/services/metrics_aggregator/test_registry.py @@ -207,6 +207,13 @@ def test_float_dtype(self): assert stat.count == 3 assert stat.total == pytest.approx(7.5) + def test_ns_range_sum_sq_is_float(self): + s = self._make() + s.record(_NS_HIGH) + s.record(_NS_HIGH) + assert isinstance(s.build_stat(exact=False).sum_sq, float) + assert isinstance(s.build_stat(exact=True).sum_sq, float) + @pytest.mark.unit class TestMetricsRegistry: diff --git a/tests/unit/async_utils/services/metrics_aggregator/test_snapshot.py b/tests/unit/async_utils/services/metrics_aggregator/test_snapshot.py index ba861094..db07020c 100644 --- a/tests/unit/async_utils/services/metrics_aggregator/test_snapshot.py +++ b/tests/unit/async_utils/services/metrics_aggregator/test_snapshot.py @@ -47,14 +47,22 @@ def test_float_value(self): @pytest.mark.unit class TestSeriesStat: - def test_roundtrip(self): + @pytest.mark.parametrize( + "sum_sq", + [ + 55000, + # Float exceeding uint64 max β€” would overflow msgpack if encoded as int. + 2.0 * (2**64 - 1), + ], + ) + def test_roundtrip(self, sum_sq): stat = SeriesStat( name="ttft_ns", count=5, total=500, min=50, max=150, - sum_sq=55000, + sum_sq=sum_sq, percentiles={"50": 100.0, "99": 145.0}, histogram=[((50.0, 100.0), 2), ((100.0, 150.0), 3)], ) diff --git a/tests/unit/load_generator/test_async_session.py b/tests/unit/load_generator/test_async_session.py index 8dccb64b..81baff6e 100644 --- a/tests/unit/load_generator/test_async_session.py +++ b/tests/unit/load_generator/test_async_session.py @@ -594,16 +594,32 @@ async def inject_error(): @pytest.mark.asyncio async def test_handle_response_stamps_conversation_id_and_turn(self): - """Both COMPLETE and ERROR events inherit the (conv_id, turn) seeded at - issue time; the entry is popped so late duplicates can't reuse it.""" + """All event types inherit (conv_id, turn) seeded at issue time; streaming + events use .get() so the entry survives for the terminal QueryResult pop.""" loop = asyncio.get_running_loop() issuer = FakeIssuer() publisher = FakePublisher() session = BenchmarkSession(issuer, publisher, loop) - phase_issuer = PhaseIssuer(FakeDataset(2), issuer, publisher, lambda: False) + phase_issuer = PhaseIssuer(FakeDataset(3), issuer, publisher, lambda: False) session._current_phase_issuer = phase_issuer + # Streaming path: entry stays available for the terminal COMPLETE pop. + phase_issuer.uuid_to_conv_info["q-stream"] = ("conv-s", 7) + session._handle_response( + StreamChunk(id="q-stream", metadata={"first_chunk": True}) + ) + session._handle_response(StreamChunk(id="q-stream", response_chunk="delta")) + assert ( + publisher.events_of_type(SampleEventType.RECV_FIRST)[0].conversation_id, + publisher.events_of_type(SampleEventType.RECV_FIRST)[0].turn, + ) == ("conv-s", 7) + assert ( + publisher.events_of_type(SampleEventType.RECV_NON_FIRST)[0].conversation_id, + publisher.events_of_type(SampleEventType.RECV_NON_FIRST)[0].turn, + ) == ("conv-s", 7) + assert "q-stream" in phase_issuer.uuid_to_conv_info + # Success path: COMPLETE inherits conv info, entry is popped. phase_issuer.uuid_to_index["q-ok"] = 0 phase_issuer.uuid_to_conv_info["q-ok"] = ("conv-9", 5)