diff --git a/docs/MULTI_TURN_QUICKSTART.md b/docs/MULTI_TURN_QUICKSTART.md new file mode 100644 index 00000000..5a48ab75 --- /dev/null +++ b/docs/MULTI_TURN_QUICKSTART.md @@ -0,0 +1,280 @@ +# Multi-Turn Conversation Benchmarking - Quick Start Guide + +## Quick Start in 5 Minutes + +### 1. Prepare Your Dataset + +Create a JSONL file with your conversations. All rows for a given `conversation_id` must appear +**consecutively** in the file (no interleaving with other conversations): + +```jsonl +{"conversation_id": "c1", "turn": 1, "role": "user", "content": "Hello!", "system": "You are a helpful assistant"} +{"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "Hi! How can I help?"} +{"conversation_id": "c1", "turn": 3, "role": "user", "content": "What's 2+2?"} +{"conversation_id": "c1", "turn": 4, "role": "assistant", "content": "2+2 equals 4."} +``` + +**Rules**: + +- Alternate between "user" and "assistant" roles +- Start with "user" role +- Sequential turn numbers (1, 2, 3, ...) +- Same `conversation_id` for all turns in a conversation +- All rows for the same `conversation_id` must be grouped together + +### 2. Create Configuration File + +Save as `multi_turn_config.yaml`: + +```yaml +name: "my-multi-turn-benchmark" +version: "1.0" +type: "online" + +model_params: + name: "your-model-name" + temperature: 0.7 + max_new_tokens: 256 + +datasets: + - name: my_conversations + type: performance + path: path/to/your/conversations.jsonl + multi_turn: # ← Presence of this block enables multi-turn mode + turn_timeout_s: 300 # ← Max wait for prev turn + +settings: + load_pattern: + type: multi_turn # ← Use multi-turn scheduler + target_concurrency: 32 # ← Required: max simultaneous conversations + + client: + num_workers: 4 + +endpoint_config: + endpoints: + - "http://your-endpoint:8000" + api_type: openai + +report_dir: logs/my_multi_turn_benchmark +``` + +Results are written to `report_dir` (here: `logs/my_multi_turn_benchmark/`). + +### 3. Run Benchmark + +```bash +inference-endpoint benchmark from-config --config multi_turn_config.yaml +``` + +That's it! Your benchmark will now: + +- ✅ Enforce turn ordering (turn N+1 waits for turn N) +- ✅ Include conversation history in each request +- ✅ Log all issued (client) turns to events.jsonl — scripted assistant rows are context only and do not produce sample events + +--- + +## Understanding Results + +After the benchmark completes, check the directory configured via `report_dir`: + +### Events Log + +The `events.jsonl` file contains one JSON record per line, with the standard +`sample_uuid`, `event_type`, and `timestamp_ns` fields. Events are keyed by +`sample_uuid` only. To correlate events with conversations, join through +`sample_idx_map.json` (written next to `events.jsonl`) and the multi-turn +dataset's `conversation_metadata["samples"]`, which maps sample indices to +`(conversation_id, turn)` tuples. + +### Metrics + +Currently available: + +- **Per-turn metrics**: Latency, TTFT, TPOT for each turn +- **Conversation tracking**: events are keyed by `sample_uuid` only; correlate any event back to a conversation by joining through `sample_idx_map.json` and `conversation_metadata["samples"]` + +_Note: Per-conversation aggregation (e.g., "conversations/sec") is coming in a future update._ + +--- + +## Concurrency Control + +`target_concurrency` is **required** for the `multi_turn` load pattern. It controls how many +conversations are active simultaneously. Each active conversation has exactly one in-flight turn +at a time — a worker issues turn N, waits for the response, then issues turn N+1. A new +conversation starts only after a worker finishes all turns of its current one. + +```yaml +settings: + load_pattern: + type: multi_turn + target_concurrency: 32 # ← 32 conversations active simultaneously +``` + +--- + +## Troubleshooting + +### Validate Your Dataset Before Running + +Use the bundled validation script to check your JSONL file for schema errors before benchmarking: + +```bash +python scripts/validate_jsonl_schema.py path/to/your/conversations.jsonl +``` + +This catches per-row schema errors (missing required fields, wrong types, +malformed `tool_results`). Cross-row invariants (consecutive turn numbers, +valid role sequences, grouped conversations) are enforced by +`MultiTurnDataset` at load time and will surface at benchmark startup. + +### "Conversation has invalid role sequence" + +**Problem**: Your dataset doesn't follow a valid role sequence. + +**Fix**: Check your JSONL. Valid sequences: + +- Plain chat: `user → assistant → user → assistant → ...` +- Agentic (tool-use): `user → assistant → tool → assistant → tool → ... → user` + +Conversations may also end with a `tool` row (the model's response to the final tool call is the benchmark target). + +### "Rows for conversation X are not consecutive" + +**Problem**: Rows for the same `conversation_id` are interleaved with rows from other conversations. + +**Fix**: Sort your JSONL so all rows for each conversation appear together. + +### "Turn timed out waiting for prev turn" + +**Problem**: Previous turn took longer than `turn_timeout_s`. + +**Fixes**: + +1. Increase `turn_timeout_s` in config +2. Check if your endpoint is slow or unresponsive +3. Look for errors in the endpoint logs + +### Dataset not loading + +**Problem**: MultiTurnDataset not recognized. + +**Fix**: Ensure `multi_turn:` block is present in the dataset config. The file format +is auto-detected from the `.jsonl` extension — no `format` field is needed: + +```yaml +datasets: + - path: your_file.jsonl + multi_turn: {} +``` + +--- + +## Example Datasets + +### Simple 2-Turn Conversation + +```jsonl +{"conversation_id": "c1", "turn": 1, "role": "user", "content": "Hi"} +{"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "Hello!"} +``` + +### With System Prompt + +```jsonl +{"conversation_id": "c1", "turn": 1, "role": "user", "content": "Who won?", "system": "You are a sports expert"} +{"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "The Lakers won."} +``` + +### Multiple Conversations + +```jsonl +{"conversation_id": "c1", "turn": 1, "role": "user", "content": "Hi"} +{"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "Hello!"} +{"conversation_id": "c2", "turn": 1, "role": "user", "content": "Hey"} +{"conversation_id": "c2", "turn": 2, "role": "assistant", "content": "Hi there!"} +``` + +### With Model Override + +```jsonl +{"conversation_id": "c1", "turn": 1, "role": "user", "content": "Summarize this", "model": "gpt-4"} +{"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "Here's the summary..."} +``` + +--- + +## Testing Your Setup + +### 1. Use the Example Dataset + +```bash +# Run from the repository root — dataset paths in the bundled YAML are +# repo-relative (e.g. examples/09_MultiTurn/customer_support_conversations.jsonl). +inference-endpoint benchmark from-config \ + --config examples/09_MultiTurn/multi_turn_benchmark.yaml +``` + +### 2. Check the Logs + +```bash +cat logs/multi_turn_test/benchmark.log +# Look for: "Turn X of conversation_id issued" +``` + +### 3. Verify Event Recording + +```bash +# List all sample UUIDs in the events log +jq -r '.sample_uuid' logs/multi_turn_test/events.jsonl | sort -u +# Should show UUIDs; correlate to conversations via sample_idx_map.json +``` + +--- + +## Tips & Best Practices + +### Dataset Design + +- **Keep conversations realistic**: 2-10 turns typical +- **Test edge cases**: 1-turn conversations, very long conversations +- **Include system prompts**: Helps model understand context + +### Performance + +- **Workers**: `client.num_workers` controls HTTP worker processes, independent of `target_concurrency`. The default (`-1`) auto-tunes based on NUMA topology. +- **Timeout**: Set `turn_timeout_s` = 2x your longest expected turn latency +- **Memory**: ~1KB per turn, plan accordingly for large datasets + +### Debugging + +- **Start small**: Test with 1-2 conversations first +- **Single conversation**: Use `target_concurrency: 1` +- **Check events.jsonl**: Verify turn ordering with `jq` + +--- + +## More Information + +- **Full Documentation**: See `examples/09_MultiTurn/README.md` +- **Architecture**: See `AGENTS.md` (Multi-Turn section) + +--- + +## Checklist + +Before running your first multi-turn benchmark: + +- [ ] Dataset follows format (user/assistant alternation, or agentic user→assistant→tool sequences) +- [ ] All rows for each conversation_id are grouped together +- [ ] Config has `multi_turn:` block in the dataset section +- [ ] Config has `load_pattern.type: multi_turn` +- [ ] Endpoint is running and reachable +- [ ] File uses `.jsonl` extension (format is auto-detected) +- [ ] Conversation IDs are unique per conversation +- [ ] Turn numbers are sequential (1, 2, 3, ...) +- [ ] Dataset is configured as `type: performance` (accuracy evaluation of multi-turn datasets is not yet supported) + +Happy benchmarking! diff --git a/examples/09_MultiTurn/README.md b/examples/09_MultiTurn/README.md new file mode 100644 index 00000000..1f6a24fc --- /dev/null +++ b/examples/09_MultiTurn/README.md @@ -0,0 +1,290 @@ +# Multi-Turn Conversation Benchmarking Examples + +This directory contains examples for benchmarking conversational AI workloads with multi-turn conversation support. + +## Overview + +Multi-turn conversation benchmarking enables testing realistic conversational AI scenarios where each turn depends on previous responses. The system maintains conversation history and enforces turn sequencing to simulate real-world multi-turn interactions. + +## Dataset Format + +Multi-turn datasets use JSONL format with the following structure: + +```jsonl +{"conversation_id": "c1", "turn": 1, "role": "user", "content": "...", "system": "..."} +{"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "..."} +{"conversation_id": "c1", "turn": 3, "role": "user", "content": "..."} +``` + +### Required Fields + +- `conversation_id`: Unique identifier for each conversation +- `turn`: Turn number within conversation (1-indexed) +- `role`: Speaker role ("user" or "assistant") +- `content`: Message content + +### Optional Fields + +- `system`: System prompt (typically only on first user turn) +- `model`: Model name override for this turn +- `max_new_tokens`: Maximum tokens to generate for this turn + +### Validation Rules + +1. All rows for a given `conversation_id` must appear **consecutively** in the file (no interleaving + with rows from other conversations). File-order within a conversation does not matter — the + loader sorts by the `turn` column when building conversation history. + The flat-row format is intentional: it enables row-by-row streaming without loading entire + conversations into memory first. +2. Conversations must follow a valid role sequence: + - Plain chat: `user → assistant → user → ...` + - Agentic: `user → assistant (with tool_calls) → tool (tool_results list; parallel results merged) → [assistant (with tool_calls) → tool]* → assistant → user → ...` +3. First turn must be "user" role +4. Turn numbers must be sequential (1, 2, 3, ...) +5. Each conversation must have at least one turn + +## Agentic (Tool-Sequence) Datasets + +For agentic workloads where the model dispatches tools, the dataset must include tool-call +metadata. The source format for these datasets is a **snapshot JSONL** — each line contains the +full conversation history at a particular checkpoint. The benchmarker requires **flat-row JSONL** +(one row per message), so a conversion step is needed first. + +### Source snapshot format + +Each line in the source file represents one snapshot of a conversation: + +```json +{ + "conversation_id": "sim_001", + "conversation_idx": 5, + "messages": [ + {"role": "system", "content": "..."}, + {"role": "user", "content": "..."}, + {"role": "assistant", "tool_calls": [{"id": "...", "type": "function", "function": {"name": "bash", "arguments": "{\"cmd\": \"ls\"}"}}]}, + {"role": "tool", "tool_call_id": "...", "content": "file1.txt\nfile2.txt"}, + {"role": "assistant", "content": "Done."} + ], + "tools": [...], + "metadata": {} +} +``` + +Multiple snapshots may exist per `conversation_id` (one per `conversation_idx`); only the +highest-indexed snapshot per conversation is used. + +### Converting to flat-row format + +The following commands convert each source snapshot file to the flat-row format required by the benchmarker. +Run from the repo root: + +```bash +# First argument: input snapshot JSONL; second argument: output flat-row JSONL +python scripts/convert_agentic_snapshot.py \ + /path/to/agentic_coding_dataset.jsonl \ + examples/09_MultiTurn/datasets/agentic_coding_flat.jsonl \ + --verify + +python scripts/convert_agentic_snapshot.py \ + /path/to/agentic_workflow_dataset.jsonl \ + examples/09_MultiTurn/datasets/agentic_workflow_flat.jsonl \ + --verify +``` + +The `datasets/` directory under `examples/09_MultiTurn/` is a placeholder; run the conversion +commands above to populate it before benchmarking. + +The `--verify` flag cross-checks every client turn's message history against the source snapshot +and exits with code 1 if any mismatch is found. The script also: + +- Collapses consecutive `user` messages into one (keeps turn sequencing clean) +- Merges consecutive `tool` messages for the same assistant dispatch into a single row with a + `tool_results` list (so all parallel results are sent together in one API call) + +### Flat-row format after conversion + +The extra fields supported beyond plain user/assistant: + +| Row role | Extra fields | +| ------------------------------------------ | ------------------------------------------------------------------ | +| `assistant` with tool calls | `tool_calls: [{id, type, function: {name, arguments}}]` | +| `tool` results (single or merged parallel) | `tool_results: [{tool_call_id, content}, ...]` | +| `user` or `tool` turns | `tools: [...]` (OpenAI tool definitions forwarded to the endpoint) | + +Example rows from a converted agentic dataset: + +```jsonl +{"conversation_id": "sim_001", "turn": 1, "role": "user", "content": "Fix the bug in foo.py", "system": "You are a coding agent.", "tools": [...]} +{"conversation_id": "sim_001", "turn": 2, "role": "assistant", "tool_calls": [{"id": "functions.bash:0", "type": "function", "function": {"name": "bash", "arguments": "{\"cmd\": \"cat foo.py\"}"}}]} +{"conversation_id": "sim_001", "turn": 3, "role": "tool", "tool_results": [{"tool_call_id": "functions.bash:0", "content": "def foo():\n return 1/0"}], "tools": [...]} +{"conversation_id": "sim_001", "turn": 4, "role": "assistant", "content": "The bug is a ZeroDivisionError. Here is the fix: ..."} +``` + +### Running agentic benchmarks + +After converting the datasets, update the `path` field in the config files and run: + +```bash +inference-endpoint benchmark from-config \ + --config examples/09_MultiTurn/agentic_coding_benchmark.yaml + +inference-endpoint benchmark from-config \ + --config examples/09_MultiTurn/agentic_workflow_benchmark.yaml +``` + +--- + +## Configuration + +### Basic Configuration + +```yaml +datasets: + - name: customer_support + type: performance + path: examples/09_MultiTurn/customer_support_conversations.jsonl + multi_turn: + turn_timeout_s: 300.0 + +settings: + load_pattern: + type: multi_turn + target_concurrency: 32 # ← Required for multi_turn load pattern +``` + +### Concurrency Control + +The `target_concurrency` field is **required** for the `multi_turn` load pattern and controls the maximum number of in-flight requests across all conversations: + +```yaml +settings: + load_pattern: + type: multi_turn + target_concurrency: 32 # ← Limit to 32 concurrent requests +``` + +**Behavior**: + +- With `target_concurrency`: At most `target_concurrency` conversations are active simultaneously; each active conversation has exactly one in-flight turn at any time. +- Turn sequencing is preserved: turn N+1 is issued only after turn N's response arrives. + +**Use cases**: + +- **Prevent endpoint overload**: Control request rate to busy endpoints +- **Large-scale testing**: Benchmark 1000+ conversations without overwhelming system +- **Resource management**: Stay within port limits, memory constraints + +**Example**: 100 conversations with `target_concurrency: 32` + +``` +t=0: Start 32 conversations, issue turn-1 for each (32 in-flight) +t=0.5: Turn-1 of conv A completes → issue turn-2 of conv A (still 32 in-flight) +t=1.0: All turns of conv B complete → start conv 33, issue its turn-1 (still 32 in-flight) +... Maintains at most 32 active conversations +``` + +### Turn Timeout + +Configure the maximum time allowed between issuing a turn and receiving its response: + +```yaml +multi_turn: + turn_timeout_s: 300.0 # 5 minutes +``` + +If a turn does not receive a response within `turn_timeout_s` seconds, that turn is marked failed and all remaining turns in the same conversation are aborted (subsequent turns depend on the timed-out response). The event is logged as a warning. + +## Running Multi-Turn Benchmarks + +### Using Configuration File + +```bash +inference-endpoint benchmark from-config \ + --config examples/09_MultiTurn/multi_turn_benchmark.yaml +``` + +### Viewing Results + +Multi-turn benchmarks produce per-turn metrics: + +- **Per-turn metrics**: Latency, TTFT, TPOT for each individual turn +- **Per-conversation metrics**: Total conversation latency, conversations per second _(planned — not yet implemented)_ + +**Note**: Multi-turn datasets are only supported as performance datasets. Using a multi-turn dataset as an accuracy dataset (`type: accuracy`) is not yet supported and will raise an error at startup. + +Results are stored in the configured `report_dir`. The `events.jsonl` log contains one record per turn keyed by `sample_uuid`. To correlate events with conversations, join through `sample_idx_map.json` and the dataset's `conversation_metadata["samples"]`. + +## Example Datasets + +### customer_support_conversations.jsonl + +Simple customer support conversations demonstrating basic multi-turn interactions: + +- 3 conversations +- 2-4 turns per conversation +- Customer support agent system prompt + +## Architecture Notes + +### Key Components + +- **ConversationManager**: Tracks conversation state and message history +- **MultiTurnStrategy**: Enforces turn sequencing within conversations +- **MultiTurnDataset**: Validates and structures multi-turn data + +### Turn Sequencing + +The system ensures that: + +1. Turn N+1 cannot be issued until turn N completes +2. Message history is included in subsequent requests +3. Concurrent conversations are supported (in independent mode) + +### Memory Considerations + +Each conversation maintains message history in memory. For large-scale benchmarks with long conversations: + +- Memory usage: ~1KB per turn (approximate) +- 1000 conversations × 10 turns = ~10MB + +## Troubleshooting + +### "Conversation has invalid role sequence" + +**Cause**: Conversation doesn't follow a valid role sequence. + +**Fix**: For plain chat, ensure the dataset alternates between user and assistant: + +``` +user -> assistant -> user -> assistant -> ... +``` + +For agentic datasets, use the conversion script (`scripts/convert_agentic_snapshot.py`) to +produce a properly sequenced flat-row file. The valid agentic sequence is: + +``` +user -> assistant (tool_calls) -> tool -> [tool | assistant (tool_calls)]* -> assistant -> user -> ... +``` + +### "Turn timed out" + +**Cause**: A turn did not receive a response within `turn_timeout_s` seconds after it was issued. + +**Fixes**: + +- Increase `turn_timeout_s` in configuration +- Check endpoint performance +- Verify endpoint is responding + +### Single-turn benchmarks unaffected + +Multi-turn logic is only activated when a `multi_turn:` block is present in the dataset configuration. Existing single-turn benchmarks continue to work unchanged with zero performance overhead. + +## Future Enhancements + +Planned features: + +- [ ] Poisson conversation arrival mode +- [ ] Per-conversation metrics in reporting (total conversation latency, conversations per second) +- [ ] Conversation-level latency percentiles +- [ ] Dynamic conversation branching diff --git a/examples/09_MultiTurn/agentic_coding_benchmark.yaml b/examples/09_MultiTurn/agentic_coding_benchmark.yaml new file mode 100644 index 00000000..dace765d --- /dev/null +++ b/examples/09_MultiTurn/agentic_coding_benchmark.yaml @@ -0,0 +1,32 @@ +name: "agentic-coding-benchmark" +version: "1.0" +type: "online" + +model_params: + name: "your-model-name" # Replace with your actual model name + max_new_tokens: 1024 + +datasets: + - name: agentic_coding + type: performance + # Run: python scripts/convert_agentic_snapshot.py examples/09_MultiTurn/datasets/agentic_coding_flat.jsonl --verify + # The datasets/ directory is a placeholder; populate it with the conversion script above. + path: examples/09_MultiTurn/datasets/agentic_coding_flat.jsonl + multi_turn: + turn_timeout_s: 600.0 + +settings: + runtime: + min_duration_ms: 0 + max_duration_ms: 3600000 + + load_pattern: + type: multi_turn + target_concurrency: 4096 + +endpoint_config: + endpoints: + - "http://localhost:8868" + api_type: openai + +report_dir: logs/agentic_coding diff --git a/examples/09_MultiTurn/agentic_workflow_benchmark.yaml b/examples/09_MultiTurn/agentic_workflow_benchmark.yaml new file mode 100644 index 00000000..a66b16c4 --- /dev/null +++ b/examples/09_MultiTurn/agentic_workflow_benchmark.yaml @@ -0,0 +1,32 @@ +name: "agentic-workflow-benchmark" +version: "1.0" +type: "online" + +model_params: + name: "your-model-name" # Replace with your actual model name + max_new_tokens: 512 + +datasets: + - name: agentic_workflow + type: performance + # Run: python scripts/convert_agentic_snapshot.py examples/09_MultiTurn/datasets/agentic_workflow_flat.jsonl --verify + # The datasets/ directory is a placeholder; populate it with the conversion script above. + path: examples/09_MultiTurn/datasets/agentic_workflow_flat.jsonl + multi_turn: + turn_timeout_s: 600.0 + +settings: + runtime: + min_duration_ms: 0 + max_duration_ms: 3600000 + + load_pattern: + type: multi_turn + target_concurrency: 96 + +endpoint_config: + endpoints: + - "http://localhost:8868" + api_type: openai + +report_dir: logs/agentic_workflow diff --git a/examples/09_MultiTurn/customer_support_conversations.jsonl b/examples/09_MultiTurn/customer_support_conversations.jsonl new file mode 100644 index 00000000..ac19e907 --- /dev/null +++ b/examples/09_MultiTurn/customer_support_conversations.jsonl @@ -0,0 +1,10 @@ +{"conversation_id": "conv_001", "turn": 1, "role": "user", "content": "I need help resetting my password", "system": "You are a helpful customer support agent"} +{"conversation_id": "conv_001", "turn": 2, "role": "assistant", "content": "I'd be happy to help you reset your password. Can you provide your email address?"} +{"conversation_id": "conv_001", "turn": 3, "role": "user", "content": "It's user@example.com"} +{"conversation_id": "conv_001", "turn": 4, "role": "assistant", "content": "Thank you. I've sent a password reset link to user@example.com. Please check your inbox and follow the instructions."} +{"conversation_id": "conv_002", "turn": 1, "role": "user", "content": "What are your business hours?", "system": "You are a helpful customer support agent"} +{"conversation_id": "conv_002", "turn": 2, "role": "assistant", "content": "We're open Monday-Friday, 9 AM to 5 PM EST. How can I assist you today?"} +{"conversation_id": "conv_002", "turn": 3, "role": "user", "content": "Do you offer weekend support?"} +{"conversation_id": "conv_002", "turn": 4, "role": "assistant", "content": "For urgent issues, we offer limited support on weekends from 10 AM to 2 PM EST. For non-urgent matters, please contact us during our regular business hours."} +{"conversation_id": "conv_003", "turn": 1, "role": "user", "content": "Can I cancel my subscription?", "system": "You are a helpful customer support agent"} +{"conversation_id": "conv_003", "turn": 2, "role": "assistant", "content": "Yes, you can cancel your subscription at any time. Would you like me to guide you through the cancellation process?"} diff --git a/examples/09_MultiTurn/datasets/.gitkeep b/examples/09_MultiTurn/datasets/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/examples/09_MultiTurn/multi_turn_benchmark.yaml b/examples/09_MultiTurn/multi_turn_benchmark.yaml new file mode 100644 index 00000000..2cc815c9 --- /dev/null +++ b/examples/09_MultiTurn/multi_turn_benchmark.yaml @@ -0,0 +1,34 @@ +name: "multi-turn-customer-support" +version: "1.0" +type: "online" + +model_params: + name: "meta-llama/Llama-3.2-1B-Instruct" + temperature: 0.7 + max_new_tokens: 256 + +datasets: + - name: customer_support_conversations + type: performance + path: examples/09_MultiTurn/customer_support_conversations.jsonl + multi_turn: + turn_timeout_s: 300.0 + +settings: + runtime: + min_duration_ms: 60000 + max_duration_ms: 300000 + + load_pattern: + type: multi_turn + target_concurrency: 32 + + client: + warmup_connections: 0 + +endpoint_config: + endpoints: + - "http://localhost:8868" + api_type: openai + +report_dir: logs/multi_turn_test diff --git a/examples/09_MultiTurn/multi_turn_with_concurrency.yaml b/examples/09_MultiTurn/multi_turn_with_concurrency.yaml new file mode 100644 index 00000000..c1fcf26f --- /dev/null +++ b/examples/09_MultiTurn/multi_turn_with_concurrency.yaml @@ -0,0 +1,34 @@ +name: "multi-turn-with-concurrency-control" +version: "1.0" +type: "online" + +model_params: + name: "meta-llama/Llama-3.2-1B-Instruct" + temperature: 0.7 + max_new_tokens: 256 + +datasets: + - name: customer_support_conversations + type: performance + path: examples/09_MultiTurn/customer_support_conversations.jsonl + multi_turn: + turn_timeout_s: 300.0 + +settings: + runtime: + min_duration_ms: 60000 + max_duration_ms: 300000 + + load_pattern: + type: multi_turn + target_concurrency: 32 # ← NEW: Limit to 32 concurrent requests + + client: + warmup_connections: 0 + +endpoint_config: + endpoints: + - "http://localhost:8868" + api_type: openai + +report_dir: logs/multi_turn_with_concurrency diff --git a/scripts/convert_agentic_snapshot.py b/scripts/convert_agentic_snapshot.py new file mode 100644 index 00000000..fe217b9b --- /dev/null +++ b/scripts/convert_agentic_snapshot.py @@ -0,0 +1,356 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convert agentic snapshot datasets to the flat-row JSONL format expected by MultiTurnDataset. + +Each snapshot record contains the full conversation history up to a checkpoint: + {"conversation_id": "sim_000001", "conversation_idx": 0, + "messages": [{"role": "system", ...}, ...], "tools": [...], "metadata": {}} + +For each conversation only the final snapshot (highest conversation_idx) is used. +Its messages array is expanded into individual flat rows, one per message. + +Usage: + python scripts/convert_agentic_snapshot.py INPUT.jsonl OUTPUT.jsonl + python scripts/convert_agentic_snapshot.py INPUT.jsonl OUTPUT.jsonl --verify +""" + +import argparse +import json +import sys +from pathlib import Path + +# --------------------------------------------------------------------------- +# Helpers shared between convert() and verify() +# --------------------------------------------------------------------------- + + +def _load_final_snapshots(input_path: Path) -> dict[str, dict]: + """Return {conv_id: record} keeping only the highest conversation_idx per conv.""" + final: dict[str, dict] = {} + with input_path.open() as fh: + for line in fh: + line = line.strip() + if not line: + continue + record = json.loads(line) + conv_id = record["conversation_id"] + if ( + conv_id not in final + or record["conversation_idx"] > final[conv_id]["conversation_idx"] + ): + final[conv_id] = record + return final + + +def _apply_collapses(non_system: list[dict]) -> list[tuple[dict, int]]: + """Apply user-collapse and tool-merge passes, tracking the last source index each + output row covers. + + Returns list of (output_msg, last_source_idx) pairs where last_source_idx is the + 0-based index within non_system of the final source message folded into this row. + """ + # Pass 1: collapse consecutive user messages + collapsed: list[tuple[dict, int]] = [] # (msg, last_source_idx) + for src_idx, msg in enumerate(non_system): + if collapsed and collapsed[-1][0]["role"] == "user" and msg["role"] == "user": + prev_msg, _ = collapsed[-1] + prev_text = prev_msg.get("content") or "" + cur_text = msg.get("content") or "" + collapsed[-1] = ( + {**prev_msg, "content": f"{prev_text}\n\n{cur_text}".strip()}, + src_idx, + ) + else: + collapsed.append((msg, src_idx)) + + # Pass 2: merge consecutive tool messages + # Input messages are raw snapshot wire-format (tool_call_id + content on each msg). + # On merge, upgrade the first message to a tool_results list so the output always + # uses the tool_results array form regardless of how many results there are. + merged: list[tuple[dict, int]] = [] + for msg, last_src in collapsed: + if merged and merged[-1][0]["role"] == "tool" and msg["role"] == "tool": + prev_msg, _ = merged[-1] + tool_results = prev_msg.get("tool_results") + if tool_results is None: + tool_results = [ + { + "tool_call_id": prev_msg.get("tool_call_id"), + "content": prev_msg.get("content"), + } + ] + prev_msg = {"role": "tool", "tool_results": tool_results} + tool_results.append( + { + "tool_call_id": msg.get("tool_call_id"), + "content": msg.get("content"), + } + ) + merged[-1] = (prev_msg, last_src) + else: + merged.append((msg, last_src)) + + return merged + + +def _normalize_msg(msg: dict) -> dict: + """Drop None values for comparison.""" + return {k: v for k, v in msg.items() if v is not None} + + +def _expand_row_to_wire_msgs(row: dict) -> list[dict]: + """Expand a single flat row into one or more OpenAI wire-format messages. + + Handles two tool row forms: + - Output flat rows: tool_results array (always used after conversion) + - Raw snapshot messages passed through verify(): tool_call_id + content directly + """ + if isinstance(row.get("tool_results"), list): + return [ + { + "role": "tool", + "tool_call_id": r.get("tool_call_id"), + "content": r.get("content"), + } + for r in row["tool_results"] + ] + msg: dict = {"role": row["role"], "content": row.get("content")} + if row.get("tool_calls"): + msg["tool_calls"] = row["tool_calls"] + if row.get("tool_call_id"): + msg["tool_call_id"] = row["tool_call_id"] + return [msg] + + +def verify(input_path: Path, output_path: Path) -> bool: + """Cross-check every client-turn's pre_built_messages against the source snapshot. + + For each output client turn, reconstruct the pre_built_messages that + MultiTurnDataset would build from the flat rows and compare it against the + ground-truth messages built directly from the source snapshot up to the same + point (accounting for user-collapse and tool-merge). + + Returns: + True if all checks pass, False if any mismatch found. + """ + final = _load_final_snapshots(input_path) + + # Load converted rows grouped by conversation_id + conv_rows: dict[str, list[dict]] = {} + with output_path.open() as fh: + for line in fh: + line = line.strip() + if not line: + continue + row = json.loads(line) + cid = row["conversation_id"] + conv_rows.setdefault(cid, []).append(row) + for cid in conv_rows: + conv_rows[cid].sort(key=lambda r: r["turn"]) + + errors: list[str] = [] + total_checked = 0 + + for conv_id in sorted(final): + record = final[conv_id] + system_content: str | None = None + non_system: list[dict] = [] + for msg in record["messages"]: + if msg["role"] == "system": + system_content = msg.get("content") + else: + non_system.append(msg) + + # Re-apply the same collapses the converter applies, tracking source coverage + processed = _apply_collapses(non_system) # [(output_msg, last_source_idx), ...] + flat_rows = conv_rows.get(conv_id, []) + + if len(processed) != len(flat_rows): + errors.append( + f"{conv_id}: expected {len(processed)} flat rows after collapses, " + f"got {len(flat_rows)} in output" + ) + continue + + client_turn_pairs = [ + (out_pos, flat_row) + for out_pos, (flat_row, _) in enumerate( + zip(flat_rows, processed, strict=True) + ) + if flat_row["role"] in ("user", "tool") + ] + + for ct_idx, (out_pos, flat_row) in enumerate(client_turn_pairs): + # Ground truth: apply the same collapses the converter applies, then + # build the message list from the processed (collapsed/merged) rows up to + # and including this client turn. This correctly reflects what the + # converter produces — consecutive user/tool merges mean history is + # shorter than the raw source but content-equivalent. + expected: list[dict] = [] + if system_content: + expected.append({"role": "system", "content": system_content}) + for proc_msg, _ in processed[: out_pos + 1]: + expected.extend(_expand_row_to_wire_msgs(proc_msg)) + + # Reconstructed output: system + expand all flat rows up to this turn + got: list[dict] = [] + if system_content: + got.append({"role": "system", "content": system_content}) + for row in flat_rows[: out_pos + 1]: + got.extend(_expand_row_to_wire_msgs(row)) + + exp_norm = [_normalize_msg(m) for m in expected] + got_norm = [_normalize_msg(m) for m in got] + + if exp_norm != got_norm: + errors.append( + f"{conv_id} client-turn {ct_idx + 1} (flat turn {flat_row['turn']}):\n" + f" expected {len(exp_norm)} msgs, got {len(got_norm)}\n" + f" EXPECTED: {json.dumps(exp_norm, ensure_ascii=False)[:400]}\n" + f" GOT: {json.dumps(got_norm, ensure_ascii=False)[:400]}" + ) + total_checked += 1 + + if errors: + print( + f"FAIL: {len(errors)} mismatches out of {total_checked} client turns checked.", + file=sys.stderr, + ) + for err in errors[:20]: + print(err, file=sys.stderr) + if len(errors) > 20: + print(f" ... and {len(errors) - 20} more", file=sys.stderr) + return False + + print( + f"OK: all {total_checked} client turns verified against source.", + file=sys.stderr, + ) + return True + + +def convert(input_path: Path, output_path: Path) -> None: + # Group records by conversation_id, keep only the final snapshot per conversation. + final: dict[str, dict] = {} + with input_path.open() as fh: + for line in fh: + line = line.strip() + if not line: + continue + record = json.loads(line) + conv_id = record["conversation_id"] + if ( + conv_id not in final + or record["conversation_idx"] > final[conv_id]["conversation_idx"] + ): + final[conv_id] = record + + print(f"Found {len(final)} conversations in {input_path.name}", file=sys.stderr) + + rows_written = 0 + with output_path.open("w") as out: + for conv_id, record in sorted(final.items()): + messages = record["messages"] + tools = record.get("tools") or [] + + # Extract system message (always first if present). + system_content: str | None = None + non_system: list[dict] = [] + for msg in messages: + if msg["role"] == "system": + system_content = msg.get("content") + else: + non_system.append(msg) + + # Apply the same user-collapse and tool-merge passes used by verify(). + # _apply_collapses returns [(msg, last_source_idx), ...]; strip the indices. + non_system = [msg for msg, _ in _apply_collapses(non_system)] + + first_user_seen = False + for position, msg in enumerate(non_system): + role = msg["role"] + turn = position + 1 # 1-indexed + + row: dict = {"conversation_id": conv_id, "turn": turn, "role": role} + + # System prompt on the first user row only. + if role == "user" and not first_user_seen: + if system_content is not None: + row["system"] = system_content + first_user_seen = True + + # tool_calls for assistant messages that dispatch tools. + if msg.get("tool_calls"): + row["tool_calls"] = msg["tool_calls"] + + if role == "tool": + # All tool rows use tool_results array (single results have one entry). + if msg.get("tool_results"): + row["tool_results"] = msg["tool_results"] + else: + row["tool_results"] = [ + { + "tool_call_id": msg.get("tool_call_id"), + "content": msg.get("content"), + } + ] + else: + # content field (may be None for tool-dispatching assistant messages) + row["content"] = msg.get("content") + + # Attach tool definitions to client-turn rows only (user + tool). + # This avoids duplicating the large tools array on every assistant row + # while still making them available via load_sample(). + if role in ("user", "tool") and tools: + row["tools"] = tools + + out.write(json.dumps(row, ensure_ascii=False) + "\n") + rows_written += 1 + + print(f"Wrote {rows_written} rows to {output_path}", file=sys.stderr) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Convert agentic snapshot JSONL to MultiTurnDataset flat-row JSONL." + ) + parser.add_argument("input", type=Path, help="Input snapshot JSONL file") + parser.add_argument("output", type=Path, help="Output flat-row JSONL file") + parser.add_argument( + "--verify", + action="store_true", + help=( + "After converting, cross-check every client-turn's pre_built_messages " + "against the source snapshot. Exits with code 1 if any mismatch found." + ), + ) + args = parser.parse_args() + + if not args.input.exists(): + print(f"Error: input file not found: {args.input}", file=sys.stderr) + sys.exit(1) + + convert(args.input, args.output) + + if args.verify: + ok = verify(args.input, args.output) + if not ok: + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/scripts/multi_turn_dataset_schema.json b/scripts/multi_turn_dataset_schema.json new file mode 100644 index 00000000..b1b7ca13 --- /dev/null +++ b/scripts/multi_turn_dataset_schema.json @@ -0,0 +1,557 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "title": "Multi-Turn Conversation Dataset Schema", + "description": "JSON schema describing the structure and requirements for multi-turn conversation datasets in the MLPerf Inference Endpoint Benchmarking System", + "version": "1.0.0", + + "definitions": { + "basicMessageTypes": { + "title": "Basic Message Types", + "description": "Plain conversational messages without tool calls", + "oneOf": [ + { + "title": "User Message", + "type": "object", + "properties": { + "conversation_id": { + "type": "string", + "description": "Unique identifier for the conversation" + }, + "turn": { + "type": "integer", + "minimum": 1, + "description": "Turn number within conversation (1-indexed)" + }, + "role": { + "const": "user", + "description": "Message role - user initiates turns" + }, + "content": { + "type": "string", + "description": "Message content from the user" + }, + "system": { + "type": "string", + "description": "System prompt (typically only on first user turn)" + } + }, + "required": ["conversation_id", "turn", "role", "content"], + "additionalProperties": true + }, + { + "title": "Assistant Message", + "type": "object", + "properties": { + "conversation_id": { + "type": "string", + "description": "Unique identifier for the conversation" + }, + "turn": { + "type": "integer", + "minimum": 1, + "description": "Turn number within conversation" + }, + "role": { + "const": "assistant", + "description": "Message role - assistant responds to user" + }, + "content": { + "type": "string", + "description": "Message content from the assistant" + } + }, + "required": ["conversation_id", "turn", "role", "content"], + "not": { "required": ["tool_calls"] }, + "additionalProperties": true + } + ] + }, + + "toolCallMessage": { + "title": "Assistant Message with Tool Calls", + "description": "Assistant message that dispatches one or more tool calls. Role must be 'assistant' with a non-empty tool_calls array (OpenAI wire format).", + "type": "object", + "properties": { + "conversation_id": { + "type": "string", + "description": "Unique identifier for the conversation" + }, + "turn": { + "type": "integer", + "minimum": 1, + "description": "Turn number within conversation" + }, + "role": { + "const": "assistant", + "description": "Role for a tool-dispatching assistant message." + }, + "content": { + "type": ["string", "null"], + "description": "Optional textual prefix alongside tool dispatch (e.g., 'I will investigate this with bash'). Typically null/absent." + }, + "tool_calls": { + "type": "array", + "minItems": 1, + "description": "List of tool calls dispatched by the assistant", + "items": { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "Unique identifier for this tool call (e.g., 'functions.bash:0')" + }, + "type": { + "type": "string", + "const": "function", + "description": "Tool type (currently only 'function' supported)" + }, + "function": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "Name of the tool/function to invoke" + }, + "arguments": { + "type": "string", + "description": "JSON string containing function arguments" + } + }, + "required": ["name", "arguments"] + } + }, + "required": ["id", "type", "function"] + } + } + }, + "required": ["conversation_id", "turn", "role", "tool_calls"], + "additionalProperties": true + }, + + "toolMessage": { + "title": "Tool Result Message", + "description": "Tool execution results as a list. Single results have one entry; parallel results have multiple entries.", + "type": "object", + "properties": { + "conversation_id": { + "type": "string", + "description": "Unique identifier for the conversation" + }, + "turn": { + "type": "integer", + "minimum": 1, + "description": "Turn number within conversation" + }, + "role": { + "const": "tool", + "description": "Tool result role" + }, + "tool_results": { + "type": "array", + "minItems": 1, + "description": "List of tool execution results. Single tool calls have one entry; parallel tool calls have multiple entries.", + "items": { + "type": "object", + "properties": { + "tool_call_id": { + "type": "string", + "description": "ID of the tool call this result corresponds to" + }, + "content": { + "type": "string", + "description": "Output/result content from the tool execution" + } + }, + "required": ["tool_call_id", "content"] + } + } + }, + "required": ["conversation_id", "turn", "role", "tool_results"], + "additionalProperties": true + }, + + "generationParameters": { + "title": "Generation Parameters", + "description": "Optional parameters controlling the model's behavior for generation", + "type": "object", + "properties": { + "model": { + "type": "string", + "description": "Model name override for this turn" + }, + "max_new_tokens": { + "type": "integer", + "minimum": 1, + "description": "Maximum number of tokens to generate" + }, + "max_completion_tokens": { + "type": "integer", + "minimum": 1, + "description": "OpenAI API compatible max tokens parameter" + }, + "stream": { + "type": "boolean", + "description": "Whether to use streaming for this turn" + }, + "temperature": { + "type": "number", + "minimum": 0, + "maximum": 2, + "description": "Sampling temperature (0 = deterministic, higher = more random)" + }, + "top_p": { + "type": "number", + "minimum": 0, + "maximum": 1, + "description": "Nucleus sampling parameter" + }, + "top_k": { + "type": "integer", + "minimum": 1, + "description": "Top-k sampling parameter" + }, + "seed": { + "type": "integer", + "description": "Random seed for reproducibility" + }, + "repetition_penalty": { + "type": "number", + "minimum": 0, + "description": "Penalty for repeating tokens" + }, + "frequency_penalty": { + "type": "number", + "minimum": -2, + "maximum": 2, + "description": "Frequency penalty for tokens" + }, + "presence_penalty": { + "type": "number", + "minimum": -2, + "maximum": 2, + "description": "Presence penalty for tokens" + }, + "stop": { + "oneOf": [ + { "type": "string" }, + { "type": "array", "items": { "type": "string" } } + ], + "description": "Stop sequences for generation" + }, + "n": { + "type": "integer", + "minimum": 1, + "description": "Number of completions to generate" + }, + "logit_bias": { + "type": "object", + "description": "Token probability adjustments (token_id -> bias)" + }, + "name": { + "type": "string", + "description": "Entity name for role tracking (e.g., 'Bob')" + }, + "user": { + "type": "string", + "description": "End-user identifier for monitoring/abuse detection" + }, + "chat_template": { + "type": "string", + "description": "Custom chat formatting template" + }, + "tools": { + "type": "array", + "description": "OpenAI tool definitions for tool-calling models", + "items": { "type": "object" } + } + } + } + }, + + "type": "object", + "oneOf": [ + { + "title": "Plain Conversation Row", + "description": "A single row representing a plain user or assistant message", + "allOf": [ + { "$ref": "#/definitions/basicMessageTypes" }, + { "$ref": "#/definitions/generationParameters" } + ] + }, + { + "title": "Tool Call Row", + "description": "A single row representing an assistant dispatch of tool calls", + "allOf": [ + { "$ref": "#/definitions/toolCallMessage" }, + { "$ref": "#/definitions/generationParameters" } + ] + }, + { + "title": "Tool Result Row", + "description": "A single row representing one or more tool results", + "allOf": [ + { "$ref": "#/definitions/toolMessage" }, + { "$ref": "#/definitions/generationParameters" } + ] + } + ], + + "examples": [ + { + "title": "Basic user message", + "data": { + "conversation_id": "conv_001", + "turn": 1, + "role": "user", + "content": "I need help resetting my password", + "system": "You are a helpful customer support agent" + } + }, + { + "title": "Assistant response", + "data": { + "conversation_id": "conv_001", + "turn": 2, + "role": "assistant", + "content": "I'd be happy to help. Can you provide your email address?" + } + }, + { + "title": "Assistant with tool calls (converter/OpenAI wire format)", + "data": { + "conversation_id": "sim_001", + "turn": 2, + "role": "assistant", + "tool_calls": [ + { + "id": "functions.bash:0", + "type": "function", + "function": { + "name": "bash", + "arguments": "{\"cmd\": \"cat foo.py\"}" + } + } + ] + } + }, + { + "title": "Tool result from execution", + "data": { + "conversation_id": "sim_001", + "turn": 3, + "role": "tool", + "tool_results": [ + { + "tool_call_id": "functions.bash:0", + "content": "def foo():\n return 1/0" + } + ] + } + }, + { + "title": "Merged parallel tool results", + "data": { + "conversation_id": "sim_002", + "turn": 3, + "role": "tool", + "tool_results": [ + { + "tool_call_id": "functions.bash:0", + "content": "file1.txt" + }, + { + "tool_call_id": "functions.bash:1", + "content": "file2.txt" + } + ] + } + }, + { + "title": "User turn with generation parameters", + "data": { + "conversation_id": "conv_002", + "turn": 5, + "role": "user", + "content": "What's the best way to optimize this code?", + "temperature": 0.7, + "max_new_tokens": 256, + "top_p": 0.9 + } + } + ], + + "documentation": { + "overview": "Multi-turn conversation datasets enable benchmarking of realistic conversational AI workloads where each turn depends on previous responses. The system maintains conversation history and enforces turn sequencing.", + + "requiredFields": [ + { + "field": "conversation_id", + "type": "string", + "description": "Unique identifier for each conversation. All rows belonging to the same conversation must share the same conversation_id." + }, + { + "field": "turn", + "type": "integer", + "description": "Turn number within conversation (1-indexed). Must be consecutive starting at 1 (i.e., 1, 2, 3, …, N with no gaps or duplicates)." + }, + { + "field": "role", + "type": "string", + "enum": ["user", "assistant", "tool"], + "description": "Speaker role. 'user' or 'tool' are client-initiated turns. 'assistant' is the server response — either a terminal reply or a tool dispatch (with tool_calls field)." + }, + { + "field": "content", + "type": "string", + "description": "Message content. Required for 'user' role and plain 'assistant' rows. For tool-dispatching assistant rows, content may be omitted (null/absent). For 'tool' rows using tool_results (merged parallel results), top-level content is absent — results are in the tool_results array instead." + } + ], + + "optionalFields": [ + { + "field": "system", + "type": "string", + "description": "System prompt (typically only on first user turn). Applied to all messages in the conversation." + }, + { + "field": "tool_calls", + "type": "array", + "description": "Tool calls dispatched by assistant (for tool-dispatching 'assistant' rows). Each element has {id, type, function: {name, arguments}}." + }, + { + "field": "tool_results", + "type": "array", + "description": "Tool execution results (required for all 'tool' role rows). Each element has {tool_call_id, content}. Single results have one entry; parallel results have multiple entries." + }, + { + "field": "model", + "type": "string", + "description": "Model name override for this turn." + }, + { + "field": "max_new_tokens", + "type": "integer", + "description": "Maximum tokens to generate for this turn." + }, + { + "field": "temperature", + "type": "number", + "description": "Sampling temperature (0 to 2)." + }, + { + "field": "top_p", + "type": "number", + "description": "Nucleus sampling parameter (0 to 1)." + }, + { + "field": "tools", + "type": "array", + "description": "OpenAI tool definitions forwarded to the endpoint for tool-calling models. The converter attaches this only to client-turn rows (user and tool) to avoid duplicating the large array on every assistant row. Hand-authored datasets typically place it on the first user turn." + } + ], + + "validRoleSequences": [ + { + "name": "Plain conversation", + "sequence": "user → assistant → user → assistant → ...", + "description": "Standard alternating conversation without tool use." + }, + { + "name": "Agentic with tools", + "sequence": "user → assistant → tool → [assistant → tool]* → assistant → user", + "description": "Agent dispatches tools (assistant with tool_calls), executes them, and returns results before final response. 'tool → user' is also valid when no terminal assistant response is needed before the next user turn." + } + ], + "stateMachine": { + "description": "Complete valid-next-state table from _validate_conversation_structure()", + "transitions": { + "start": ["user"], + "user": ["assistant"], + "assistant": ["tool", "user"], + "tool": ["assistant", "user"] + } + }, + + "validationRules": [ + { + "rule": "Turn numbers must be consecutive starting at 1", + "violation": "Turn sequence is not exactly 1, 2, 3, …, N (missing, duplicate, or out-of-range turns)" + }, + { + "rule": "Role sequences must follow the state machine", + "violation": "Invalid transition (e.g., 'user' directly followed by 'user', consecutive 'assistant' rows). Note: 'tool → user' IS a valid transition. The state machine also implicitly enforces that the first row must be a user turn." + } + ], + "notValidated": [ + "tool_results[*].tool_call_id pairing: the validator does NOT verify that tool_call_id values inside tool_results items reference a prior assistant tool_calls entry. Correct pairing is the dataset author's or converter's responsibility." + ], + + "dataTypes": [ + { + "name": "Basic Message", + "roles": ["user", "assistant"], + "fields": { + "required": ["conversation_id", "turn", "role", "content"], + "optional": ["system", "...generation parameters"] + } + }, + { + "name": "Tool Call Dispatch", + "roles": ["assistant"], + "fields": { + "required": ["conversation_id", "turn", "role", "tool_calls"], + "optional": ["content", "...generation parameters"] + } + }, + { + "name": "Tool Result", + "roles": ["tool"], + "fields": { + "required": ["conversation_id", "turn", "role", "tool_results"], + "optional": ["...generation parameters"], + "note": "Expands to one OpenAI tool message per result entry at the wire layer" + } + } + ], + + "conversionFromSnapshot": { + "description": "Agentic datasets are often stored as full-conversation snapshots. Use scripts/convert_agentic_snapshot.py to convert.", + "sourceFormat": "Each JSONL line is a complete conversation snapshot with 'messages' array", + "targetFormat": "Each JSONL line is a single message row with conversation metadata", + "process": [ + "Extract conversation_id, conversation_idx, and messages array", + "Use highest-indexed snapshot per conversation_id", + "Collapse consecutive user messages into a single user row (newline-joined content)", + "Emit all tool result rows as tool_results arrays (single results have one entry; consecutive tool results from parallel dispatch are merged into one row with multiple entries)", + "Flatten messages into individual rows, numbering turns sequentially from 1", + "Attach system prompt to first user row only", + "Attach tools array to client-turn rows (user and tool) only — not to assistant rows", + "Tool-dispatching assistant messages are written as role 'assistant' with tool_calls (OpenAI wire format)" + ] + }, + + "performanceNotes": [ + "Pre-built messages: The system pre-computes complete message lists during dataset load() for efficient turn serving", + "Memory efficiency: ~1KB per turn average; 1000 conversations × 10 turns = ~10MB", + "Hot path: Only client turns (user/tool) are issued; assistant turns remain in backing store for history" + ], + + "commonErrors": [ + { + "error": "Invalid role sequence", + "cause": "Violates state machine (e.g., user→user, user→tool, consecutive assistant rows)", + "fix": "Verify alternation or use conversion script for agentic data" + }, + { + "error": "Turn numbers not consecutive", + "cause": "Turn sequence has gaps, duplicates, or doesn't start at 1", + "fix": "Ensure turns are numbered 1, 2, 3, …, N with no missing or duplicate values" + }, + { + "error": "Turn timeout", + "cause": "Previous turn took too long to complete", + "fix": "Increase turn_timeout_s in configuration or check endpoint performance" + } + ] + } +} diff --git a/scripts/validate_jsonl_schema.py b/scripts/validate_jsonl_schema.py new file mode 100644 index 00000000..1be81dd2 --- /dev/null +++ b/scripts/validate_jsonl_schema.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Validate multi-turn JSONL dataset files against scripts/multi_turn_dataset_schema.json. + +Checks each row's structure against the JSON schema (field types, required fields, +tool_results shape, etc.). Does NOT check cross-row invariants such as turn +numbering or role sequences — those are enforced by MultiTurnDataset at load time. + +Usage: + python scripts/validate_jsonl_schema.py FILE [FILE ...] + python scripts/validate_jsonl_schema.py examples/09_MultiTurn/datasets/agentic_coding_flat.jsonl +""" + +import argparse +import json +import sys +from pathlib import Path + +try: + import jsonschema +except ImportError: + print( + "Error: jsonschema not installed. Run: pip install jsonschema", file=sys.stderr + ) + sys.exit(1) + + +def validate_file(path: Path, schema: dict, max_errors: int = 50) -> int: + """Validate every row in a JSONL file against the schema. + + Returns the number of validation errors found. + """ + errors: list[str] = [] + validator = jsonschema.Draft7Validator(schema) + + with path.open() as fh: + for lineno, line in enumerate(fh, 1): + line = line.strip() + if not line: + continue + try: + row = json.loads(line) + except json.JSONDecodeError as e: + errors.append(f" line {lineno}: JSON parse error: {e}") + if len(errors) >= max_errors: + break + continue + + conv_id = row.get("conversation_id", "") + turn = row.get("turn", "?") + role = row.get("role", "?") + + row_errors = list(validator.iter_errors(row)) + for err in row_errors: + path_str = " -> ".join(str(p) for p in err.absolute_path) or "(root)" + errors.append( + f" line {lineno} [{conv_id} turn={turn} role={role}] " + f"@ {path_str}: {err.message}" + ) + + if len(errors) >= max_errors: + errors.append(f" ... stopping after {max_errors} errors") + break + + if errors: + print(f"FAIL {path.name}: {len(errors)} error(s)") + for msg in errors: + print(msg) + else: + print(f"OK {path.name}") + + return len(errors) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Validate multi-turn JSONL files against scripts/multi_turn_dataset_schema.json." + ) + parser.add_argument("files", nargs="+", type=Path, help="JSONL files to validate") + parser.add_argument( + "--schema", + type=Path, + default=Path(__file__).parent / "multi_turn_dataset_schema.json", + help="Path to the JSON schema file (default: scripts/multi_turn_dataset_schema.json)", + ) + parser.add_argument( + "--max-errors", + type=int, + default=50, + help="Stop reporting after this many errors per file (default: 50)", + ) + args = parser.parse_args() + + if not args.schema.exists(): + print(f"Error: schema not found: {args.schema}", file=sys.stderr) + sys.exit(1) + + schema = json.load(args.schema.open()) + + total_errors = 0 + for path in args.files: + if not path.exists(): + print(f"Error: file not found: {path}", file=sys.stderr) + total_errors += 1 + continue + total_errors += validate_file(path, schema, max_errors=args.max_errors) + + sys.exit(1 if total_errors > 0 else 0) + + +if __name__ == "__main__": + main() diff --git a/src/inference_endpoint/async_utils/services/event_logger/sql_writer.py b/src/inference_endpoint/async_utils/services/event_logger/sql_writer.py index fd6a1559..52c9b010 100644 --- a/src/inference_endpoint/async_utils/services/event_logger/sql_writer.py +++ b/src/inference_endpoint/async_utils/services/event_logger/sql_writer.py @@ -50,6 +50,9 @@ class EventRowModel(Base): timestamp_ns: Mapped[int] = mapped_column(BigInteger, nullable=False) """Monotonic timestamp in nanoseconds.""" + conversation_id: Mapped[str] = mapped_column(String, nullable=False, default="") + turn: Mapped[int | None] = mapped_column(Integer, nullable=True) + data: Mapped[bytes] = mapped_column(LargeBinary, nullable=False, default=b"") """JSON-encoded event data.""" @@ -61,6 +64,8 @@ def _record_to_row(record: EventRecord) -> EventRowModel: sample_uuid=record.sample_uuid, event_type=topic, timestamp_ns=record.timestamp_ns, + conversation_id=record.conversation_id, + turn=record.turn, data=msgspec.json.encode(record.data), ) diff --git a/src/inference_endpoint/async_utils/services/metrics_aggregator/metrics_table.py b/src/inference_endpoint/async_utils/services/metrics_aggregator/metrics_table.py index ae66821a..46a17e92 100644 --- a/src/inference_endpoint/async_utils/services/metrics_aggregator/metrics_table.py +++ b/src/inference_endpoint/async_utils/services/metrics_aggregator/metrics_table.py @@ -181,7 +181,9 @@ class AsyncTokenTrigger(EmitTrigger): Subclasses implement ``_extract_text()`` to pull the text to tokenize from the event record. If text is returned, an async task is created - to tokenize and emit. Subclasses can override ``_compute_value()`` to + to tokenize and emit. Subclasses can also override ``_extract_message()`` + to return (content, reasoning, tool_calls) for chat-template–aware tokenization + when tool calls are present. Subclasses can override ``_compute_value()`` to transform the token count before storing. """ @@ -205,6 +207,16 @@ def _extract_text( """Return the text to tokenize, or None to skip.""" raise NotImplementedError() + def _extract_message( + self, ev_rec: EventRecord, row: SampleRow, pre_change: dict[str, Any] + ) -> tuple[str, str | None, tuple[dict[str, Any], ...] | None] | None: + """Return (content, reasoning, tool_calls) for message-aware tokenization, or None. + + When non-None is returned, ``token_count_message_async`` is used instead of + ``token_count_async``. Default returns None (use text path). + """ + return None + def _compute_value( self, token_count: int, ev_rec: EventRecord, pre_change: dict[str, Any] ) -> int | float | None: @@ -214,6 +226,27 @@ def _compute_value( def fire(self, ev_rec, row, pre_change): if self._pool is None or self._loop is None: return None + + message_parts = self._extract_message(ev_rec, row, pre_change) + if message_parts is not None: + content, reasoning, tool_calls = message_parts + pool, loop = self._pool, self._loop + registry, name = self.registry, self.metric_name + uuid = row.sample_uuid + + async def _tokenize_message_and_emit() -> None: + try: + count = await pool.token_count_message_async( + content, reasoning, tool_calls, loop + ) + value = self._compute_value(count, ev_rec, pre_change) + if value is not None: + registry.record(name, value) + except Exception: + logger.exception("%s tokenization failed for %s", name, uuid) + + return loop.create_task(_tokenize_message_and_emit()) + text = self._extract_text(ev_rec, row, pre_change) if not text: return None @@ -319,10 +352,18 @@ def __init__( def _extract_text(self, ev_rec, row, pre_change): if isinstance(ev_rec.data, TextModelOutput): + if ev_rec.data.tool_calls: + # Delegate to _extract_message for chat-template tokenization. + return None text = str(ev_rec.data) return text if text else None return None + def _extract_message(self, ev_rec, row, pre_change): + if isinstance(ev_rec.data, TextModelOutput) and ev_rec.data.tool_calls: + return ev_rec.data.as_message_parts() + return None + class TpotTrigger(AsyncTokenTrigger): """TPOT = (complete_ns - recv_first_ns) / token_count(text_after_first_chunk). @@ -358,9 +399,19 @@ def _extract_text(self, ev_rec, row, pre_change): if pre_change.get(SampleField.RECV_FIRST_NS) is None: return None if isinstance(ev_rec.data, TextModelOutput): + if ev_rec.data.tool_calls: + # Delegate to _extract_message for chat-template tokenization. + return None return ev_rec.data.text_after_first_chunk() or None return None + def _extract_message(self, ev_rec, row, pre_change): + if pre_change.get(SampleField.RECV_FIRST_NS) is None: + return None + if isinstance(ev_rec.data, TextModelOutput) and ev_rec.data.tool_calls: + return ev_rec.data.as_message_parts_after_first_chunk() + return None + def _compute_value(self, token_count, ev_rec, pre_change): if token_count <= 0: return None diff --git a/src/inference_endpoint/async_utils/services/metrics_aggregator/registry.py b/src/inference_endpoint/async_utils/services/metrics_aggregator/registry.py index 1268ba9e..800238f3 100644 --- a/src/inference_endpoint/async_utils/services/metrics_aggregator/registry.py +++ b/src/inference_endpoint/async_utils/services/metrics_aggregator/registry.py @@ -274,7 +274,7 @@ def _hdr_stat(self) -> SeriesStat: total=self._total, min=self._min, max=self._max, - sum_sq=self._sum_sq, + sum_sq=float(self._sum_sq), percentiles=perc_dict, histogram=histogram, ) @@ -313,7 +313,7 @@ def _exact_stat(self) -> SeriesStat: total=self._total, min=self._min, max=self._max, - sum_sq=self._sum_sq, + sum_sq=float(self._sum_sq), percentiles=perc_dict, histogram=histogram, ) diff --git a/src/inference_endpoint/async_utils/services/metrics_aggregator/token_metrics.py b/src/inference_endpoint/async_utils/services/metrics_aggregator/token_metrics.py index 56dee33f..b46990f3 100644 --- a/src/inference_endpoint/async_utils/services/metrics_aggregator/token_metrics.py +++ b/src/inference_endpoint/async_utils/services/metrics_aggregator/token_metrics.py @@ -18,15 +18,53 @@ from __future__ import annotations import asyncio +import json +import logging import threading from concurrent.futures import ThreadPoolExecutor -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any +import msgspec from transformers import AutoTokenizer +# Minimal user message used to satisfy chat templates that reject assistant-only +# message lists. Its token count is subtracted so only the assistant payload is +# measured. +_PREFIX_USER_MSG: dict[str, str] = {"role": "user", "content": ""} + + +def _normalize_tool_calls_for_template( + tool_calls: tuple[dict[str, Any], ...] | list[dict[str, Any]], +) -> list[dict[str, Any]]: + """Ensure ``function.arguments`` is a dict, not the OpenAI-wire JSON string. + + Hermes-style chat templates iterate ``arguments`` as a mapping; a string + payload raises and forces the fallback path, inflating token counts. + """ + normalized: list[dict[str, Any]] = [] + for tc in tool_calls: + fn = tc.get("function") or {} + args = fn.get("arguments") + if isinstance(args, str): + try: + parsed = json.loads(args) + except (json.JSONDecodeError, TypeError): + normalized.append(tc) + continue + if isinstance(parsed, dict): + new_tc = dict(tc) + new_tc["function"] = {**fn, "arguments": parsed} + normalized.append(new_tc) + continue + normalized.append(tc) + return normalized + + if TYPE_CHECKING: from transformers import PreTrainedTokenizerBase +logger = logging.getLogger(__name__) + class TokenizePool: """A pool of worker threads, each with its own HuggingFace AutoTokenizer. @@ -52,6 +90,7 @@ def __init__(self, tokenizer_name: str, n_workers: int) -> None: self._tokenizer_name = tokenizer_name self._n_workers = n_workers self._thread_local = threading.local() + self._fallback_warned: set[str] = set() self._executor: ThreadPoolExecutor | None = ThreadPoolExecutor( max_workers=n_workers, thread_name_prefix="TokenizePool", @@ -79,6 +118,35 @@ def _get_thread_tokenizer(self) -> PreTrainedTokenizerBase: self._thread_local.tokenizer = AutoTokenizer.from_pretrained( self._tokenizer_name ) + # Baseline = tokens contributed by a [user, empty-assistant] pair minus + # the [user] prefix alone. Some templates (Qwen3-Coder, etc.) reject + # assistant-only message lists, so a user prefix is required; we + # subtract it out so the baseline reflects only the assistant frame. + try: + tok = self._thread_local.tokenizer + prefix_rendered = tok.apply_chat_template( + [_PREFIX_USER_MSG], + tokenize=False, + add_generation_prompt=False, + ) + prefix_len = len(tok.tokenize(prefix_rendered)) + with_empty_assistant_rendered = tok.apply_chat_template( + [_PREFIX_USER_MSG, {"role": "assistant", "content": ""}], + tokenize=False, + add_generation_prompt=False, + ) + with_empty_assistant_len = len( + tok.tokenize(with_empty_assistant_rendered) + ) + self._thread_local.prefix_len = prefix_len + self._thread_local.baseline = with_empty_assistant_len - prefix_len + except Exception: + self._thread_local.prefix_len = 0 + self._thread_local.baseline = 0 + logger.exception( + "Failed to compute chat-template baseline for %s; tool-call token counts may be over-estimated", + self._tokenizer_name, + ) return self._thread_local.tokenizer def _token_count_worker(self, text: str) -> int: @@ -86,6 +154,53 @@ def _token_count_worker(self, text: str) -> int: tokenizer = self._get_thread_tokenizer() return len(tokenizer.tokenize(text)) + def _token_count_message_worker( + self, + content: str, + reasoning: str | None, + tool_calls: tuple[dict[str, Any], ...] | None, + ) -> int: + """Worker entry: tokenize a full assistant message using apply_chat_template. + + Falls back to whitespace-split tokenization if apply_chat_template raises + (e.g. the template does not support tool_calls or reasoning fields). + """ + tokenizer = self._get_thread_tokenizer() + msg: dict[str, Any] = {"role": "assistant", "content": content or ""} + if reasoning: + msg["reasoning_content"] = reasoning + if tool_calls: + msg["tool_calls"] = _normalize_tool_calls_for_template(tool_calls) + try: + rendered = tokenizer.apply_chat_template( + [_PREFIX_USER_MSG, msg], + tokenize=False, + add_generation_prompt=False, + ) + full = len(tokenizer.tokenize(rendered)) + prefix_len = getattr(self._thread_local, "prefix_len", 0) + baseline = getattr(self._thread_local, "baseline", 0) + return max(0, full - prefix_len - baseline) + except Exception as exc: + key = f"{self._tokenizer_name}:{type(exc).__name__}" + if key not in self._fallback_warned: + self._fallback_warned.add(key) + logger.exception( + "apply_chat_template failed for %s (%s); falling back to " + "whitespace tokenization. Tool-call OSL/TPOT may diverge " + "from server-side counts for this run.", + self._tokenizer_name, + type(exc).__name__, + ) + tool_calls_json = ( + msgspec.json.encode(list(tool_calls)).decode() if tool_calls else None + ) + parts = [ + p for p in (content or None, reasoning or None, tool_calls_json) if p + ] + fallback_text = "\n".join(parts) + return self._token_count_worker(fallback_text) + def token_count(self, text: str) -> int: """Return the number of tokens in the input string (blocking).""" if self._executor is None: @@ -93,6 +208,20 @@ def token_count(self, text: str) -> int: future = self._executor.submit(self._token_count_worker, text) return future.result() + def token_count_message( + self, + content: str, + reasoning: str | None, + tool_calls: tuple[dict[str, Any], ...] | None, + ) -> int: + """Return the token count for an assistant message (blocking).""" + if self._executor is None: + raise RuntimeError("TokenizePool is closed") + future = self._executor.submit( + self._token_count_message_worker, content, reasoning, tool_calls + ) + return future.result() + async def token_count_async( self, text: str, loop: asyncio.AbstractEventLoop ) -> int: @@ -107,6 +236,24 @@ async def token_count_async( self._executor, self._token_count_worker, text ) + async def token_count_message_async( + self, + content: str, + reasoning: str | None, + tool_calls: tuple[dict[str, Any], ...] | None, + loop: asyncio.AbstractEventLoop, + ) -> int: + """Return the token count for an assistant message without blocking the event loop.""" + if self._executor is None: + raise RuntimeError("TokenizePool is closed") + return await loop.run_in_executor( + self._executor, + self._token_count_message_worker, + content, + reasoning, + tool_calls, + ) + def close(self) -> None: """Shut down the worker pool. Idempotent.""" if self._executor is not None: diff --git a/src/inference_endpoint/commands/benchmark/execute.py b/src/inference_endpoint/commands/benchmark/execute.py index 1abb44c3..1a569f60 100644 --- a/src/inference_endpoint/commands/benchmark/execute.py +++ b/src/inference_endpoint/commands/benchmark/execute.py @@ -30,6 +30,7 @@ import signal import tempfile import uuid +from collections.abc import Callable from dataclasses import dataclass, field from datetime import datetime from pathlib import Path @@ -40,6 +41,7 @@ import msgspec.json from huggingface_hub import model_info from tqdm import tqdm +from transformers import AutoTokenizer from transformers.utils import logging as transformers_logging from inference_endpoint.async_utils.event_publisher import EventPublisherService @@ -54,6 +56,9 @@ from inference_endpoint.async_utils.services.metrics_aggregator.subscriber import ( MetricsSnapshotSubscriber, ) +from inference_endpoint.async_utils.services.metrics_aggregator.token_metrics import ( + _normalize_tool_calls_for_template, +) from inference_endpoint.async_utils.transport.zmq.context import ManagedZMQContext from inference_endpoint.config.runtime_settings import RuntimeSettings from inference_endpoint.config.schema import ( @@ -69,6 +74,7 @@ from inference_endpoint.core.types import QueryResult from inference_endpoint.dataset_manager.dataset import Dataset from inference_endpoint.dataset_manager.factory import DataLoaderFactory +from inference_endpoint.dataset_manager.multi_turn_dataset import MultiTurnDataset from inference_endpoint.endpoint_client.cpu_affinity import AffinityPlan, pin_loadgen from inference_endpoint.endpoint_client.http_client import HTTPEndpointClient from inference_endpoint.endpoint_client.http_sample_issuer import HttpClientSampleIssuer @@ -79,6 +85,8 @@ InputValidationError, SetupError, ) +from inference_endpoint.load_generator.conversation_manager import ConversationManager +from inference_endpoint.load_generator.multi_turn_strategy import MultiTurnStrategy from inference_endpoint.load_generator.session import ( BenchmarkSession, PhaseConfig, @@ -291,6 +299,75 @@ def _load_datasets( return dataloader, accuracy_datasets, eval_configs +def _precompute_isl_for_multi_turn( + dataloader: MultiTurnDataset, tokenizer_name: str +) -> None: + """Tokenize pre-built message lists and store token counts in each sample. + + Runs apply_chat_template once per client turn so the hot-path IslTrigger + sync path (len(token_ids)) is used instead of on-the-fly text tokenization. + Only affects dataset-history turns; live-history turns override 'messages' + at runtime so the stored input_tokens are stale (acceptable approximation). + """ + try: + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + except Exception: + logger.exception( + "ISL pre-computation: failed to load tokenizer %s; " + "falling back to text-tokenization at runtime", + tokenizer_name, + ) + return + skipped = 0 + first_failure_logged = False + for sample in dataloader.data or []: + messages = sample.get("messages") + if not messages: + continue + try: + normalized_messages = [] + for msg in messages: + if msg.get("tool_calls"): + msg = { + **msg, + "tool_calls": _normalize_tool_calls_for_template( + msg["tool_calls"] + ), + } + normalized_messages.append(msg) + tools = sample.get("tools") + raw = tokenizer.apply_chat_template( + normalized_messages, + tools=tools if tools else None, + tokenize=True, + add_generation_prompt=True, + ) + # Some tokenizers (e.g. Qwen3 fast tokenizer) return BatchEncoding + # instead of a plain list; extract .input_ids in that case. + token_ids: list[int] = raw.input_ids if hasattr(raw, "input_ids") else raw + sample["input_tokens"] = token_ids + except Exception: + if not first_failure_logged: + logger.exception( + "ISL pre-computation: apply_chat_template failed (first failure shown)" + ) + first_failure_logged = True + skipped += 1 + if skipped: + logger.warning( + "ISL pre-computation: %d turn(s) skipped (apply_chat_template failed)", + skipped, + ) + total_with_messages = len([s for s in (dataloader.data or []) if s.get("messages")]) + if total_with_messages > 0 and skipped == total_with_messages: + logger.warning( + "ISL precomputation: all %d turn(s) failed apply_chat_template; " + "ISL metrics will use text-tokenization fallback. " + "Check tokenizer/template compatibility.", + total_with_messages, + ) + + def setup_benchmark(config: BenchmarkConfig, test_mode: TestMode) -> BenchmarkContext: """Load tokenizer, dataset, create scheduler, setup report dir.""" # CPU affinity @@ -320,6 +397,10 @@ def setup_benchmark(config: BenchmarkConfig, test_mode: TestMode) -> BenchmarkCo # Datasets dataloader, accuracy_datasets, eval_configs = _load_datasets(config, report_dir) + if isinstance(dataloader, MultiTurnDataset) and tokenizer_name is not None: + logger.info("Pre-computing ISL token counts for multi-turn dataset…") + _precompute_isl_for_multi_turn(dataloader, tokenizer_name) + # Setup runtime settings using factory method rt_settings = RuntimeSettings.from_config(config, dataloader.num_samples()) @@ -350,14 +431,21 @@ def setup_benchmark(config: BenchmarkConfig, test_mode: TestMode) -> BenchmarkCo ) -def _build_phases(ctx: BenchmarkContext) -> list[PhaseConfig]: +def _build_phases( + ctx: BenchmarkContext, + perf_strategy: MultiTurnStrategy | None = None, +) -> list[PhaseConfig]: """Build the phase list from BenchmarkContext.""" phases: list[PhaseConfig] = [] # Performance phase phases.append( PhaseConfig( - "performance", ctx.rt_settings, ctx.dataloader, PhaseType.PERFORMANCE + "performance", + ctx.rt_settings, + ctx.dataloader, + PhaseType.PERFORMANCE, + strategy=perf_strategy, ) ) @@ -365,6 +453,17 @@ def _build_phases(ctx: BenchmarkContext) -> list[PhaseConfig]: # what Scorer._load_sample_index_map() looks up in sample_idx_map.json for eval_cfg in ctx.eval_configs: acc_ds = eval_cfg.dataset + if isinstance(acc_ds, MultiTurnDataset): + raise InputValidationError( + f"Accuracy dataset '{eval_cfg.dataset_name}' is a MultiTurnDataset, " + "which is not yet supported for accuracy evaluation." + ) + # Accuracy phases run at MAX_THROUGHPUT; inheriting perf_lp (e.g. POISSON) + # would silently rate-limit evaluation until a multi-turn accuracy strategy + # and QPS-budgeting support are added. + acc_load_pattern: LoadPattern | None = LoadPattern( + type=LoadPatternType.MAX_THROUGHPUT + ) acc_settings = RuntimeSettings( metric_target=ctx.rt_settings.metric_target, reported_metrics=ctx.rt_settings.reported_metrics, @@ -375,7 +474,7 @@ def _build_phases(ctx: BenchmarkContext) -> list[PhaseConfig]: min_sample_count=acc_ds.num_samples() * acc_ds.repeats, rng_sched=ctx.rt_settings.rng_sched, rng_sample_index=ctx.rt_settings.rng_sample_index, - load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), + load_pattern=acc_load_pattern, ) phases.append( PhaseConfig(eval_cfg.dataset_name, acc_settings, acc_ds, PhaseType.ACCURACY) @@ -525,16 +624,63 @@ async def _run_benchmark_async( launcher.kill_all() raise SetupError(f"Failed to connect to endpoint: {e}") from e + # Build multi-turn strategy if the performance dataset is a MultiTurnDataset. + multi_turn_strategy: MultiTurnStrategy | None = None + if isinstance(ctx.dataloader, MultiTurnDataset): + mt_cfg = None + if ctx.config.datasets: + perf_ds_cfg = next( + ( + d + for d in ctx.config.datasets + if d.type == DatasetType.PERFORMANCE + ), + None, + ) + if perf_ds_cfg is not None: + mt_cfg = perf_ds_cfg.multi_turn + assert ctx.dataloader.conversation_metadata is not None + multi_turn_strategy = MultiTurnStrategy( + conversation_manager=ConversationManager(), + dataset_metadata=ctx.dataloader.conversation_metadata, + multi_turn_config=mt_cfg, + target_concurrency=ctx.config.settings.load_pattern.target_concurrency, + ) + + _on_sample_complete: Callable[[QueryResult], None] + if multi_turn_strategy is not None: + + def _on_sample_complete(result: QueryResult) -> None: + try: + multi_turn_strategy.on_sample_complete(result) + except Exception: + logger.exception( + "multi_turn_strategy.on_sample_complete failed (result=%s)", + result.id, + ) + try: + collector.on_complete_hook(result) + except Exception: + logger.exception( + "collector.on_complete_hook failed (result=%s)", result.id + ) + + multi_turn_strategy._session_on_sample_complete = _on_sample_complete + multi_turn_strategy._session_publisher = publisher + + else: + _on_sample_complete = collector.on_complete_hook + # Create session session = BenchmarkSession( issuer=issuer, event_publisher=publisher, loop=loop, - on_sample_complete=collector.on_complete_hook, + on_sample_complete=_on_sample_complete, session_id=session_id, ) - phases = _build_phases(ctx) + phases = _build_phases(ctx, perf_strategy=multi_turn_strategy) report: Report | None = None loop.add_signal_handler(signal.SIGINT, session.stop) diff --git a/src/inference_endpoint/config/runtime_settings.py b/src/inference_endpoint/config/runtime_settings.py index fb349a02..7259bda7 100644 --- a/src/inference_endpoint/config/runtime_settings.py +++ b/src/inference_endpoint/config/runtime_settings.py @@ -32,6 +32,7 @@ from typing import TYPE_CHECKING from .. import metrics +from .schema import LoadPatternType logger = logging.getLogger(__name__) @@ -194,6 +195,25 @@ def total_samples_to_issue( ) return self.n_samples_to_issue + # Multi-turn must issue exactly all client turns — QPS-based formulas are meaningless. + if ( + self.load_pattern is not None + and self.load_pattern.type == LoadPatternType.MULTI_TURN + ): + if self.n_samples_from_dataset < self.min_sample_count: + logger.warning( + "Multi-turn run: min_sample_count=%d exceeds dataset " + "client-turn count=%d; using dataset size. Multi-turn cannot " + "issue more samples than the dataset provides.", + self.min_sample_count, + self.n_samples_from_dataset, + ) + logger.debug( + "Sample count: %d (multi-turn: issuing all client turns)", + self.n_samples_from_dataset, + ) + return self.n_samples_from_dataset + # If min_duration is 0, use all dataset samples (new CLI default behavior) if self.min_duration_ms == 0: result = max(self.min_sample_count, self.n_samples_from_dataset) diff --git a/src/inference_endpoint/config/schema.py b/src/inference_endpoint/config/schema.py index 8ab8f3b0..d3390483 100644 --- a/src/inference_endpoint/config/schema.py +++ b/src/inference_endpoint/config/schema.py @@ -60,6 +60,7 @@ class LoadPatternType(str, Enum): MAX_THROUGHPUT = "max_throughput" # Offline: all queries at t=0 POISSON = "poisson" # Online: fixed QPS with Poisson distribution CONCURRENCY = "concurrency" # Online: fixed concurrent requests + MULTI_TURN = "multi_turn" # Multi-turn conversations with turn sequencing BURST = "burst" # Burst pattern (TODO) STEP = "step" # Step pattern (TODO) @@ -230,6 +231,27 @@ def get_ruleset_instance(self) -> BenchmarkSuiteRuleset: return get_ruleset(self.ruleset) +class MultiTurnConfig(BaseModel): + """Multi-turn conversation configuration. + + Configuration for benchmarking conversational AI workloads with turn sequencing. + Enables testing multi-turn conversations where each turn depends on previous responses. + Presence of this block in the dataset config enables multi-turn mode. + + Attributes: + turn_timeout_s: Deadline between issuing a turn and receiving its + response. A timeout aborts that turn and all remaining client + turns of the same conversation because subsequent turns depend + on the timed-out response. + use_dataset_history: If True, use pre-built message history from dataset. + """ + + model_config = ConfigDict(extra="forbid", frozen=True) + + turn_timeout_s: float = Field(default=300.0, gt=0) + use_dataset_history: bool = True + + class Dataset(BaseModel): """Dataset configuration. @@ -260,6 +282,9 @@ class Dataset(BaseModel): accuracy_config: AccuracyConfig | None = Field( None, description="Accuracy evaluation settings" ) + multi_turn: MultiTurnConfig | None = Field( + None, description="Multi-turn conversation configuration" + ) @model_validator(mode="after") def _auto_derive_name(self) -> Self: @@ -389,6 +414,12 @@ def _validate_completeness(self) -> Self: raise ValueError( "Concurrency requires --concurrency (e.g., --concurrency 10)" ) + if self.type == LoadPatternType.MULTI_TURN and ( + not self.target_concurrency or self.target_concurrency <= 0 + ): + raise ValueError( + "Multi-turn requires --concurrency (e.g., --concurrency 96)" + ) return self @@ -586,11 +617,41 @@ def _resolve_and_validate(self) -> Self: f"Offline benchmarks must use 'max_throughput', got '{lp.type}'" ) elif effective_mode == TestType.ONLINE: - if lp.type not in (LoadPatternType.POISSON, LoadPatternType.CONCURRENCY): + if lp.type not in ( + LoadPatternType.POISSON, + LoadPatternType.CONCURRENCY, + LoadPatternType.MULTI_TURN, + ): raise ValueError( - "Online mode requires --load-pattern (poisson or concurrency)" + "Online mode requires --load-pattern (poisson, concurrency, or multi_turn)" ) + # Cross-validate load_pattern.type=multi_turn ↔ performance dataset.multi_turn config + has_multi_turn_perf_dataset = any( + d.multi_turn is not None + for d in (self.datasets or []) + if d.type == DatasetType.PERFORMANCE + ) + has_multi_turn_non_perf_dataset = any( + d.multi_turn is not None + for d in (self.datasets or []) + if d.type != DatasetType.PERFORMANCE + ) + if has_multi_turn_non_perf_dataset: + raise ValueError( + "multi_turn config is only supported on performance datasets; " + "accuracy datasets with multi_turn are not supported" + ) + if lp.type == LoadPatternType.MULTI_TURN and not has_multi_turn_perf_dataset: + raise ValueError( + "load_pattern.type=multi_turn requires the performance dataset to have multi_turn config" + ) + if has_multi_turn_perf_dataset and lp.type != LoadPatternType.MULTI_TURN: + raise ValueError( + f"Performance dataset with multi_turn config requires load_pattern.type=multi_turn, " + f"got '{lp.type}'" + ) + return self @model_validator(mode="after") diff --git a/src/inference_endpoint/config/templates/concurrency_template.yaml b/src/inference_endpoint/config/templates/concurrency_template.yaml index db87b71c..d9087994 100644 --- a/src/inference_endpoint/config/templates/concurrency_template.yaml +++ b/src/inference_endpoint/config/templates/concurrency_template.yaml @@ -14,7 +14,7 @@ settings: max_duration_ms: 0 # Maximum test duration in ms (0 for no limit) n_samples_to_issue: null # Sample count override load_pattern: - type: concurrency # Load pattern type | options: max_throughput, poisson, concurrency, burst, step + type: concurrency # Load pattern type | options: max_throughput, poisson, concurrency, multi_turn, burst, step target_concurrency: 32 # Concurrent requests endpoint_config: endpoints: # Endpoint URL(s). Must include scheme, e.g. 'http://host:port'. diff --git a/src/inference_endpoint/config/templates/concurrency_template_full.yaml b/src/inference_endpoint/config/templates/concurrency_template_full.yaml index 2e16bc0d..24e2f337 100644 --- a/src/inference_endpoint/config/templates/concurrency_template_full.yaml +++ b/src/inference_endpoint/config/templates/concurrency_template_full.yaml @@ -22,6 +22,7 @@ datasets: # Dataset configs parser: # Column remapping: {prompt: , system: } prompt: text_input accuracy_config: null # Accuracy evaluation settings + multi_turn: null # Multi-turn conversation configuration - name: accuracy type: accuracy # Dataset purpose: performance or accuracy | options: performance, accuracy path: '' # Dataset file path @@ -36,6 +37,7 @@ datasets: # Dataset configs ground_truth: ground_truth # Ground truth column name extractor: boxed_math_extractor # Answer extractor (abcd_extractor, boxed_math_extractor, identity_extractor, python_code_extractor) num_repeats: 1 # Repeat dataset N times for evaluation + multi_turn: null # Multi-turn conversation configuration settings: runtime: min_duration_ms: 600000 # Min duration (ms, or with suffix: 600s, 10m) @@ -44,7 +46,7 @@ settings: scheduler_random_seed: 42 # Scheduler RNG seed dataloader_random_seed: 42 # Dataloader RNG seed load_pattern: - type: concurrency # Load pattern type | options: max_throughput, poisson, concurrency, burst, step + type: concurrency # Load pattern type | options: max_throughput, poisson, concurrency, multi_turn, burst, step target_qps: null # Target QPS target_concurrency: 32 # Concurrent requests client: diff --git a/src/inference_endpoint/config/templates/offline_template_full.yaml b/src/inference_endpoint/config/templates/offline_template_full.yaml index a3b4ed0a..1eceb584 100644 --- a/src/inference_endpoint/config/templates/offline_template_full.yaml +++ b/src/inference_endpoint/config/templates/offline_template_full.yaml @@ -22,6 +22,7 @@ datasets: # Dataset configs parser: # Column remapping: {prompt: , system: } prompt: text_input accuracy_config: null # Accuracy evaluation settings + multi_turn: null # Multi-turn conversation configuration - name: accuracy type: accuracy # Dataset purpose: performance or accuracy | options: performance, accuracy path: '' # Dataset file path @@ -36,6 +37,7 @@ datasets: # Dataset configs ground_truth: ground_truth # Ground truth column name extractor: boxed_math_extractor # Answer extractor (abcd_extractor, boxed_math_extractor, identity_extractor, python_code_extractor) num_repeats: 1 # Repeat dataset N times for evaluation + multi_turn: null # Multi-turn conversation configuration settings: runtime: min_duration_ms: 600000 # Min duration (ms, or with suffix: 600s, 10m) @@ -44,7 +46,7 @@ settings: scheduler_random_seed: 42 # Scheduler RNG seed dataloader_random_seed: 42 # Dataloader RNG seed load_pattern: - type: max_throughput # Load pattern type | options: max_throughput, poisson, concurrency, burst, step + type: max_throughput # Load pattern type | options: max_throughput, poisson, concurrency, multi_turn, burst, step target_qps: null # Target QPS target_concurrency: null # Concurrent requests client: diff --git a/src/inference_endpoint/config/templates/online_template.yaml b/src/inference_endpoint/config/templates/online_template.yaml index c8431687..c3348353 100644 --- a/src/inference_endpoint/config/templates/online_template.yaml +++ b/src/inference_endpoint/config/templates/online_template.yaml @@ -14,7 +14,7 @@ settings: max_duration_ms: 0 # Maximum test duration in ms (0 for no limit) n_samples_to_issue: null # Sample count override load_pattern: - type: poisson # Load pattern type | options: max_throughput, poisson, concurrency, burst, step + type: poisson # Load pattern type | options: max_throughput, poisson, concurrency, multi_turn, burst, step target_qps: 10.0 # Target QPS endpoint_config: endpoints: # Endpoint URL(s). Must include scheme, e.g. 'http://host:port'. diff --git a/src/inference_endpoint/config/templates/online_template_full.yaml b/src/inference_endpoint/config/templates/online_template_full.yaml index d6c44dcc..c03ae175 100644 --- a/src/inference_endpoint/config/templates/online_template_full.yaml +++ b/src/inference_endpoint/config/templates/online_template_full.yaml @@ -22,6 +22,7 @@ datasets: # Dataset configs parser: # Column remapping: {prompt: , system: } prompt: text_input accuracy_config: null # Accuracy evaluation settings + multi_turn: null # Multi-turn conversation configuration - name: accuracy type: accuracy # Dataset purpose: performance or accuracy | options: performance, accuracy path: '' # Dataset file path @@ -36,6 +37,7 @@ datasets: # Dataset configs ground_truth: ground_truth # Ground truth column name extractor: boxed_math_extractor # Answer extractor (abcd_extractor, boxed_math_extractor, identity_extractor, python_code_extractor) num_repeats: 1 # Repeat dataset N times for evaluation + multi_turn: null # Multi-turn conversation configuration settings: runtime: min_duration_ms: 600000 # Min duration (ms, or with suffix: 600s, 10m) @@ -44,7 +46,7 @@ settings: scheduler_random_seed: 42 # Scheduler RNG seed dataloader_random_seed: 42 # Dataloader RNG seed load_pattern: - type: poisson # Load pattern type | options: max_throughput, poisson, concurrency, burst, step + type: poisson # Load pattern type | options: max_throughput, poisson, concurrency, multi_turn, burst, step target_qps: 10.0 # Target QPS target_concurrency: null # Concurrent requests client: diff --git a/src/inference_endpoint/core/record.py b/src/inference_endpoint/core/record.py index da35389e..9ac60e1b 100644 --- a/src/inference_endpoint/core/record.py +++ b/src/inference_endpoint/core/record.py @@ -153,6 +153,8 @@ class EventRecord(msgspec.Struct, kw_only=True, frozen=True, gc=False): # type: event_type: EventType timestamp_ns: int = msgspec.field(default_factory=time.monotonic_ns) sample_uuid: str = "" + conversation_id: str = "" + turn: int | None = None data: OUTPUT_TYPE | PromptData | ErrorData | None = None diff --git a/src/inference_endpoint/core/types.py b/src/inference_endpoint/core/types.py index f09f8a7e..aaea804f 100644 --- a/src/inference_endpoint/core/types.py +++ b/src/inference_endpoint/core/types.py @@ -90,23 +90,33 @@ class TextModelOutput( ): # type: ignore[call-arg] """Structured output from a text model. - Supports main output and optional reasoning (e.g. chain-of-thought). + Supports main output, optional reasoning (e.g. chain-of-thought), and tool calls. Each field may be a string (non-streaming) or tuple of strings (streaming chunks). + AT-RISK (gc=False): Has mutable container field `tool_calls`. Any change that + mutates `tool_calls` after construction or stores cyclic references in it + must be audited; if so, remove gc=False. + Attributes: output: Main model output. Defaults to empty string. reasoning: Optional reasoning trace. Defaults to None. + tool_calls: Optional structured tool calls. Defaults to None. + Placed after reasoning so wire-format with array_like=True is + backward compatible (missing trailing elements decode as default). """ output: OUTPUT_ELEM_TYPE = "" reasoning: OUTPUT_ELEM_TYPE | None = None + tool_calls: tuple[dict[str, Any], ...] | None = None def __post_init__(self): - """Convert list to tuple for output and reasoning to preserve immutability.""" + """Convert list to tuple for output, reasoning, and tool_calls to preserve immutability.""" if isinstance(self.output, list): msgspec.structs.force_setattr(self, "output", tuple(self.output)) if self.reasoning is not None and isinstance(self.reasoning, list): msgspec.structs.force_setattr(self, "reasoning", tuple(self.reasoning)) + if self.tool_calls is not None and isinstance(self.tool_calls, list): + msgspec.structs.force_setattr(self, "tool_calls", tuple(self.tool_calls)) def __str__(self) -> str: """Return the full output as a single string (joins tuple chunks if streaming).""" @@ -123,6 +133,9 @@ def __str__(self) -> str: elif isinstance(self.output, tuple): parts.extend(self.output) + if self.tool_calls: + parts.append(msgspec.json.encode(list(self.tool_calls)).decode()) + # NOTE: Not sure how output is formatted - there *might* need to be a space or separator between # reasoning and output depending on the accumulator / API. return "".join(parts) @@ -136,6 +149,10 @@ def text_after_first_chunk(self) -> str: For non-streaming (str fields), there is no "first chunk" concept so this returns an empty string. + + Tool calls are always included when present — they are accumulated across + multiple deltas and only realized at stream end, so they always contribute + to the TPOT denominator. """ parts: list[str] = [] if self.reasoning: @@ -155,8 +172,57 @@ def text_after_first_chunk(self) -> str: elif len(self.output) > 1: # No reasoning; first chunk is output[0], skip it. parts.extend(self.output[1:]) + if self.tool_calls: + parts.append(msgspec.json.encode(list(self.tool_calls)).decode()) return "".join(parts) + def as_message_parts( + self, + ) -> tuple[str, str | None, tuple[dict[str, Any], ...] | None]: + """Return (content, reasoning, tool_calls) for chat-template tokenization.""" + if isinstance(self.output, str): + content = self.output + else: + content = "".join(self.output) + + reasoning_str: str | None = None + if self.reasoning: + if isinstance(self.reasoning, str): + reasoning_str = self.reasoning + else: + reasoning_str = "".join(self.reasoning) + + return content, reasoning_str, self.tool_calls + + def as_message_parts_after_first_chunk( + self, + ) -> tuple[str, str | None, tuple[dict[str, Any], ...] | None]: + """Return (content_after_first, reasoning_after_first, tool_calls) for TPOT tokenization.""" + reasoning_after: str | None = None + if isinstance(self.reasoning, tuple) and len(self.reasoning) > 1: + reasoning_after = "".join(self.reasoning[1:]) + + # has_reasoning_tail: reasoning[1:] is non-empty (mirrors `parts` logic in text_after_first_chunk) + has_reasoning_tail = reasoning_after is not None + + content_after = "" + if self.output: + if isinstance(self.output, str): + # Include if reasoning is any non-empty tuple (it was the first chunk) + if has_reasoning_tail or ( + self.reasoning and isinstance(self.reasoning, tuple) + ): + content_after = self.output + elif isinstance(self.output, tuple): + if has_reasoning_tail or self.reasoning: + # First chunk was in reasoning; include all output chunks. + content_after = "".join(self.output) + elif len(self.output) > 1: + # No reasoning; first chunk is output[0], skip it. + content_after = "".join(self.output[1:]) + + return content_after, reasoning_after, self.tool_calls + OUTPUT_TYPE = TextModelOutput @@ -245,7 +311,7 @@ class Query( gc=False: Safe because data/headers are simple key-value pairs without cycles. Do NOT store self-referential or cyclic structures in data/headers fields. - array_like=True: Encodes as array instead of object (e.g., ["id", {...}, {...}, 0.0] + array_like=True: Encodes as array instead of object (e.g., ["id", {...}, 0.0] instead of {"id": ..., "data": ..., ...}). Provides ~6-50% size reduction and ~6-29% ser/des speedup for ZMQ transport depending on payload size. diff --git a/src/inference_endpoint/dataset_manager/__init__.py b/src/inference_endpoint/dataset_manager/__init__.py index 4bb6c575..12938f8e 100644 --- a/src/inference_endpoint/dataset_manager/__init__.py +++ b/src/inference_endpoint/dataset_manager/__init__.py @@ -21,6 +21,7 @@ from .dataset import Dataset, EmptyDataset from .factory import DataLoaderFactory +from .multi_turn_dataset import MultiTurnDataset from .predefined.aime25 import AIME25 from .predefined.cnndailymail import CNNDailyMail from .predefined.gpqa import GPQA @@ -58,4 +59,5 @@ "CNNDailyMail", "RandomDataset", "ShopifyProductCatalogue", + "MultiTurnDataset", ] diff --git a/src/inference_endpoint/dataset_manager/factory.py b/src/inference_endpoint/dataset_manager/factory.py index 6ed1674a..96a9174d 100644 --- a/src/inference_endpoint/dataset_manager/factory.py +++ b/src/inference_endpoint/dataset_manager/factory.py @@ -24,6 +24,7 @@ from inference_endpoint.config.schema import Dataset as DatasetConfig from inference_endpoint.dataset_manager.dataset import Dataset, DatasetFormat +from .multi_turn_dataset import MultiTurnDataset from .transforms import ColumnRemap, MakeAdapterCompatible, Transform logger = logging.getLogger(__name__) @@ -95,18 +96,24 @@ def create_loader(config: DatasetConfig, num_repeats: int = 1, **kwargs) -> Data if file_format is not None: format_enum = DatasetFormat(file_format) + dataset_id = None + if config.multi_turn is not None: + dataset_id = MultiTurnDataset.DATASET_ID + transforms: list[Transform] = [] - if remap is not None: + if remap is not None and dataset_id != MultiTurnDataset.DATASET_ID: # Parser convention is {target: source} (e.g. {prompt: article}). # ColumnRemap expects {source: target} — flip it. flipped = {src: dst for dst, src in remap.items()} transforms.append(ColumnRemap(flipped)) # type: ignore[arg-type] - transforms.append(MakeAdapterCompatible()) + if dataset_id != MultiTurnDataset.DATASET_ID: + transforms.append(MakeAdapterCompatible()) assert dataset_path is not None return Dataset.load_from_file( Path(dataset_path), transforms=transforms, format=format_enum, + dataset_id=dataset_id, num_repeats=num_repeats, ) diff --git a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py new file mode 100644 index 00000000..7b0484bd --- /dev/null +++ b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py @@ -0,0 +1,553 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multi-turn conversation dataset for conversational AI benchmarking.""" + +import logging +from dataclasses import dataclass, field, replace +from typing import Any + +import pandas as pd + +from ..config.schema import APIType, ModelParams +from ..exceptions import InputValidationError +from .dataset import Dataset +from .transforms import ( + AddStaticColumns, + apply_transforms, + get_transforms_for_api_type, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class ConversationSampleEntry: + """One client-turn entry in ConversationMetadata.samples. + + sample_index is populated after transforms in MultiTurnDataset.load(); + None before load() is called. + """ + + conversation_id: str + turn: int + sample_index: int | None = None + + +@dataclass +class ConversationMetadata: + """Bundle of maps/lists consumed by MultiTurnStrategy. + + Produced by MultiTurnDataset._build_metadata() from the post-transform dataframe. + Keys in the *_by_key dicts are (str(conversation_id), int(turn)). + Populated by load(); None before load() is called. + """ + + samples: list[ConversationSampleEntry] + num_conversations: int + max_turns_per_conv: int + client_turns_per_conversation: dict[str, int] + pre_built_messages_by_key: dict[tuple[str, int], list[dict]] = field( + default_factory=dict + ) + current_turn_messages_by_key: dict[tuple[str, int], list[dict]] = field( + default_factory=dict + ) + system_prompts_by_conv: dict[str, str | None] = field(default_factory=dict) + + +def _expand_tool_results(row: dict) -> list[dict]: + """Expand a tool row into one OpenAI tool message per result. + + All ``role: "tool"`` rows carry a ``tool_results`` array. Each entry expands to + one OpenAI tool message with ``tool_call_id`` and ``content``. + + Returns an empty list if ``tool_results`` is absent or not a list (non-tool rows). + """ + tool_results = row.get("tool_results") + if not isinstance(tool_results, list): + return [] + if not tool_results: + logger.warning( + "Row has empty tool_results list (conversation_id=%s, turn=%s)", + row.get("conversation_id"), + row.get("turn"), + ) + return [] + messages = [] + for i, result in enumerate(tool_results): + if not isinstance(result, dict): + raise InputValidationError( + f"tool_results[{i}] in conversation {row.get('conversation_id')!r} " + f"turn {row.get('turn')} must be a dict, got {type(result).__name__}" + ) + tool_call_id = result.get("tool_call_id") + content = result.get("content") + if tool_call_id is None: + raise InputValidationError( + f"tool_results[{i}] in conversation {row.get('conversation_id')!r} " + f"turn {row.get('turn')} is missing required field 'tool_call_id'" + ) + if content is None: + raise InputValidationError( + f"tool_results[{i}] in conversation {row.get('conversation_id')!r} " + f"turn {row.get('turn')} is missing required field 'content'" + ) + messages.append( + {"role": "tool", "tool_call_id": tool_call_id, "content": content} + ) + return messages + + +class MultiTurnDataset(Dataset, dataset_id="multi_turn_conversations"): + """Dataset for multi-turn conversations. + + Supports conversational AI benchmarking with turn sequencing and conversation history. + Validates that conversations have proper structure (alternating user/assistant roles) + and builds metadata for the scheduler to enforce turn ordering. + + Dataset format (JSONL): + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "...", "system": "..."} + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "..."} + {"conversation_id": "c1", "turn": 3, "role": "user", "content": "..."} + + Required columns: + - conversation_id: Unique identifier for each conversation + - turn: Turn number within conversation (1-indexed) + - role: Speaker role ("user" or "assistant") + - content: Message content + + Optional columns: + - system: System prompt associated with the conversation (typically set on the first user turn) + - model: Model name override + - max_new_tokens / max_completion_tokens: Max tokens for this turn (alias; mapped to max_completion_tokens) + + Attributes: + conversation_metadata: ConversationMetadata populated by load() (None before). + Validators run at construction; metadata is built once in load() against + the post-transform dataframe so pre_built_messages_by_key is always in sync. + """ + + COLUMN_NAMES = ["conversation_id", "turn", "role", "content"] + + def __init__(self, dataframe: pd.DataFrame, **kwargs): + """Initialize multi-turn dataset. + + Args: + dataframe: DataFrame with conversation data. + **kwargs: Additional arguments passed to Dataset.__init__. + + Raises: + ValueError: If conversation structure is invalid. + """ + super().__init__(dataframe, **kwargs) + assert self.dataframe is not None, "Dataframe must be initialized" + self._conv_groups = dict( + list(self.dataframe.groupby("conversation_id", sort=False)) + ) + self._validate_conversation_grouping() + self._validate_conversation_structure() + self._validate_turn_numbering() + # Populated by load() after transforms; None until then. + self.conversation_metadata: ConversationMetadata | None = None + + def _validate_conversation_grouping(self) -> None: + """Validate that all rows for each conversation_id appear consecutively in file order. + + Raises: + InputValidationError: If rows for a conversation_id are interleaved with other conversations. + """ + assert self.dataframe is not None, "Dataframe must be initialized" + seen: set[str] = set() + last_conv: str | None = None + for row in self.dataframe.to_dict(orient="records"): + conv_id = str(row["conversation_id"]) + if conv_id != last_conv: + if conv_id in seen: + raise InputValidationError( + f"Rows for conversation '{conv_id}' are not consecutive. " + "All rows for a conversation must appear together in the file." + ) + seen.add(conv_id) + last_conv = conv_id + + def _validate_conversation_structure(self): + """Validate conversations are well-formed. + + Accepts plain user/assistant alternation as well as tool sequences: + user → assistant → tool → [assistant → tool]* → assistant → user + + Raises: + ValueError: If any conversation has invalid role sequence. + """ + VALID_NEXT: dict[str, set[str]] = { + "start": {"user"}, + "user": {"assistant"}, + "assistant": {"tool", "user"}, + "tool": {"assistant", "user"}, + } + + for conv_id, group in self._conv_groups.items(): + sorted_group = group.sort_values("turn") + state = "start" + prev_assistant_had_tool_calls = False + + for _, row in sorted_group.iterrows(): + role = row["role"] + + if role not in VALID_NEXT.get(state, set()): + raise ValueError( + f"Conversation {conv_id} has invalid role sequence at turn " + f"{row['turn']}: got '{role}' after state '{state}'" + ) + + if role == "tool": + if state == "assistant" and not prev_assistant_had_tool_calls: + raise InputValidationError( + f"Conversation {conv_id} turn {row['turn']}: " + "'tool' row must follow an 'assistant' row that has non-empty 'tool_calls'" + ) + tool_results = row.get("tool_results") + if not isinstance(tool_results, list) or len(tool_results) == 0: + raise InputValidationError( + f"Conversation {conv_id} turn {row['turn']}: " + "tool rows must have a non-empty 'tool_results' list" + ) + elif role == "assistant": + content = row.get("content") + is_empty_content = ( + content is None + or (isinstance(content, float) and pd.isna(content)) + or content == "" + ) + tool_calls = row.get("tool_calls") + has_tool_calls = ( + isinstance(tool_calls, list) and len(tool_calls) > 0 + ) + if is_empty_content and not has_tool_calls: + raise InputValidationError( + f"Conversation {conv_id} turn {row['turn']}: " + "assistant rows must have non-empty 'content' or non-empty 'tool_calls'" + ) + if ( + tool_calls is not None + and not (isinstance(tool_calls, float) and pd.isna(tool_calls)) + and not has_tool_calls + ): + raise InputValidationError( + f"Conversation {conv_id} turn {row['turn']}: " + "'tool_calls' field is present but is not a non-empty list; " + "omit the field or provide a valid non-empty list" + ) + prev_assistant_had_tool_calls = has_tool_calls + elif role == "user": + content = row.get("content") + is_empty_content = ( + content is None + or (isinstance(content, float) and pd.isna(content)) + or content == "" + ) + if is_empty_content: + raise InputValidationError( + f"Conversation {conv_id} turn {row['turn']}: " + "user rows must have non-empty 'content'" + ) + if state == "assistant" and prev_assistant_had_tool_calls: + raise InputValidationError( + f"Conversation {conv_id} turn {row['turn']}: " + "'user' row cannot follow an assistant row with 'tool_calls'; " + "a 'tool' result row is required first" + ) + + if role != "assistant": + prev_assistant_had_tool_calls = False + + state = role + + def _validate_turn_numbering(self): + """Validate turn numbers are consecutive starting at 1. + + Raises: + ValueError: If turn numbers are not exactly 1, 2, 3, …, N. + """ + for conv_id, group in self._conv_groups.items(): + turns = sorted(group["turn"].tolist()) + expected = list(range(1, len(turns) + 1)) + if turns != expected: + raise ValueError( + f"Conversation {conv_id}: Turn numbers must be consecutive starting at 1, " + f"got {turns}" + ) + + def _build_metadata(self) -> ConversationMetadata: + """Build metadata for scheduler (maps sample index to conversation context). + + Pre-computes the complete message list for each client turn so that + conversation history does not need to be accumulated at runtime. + + Returns: + ConversationMetadata with samples, counts, and pre-built message maps. + """ + samples: list[ConversationSampleEntry] = [] + + # Count client turns (user + tool) per conversation for completion tracking + client_turns_per_conv = { + str(conv_id): int(group["role"].isin(["user", "tool"]).sum()) + for conv_id, group in self._conv_groups.items() + } + + # Map (conversation_id, turn) → complete message list ready to send to endpoint. + # Each entry is: [system (optional)] + all prior rows formatted as messages + # + the current client turn message. + # This includes assistant rows (tool dispatches or terminal responses) + # so no runtime injection is required. + pre_built_messages_by_key: dict[tuple, list[dict]] = {} + current_turn_messages_by_key: dict[tuple, list[dict]] = {} + system_prompts_by_conv: dict[str, str | None] = {} + + assert self.dataframe is not None, "Dataframe must be initialized" + for conv_id, group in self._conv_groups.items(): + sorted_group = group.sort_values("turn") + client_rows = sorted_group[sorted_group["role"].isin(["user", "tool"])] + + # Extract system prompt from the first row that has it (typically turn 1) + system_content: str | None = None + for _, srow in sorted_group.iterrows(): + val = srow.get("system") + if val and isinstance(val, str): + system_content = val + break + system_prompts_by_conv[str(conv_id)] = system_content + + for _, row in client_rows.iterrows(): + t_n = int(row["turn"]) + + messages: list[dict] = [] + if system_content: + messages.append({"role": "system", "content": system_content}) + + # All dataset rows strictly before this client turn (includes + # assistant rows and prior tool results). + prior_rows = sorted_group[sorted_group["turn"] < t_n] + for _, prior_row in prior_rows.iterrows(): + msg: dict[str, Any] = {} + for key in ( + "role", + "content", + "name", + "tool_calls", + "tool_results", + ): + val = prior_row.get(key) + if val is not None and not ( + isinstance(val, float) and pd.isna(val) + ): + msg[key] = val + if ( + msg.get("role") == "assistant" + and "tool_calls" in msg + and "content" not in msg + ): + msg["content"] = None + if msg.get("role"): + # Expand merged parallel tool results: a single row with + # tool_results: [{tool_call_id, content}, ...] expands into + # one OpenAI tool message per result entry. + expanded = _expand_tool_results(msg) + if expanded: + messages.extend(expanded) + else: + messages.append(msg) + + # Append the current client turn message. + # A merged parallel-tool row carries tool_results instead of a + # single tool_call_id/content pair; expand to one message per result. + current_turn_msgs: list[dict] = [] + expanded = _expand_tool_results(row) + if expanded: + current_turn_msgs = expanded + else: + cur: dict[str, Any] = {} + for key in ("role", "content", "name"): + val = row.get(key) + if val is not None and not ( + isinstance(val, float) and pd.isna(val) + ): + cur[key] = val + current_turn_msgs = [cur] + messages.extend(current_turn_msgs) + + str_conv_id = str(conv_id) + pre_built_messages_by_key[(str_conv_id, t_n)] = messages + current_turn_messages_by_key[(str_conv_id, t_n)] = current_turn_msgs + + samples.append( + ConversationSampleEntry( + conversation_id=str_conv_id, + turn=t_n, + ) + ) + + return ConversationMetadata( + samples=samples, + num_conversations=len(self._conv_groups), + max_turns_per_conv=max(g["turn"].max() for g in self._conv_groups.values()), + client_turns_per_conversation=client_turns_per_conv, + pre_built_messages_by_key=pre_built_messages_by_key, + current_turn_messages_by_key=current_turn_messages_by_key, + system_prompts_by_conv=system_prompts_by_conv, + ) + + def load( + self, + adapter=None, + api_type: APIType | None = None, + model_params: ModelParams | None = None, + force: bool = False, + ): + """Load dataset, apply adapter defaults, and pre-bake client-turn samples. + + Unlike single-turn datasets, multi-turn rows do not have a `prompt` column, + so ColumnFilter (which requires prompt) is skipped. AddStaticColumns entries + from the adapter are applied via AddStaticColumns(..., overwrite=False) + (fill-missing-only) so that + per-row dataset overrides are preserved. + + After transforms, only client turns (user + tool) are stored in self.data as + fully assembled sample dicts (with messages attached). + load_sample() and num_samples() are inherited from the base class. + """ + if not force and self.data is not None: + return + + df = self.dataframe + if df is None: + raise ValueError( + f"Cannot load dataset {self.__class__.__name__}: dataframe is None" + ) + + transforms = [] + if self.transforms is not None: + transforms.extend(self.transforms) + + if transforms: + df = apply_transforms(df, transforms) + + # Extract AddStaticColumns defaults from adapter transforms and apply as + # fill-missing-only (preserves per-row dataset values). + if api_type is not None and model_params is not None: + adapter_transforms = get_transforms_for_api_type(api_type, model_params) + defaults: dict[str, Any] = {} + for t in adapter_transforms: + if isinstance(t, AddStaticColumns): + defaults.update(t.data) + if defaults: + df = AddStaticColumns(defaults, overwrite=False)(df) + + # Rebuild conv_groups + metadata from the final post-transform df so + # pre_built_messages_by_key reflects any transforms applied above. + self.dataframe = df + self._conv_groups = dict(list(df.groupby("conversation_id", sort=False))) + self.conversation_metadata = self._build_metadata() + + all_rows = df.to_dict(orient="records") + + # Pre-bake: assemble one complete sample dict per client turn. + # NaN filtering replaces the GENERATION_PARAMS allowlist — any key whose + # value is float NaN was absent in the original dataset row. + pre_built = self.conversation_metadata.pre_built_messages_by_key + client_turn_samples: list[dict[str, Any]] = [] + # Maps (conv_id, turn) → dense sample_index for metadata backfill. + key_to_sample_index: dict[tuple[str, int], int] = {} + + # Collect per-conversation defaults from the first user row so that + # fields like model/max_completion_tokens propagate to tool rows. + _PROPAGATED_KEYS = { + "model", + "max_completion_tokens", + "max_new_tokens", + "stream", + "tools", + } + conv_defaults: dict[str, dict[str, Any]] = {} + for row in all_rows: + cid = row.get("conversation_id") + if cid not in conv_defaults and row.get("role") == "user": + conv_defaults[cid] = { + k: row[k] + for k in _PROPAGATED_KEYS + if k in row + and row[k] is not None + and not (isinstance(row[k], float) and pd.isna(row[k])) + } + + for row in all_rows: + if row.get("role") not in ("user", "tool"): + continue + + # Filter NaN values; keep all meaningful fields (extra keys are harmless + # since adapters consume only what they recognize). + sample: dict[str, Any] = { + k: v + for k, v in row.items() + if v is not None and not (isinstance(v, float) and pd.isna(v)) + } + # Strip dataset-internal fields that must not reach the endpoint. + sample.pop("tool_results", None) + sample.pop("tool_calls", None) + + # Fill missing propagated fields from the first user row of this conversation. + for k, v in conv_defaults.get(row.get("conversation_id"), {}).items(): + if k not in sample: + sample[k] = v + + # Normalize max-tokens across all adapter aliases. + max_tokens_val = ( + sample.pop("max_new_tokens", None) + or sample.get("max_completion_tokens") + or sample.get("max_tokens") + or 128 + ) + sample["max_new_tokens"] = max_tokens_val + sample["max_completion_tokens"] = max_tokens_val + sample["max_tokens"] = max_tokens_val + if "stream" not in sample: + sample["stream"] = False + + # Attach pre-built message list (system + history + current turn). + key = (str(row["conversation_id"]), int(row["turn"])) + if key not in pre_built: + logger.warning( + "dropping sample missing pre-built messages: key=%s", key + ) + continue + sample["messages"] = pre_built[key] + + # Record dense 0-based index before appending (matches load_sample() position). + key_to_sample_index[key] = len(client_turn_samples) + client_turn_samples.append(sample) + + # Backfill explicit sample_index into conversation_metadata.samples. + # Drop entries whose key is absent (truncated turns not in client_turn_samples). + updated_samples = [] + for s in self.conversation_metadata.samples: + skey: tuple[str, int] = (str(s.conversation_id), int(s.turn)) + if skey in key_to_sample_index: + updated_samples.append( + replace(s, sample_index=key_to_sample_index[skey]) + ) + self.conversation_metadata.samples = updated_samples + + self.data = client_turn_samples diff --git a/src/inference_endpoint/dataset_manager/transforms.py b/src/inference_endpoint/dataset_manager/transforms.py index 79133796..9ef7b1c3 100644 --- a/src/inference_endpoint/dataset_manager/transforms.py +++ b/src/inference_endpoint/dataset_manager/transforms.py @@ -114,16 +114,29 @@ def process_row(self, row: dict[str, Any]) -> dict[str, Any]: class AddStaticColumns(Transform): - """Transform that adds columns with constant values to a DataFrame.""" + """Transform that adds columns with constant values to a DataFrame. - def __init__(self, data: dict[str, Any]): + When overwrite=False, existing non-null values are preserved — dataset + per-row overrides take precedence over the supplied defaults. + """ + + def __init__(self, data: dict[str, Any], overwrite: bool = True): """Initialize the AddStaticColumns transform.""" self.data = data + self.overwrite = overwrite def __call__(self, df: pd.DataFrame) -> pd.DataFrame: - """Add the static columns to the row.""" + """Add the static columns to the dataframe.""" for key, value in self.data.items(): - df[key] = value + if self.overwrite: + df[key] = value + else: + if value is None: + continue + if key in df.columns: + df[key] = df[key].where(pd.notna(df[key]), value) + else: + df[key] = value return df diff --git a/src/inference_endpoint/endpoint_client/adapter_protocol.py b/src/inference_endpoint/endpoint_client/adapter_protocol.py index 164f71e1..a96faaa0 100644 --- a/src/inference_endpoint/endpoint_client/adapter_protocol.py +++ b/src/inference_endpoint/endpoint_client/adapter_protocol.py @@ -17,16 +17,21 @@ from __future__ import annotations +import logging import re from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any +import msgspec + from inference_endpoint.core.types import Query, QueryResult if TYPE_CHECKING: from inference_endpoint.config.schema import ModelParams from inference_endpoint.dataset_manager.transforms import Transform +logger = logging.getLogger(__name__) + class HttpRequestAdapter(ABC): """ @@ -95,40 +100,39 @@ def decode_response(cls, response_bytes: bytes, query_id: str) -> QueryResult: @abstractmethod def decode_sse_message(cls, json_bytes: bytes) -> Any: """ - Decode SSE message and extract content. + Decode SSE message and return adapter-specific chunk object. Args: json_bytes: Raw JSON bytes from SSE stream Returns: - Decoded SSE content (type depends on the adapter implementation) + Adapter-specific chunk object passed to accumulator.add_chunk() """ raise NotImplementedError("decode_sse_message not implemented") @classmethod - def parse_sse_chunk(cls, buffer: bytes, end_pos: int) -> list[str]: - """ - Parse SSE chunk and extract all content strings. + def parse_sse_chunk(cls, buffer: bytes, end_pos: int) -> list[Any]: + """Parse SSE chunk and extract all chunk objects. - Extracts JSON documents from SSE stream and decodes them to content strings. - Silently ignores non-content SSE messages (role, finish_reason, etc). + Extracts JSON documents from SSE stream and decodes them to chunk objects. + Filters None returns from decode_sse_message and skips frames that fail + to decode (e.g. role-only or finish_reason-only frames). Args: buffer: Byte buffer containing SSE data end_pos: End position in buffer to parse up to Returns: - List of content strings extracted from the SSE chunk + List of chunk objects extracted from the SSE chunk """ json_docs = cls.SSE_DATA_PATTERN.findall(buffer[:end_pos]) - parsed_contents = [] - + parsed: list[Any] = [] + # Note: if one frame is malformed, remaining frames are skipped try: for json_doc in json_docs: content = cls.decode_sse_message(json_doc) - parsed_contents.append(content) - except Exception: - # Normal for non-content SSE messages (role, finish_reason, etc) - pass - - return parsed_contents + if content is not None: + parsed.append(content) + except (msgspec.DecodeError, msgspec.ValidationError) as exc: + logger.warning("skipping malformed SSE batch (%s)", type(exc).__name__) + return parsed diff --git a/src/inference_endpoint/load_generator/conversation_manager.py b/src/inference_endpoint/load_generator/conversation_manager.py new file mode 100644 index 00000000..1b0834bb --- /dev/null +++ b/src/inference_endpoint/load_generator/conversation_manager.py @@ -0,0 +1,172 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Conversation state management for multi-turn benchmarking.""" + +import logging +from dataclasses import dataclass, field +from typing import Any + +logger = logging.getLogger(__name__) + + +@dataclass +class ConversationState: + """Per-conversation state for multi-turn benchmarking. + + Attributes: + conversation_id: Unique identifier for this conversation. + message_history: Accumulated message list (populated only when + use_dataset_history=False; empty otherwise). + completed_turns: Turns with responses (success or failure) — observability only. + failed_turns: Turns that failed — observability only. + expected_client_turns: Expected total client turns (for completion detection). + """ + + conversation_id: str + message_history: list[dict[str, Any]] = field(default_factory=list) + completed_turns: int = 0 + failed_turns: int = 0 + expected_client_turns: int | None = None + + def is_complete(self) -> bool: + """Return True when all expected turns have a response.""" + if self.expected_client_turns is None: + return False + return self.completed_turns >= self.expected_client_turns + + +class ConversationManager: + """Manages per-conversation state for multi-turn benchmarking. + + All methods are synchronous. Turn sequencing is driven by MultiTurnStrategy + which calls on_sample_complete() → _issue_next_turn() directly. + + All states are pre-created by MultiTurnStrategy.execute() before any turns + are issued, so get_or_create() requires no locking. + """ + + def __init__(self): + """Initialize with empty state.""" + self._conversations: dict[str, ConversationState] = {} + + def get_state(self, conversation_id: str) -> ConversationState | None: + """Return existing state without creating (read-only access).""" + return self._conversations.get(conversation_id) + + def get_or_create( + self, + conversation_id: str, + expected_client_turns: int | None = None, + system_message: dict[str, Any] | None = None, + ) -> ConversationState: + """Return existing state or create a new one. + + Args: + conversation_id: Unique identifier for conversation. + expected_client_turns: Expected number of client turns. + system_message: System message to prepend to message_history + (only used when use_dataset_history=False and state is new). + + Returns: + ConversationState for this conversation. + """ + if conversation_id not in self._conversations: + initial_history: list[dict[str, Any]] = ( + [system_message] if system_message is not None else [] + ) + self._conversations[conversation_id] = ConversationState( + conversation_id=conversation_id, + expected_client_turns=expected_client_turns, + message_history=initial_history, + ) + return self._conversations[conversation_id] + + def _log_if_complete(self, state: ConversationState, conversation_id: str) -> None: + """Log completion status once all expected turns have a response.""" + if not state.is_complete(): + return + if state.failed_turns > 0: + logger.info( + f"Conversation {conversation_id} completed with failures: " + f"{state.completed_turns - state.failed_turns}/" + f"{state.expected_client_turns} successful, " + f"{state.failed_turns} failed" + ) + else: + logger.debug( + f"Conversation {conversation_id} completed: " + f"{state.completed_turns}/{state.expected_client_turns} turns" + ) + + def mark_turn_complete( + self, + conversation_id: str, + response: str, + store_in_history: bool = False, + metadata: dict[str, Any] | None = None, + ) -> None: + """Record a successful response. + + Args: + conversation_id: Conversation ID. + response: Model output (appended to history when store_in_history=True). + store_in_history: When True, append response to message_history. + metadata: Optional response metadata; tool_calls are preserved in history + when present (only used when store_in_history=True). + + Raises: + KeyError: If conversation_id not found. + """ + state = self._conversations.get(conversation_id) + if state is None: + raise KeyError(f"Conversation {conversation_id} not initialized") + if store_in_history: + tool_calls = metadata.get("tool_calls") if metadata else None + if response or tool_calls: + msg: dict[str, Any] = {"role": "assistant", "content": response or None} + if tool_calls: + msg["tool_calls"] = tool_calls + state.message_history.append(msg) + state.completed_turns += 1 + self._log_if_complete(state, conversation_id) + + def mark_turn_failed( + self, + conversation_id: str, + store_in_history: bool = False, + ) -> None: + """Record a failed response. + + Failed turns count toward completion so sequencing progresses under errors. + + Args: + conversation_id: Conversation ID. + store_in_history: When True, append error placeholder to message_history. + + Raises: + KeyError: If conversation_id not found. + """ + state = self._conversations.get(conversation_id) + if state is None: + raise KeyError(f"Conversation {conversation_id} not initialized") + if store_in_history: + state.message_history.append( + {"role": "assistant", "content": "[ERROR: Turn failed or timed out]"} + ) + state.completed_turns += 1 + state.failed_turns += 1 + logger.warning(f"Turn failed for conversation {conversation_id}") + self._log_if_complete(state, conversation_id) diff --git a/src/inference_endpoint/load_generator/multi_turn_strategy.py b/src/inference_endpoint/load_generator/multi_turn_strategy.py new file mode 100644 index 00000000..3803e225 --- /dev/null +++ b/src/inference_endpoint/load_generator/multi_turn_strategy.py @@ -0,0 +1,443 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Async multi-turn load strategy implementing the LoadStrategy protocol.""" + +import asyncio +import logging +import time +from collections import defaultdict, deque +from collections.abc import Iterator +from typing import Any + +from ..config.schema import MultiTurnConfig +from ..core.record import ErrorEventType, EventRecord, SampleEventType +from ..core.types import ErrorData, QueryResult, TextModelOutput +from ..dataset_manager.multi_turn_dataset import ConversationMetadata +from ..exceptions import InputValidationError +from .conversation_manager import ConversationManager, ConversationState +from .strategy import PhaseIssuerProtocol + +logger = logging.getLogger(__name__) + +# Default turn timeout when no MultiTurnConfig is provided. +_DEFAULT_TURN_TIMEOUT_S = 300.0 + + +class MultiTurnStrategy: + """Event-driven multi-turn strategy. Completion of each turn triggers the next. + + execute() seeds the first N conversations (issues turn 1 for each), then + awaits _all_done. on_sample_complete() is called synchronously from the + receive coroutine for each response — it issues the next turn immediately + (zero event-loop iterations between response and next issuance), or starts + a new conversation when the current one finishes all turns. + + At most target_concurrency conversations are active simultaneously. When + target_concurrency is None, all conversations start at once. + + A turn-level timeout aborts the remaining client turns of that conversation + because subsequent turns depend on the timed-out response. The timed-out + turn and all downstream turns are marked failed. + + Integration with BenchmarkSession: + - execute(): seeds conversations, awaits completion + - on_query_complete(): no-op (required by LoadStrategy protocol) + - on_sample_complete(): routes completed QueryResult, issues next turn + + The response routing path: + 1. _issue_next_turn issues turn N via phase_issuer.issue(idx) → query_id + 2. _issue_next_turn stores conv_id in _inflight[query_id] + 3. BenchmarkSession calls on_sample_complete(result) with the QueryResult + 4. on_sample_complete looks up conv_id from _inflight, calls mark_turn_complete + 5. on_sample_complete calls _issue_next_turn for turn N+1 (synchronously) + """ + + def __init__( + self, + conversation_manager: ConversationManager, + dataset_metadata: ConversationMetadata, + multi_turn_config: MultiTurnConfig | None = None, + target_concurrency: int | None = None, + ): + """Initialize multi-turn strategy. + + Args: + conversation_manager: Manages conversation sequencing state. + dataset_metadata: ConversationMetadata from MultiTurnDataset (after load()). + multi_turn_config: Multi-turn conversation configuration. + target_concurrency: Maximum number of simultaneously active conversations. + None means all conversations run concurrently. + """ + self._conv_manager = conversation_manager + self._dataset_metadata = dataset_metadata + self._multi_turn_config = multi_turn_config + self._turn_timeout_s = ( + multi_turn_config.turn_timeout_s + if multi_turn_config is not None + else _DEFAULT_TURN_TIMEOUT_S + ) + self._target_concurrency = target_concurrency + self._store_in_history = ( + not multi_turn_config.use_dataset_history + if multi_turn_config is not None + else False + ) + + if self._store_in_history: + tool_turn_keys = [ + key + for key, msgs in dataset_metadata.current_turn_messages_by_key.items() + if any(m.get("role") == "tool" for m in msgs) + ] + if tool_turn_keys: + raise InputValidationError( + "Multi-turn with tool turns requires use_dataset_history=True. " + "Live-history mode (use_dataset_history=False) with tool calls " + "is not implemented yet. " + f"Offending turn(s): {tool_turn_keys[:5]}" + + ( + f" (+{len(tool_turn_keys) - 5} more)" + if len(tool_turn_keys) > 5 + else "" + ) + ) + + # Composite on_sample_complete callback set by execute.py; used by + # _handle_timeout to route synthetic failure results. + self._session_on_sample_complete: Any | None = None + self._session_publisher: Any | None = None + + # Maps query_id -> conversation_id for routing completions. + self._inflight: dict[str, str] = {} + # Cached ConversationState refs for O(1) lookup in on_sample_complete. + self._conv_states: dict[str, ConversationState] = {} + + # Event-driven state — populated in execute(). + self._pending_convs: deque[tuple[str, list[tuple[int, int]]]] = deque() + self._active_iters: dict[str, Iterator[tuple[int, int]]] = {} + self._timeout_handles: dict[str, asyncio.TimerHandle] = {} + self._error: BaseException | None = None + self._all_done: asyncio.Event | None = None + self._loop: asyncio.AbstractEventLoop | None = None + self._phase_issuer: PhaseIssuerProtocol | None = None + + async def execute(self, phase_issuer: PhaseIssuerProtocol) -> int: + """Drive multi-turn sample issuance. + + Args: + phase_issuer: Interface for issuing samples to the endpoint. + + Returns: + Total count of samples issued. + """ + self._phase_issuer = phase_issuer + self._loop = asyncio.get_running_loop() + self._all_done = asyncio.Event() + self._error = None + + conv_samples: dict[str, list[tuple[int, int]]] = defaultdict(list) + for sample_meta in self._dataset_metadata.samples: + conv_id = sample_meta.conversation_id + assert sample_meta.sample_index is not None + conv_samples[conv_id].append((sample_meta.sample_index, sample_meta.turn)) + + # Pre-create all conversation states before issuing any turns (no locking needed). + sys_prompts = self._dataset_metadata.system_prompts_by_conv + for conv_id, turns in conv_samples.items(): + sys_content = sys_prompts.get(conv_id) if self._store_in_history else None + system_message = ( + {"role": "system", "content": sys_content} + if sys_content is not None + else None + ) + state = self._conv_manager.get_or_create( + conv_id, + expected_client_turns=len(turns), + system_message=system_message, + ) + self._conv_states[conv_id] = state + + # Build pending queue (sorted turns per conversation). + for conv_id, turns in conv_samples.items(): + self._pending_convs.append((conv_id, sorted(turns, key=lambda x: x[1]))) + + n_to_start = ( + min(self._target_concurrency, len(self._pending_convs)) + if self._target_concurrency is not None and self._target_concurrency > 0 + else len(self._pending_convs) + ) + try: + for _ in range(n_to_start): + self._start_conversation() + + if not self._active_iters and not self._inflight: + return phase_issuer.issued_count + + await self._all_done.wait() + if self._error is not None: + raise self._error + return phase_issuer.issued_count + finally: + for handle in self._timeout_handles.values(): + handle.cancel() + self._timeout_handles.clear() + if self._inflight: + logger.warning( + "%d query(ies) never received a response (session stop or transport failure): %s", + len(self._inflight), + list(self._inflight.keys()), + ) + self._inflight.clear() + + def _start_conversation(self) -> None: + """Pop the next conversation from the pending queue and issue its first turn.""" + conv_id, turns = self._pending_convs.popleft() + self._active_iters[conv_id] = iter(turns) + self._issue_next_turn(conv_id) + + def _issue_next_turn(self, conv_id: str) -> None: + """Issue the next turn for conv_id, or mark the conversation done.""" + it = self._active_iters.get(conv_id) + if it is None: + return + + pair = next(it, None) + if pair is None: + del self._active_iters[conv_id] + self._fill_slot() + return + + idx, turn = pair + state = self._conv_states[conv_id] + + data_override: dict[str, Any] | None = None + current_turn_messages: list[dict[str, Any]] | None = None + if self._store_in_history: + current_turn_messages = ( + self._dataset_metadata.current_turn_messages_by_key.get((conv_id, turn)) + ) + if current_turn_messages: + live_messages = state.message_history.copy() + current_turn_messages + data_override = { + "messages": live_messages, + "input_tokens": None, + "token_ids": None, + } + + assert self._phase_issuer is not None + query_id = self._phase_issuer.issue( + idx, + data_override=data_override, + conversation_id=conv_id, + turn=turn, + ) + if query_id is None: + # Session stopping — signal done. + assert self._all_done is not None + self._all_done.set() + return + + self._inflight[query_id] = conv_id + + if self._store_in_history and current_turn_messages: + state.message_history.extend(current_turn_messages) + + assert self._loop is not None + handle = self._loop.call_later( + self._turn_timeout_s, self._handle_timeout, query_id, conv_id + ) + self._timeout_handles[query_id] = handle + + def _fill_slot(self) -> None: + """Start a new conversation from the pending queue, or signal all done.""" + if self._pending_convs: + self._start_conversation() + elif not self._active_iters: + assert self._all_done is not None + self._all_done.set() + + def _handle_timeout(self, query_id: str, conv_id: str) -> None: + """Called by the event loop when a turn response does not arrive in time.""" + if self._inflight.pop(query_id, None) is None: + return + self._timeout_handles.pop(query_id, None) + + if ( + self._phase_issuer is not None + and hasattr(self._phase_issuer, "uuid_to_index") + and query_id in self._phase_issuer.uuid_to_index # type: ignore[attr-defined] + ): + self._phase_issuer.inflight -= 1 # type: ignore[attr-defined] + if hasattr(self._phase_issuer, "completed_uuids"): + self._phase_issuer.completed_uuids.add(query_id) # type: ignore[attr-defined] + if hasattr(self._phase_issuer, "uuid_to_conv_info"): + self._phase_issuer.uuid_to_conv_info.pop(query_id, None) # type: ignore[attr-defined] + + logger.warning( + "Turn timed out for conversation %s (query=%s)", conv_id, query_id + ) + + self._conv_manager.mark_turn_failed( + conv_id, store_in_history=self._store_in_history + ) + + # Route a synthetic failure result so the accuracy collector and event + # logger see the timed-out turn. + timeout_result = QueryResult( + id=query_id, + error=ErrorData( + error_type="TurnTimeout", + error_message=f"turn timeout after {self._turn_timeout_s}s", + ), + ) + + # Publish ERROR + COMPLETE so the metrics aggregator and event logger + # see the timeout (matches BenchmarkSession._handle_response ordering). + if self._session_publisher is not None: + try: + self._session_publisher.publish( + EventRecord( + event_type=ErrorEventType.GENERIC, + timestamp_ns=time.monotonic_ns(), + sample_uuid=query_id, + data=timeout_result.error, + ) + ) + self._session_publisher.publish( + EventRecord( + event_type=SampleEventType.COMPLETE, + timestamp_ns=time.monotonic_ns(), + sample_uuid=query_id, + data=None, + ) + ) + except Exception: + logger.exception( + "Failed to publish timeout EventRecords (query=%s)", query_id + ) + + if self._session_on_sample_complete is not None: + try: + self._session_on_sample_complete(timeout_result) + except Exception: + logger.exception( + "on_sample_complete callback raised for timeout result (query=%s)", + query_id, + ) + + it = self._active_iters.pop(conv_id, None) + dropped = 0 + if it is not None: + for _ in it: + self._conv_manager.mark_turn_failed( + conv_id, store_in_history=self._store_in_history + ) + dropped += 1 + if dropped: + logger.warning( + "turn timeout on conv=%s dropped %d remaining client turn(s)", + conv_id, + dropped, + ) + + self._fill_slot() + + def on_query_complete(self, query_id: str) -> None: + """No-op. Required by LoadStrategy protocol; called by BenchmarkSession.""" + pass + + def on_sample_complete(self, result: QueryResult) -> None: + """Route completed QueryResult to ConversationManager and issue next turn. + + Called synchronously from BenchmarkSession._handle_response(). Issues the + next turn immediately (zero event-loop delay) or starts a new conversation + when this one finishes all turns. + + Args: + result: Completed QueryResult from the endpoint. + """ + conv_id = self._inflight.pop(result.id, None) + if conv_id is None: + logger.debug( + "dropping late response result=%s (no matching in-flight entry)", + result.id, + ) + return + + handle = self._timeout_handles.pop(result.id, None) + if handle is not None: + handle.cancel() + + output = result.response_output + if isinstance(output, TextModelOutput): + response_text: str | None = ( + "".join(output.output) + if isinstance(output.output, tuple) + else output.output + ) or None + else: + response_text = output if isinstance(output, str) else None + + try: + if result.error is not None: + self._conv_manager.mark_turn_failed( + conv_id, store_in_history=self._store_in_history + ) + else: + self._conv_manager.mark_turn_complete( + conv_id, + response_text or "", + store_in_history=self._store_in_history, + metadata=result.metadata, + ) + except KeyError: + self._active_iters.pop(conv_id, None) + self._fill_slot() + logger.exception( + "on_sample_complete routing miss for conv=%s result=%s", + conv_id, + result.id, + ) + return + + # If this turn failed, abandon the rest of the conversation: replaying + # later turns against a corrupt history (assistant placeholder / + # missing tool result) is meaningless and matches the timeout path. + if result.error is not None: + it = self._active_iters.pop(conv_id, None) + dropped = 0 + if it is not None: + for _ in it: + self._conv_manager.mark_turn_failed( + conv_id, store_in_history=self._store_in_history + ) + dropped += 1 + if dropped: + logger.warning( + "turn error on conv=%s dropped %d remaining client turn(s)", + conv_id, + dropped, + ) + self._fill_slot() + return + + try: + self._issue_next_turn(conv_id) + except Exception as exc: + logger.exception("Error issuing next turn for %s", conv_id) + self._error = exc + if self._all_done is not None: + self._all_done.set() diff --git a/src/inference_endpoint/load_generator/session.py b/src/inference_endpoint/load_generator/session.py index 0ccfe6a0..b860db0c 100644 --- a/src/inference_endpoint/load_generator/session.py +++ b/src/inference_endpoint/load_generator/session.py @@ -21,14 +21,15 @@ from __future__ import annotations import asyncio +import json import logging import os import time import uuid from collections.abc import Callable -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import Enum -from typing import Protocol +from typing import Any, Protocol from ..config.runtime_settings import RuntimeSettings from ..core.record import ( @@ -47,6 +48,29 @@ _WARMUP_ENABLED = os.environ.get("ENABLE_WARMUP") == "1" +def _extract_prompt_text(messages: list[Any]) -> str | None: + """Join text content from an OpenAI messages list; handles list-form multimodal content.""" + parts: list[str] = [] + for m in messages: + if not isinstance(m, dict): + continue + c = m.get("content") + if isinstance(c, str) and c: + parts.append(c) + elif isinstance(c, list): + parts.extend( + p["text"] + for p in c + if isinstance(p, dict) + and p.get("type") == "text" + and isinstance(p.get("text"), str) + ) + tc = m.get("tool_calls") + if tc: + parts.append(json.dumps(tc, separators=(",", ":"))) + return "\n".join(parts) if parts else None + + # --------------------------------------------------------------------------- # Phase configuration # --------------------------------------------------------------------------- @@ -68,6 +92,7 @@ class PhaseConfig: runtime_settings: RuntimeSettings dataset: Dataset phase_type: PhaseType = PhaseType.PERFORMANCE + strategy: LoadStrategy | None = field(default=None, compare=False) # --------------------------------------------------------------------------- @@ -153,6 +178,8 @@ class PhaseIssuer: "_publisher", "_stop_check", "uuid_to_index", + "uuid_to_conv_info", + "completed_uuids", "inflight", "issued_count", ) @@ -169,14 +196,28 @@ def __init__( self._publisher = publisher self._stop_check = stop_check self.uuid_to_index: dict[str, int] = {} + self.uuid_to_conv_info: dict[str, tuple[str, int | None]] = {} + self.completed_uuids: set[str] = set() self.inflight: int = 0 self.issued_count: int = 0 - def issue(self, sample_index: int) -> str | None: + def issue( + self, + sample_index: int, + data_override: dict[str, Any] | None = None, + conversation_id: str = "", + turn: int | None = None, + ) -> str | None: """Load data, build Query, publish ISSUED, send to endpoint. Returns query_id on success, None if session is stopping. + Args: + sample_index: Index into the dataset. + data_override: If provided, merged over the loaded sample data. + Keys in data_override take precedence. Used by MultiTurnStrategy + to substitute live-accumulated message history. + Note: load_sample() runs synchronously before the ISSUED timestamp. For accurate timing, datasets MUST be pre-loaded into memory. Disk-backed datasets will inflate timing and delay subsequent issues. @@ -185,8 +226,11 @@ def issue(self, sample_index: int) -> str | None: return None query_id = uuid.uuid4().hex data = self._dataset.load_sample(sample_index) + if data_override is not None: + data = {**data, **data_override} query = Query(id=query_id, data=data) self.uuid_to_index[query_id] = sample_index + self.uuid_to_conv_info[query_id] = (conversation_id, turn) ts = time.monotonic_ns() prompt_data: PromptData if isinstance(data, dict): @@ -197,9 +241,11 @@ def issue(self, sample_index: int) -> str | None: # meaningful for ISL reporting on text-only prompts. # Therefore, setting `text=None` for non-string prompts # means that ISL reporting will be unavailable for multimodal samples. - prompt = data.get("prompt") + prompt_text = data.get("prompt") + if prompt_text is None and "messages" in data: + prompt_text = _extract_prompt_text(data["messages"]) prompt_data = PromptData( - text=prompt if isinstance(prompt, str) else None, + text=prompt_text if isinstance(prompt_text, str) else None, token_ids=tuple(token_ids) if token_ids is not None else None, ) else: @@ -209,6 +255,8 @@ def issue(self, sample_index: int) -> str | None: event_type=SampleEventType.ISSUED, timestamp_ns=ts, sample_uuid=query_id, + conversation_id=conversation_id, + turn=turn, data=prompt_data, ) ) @@ -313,10 +361,13 @@ async def _run_phase(self, phase: PhaseConfig) -> PhaseResult | None: phase_start = time.monotonic_ns() # Create per-phase state - sample_order = create_sample_order(phase.runtime_settings) - strategy = create_load_strategy( - phase.runtime_settings, self._loop, sample_order - ) + if phase.strategy is not None: + strategy = phase.strategy + else: + sample_order = create_sample_order(phase.runtime_settings) + strategy = create_load_strategy( + phase.runtime_settings, self._loop, sample_order + ) phase_issuer = PhaseIssuer( dataset=phase.dataset, issuer=self._issuer, @@ -407,6 +458,20 @@ def _handle_response(self, resp: QueryResult | StreamChunk) -> None: if isinstance(resp, QueryResult): query_id = resp.id + # Drop late responses for queries already terminated (e.g. by + # MultiTurnStrategy._handle_timeout). Without this gate, a real + # response arriving after timeout double-publishes ERROR/COMPLETE + # and double-decrements inflight (no per-request HTTP timeout + # exists in endpoint_client; late arrivals are possible). + if phase_issuer is not None and query_id in phase_issuer.completed_uuids: + return + + conv_id_str, turn_num = ("", None) + if phase_issuer is not None: + conv_id_str, turn_num = phase_issuer.uuid_to_conv_info.pop( + query_id, ("", None) + ) + # Emit ERROR before COMPLETE for failed queries so downstream # consumers (notably the metrics aggregator) see the ERROR # while the in-flight tracked row still exists. COMPLETE @@ -425,6 +490,8 @@ def _handle_response(self, resp: QueryResult | StreamChunk) -> None: event_type=ErrorEventType.GENERIC, timestamp_ns=time.monotonic_ns(), sample_uuid=query_id, + conversation_id=conv_id_str, + turn=turn_num, data=resp.error, ) ) @@ -435,10 +502,13 @@ def _handle_response(self, resp: QueryResult | StreamChunk) -> None: if isinstance(resp.completed_at, int) else time.monotonic_ns(), sample_uuid=query_id, + conversation_id=conv_id_str, + turn=turn_num, data=resp.response_output, ) ) if phase_issuer is not None and query_id in phase_issuer.uuid_to_index: + phase_issuer.completed_uuids.add(query_id) phase_issuer.inflight -= 1 if phase_issuer.inflight <= 0: self._drain_event.set() @@ -454,11 +524,18 @@ def _handle_response(self, resp: QueryResult | StreamChunk) -> None: if is_first else SampleEventType.RECV_NON_FIRST ) + conv_id_str, turn_num = ("", None) + if phase_issuer is not None: + conv_id_str, turn_num = phase_issuer.uuid_to_conv_info.get( + resp.id, ("", None) + ) self._publisher.publish( EventRecord( event_type=event_type, timestamp_ns=time.monotonic_ns(), sample_uuid=resp.id, + conversation_id=conv_id_str, + turn=turn_num, ) ) diff --git a/src/inference_endpoint/load_generator/strategy.py b/src/inference_endpoint/load_generator/strategy.py index dd311f10..1dd6fa27 100644 --- a/src/inference_endpoint/load_generator/strategy.py +++ b/src/inference_endpoint/load_generator/strategy.py @@ -29,7 +29,7 @@ import logging from collections.abc import Callable, Iterator from time import monotonic_ns -from typing import Protocol +from typing import Any, Protocol from ..config.runtime_settings import RuntimeSettings from ..config.schema import LoadPatternType @@ -47,8 +47,26 @@ class PhaseIssuerProtocol(Protocol): """Minimal interface that strategies see for issuing samples.""" - def issue(self, sample_index: int) -> str | None: - """Issue a sample. Returns query_id, or None if the session is stopping.""" + def issue( + self, + sample_index: int, + data_override: dict[str, Any] | None = None, + conversation_id: str = "", + turn: int | None = None, + ) -> str | None: + """Issue a sample. Returns query_id, or None if the session is stopping. + + Args: + sample_index: Index into the dataset. + data_override: If provided, merged over the loaded sample — keys in + data_override take precedence. Used by MultiTurnStrategy to inject + a runtime-assembled `messages` array while still inheriting + `model`/`max_completion_tokens`/`tools`/`stream` from the dataset row. + conversation_id: Conversation identifier (multi-turn). Empty string + for single-turn issues; propagated onto the published EventRecords + so downstream consumers can group by conversation. + turn: Turn number within a conversation (multi-turn), or None. + """ ... issued_count: int @@ -297,5 +315,11 @@ def create_load_strategy( ) return ConcurrencyStrategy(lp.target_concurrency, sample_order) + case LoadPatternType.MULTI_TURN: + raise ValueError( + "MULTI_TURN load pattern requires a MultiTurnDataset — " + "use 'inference-endpoint benchmark from-config' with a multi-turn dataset" + ) + case _: raise ValueError(f"Unsupported load pattern type: {lp.type}") diff --git a/src/inference_endpoint/openai/accumulator.py b/src/inference_endpoint/openai/accumulator.py index 6cb23ed8..0c1e88f4 100644 --- a/src/inference_endpoint/openai/accumulator.py +++ b/src/inference_endpoint/openai/accumulator.py @@ -15,15 +15,13 @@ """OpenAI SSE stream accumulator implementation.""" -import logging +from typing import Any from inference_endpoint.core.types import QueryResult, StreamChunk, TextModelOutput from inference_endpoint.endpoint_client.accumulator_protocol import ( SSEAccumulatorProtocol, ) -from inference_endpoint.openai.types import SSEDelta as OpenAISSEDelta - -logger = logging.getLogger(__name__) +from inference_endpoint.openai.types import SSEChoice class OpenAISSEAccumulator(SSEAccumulatorProtocol): @@ -32,22 +30,59 @@ class OpenAISSEAccumulator(SSEAccumulatorProtocol): def __init__(self, query_id: str, stream_all_chunks: bool): self.output_chunks: list[str] = [] self.reasoning_chunks: list[str] = [] + self._tool_calls: dict[int, dict[str, Any]] = {} + self._finish_reason: str | None = None self.first_chunk_sent = False self.query_id = query_id self.stream_all_chunks = stream_all_chunks - def add_chunk(self, delta: OpenAISSEDelta) -> StreamChunk | None: - if not isinstance(delta, OpenAISSEDelta): + def add_chunk(self, choice: SSEChoice | None) -> StreamChunk | None: + if not isinstance(choice, SSEChoice): + return None + + if choice.finish_reason: + self._finish_reason = choice.finish_reason + + delta = choice.delta + if delta is None: return None + # Accumulate tool_calls partials (streamed as incremental JSON fragments) + if delta.tool_calls: + for partial in delta.tool_calls: + idx = partial.get("index", 0) + tc = self._tool_calls.setdefault( + idx, {"type": "function", "function": {"arguments": ""}} + ) + if partial.get("id"): + tc["id"] = partial["id"] + if partial.get("type"): + tc["type"] = partial["type"] + fn = partial.get("function") or {} + if fn.get("name"): + tc["function"]["name"] = fn["name"] + if fn.get("arguments"): + tc["function"]["arguments"] += fn["arguments"] + content = None if delta.content: self.output_chunks.append(delta.content) content = delta.content - elif delta.reasoning: - self.reasoning_chunks.append(delta.reasoning) - content = delta.reasoning + elif delta.reasoning_content or delta.reasoning: + rc = delta.reasoning_content or delta.reasoning + self.reasoning_chunks.append(rc) # type: ignore[arg-type] + content = rc + elif delta.tool_calls and not self.first_chunk_sent: + # Pure tool-call delta with no text: emit a zero-length sentinel so + # RECV_FIRST / TTFT fires for agentic responses that have no content. + sentinel = StreamChunk( + id=self.query_id, + response_chunk="", + metadata={"first_chunk": True}, + ) + self.first_chunk_sent = True + return sentinel else: return None @@ -67,31 +102,44 @@ def add_chunk(self, delta: OpenAISSEDelta) -> StreamChunk | None: return None def get_final_output(self) -> QueryResult: + tool_calls_tuple: tuple[dict[str, Any], ...] | None = ( + tuple(self._tool_calls[i] for i in sorted(self._tool_calls)) + if self._tool_calls + else None + ) + if self.reasoning_chunks: - # If there are reasoning chunks, then the first chunk received - # is the first reasoning chunk. The rest of the reasoning chunks, - # as well as the output chunks can be joined together. resp_reasoning: list[str] = [self.reasoning_chunks[0]] if len(self.reasoning_chunks) > 1: resp_reasoning.append("".join(self.reasoning_chunks[1:])) text_output = TextModelOutput( output="".join(self.output_chunks), reasoning=resp_reasoning, + tool_calls=tool_calls_tuple, ) elif self.output_chunks: - # If there are only output chunks, the first chunk is used for - # TTFT calculations. The rest are joined together. resp_output: list[str] = [self.output_chunks[0]] if len(self.output_chunks) > 1: resp_output.append("".join(self.output_chunks[1:])) - text_output = TextModelOutput(output=resp_output, reasoning=None) + text_output = TextModelOutput( + output=resp_output, reasoning=None, tool_calls=tool_calls_tuple + ) else: - text_output = TextModelOutput(output=[], reasoning=None) + text_output = TextModelOutput( + output=[], reasoning=None, tool_calls=tool_calls_tuple + ) + + metadata: dict[str, Any] = { + "first_chunk": not self.first_chunk_sent, + "final_chunk": True, + } + if self._finish_reason: + metadata["finish_reason"] = self._finish_reason + if tool_calls_tuple: + metadata["tool_calls"] = list(tool_calls_tuple) + return QueryResult( id=self.query_id, response_output=text_output, - metadata={ - "first_chunk": not self.first_chunk_sent, - "final_chunk": True, - }, + metadata=metadata, ) diff --git a/src/inference_endpoint/openai/completions_adapter.py b/src/inference_endpoint/openai/completions_adapter.py index 0d4cdfa7..de1d1632 100644 --- a/src/inference_endpoint/openai/completions_adapter.py +++ b/src/inference_endpoint/openai/completions_adapter.py @@ -27,6 +27,7 @@ from inference_endpoint.endpoint_client.adapter_protocol import HttpRequestAdapter from .types import ( + SSEChoice, SSEDelta, TextCompletionRequest, TextCompletionResponse, @@ -104,8 +105,12 @@ def decode_response(cls, response_bytes: bytes, query_id: str) -> QueryResult: ) @classmethod - def decode_sse_message(cls, json_bytes: bytes) -> SSEDelta: + def decode_sse_message(cls, json_bytes: bytes) -> SSEChoice: msg = cls._sse_decoder.decode(json_bytes) if not msg.choices: - return SSEDelta() - return SSEDelta(content=msg.choices[0].text) + return SSEChoice() + choice = msg.choices[0] + return SSEChoice( + delta=SSEDelta(content=choice.text), + finish_reason=choice.finish_reason, + ) diff --git a/src/inference_endpoint/openai/openai_adapter.py b/src/inference_endpoint/openai/openai_adapter.py index 5834d6b0..8f516b7b 100644 --- a/src/inference_endpoint/openai/openai_adapter.py +++ b/src/inference_endpoint/openai/openai_adapter.py @@ -14,6 +14,7 @@ # limitations under the License. import time +from typing import Any import msgspec from inference_endpoint.core.types import Query, QueryResult, TextModelOutput @@ -36,7 +37,7 @@ Role6, ServiceTier, ) -from .types import SSEMessage +from .types import SSEChoice, SSEMessage class OpenAIAdapter(HttpRequestAdapter): @@ -57,7 +58,7 @@ def dataset_transforms(cls, model_params: ModelParams) -> list[Transform]: return [ ColumnFilter( required_columns=["prompt"], - optional_columns=["system"], + optional_columns=["system", "tools"], ), AddStaticColumns(metadata), ] @@ -75,10 +76,12 @@ def decode_response(cls, response_bytes: bytes, query_id: str) -> QueryResult: return cls.from_endpoint_response(openai_response, result_id=query_id) @classmethod - def decode_sse_message(cls, json_bytes: bytes) -> str: - """Decode SSE message and extract content string.""" + def decode_sse_message(cls, json_bytes: bytes) -> SSEChoice | None: + """Decode SSE message and return SSEChoice (delta + finish_reason).""" msg = msgspec.json.decode(json_bytes, type=SSEMessage) - return msg.choices[0].delta + if not msg.choices: + return None + return msg.choices[0] # ======================================================================== # Internal APIs @@ -86,15 +89,21 @@ def decode_sse_message(cls, json_bytes: bytes) -> str: @classmethod def to_endpoint_request(cls, query: Query) -> CreateChatCompletionRequest: - """Convert a Query to an OpenAI request.""" - if "prompt" not in query.data: - raise ValueError("prompt not found in query.data") - - messages = [{"role": Role5.user.value, "content": query.data["prompt"]}] - if "system" in query.data: - messages.insert( - 0, {"role": Role3.system.value, "content": query.data["system"]} - ) + """Convert a Query to an OpenAI request. + + Supports both single-turn (prompt/system) and multi-turn (messages array) formats. + """ + if "messages" in query.data and isinstance(query.data["messages"], list): + messages = query.data["messages"] + else: + if "prompt" not in query.data: + raise ValueError("prompt not found in query.data") + + messages = [{"role": Role5.user.value, "content": query.data["prompt"]}] + if "system" in query.data: + messages.insert( + 0, {"role": Role3.system.value, "content": query.data["system"]} + ) request = CreateChatCompletionRequest( model=ModelIdsShared(query.data.get("model", "no-model-name")), @@ -103,6 +112,7 @@ def to_endpoint_request(cls, query: Query) -> CreateChatCompletionRequest: stream=query.data.get("stream", False), max_completion_tokens=query.data.get("max_completion_tokens", 100), temperature=query.data.get("temperature", 0.7), + tools=query.data.get("tools"), ) return request @@ -119,9 +129,26 @@ def from_endpoint_response( if result_id is None: result_id = response.id + choice = response.choices[0] + metadata: dict[str, Any] = {} + if choice.finish_reason: + metadata["finish_reason"] = choice.finish_reason.value + if choice.message.tool_calls: + raw_tool_calls = [ + tc.model_dump(mode="json") for tc in choice.message.tool_calls.root + ] + metadata["tool_calls"] = raw_tool_calls + else: + raw_tool_calls = None + + tool_calls_tuple = tuple(raw_tool_calls) if raw_tool_calls else None return QueryResult( id=result_id, - response_output=TextModelOutput(output=response.choices[0].message.content), + response_output=TextModelOutput( + output=choice.message.content or "", + tool_calls=tool_calls_tuple, + ), + metadata=metadata, ) @classmethod @@ -160,11 +187,11 @@ def decode_endpoint_response( response_dict = msgspec.json.decode(response_bytes) # Set default values for optional fields if missing - response_dict["choices"][0]["message"]["refusal"] = "None" + response_dict["choices"][0]["message"]["refusal"] = "" response_dict["choices"][0]["logprobs"] = {"content": [], "refusal": []} if ( "content" not in response_dict["choices"][0]["message"] or response_dict["choices"][0]["message"]["content"] is None ): - response_dict["choices"][0]["message"]["content"] = "None" + response_dict["choices"][0]["message"]["content"] = "" return CreateChatCompletionResponse(**response_dict, ignore_extra=True) diff --git a/src/inference_endpoint/openai/openai_msgspec_adapter.py b/src/inference_endpoint/openai/openai_msgspec_adapter.py index 6106e1bd..9ef5b3a8 100644 --- a/src/inference_endpoint/openai/openai_msgspec_adapter.py +++ b/src/inference_endpoint/openai/openai_msgspec_adapter.py @@ -18,6 +18,7 @@ """ import time +from typing import Any import msgspec from inference_endpoint.config.schema import ModelParams, StreamingMode @@ -37,6 +38,7 @@ ChatCompletionResponse, ChatCompletionResponseMessage, ChatMessage, + SSEChoice, SSEMessage, ) @@ -45,6 +47,17 @@ # ============================================================================ +def _chat_message_from_dict(msg: dict) -> "ChatMessage": + """Build a ChatMessage from a dict, forwarding all supported fields.""" + return ChatMessage( + role=msg["role"], + content=msg.get("content"), + name=msg.get("name"), + tool_calls=msg.get("tool_calls"), + tool_call_id=msg.get("tool_call_id"), + ) + + class OpenAIMsgspecAdapter(HttpRequestAdapter): """OpenAI adapter using msgspec for serialization/deserialization.""" @@ -82,6 +95,7 @@ def dataset_transforms(cls, model_params: ModelParams) -> list[Transform]: "logit_bias", "user", "chat_template", + "tools", ] return [ ColumnFilter( @@ -105,10 +119,12 @@ def decode_response(cls, response_bytes: bytes, query_id: str) -> QueryResult: return cls.from_endpoint_response(openai_response, result_id=query_id) @classmethod - def decode_sse_message(cls, json_bytes: bytes) -> str: - """Decode SSE message and extract content string.""" + def decode_sse_message(cls, json_bytes: bytes) -> SSEChoice | None: + """Decode SSE message and return the SSEChoice (delta + finish_reason).""" msg = cls._sse_decoder.decode(json_bytes) - return msg.choices[0].delta + if not msg.choices: + return None + return msg.choices[0] # ======================================================================== # Internal APIs @@ -129,24 +145,27 @@ def to_endpoint_request(cls, query: Query) -> ChatCompletionRequest: Returns: msgspec.Struct ChatCompletionRequest """ - if "prompt" not in query.data: - raise ValueError("prompt not found in query.data") - - messages = [ - ChatMessage( - role="user", - content=query.data["prompt"], - name=query.data.get("name"), - ), - ] - if "system" in query.data: - messages.insert( - 0, + if "messages" in query.data: + messages = [_chat_message_from_dict(m) for m in query.data["messages"]] + else: + if "prompt" not in query.data: + raise ValueError("prompt not found in query.data") + + messages = [ ChatMessage( - role="system", - content=query.data["system"], + role="user", + content=query.data["prompt"], + name=query.data.get("name"), ), - ) + ] + if "system" in query.data: + messages.insert( + 0, + ChatMessage( + role="system", + content=query.data["system"], + ), + ) return ChatCompletionRequest( model=query.data.get("model", "no-model-name"), @@ -164,6 +183,7 @@ def to_endpoint_request(cls, query: Query) -> ChatCompletionRequest: logit_bias=query.data.get("logit_bias"), user=query.data.get("user"), chat_template=query.data.get("chat_template"), + tools=query.data.get("tools"), ) @classmethod @@ -184,9 +204,26 @@ def from_endpoint_response( if not response.choices: raise ValueError("Response must contain at least one choice") + choice = response.choices[0] + metadata: dict[str, Any] = {} + if choice.finish_reason: + metadata["finish_reason"] = choice.finish_reason + if choice.message.tool_calls: + metadata["tool_calls"] = choice.message.tool_calls + if choice.message.reasoning_content: + metadata["reasoning_content"] = choice.message.reasoning_content + + tool_calls_tuple = ( + tuple(choice.message.tool_calls) if choice.message.tool_calls else None + ) return QueryResult( id=result_id or response.id, - response_output=TextModelOutput(output=response.choices[0].message.content), + response_output=TextModelOutput( + output=choice.message.content or "", + reasoning=choice.message.reasoning_content, + tool_calls=tool_calls_tuple, + ), + metadata=metadata, ) @classmethod diff --git a/src/inference_endpoint/openai/types.py b/src/inference_endpoint/openai/types.py index 875656fa..5f4cf2a6 100644 --- a/src/inference_endpoint/openai/types.py +++ b/src/inference_endpoint/openai/types.py @@ -36,16 +36,25 @@ # NOTE(vir): msgspec usage # omit_defaults=True: Fields with static defaults are omitted if value equals default (ie those not using default_factory) -# gc=False: Safe for request/response structs with scalar and nested struct fields only. +# gc=False: audit 2026-05: all container fields are populated at construction and never mutated. # frozen=True: Makes structs immutable and hashable, also enables faster struct decoding # (direct attribute access via fixed memory offset vs hash table lookup) +# gc=False: audit 2026-05: tool_calls is set at construction; frozen=True blocks field reassignment. class SSEDelta(msgspec.Struct, frozen=True, kw_only=True, omit_defaults=True, gc=False): # type: ignore[call-arg] - """SSE delta object containing content.""" + """SSE delta object containing content. - content: str = "" - reasoning: str = "" + AT-RISK (gc=False): Has mutable container field `tool_calls`. Any change that + mutates `tool_calls` after construction or stores cyclic references in it + must be audited; if so, remove gc=False. + """ + + role: str | None = None + content: str | None = None + reasoning_content: str | None = None # SGLang / DeepSeek field name + reasoning: str | None = None # vLLM field name + tool_calls: list[dict[str, Any]] | None = None class SSEChoice( @@ -70,23 +79,39 @@ class SSEMessage( # ============================================================================ +# gc=False: audit 2026-05: content/tool_calls set at construction; frozen=True blocks field reassignment. class ChatMessage( msgspec.Struct, frozen=True, kw_only=True, omit_defaults=True, gc=False ): # type: ignore[call-arg] """Chat message in OpenAI format. - content: str for text-only messages; list[dict] for multimodal (vision). + AT-RISK (gc=False): Has mutable container fields `content` (list[dict] for multimodal) + and `tool_calls`. Any change that mutates these after construction or stores cyclic + references in them must be audited; if so, remove gc=False. + + content: str for text-only messages; list[dict] for multimodal (vision); + None for tool-dispatching assistant messages. + tool_calls: list of tool call objects for assistant messages that invoke tools. + tool_call_id: correlates a tool result message to its tool call. """ role: str - content: ChatMessageContent + content: ChatMessageContent | None = None name: str | None = None + tool_calls: list[dict[str, Any]] | None = None + tool_call_id: str | None = None +# gc=False: audit 2026-05: messages/tools set at construction; frozen=True blocks field reassignment. class ChatCompletionRequest( msgspec.Struct, frozen=True, kw_only=True, omit_defaults=True, gc=False ): # type: ignore[call-arg] - """OpenAI chat completion request.""" + """OpenAI chat completion request. + + AT-RISK (gc=False): Has mutable container fields `messages`, `tools`, and `logit_bias`. + Any change that mutates these after construction or stores cyclic references in them + must be audited; if so, remove gc=False. + """ model: str messages: list[ChatMessage] @@ -103,8 +128,10 @@ class ChatCompletionRequest( logit_bias: dict[str, float] | None = None user: str | None = None chat_template: str | None = None + tools: list[dict[str, Any]] | None = None +# gc=False: audit 2026-05: tool_calls set at construction; frozen=True blocks field reassignment. class ChatCompletionResponseMessage( msgspec.Struct, frozen=True, kw_only=True, omit_defaults=True, gc=False ): # type: ignore[call-arg] @@ -118,6 +145,8 @@ class ChatCompletionResponseMessage( role: str content: str | None = None refusal: str | None = None + tool_calls: list[dict[str, Any]] | None = None + reasoning_content: str | None = None class ChatCompletionChoice( @@ -144,6 +173,7 @@ class CompletionUsage( total_tokens: int +# gc=False: audit 2026-05: choices set at construction; frozen=True blocks field reassignment. class ChatCompletionResponse( msgspec.Struct, frozen=True, diff --git a/tests/integration/test_multi_turn.py b/tests/integration/test_multi_turn.py new file mode 100644 index 00000000..bd1f82a9 --- /dev/null +++ b/tests/integration/test_multi_turn.py @@ -0,0 +1,835 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for multi-turn benchmarking end-to-end. + +Validates that MultiTurnDataset + MultiTurnStrategy + BenchmarkSession work +correctly together against a real HTTP echo server. + +Tests cover: + 1. Dataset-history mode (use_dataset_history=True): pre-built messages are + issued as-is; each turn is issued sequentially per conversation. + 2. Live-history mode (use_dataset_history=False): messages are built at + runtime from ConversationManager.message_history; the injected messages + grow with each turn. + 3. Multiple concurrent conversations complete successfully. + 4. Turn ordering: turn N+1 is never issued before turn N completes. +""" + +import asyncio +import random +import time +from urllib.parse import urljoin + +import pandas as pd +import pytest +from inference_endpoint import metrics +from inference_endpoint.config.runtime_settings import RuntimeSettings +from inference_endpoint.config.schema import ( + LoadPattern, + LoadPatternType, + MultiTurnConfig, +) +from inference_endpoint.core.record import EventRecord +from inference_endpoint.core.types import QueryResult +from inference_endpoint.dataset_manager.multi_turn_dataset import MultiTurnDataset +from inference_endpoint.endpoint_client.config import HTTPClientConfig +from inference_endpoint.endpoint_client.http_client import HTTPEndpointClient +from inference_endpoint.endpoint_client.http_sample_issuer import HttpClientSampleIssuer +from inference_endpoint.exceptions import InputValidationError +from inference_endpoint.load_generator.conversation_manager import ConversationManager +from inference_endpoint.load_generator.multi_turn_strategy import MultiTurnStrategy +from inference_endpoint.load_generator.session import ( + BenchmarkSession, + PhaseConfig, + PhaseType, +) +from inference_endpoint.testing.echo_server import EchoServer + + +class _NoOpPublisher: + def publish(self, event_record: EventRecord) -> None: + pass + + def flush(self) -> None: + pass + + +def _make_dataset(rows: list[dict]) -> MultiTurnDataset: + """Build a loaded MultiTurnDataset from a list of row dicts.""" + df = pd.DataFrame(rows) + ds = MultiTurnDataset(dataframe=df) + ds.load() + return ds + + +def _make_strategy( + ds: MultiTurnDataset, + use_dataset_history: bool = True, + target_concurrency: int | None = None, +) -> MultiTurnStrategy: + mt_cfg = MultiTurnConfig( + turn_timeout_s=10.0, + use_dataset_history=use_dataset_history, + ) + assert ds.conversation_metadata is not None + return MultiTurnStrategy( + conversation_manager=ConversationManager(), + dataset_metadata=ds.conversation_metadata, + multi_turn_config=mt_cfg, + target_concurrency=target_concurrency, + ) + + +async def _run_session( + server_url: str, + ds: MultiTurnDataset, + strategy: MultiTurnStrategy, + responses_out: dict, +) -> int: + """Wire up HTTPEndpointClient + BenchmarkSession and run one phase. + + Populates responses_out[query_id] = response_text for every completed turn. + Returns issued_count. + """ + loop = asyncio.get_running_loop() + + def on_complete(result: QueryResult) -> None: + strategy.on_sample_complete(result) + responses_out[result.id] = result.get_response_output_string() + + http_config = HTTPClientConfig( + endpoint_urls=[urljoin(server_url, "/v1/chat/completions")], + warmup_connections=0, + num_workers=2, + ) + http_client = await HTTPEndpointClient.create(http_config, loop) + issuer = HttpClientSampleIssuer(http_client) + + try: + session = BenchmarkSession( + issuer=issuer, + event_publisher=_NoOpPublisher(), + loop=loop, + on_sample_complete=on_complete, + ) + rt = RuntimeSettings( + metrics.Throughput(1000), + [metrics.Throughput(1000)], + min_duration_ms=0, + max_duration_ms=30_000, + n_samples_from_dataset=ds.num_samples(), + n_samples_to_issue=ds.num_samples(), + min_sample_count=1, + rng_sched=random.Random(42), + rng_sample_index=random.Random(42), + load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), + ) + phase = PhaseConfig( + "perf", + rt, + ds, + PhaseType.PERFORMANCE, + strategy=strategy, + ) + result = await asyncio.wait_for(session.run([phase]), timeout=30.0) + return result.perf_results[0].issued_count + finally: + await http_client.shutdown_async() + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def echo_server(): + server = EchoServer(port=0) + server.start() + try: + yield server + finally: + server.stop() + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_single_conversation_all_turns_issued(echo_server): + """All turns of a single conversation are issued and completed.""" + rows = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "Hello"}, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "Hi"}, + {"conversation_id": "c1", "turn": 3, "role": "user", "content": "Bye"}, + ] + ds = _make_dataset(rows) + strategy = _make_strategy(ds) + responses: dict = {} + + count = await _run_session(echo_server.url, ds, strategy, responses) + + # Two user turns (turns 1 and 3); turn 2 is assistant so not a client turn + assert count == 2 + assert len(responses) == 2 + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_multiple_conversations_all_issued(echo_server): + """Multiple conversations complete independently and concurrently.""" + rows = [] + for conv_idx in range(3): + conv_id = f"conv_{conv_idx}" + rows.append( + { + "conversation_id": conv_id, + "turn": 1, + "role": "user", + "content": f"Q1 {conv_idx}", + } + ) + rows.append( + { + "conversation_id": conv_id, + "turn": 2, + "role": "assistant", + "content": f"A1 {conv_idx}", + } + ) + rows.append( + { + "conversation_id": conv_id, + "turn": 3, + "role": "user", + "content": f"Q2 {conv_idx}", + } + ) + ds = _make_dataset(rows) + strategy = _make_strategy(ds) + responses: dict = {} + + count = await _run_session(echo_server.url, ds, strategy, responses) + + # 3 conversations × 2 user turns each = 6 + assert count == 6 + assert len(responses) == 6 + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_dataset_history_messages_present(echo_server): + """Dataset-history mode: each request contains the messages array from the dataset.""" + received_payloads: list[dict] = [] + + # Override get_response to capture the incoming request body. + # EchoServer._handle_echo_chat_completions_request parses it into + # CreateChatCompletionRequest — we capture the raw JSON at the HTTP layer + # by subclassing and overriding get_response (called with first user content). + # Instead, use a custom echo server that logs the full payload. + class CapturingEchoServer(EchoServer): + async def _handle_echo_chat_completions_request(self, request): + try: + payload = await request.json() + received_payloads.append(payload) + except Exception: + pass # request body may not be JSON + return await super()._handle_echo_chat_completions_request(request) + + server = CapturingEchoServer(port=0) + server.start() + try: + rows = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "First question", + }, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": "First answer", + }, + { + "conversation_id": "c1", + "turn": 3, + "role": "user", + "content": "Second question", + }, + ] + ds = _make_dataset(rows) + strategy = _make_strategy(ds, use_dataset_history=True) + responses: dict = {} + + count = await _run_session(server.url, ds, strategy, responses) + assert count == 2 + + # Both requests must include a "messages" array + assert len(received_payloads) == 2 + for payload in received_payloads: + assert "messages" in payload + assert len(payload["messages"]) >= 1 + + # Turn 1 should have 1 user message; turn 3 should have 3 messages + # (system? no system here — user, assistant, user) + msg_counts = sorted(len(p["messages"]) for p in received_payloads) + assert msg_counts == [1, 3] + finally: + server.stop() + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_live_history_messages_grow_each_turn(echo_server): + """Live-history mode: messages array grows with each completed turn.""" + received_payloads: list[dict] = [] + + class CapturingEchoServer(EchoServer): + async def _handle_echo_chat_completions_request(self, request): + try: + payload = await request.json() + received_payloads.append(payload) + except Exception: + pass # request body may not be JSON + return await super()._handle_echo_chat_completions_request(request) + + server = CapturingEchoServer(port=0) + server.start() + try: + rows = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "Turn one"}, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": "Answer one", + }, + {"conversation_id": "c1", "turn": 3, "role": "user", "content": "Turn two"}, + ] + ds = _make_dataset(rows) + strategy = _make_strategy(ds, use_dataset_history=False) + responses: dict = {} + + count = await _run_session(server.url, ds, strategy, responses) + assert count == 2 + + assert len(received_payloads) == 2 + msg_counts = sorted(len(p["messages"]) for p in received_payloads) + # Turn 1: [user msg] = 1; Turn 3: [user, assistant, user] = 3 + assert msg_counts == [1, 3] + finally: + server.stop() + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_turn_ordering_enforced_end_to_end(echo_server): + """Turn N+1 is issued after Turn N's response arrives, verified by timestamps.""" + complete_times: dict[str, float] = {} + + rows = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "First"}, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": "Response", + }, + {"conversation_id": "c1", "turn": 3, "role": "user", "content": "Second"}, + ] + ds = _make_dataset(rows) + mt_cfg = MultiTurnConfig(turn_timeout_s=10.0, use_dataset_history=True) + conv_manager = ConversationManager() + strategy = MultiTurnStrategy( + conversation_manager=conv_manager, + dataset_metadata=ds.conversation_metadata, + multi_turn_config=mt_cfg, + ) + + # Wrap on_sample_complete to record completion timestamps + orig_on_sample_complete = strategy.on_sample_complete + + def tracked_on_sample_complete(result: QueryResult) -> None: + # Map query_id → sample_index via uuid_to_index (set after session runs) + complete_times[result.id] = time.monotonic() + orig_on_sample_complete(result) + + strategy.on_sample_complete = tracked_on_sample_complete + + loop = asyncio.get_running_loop() + responses: dict[str, str] = {} + + http_config = HTTPClientConfig( + endpoint_urls=[urljoin(echo_server.url, "/v1/chat/completions")], + warmup_connections=0, + num_workers=1, + ) + http_client = await HTTPEndpointClient.create(http_config, loop) + issuer = HttpClientSampleIssuer(http_client) + + rt = RuntimeSettings( + metrics.Throughput(1000), + [metrics.Throughput(1000)], + min_duration_ms=0, + max_duration_ms=30_000, + n_samples_from_dataset=ds.num_samples(), + n_samples_to_issue=ds.num_samples(), + min_sample_count=1, + rng_sched=random.Random(42), + rng_sample_index=random.Random(42), + load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), + ) + + try: + + def on_complete(result: QueryResult) -> None: + tracked_on_sample_complete(result) + responses[result.id] = result.get_response_output_string() + + session = BenchmarkSession( + issuer=issuer, + event_publisher=_NoOpPublisher(), + loop=loop, + on_sample_complete=on_complete, + ) + phase = PhaseConfig("perf", rt, ds, PhaseType.PERFORMANCE, strategy=strategy) + result = await asyncio.wait_for(session.run([phase]), timeout=30.0) + finally: + await http_client.shutdown_async() + + assert result.perf_results[0].issued_count == 2 + + # Build query_id → sample_index from session result + uuid_to_index = result.perf_results[0].uuid_to_index + index_to_query = {v: k for k, v in uuid_to_index.items()} + + # Sample 0 = turn 1, sample 1 = turn 3 + q_turn1 = index_to_query[0] + q_turn3 = index_to_query[1] + + # Turn 3 must complete after turn 1 completes + assert complete_times[q_turn3] >= complete_times[q_turn1] + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_tool_use_conversation_all_turns_issued(echo_server): + """Tool-use conversation: all client turns (user + tool) are issued and completed.""" + tool_calls = [ + { + "id": "call_1", + "type": "function", + "function": {"name": "search", "arguments": '{"q": "test"}'}, + } + ] + tool_results = [{"tool_call_id": "call_1", "content": "search result"}] + tools = [ + { + "type": "function", + "function": { + "name": "search", + "description": "Search", + "parameters": { + "type": "object", + "properties": {"q": {"type": "string"}}, + "required": ["q"], + }, + }, + } + ] + + rows = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "Find something", + "tools": tools, + }, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": None, + "tool_calls": tool_calls, + }, + { + "conversation_id": "c1", + "turn": 3, + "role": "tool", + "tool_results": tool_results, + "tools": tools, + }, + { + "conversation_id": "c1", + "turn": 4, + "role": "assistant", + "content": "Here is the result", + }, + {"conversation_id": "c1", "turn": 5, "role": "user", "content": "Thanks"}, + ] + ds = _make_dataset(rows) + strategy = _make_strategy(ds) + responses: dict = {} + + count = await _run_session(echo_server.url, ds, strategy, responses) + + # Client turns: turn 1 (user) + turn 3 (tool) + turn 5 (user) = 3 + assert count == 3 + assert len(responses) == 3 + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_conversation_ending_with_tool_row(echo_server): + """Conversation ending with a tool row completes normally (matches agentic_coding dataset pattern).""" + tool_calls = [ + { + "id": "call_1", + "type": "function", + "function": {"name": "write_file", "arguments": '{"path": "out.py"}'}, + } + ] + tool_results = [{"tool_call_id": "call_1", "content": "file written"}] + tools = [ + { + "type": "function", + "function": { + "name": "write_file", + "description": "Write a file", + "parameters": { + "type": "object", + "properties": {"path": {"type": "string"}}, + "required": ["path"], + }, + }, + } + ] + + rows = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "Write a file", + "tools": tools, + }, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": None, + "tool_calls": tool_calls, + }, + { + "conversation_id": "c1", + "turn": 3, + "role": "tool", + "tool_results": tool_results, + "tools": tools, + }, + ] + ds = _make_dataset(rows) + strategy = _make_strategy(ds) + responses: dict = {} + + count = await _run_session(echo_server.url, ds, strategy, responses) + + # Client turns: turn 1 (user) + turn 3 (tool) = 2 + assert count == 2 + assert len(responses) == 2 + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_concurrent_conversations_stress(echo_server): + """12 conversations × 3 turns each complete with correct counts.""" + num_convs = 12 + turns_per_conv = 3 # 2 user turns + 1 assistant turn each + rows = [] + for i in range(num_convs): + conv_id = f"stress_conv_{i}" + rows.append( + { + "conversation_id": conv_id, + "turn": 1, + "role": "user", + "content": f"Q1-{i}", + } + ) + rows.append( + { + "conversation_id": conv_id, + "turn": 2, + "role": "assistant", + "content": f"A1-{i}", + } + ) + rows.append( + { + "conversation_id": conv_id, + "turn": 3, + "role": "user", + "content": f"Q2-{i}", + } + ) + + ds = _make_dataset(rows) + strategy = _make_strategy(ds) + responses: dict = {} + + count = await _run_session(echo_server.url, ds, strategy, responses) + + # 12 conversations × 2 client turns each = 24 + expected_client_turns = num_convs * (turns_per_conv - 1) # 24 + assert count == expected_client_turns + assert len(responses) == expected_client_turns + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_multi_turn_active_conversations_respects_target_concurrency(echo_server): + num_convs = 20 + rows = [] + for i in range(num_convs): + conv_id = f"cap_conv_{i}" + rows += [ + { + "conversation_id": conv_id, + "turn": 1, + "role": "user", + "content": f"Q1-{i}", + }, + { + "conversation_id": conv_id, + "turn": 2, + "role": "assistant", + "content": f"A1-{i}", + }, + { + "conversation_id": conv_id, + "turn": 3, + "role": "user", + "content": f"Q2-{i}", + }, + ] + + ds = _make_dataset(rows) + strategy = _make_strategy(ds, target_concurrency=4) + responses: dict = {} + + observed_max: list[int] = [] + orig_on_sample_complete = strategy.on_sample_complete + + def tracked_on_sample_complete(result) -> None: + observed_max.append(len(strategy._active_iters)) + orig_on_sample_complete(result) + + strategy.on_sample_complete = tracked_on_sample_complete + + await _run_session(echo_server.url, ds, strategy, responses) + + assert len(responses) == num_convs * 2 # 2 client turns per conversation + assert max(observed_max, default=0) <= 4 + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_multi_turn_pipeline_exception_propagates(echo_server): + rows = [ + {"conversation_id": "err_c1", "turn": 1, "role": "user", "content": "Q1"}, + {"conversation_id": "err_c1", "turn": 2, "role": "assistant", "content": "A1"}, + {"conversation_id": "err_c1", "turn": 3, "role": "user", "content": "Q2"}, + ] + ds = _make_dataset(rows) + strategy = _make_strategy(ds) + + call_count = 0 + orig_issue_next_turn = strategy._issue_next_turn + + def failing_issue_next_turn(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count >= 2: + raise RuntimeError("injected pipeline error") + return orig_issue_next_turn(*args, **kwargs) + + strategy._issue_next_turn = failing_issue_next_turn + + loop = asyncio.get_running_loop() + http_config = HTTPClientConfig( + endpoint_urls=[urljoin(echo_server.url, "/v1/chat/completions")], + warmup_connections=0, + num_workers=2, + ) + http_client = await HTTPEndpointClient.create(http_config, loop) + issuer = HttpClientSampleIssuer(http_client) + + try: + session = BenchmarkSession( + issuer=issuer, + event_publisher=_NoOpPublisher(), + loop=loop, + on_sample_complete=strategy.on_sample_complete, + ) + rt = RuntimeSettings( + metrics.Throughput(1000), + [metrics.Throughput(1000)], + min_duration_ms=0, + max_duration_ms=30_000, + n_samples_from_dataset=ds.num_samples(), + n_samples_to_issue=ds.num_samples(), + min_sample_count=1, + rng_sched=random.Random(42), + rng_sample_index=random.Random(42), + load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT), + ) + phase = PhaseConfig("perf", rt, ds, PhaseType.PERFORMANCE, strategy=strategy) + + with pytest.raises(RuntimeError, match="injected pipeline error"): + await asyncio.wait_for(session.run([phase]), timeout=30.0) + + assert strategy._inflight == {} + finally: + await http_client.shutdown_async() + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_tools_field_forwarded_to_endpoint(echo_server): + """The 'tools' array from the dataset reaches the endpoint in every request payload. + + TODO: Add a tool-call-aware server that returns dynamic tool_call_ids to + validate live-history mode with real tool_call_id round-tripping. + """ + received_payloads: list[dict] = [] + + class CapturingEchoServer(EchoServer): + async def _handle_echo_chat_completions_request(self, request): + try: + payload = await request.json() + received_payloads.append(payload) + except Exception: + pass # request body may not be JSON + return await super()._handle_echo_chat_completions_request(request) + + server = CapturingEchoServer(port=0) + server.start() + try: + tool_calls = [ + { + "id": "call_1", + "type": "function", + "function": {"name": "search", "arguments": '{"q": "hello"}'}, + } + ] + tool_results = [{"tool_call_id": "call_1", "content": "result"}] + tools = [ + { + "type": "function", + "function": { + "name": "search", + "description": "Search", + "parameters": { + "type": "object", + "properties": {"q": {"type": "string"}}, + "required": ["q"], + }, + }, + } + ] + + rows = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "Search for hello", + "tools": tools, + }, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": None, + "tool_calls": tool_calls, + }, + { + "conversation_id": "c1", + "turn": 3, + "role": "tool", + "tool_results": tool_results, + "tools": tools, + }, + ] + ds = _make_dataset(rows) + strategy = _make_strategy(ds, use_dataset_history=True) + responses: dict = {} + + count = await _run_session(server.url, ds, strategy, responses) + assert count == 2 + + assert len(received_payloads) == 2 + for payload in received_payloads: + assert "tools" in payload + assert len(payload["tools"]) == 1 + assert payload["tools"][0]["function"]["name"] == "search" + finally: + server.stop() + + +@pytest.mark.integration +def test_live_history_rejects_tool_turns(): + """MultiTurnStrategy raises InputValidationError at __init__ when use_dataset_history=False + and the dataset contains tool-role turns. + """ + tool_calls = [ + { + "id": "call_1", + "type": "function", + "function": {"name": "search", "arguments": '{"q": "hello"}'}, + } + ] + tool_results = [{"tool_call_id": "call_1", "content": "result"}] + rows = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "Search"}, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": None, + "tool_calls": tool_calls, + }, + { + "conversation_id": "c1", + "turn": 3, + "role": "tool", + "tool_results": tool_results, + }, + ] + ds = _make_dataset(rows) + mt_cfg = MultiTurnConfig(turn_timeout_s=10.0, use_dataset_history=False) + with pytest.raises(InputValidationError, match="use_dataset_history=True"): + MultiTurnStrategy( + conversation_manager=ConversationManager(), + dataset_metadata=ds.conversation_metadata, + multi_turn_config=mt_cfg, + ) diff --git a/tests/performance/openai/test_adapter.py b/tests/performance/openai/test_adapter.py index e38ec5e8..6ed12ab2 100644 --- a/tests/performance/openai/test_adapter.py +++ b/tests/performance/openai/test_adapter.py @@ -71,11 +71,7 @@ def make_response_bytes(text: str) -> bytes: def make_sse_bytes(text: str) -> bytes: """Create SSE message JSON bytes.""" return json.dumps( - { - "choices": [ - {"delta": {"content": text, "reasoning": ""}, "finish_reason": None} - ] - } + {"choices": [{"delta": {"content": text}, "finish_reason": None}]} ).encode() diff --git a/tests/performance/openai/test_msgspec_adapter.py b/tests/performance/openai/test_msgspec_adapter.py index 9fc6a10b..21bedd25 100644 --- a/tests/performance/openai/test_msgspec_adapter.py +++ b/tests/performance/openai/test_msgspec_adapter.py @@ -71,11 +71,7 @@ def make_response_bytes(text: str) -> bytes: def make_sse_bytes(text: str) -> bytes: """Create SSE message JSON bytes.""" return json.dumps( - { - "choices": [ - {"delta": {"content": text, "reasoning": ""}, "finish_reason": None} - ] - } + {"choices": [{"delta": {"content": text}, "finish_reason": None}]} ).encode() diff --git a/tests/performance/openai/test_types.py b/tests/performance/openai/test_types.py index 43df3a4a..bdf8aa6b 100644 --- a/tests/performance/openai/test_types.py +++ b/tests/performance/openai/test_types.py @@ -61,7 +61,7 @@ def make_sse_message(text: str) -> SSEMessage: return SSEMessage( choices=( SSEChoice( - delta=SSEDelta(content=text, reasoning=""), + delta=SSEDelta(content=text), finish_reason=None, ), ) diff --git a/tests/unit/async_utils/services/event_logger/test_sql_writer.py b/tests/unit/async_utils/services/event_logger/test_sql_writer.py index 4d6bbcc6..ae443f54 100644 --- a/tests/unit/async_utils/services/event_logger/test_sql_writer.py +++ b/tests/unit/async_utils/services/event_logger/test_sql_writer.py @@ -34,9 +34,14 @@ from sqlalchemy.orm import Session -def _record(event_type, uuid="", ts=0, data=None): +def _record(event_type, uuid="", ts=0, data=None, conversation_id="", turn=None): return EventRecord( - event_type=event_type, timestamp_ns=ts, sample_uuid=uuid, data=data + event_type=event_type, + timestamp_ns=ts, + sample_uuid=uuid, + conversation_id=conversation_id, + turn=turn, + data=data, ) @@ -58,6 +63,9 @@ def test_session_event_topic(self): assert row.event_type == "session.ended" assert row.sample_uuid == "" assert row.timestamp_ns == 42 + # Defaults for non-multi-turn events: empty conversation_id, NULL turn. + assert row.conversation_id == "" + assert row.turn is None def test_error_event_topic(self): row = _record_to_row(_record(ErrorEventType.GENERIC, ts=99)) @@ -74,6 +82,19 @@ def test_none_data_encodes_to_null(self): decoded = msgspec.json.decode(row.data) assert decoded is None + def test_conversation_id_and_turn_copied_to_row(self): + row = _record_to_row( + _record( + SampleEventType.ISSUED, + uuid="q1", + ts=10, + conversation_id="conv-x", + turn=2, + ) + ) + assert row.conversation_id == "conv-x" + assert row.turn == 2 + # --------------------------------------------------------------------------- # EventRowModel schema @@ -260,3 +281,43 @@ def test_mixed_event_types(self, tmp_path): "session.ended", ] engine.dispose() + + def test_conversation_id_and_turn_persisted(self, tmp_path): + writer = SQLWriter(tmp_path / "events", flush_interval=1) + try: + writer.write( + _record( + SampleEventType.ISSUED, + uuid="q1", + ts=10, + conversation_id="conv-a", + turn=1, + ) + ) + writer.write( + _record( + SampleEventType.COMPLETE, + uuid="q1", + ts=20, + conversation_id="conv-a", + turn=1, + ) + ) + # Single-turn / non-conversation event leaves defaults. + writer.write(_record(SessionEventType.STARTED, ts=0)) + finally: + writer.close() + + engine = create_engine(f"sqlite:///{tmp_path / 'events.db'}") + with Session(engine) as session: + rows = ( + session.execute(select(EventRowModel).order_by(EventRowModel.id)) + .scalars() + .all() + ) + assert [(r.conversation_id, r.turn) for r in rows] == [ + ("conv-a", 1), + ("conv-a", 1), + ("", None), + ] + engine.dispose() diff --git a/tests/unit/async_utils/services/metrics_aggregator/conftest.py b/tests/unit/async_utils/services/metrics_aggregator/conftest.py index db6bcfcf..7adbe036 100644 --- a/tests/unit/async_utils/services/metrics_aggregator/conftest.py +++ b/tests/unit/async_utils/services/metrics_aggregator/conftest.py @@ -68,6 +68,22 @@ async def token_count_async( await asyncio.sleep(self._delay) return len(text.split()) + async def token_count_message_async( + self, + content: str, + reasoning: str | None, + tool_calls, + _loop: asyncio.AbstractEventLoop, + ) -> int: + import msgspec + + await asyncio.sleep(self._delay) + tool_calls_str = ( + msgspec.json.encode(list(tool_calls)).decode() if tool_calls else "" + ) + combined = (content or "") + " " + (reasoning or "") + " " + tool_calls_str + return len(combined.split()) + def close(self) -> None: pass diff --git a/tests/unit/async_utils/services/metrics_aggregator/test_aggregator.py b/tests/unit/async_utils/services/metrics_aggregator/test_aggregator.py index 0669a31e..9877aee5 100644 --- a/tests/unit/async_utils/services/metrics_aggregator/test_aggregator.py +++ b/tests/unit/async_utils/services/metrics_aggregator/test_aggregator.py @@ -41,7 +41,7 @@ SampleEventType, SessionEventType, ) -from inference_endpoint.core.types import ErrorData, PromptData +from inference_endpoint.core.types import ErrorData, PromptData, TextModelOutput from .conftest import ( MockTokenizePool, @@ -1120,3 +1120,43 @@ def __exit__(self, *args): ) finally: agg.close() + + @pytest.mark.asyncio + async def test_tpot_osl_for_tool_call_complete(self, tmp_path): + """OSL and TPOT use message-path tokenization when COMPLETE carries tool_calls.""" + loop = asyncio.get_event_loop() + pool = MockTokenizePool(delay=0.0) + with ManagedZMQContext.scoped(socket_dir=str(tmp_path)) as ctx: + agg, registry, _ = make_aggregator( + ctx, loop, "agg_tpot_osl_tool_call", tokenize_pool=pool + ) + try: + tool_call = { + "id": "c1", + "type": "function", + "function": {"name": "f", "arguments": "{}"}, + } + await agg.process( + [ + session_event( + SessionEventType.START_PERFORMANCE_TRACKING, ts=0 + ), + sample_event(SampleEventType.ISSUED, "s1", ts=1000), + sample_event(SampleEventType.RECV_FIRST, "s1", ts=2000), + sample_event( + SampleEventType.COMPLETE, + "s1", + ts=5000, + data=TextModelOutput(output="ok", tool_calls=(tool_call,)), + ), + ] + ) + await agg._table.drain_tasks() + # OSL = token_count("ok" + tool_calls_json) = 2 + assert snapshot_series_total(registry, MetricSeriesKey.OSL.value) == 2 + # tpot = (5000 - 2000) / token_count(tool_calls_json) = 3000 / 1 = 3000 + assert snapshot_series_total( + registry, MetricSeriesKey.TPOT_NS.value + ) == pytest.approx(3000.0) + finally: + agg.close() diff --git a/tests/unit/async_utils/services/metrics_aggregator/test_metrics_table.py b/tests/unit/async_utils/services/metrics_aggregator/test_metrics_table.py index 5924b547..d3781ad9 100644 --- a/tests/unit/async_utils/services/metrics_aggregator/test_metrics_table.py +++ b/tests/unit/async_utils/services/metrics_aggregator/test_metrics_table.py @@ -22,6 +22,8 @@ from __future__ import annotations +import asyncio + import msgspec import pytest from inference_endpoint.async_utils.services.metrics_aggregator.metrics_table import ( @@ -277,3 +279,124 @@ def test_multiple_tracking_windows(self): assert table.tracked_blocks[1].duration_ns == 200 # 1000 - 800 assert table.total_tracked_duration_ns == 800 assert table.total_completed_tracked_samples == 2 + + +@pytest.mark.unit +@pytest.mark.asyncio +class TestOslTriggerToolCalls: + """OslTrigger routes to message path when tool_calls are present.""" + + async def test_osl_with_tool_calls_uses_message_path(self): + """OslTrigger stores combined content+tool_calls word count.""" + from inference_endpoint.async_utils.services.metrics_aggregator.metrics_table import ( + OslTrigger, + SampleRow, + ) + from inference_endpoint.core.types import TextModelOutput + + from .conftest import MockTokenizePool, snapshot_series_count + + registry = MetricsRegistry() + registry.register_series("osl", hdr_low=1, hdr_high=100_000) + loop = asyncio.get_running_loop() + pool = MockTokenizePool(delay=0) + trigger = OslTrigger(registry, pool, loop) + + tool_calls = ( + { + "id": "c1", + "type": "function", + "function": {"name": "f", "arguments": "{}"}, + }, + ) + tmo = TextModelOutput(output="hello world", tool_calls=tool_calls) + ev = EventRecord( + event_type=SampleEventType.COMPLETE, + timestamp_ns=1000, + sample_uuid="s1", + data=tmo, + ) + row = SampleRow(sample_uuid="s1") + task = trigger.fire(ev, row, {}) + assert task is not None + await task + + assert snapshot_series_count(registry, "osl") == 1 + + async def test_osl_without_tool_calls_uses_text_path(self): + """OslTrigger uses text path for output with no tool_calls (regression guard).""" + from inference_endpoint.async_utils.services.metrics_aggregator.metrics_table import ( + OslTrigger, + SampleRow, + ) + from inference_endpoint.core.types import TextModelOutput + + from .conftest import MockTokenizePool, snapshot_series_count + + registry = MetricsRegistry() + registry.register_series("osl", hdr_low=1, hdr_high=100_000) + loop = asyncio.get_running_loop() + pool = MockTokenizePool(delay=0) + trigger = OslTrigger(registry, pool, loop) + + tmo = TextModelOutput(output="hello world") + ev = EventRecord( + event_type=SampleEventType.COMPLETE, + timestamp_ns=1000, + sample_uuid="s1", + data=tmo, + ) + row = SampleRow(sample_uuid="s1") + task = trigger.fire(ev, row, {}) + assert task is not None + await task + + assert snapshot_series_count(registry, "osl") == 1 + + +@pytest.mark.unit +@pytest.mark.asyncio +class TestTpotTriggerToolCalls: + """TpotTrigger routes to message path when tool_calls are present.""" + + async def test_tpot_tool_calls_only_response(self): + """TpotTrigger includes tool_calls in TPOT denominator for agentic responses.""" + from inference_endpoint.async_utils.services.metrics_aggregator.metrics_table import ( + SampleField, + SampleRow, + TpotTrigger, + ) + from inference_endpoint.core.types import TextModelOutput + + from .conftest import MockTokenizePool, snapshot_series_count + + registry = MetricsRegistry() + registry.register_series( + "tpot_ns", hdr_low=1, hdr_high=100_000_000_000, dtype=float + ) + loop = asyncio.get_running_loop() + pool = MockTokenizePool(delay=0) + trigger = TpotTrigger(registry, pool, loop) + + tool_calls = ( + { + "id": "c1", + "type": "function", + "function": {"name": "f", "arguments": "{}"}, + }, + ) + tmo = TextModelOutput(output=[], tool_calls=tool_calls) + ev = EventRecord( + event_type=SampleEventType.COMPLETE, + timestamp_ns=2000, + sample_uuid="s1", + data=tmo, + ) + row = SampleRow(sample_uuid="s1") + # RECV_FIRST_NS was set at t=1000 + pre_change = {SampleField.RECV_FIRST_NS: 1000} + task = trigger.fire(ev, row, pre_change) + assert task is not None + await task + + assert snapshot_series_count(registry, "tpot_ns") == 1 diff --git a/tests/unit/async_utils/services/metrics_aggregator/test_registry.py b/tests/unit/async_utils/services/metrics_aggregator/test_registry.py index d6759973..ce8fa704 100644 --- a/tests/unit/async_utils/services/metrics_aggregator/test_registry.py +++ b/tests/unit/async_utils/services/metrics_aggregator/test_registry.py @@ -207,6 +207,13 @@ def test_float_dtype(self): assert stat.count == 3 assert stat.total == pytest.approx(7.5) + def test_ns_range_sum_sq_is_float(self): + s = self._make() + s.record(_NS_HIGH) + s.record(_NS_HIGH) + assert isinstance(s.build_stat(exact=False).sum_sq, float) + assert isinstance(s.build_stat(exact=True).sum_sq, float) + @pytest.mark.unit class TestMetricsRegistry: diff --git a/tests/unit/async_utils/services/metrics_aggregator/test_snapshot.py b/tests/unit/async_utils/services/metrics_aggregator/test_snapshot.py index ba861094..db07020c 100644 --- a/tests/unit/async_utils/services/metrics_aggregator/test_snapshot.py +++ b/tests/unit/async_utils/services/metrics_aggregator/test_snapshot.py @@ -47,14 +47,22 @@ def test_float_value(self): @pytest.mark.unit class TestSeriesStat: - def test_roundtrip(self): + @pytest.mark.parametrize( + "sum_sq", + [ + 55000, + # Float exceeding uint64 max — would overflow msgpack if encoded as int. + 2.0 * (2**64 - 1), + ], + ) + def test_roundtrip(self, sum_sq): stat = SeriesStat( name="ttft_ns", count=5, total=500, min=50, max=150, - sum_sq=55000, + sum_sq=sum_sq, percentiles={"50": 100.0, "99": 145.0}, histogram=[((50.0, 100.0), 2), ((100.0, 150.0), 3)], ) diff --git a/tests/unit/async_utils/services/metrics_aggregator/test_token_metrics.py b/tests/unit/async_utils/services/metrics_aggregator/test_token_metrics.py index 25cad157..51c6e80d 100644 --- a/tests/unit/async_utils/services/metrics_aggregator/test_token_metrics.py +++ b/tests/unit/async_utils/services/metrics_aggregator/test_token_metrics.py @@ -103,3 +103,84 @@ def test_context_manager(self): assert pool.token_count("a b c") == 3 with pytest.raises(RuntimeError, match="closed"): pool.token_count("test") + + +class _FakeTokenizerWithTemplate(_FakeTokenizer): + """Tokenizer that supports apply_chat_template for tool-call testing.""" + + def apply_chat_template( + self, messages, tokenize=False, add_generation_prompt=False + ): + # Simulate 2 wrapper tokens for the template frame. + parts = ["WRAPPER", "WRAPPER"] + for msg in messages: + content = msg.get("content") + if content: + parts.append(content) + if msg.get("reasoning_content"): + parts.append(msg["reasoning_content"]) + if msg.get("tool_calls"): + import msgspec + + parts.append(msgspec.json.encode(msg["tool_calls"]).decode()) + rendered = " ".join(parts) + if tokenize: + return list(range(len(rendered.split()))) + return rendered + + +@pytest.mark.unit +class TestTokenizePoolMessageTokenization: + def test_token_count_message_subtracts_baseline(self): + """token_count_message returns full_tokens - baseline.""" + with patch(_MOCK_TARGET, _FakeTokenizerWithTemplate): + with TokenizePool("fake", n_workers=1) as pool: + # "hello world" -> 2 content words + 2 wrapper = 4; baseline = 0 + 2 = 2; net = 2 + count = pool.token_count_message("hello world", None, None) + assert count == 2 + + def test_token_count_message_includes_tool_calls(self): + """token_count_message includes tool-call JSON tokens.""" + with patch(_MOCK_TARGET, _FakeTokenizerWithTemplate): + with TokenizePool("fake", n_workers=1) as pool: + tool_calls = ( + { + "id": "c1", + "type": "function", + "function": {"name": "f", "arguments": "{}"}, + }, + ) + count_without = pool.token_count_message("hello", None, None) + count_with = pool.token_count_message("hello", None, tool_calls) + assert count_with > count_without + + def test_token_count_message_fallback_on_exception(self): + """Falls back to whitespace split when apply_chat_template raises.""" + + class _BadTemplateTokenizer(_FakeTokenizer): + def apply_chat_template(self, *args, **kwargs): + raise ValueError("template does not support tool_calls") + + with patch(_MOCK_TARGET, _BadTemplateTokenizer): + with TokenizePool("fake", n_workers=1) as pool: + tool_calls = ( + { + "id": "c1", + "type": "function", + "function": {"name": "f", "arguments": "{}"}, + }, + ) + # Should not raise; falls back to whitespace tokenizer + count = pool.token_count_message("hello world", None, tool_calls) + assert count > 0 + + @pytest.mark.asyncio + async def test_token_count_message_async(self): + """token_count_message_async returns count without blocking event loop.""" + with patch(_MOCK_TARGET, _FakeTokenizerWithTemplate): + loop = asyncio.get_running_loop() + with TokenizePool("fake", n_workers=1) as pool: + count = await pool.token_count_message_async( + "hello world", None, None, loop + ) + assert count == 2 diff --git a/tests/unit/commands/test_precompute_isl.py b/tests/unit/commands/test_precompute_isl.py new file mode 100644 index 00000000..d4ccf852 --- /dev/null +++ b/tests/unit/commands/test_precompute_isl.py @@ -0,0 +1,173 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for _precompute_isl_for_multi_turn.""" + +from unittest.mock import MagicMock, patch + +import pytest +from inference_endpoint.commands.benchmark.execute import _precompute_isl_for_multi_turn + + +def _make_dataloader(samples: list[dict]) -> MagicMock: + dl = MagicMock() + dl.data = samples + return dl + + +class TestPrecomputeIslForMultiTurn: + @pytest.mark.unit + def test_sets_input_tokens_for_samples_with_messages(self): + samples = [ + {"messages": [{"role": "user", "content": "hello"}]}, + {"messages": [{"role": "user", "content": "world"}]}, + ] + dataloader = _make_dataloader(samples) + mock_tokenizer = MagicMock() + mock_tokenizer.apply_chat_template.side_effect = lambda msgs, **_: list( + range(len(msgs) * 3) + ) + + with patch( + "inference_endpoint.commands.benchmark.execute.AutoTokenizer" + ) as mock_cls: + mock_cls.from_pretrained.return_value = mock_tokenizer + _precompute_isl_for_multi_turn(dataloader, "test-model") + + for sample in samples: + assert "input_tokens" in sample + assert isinstance(sample["input_tokens"], list) + + @pytest.mark.unit + def test_leaves_samples_without_messages_untouched(self): + samples = [ + {"prompt": "no messages here"}, + {"input_tokens": [1, 2, 3]}, + ] + dataloader = _make_dataloader(samples) + mock_tokenizer = MagicMock() + + with patch( + "inference_endpoint.commands.benchmark.execute.AutoTokenizer" + ) as mock_cls: + mock_cls.from_pretrained.return_value = mock_tokenizer + _precompute_isl_for_multi_turn(dataloader, "test-model") + + mock_tokenizer.apply_chat_template.assert_not_called() + assert "input_tokens" not in samples[0] + assert samples[1]["input_tokens"] == [1, 2, 3] + + @pytest.mark.unit + def test_skips_failed_template_calls_with_warning(self, caplog): + samples = [ + {"messages": [{"role": "user", "content": "good"}]}, + {"messages": [{"role": "user", "content": "bad"}]}, + ] + dataloader = _make_dataloader(samples) + + def side_effect(msgs, **_): + if msgs[0]["content"] == "bad": + raise ValueError("template error") + return [10, 20, 30] + + mock_tokenizer = MagicMock() + mock_tokenizer.apply_chat_template.side_effect = side_effect + + with patch( + "inference_endpoint.commands.benchmark.execute.AutoTokenizer" + ) as mock_cls: + mock_cls.from_pretrained.return_value = mock_tokenizer + with caplog.at_level("WARNING"): + _precompute_isl_for_multi_turn(dataloader, "test-model") + + assert "input_tokens" in samples[0] + assert "input_tokens" not in samples[1] + assert "1 turn(s) skipped" in caplog.text + + @pytest.mark.unit + def test_batch_encoding_return_value_is_unwrapped(self): + """Tokenizers like Qwen3 return BatchEncoding instead of list[int].""" + samples = [{"messages": [{"role": "user", "content": "hi"}]}] + dataloader = _make_dataloader(samples) + + batch_encoding = MagicMock() + batch_encoding.input_ids = [1, 2, 3] + + mock_tokenizer = MagicMock() + mock_tokenizer.apply_chat_template.return_value = batch_encoding + + with patch( + "inference_endpoint.commands.benchmark.execute.AutoTokenizer" + ) as mock_cls: + mock_cls.from_pretrained.return_value = mock_tokenizer + _precompute_isl_for_multi_turn(dataloader, "test-model") + + assert samples[0]["input_tokens"] == [1, 2, 3] + + @pytest.mark.unit + def test_add_generation_prompt_true(self): + samples = [{"messages": [{"role": "user", "content": "hi"}]}] + dataloader = _make_dataloader(samples) + mock_tokenizer = MagicMock() + mock_tokenizer.apply_chat_template.return_value = [1, 2, 3] + + with patch( + "inference_endpoint.commands.benchmark.execute.AutoTokenizer" + ) as mock_cls: + mock_cls.from_pretrained.return_value = mock_tokenizer + _precompute_isl_for_multi_turn(dataloader, "test-model") + + _, kwargs = mock_tokenizer.apply_chat_template.call_args + assert kwargs.get("add_generation_prompt") is True + assert kwargs.get("tokenize") is True + + @pytest.mark.unit + def test_normalizes_tool_call_arguments_before_apply_chat_template(self): + """_normalize_tool_calls_for_template converts arguments strings to dicts.""" + samples = [ + { + "messages": [ + {"role": "user", "content": "use a tool"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "c1", + "type": "function", + "function": { + "name": "bash", + "arguments": '{"cmd": "ls"}', + }, + } + ], + }, + ] + }, + ] + dataloader = _make_dataloader(samples) + mock_tokenizer = MagicMock() + mock_tokenizer.apply_chat_template.return_value = [1, 2, 3] + + with patch( + "inference_endpoint.commands.benchmark.execute.AutoTokenizer" + ) as mock_cls: + mock_cls.from_pretrained.return_value = mock_tokenizer + _precompute_isl_for_multi_turn(dataloader, "test-model") + + # Production code builds new dicts, so call_args captures the normalized value. + passed_msgs = mock_tokenizer.apply_chat_template.call_args[0][0] + asst_msg = next(m for m in passed_msgs if m.get("role") == "assistant") + assert asst_msg["tool_calls"][0]["function"]["arguments"] == {"cmd": "ls"} diff --git a/tests/unit/config/test_schema.py b/tests/unit/config/test_schema.py index 845711da..ea7dc9ee 100644 --- a/tests/unit/config/test_schema.py +++ b/tests/unit/config/test_schema.py @@ -15,13 +15,19 @@ """Tests for configuration schema models and validation.""" +import random + import pytest +from inference_endpoint import metrics +from inference_endpoint.config.runtime_settings import RuntimeSettings from inference_endpoint.config.schema import ( APIType, BenchmarkConfig, Dataset, DatasetType, EvalMethod, + LoadPattern, + LoadPatternType, ModelParams, OSLDistribution, OSLDistributionType, @@ -463,3 +469,110 @@ def test_openai_completions_endpoint_resolves_adapter(self): assert config.settings.client.api_type is APIType.OPENAI_COMPLETIONS assert config.settings.client.adapter is OpenAITextCompletionsAdapter assert config.settings.client.accumulator is OpenAISSEAccumulator + + +class TestMultiTurnValidation: + """Tests for multi-turn config validation and cross-validation.""" + + def _make_online_multi_turn(self, concurrency: int | None = 4, **ds_kwargs): + lp: dict = {"type": "multi_turn"} + if concurrency is not None: + lp["target_concurrency"] = concurrency + return { + "type": TestType.ONLINE, + "model_params": {"name": "M"}, + "endpoint_config": {"endpoints": ["http://x"]}, + "datasets": [{"path": "D", "multi_turn": {}, **ds_kwargs}], + "settings": {"load_pattern": lp}, + } + + @pytest.mark.unit + def test_multi_turn_valid_config(self): + config = BenchmarkConfig(**self._make_online_multi_turn(concurrency=16)) + assert config.settings.load_pattern.type == LoadPatternType.MULTI_TURN + assert config.settings.load_pattern.target_concurrency == 16 + + @pytest.mark.unit + def test_multi_turn_requires_target_concurrency(self): + with pytest.raises(ValueError, match="Multi-turn requires --concurrency"): + BenchmarkConfig(**self._make_online_multi_turn(concurrency=None)) + + @pytest.mark.unit + def test_multi_turn_without_multi_turn_dataset_rejected(self): + with pytest.raises( + ValueError, + match="requires the performance dataset to have multi_turn config", + ): + BenchmarkConfig( + type=TestType.ONLINE, + model_params={"name": "M"}, + endpoint_config={"endpoints": ["http://x"]}, + datasets=[{"path": "D"}], + settings={ + "load_pattern": {"type": "multi_turn", "target_concurrency": 4} + }, + ) + + @pytest.mark.unit + def test_multi_turn_dataset_without_multi_turn_load_pattern_rejected(self): + with pytest.raises(ValueError, match="requires load_pattern.type=multi_turn"): + BenchmarkConfig( + type=TestType.ONLINE, + model_params={"name": "M"}, + endpoint_config={"endpoints": ["http://x"]}, + datasets=[{"path": "D", "multi_turn": {}}], + settings={"load_pattern": {"type": "poisson", "target_qps": 10}}, + ) + + +class TestMultiTurnTotalSamples: + """Tests for total_samples_to_issue() with multi_turn load pattern.""" + + @pytest.mark.unit + def test_multi_turn_uses_dataset_size_ignoring_duration(self): + config = BenchmarkConfig( + type=TestType.ONLINE, + model_params={"name": "M"}, + endpoint_config={"endpoints": ["http://x"]}, + datasets=[{"path": "D", "multi_turn": {}}], + settings={ + "load_pattern": {"type": "multi_turn", "target_concurrency": 4}, + "runtime": {"min_duration_ms": 600000}, + }, + ) + rt = RuntimeSettings.from_config(config, dataloader_num_samples=4316) + assert rt.total_samples_to_issue() == 4316 + + @pytest.mark.unit + def test_multi_turn_clamps_to_dataset_size(self): + lp = LoadPattern(type=LoadPatternType.MULTI_TURN, target_concurrency=4) + rt = RuntimeSettings( + metric_target=metrics.Throughput(10.0), + reported_metrics=[metrics.Throughput(10.0)], + min_duration_ms=600000, + max_duration_ms=None, + n_samples_from_dataset=5, + n_samples_to_issue=None, + min_sample_count=100, + rng_sched=random.Random(0), + rng_sample_index=random.Random(0), + load_pattern=lp, + ) + assert rt.total_samples_to_issue() == 5 + + @pytest.mark.unit + def test_multi_turn_explicit_n_samples_takes_precedence(self): + lp = LoadPattern(type=LoadPatternType.MULTI_TURN, target_concurrency=4) + rt = RuntimeSettings( + metric_target=metrics.Throughput(10.0), + reported_metrics=[metrics.Throughput(10.0)], + min_duration_ms=600000, + max_duration_ms=None, + n_samples_from_dataset=4316, + n_samples_to_issue=200, + min_sample_count=1, + rng_sched=random.Random(0), + rng_sample_index=random.Random(0), + load_pattern=lp, + ) + assert rt.total_samples_to_issue() == 200 diff --git a/tests/unit/core/test_record.py b/tests/unit/core/test_record.py index 86b108a0..cd9173f6 100644 --- a/tests/unit/core/test_record.py +++ b/tests/unit/core/test_record.py @@ -66,6 +66,8 @@ def test_construction_with_only_event_type_uses_defaults(self): after = time.monotonic_ns() assert before <= record.timestamp_ns <= after assert record.sample_uuid == "" + assert record.conversation_id == "" + assert record.turn is None assert record.data is None @@ -192,6 +194,8 @@ def test_record_with_only_event_type_round_trips_with_defaults(self): decoded = _codec.decode(payload) assert decoded.event_type.topic == SessionEventType.ENDED.topic assert decoded.sample_uuid == "" + assert decoded.conversation_id == "" + assert decoded.turn is None assert decoded.data is None assert decoded.timestamp_ns > 0 @@ -205,6 +209,20 @@ def test_explicit_timestamp_ns_preserved_round_trip(self): decoded = _codec.decode(payload) assert decoded.timestamp_ns == ts + def test_conversation_id_and_turn_round_trip(self): + record = EventRecord( + event_type=SampleEventType.ISSUED, + sample_uuid="q-mt", + conversation_id="conv-7", + turn=4, + data=PromptData(text="hi"), + ) + _, payload = _codec.encode(record) + decoded = _codec.decode(payload) + assert decoded.conversation_id == "conv-7" + assert decoded.turn == 4 + assert decoded.sample_uuid == "q-mt" + class TestEventRecordCodecOnDecodeError: """Tests for the two branches of EventRecordCodec.on_decode_error. diff --git a/tests/unit/core/test_types.py b/tests/unit/core/test_types.py index 52bdbe77..6dc53a28 100644 --- a/tests/unit/core/test_types.py +++ b/tests/unit/core/test_types.py @@ -33,6 +33,7 @@ ) +@pytest.mark.unit class TestErrorData: """Test ErrorData string representation.""" @@ -47,6 +48,7 @@ def test_error_data_str_without_message(self): assert str(err) == "TimeoutError" +@pytest.mark.unit class TestQuerySerialization: """Test Query msgspec.msgpack serialization with various field combinations.""" @@ -186,6 +188,7 @@ def test_query_multiple_roundtrips(self): assert decoded2.created_at == original.created_at +@pytest.mark.unit class TestQueryResultSerialization: """Test QueryResult msgspec.msgpack serialization with various field combinations.""" @@ -467,6 +470,7 @@ def test_query_result_multiple_roundtrips(self): assert decoded2.metadata == original.metadata +@pytest.mark.unit class TestStreamChunkSerialization: """Test StreamChunk msgspec.msgpack serialization with various field combinations.""" @@ -589,6 +593,7 @@ def test_stream_chunk_multiple_roundtrips(self): assert decoded2.metadata == original.metadata +@pytest.mark.unit class TestQueryResultWorkerPatterns: """Test QueryResult serialization patterns used by worker.py (TextModelOutput).""" @@ -690,6 +695,7 @@ def test_query_result_single_output_chunk(self): assert len(decoded.response_output.output) == 1 +@pytest.mark.unit class TestTextAfterFirstChunk: """Test TextModelOutput.text_after_first_chunk() for all reasoning/output combos.""" @@ -744,6 +750,7 @@ def test_text_after_first_chunk(self, reasoning, output, expected): assert tmo.text_after_first_chunk() == expected +@pytest.mark.unit class TestMixedTypeSerialization: """Test serialization of mixed type combinations and edge cases.""" @@ -891,3 +898,86 @@ def test_numeric_types_in_metadata(self): assert decoded.metadata["large_int"] == 9999999999999999 assert decoded.metadata["negative"] == -123.456 assert decoded.metadata["zero"] == 0 + + +@pytest.mark.unit +class TestTextModelOutputToolCalls: + """Test TextModelOutput.tool_calls field: coercion, __str__, text_after_first_chunk.""" + + _TC = [ + {"id": "c1", "type": "function", "function": {"name": "f", "arguments": "{}"}} + ] + + def test_list_coerced_to_tuple(self): + tmo = TextModelOutput(output="hi", tool_calls=self._TC) + assert isinstance(tmo.tool_calls, tuple) + assert len(tmo.tool_calls) == 1 + + def test_none_tool_calls(self): + tmo = TextModelOutput(output="hi", tool_calls=None) + assert tmo.tool_calls is None + + def test_str_includes_tool_calls_json(self): + tmo = TextModelOutput(output="hello", tool_calls=self._TC) + s = str(tmo) + assert "hello" in s + assert '"name"' in s + assert '"function"' in s + + def test_str_without_tool_calls(self): + tmo = TextModelOutput(output="hello") + assert str(tmo) == "hello" + + def test_text_after_first_chunk_includes_tool_calls(self): + # streaming output with tool_calls: first chunk skipped, tool_calls appended + tmo = TextModelOutput(output=("a", "b"), tool_calls=self._TC) + after = tmo.text_after_first_chunk() + assert "b" in after + assert '"function"' in after + + def test_text_after_first_chunk_tool_calls_only_no_content(self): + # tool_calls only (pure agentic response with no text content) + tmo = TextModelOutput(output=[], tool_calls=self._TC) + after = tmo.text_after_first_chunk() + assert '"function"' in after + + def test_as_message_parts_str_output(self): + tmo = TextModelOutput(output="hello", tool_calls=self._TC) + content, reasoning, tc = tmo.as_message_parts() + assert content == "hello" + assert reasoning is None + assert tc == tmo.tool_calls + + def test_as_message_parts_tuple_output(self): + tmo = TextModelOutput( + output=("a", "b"), reasoning=("r1", "r2"), tool_calls=self._TC + ) + content, reasoning, tc = tmo.as_message_parts() + assert content == "ab" + assert reasoning == "r1r2" + assert tc == tmo.tool_calls + + def test_as_message_parts_after_first_chunk_str_output(self): + tmo = TextModelOutput(output="hello", tool_calls=self._TC) + content, reasoning, tc = tmo.as_message_parts_after_first_chunk() + # non-streaming str output: no "after first chunk" for content + assert content == "" + assert reasoning is None + assert tc == tmo.tool_calls + + def test_as_message_parts_after_first_chunk_tuple_output(self): + tmo = TextModelOutput(output=("a", "b", "c"), tool_calls=self._TC) + content, reasoning, tc = tmo.as_message_parts_after_first_chunk() + assert content == "bc" + assert reasoning is None + assert tc == tmo.tool_calls + + def test_serialization_roundtrip_with_tool_calls(self): + import msgspec + + tmo = TextModelOutput(output="hello", tool_calls=self._TC) + encoded = msgspec.msgpack.encode(tmo) + decoded = msgspec.msgpack.decode(encoded, type=TextModelOutput) + assert decoded.tool_calls is not None + assert len(decoded.tool_calls) == 1 + assert decoded.tool_calls[0]["function"]["name"] == "f" diff --git a/tests/unit/dataset_manager/test_multi_turn_dataset.py b/tests/unit/dataset_manager/test_multi_turn_dataset.py new file mode 100644 index 00000000..03659307 --- /dev/null +++ b/tests/unit/dataset_manager/test_multi_turn_dataset.py @@ -0,0 +1,1427 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import tempfile +from collections.abc import Generator +from pathlib import Path + +import pandas as pd +import pytest +from inference_endpoint.dataset_manager.dataset import DatasetFormat +from inference_endpoint.dataset_manager.multi_turn_dataset import MultiTurnDataset +from inference_endpoint.exceptions import InputValidationError + + +@pytest.fixture +def valid_multi_turn_jsonl() -> Generator[str, None, None]: + """Create valid multi-turn conversation JSONL data.""" + data = [ + { + "conversation_id": "conv_001", + "turn": 1, + "role": "user", + "content": "Hello, how are you?", + "system": "You are a helpful assistant", + }, + { + "conversation_id": "conv_001", + "turn": 2, + "role": "assistant", + "content": "I'm doing well, thank you!", + }, + { + "conversation_id": "conv_001", + "turn": 3, + "role": "user", + "content": "What can you help me with?", + }, + { + "conversation_id": "conv_002", + "turn": 1, + "role": "user", + "content": "What's the weather?", + }, + { + "conversation_id": "conv_002", + "turn": 2, + "role": "assistant", + "content": "I don't have access to real-time weather data.", + }, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + yield temp_path + Path(temp_path).unlink() + + +@pytest.fixture +def invalid_role_sequence_jsonl() -> Generator[str, None, None]: + """Create JSONL with invalid role sequence (not alternating).""" + data = [ + {"conversation_id": "conv_001", "turn": 1, "role": "user", "content": "Hello"}, + { + "conversation_id": "conv_001", + "turn": 2, + "role": "user", + "content": "Another user message", + }, # Invalid - consecutive user + { + "conversation_id": "conv_001", + "turn": 3, + "role": "assistant", + "content": "Response", + }, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + yield temp_path + Path(temp_path).unlink() + + +@pytest.fixture +def missing_fields_jsonl() -> Generator[str, None, None]: + """Create JSONL with missing required fields.""" + data = [ + {"conversation_id": "conv_001", "turn": 1, "role": "user"}, # Missing content + { + "conversation_id": "conv_001", + "turn": 2, + "role": "assistant", + "content": "Response", + }, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + yield temp_path + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_multi_turn_dataset_load_valid_data(valid_multi_turn_jsonl): + """Test loading valid multi-turn conversation data.""" + dataset = MultiTurnDataset.load_from_file( + valid_multi_turn_jsonl, format=DatasetFormat.JSONL + ) + dataset.load() + + # data contains only client turns (3 user turns), not all rows + assert len(dataset.data) == 3 + + # Should have 3 user turns (samples) - only user turns are indexed + assert dataset.num_samples() == 3 + + +@pytest.mark.unit +def test_multi_turn_dataset_user_turn_indexing(valid_multi_turn_jsonl): + """Test that only client turns (user + tool) are stored as samples.""" + dataset = MultiTurnDataset.load_from_file( + valid_multi_turn_jsonl, format=DatasetFormat.JSONL + ) + dataset.load() + + # data contains only client turns (fixture has only user turns) + assert dataset.num_samples() == 3 + + # Every sample in data is a client turn + for i in range(dataset.num_samples()): + assert dataset.load_sample(i)["role"] in ("user", "tool") + + +@pytest.mark.unit +def test_multi_turn_dataset_load_sample(valid_multi_turn_jsonl): + """Test load_sample returns correct user turns with dense indexing.""" + dataset = MultiTurnDataset.load_from_file( + valid_multi_turn_jsonl, format=DatasetFormat.JSONL + ) + dataset.load() + + # Sample 0 should be first user turn + sample_0 = dataset.load_sample(0) + assert sample_0["conversation_id"] == "conv_001" + assert sample_0["turn"] == 1 + assert sample_0["role"] == "user" + assert sample_0["content"] == "Hello, how are you?" + # System prompt is the first message in the messages array + assert sample_0["messages"][0]["role"] == "system" + assert sample_0["messages"][0]["content"] == "You are a helpful assistant" + + # Sample 1 should be second user turn (conv_001 turn 3) + sample_1 = dataset.load_sample(1) + assert sample_1["conversation_id"] == "conv_001" + assert sample_1["turn"] == 3 + assert sample_1["role"] == "user" + assert sample_1["content"] == "What can you help me with?" + + # Sample 2 should be third user turn (conv_002 turn 1) + sample_2 = dataset.load_sample(2) + assert sample_2["conversation_id"] == "conv_002" + assert sample_2["turn"] == 1 + assert sample_2["role"] == "user" + assert sample_2["content"] == "What's the weather?" + + +@pytest.mark.unit +def test_multi_turn_dataset_conversation_metadata(valid_multi_turn_jsonl): + """Test conversation metadata generation.""" + dataset = MultiTurnDataset.load_from_file( + valid_multi_turn_jsonl, format=DatasetFormat.JSONL + ) + dataset.load() + + metadata = dataset.conversation_metadata + + # Should have 3 client turn samples (fixture has only user turns, no tool turns) + assert len(metadata.samples) == 3 + + # Should have 2 conversations + assert metadata.num_conversations == 2 + + # Max turns per conversation should be 3 (conv_001 has 3 turns) + assert metadata.max_turns_per_conv == 3 + + # Check sample metadata structure + sample_meta = metadata.samples[0] + assert sample_meta.conversation_id is not None + assert sample_meta.turn is not None + + +@pytest.mark.unit +def test_multi_turn_dataset_validation_invalid_role_sequence( + invalid_role_sequence_jsonl, +): + """Test validation rejects invalid role sequences.""" + # Validation happens during load_from_file (in __init__), not during load() + with pytest.raises(ValueError, match="invalid role sequence"): + MultiTurnDataset.load_from_file( + invalid_role_sequence_jsonl, format=DatasetFormat.JSONL + ) + + +@pytest.mark.unit +def test_multi_turn_dataset_validation_missing_fields(missing_fields_jsonl): + """User rows with missing content are rejected at construction time.""" + with pytest.raises( + InputValidationError, match="user rows must have non-empty 'content'" + ): + MultiTurnDataset.load_from_file( + missing_fields_jsonl, format=DatasetFormat.JSONL + ) + + +@pytest.mark.unit +def test_multi_turn_dataset_multiple_conversations(): + """Test dataset with multiple conversations of varying lengths.""" + data = [ + # Conversation 1: 3 turns (user-assistant-user, missing final assistant) + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "msg1"}, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "resp1"}, + {"conversation_id": "c1", "turn": 3, "role": "user", "content": "msg1b"}, + # Conversation 2: 4 turns (complete user-assistant alternation) + {"conversation_id": "c2", "turn": 1, "role": "user", "content": "msg2"}, + {"conversation_id": "c2", "turn": 2, "role": "assistant", "content": "resp2"}, + {"conversation_id": "c2", "turn": 3, "role": "user", "content": "msg3"}, + {"conversation_id": "c2", "turn": 4, "role": "assistant", "content": "resp3"}, + # Conversation 3: 2 turns (complete user-assistant) + {"conversation_id": "c3", "turn": 1, "role": "user", "content": "msg4"}, + {"conversation_id": "c3", "turn": 2, "role": "assistant", "content": "resp4"}, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + dataset = MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + dataset.load() + + # data contains only client turns: 5 user turns (c1:t1, c1:t3, c2:t1, c2:t3, c3:t1) + assert len(dataset.data) == 5 + assert dataset.num_samples() == 5 + + # Metadata checks + metadata = dataset.conversation_metadata + assert metadata.num_conversations == 3 + assert metadata.max_turns_per_conv == 4 # c2 has 4 turns + + # Verify user turns are correctly indexed + samples = [dataset.load_sample(i) for i in range(5)] + + # Check we got all the user turns + user_turns = [(s["conversation_id"], s["turn"]) for s in samples] + expected_turns = [("c1", 1), ("c1", 3), ("c2", 1), ("c2", 3), ("c3", 1)] + assert sorted(user_turns) == sorted(expected_turns) + + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_multi_turn_dataset_system_prompt_handling(valid_multi_turn_jsonl): + """Test system prompt is included as the first message in the messages array. + + The system prompt is pre-baked into every client turn's message list so the + conversation manager no longer needs to track it separately. + """ + dataset = MultiTurnDataset.load_from_file( + valid_multi_turn_jsonl, format=DatasetFormat.JSONL + ) + dataset.load() + + # First sample: messages starts with system message + sample_0 = dataset.load_sample(0) + assert "messages" in sample_0 + msgs = sample_0["messages"] + assert msgs[0]["role"] == "system" + assert msgs[0]["content"] == "You are a helpful assistant" + + # Second sample (same conversation, turn 3): system message still first + sample_1 = dataset.load_sample(1) + msgs_1 = sample_1["messages"] + assert msgs_1[0]["role"] == "system" + assert msgs_1[0]["content"] == "You are a helpful assistant" + + +@pytest.mark.unit +def test_multi_turn_dataset_single_turn_conversations(): + """Test conversations with only one turn.""" + data = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "Single turn"}, + # No assistant response + { + "conversation_id": "c2", + "turn": 1, + "role": "user", + "content": "Another single", + }, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + dataset = MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + dataset.load() + + # 2 rows, 2 user turns + assert len(dataset.data) == 2 + assert dataset.num_samples() == 2 + + # Both samples should be user turns + assert dataset.load_sample(0)["role"] == "user" + assert dataset.load_sample(1)["role"] == "user" + + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_multi_turn_dataset_empty_conversation(): + """Empty JSONL file raises ValueError (no columns to validate against).""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + temp_path = f.name + + try: + with pytest.raises(ValueError): + MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_multi_turn_dataset_conversation_grouping(): + """Test that properly grouped conversations load correctly.""" + data = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "c1t1"}, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "c1t2"}, + {"conversation_id": "c1", "turn": 3, "role": "user", "content": "c1t3"}, + {"conversation_id": "c2", "turn": 1, "role": "user", "content": "c2t1"}, + {"conversation_id": "c2", "turn": 2, "role": "assistant", "content": "c2t2"}, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + dataset = MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + dataset.load() + + # data contains only client turns: 3 user turns (c1t1, c1t3, c2t1) + assert len(dataset.data) == 3 + assert dataset.num_samples() == 3 + + # Load samples to verify conversation grouping + samples = [dataset.load_sample(i) for i in range(3)] + + # Verify conversation IDs + conv_ids = [s["conversation_id"] for s in samples] + assert conv_ids == ["c1", "c1", "c2"] + + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_multi_turn_dataset_interleaved_conversations_rejected(): + """Test that interleaved conversation rows raise InputValidationError.""" + from inference_endpoint.exceptions import InputValidationError + + data = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "c1t1"}, + {"conversation_id": "c2", "turn": 1, "role": "user", "content": "c2t1"}, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "c1t2"}, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + with pytest.raises(InputValidationError, match="not consecutive"): + MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +@pytest.mark.parametrize( + "rows", + [ + # assistant-first + [ + {"conversation_id": "c1", "turn": 1, "role": "assistant", "content": "A"}, + {"conversation_id": "c1", "turn": 2, "role": "user", "content": "B"}, + ], + # consecutive assistants + [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "A"}, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "B"}, + {"conversation_id": "c1", "turn": 3, "role": "assistant", "content": "C"}, + ], + # tool directly after user (tool-before-assistant) + [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "A"}, + { + "conversation_id": "c1", + "turn": 2, + "role": "tool", + "tool_results": [{"tool_call_id": "x", "content": "r"}], + }, + ], + # consecutive users + [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "A"}, + {"conversation_id": "c1", "turn": 2, "role": "user", "content": "B"}, + ], + ], +) +def test_validation_rejects_invalid_role_sequence(rows): + """Invalid role sequences raise ValueError regardless of turn numbering.""" + with pytest.raises(ValueError, match="invalid role sequence"): + MultiTurnDataset(pd.DataFrame(rows)) + + +@pytest.mark.unit +def test_multi_turn_dataset_additional_fields(): + """Test that additional fields (model, max_new_tokens, etc.) are preserved.""" + data = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "Hello", + "model": "gpt-4", + "max_new_tokens": 256, + "temperature": 0.7, + }, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "Hi"}, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + dataset = MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + dataset.load() + + sample = dataset.load_sample(0) + assert sample["model"] == "gpt-4" + assert sample["max_completion_tokens"] == 256 + assert sample["max_tokens"] == 256 + assert sample["temperature"] == pytest.approx(0.7) + + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_multi_turn_dataset_openai_field_forwarding(): + """Test that OpenAI-specific fields are preserved and forwarded.""" + data = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "Hello", + # OpenAI fields that should be forwarded + "n": 3, + "name": "Alice", + "user": "user_12345", + "logit_bias": {"50256": -100}, + "chat_template": "custom_template", + }, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "Hi"}, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + dataset = MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + dataset.load() + + sample = dataset.load_sample(0) + + # Verify OpenAI fields are present + assert sample.get("n") == 3 + assert sample.get("name") == "Alice" + assert sample.get("user") == "user_12345" + assert sample.get("logit_bias") == {"50256": -100} + assert sample.get("chat_template") == "custom_template" + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_multi_turn_dataset_all_generation_params(): + """Test that dataset-supplied generation parameters are forwarded to the sample.""" + # Create dataset with a representative set of generation params + row_params = { + "model": "test-model", + "max_completion_tokens": 100, + "stream": True, + "temperature": 0.8, + "top_p": 0.95, + "top_k": 50, + "seed": 42, + "repetition_penalty": 1.1, + "frequency_penalty": 0.5, + "presence_penalty": 0.3, + "stop": ["END"], + "n": 2, + "logit_bias": {"100": 10}, + "name": "TestEntity", + "user": "test_user_001", + "chat_template": "test_template", + } + data = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "Test", + **row_params, + }, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": "Response", + }, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + dataset = MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + dataset.load() + + sample = dataset.load_sample(0) + + # All non-NaN row fields must appear in the pre-baked sample + for param in row_params: + assert param in sample, f"Parameter '{param}' not forwarded to sample" + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_validation_rejects_non_contiguous_turns(): + """Turn numbers must be consecutive; gaps are rejected.""" + rows = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "a"}, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "b"}, + {"conversation_id": "c1", "turn": 5, "role": "user", "content": "c"}, + {"conversation_id": "c1", "turn": 6, "role": "assistant", "content": "d"}, + ] + with pytest.raises(ValueError, match="consecutive"): + MultiTurnDataset(pd.DataFrame(rows)) + + +@pytest.mark.unit +def test_validation_rejects_turns_not_starting_at_one(): + """Validation should reject conversations whose turns don't start at 1.""" + data = [ + {"conversation_id": "c1", "turn": 3, "role": "user", "content": "msg"}, + {"conversation_id": "c1", "turn": 4, "role": "assistant", "content": "resp"}, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + with pytest.raises(ValueError, match="consecutive"): + MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_validation_accepts_valid_contiguous_turns(): + """Validation should accept contiguous turn sequences.""" + data = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "msg1"}, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "resp1"}, + {"conversation_id": "c1", "turn": 3, "role": "user", "content": "msg2"}, + {"conversation_id": "c1", "turn": 4, "role": "assistant", "content": "resp2"}, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + dataset = MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + dataset.load() + assert dataset.num_samples() == 2 + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_validation_rejects_turn_starting_at_zero(): + """Validation should reject conversations starting at turn 0.""" + data = [ + {"conversation_id": "c1", "turn": 0, "role": "user", "content": "msg"}, + {"conversation_id": "c1", "turn": 1, "role": "assistant", "content": "resp"}, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + with pytest.raises(ValueError, match="consecutive"): + MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_validation_rejects_duplicate_turn_numbers(): + """Duplicate turn numbers within a conversation are rejected.""" + data = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "msg1"}, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "resp1"}, + {"conversation_id": "c2", "turn": 1, "role": "user", "content": "msg2"}, + {"conversation_id": "c2", "turn": 2, "role": "assistant", "content": "resp2"}, + # c2 has duplicate turn 2 — second assistant row with same turn number + {"conversation_id": "c2", "turn": 2, "role": "user", "content": "dup"}, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + with pytest.raises(ValueError, match="consecutive"): + MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_validation_rejects_assistant_tc_role_literal(): + """role='assistant_tc' literal in dataset is rejected; only 'assistant' is valid.""" + rows = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "Q"}, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant_tc", + "tool_calls": [ + { + "id": "c0", + "type": "function", + "function": {"name": "f", "arguments": "{}"}, + } + ], + }, + { + "conversation_id": "c1", + "turn": 3, + "role": "tool", + "tool_results": [{"tool_call_id": "c0", "content": "r"}], + }, + {"conversation_id": "c1", "turn": 4, "role": "assistant", "content": "A"}, + ] + with pytest.raises(ValueError, match="invalid role sequence"): + MultiTurnDataset(pd.DataFrame(rows)) + + +# ============================================================================ +# Tool sequence tests +# ============================================================================ + + +def _make_tool_sequence_df(): + """Return a DataFrame with a tool sequence embedded between user turns.""" + return pd.DataFrame( + [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "What is the weather?", + "system": "Be helpful", + }, + # assistant (with tool_calls): dispatches a tool call + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_c1_0", + "type": "function", + "function": {"name": "get_weather", "arguments": "{}"}, + } + ], + }, + # tool result + { + "conversation_id": "c1", + "turn": 3, + "role": "tool", + "tool_results": [ + {"tool_call_id": "call_c1_0", "content": '{"temp": 22}'} + ], + }, + # terminal assistant + { + "conversation_id": "c1", + "turn": 4, + "role": "assistant", + "content": "The weather is 22°C.", + }, + # second user turn + { + "conversation_id": "c1", + "turn": 5, + "role": "user", + "content": "Thanks!", + }, + ] + ) + + +@pytest.mark.unit +def test_validation_accepts_tool_sequence(): + """user → assistant → tool → assistant → user passes validation.""" + df = _make_tool_sequence_df() + ds = MultiTurnDataset(df) + ds.load() + assert ds.num_samples() == 3 # user(1), tool(3), user(5) are all client turns + + +@pytest.mark.unit +def test_validation_accepts_parallel_tool_calls(): + """Assistant with two tool_calls + merged tool_results row passes.""" + df = pd.DataFrame( + [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "Hi"}, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "c_0", + "type": "function", + "function": {"name": "f1", "arguments": "{}"}, + }, + { + "id": "c_1", + "type": "function", + "function": {"name": "f2", "arguments": "{}"}, + }, + ], + }, + { + "conversation_id": "c1", + "turn": 3, + "role": "tool", + "tool_results": [ + {"tool_call_id": "c_0", "content": "r1"}, + {"tool_call_id": "c_1", "content": "r2"}, + ], + }, + { + "conversation_id": "c1", + "turn": 4, + "role": "assistant", + "content": "Done", + }, + ] + ) + ds = MultiTurnDataset(df) + ds.load() + assert ds.num_samples() == 2 # user(1), tool(3) are client turns + + +@pytest.mark.unit +def test_load_sample_merged_tool_row_has_no_content_key(): + """load_sample for a merged tool_results row must not emit content: NaN.""" + df = pd.DataFrame( + [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "Go"}, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "c_0", + "type": "function", + "function": {"name": "f1", "arguments": "{}"}, + }, + { + "id": "c_1", + "type": "function", + "function": {"name": "f2", "arguments": "{}"}, + }, + ], + }, + { + "conversation_id": "c1", + "turn": 3, + "role": "tool", + "tool_results": [ + {"tool_call_id": "c_0", "content": "r1"}, + {"tool_call_id": "c_1", "content": "r2"}, + ], + }, + { + "conversation_id": "c1", + "turn": 4, + "role": "assistant", + "content": "Done", + }, + ] + ) + ds = MultiTurnDataset(df) + ds.load() + + # Sample 1 is the merged tool row (turn 3) + s1 = ds.load_sample(1) + assert s1["role"] == "tool" + assert "content" not in s1 # must NOT emit NaN + assert "messages" in s1 + + +@pytest.mark.unit +def test_build_metadata_pre_built_messages(): + """pre_built_messages_by_key contains complete message arrays for each client turn. + + Dataset: + turn 1: user ← client turn 1 + turn 2: asst_tc ← scripted (assistant with tool_calls) + turn 3: tool ← client turn 2 + turn 4: assistant ← terminal assistant + turn 5: user ← client turn 3 + + Expected pre_built_messages: + client turn 1 (t=1): [system, user(1)] + client turn 2 (t=3): [system, user(1), asst_tc(2), tool(3)] + client turn 3 (t=5): [system, user(1), asst_tc(2), tool(3), asst(4), user(5)] + """ + df = _make_tool_sequence_df() + ds = MultiTurnDataset(df) + ds.load() + + pbm = ds.conversation_metadata.pre_built_messages_by_key + + # Client turn 1 (user, t=1): [system, user(1)] + msgs_t1 = pbm[("c1", 1)] + assert len(msgs_t1) == 2 + assert msgs_t1[0] == {"role": "system", "content": "Be helpful"} + assert msgs_t1[1] == {"role": "user", "content": "What is the weather?"} + + # Client turn 2 (tool, t=3): [system, user(1), asst_tc(2), tool(3)] + msgs_t3 = pbm[("c1", 3)] + assert len(msgs_t3) == 4 + assert msgs_t3[0]["role"] == "system" + assert msgs_t3[1]["role"] == "user" + assert msgs_t3[2]["role"] == "assistant" + assert "tool_calls" in msgs_t3[2] + assert msgs_t3[3]["role"] == "tool" + assert msgs_t3[3]["content"] == '{"temp": 22}' + assert msgs_t3[3]["tool_call_id"] == "call_c1_0" + + # Client turn 3 (user, t=5): [system, user(1), asst_tc(2), tool(3), asst(4), user(5)] + msgs_t5 = pbm[("c1", 5)] + assert len(msgs_t5) == 6 + assert msgs_t5[4] == {"role": "assistant", "content": "The weather is 22°C."} + assert msgs_t5[5] == {"role": "user", "content": "Thanks!"} + + +@pytest.mark.unit +def test_build_metadata_pre_built_messages_no_tools(): + """Plain user/assistant alternation produces correct pre_built_messages.""" + df = pd.DataFrame( + [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "A"}, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "B"}, + {"conversation_id": "c1", "turn": 3, "role": "user", "content": "C"}, + ] + ) + ds = MultiTurnDataset(df) + ds.load() + pbm = ds.conversation_metadata.pre_built_messages_by_key + + # Turn 1: just the user message (no system, no prior rows) + assert pbm[("c1", 1)] == [{"role": "user", "content": "A"}] + + # Turn 3: user(1) + assistant(2) + user(3) + msgs = pbm[("c1", 3)] + assert len(msgs) == 3 + assert msgs[0] == {"role": "user", "content": "A"} + assert msgs[1] == {"role": "assistant", "content": "B"} + assert msgs[2] == {"role": "user", "content": "C"} + + +@pytest.mark.unit +def test_load_sample_includes_messages(): + """load_sample returns messages with the complete message list.""" + df = _make_tool_sequence_df() + ds = MultiTurnDataset(df) + ds.load() + + s0 = ds.load_sample(0) # user turn 1 + assert "messages" in s0 + msgs = s0["messages"] + assert msgs[0]["role"] == "system" + assert msgs[-1] == {"role": "user", "content": "What is the weather?"} + + s1 = ds.load_sample(1) # tool turn 3 + assert s1["role"] == "tool" + msgs_t3 = s1["messages"] + # system + user(1) + asst_tc(2) + tool(3) = 4 messages + assert len(msgs_t3) == 4 + assert msgs_t3[-1]["role"] == "tool" + + s2 = ds.load_sample(2) # user turn 5 + msgs_t5 = s2["messages"] + # system + user(1) + asst_tc(2) + tool(3) + asst(4) + user(5) = 6 messages + assert len(msgs_t5) == 6 + + +@pytest.mark.unit +def test_client_turns_include_tool_rows(): + """Tool rows are counted in num_samples() as client turns.""" + df = _make_tool_sequence_df() + ds = MultiTurnDataset(df) + ds.load() + # 5 rows total: user(1), assistant(2), tool(3), assistant(4), user(5) + # Client turns: user(1), tool(3), user(5) → 3 + assert ds.num_samples() == 3 + + +# ============================================================================ +# Pre-built messages content correctness +# ============================================================================ + + +@pytest.mark.unit +def test_messages_include_prior_assistant_response(valid_multi_turn_jsonl): + """The terminal assistant response before each user turn is included in messages.""" + dataset = MultiTurnDataset.load_from_file( + valid_multi_turn_jsonl, format=DatasetFormat.JSONL + ) + dataset.load() + + # Sample 0: turn 1 (first user) → just [system, user(1)] + s0 = dataset.load_sample(0) + msgs_0 = s0["messages"] + assert msgs_0[0]["role"] == "system" + assert msgs_0[-1]["role"] == "user" + + # Sample 1: turn 3 (second user) → [system, user(1), assistant(2), user(3)] + s1 = dataset.load_sample(1) + msgs_1 = s1["messages"] + assert len(msgs_1) == 4 + assert msgs_1[2] == {"role": "assistant", "content": "I'm doing well, thank you!"} + assert msgs_1[3]["role"] == "user" + + # Sample 2: turn 1 of conv_002 → no prior assistant row + s2 = dataset.load_sample(2) + msgs_2 = s2["messages"] + assert all(m["role"] != "assistant" for m in msgs_2) + + +@pytest.mark.unit +def test_messages_no_cross_conversation_bleed(): + """Messages for conv_001 must not appear in conv_002's messages array.""" + data = [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "c1 user"}, + {"conversation_id": "c2", "turn": 1, "role": "user", "content": "c2 user"}, + {"conversation_id": "c2", "turn": 2, "role": "assistant", "content": "c2 resp"}, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + dataset = MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + dataset.load() + + # c1: only its own user message + s_c1 = dataset.load_sample(0) + assert s_c1["messages"] == [{"role": "user", "content": "c1 user"}] + + # c2: only c2 messages (no c1 content) + s_c2 = dataset.load_sample(1) + contents = [m.get("content") for m in s_c2["messages"]] + assert "c1 user" not in contents + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_messages_with_tool_sequence_terminal_assistant(): + """Terminal assistant response (turn 4) appears in messages for user(5).""" + df = _make_tool_sequence_df() + ds = MultiTurnDataset(df) + ds.load() + + s2 = ds.load_sample(2) # user turn 5 + msgs = s2["messages"] + # The terminal assistant at turn 4 should be included + assistant_msgs = [m for m in msgs if m["role"] == "assistant" and m.get("content")] + assert any(m["content"] == "The weather is 22°C." for m in assistant_msgs) + + +# ============================================================================ +# Tool-use flat dataset regression tests (BUG 1, BUG 2, BUG 3) +# ============================================================================ + + +@pytest.mark.unit +def test_prior_tool_row_expanded_with_tool_call_id(): + """Prior tool rows must expand to messages with tool_call_id and content (BUG 1).""" + df = _make_tool_sequence_df() + ds = MultiTurnDataset(df) + ds.load() + pbm = ds.conversation_metadata.pre_built_messages_by_key + + # Client turn 3 (user, t=5) has a prior tool row at t=3. + # msgs_t5[3] should be the expanded tool message with proper fields. + msgs_t5 = pbm[("c1", 5)] + tool_msgs = [m for m in msgs_t5 if m.get("role") == "tool"] + assert len(tool_msgs) == 1 + assert tool_msgs[0]["tool_call_id"] == "call_c1_0" + assert tool_msgs[0]["content"] == '{"temp": 22}' + + +@pytest.mark.unit +def test_prior_parallel_tool_results_expand_to_multiple_messages(): + """Prior turn with 2 parallel tool_results expands to 2 tool messages.""" + df = pd.DataFrame( + [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "Hi"}, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "c_0", + "type": "function", + "function": {"name": "f1", "arguments": "{}"}, + }, + { + "id": "c_1", + "type": "function", + "function": {"name": "f2", "arguments": "{}"}, + }, + ], + }, + { + "conversation_id": "c1", + "turn": 3, + "role": "tool", + "tool_results": [ + {"tool_call_id": "c_0", "content": "r1"}, + {"tool_call_id": "c_1", "content": "r2"}, + ], + }, + { + "conversation_id": "c1", + "turn": 4, + "role": "assistant", + "content": "Done", + }, + {"conversation_id": "c1", "turn": 5, "role": "user", "content": "Ok"}, + ] + ) + ds = MultiTurnDataset(df) + ds.load() + pbm = ds.conversation_metadata.pre_built_messages_by_key + + # user(5) sees prior rows: user(1), assistant(2), tool(3)x2, assistant(4) + msgs_t5 = pbm[("c1", 5)] + tool_msgs = [m for m in msgs_t5 if m.get("role") == "tool"] + assert len(tool_msgs) == 2 + assert tool_msgs[0]["tool_call_id"] == "c_0" + assert tool_msgs[0]["content"] == "r1" + assert tool_msgs[1]["tool_call_id"] == "c_1" + assert tool_msgs[1]["content"] == "r2" + + +@pytest.mark.unit +def test_assistant_content_null_preserved_in_history(): + """Assistant messages with tool_calls and content:null include content key (BUG 2).""" + df = _make_tool_sequence_df() + ds = MultiTurnDataset(df) + ds.load() + pbm = ds.conversation_metadata.pre_built_messages_by_key + + # Client turn 2 (tool, t=3): prior includes assistant(2) with tool_calls + content: null + msgs_t3 = pbm[("c1", 3)] + asst_msg = msgs_t3[2] + assert asst_msg["role"] == "assistant" + assert "tool_calls" in asst_msg + assert "content" in asst_msg + assert asst_msg["content"] is None + + # Also verify in user(5)'s history + msgs_t5 = pbm[("c1", 5)] + asst_tc_msg = msgs_t5[2] + assert asst_tc_msg["role"] == "assistant" + assert "tool_calls" in asst_tc_msg + assert "content" in asst_tc_msg + assert asst_tc_msg["content"] is None + + +@pytest.mark.unit +def test_jsonl_round_trip_with_tools_field(): + """Load from JSONL tmpfile with tools field; verify tools survives to sample dict.""" + data = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "Run the test", + "tools": [ + { + "type": "function", + "function": { + "name": "bash", + "description": "Run a bash command", + "parameters": {"type": "object", "properties": {}}, + }, + } + ], + }, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "tc_0", + "type": "function", + "function": {"name": "bash", "arguments": '{"cmd": "ls"}'}, + } + ], + }, + { + "conversation_id": "c1", + "turn": 3, + "role": "tool", + "tool_results": [{"tool_call_id": "tc_0", "content": "file1.py"}], + "tools": [ + { + "type": "function", + "function": { + "name": "bash", + "description": "Run a bash command", + "parameters": {"type": "object", "properties": {}}, + }, + } + ], + }, + { + "conversation_id": "c1", + "turn": 4, + "role": "assistant", + "content": "The directory contains file1.py", + }, + ] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f: + for item in data: + f.write(json.dumps(item) + "\n") + temp_path = f.name + + try: + dataset = MultiTurnDataset.load_from_file(temp_path, format=DatasetFormat.JSONL) + dataset.load() + + # user(1) has tools + s0 = dataset.load_sample(0) + assert "tools" in s0 + assert len(s0["tools"]) == 1 + assert s0["tools"][0]["function"]["name"] == "bash" + + # tool(3) also has tools + s1 = dataset.load_sample(1) + assert "tools" in s1 + assert s1["tools"][0]["function"]["name"] == "bash" + finally: + Path(temp_path).unlink() + + +@pytest.mark.unit +def test_current_turn_messages_by_key_parallel_tools(): + """current_turn_messages_by_key stores all expanded messages for a tool turn.""" + df = pd.DataFrame( + [ + {"conversation_id": "c1", "turn": 1, "role": "user", "content": "Go"}, + { + "conversation_id": "c1", + "turn": 2, + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "c_0", + "type": "function", + "function": {"name": "f1", "arguments": "{}"}, + }, + { + "id": "c_1", + "type": "function", + "function": {"name": "f2", "arguments": "{}"}, + }, + ], + }, + { + "conversation_id": "c1", + "turn": 3, + "role": "tool", + "tool_results": [ + {"tool_call_id": "c_0", "content": "r1"}, + {"tool_call_id": "c_1", "content": "r2"}, + ], + }, + { + "conversation_id": "c1", + "turn": 4, + "role": "assistant", + "content": "Done", + }, + ] + ) + ds = MultiTurnDataset(df) + ds.load() + ctm = ds.conversation_metadata.current_turn_messages_by_key + + # user(1) current turn is 1 message + assert len(ctm[("c1", 1)]) == 1 + assert ctm[("c1", 1)][0] == {"role": "user", "content": "Go"} + + # tool(3) current turn has 2 expanded messages (parallel tool_results) + assert len(ctm[("c1", 3)]) == 2 + assert ctm[("c1", 3)][0]["tool_call_id"] == "c_0" + assert ctm[("c1", 3)][1]["tool_call_id"] == "c_1" + + +# ============================================================================ +# Fix 1: system_prompts_by_conv in metadata (live-history mode) +# ============================================================================ + + +@pytest.mark.unit +def test_metadata_contains_system_prompts_by_conv(): + """_build_metadata exposes system_prompts_by_conv keyed by conversation_id.""" + data = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "Hi", + "system": "Be concise", + }, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "Ok"}, + # c2 has no system prompt + {"conversation_id": "c2", "turn": 1, "role": "user", "content": "Hello"}, + ] + df = pd.DataFrame(data) + ds = MultiTurnDataset(df) + ds.load() + + spc = ds.conversation_metadata.system_prompts_by_conv + assert spc["c1"] == "Be concise" + assert spc["c2"] is None + + +@pytest.mark.unit +def test_metadata_system_prompts_multiple_convs(): + """Each conversation gets its own system prompt entry.""" + data = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "A", + "system": "Sys1", + }, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "B"}, + { + "conversation_id": "c2", + "turn": 1, + "role": "user", + "content": "C", + "system": "Sys2", + }, + {"conversation_id": "c2", "turn": 2, "role": "assistant", "content": "D"}, + ] + df = pd.DataFrame(data) + ds = MultiTurnDataset(df) + ds.load() + + spc = ds.conversation_metadata.system_prompts_by_conv + assert spc["c1"] == "Sys1" + assert spc["c2"] == "Sys2" + + +# ============================================================================ +# Fix 2: tool_results / tool_calls stripped from sample dicts +# ============================================================================ + + +@pytest.mark.unit +def test_tool_results_not_in_sample_dict(): + """tool_results must not appear in the pre-baked sample dict for tool turns.""" + df = _make_tool_sequence_df() + ds = MultiTurnDataset(df) + ds.load() + + # Sample 1 is the tool turn (turn 3) + s1 = ds.load_sample(1) + assert s1["role"] == "tool" + assert "tool_results" not in s1 + + +@pytest.mark.unit +def test_tool_calls_not_in_sample_dict(): + """tool_calls must not appear in sample dicts (only relevant on assistant rows).""" + data = [ + { + "conversation_id": "c1", + "turn": 1, + "role": "user", + "content": "Go", + "tool_calls": [ + {"id": "bad", "type": "function", "function": {"name": "f"}} + ], + }, + {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "Done"}, + ] + df = pd.DataFrame(data) + ds = MultiTurnDataset(df) + ds.load() + + s0 = ds.load_sample(0) + assert "tool_calls" not in s0 + + +# ============================================================================ +# Fix 3: no dead current_turn_message / system_content fields in sample dicts +# ============================================================================ + + +@pytest.mark.unit +def test_no_dead_current_turn_message_field(): + """current_turn_message must not appear in pre-baked sample dicts.""" + df = _make_tool_sequence_df() + ds = MultiTurnDataset(df) + ds.load() + + for i in range(ds.num_samples()): + s = ds.load_sample(i) + assert ( + "current_turn_message" not in s + ), f"Sample {i} has dead field current_turn_message" + + +@pytest.mark.unit +def test_no_dead_system_content_field(): + """system_content must not appear in pre-baked sample dicts.""" + df = _make_tool_sequence_df() + ds = MultiTurnDataset(df) + ds.load() + + for i in range(ds.num_samples()): + s = ds.load_sample(i) + assert "system_content" not in s, f"Sample {i} has dead field system_content" diff --git a/tests/unit/dataset_manager/test_transforms.py b/tests/unit/dataset_manager/test_transforms.py index ab342204..7dee620f 100644 --- a/tests/unit/dataset_manager/test_transforms.py +++ b/tests/unit/dataset_manager/test_transforms.py @@ -822,6 +822,55 @@ def test_no_matching_columns(self): transform = MakeAdapterCompatible() result = transform(df) - # Should not raise error or create prompt column + # Should not raise error, should not create prompt, and must preserve unrelated columns assert "prompt" not in result.columns assert "unrelated" in result.columns + assert list(result["unrelated"]) == ["data"] + + +class TestAddStaticColumnsNoOverwrite: + """Unit tests for AddStaticColumns(overwrite=False) behavior.""" + + @pytest.mark.unit + def test_fills_missing_columns(self): + """New columns are added when absent.""" + df = pd.DataFrame({"a": [1, 2]}) + result = AddStaticColumns({"b": 10, "c": "x"}, overwrite=False)(df) + assert list(result["b"]) == [10, 10] + assert list(result["c"]) == ["x", "x"] + + @pytest.mark.unit + def test_preserves_existing_non_null_values(self): + """Existing non-null values are not overwritten.""" + df = pd.DataFrame({"a": [1, 2]}) + result = AddStaticColumns({"a": 99}, overwrite=False)(df) + assert list(result["a"]) == [1, 2] + + @pytest.mark.unit + def test_fills_nan_values_in_existing_column(self): + """NaN cells in an existing column are replaced with the default.""" + + df = pd.DataFrame({"a": [1.0, float("nan"), 3.0]}) + result = AddStaticColumns({"a": 99}, overwrite=False)(df) + assert result["a"].tolist()[0] == 1.0 + assert result["a"].tolist()[1] == 99 + assert result["a"].tolist()[2] == 3.0 + + @pytest.mark.unit + def test_skips_none_default_values(self): + """A None default value is ignored; the column is not modified.""" + df = pd.DataFrame({"a": [1]}) + original_a = df["a"].copy() + result = AddStaticColumns({"a": None, "b": None}, overwrite=False)(df) + assert list(result["a"]) == list(original_a) + assert "b" not in result.columns + + @pytest.mark.unit + def test_mixed_nan_and_real_values(self): + """Only NaN cells are filled; real values in the same column are preserved.""" + + df = pd.DataFrame({"temp": [0.9, float("nan"), 0.5]}) + result = AddStaticColumns({"temp": 0.7}, overwrite=False)(df) + assert result["temp"].tolist()[0] == pytest.approx(0.9) + assert result["temp"].tolist()[1] == pytest.approx(0.7) + assert result["temp"].tolist()[2] == pytest.approx(0.5) diff --git a/tests/unit/load_generator/test_async_session.py b/tests/unit/load_generator/test_async_session.py index 9d24618d..81baff6e 100644 --- a/tests/unit/load_generator/test_async_session.py +++ b/tests/unit/load_generator/test_async_session.py @@ -39,6 +39,7 @@ PhaseResult, PhaseType, SessionResult, + _extract_prompt_text, ) from inference_endpoint.metrics.metric import Throughput @@ -160,11 +161,15 @@ def test_issue_builds_query_and_publishes(self): assert len(issuer.issued_queries) == 1 assert issuer.issued_queries[0].id == result assert 3 in phase_issuer.uuid_to_index.values() + # Single-turn callers omit conv_id/turn — defaults flow through. + assert phase_issuer.uuid_to_conv_info[result] == ("", None) # Should have published ISSUED event issued_events = publisher.events_of_type(SampleEventType.ISSUED) assert len(issued_events) == 1 assert issued_events[0].sample_uuid == result + assert issued_events[0].conversation_id == "" + assert issued_events[0].turn is None def test_issue_returns_none_when_stopped(self): dataset = FakeDataset(5) @@ -187,6 +192,23 @@ def test_uuid_is_unique_per_issue(self): ids = [phase_issuer.issue(i % 5) for i in range(10)] assert len(set(ids)) == 10 + def test_issue_stamps_conversation_id_and_turn_on_issued_event(self): + dataset = FakeDataset(5) + issuer = FakeIssuer() + issuer._auto_respond = False + publisher = FakePublisher() + phase_issuer = PhaseIssuer(dataset, issuer, publisher, lambda: False) + + query_id = phase_issuer.issue(2, conversation_id="conv-1", turn=3) + assert query_id is not None + assert phase_issuer.uuid_to_conv_info[query_id] == ("conv-1", 3) + + issued = publisher.events_of_type(SampleEventType.ISSUED) + assert len(issued) == 1 + assert issued[0].sample_uuid == query_id + assert issued[0].conversation_id == "conv-1" + assert issued[0].turn == 3 + # --------------------------------------------------------------------------- # BenchmarkSession tests @@ -570,6 +592,62 @@ async def inject_error(): f"complete at idx {complete_idx}" ) + @pytest.mark.asyncio + async def test_handle_response_stamps_conversation_id_and_turn(self): + """All event types inherit (conv_id, turn) seeded at issue time; streaming + events use .get() so the entry survives for the terminal QueryResult pop.""" + loop = asyncio.get_running_loop() + issuer = FakeIssuer() + publisher = FakePublisher() + session = BenchmarkSession(issuer, publisher, loop) + + phase_issuer = PhaseIssuer(FakeDataset(3), issuer, publisher, lambda: False) + session._current_phase_issuer = phase_issuer + + # Streaming path: entry stays available for the terminal COMPLETE pop. + phase_issuer.uuid_to_conv_info["q-stream"] = ("conv-s", 7) + session._handle_response( + StreamChunk(id="q-stream", metadata={"first_chunk": True}) + ) + session._handle_response(StreamChunk(id="q-stream", response_chunk="delta")) + assert ( + publisher.events_of_type(SampleEventType.RECV_FIRST)[0].conversation_id, + publisher.events_of_type(SampleEventType.RECV_FIRST)[0].turn, + ) == ("conv-s", 7) + assert ( + publisher.events_of_type(SampleEventType.RECV_NON_FIRST)[0].conversation_id, + publisher.events_of_type(SampleEventType.RECV_NON_FIRST)[0].turn, + ) == ("conv-s", 7) + assert "q-stream" in phase_issuer.uuid_to_conv_info + + # Success path: COMPLETE inherits conv info, entry is popped. + phase_issuer.uuid_to_index["q-ok"] = 0 + phase_issuer.uuid_to_conv_info["q-ok"] = ("conv-9", 5) + phase_issuer.inflight = 1 + session._handle_response( + QueryResult(id="q-ok", response_output="ok", completed_at=12345) + ) + complete = publisher.events_of_type(SampleEventType.COMPLETE) + assert [(e.conversation_id, e.turn) for e in complete] == [("conv-9", 5)] + assert "q-ok" not in phase_issuer.uuid_to_conv_info + + # Error path: ERROR (emitted before COMPLETE) also carries conv info. + phase_issuer.uuid_to_index["q-err"] = 1 + phase_issuer.uuid_to_conv_info["q-err"] = ("conv-err", 2) + phase_issuer.inflight = 1 + session._handle_response( + QueryResult( + id="q-err", + error=ErrorData(error_type="boom", error_message="x"), + ) + ) + error_events = [ + e for e in publisher.events if isinstance(e.event_type, ErrorEventType) + ] + assert [(e.conversation_id, e.turn) for e in error_events] == [("conv-err", 2)] + complete = publisher.events_of_type(SampleEventType.COMPLETE) + assert (complete[-1].conversation_id, complete[-1].turn) == ("conv-err", 2) + @pytest.mark.unit class TestBenchmarkSessionPoissonIntegration: @@ -894,3 +972,104 @@ def test_perf_results_filter(self, enable_warmup): assert len(sr.perf_results) == 2 assert len(sr.accuracy_results) == 1 assert sr.perf_results[0].name == "perf1" + + +@pytest.mark.unit +class TestExtractPromptText: + def test_string_content_extracted(self): + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi"}, + ] + assert _extract_prompt_text(messages) == "Hello\nHi" + + def test_multimodal_list_content_text_parts_extracted(self): + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe this image"}, + {"type": "image_url"}, + ], + } + ] + assert _extract_prompt_text(messages) == "Describe this image" + + def test_mixed_string_and_list_content(self): + messages = [ + {"role": "system", "content": "You are helpful"}, + { + "role": "user", + "content": [ + {"type": "text", "text": "What is this?"}, + {"type": "image_url"}, + ], + }, + ] + assert _extract_prompt_text(messages) == "You are helpful\nWhat is this?" + + def test_none_content_skipped(self): + messages = [ + {"role": "assistant", "content": None}, + {"role": "user", "content": "Hello"}, + ] + assert _extract_prompt_text(messages) == "Hello" + + def test_list_content_with_no_text_parts_returns_none(self): + messages = [{"role": "user", "content": [{"type": "image_url"}]}] + assert _extract_prompt_text(messages) is None + + def test_non_dict_messages_skipped(self): + messages = ["not a dict", {"role": "user", "content": "Valid"}] + assert _extract_prompt_text(messages) == "Valid" + + def test_tool_calls_included(self): + messages = [ + {"role": "user", "content": "What's the weather?"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "c1", + "type": "function", + "function": {"name": "get_weather", "arguments": "{}"}, + } + ], + }, + ] + result = _extract_prompt_text(messages) + assert result is not None + assert "What's the weather?" in result + assert "get_weather" in result + + +@pytest.mark.unit +class TestBenchmarkSessionHandleResponse: + """Direct invocation of BenchmarkSession._handle_response (no session.run).""" + + @pytest.mark.asyncio + async def test_drops_late_response_after_timeout(self): + """A late QueryResult for a query already in completed_uuids must be a no-op: + no duplicate ERROR/COMPLETE publish and no second inflight decrement.""" + loop = asyncio.get_running_loop() + dataset = FakeDataset(1) + issuer = FakeIssuer() + publisher = FakePublisher() + phase_issuer = PhaseIssuer(dataset, issuer, publisher, lambda: False) + + phase_issuer.uuid_to_index["q-late"] = 0 + phase_issuer.completed_uuids.add("q-late") + phase_issuer.inflight = 1 + + session = BenchmarkSession(issuer, publisher, loop) + session._current_phase_issuer = phase_issuer + + late_resp = QueryResult( + id="q-late", + error=ErrorData(error_type="late", error_message="late arrival"), + ) + session._handle_response(late_resp) + + assert publisher.events == [] + assert phase_issuer.inflight == 1 diff --git a/tests/unit/load_generator/test_multi_turn_conversation_manager.py b/tests/unit/load_generator/test_multi_turn_conversation_manager.py new file mode 100644 index 00000000..c389fb5f --- /dev/null +++ b/tests/unit/load_generator/test_multi_turn_conversation_manager.py @@ -0,0 +1,220 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import inspect + +import pytest +from inference_endpoint.load_generator.conversation_manager import ( + ConversationManager, + ConversationState, +) + + +@pytest.mark.unit +def test_conversation_state_initialization(): + """ConversationState initializes with correct defaults.""" + state = ConversationState(conversation_id="conv_001") + + assert state.conversation_id == "conv_001" + assert state.message_history == [] + assert state.completed_turns == 0 + assert state.failed_turns == 0 + assert state.expected_client_turns is None + + +@pytest.mark.unit +def test_conversation_state_is_complete_without_expected(): + """is_complete() returns False when expected_client_turns is None.""" + state = ConversationState(conversation_id="conv_001") + assert not state.is_complete() + state.completed_turns = 5 + assert not state.is_complete() + + +@pytest.mark.unit +def test_conversation_state_is_complete_with_expected(): + """is_complete() returns True once completed_turns >= expected.""" + state = ConversationState(conversation_id="conv_001", expected_client_turns=2) + assert not state.is_complete() + state.completed_turns = 1 + assert not state.is_complete() + state.completed_turns = 2 + assert state.is_complete() + + +@pytest.mark.unit +def test_create_is_synchronous(): + """get_or_create() must be a plain function, not a coroutine.""" + manager = ConversationManager() + result = manager.get_or_create("conv_001") + assert not inspect.iscoroutine(result), "get_or_create returned a coroutine" + assert isinstance(result, ConversationState) + + +@pytest.mark.unit +def test_conversation_manager_get_or_create(): + """get_or_create returns the same state for the same conversation_id.""" + manager = ConversationManager() + + state1 = manager.get_or_create("conv_001") + state2 = manager.get_or_create("conv_001") + + assert state1 is state2 + assert state1.conversation_id == "conv_001" + + +@pytest.mark.unit +def test_conversation_manager_multiple_conversations(): + """Manager tracks multiple conversations independently.""" + manager = ConversationManager() + + state1 = manager.get_or_create("conv_001") + state2 = manager.get_or_create("conv_002") + + assert state1 is not state2 + + manager.mark_turn_complete("conv_001", "Response to conv_001") + + assert state1.completed_turns == 1 + assert state2.completed_turns == 0 + + +@pytest.mark.unit +def test_conversation_manager_mark_turn_complete(): + """mark_turn_complete increments counter and appends history.""" + manager = ConversationManager() + state = manager.get_or_create("conv_001") + + manager.mark_turn_complete("conv_001", "Assistant response") + + assert state.completed_turns == 1 + assert state.failed_turns == 0 + assert state.message_history == [] # store_in_history=False by default + + +@pytest.mark.unit +def test_conversation_manager_mark_turn_complete_stores_history(): + """mark_turn_complete appends to history when store_in_history=True.""" + manager = ConversationManager() + state = manager.get_or_create("conv_001") + + manager.mark_turn_complete("conv_001", "Hello", store_in_history=True) + + assert state.message_history == [{"role": "assistant", "content": "Hello"}] + + +@pytest.mark.unit +def test_conversation_manager_mark_turn_failed(): + """mark_turn_failed increments both counters.""" + manager = ConversationManager() + state = manager.get_or_create("conv_001", expected_client_turns=2) + + manager.mark_turn_failed("conv_001") + + assert state.completed_turns == 1 + assert state.failed_turns == 1 + + +@pytest.mark.unit +def test_conversation_completion_tracking(): + """is_complete() returns True after all expected turns receive responses.""" + manager = ConversationManager() + state = manager.get_or_create("conv_001", expected_client_turns=2) + + assert not state.is_complete() + manager.mark_turn_complete("conv_001", "r1") + assert not state.is_complete() + manager.mark_turn_complete("conv_001", "r2") + assert state.is_complete() + + +@pytest.mark.unit +def test_conversation_completion_without_expected_turns(): + """Completion is never True when expected_client_turns is None.""" + manager = ConversationManager() + state = manager.get_or_create("conv_001", expected_client_turns=None) + + manager.mark_turn_complete("conv_001", "r1") + + assert not state.is_complete() + + +@pytest.mark.unit +def test_conversation_completion_with_failures(): + """Conversations complete even when some turns fail.""" + manager = ConversationManager() + state = manager.get_or_create("conv1", expected_client_turns=3) + + manager.mark_turn_complete("conv1", "Hi") + assert not state.is_complete() + + manager.mark_turn_failed("conv1") + assert not state.is_complete() + + manager.mark_turn_complete("conv1", "Bye") + assert state.is_complete() + assert state.failed_turns == 1 + assert state.completed_turns == 3 + + +@pytest.mark.unit +def test_all_turns_fail(): + """Conversation completes when all turns fail.""" + manager = ConversationManager() + state = manager.get_or_create("conv1", expected_client_turns=2) + + manager.mark_turn_failed("conv1") + manager.mark_turn_failed("conv1") + + assert state.is_complete() + assert state.completed_turns == 2 + assert state.failed_turns == 2 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_conversation_manager_concurrent_access(): + """Concurrent pipeline tasks on independent conversations complete without errors.""" + manager = ConversationManager() + num_conversations = 10 + turns_per_conv = 5 + + for i in range(num_conversations): + manager.get_or_create(f"conv_{i:03d}", expected_client_turns=turns_per_conv) + + errors = [] + + async def process_conversation(conv_id: str): + try: + state = manager.get_state(conv_id) + assert state is not None + for _ in range(turns_per_conv): + manager.mark_turn_complete(conv_id, "response") + await asyncio.sleep(0.001) + except Exception as e: + errors.append(f"{conv_id} error: {e}") + + tasks = [ + asyncio.create_task(process_conversation(f"conv_{i:03d}")) + for i in range(num_conversations) + ] + await asyncio.gather(*tasks) + + assert not errors + for i in range(num_conversations): + state = manager._conversations[f"conv_{i:03d}"] + assert state.completed_turns == turns_per_conv + assert state.is_complete() diff --git a/tests/unit/load_generator/test_multi_turn_strategy.py b/tests/unit/load_generator/test_multi_turn_strategy.py new file mode 100644 index 00000000..bd70a2e2 --- /dev/null +++ b/tests/unit/load_generator/test_multi_turn_strategy.py @@ -0,0 +1,723 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for MultiTurnStrategy.""" + +import asyncio +from unittest.mock import MagicMock + +import pytest +from inference_endpoint.core.record import ErrorEventType, SampleEventType +from inference_endpoint.core.types import ErrorData, QueryResult, TextModelOutput +from inference_endpoint.dataset_manager.multi_turn_dataset import ( + ConversationMetadata, + ConversationSampleEntry, +) +from inference_endpoint.load_generator.conversation_manager import ConversationManager +from inference_endpoint.load_generator.multi_turn_strategy import MultiTurnStrategy + + +class FakePhaseIssuer: + """Minimal PhaseIssuerProtocol stub.""" + + def __init__(self, stop_after: int | None = None): + self._count = 0 + self._stop_after = stop_after + self.issued: list[int] = [] + self.issued_count = 0 + self.inflight: int = 0 + self.uuid_to_index: dict[str, int] = {} + self.uuid_to_conv_info: dict[str, tuple[str, int | None]] = {} + self.completed_uuids: set[str] = set() + + def issue( + self, + sample_index: int, + data_override: dict | None = None, + conversation_id: str = "", + turn: int | None = None, + ) -> str | None: + if self._stop_after is not None and self._count >= self._stop_after: + return None + self._count += 1 + self.issued_count += 1 + query_id = f"q{sample_index:04d}" + self.issued.append(sample_index) + self.uuid_to_conv_info[query_id] = (conversation_id, turn) + return query_id + + +def _make_dataset_metadata(conversations: dict[str, list[int]]) -> ConversationMetadata: + """Build ConversationMetadata from {conv_id: [turn_numbers]} mapping.""" + samples = [] + sample_index = 0 + for conv_id, turns in conversations.items(): + for turn in turns: + samples.append( + ConversationSampleEntry( + conversation_id=conv_id, + turn=turn, + sample_index=sample_index, + ) + ) + sample_index += 1 + return ConversationMetadata( + samples=samples, + num_conversations=len(conversations), + max_turns_per_conv=max((max(t) for t in conversations.values()), default=0), + client_turns_per_conversation={c: len(t) for c, t in conversations.items()}, + ) + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_single_conversation_single_turn(): + """Single conversation, single turn — should issue exactly one sample.""" + conv_manager = ConversationManager() + metadata = _make_dataset_metadata({"conv1": [1]}) + strategy = MultiTurnStrategy(conv_manager, metadata) + issuer = FakePhaseIssuer() + + async def complete_turns(): + await asyncio.sleep(0.01) + result = QueryResult( + id="q0000", response_output=TextModelOutput(output="response 1") + ) + strategy.on_sample_complete(result) + + asyncio.create_task(complete_turns()) + count = await strategy.execute(issuer) + + assert count == 1 + assert issuer.issued == [0] + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_single_conversation_multi_turn(): + """Single conversation, 3 turns — turns must be issued sequentially.""" + conv_manager = ConversationManager() + metadata = _make_dataset_metadata({"conv1": [1, 3, 5]}) + strategy = MultiTurnStrategy(conv_manager, metadata) + issuer = FakePhaseIssuer() + + issued_order: list[str] = [] + original_issue = issuer.issue + + def tracked_issue(idx, data_override=None, conversation_id="", turn=None): + q = original_issue( + idx, + data_override=data_override, + conversation_id=conversation_id, + turn=turn, + ) + if q: + issued_order.append(q) + return q + + issuer.issue = tracked_issue + + async def simulate_responses(): + await asyncio.sleep(0.01) + for turn_q, resp in [("q0000", "r1"), ("q0001", "r2"), ("q0002", "r3")]: + result = QueryResult( + id=turn_q, response_output=TextModelOutput(output=resp) + ) + strategy.on_sample_complete(result) + await asyncio.sleep(0.01) + + asyncio.create_task(simulate_responses()) + count = await strategy.execute(issuer) + + assert count == 3 + assert issuer.issued == [0, 1, 2] + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_multiple_conversations_concurrent(): + """Two conversations run concurrently, each with 2 turns.""" + conv_manager = ConversationManager() + metadata = _make_dataset_metadata({"conv1": [1, 3], "conv2": [1, 3]}) + strategy = MultiTurnStrategy(conv_manager, metadata) + issuer = FakePhaseIssuer() + + async def simulate_responses(): + await asyncio.sleep(0.02) + for q_prefix in range(4): + q = f"q{q_prefix:04d}" + result = QueryResult(id=q, response_output=TextModelOutput(output="resp")) + strategy.on_sample_complete(result) + await asyncio.sleep(0.01) + + asyncio.create_task(simulate_responses()) + count = await strategy.execute(issuer) + + assert count == 4 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_turn_ordering_enforced(): + """Turn 2 must not be issued before Turn 1 completes.""" + conv_manager = ConversationManager() + metadata = _make_dataset_metadata({"conv1": [1, 3]}) + strategy = MultiTurnStrategy(conv_manager, metadata) + + issue_timestamps: dict[int, float] = {} + complete_timestamps: dict[int, float] = {} + + class TimedIssuer: + issued_count = 0 + issued: list[int] = [] + + def issue( + self, + idx: int, + data_override: dict | None = None, + conversation_id: str = "", + turn: int | None = None, + ) -> str | None: + import time + + issue_timestamps[idx] = time.monotonic() + self.issued.append(idx) + self.issued_count += 1 + return f"q{idx:04d}" + + issuer = TimedIssuer() + + async def simulate_responses(): + import time + + await asyncio.sleep(0.02) + complete_timestamps[0] = time.monotonic() + result = QueryResult(id="q0000", response_output=TextModelOutput(output="r1")) + strategy.on_sample_complete(result) + await asyncio.sleep(0.05) + complete_timestamps[1] = time.monotonic() + result = QueryResult(id="q0001", response_output=TextModelOutput(output="r2")) + strategy.on_sample_complete(result) + + asyncio.create_task(simulate_responses()) + count = await strategy.execute(issuer) + + assert count == 2 + # Turn 2 (sample index 1) must be issued AFTER turn 1 completes + assert issue_timestamps[1] >= complete_timestamps[0] + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_turn_timeout_triggers_failure(): + """A turn that never completes should timeout and abort remaining turns.""" + conv_manager = ConversationManager() + metadata = _make_dataset_metadata({"conv1": [1, 3]}) + strategy = MultiTurnStrategy(conv_manager, metadata, target_concurrency=None) + strategy._turn_timeout_s = 0.1 # Very short timeout for testing + issuer = FakePhaseIssuer() + + # Do NOT simulate any response — turn 1 will timeout + await strategy.execute(issuer) + + # Only turn 1 should be issued (turn 2 never gets to run) + assert issuer.issued_count == 1 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_on_sample_complete_routes_to_manager(): + """on_sample_complete marks the turn complete in the ConversationManager.""" + conv_manager = ConversationManager() + conv_manager.get_or_create("conv1", expected_client_turns=1) + metadata = _make_dataset_metadata({"conv1": [1]}) + strategy = MultiTurnStrategy(conv_manager, metadata) + + # Simulate issuer registering conv_id in _inflight + strategy._inflight["q0001"] = "conv1" + + result = QueryResult(id="q0001", response_output=TextModelOutput(output="hello")) + strategy.on_sample_complete(result) + + state = conv_manager.get_state("conv1") + assert state is not None + assert state.completed_turns == 1 + assert state.is_complete() + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_error_response_marks_turn_failed(): + """on_sample_complete marks failed when result.error is set.""" + from inference_endpoint.core.types import ErrorData + + conv_manager = ConversationManager() + conv_manager.get_or_create("conv1", expected_client_turns=1) + metadata = _make_dataset_metadata({"conv1": [1]}) + strategy = MultiTurnStrategy(conv_manager, metadata) + + strategy._inflight["q0001"] = "conv1" + strategy._all_done = asyncio.Event() + + result = QueryResult( + id="q0001", + response_output=None, + error=ErrorData(error_type="timeout", error_message="timed out"), + ) + strategy.on_sample_complete(result) + + state = conv_manager.get_state("conv1") + assert state is not None + assert state.failed_turns == 1 + + +def _make_metadata_with_system( + conversations: dict[str, list[int]], + system_prompts: dict[str, str | None] | None = None, +) -> ConversationMetadata: + """Build ConversationMetadata including system_prompts_by_conv.""" + samples = [] + sample_index = 0 + for conv_id, turns in conversations.items(): + for turn in turns: + samples.append( + ConversationSampleEntry( + conversation_id=conv_id, + turn=turn, + sample_index=sample_index, + ) + ) + sample_index += 1 + return ConversationMetadata( + samples=samples, + num_conversations=len(conversations), + max_turns_per_conv=max((max(t) for t in conversations.values()), default=0), + client_turns_per_conversation={c: len(t) for c, t in conversations.items()}, + system_prompts_by_conv=system_prompts or {}, + ) + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_live_history_initializes_system_prompt(): + """In live-history mode, ConversationManager.message_history starts with system message.""" + from inference_endpoint.config.schema import MultiTurnConfig + + conv_manager = ConversationManager() + metadata = _make_metadata_with_system( + {"conv1": [1]}, + system_prompts={"conv1": "Be helpful"}, + ) + mt_cfg = MultiTurnConfig(use_dataset_history=False, turn_timeout_s=10.0) + strategy = MultiTurnStrategy(conv_manager, metadata, multi_turn_config=mt_cfg) + issuer = FakePhaseIssuer() + + async def complete_turn(): + await asyncio.sleep(0.01) + result = QueryResult( + id="q0000", response_output=TextModelOutput(output="response") + ) + strategy.on_sample_complete(result) + + asyncio.create_task(complete_turn()) + await strategy.execute(issuer) + + state = conv_manager.get_state("conv1") + assert state is not None + # message_history[0] must be the system message + assert len(state.message_history) >= 1 + assert state.message_history[0] == {"role": "system", "content": "Be helpful"} + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_live_history_no_system_prompt_when_none(): + """In live-history mode, no system message is prepended when system_prompt is None.""" + from inference_endpoint.config.schema import MultiTurnConfig + + conv_manager = ConversationManager() + metadata = _make_metadata_with_system( + {"conv1": [1]}, + system_prompts={"conv1": None}, + ) + mt_cfg = MultiTurnConfig(use_dataset_history=False, turn_timeout_s=10.0) + strategy = MultiTurnStrategy(conv_manager, metadata, multi_turn_config=mt_cfg) + issuer = FakePhaseIssuer() + + async def complete_turn(): + await asyncio.sleep(0.01) + result = QueryResult( + id="q0000", response_output=TextModelOutput(output="response") + ) + strategy.on_sample_complete(result) + + asyncio.create_task(complete_turn()) + await strategy.execute(issuer) + + state = conv_manager.get_state("conv1") + assert state is not None + # No system message should be in history + system_msgs = [m for m in state.message_history if m.get("role") == "system"] + assert len(system_msgs) == 0 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_dataset_history_mode_does_not_inject_system_prompt(): + """In dataset-history mode (use_dataset_history=True), system_message is not passed.""" + conv_manager = ConversationManager() + metadata = _make_metadata_with_system( + {"conv1": [1]}, + system_prompts={"conv1": "Some system"}, + ) + # Default: use_dataset_history=True → _store_in_history=False + strategy = MultiTurnStrategy(conv_manager, metadata) + issuer = FakePhaseIssuer() + + async def complete_turn(): + await asyncio.sleep(0.01) + result = QueryResult( + id="q0000", response_output=TextModelOutput(output="response") + ) + strategy.on_sample_complete(result) + + asyncio.create_task(complete_turn()) + await strategy.execute(issuer) + + state = conv_manager.get_state("conv1") + assert state is not None + # message_history should be empty (dataset-history mode doesn't accumulate) + assert len(state.message_history) == 0 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_pipeline_error_propagated(): + """execute() re-raises when _issue_next_turn raises an exception.""" + conv_manager = ConversationManager() + metadata = _make_dataset_metadata({"conv1": [1]}) + strategy = MultiTurnStrategy(conv_manager, metadata) + + class ErrorIssuer: + issued_count = 0 + issued: list[int] = [] + + def issue( + self, + idx: int, + data_override: dict | None = None, + conversation_id: str = "", + turn: int | None = None, + ) -> str | None: + raise RuntimeError("simulated pipeline error") + + with pytest.raises(RuntimeError, match="simulated pipeline error"): + await strategy.execute(ErrorIssuer()) + + +@pytest.mark.unit +def test_mark_turn_complete_preserves_tool_calls(): + """mark_turn_complete stores tool_calls in history when metadata contains them.""" + conv_manager = ConversationManager() + conv_manager.get_or_create("conv1", expected_client_turns=1) + + tool_calls = [ + { + "id": "call_1", + "type": "function", + "function": {"name": "bash", "arguments": '{"cmd": "ls"}'}, + } + ] + conv_manager.mark_turn_complete( + "conv1", + response="", + store_in_history=True, + metadata={"tool_calls": tool_calls}, + ) + + state = conv_manager.get_state("conv1") + assert state is not None + assert len(state.message_history) == 1 + msg = state.message_history[0] + assert msg["role"] == "assistant" + assert msg["content"] is None + assert msg["tool_calls"] == tool_calls + + +@pytest.mark.unit +def test_mark_turn_complete_with_response_and_tool_calls(): + """mark_turn_complete stores both content and tool_calls when both are present.""" + conv_manager = ConversationManager() + conv_manager.get_or_create("conv1", expected_client_turns=1) + + tool_calls = [ + { + "id": "call_1", + "type": "function", + "function": {"name": "search", "arguments": "{}"}, + } + ] + conv_manager.mark_turn_complete( + "conv1", + response="Calling search...", + store_in_history=True, + metadata={"tool_calls": tool_calls}, + ) + + state = conv_manager.get_state("conv1") + assert state is not None + msg = state.message_history[0] + assert msg["content"] == "Calling search..." + assert msg["tool_calls"] == tool_calls + + +@pytest.mark.unit +def test_mark_turn_complete_no_history_when_empty(): + """mark_turn_complete does not append when response is empty and no tool_calls.""" + conv_manager = ConversationManager() + conv_manager.get_or_create("conv1", expected_client_turns=1) + + conv_manager.mark_turn_complete("conv1", response="", store_in_history=True) + + state = conv_manager.get_state("conv1") + assert state is not None + assert len(state.message_history) == 0 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_on_sample_complete_passes_metadata(): + """on_sample_complete forwards result.metadata (including tool_calls) to ConversationManager.""" + from inference_endpoint.config.schema import MultiTurnConfig + + conv_manager = ConversationManager() + metadata_dict = _make_metadata_with_system({"conv1": [1]}) + mt_cfg = MultiTurnConfig(use_dataset_history=False, turn_timeout_s=10.0) + strategy = MultiTurnStrategy(conv_manager, metadata_dict, multi_turn_config=mt_cfg) + + conv_manager.get_or_create("conv1", expected_client_turns=1) + strategy._inflight["q0001"] = "conv1" + + tool_calls = [ + { + "id": "call_1", + "type": "function", + "function": {"name": "bash", "arguments": "{}"}, + } + ] + result = QueryResult( + id="q0001", + response_output=TextModelOutput(output=""), + metadata={"tool_calls": tool_calls}, + ) + strategy.on_sample_complete(result) + + state = conv_manager.get_state("conv1") + assert state is not None + assert state.completed_turns == 1 + assert len(state.message_history) == 1 + assert state.message_history[0]["tool_calls"] == tool_calls + assert state.message_history[0]["content"] is None + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_concurrency_limits_active_conversations(): + """target_concurrency=2 starts at most 2 conversation pipelines simultaneously. + + Uses 2-turn conversations so each pipeline has an await point between turns. + With 4 conversations and 2 workers, the 3rd and 4th conversations cannot start + until a worker finishes its current conversation. + """ + conv_manager = ConversationManager() + # 4 two-turn conversations; pipeline awaits turn-1 response before issuing turn-2 + metadata = _make_dataset_metadata( + {"conv1": [1, 2], "conv2": [1, 2], "conv3": [1, 2], "conv4": [1, 2]} + ) + strategy = MultiTurnStrategy(conv_manager, metadata, target_concurrency=2) + issuer = FakePhaseIssuer() + + async def auto_respond(): + already_done = 0 + while True: + while already_done < len(issuer.issued): + idx = issuer.issued[already_done] + q = f"q{idx:04d}" + strategy.on_sample_complete( + QueryResult(id=q, response_output=TextModelOutput(output="r")) + ) + already_done += 1 + await asyncio.sleep(0.02) + + responder_task = asyncio.create_task(auto_respond()) + execute_task = asyncio.create_task(strategy.execute(issuer)) + + # Let both seed turns get issued before auto_respond fires + await asyncio.sleep(0.01) + + # Only 2 workers → exactly 2 turn-1 queries issued (conv3/conv4 not started yet) + assert issuer.issued_count == 2 + + await asyncio.wait_for(execute_task, timeout=5.0) + responder_task.cancel() + + assert issuer.issued_count == 8 # 4 conversations × 2 turns + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_conversation_slot_reuse(): + """With target_concurrency=1, worker completes conv1 before starting conv2. + + The single slot must process both turns of conv1 before conv2's turn 1 is issued. + """ + conv_manager = ConversationManager() + # 2 two-turn conversations; sample indices: conv1→[0,1], conv2→[2,3] + metadata = _make_dataset_metadata({"conv1": [1, 2], "conv2": [1, 2]}) + strategy = MultiTurnStrategy(conv_manager, metadata, target_concurrency=1) + issuer = FakePhaseIssuer() + + async def auto_respond(): + already_done = 0 + while True: + while already_done < len(issuer.issued): + idx = issuer.issued[already_done] + q = f"q{idx:04d}" + strategy.on_sample_complete( + QueryResult(id=q, response_output=TextModelOutput(output="r")) + ) + already_done += 1 + await asyncio.sleep(0.02) + + responder_task = asyncio.create_task(auto_respond()) + await strategy.execute(issuer) + responder_task.cancel() + + # Single slot: conv1 turns (samples 0,1) must be issued before conv2 turns (2,3) + assert issuer.issued[:2] == [0, 1], "Conv1 turns should be issued before conv2" + assert issuer.issued[2:] == [2, 3], "Conv2 turns should follow conv1" + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_timeout_publishes_error_and_complete_events(): + """_handle_timeout publishes ERROR then COMPLETE EventRecords via session_publisher.""" + conv_manager = ConversationManager() + conv_manager.get_or_create("conv-x", expected_client_turns=1) + metadata = _make_dataset_metadata({"conv-x": [1]}) + strategy = MultiTurnStrategy(conv_manager, metadata) + + publisher = MagicMock() + strategy._session_publisher = publisher + + # Seed _inflight so _handle_timeout finds the entry + strategy._inflight["q-x"] = "conv-x" + strategy._active_iters["conv-x"] = iter([]) + + issuer = FakePhaseIssuer() + issuer.uuid_to_index["q-x"] = 0 + issuer.uuid_to_conv_info["q-x"] = ("conv-x", 1) + issuer.inflight = 1 + + strategy._all_done = asyncio.Event() + strategy._loop = asyncio.get_running_loop() + strategy._phase_issuer = issuer + + strategy._handle_timeout("q-x", "conv-x") + + assert publisher.publish.call_count == 2 + assert issuer.inflight == 0 + assert "q-x" in issuer.completed_uuids + # Conv info must be cleared so a late real response can't reuse stale state. + assert "q-x" not in issuer.uuid_to_conv_info + first_call, second_call = publisher.publish.call_args_list + first_record = first_call.args[0] + second_record = second_call.args[0] + + assert first_record.event_type == ErrorEventType.GENERIC + assert first_record.sample_uuid == "q-x" + assert second_record.event_type == SampleEventType.COMPLETE + assert second_record.sample_uuid == "q-x" + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_issue_passes_conversation_id_and_turn_to_phase_issuer(): + """MultiTurnStrategy must forward (conv_id, turn) to phase_issuer.issue().""" + conv_manager = ConversationManager() + conversations = {"conv-A": [1, 2], "conv-B": [1]} + metadata = _make_dataset_metadata(conversations) + strategy = MultiTurnStrategy(conv_manager, metadata) + issuer = FakePhaseIssuer() + + # Build the expected (query_id -> (conv_id, turn)) map from the same + # sample_index ordering _make_dataset_metadata uses, so the test does not + # encode that ordering as magic numbers. + expected: dict[str, tuple[str, int]] = {} + sample_index = 0 + for conv_id, turns in conversations.items(): + for turn in turns: + expected[f"q{sample_index:04d}"] = (conv_id, turn) + sample_index += 1 + + async def respond_in_order(): + await asyncio.sleep(0.01) + for query_id in expected: + strategy.on_sample_complete( + QueryResult(id=query_id, response_output=TextModelOutput(output="ok")) + ) + await asyncio.sleep(0.005) + + asyncio.create_task(respond_in_order()) + await strategy.execute(issuer) + + assert issuer.uuid_to_conv_info == expected + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_error_turn_aborts_remaining_turns(): + """on_sample_complete with result.error aborts and marks-failed remaining turns.""" + conv_manager = ConversationManager() + conv_manager.get_or_create("conv1", expected_client_turns=3) + metadata = _make_dataset_metadata({"conv1": [1, 2, 3]}) + strategy = MultiTurnStrategy(conv_manager, metadata) + issuer = FakePhaseIssuer() + + strategy._all_done = asyncio.Event() + strategy._loop = asyncio.get_running_loop() + strategy._phase_issuer = issuer + + # Seed: conv1 is active with turns 2 and 3 still pending + remaining_turns = iter([(1, 2), (2, 3)]) + strategy._active_iters["conv1"] = remaining_turns + strategy._inflight["q0001"] = "conv1" + strategy._conv_states["conv1"] = conv_manager.get_state("conv1") + + result = QueryResult( + id="q0001", + response_output=None, + error=ErrorData( + error_type="endpoint_error", error_message="500 Internal Server Error" + ), + ) + strategy.on_sample_complete(result) + + # Conversation should no longer be active + assert "conv1" not in strategy._active_iters + # Remaining 2 turns were marked failed + state = conv_manager.get_state("conv1") + assert state is not None + assert state.failed_turns == 3 # the failing turn + 2 dropped diff --git a/tests/unit/openai/test_accumulator.py b/tests/unit/openai/test_accumulator.py new file mode 100644 index 00000000..ae23f7e6 --- /dev/null +++ b/tests/unit/openai/test_accumulator.py @@ -0,0 +1,163 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for OpenAISSEAccumulator tool-call handling.""" + +import pytest +from inference_endpoint.core.types import StreamChunk, TextModelOutput +from inference_endpoint.openai.accumulator import OpenAISSEAccumulator +from inference_endpoint.openai.types import SSEChoice, SSEDelta + + +def _make_tc_partial(idx: int, tc_id: str, name: str, args: str) -> SSEChoice: + """Create an SSEChoice with a single partial tool_call delta.""" + return SSEChoice( + delta=SSEDelta( + tool_calls=[ + { + "index": idx, + "id": tc_id, + "type": "function", + "function": {"name": name, "arguments": args}, + } + ] + ) + ) + + +def _make_content_choice(content: str, first_chunk: bool = False) -> SSEChoice: + return SSEChoice(delta=SSEDelta(content=content)) + + +def _make_reasoning_choice(reasoning: str) -> SSEChoice: + return SSEChoice(delta=SSEDelta(reasoning_content=reasoning)) + + +def _make_finish_choice() -> SSEChoice: + return SSEChoice(finish_reason="tool_calls") + + +@pytest.mark.unit +class TestAccumulatorPureToolCalls: + """Pure tool-call stream (no content/reasoning text chunks).""" + + def test_tool_calls_in_text_output(self): + acc = OpenAISSEAccumulator("qid", stream_all_chunks=False) + acc.add_chunk(_make_tc_partial(0, "c1", "search", '{"q":')) + acc.add_chunk( + SSEChoice( + delta=SSEDelta( + tool_calls=[{"index": 0, "function": {"arguments": '"test"}'}}] + ) + ) + ) + acc.add_chunk(_make_finish_choice()) + + result = acc.get_final_output() + assert isinstance(result.response_output, TextModelOutput) + assert result.response_output.tool_calls is not None + assert len(result.response_output.tool_calls) == 1 + assert result.response_output.tool_calls[0]["function"]["name"] == "search" + assert ( + result.response_output.tool_calls[0]["function"]["arguments"] + == '{"q":"test"}' + ) + + def test_metadata_tool_calls_preserved(self): + acc = OpenAISSEAccumulator("qid", stream_all_chunks=False) + acc.add_chunk(_make_tc_partial(0, "c1", "f", "{}")) + acc.add_chunk(_make_finish_choice()) + + result = acc.get_final_output() + assert "tool_calls" in result.metadata + assert result.metadata["tool_calls"][0]["function"]["name"] == "f" + + def test_first_tool_call_delta_emits_sentinel_stream_chunk(self): + acc = OpenAISSEAccumulator("qid", stream_all_chunks=False) + sentinel = acc.add_chunk(_make_tc_partial(0, "c1", "f", "{}")) + + assert isinstance(sentinel, StreamChunk) + assert sentinel.id == "qid" + assert sentinel.response_chunk == "" + assert sentinel.metadata.get("first_chunk") is True + + def test_subsequent_tool_call_deltas_return_none(self): + acc = OpenAISSEAccumulator("qid", stream_all_chunks=False) + # First delta: sentinel emitted + acc.add_chunk(_make_tc_partial(0, "c1", "f", "{")) + # Second delta: no sentinel + second = acc.add_chunk( + SSEChoice( + delta=SSEDelta( + tool_calls=[{"index": 0, "function": {"arguments": "}"}}] + ) + ) + ) + assert second is None + + def test_first_chunk_sent_after_sentinel(self): + acc = OpenAISSEAccumulator("qid", stream_all_chunks=False) + acc.add_chunk(_make_tc_partial(0, "c1", "f", "{}")) + assert acc.first_chunk_sent is True + + +@pytest.mark.unit +class TestAccumulatorMixedReasoningAndToolCalls: + """Mixed stream: reasoning followed by tool_calls.""" + + def test_reasoning_chunk_is_first_chunk_not_tool_call(self): + acc = OpenAISSEAccumulator("qid", stream_all_chunks=False) + reasoning_chunk = acc.add_chunk(_make_reasoning_choice("Let me think")) + # Reasoning chunk should be the first chunk + assert isinstance(reasoning_chunk, StreamChunk) + assert reasoning_chunk.metadata.get("first_chunk") is True + + # Now a tool_call delta should NOT emit another sentinel (first_chunk_sent=True) + tc_chunk = acc.add_chunk(_make_tc_partial(0, "c1", "f", "{}")) + assert tc_chunk is None + + def test_tool_calls_in_output_after_reasoning(self): + acc = OpenAISSEAccumulator("qid", stream_all_chunks=False) + acc.add_chunk(_make_reasoning_choice("Thinking...")) + acc.add_chunk(_make_tc_partial(0, "c1", "search", '{"q":"x"}')) + acc.add_chunk(_make_finish_choice()) + + result = acc.get_final_output() + assert isinstance(result.response_output, TextModelOutput) + assert result.response_output.reasoning is not None + assert result.response_output.tool_calls is not None + assert result.response_output.tool_calls[0]["function"]["name"] == "search" + + def test_content_then_tool_calls(self): + acc = OpenAISSEAccumulator("qid", stream_all_chunks=False) + acc.add_chunk(_make_content_choice("Hello")) + acc.add_chunk(_make_tc_partial(0, "c1", "f", "{}")) + acc.add_chunk(_make_finish_choice()) + + result = acc.get_final_output() + assert ( + result.response_output.output == ("Hello",) + or result.response_output.output == "Hello" + ) + assert result.response_output.tool_calls is not None + + def test_no_tool_calls_returns_none_field(self): + acc = OpenAISSEAccumulator("qid", stream_all_chunks=False) + acc.add_chunk(_make_content_choice("Hello world")) + acc.add_chunk(_make_finish_choice()) + + result = acc.get_final_output() + assert result.response_output.tool_calls is None + assert "tool_calls" not in result.metadata diff --git a/tests/unit/openai/test_completions_adapter.py b/tests/unit/openai/test_completions_adapter.py index 6045eeb5..c590d7a0 100644 --- a/tests/unit/openai/test_completions_adapter.py +++ b/tests/unit/openai/test_completions_adapter.py @@ -37,7 +37,7 @@ from inference_endpoint.openai.accumulator import OpenAISSEAccumulator from inference_endpoint.openai.completions_adapter import OpenAITextCompletionsAdapter from inference_endpoint.openai.types import ( - SSEDelta, + SSEChoice, TextCompletionSSEChoice, TextCompletionSSEMessage, ) @@ -126,36 +126,38 @@ def test_decode_empty_choices_raises(self): class TestOpenAITextCompletionsAdapterDecodeSSE: @pytest.mark.unit - def test_decode_sse_message_returns_sse_delta(self): + def test_decode_sse_message_returns_sse_choice(self): msg = TextCompletionSSEMessage(choices=(TextCompletionSSEChoice(text="tok"),)) json_bytes = msgspec.json.encode(msg) - delta = OpenAITextCompletionsAdapter.decode_sse_message(json_bytes) - assert isinstance(delta, SSEDelta) - assert delta.content == "tok" + choice = OpenAITextCompletionsAdapter.decode_sse_message(json_bytes) + assert isinstance(choice, SSEChoice) + assert choice.delta is not None + assert choice.delta.content == "tok" @pytest.mark.unit - def test_decode_sse_empty_choices_returns_empty_delta(self): + def test_decode_sse_empty_choices_returns_empty_choice(self): msg = TextCompletionSSEMessage(choices=()) json_bytes = msgspec.json.encode(msg) - delta = OpenAITextCompletionsAdapter.decode_sse_message(json_bytes) - assert isinstance(delta, SSEDelta) - assert delta.content == "" + choice = OpenAITextCompletionsAdapter.decode_sse_message(json_bytes) + assert isinstance(choice, SSEChoice) + assert choice.delta is None @pytest.mark.unit - def test_decode_sse_empty_text_returns_empty_delta(self): + def test_decode_sse_empty_text_returns_choice_with_empty_content(self): msg = TextCompletionSSEMessage(choices=(TextCompletionSSEChoice(text=""),)) json_bytes = msgspec.json.encode(msg) - delta = OpenAITextCompletionsAdapter.decode_sse_message(json_bytes) - assert isinstance(delta, SSEDelta) - assert delta.content == "" + choice = OpenAITextCompletionsAdapter.decode_sse_message(json_bytes) + assert isinstance(choice, SSEChoice) + assert choice.delta is not None + assert choice.delta.content == "" @pytest.mark.unit - def test_sse_delta_compatible_with_openai_accumulator(self): + def test_sse_choice_compatible_with_openai_accumulator(self): acc = OpenAISSEAccumulator(query_id="q1", stream_all_chunks=True) msg = TextCompletionSSEMessage(choices=(TextCompletionSSEChoice(text="hello"),)) json_bytes = msgspec.json.encode(msg) - delta = OpenAITextCompletionsAdapter.decode_sse_message(json_bytes) - chunk = acc.add_chunk(delta) + choice = OpenAITextCompletionsAdapter.decode_sse_message(json_bytes) + chunk = acc.add_chunk(choice) assert chunk is not None assert chunk.response_chunk == "hello" diff --git a/tests/unit/openai/test_msgspec_adapter.py b/tests/unit/openai/test_msgspec_adapter.py new file mode 100644 index 00000000..7360c304 --- /dev/null +++ b/tests/unit/openai/test_msgspec_adapter.py @@ -0,0 +1,193 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for OpenAIMsgspecAdapter with tool call fields.""" + +import json + +import msgspec +import pytest +from inference_endpoint.core.types import Query +from inference_endpoint.openai.openai_msgspec_adapter import ( + OpenAIMsgspecAdapter, + _chat_message_from_dict, +) +from inference_endpoint.openai.types import ChatMessage + + +@pytest.mark.unit +def test_chat_message_tool_calls_serialised(): + """tool_calls field is included in the JSON output when non-None.""" + tool_calls = [ + { + "id": "call_0", + "type": "function", + "function": {"name": "get_weather", "arguments": "{}"}, + } + ] + msg = ChatMessage(role="assistant", tool_calls=tool_calls) + encoded = msgspec.json.encode(msg) + decoded = json.loads(encoded) + assert decoded["role"] == "assistant" + assert decoded["tool_calls"] == tool_calls + assert "content" not in decoded # omit_defaults=True, None omitted + + +@pytest.mark.unit +def test_chat_message_tool_call_id_serialised(): + """tool_call_id field is included in the JSON output when non-None.""" + msg = ChatMessage(role="tool", content="result", tool_call_id="call_0") + encoded = msgspec.json.encode(msg) + decoded = json.loads(encoded) + assert decoded["role"] == "tool" + assert decoded["content"] == "result" + assert decoded["tool_call_id"] == "call_0" + + +@pytest.mark.unit +def test_to_endpoint_request_preserves_tool_calls(): + """to_endpoint_request forwards tool_calls in the messages array.""" + tool_calls = [ + { + "id": "call_0", + "type": "function", + "function": {"name": "lookup", "arguments": '{"q": "test"}'}, + } + ] + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": None, "tool_calls": tool_calls}, + {"role": "tool", "content": "answer", "tool_call_id": "call_0"}, + {"role": "assistant", "content": "Done"}, + ] + query = Query( + id="q1", + data={ + "model": "test-model", + "messages": messages, + }, + ) + request = OpenAIMsgspecAdapter.to_endpoint_request(query) + encoded = msgspec.json.encode(request) + payload = json.loads(encoded) + + msgs = payload["messages"] + # assistant tool-dispatch row + assert msgs[1]["role"] == "assistant" + assert msgs[1]["tool_calls"] == tool_calls + assert "content" not in msgs[1] + # tool result row + assert msgs[2]["role"] == "tool" + assert msgs[2]["tool_call_id"] == "call_0" + assert msgs[2]["content"] == "answer" + # terminal assistant row + assert msgs[3]["content"] == "Done" + + +@pytest.mark.unit +def test_backward_compat_plain_messages_unchanged(): + """Plain user/assistant messages encode identically to before the change.""" + messages = [ + {"role": "system", "content": "You are helpful"}, + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!"}, + ] + query = Query( + id="q2", + data={"model": "m", "messages": messages}, + ) + request = OpenAIMsgspecAdapter.to_endpoint_request(query) + encoded = msgspec.json.encode(request) + payload = json.loads(encoded) + + for i, msg in enumerate(payload["messages"]): + assert msg["role"] == messages[i]["role"] + assert msg["content"] == messages[i]["content"] + assert "tool_calls" not in msg + assert "tool_call_id" not in msg + + +@pytest.mark.unit +def test_chat_message_from_dict_all_fields(): + """_chat_message_from_dict forwards all four optional fields.""" + tool_calls = [ + {"id": "x", "type": "function", "function": {"name": "f", "arguments": "{}"}} + ] + msg = _chat_message_from_dict( + { + "role": "assistant", + "content": None, + "tool_calls": tool_calls, + "tool_call_id": None, + } + ) + assert msg.role == "assistant" + assert msg.content is None + assert msg.tool_calls == tool_calls + assert msg.tool_call_id is None + + +@pytest.mark.unit +def test_chat_message_content_optional(): + """ChatMessage accepts content=None for tool-dispatching assistant turns.""" + msg = ChatMessage(role="assistant", tool_calls=[]) + assert msg.content is None + + +@pytest.mark.unit +def test_from_endpoint_response_populates_tool_calls_in_text_output(): + """Non-streaming response with tool_calls populates TextModelOutput.tool_calls.""" + import json + + tool_calls = [ + { + "id": "call_1", + "type": "function", + "function": {"name": "search", "arguments": '{"q": "test"}'}, + } + ] + response_bytes = json.dumps( + { + "id": "chatcmpl-test", + "object": "chat.completion", + "created": 1234567890, + "model": "test-model", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": None, + "refusal": None, + "tool_calls": tool_calls, + }, + "finish_reason": "tool_calls", + } + ], + "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, + "system_fingerprint": None, + } + ).encode() + + result = OpenAIMsgspecAdapter.decode_response(response_bytes, "q1") + + from inference_endpoint.core.types import TextModelOutput + + assert isinstance(result.response_output, TextModelOutput) + assert result.response_output.tool_calls is not None + assert len(result.response_output.tool_calls) == 1 + assert result.response_output.tool_calls[0]["function"]["name"] == "search" + # metadata path unchanged + assert result.metadata.get("tool_calls") == tool_calls diff --git a/tests/unit/openai/test_openai_adapter.py b/tests/unit/openai/test_openai_adapter.py new file mode 100644 index 00000000..82672763 --- /dev/null +++ b/tests/unit/openai/test_openai_adapter.py @@ -0,0 +1,194 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for OpenAIAdapter tool serialization.""" + +import json + +import msgspec +import pytest +from inference_endpoint.core.types import Query +from inference_endpoint.openai.openai_adapter import OpenAIAdapter + +_TOOL_DEF = { + "type": "function", + "function": { + "name": "search", + "description": "Search the web", + "parameters": { + "type": "object", + "properties": {"q": {"type": "string"}}, + "required": ["q"], + }, + }, +} + +_TOOL_CALLS = [ + { + "id": "call_1", + "type": "function", + "function": {"name": "search", "arguments": '{"q": "test"}'}, + } +] + + +@pytest.mark.unit +def test_tool_definitions_forwarded(): + """tools array in query.data is present in the encoded request.""" + messages = [ + {"role": "user", "content": "Find something"}, + ] + query = Query( + id="q1", + data={ + "model": "test-model", + "messages": messages, + "tools": [_TOOL_DEF], + "max_completion_tokens": 128, + "stream": False, + }, + ) + encoded = OpenAIAdapter.encode_query(query) + payload = json.loads(encoded) + + assert "tools" in payload + assert len(payload["tools"]) == 1 + assert payload["tools"][0]["function"]["name"] == "search" + + +@pytest.mark.unit +def test_tool_use_messages_roundtrip(): + """Full tool-use message sequence encodes and decodes without data loss.""" + messages = [ + {"role": "system", "content": "You are helpful"}, + {"role": "user", "content": "Find something"}, + {"role": "assistant", "content": None, "tool_calls": _TOOL_CALLS}, + {"role": "tool", "content": "search result", "tool_call_id": "call_1"}, + {"role": "assistant", "content": "Here is the answer"}, + ] + query = Query( + id="q1", + data={ + "model": "test-model", + "messages": messages, + "tools": [_TOOL_DEF], + "max_completion_tokens": 128, + "stream": False, + }, + ) + encoded = OpenAIAdapter.encode_query(query) + payload = json.loads(encoded) + + msgs = payload["messages"] + assert msgs[0]["role"] == "system" + assert msgs[1]["role"] == "user" + # assistant tool-dispatch: content is None (Pydantic model_dump includes None fields) + assert msgs[2]["role"] == "assistant" + assert msgs[2]["tool_calls"] == _TOOL_CALLS + assert msgs[2].get("content") is None + # tool result + assert msgs[3]["role"] == "tool" + assert msgs[3]["tool_call_id"] == "call_1" + assert msgs[3]["content"] == "search result" + # terminal assistant + assert msgs[4]["content"] == "Here is the answer" + + +@pytest.mark.unit +def test_encode_request_produces_valid_json_bytes(): + """encode_request returns bytes that msgspec can decode back.""" + messages = [{"role": "user", "content": "Hello"}] + query = Query( + id="q2", + data={ + "model": "m", + "messages": messages, + "max_completion_tokens": 64, + "stream": False, + }, + ) + request = OpenAIAdapter.to_endpoint_request(query) + encoded = OpenAIAdapter.encode_request(request) + + assert isinstance(encoded, bytes) + decoded = msgspec.json.decode(encoded) + assert decoded["messages"][0]["role"] == "user" + + +@pytest.mark.unit +def test_no_tools_key_when_absent(): + """When query.data has no 'tools', the encoded payload has tools=None.""" + messages = [{"role": "user", "content": "Hello"}] + query = Query( + id="q3", + data={ + "model": "m", + "messages": messages, + "max_completion_tokens": 64, + "stream": False, + }, + ) + encoded = OpenAIAdapter.encode_query(query) + payload = json.loads(encoded) + + # Pydantic model_dump includes None fields; tools must be None when not supplied + assert payload.get("tools") is None + + +@pytest.mark.unit +def test_from_endpoint_response_populates_tool_calls_in_text_output(): + """Non-streaming response with tool_calls populates TextModelOutput.tool_calls.""" + tool_calls_data = [ + { + "id": "call_1", + "type": "function", + "function": {"name": "search", "arguments": '{"q": "test"}'}, + } + ] + response_bytes = json.dumps( + { + "id": "chatcmpl-test", + "object": "chat.completion", + "created": 1234567890, + "model": "test-model", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": None, + "refusal": "", + "tool_calls": tool_calls_data, + }, + "finish_reason": "tool_calls", + "logprobs": {"content": [], "refusal": []}, + } + ], + "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, + "system_fingerprint": None, + } + ).encode() + + result = OpenAIAdapter.decode_response(response_bytes, "q1") + + from inference_endpoint.core.types import TextModelOutput + + assert isinstance(result.response_output, TextModelOutput) + assert result.response_output.tool_calls is not None + assert len(result.response_output.tool_calls) == 1 + assert result.response_output.tool_calls[0]["function"]["name"] == "search" + # metadata path preserved + assert result.metadata.get("tool_calls") is not None + assert result.metadata["tool_calls"][0]["function"]["name"] == "search"