Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,19 @@ dependencies = [
source = "vcs"
# Tags look like ``renderers-v0.1.8`` (prefix matches the publish.yml
# release contract); strip the prefix to get a PEP 440 version. The
# regex accepts any PEP 440-valid suffix after the prefix so we can
# tag pre-releases like ``renderers-v0.2.0rc1`` later if needed.
tag-pattern = '^renderers-v(?P<version>.+)$'
# The regex intentionally ignores historical ``.devN`` tags except
# ``.dev0``. hatch-vcs uses distance-from-tag to calculate dev versions,
# so anchoring at ``.dev1`` or later makes git dependency builds fail.
tag-pattern = '^renderers-v(?P<version>\d+\.\d+\.\d+(?:(?:a|b|rc)\d+|\.dev0)?)$'
# Used when building from a context without VCS metadata (e.g. an
# sdist consumed by a downstream that doesn't ship .git). Real
# builds from a checkout get the resolved version; this fallback
# only fires when the resolver has nothing to go on.
fallback-version = "0.0.0"

[tool.hatch.version.raw-options]
git_describe_command = "git describe --dirty --tags --long --match renderers-v[0-9]* --exclude renderers-v*.dev[1-9]*"

[tool.hatch.build.hooks.vcs]
# Write the resolved version to a Python file so it can be inspected
# at runtime via ``renderers.__version__`` without re-parsing the
Expand Down
2 changes: 2 additions & 0 deletions renderers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
reject_assistant_in_extension,
trim_to_turn_close,
)
from renderers.client import RendererTransport
from renderers.deepseek_v3 import DeepSeekV3Renderer
from renderers.default import DefaultRenderer
from renderers.glm5 import GLM5Renderer
Expand Down Expand Up @@ -80,6 +81,7 @@
"RenderedTokens",
"Renderer",
"RendererPool",
"RendererTransport",
"TextPart",
"ThinkingPart",
"ToolCall",
Expand Down
146 changes: 112 additions & 34 deletions renderers/client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Renderer-based generate client for vLLM 0.20's /inference/v1/generate.
"""Renderer-based generate client for vLLM 0.20 and Dynamo token-in routes.

messages → Renderer.render_ids() → token IDs → POST /inference/v1/generate
→ completion tokens → Renderer.parse_response() → structured message
Two transports are selected per call:

"vllm" → POST /inference/v1/generate
"dynamo" → POST /chat/completions with nvext.token_data

When a RendererPool is passed instead of a single Renderer, the sync tokenization
and parsing work is offloaded to threads for parallel execution across rollouts.
Expand All @@ -14,7 +16,7 @@
import asyncio
import base64
import logging
from typing import Any, cast
from typing import Any, Literal, cast

import numpy as np
from openai import AsyncOpenAI, BadRequestError
Expand All @@ -28,6 +30,8 @@
ToolSpec,
)

RendererTransport = Literal["vllm", "dynamo"]

_request_logger = logging.getLogger("renderers.client")


Expand Down Expand Up @@ -58,12 +62,13 @@ async def generate(
cache_salt: str | None = None,
priority: int | None = None,
extra_headers: dict[str, str] | None = None,
transport: RendererTransport = "vllm",
) -> dict[str, Any]:
"""Tokenize messages, call vLLM /inference/v1/generate, parse the response.
"""Tokenize messages, call the selected token-in backend, parse response.

``sampling_params`` is forwarded to vLLM verbatim. Two fields are always
set by us and override caller values: ``stop_token_ids`` (from the
renderer) and ``logprobs=1`` (we always emit completion_logprobs). Pass
``sampling_params`` is forwarded to the selected token-in backend. Two
fields are always set by us and override caller values: stop token IDs
from the renderer and ``logprobs=1`` (we always emit completion_logprobs). Pass
``prompt_ids`` to skip rendering and use a prebuilt token sequence —
pair it with ``multi_modal_data`` when the prebuilt prompt has image /
video placeholders that need engine-side mm payload.
Expand Down Expand Up @@ -101,31 +106,72 @@ def _prepare():
sp["logprobs"] = 1
sp.setdefault("skip_special_tokens", False)

body: dict[str, Any] = {
"model": model,
"token_ids": prompt_ids,
"sampling_params": sp,
}
features = (
_build_mm_features(renderer, mm_data)
if mm_data and not mm_data.is_empty()
else None
)
if features is not None:
body["features"] = features
if cache_salt is not None:
body["cache_salt"] = cache_salt
if priority is not None:
body["priority"] = priority

# /inference/v1/generate is mounted at the server root, not under /v1
# like the OpenAI-compatible endpoints. Build an absolute URL so the
# AsyncOpenAI client doesn't prepend its automatic /v1.
base = str(client.base_url).rstrip("/").removesuffix("/v1")
endpoint = f"{base}/inference/v1/generate"
if transport == "vllm":
body: dict[str, Any] = {
"model": model,
"token_ids": prompt_ids,
"sampling_params": sp,
}
features = (
_build_mm_features(renderer, mm_data)
if mm_data and not mm_data.is_empty()
else None
)
if features is not None:
body["features"] = features
if cache_salt is not None:
body["cache_salt"] = cache_salt
if priority is not None:
body["priority"] = priority

# /inference/v1/generate is mounted at the server root, not under /v1
# like the OpenAI-compatible endpoints. Build an absolute URL so the
# AsyncOpenAI client doesn't prepend its automatic /v1.
base = str(client.base_url).rstrip("/").removesuffix("/v1")
endpoint = f"{base}/inference/v1/generate"
elif transport == "dynamo":
if mm_data and not mm_data.is_empty():
raise NotImplementedError(
"Multimodal renderers are not yet supported on the dynamo transport."
)
nvext: dict[str, Any] = {
"token_data": prompt_ids,
"extra_fields": ["engine_data"],
}
if priority is not None:
nvext["agent_hints"] = {"priority": priority}

body = {
"model": model,
"messages": [{"role": "user", "content": "(token-in mode)"}],
"stream": False,
"logprobs": True,
"nvext": nvext,
}
if tools:
body["tools"] = tools
if stop_token_ids:
body["stop_token_ids"] = stop_token_ids
if cache_salt is not None:
nvext["cache_salt"] = cache_salt

passthrough = dict(sp)
passthrough.pop("stop_token_ids", None)
passthrough.pop("stop", None)
passthrough.pop("logprobs", None)
passthrough.pop("skip_special_tokens", None)
max_tokens = passthrough.pop("max_tokens", None)
if max_tokens is not None:
body["max_completion_tokens"] = max_tokens
body.update({k: v for k, v in passthrough.items() if v is not None})
endpoint = "/chat/completions"
else:
raise ValueError(f"Unsupported renderer transport: {transport}")

_request_logger.debug(
"POST %s prompt_len=%d max_tokens=%s",
"POST %s transport=%s prompt_len=%d max_tokens=%s",
endpoint,
transport,
len(prompt_ids),
sp.get("max_tokens"),
)
Expand All @@ -147,7 +193,31 @@ def _prepare():
raise

choice = (data.get("choices") or [{}])[0]
completion_ids = choice.get("token_ids") or []
if transport == "dynamo":
choice_nvext = choice.get("nvext") or {}
response_nvext = data.get("nvext") or {}
choice_engine_data = choice_nvext.get("engine_data") or {}
response_engine_data = response_nvext.get("engine_data") or {}
completion_ids = (
choice.get("token_ids")
or choice_nvext.get("completion_token_ids")
or response_nvext.get("completion_token_ids")
or choice_engine_data.get("completion_token_ids")
or response_engine_data.get("completion_token_ids")
or []
)
raw_re = (
choice.get("routed_experts")
or choice_nvext.get("routed_experts")
or response_nvext.get("routed_experts")
or choice_engine_data.get("routed_experts")
or response_engine_data.get("routed_experts")
)
request_id = data.get("id") or data.get("request_id") or ""
else:
completion_ids = choice.get("token_ids") or []
raw_re = choice.get("routed_experts")
request_id = data.get("request_id") or ""

parsed = await _maybe_offload(
renderer, lambda: renderer.parse_response(completion_ids, tools=tools)
Expand All @@ -157,9 +227,17 @@ def _prepare():
raw_logprobs = choice.get("logprobs") or {}
content_lp = raw_logprobs.get("content") if isinstance(raw_logprobs, dict) else None
completion_logprobs = [float(c.get("logprob") or 0.0) for c in content_lp or []]
if not completion_logprobs and transport == "dynamo":
choice_nvext = choice.get("nvext") or {}
response_nvext = data.get("nvext") or {}
engine_logprobs = (
(choice_nvext.get("engine_data") or {}).get("completion_logprobs")
or (response_nvext.get("engine_data") or {}).get("completion_logprobs")
or []
)
completion_logprobs = [float(logprob) for logprob in engine_logprobs]

routed_experts = None
raw_re = choice.get("routed_experts")
if isinstance(raw_re, dict) and "data" in raw_re and "shape" in raw_re:
routed_experts = (
np.frombuffer(base64.b85decode(raw_re["data"]), dtype=np.int32)
Expand All @@ -183,7 +261,7 @@ def _prepare():
finish_reason = "tool_calls"

return {
"request_id": data.get("request_id") or "",
"request_id": request_id,
"prompt_ids": list(prompt_ids),
"completion_ids": list(completion_ids),
"completion_logprobs": completion_logprobs,
Expand Down
Loading
Loading