diff --git a/pyproject.toml b/pyproject.toml index f511b65..2ff5ddb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,15 +30,19 @@ dependencies = [ source = "vcs" # Tags look like ``renderers-v0.1.8`` (prefix matches the publish.yml # release contract); strip the prefix to get a PEP 440 version. The -# regex accepts any PEP 440-valid suffix after the prefix so we can -# tag pre-releases like ``renderers-v0.2.0rc1`` later if needed. -tag-pattern = '^renderers-v(?P.+)$' +# The regex intentionally ignores historical ``.devN`` tags except +# ``.dev0``. hatch-vcs uses distance-from-tag to calculate dev versions, +# so anchoring at ``.dev1`` or later makes git dependency builds fail. +tag-pattern = '^renderers-v(?P\d+\.\d+\.\d+(?:(?:a|b|rc)\d+|\.dev0)?)$' # Used when building from a context without VCS metadata (e.g. an # sdist consumed by a downstream that doesn't ship .git). Real # builds from a checkout get the resolved version; this fallback # only fires when the resolver has nothing to go on. fallback-version = "0.0.0" +[tool.hatch.version.raw-options] +git_describe_command = "git describe --dirty --tags --long --match renderers-v[0-9]* --exclude renderers-v*.dev[1-9]*" + [tool.hatch.build.hooks.vcs] # Write the resolved version to a Python file so it can be inspected # at runtime via ``renderers.__version__`` without re-parsing the diff --git a/renderers/__init__.py b/renderers/__init__.py index 62bc666..9c0f0c4 100644 --- a/renderers/__init__.py +++ b/renderers/__init__.py @@ -36,6 +36,7 @@ reject_assistant_in_extension, trim_to_turn_close, ) +from renderers.client import RendererTransport from renderers.deepseek_v3 import DeepSeekV3Renderer from renderers.default import DefaultRenderer from renderers.glm5 import GLM5Renderer @@ -80,6 +81,7 @@ "RenderedTokens", "Renderer", "RendererPool", + "RendererTransport", "TextPart", "ThinkingPart", "ToolCall", diff --git a/renderers/client.py b/renderers/client.py index 902e148..93bcbc3 100644 --- a/renderers/client.py +++ b/renderers/client.py @@ -1,7 +1,9 @@ -"""Renderer-based generate client for vLLM 0.20's /inference/v1/generate. +"""Renderer-based generate client for vLLM 0.20 and Dynamo token-in routes. - messages → Renderer.render_ids() → token IDs → POST /inference/v1/generate - → completion tokens → Renderer.parse_response() → structured message +Two transports are selected per call: + + "vllm" → POST /inference/v1/generate + "dynamo" → POST /chat/completions with nvext.token_data When a RendererPool is passed instead of a single Renderer, the sync tokenization and parsing work is offloaded to threads for parallel execution across rollouts. @@ -14,7 +16,7 @@ import asyncio import base64 import logging -from typing import Any, cast +from typing import Any, Literal, cast import numpy as np from openai import AsyncOpenAI, BadRequestError @@ -28,6 +30,8 @@ ToolSpec, ) +RendererTransport = Literal["vllm", "dynamo"] + _request_logger = logging.getLogger("renderers.client") @@ -58,12 +62,13 @@ async def generate( cache_salt: str | None = None, priority: int | None = None, extra_headers: dict[str, str] | None = None, + transport: RendererTransport = "vllm", ) -> dict[str, Any]: - """Tokenize messages, call vLLM /inference/v1/generate, parse the response. + """Tokenize messages, call the selected token-in backend, parse response. - ``sampling_params`` is forwarded to vLLM verbatim. Two fields are always - set by us and override caller values: ``stop_token_ids`` (from the - renderer) and ``logprobs=1`` (we always emit completion_logprobs). Pass + ``sampling_params`` is forwarded to the selected token-in backend. Two + fields are always set by us and override caller values: stop token IDs + from the renderer and ``logprobs=1`` (we always emit completion_logprobs). Pass ``prompt_ids`` to skip rendering and use a prebuilt token sequence — pair it with ``multi_modal_data`` when the prebuilt prompt has image / video placeholders that need engine-side mm payload. @@ -101,31 +106,72 @@ def _prepare(): sp["logprobs"] = 1 sp.setdefault("skip_special_tokens", False) - body: dict[str, Any] = { - "model": model, - "token_ids": prompt_ids, - "sampling_params": sp, - } - features = ( - _build_mm_features(renderer, mm_data) - if mm_data and not mm_data.is_empty() - else None - ) - if features is not None: - body["features"] = features - if cache_salt is not None: - body["cache_salt"] = cache_salt - if priority is not None: - body["priority"] = priority - - # /inference/v1/generate is mounted at the server root, not under /v1 - # like the OpenAI-compatible endpoints. Build an absolute URL so the - # AsyncOpenAI client doesn't prepend its automatic /v1. - base = str(client.base_url).rstrip("/").removesuffix("/v1") - endpoint = f"{base}/inference/v1/generate" + if transport == "vllm": + body: dict[str, Any] = { + "model": model, + "token_ids": prompt_ids, + "sampling_params": sp, + } + features = ( + _build_mm_features(renderer, mm_data) + if mm_data and not mm_data.is_empty() + else None + ) + if features is not None: + body["features"] = features + if cache_salt is not None: + body["cache_salt"] = cache_salt + if priority is not None: + body["priority"] = priority + + # /inference/v1/generate is mounted at the server root, not under /v1 + # like the OpenAI-compatible endpoints. Build an absolute URL so the + # AsyncOpenAI client doesn't prepend its automatic /v1. + base = str(client.base_url).rstrip("/").removesuffix("/v1") + endpoint = f"{base}/inference/v1/generate" + elif transport == "dynamo": + if mm_data and not mm_data.is_empty(): + raise NotImplementedError( + "Multimodal renderers are not yet supported on the dynamo transport." + ) + nvext: dict[str, Any] = { + "token_data": prompt_ids, + "extra_fields": ["engine_data"], + } + if priority is not None: + nvext["agent_hints"] = {"priority": priority} + + body = { + "model": model, + "messages": [{"role": "user", "content": "(token-in mode)"}], + "stream": False, + "logprobs": True, + "nvext": nvext, + } + if tools: + body["tools"] = tools + if stop_token_ids: + body["stop_token_ids"] = stop_token_ids + if cache_salt is not None: + nvext["cache_salt"] = cache_salt + + passthrough = dict(sp) + passthrough.pop("stop_token_ids", None) + passthrough.pop("stop", None) + passthrough.pop("logprobs", None) + passthrough.pop("skip_special_tokens", None) + max_tokens = passthrough.pop("max_tokens", None) + if max_tokens is not None: + body["max_completion_tokens"] = max_tokens + body.update({k: v for k, v in passthrough.items() if v is not None}) + endpoint = "/chat/completions" + else: + raise ValueError(f"Unsupported renderer transport: {transport}") + _request_logger.debug( - "POST %s prompt_len=%d max_tokens=%s", + "POST %s transport=%s prompt_len=%d max_tokens=%s", endpoint, + transport, len(prompt_ids), sp.get("max_tokens"), ) @@ -147,7 +193,31 @@ def _prepare(): raise choice = (data.get("choices") or [{}])[0] - completion_ids = choice.get("token_ids") or [] + if transport == "dynamo": + choice_nvext = choice.get("nvext") or {} + response_nvext = data.get("nvext") or {} + choice_engine_data = choice_nvext.get("engine_data") or {} + response_engine_data = response_nvext.get("engine_data") or {} + completion_ids = ( + choice.get("token_ids") + or choice_nvext.get("completion_token_ids") + or response_nvext.get("completion_token_ids") + or choice_engine_data.get("completion_token_ids") + or response_engine_data.get("completion_token_ids") + or [] + ) + raw_re = ( + choice.get("routed_experts") + or choice_nvext.get("routed_experts") + or response_nvext.get("routed_experts") + or choice_engine_data.get("routed_experts") + or response_engine_data.get("routed_experts") + ) + request_id = data.get("id") or data.get("request_id") or "" + else: + completion_ids = choice.get("token_ids") or [] + raw_re = choice.get("routed_experts") + request_id = data.get("request_id") or "" parsed = await _maybe_offload( renderer, lambda: renderer.parse_response(completion_ids, tools=tools) @@ -157,9 +227,17 @@ def _prepare(): raw_logprobs = choice.get("logprobs") or {} content_lp = raw_logprobs.get("content") if isinstance(raw_logprobs, dict) else None completion_logprobs = [float(c.get("logprob") or 0.0) for c in content_lp or []] + if not completion_logprobs and transport == "dynamo": + choice_nvext = choice.get("nvext") or {} + response_nvext = data.get("nvext") or {} + engine_logprobs = ( + (choice_nvext.get("engine_data") or {}).get("completion_logprobs") + or (response_nvext.get("engine_data") or {}).get("completion_logprobs") + or [] + ) + completion_logprobs = [float(logprob) for logprob in engine_logprobs] routed_experts = None - raw_re = choice.get("routed_experts") if isinstance(raw_re, dict) and "data" in raw_re and "shape" in raw_re: routed_experts = ( np.frombuffer(base64.b85decode(raw_re["data"]), dtype=np.int32) @@ -183,7 +261,7 @@ def _prepare(): finish_reason = "tool_calls" return { - "request_id": data.get("request_id") or "", + "request_id": request_id, "prompt_ids": list(prompt_ids), "completion_ids": list(completion_ids), "completion_logprobs": completion_logprobs, diff --git a/tests/test_client.py b/tests/test_client.py index d724234..3ade6f5 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -2,8 +2,10 @@ import base64 import numpy as np +import pytest from renderers.base import ( + MultiModalData, ParsedResponse, ParsedToolCall, RenderedTokens, @@ -49,19 +51,27 @@ def parse_response( ) +class _NoStopRenderer(_FakeRenderer): + def get_stop_token_ids(self): + return [] + + class _FakeClient: """Mocks AsyncOpenAI's `.post()`. The renderer client builds an absolute URL off ``client.base_url``, so we expose one that includes the /v1 suffix the OpenAI SDK normally appends.""" - def __init__(self): + def __init__(self, response=None): self.calls = [] self.base_url = "http://fake-host:8000/v1" + self.response = response async def post(self, path, *, cast_to=dict, body=None, options=None): self.calls.append( {"path": path, "cast_to": cast_to, "body": body, "options": options} ) + if self.response is not None: + return self.response routed_experts = np.array([[[1]], [[2]]], dtype=np.int32) return { "request_id": "gen-test", @@ -291,3 +301,157 @@ def test_generate_serializes_multimodal_features_for_qwen3_vl(): # Items are base64 strings (encode_mm_kwargs_item output). for item in features["kwargs_data"]["image"]: assert isinstance(item, str) and len(item) > 0 + + +def test_generate_can_use_dynamo_transport(): + client = _FakeClient( + response={ + "id": "chatcmpl-test", + "model": "test-model", + "nvext": { + "engine_data": { + "completion_token_ids": [7, 8], + } + }, + "choices": [ + { + "logprobs": { + "content": [ + {"token": "token_id:7", "logprob": -0.1}, + {"token": "token_id:8", "logprob": -0.2}, + ] + }, + "finish_reason": "stop", + } + ], + } + ) + + result = asyncio.run( + generate( + client=client, + renderer=_FakeRenderer(), + messages=[{"role": "user", "content": "hi"}], + model="test-model", + tools=[{"type": "function", "function": {"name": "echo"}}], + sampling_params={ + "temperature": 0.3, + "max_tokens": 7, + "min_tokens": 2, + "stop": "caller-stop", + }, + priority=4, + cache_salt="ckpt-42", + transport="dynamo", + ) + ) + + assert client.calls[0]["path"] == "/chat/completions" + assert client.calls[0]["body"] == { + "model": "test-model", + "messages": [{"role": "user", "content": "(token-in mode)"}], + "stream": False, + "logprobs": True, + "stop_token_ids": [99], + "tools": [{"type": "function", "function": {"name": "echo"}}], + "nvext": { + "token_data": [1, 2, 3], + "extra_fields": ["engine_data"], + "agent_hints": {"priority": 4}, + "cache_salt": "ckpt-42", + }, + "max_completion_tokens": 7, + "temperature": 0.3, + "min_tokens": 2, + } + assert result["request_id"] == "chatcmpl-test" + assert result["prompt_ids"] == [1, 2, 3] + assert result["completion_ids"] == [7, 8] + assert result["completion_logprobs"] == [-0.1, -0.2] + assert result["finish_reason"] == "tool_calls" + + +def test_generate_dynamo_omits_empty_stop_token_ids(): + client = _FakeClient( + response={ + "id": "chatcmpl-test", + "model": "test-model", + "nvext": { + "engine_data": { + "completion_token_ids": [7, 8], + } + }, + "choices": [{"finish_reason": "stop"}], + } + ) + + asyncio.run( + generate( + client=client, + renderer=_NoStopRenderer(), + messages=[{"role": "user", "content": "hi"}], + model="test-model", + tools=[{"type": "function", "function": {"name": "echo"}}], + sampling_params={ + "max_tokens": 7, + "stop": "caller-stop", + "stop_token_ids": [123], + }, + transport="dynamo", + ) + ) + + body = client.calls[0]["body"] + assert "stop" not in body + assert "stop_token_ids" not in body + + +def test_generate_dynamo_reads_engine_data_fallback(): + client = _FakeClient( + response={ + "id": "chatcmpl-test", + "model": "test-model", + "nvext": { + "engine_data": { + "completion_token_ids": [7, 8], + "completion_logprobs": [-0.3, -0.4], + } + }, + "choices": [{"finish_reason": "stop"}], + } + ) + + result = asyncio.run( + generate( + client=client, + renderer=_FakeRenderer(), + messages=[{"role": "user", "content": "hi"}], + model="test-model", + prompt_ids=[1, 2, 3], + transport="dynamo", + ) + ) + + assert result["completion_ids"] == [7, 8] + assert result["completion_logprobs"] == [-0.3, -0.4] + + +class _MultiModalRenderer(_FakeRenderer): + def render(self, messages, *, tools=None, add_generation_prompt=False): + return RenderedTokens( + token_ids=[1, 2, 3], + multi_modal_data=MultiModalData(mm_hashes={"image": ["aaa"]}), + ) + + +def test_generate_dynamo_rejects_multimodal_sidecar(): + with pytest.raises(NotImplementedError, match="dynamo transport"): + asyncio.run( + generate( + client=_FakeClient(), + renderer=_MultiModalRenderer(), + messages=[{"role": "user", "content": "hi"}], + model="test-model", + transport="dynamo", + ) + )