Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions tests/test_renderer_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from functools import lru_cache
from unittest.mock import patch

Expand Down Expand Up @@ -46,6 +47,7 @@ def test_renderer_client_honors_configured_renderer_name():
size=1,
tool_parser=None,
reasoning_parser=None,
chat_template_kwargs={},
preserve_all_thinking=False,
preserve_thinking_between_tool_calls=False,
)
Expand Down Expand Up @@ -77,11 +79,63 @@ def test_renderer_client_uses_renderer_model_name_override():
size=1,
tool_parser=None,
reasoning_parser=None,
chat_template_kwargs={},
preserve_all_thinking=False,
preserve_thinking_between_tool_calls=False,
)


def test_renderer_client_consumes_sampling_chat_template_kwargs():
RendererClient._shared_pools.clear()

client = object.__new__(RendererClient)
client._renderer = None
client._pool_size = 1
client._config = vf.ClientConfig(client_type="renderer", renderer="qwen3")
client._client = object() # type: ignore[attr-defined]

sentinel_pool = RendererPool.__new__(RendererPool)
captured: dict = {}

async def _fake_generate(**kwargs):
captured.update(kwargs)
return {"content": "ok"}

with (
patch(
"verifiers.clients.renderer_client.create_renderer_pool",
return_value=sentinel_pool,
) as create_pool_mock,
patch("verifiers.clients.renderer_client.generate", side_effect=_fake_generate),
):
response = asyncio.run(
client.get_native_response(
prompt=[{"role": "user", "content": "hi"}],
model="Qwen/Qwen3-8B",
sampling_args={
"extra_body": {
"chat_template_kwargs": {"enable_thinking": False},
"top_k": 20,
}
},
tools=None,
)
)

assert response == {"content": "ok"}
create_pool_mock.assert_called_once_with(
"Qwen/Qwen3-8B",
renderer="qwen3",
size=1,
tool_parser=None,
reasoning_parser=None,
chat_template_kwargs={"enable_thinking": False},
preserve_all_thinking=False,
preserve_thinking_between_tool_calls=False,
)
assert captured["sampling_params"] == {"top_k": 20}


# Provenance: Eli's review on PR #1068, comment 3150580768.
# "RendererClient parses the GPT-OSS assistant tool call into ToolCall(name=...),
# but ToolEnv returns ToolMessage with only content/tool_call_id, and
Expand Down
47 changes: 39 additions & 8 deletions verifiers/clients/renderer_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,23 @@ def _parse_finish_reason(raw: str | None) -> FinishReason:
return None


def _freeze_json_like(value: Any) -> Any:
if isinstance(value, Mapping):
return tuple(sorted((str(k), _freeze_json_like(v)) for k, v in value.items()))
if isinstance(value, list):
return tuple(_freeze_json_like(v) for v in value)
return value


def _pop_chat_template_kwargs(sampling_params: dict[str, Any]) -> dict[str, Any]:
raw = sampling_params.pop("chat_template_kwargs", None)
if raw is None:
return {}
if not isinstance(raw, Mapping):
raise ValueError("extra_body.chat_template_kwargs must be a mapping")
return dict(raw)


class RendererClient(
Client[AsyncOpenAI, list[RendererMessage], dict[str, Any], ToolSpec]
):
Expand All @@ -418,13 +435,22 @@ class RendererClient(
"""

# Cache key is (renderer_model_name, renderer_name, tool_parser,
# reasoning_parser, pool_size, preserve_all_thinking,
# preserve_thinking_between_tool_calls) so that different parser configs,
# pool sizes, or preserve-thinking bindings for the same model don't
# collide.
# reasoning_parser, pool_size, chat_template_kwargs,
# preserve_all_thinking, preserve_thinking_between_tool_calls) so that
# different parser configs, pool sizes, template kwargs, or
# preserve-thinking bindings for the same model don't collide.
_shared_pools: ClassVar[
dict[
tuple[str, str, str | None, str | None, int, bool, bool],
tuple[
str,
str,
str | None,
str | None,
int,
Any,
bool,
bool,
],
RendererPool,
]
] = {}
Expand All @@ -451,7 +477,9 @@ async def close(self) -> None:

# ── Renderer management ─────────────────────────────────────────

def _get_renderer_or_pool(self, model: str) -> Renderer | RendererPool:
def _get_renderer_or_pool(
self, model: str, chat_template_kwargs: dict[str, Any] | None = None
) -> Renderer | RendererPool:
if self._renderer is not None:
return self._renderer

Expand All @@ -473,12 +501,14 @@ def _get_renderer_or_pool(self, model: str) -> Renderer | RendererPool:
if self._config is not None
else False
)
renderer_chat_template_kwargs = dict(chat_template_kwargs or {})
cache_key = (
renderer_model,
renderer_name,
tool_parser,
reasoning_parser,
self._pool_size,
_freeze_json_like(renderer_chat_template_kwargs),
preserve_all_thinking,
preserve_thinking_between_tool_calls,
)
Expand All @@ -491,6 +521,7 @@ def _get_renderer_or_pool(self, model: str) -> Renderer | RendererPool:
size=self._pool_size,
tool_parser=tool_parser,
reasoning_parser=reasoning_parser,
chat_template_kwargs=renderer_chat_template_kwargs,
preserve_all_thinking=preserve_all_thinking,
preserve_thinking_between_tool_calls=preserve_thinking_between_tool_calls,
)
Expand Down Expand Up @@ -528,10 +559,10 @@ async def get_native_response(
tools: list[ToolSpec] | None = None,
**kwargs: Any,
) -> dict[str, Any]:
renderer = self._get_renderer_or_pool(model)

args = dict(sampling_args)
sampling_params: dict[str, Any] = dict(args.pop("extra_body", None) or {})
chat_template_kwargs = _pop_chat_template_kwargs(sampling_params)
renderer = self._get_renderer_or_pool(model, chat_template_kwargs)
for key in (
"temperature",
"top_p",
Expand Down
Loading