diff --git a/renderers/__init__.py b/renderers/__init__.py index 62bc666..287ba33 100644 --- a/renderers/__init__.py +++ b/renderers/__init__.py @@ -44,6 +44,7 @@ from renderers.kimi_k2 import KimiK2Renderer from renderers.kimi_k25 import KimiK25Renderer from renderers.laguna_xs2 import LagunaXS2Renderer +from renderers.llama_3 import Llama3Renderer from renderers.minimax_m2 import MiniMaxM2Renderer from renderers.nemotron3 import Nemotron3Renderer from renderers.qwen3 import Qwen3Renderer @@ -63,6 +64,7 @@ "KimiK2Renderer", "KimiK25Renderer", "LagunaXS2Renderer", + "Llama3Renderer", "MULTIMODAL_MODELS", "Message", "MiniMaxM2Renderer", diff --git a/renderers/base.py b/renderers/base.py index 06e34b7..db72988 100644 --- a/renderers/base.py +++ b/renderers/base.py @@ -637,6 +637,14 @@ def bridge_to_next_turn(self, *args: Any, **kwargs: Any) -> "RenderedTokens | No # Nemotron 3. "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16": "nemotron-3", "nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16": "nemotron-3", + # Llama 3.2 (Instruct). Tested against the gated meta-llama repos and + # the unrestricted unsloth/... mirror, which ships a byte-identical + # chat template. ``Llama3Renderer`` defaults ``date_string`` to + # "26 Jul 2024" — matching the chat template's strftime fallback — + # so the renderer is reproducible. Pass ``date_string=...`` at + # construction to pin a different date. + "meta-llama/Llama-3.2-1B-Instruct": "llama-3", + "meta-llama/Llama-3.2-3B-Instruct": "llama-3", # Poolside Laguna. "poolside/Laguna-XS.2": "laguna-xs.2", # GPT-OSS. @@ -776,6 +784,7 @@ def _populate_registry(): from renderers.kimi_k2 import KimiK2Renderer from renderers.kimi_k25 import KimiK25Renderer from renderers.laguna_xs2 import LagunaXS2Renderer + from renderers.llama_3 import Llama3Renderer from renderers.minimax_m2 import MiniMaxM2Renderer from renderers.nemotron3 import Nemotron3Renderer from renderers.qwen3 import Qwen3Renderer @@ -798,6 +807,7 @@ def _populate_registry(): "kimi-k2": KimiK2Renderer, "kimi-k2.5": KimiK25Renderer, "laguna-xs.2": LagunaXS2Renderer, + "llama-3": Llama3Renderer, "nemotron-3": Nemotron3Renderer, "gpt-oss": GptOssRenderer, } @@ -862,8 +872,9 @@ def create_renderer( tokenizer: HuggingFace tokenizer instance. renderer: Renderer name ('qwen3', 'qwen3-vl', 'qwen3.5', 'qwen3.6', 'glm-5', 'glm-5.1', 'glm-4.5', 'minimax-m2', 'deepseek-v3', - 'kimi-k2', 'kimi-k2.5', 'laguna-xs.2', 'nemotron-3', - 'gpt-oss', 'default') or 'auto' to detect from model name. + 'kimi-k2', 'kimi-k2.5', 'laguna-xs.2', 'llama-3', + 'nemotron-3', 'gpt-oss', 'default') or 'auto' to detect + from model name. tool_parser: Name of a tool parser registered in ``renderers.parsers``. Only consumed by DefaultRenderer. Model-specific renderers have their own parsing wired in. diff --git a/renderers/llama_3.py b/renderers/llama_3.py new file mode 100644 index 0000000..df0a508 --- /dev/null +++ b/renderers/llama_3.py @@ -0,0 +1,401 @@ +"""Llama-3 Renderer — hard-coded Python mirroring Meta's Llama-3 chat template. + +Initial scope: Llama-3.2-1B-Instruct and Llama-3.2-3B-Instruct (and the +unrestricted ``unsloth/Llama-3.2-{1B,3B}-Instruct`` mirror, which ships a +byte-identical chat template). Other Llama-3.x sizes ship slightly +different templates and are NOT covered by this renderer until parity is +verified. + +Notable differences from the Qwen / GLM family renderers: + +* No ```` / reasoning channel — Llama-3 doesn't ship a + reasoning-content concept, so ``preserve_*_thinking`` flags don't + apply. +* ``<|begin_of_text|>`` (BOS) is emitted at the very start of every + render. The chat template never omits it. +* The system block is emitted **unconditionally** with a fixed + ``Cutting Knowledge Date: December 2023\\nToday Date: \\n\\n`` + preamble — even when no system message is supplied. Empty system + message → block ends with ``\\n\\n<|eot_id|>``. +* Tools default to "first-user-message" mode (matching the chat + template's default ``tools_in_user_message=True``): tool descriptions + + JSON signatures are injected into the first user message rather + than the system block. Pass ``tools_in_user_message=False`` at + construction to flip to system-block mode. +* ``date_string`` is a constructor kwarg pinned at ``"26 Jul 2024"`` by + default to match the chat template's ``strftime`` fallback (and keep + output deterministic). Override per instance for production runs that + want today's date. +* Tool calls: a single ``{"name": "...", "parameters": ...}`` JSON blob + inside the assistant body. The chat template explicitly raises if + ``message.tool_calls | length != 1``; this renderer matches that. +* Tool responses: rendered with role ``ipython`` regardless of whether + the source message used ``role: "tool"`` or ``role: "ipython"``. The + chat template runs ``content | tojson`` on any mapping/iterable + content — and Jinja considers strings iterable, so plain string + contents get JSON-quoted. We mirror that exactly. +""" + +from __future__ import annotations + +import json +from typing import Any + +from transformers.tokenization_utils import PreTrainedTokenizer + +from renderers.base import ( + Message, + ParsedResponse, + RenderedTokens, + ToolSpec, + reject_assistant_in_extension, + trim_to_turn_close, +) +from renderers.parsing import parse_llama_3 + +# --------------------------------------------------------------------------- +# Constants — must match the Jinja chat template's literal strings exactly. +# --------------------------------------------------------------------------- + +_DEFAULT_DATE_STRING = "26 Jul 2024" + +_CUTTING_KNOWLEDGE_DATE = "December 2023" + +# Tools-in-system intro: emitted into the system block when tools is set +# AND tools_in_user_message=False. Note: the chat template puts these +# three string literals back-to-back with NO newline between the second +# and third, so there's no space before "Do not use variables.". +_TOOLS_IN_SYSTEM_INTRO = ( + "You have access to the following functions. To call a function, " + "please respond with JSON for a function call." + 'Respond in the format {"name": function name, "parameters": ' + "dictionary of argument name and its value}." + "Do not use variables.\n\n" +) + +# Tools-in-user intro: emitted into the first user message when tools is +# set AND tools_in_user_message=True (the default). +_TOOLS_IN_USER_INTRO = ( + "Given the following functions, please respond with a JSON for a " + "function call with its proper arguments that best answers the given " + "prompt.\n\n" + 'Respond in the format {"name": function name, "parameters": ' + "dictionary of argument name and its value}." + "Do not use variables.\n\n" +) + + +class Llama3Renderer: + """Deterministic message → token renderer for Llama-3.x Instruct models.""" + + def __init__( + self, + tokenizer: PreTrainedTokenizer, + *, + date_string: str = _DEFAULT_DATE_STRING, + tools_in_user_message: bool = True, + preserve_all_thinking: bool = False, + preserve_thinking_between_tool_calls: bool = False, + ): + if preserve_all_thinking or preserve_thinking_between_tool_calls: + raise NotImplementedError( + "Llama-3 doesn't ship a reasoning_content channel — the chat " + "template has no block to preserve or drop. " + "preserve_*_thinking flags are not applicable." + ) + self._tokenizer = tokenizer + self._date_string = date_string + self._tools_in_user_message = tools_in_user_message + + self._bos = self._token_id("<|begin_of_text|>") + self._start_header = self._token_id("<|start_header_id|>") + self._end_header = self._token_id("<|end_header_id|>") + self._eot = self._token_id("<|eot_id|>") + self._end_of_text = self._token_id("<|end_of_text|>") + # ``<|eom_id|>`` shows up in some Llama-3 tool-calling traces (the + # "ipython" / python-tag flow) but the standard 3.2 chat template + # closes turns with ``<|eot_id|>``. We still treat eom as a stop + # token so models that emit it terminate cleanly. + self._eom = self._token_id("<|eom_id|>") + + def _token_id(self, token: str) -> int: + tid = self._tokenizer.convert_tokens_to_ids(token) + assert isinstance(tid, int) and tid != self._tokenizer.unk_token_id, ( + f"Special token {token!r} not found in tokenizer vocabulary" + ) + return tid + + def _encode(self, text: str) -> list[int]: + if not text: + return [] + return self._tokenizer.encode(text, add_special_tokens=False) + + @staticmethod + def _content_str(content: Any) -> str: + """Render content to a plain string. Handles ``str``, list-of-text-parts, + and ``None``. Matches the chat template's ``message.content | trim`` + callers, which expect a string in.""" + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for item in content: + if isinstance(item, str): + parts.append(item) + elif isinstance(item, dict) and "text" in item: + parts.append(item["text"]) + else: + raise ValueError(f"Unexpected content item: {item}") + return "".join(parts) + raise TypeError(f"Unexpected content type: {type(content)}") + + @staticmethod + def _tool_response_str(content: Any) -> str: + """Mirror the chat template's tool-response branch: + ``{% if message.content is mapping or message.content is iterable %} + {{ message.content | tojson }} {% else %} {{ message.content }}``. + + In Jinja, **strings are iterable** — so plain-string tool contents + also go through ``tojson`` (i.e. ``json.dumps``), wrapping them in + quotes and escaping. Non-iterable scalars (numbers, bools, None) + fall through to literal stringification. + """ + if content is None: + return "" + if isinstance(content, (dict, list, str)): + return json.dumps(content, ensure_ascii=False) + return str(content) + + # ------------------------------------------------------------------ + # render + # ------------------------------------------------------------------ + + def render( + self, + messages: list[Message], + *, + tools: list[ToolSpec] | None = None, + add_generation_prompt: bool = False, + ) -> RenderedTokens: + if not messages: + raise ValueError("No messages provided.") + + tokens: list[int] = [] + indices: list[int] = [] + + def emit_ids(ids: list[int], msg_idx: int) -> None: + tokens.extend(ids) + indices.extend([msg_idx] * len(ids)) + + def emit_special(token_id: int, msg_idx: int) -> None: + tokens.append(token_id) + indices.append(msg_idx) + + def emit_text(text: str, msg_idx: int) -> None: + emit_ids(self._encode(text), msg_idx) + + # ── 0. BOS ────────────────────────────────────────────────── + emit_special(self._bos, -1) + + # ── 1. System block (always emitted) ──────────────────────── + first_is_system = messages[0].get("role") == "system" + sys_idx = 0 if first_is_system else -1 + sys_text = ( + self._content_str(messages[0].get("content")).strip() + if first_is_system + else "" + ) + + emit_special(self._start_header, sys_idx) + emit_text("system", sys_idx) + emit_special(self._end_header, sys_idx) + body = "\n\n" + if tools is not None: + body += "Environment: ipython\n" + body += f"Cutting Knowledge Date: {_CUTTING_KNOWLEDGE_DATE}\n" + body += f"Today Date: {self._date_string}\n\n" + if tools is not None and not self._tools_in_user_message: + body += _TOOLS_IN_SYSTEM_INTRO + for t in tools: + body += json.dumps(t, indent=4, ensure_ascii=False) + "\n\n" + body += sys_text + emit_text(body, sys_idx) + emit_special(self._eot, sys_idx) + + # ── 2. Body messages ──────────────────────────────────────── + body_messages = messages[1:] if first_is_system else messages + offset = 1 if first_is_system else 0 + + i = 0 + # 2a. tools_in_user_message mode pulls the first user message + # into a special block with the tools description prepended. + if tools is not None and self._tools_in_user_message: + if i >= len(body_messages): + raise ValueError( + "Cannot place tools in the first user message — no user " + "message was provided." + ) + first_user = body_messages[i] + if first_user.get("role") != "user": + raise ValueError( + "tools_in_user_message=True requires the first non-system " + f"message to be 'user'; got {first_user.get('role')!r}." + ) + user_idx = i + offset + emit_special(self._start_header, user_idx) + emit_text("user", user_idx) + emit_special(self._end_header, user_idx) + user_body = "\n\n" + _TOOLS_IN_USER_INTRO + for t in tools: + user_body += json.dumps(t, indent=4, ensure_ascii=False) + "\n\n" + user_body += self._content_str(first_user.get("content")).strip() + emit_text(user_body, user_idx) + emit_special(self._eot, user_idx) + i += 1 + + # 2b. Remaining messages — plain user/assistant/tool/assistant-with-tool-calls. + for j in range(i, len(body_messages)): + msg = body_messages[j] + msg_idx = j + offset + role = msg.get("role") + tool_calls = msg.get("tool_calls") + + if role in ("tool", "ipython"): + emit_special(self._start_header, msg_idx) + emit_text("ipython", msg_idx) + emit_special(self._end_header, msg_idx) + emit_text( + "\n\n" + self._tool_response_str(msg.get("content")), + msg_idx, + ) + emit_special(self._eot, msg_idx) + elif tool_calls: + if len(tool_calls) != 1: + raise ValueError( + "Llama-3 chat template only supports a single tool call " + "per assistant message." + ) + tc = tool_calls[0] + func = tc.get("function") or tc + name = func.get("name", "") + arguments = func.get("arguments", {}) + if isinstance(arguments, str): + args_str = arguments + else: + args_str = json.dumps(arguments, ensure_ascii=False) + emit_special(self._start_header, msg_idx) + emit_text("assistant", msg_idx) + emit_special(self._end_header, msg_idx) + emit_text( + '\n\n{"name": "' + name + '", "parameters": ' + args_str + "}", + msg_idx, + ) + emit_special(self._eot, msg_idx) + else: + content = self._content_str(msg.get("content")).strip() + emit_special(self._start_header, msg_idx) + emit_text(role or "", msg_idx) + emit_special(self._end_header, msg_idx) + emit_text("\n\n" + content, msg_idx) + emit_special(self._eot, msg_idx) + + # ── 3. Generation prompt ──────────────────────────────────── + if add_generation_prompt: + emit_special(self._start_header, -1) + emit_text("assistant", -1) + emit_special(self._end_header, -1) + emit_text("\n\n", -1) + + return RenderedTokens(token_ids=tokens, message_indices=indices) + + def render_ids( + self, + messages: list[Message], + *, + tools: list[ToolSpec] | None = None, + add_generation_prompt: bool = False, + ) -> list[int]: + return self.render( + messages, + tools=tools, + add_generation_prompt=add_generation_prompt, + ).token_ids + + def parse_response(self, token_ids: list[int]) -> ParsedResponse: + return parse_llama_3( + self._tokenizer, + token_ids, + stop_ids={self._eot, self._end_of_text, self._eom}, + ) + + def get_stop_token_ids(self) -> list[int]: + return [self._eot, self._end_of_text, self._eom] + + # ------------------------------------------------------------------ + # bridge_to_next_turn + # ------------------------------------------------------------------ + + def bridge_to_next_turn( + self, + previous_prompt_ids: list[int], + previous_completion_ids: list[int], + new_messages: list[Message], + *, + tools: list[ToolSpec] | None = None, + ) -> list[int] | None: + if ( + not previous_prompt_ids + or not new_messages + or reject_assistant_in_extension(new_messages) + ): + return None + + previous_ids = trim_to_turn_close( + previous_prompt_ids, + previous_completion_ids, + {self._eot, self._end_of_text, self._eom}, + synthesize_close=self._eot, + ) + if previous_ids is None: + return None + + ext: list[int] = [] + + def emit_special(token_id: int, _msg_idx: int = -1) -> None: + ext.append(token_id) + + def emit_text(text: str, _msg_idx: int = -1) -> None: + ext.extend(self._encode(text)) + + for i, msg in enumerate(new_messages): + role = msg.get("role") + if role == "system": + emit_special(self._start_header, i) + emit_text("system", i) + emit_special(self._end_header, i) + emit_text("\n\n" + self._content_str(msg.get("content")).strip(), i) + emit_special(self._eot, i) + elif role == "user": + emit_special(self._start_header, i) + emit_text("user", i) + emit_special(self._end_header, i) + emit_text("\n\n" + self._content_str(msg.get("content")).strip(), i) + emit_special(self._eot, i) + elif role in ("tool", "ipython"): + emit_special(self._start_header, i) + emit_text("ipython", i) + emit_special(self._end_header, i) + emit_text("\n\n" + self._tool_response_str(msg.get("content")), i) + emit_special(self._eot, i) + else: + return None + + # Generation prompt — matches the gen-prompt branch of ``render()``. + emit_special(self._start_header, -1) + emit_text("assistant", -1) + emit_special(self._end_header, -1) + emit_text("\n\n", -1) + + return previous_ids + ext diff --git a/renderers/parsing.py b/renderers/parsing.py index 528f122..4b94bed 100644 --- a/renderers/parsing.py +++ b/renderers/parsing.py @@ -1264,3 +1264,47 @@ def _gptoss_extract_after_token( return None after = _decode(tokenizer, header_ids[pos + 1 :]).strip() return after.split()[0] if after else None + + +# ── Llama-3: single JSON tool call {"name": "...", "parameters": {...}} ─ + + +def parse_llama_3( + tokenizer, + token_ids: list[int], + *, + stop_ids: set[int], +) -> ParsedResponse: + """Parse Llama-3 completion tokens. + + The Llama-3 chat template emits tool calls as a single JSON blob in + the assistant body — ``{"name": "...", "parameters": {...}}`` — with + no surrounding XML tags or special tokens. Plain replies are just + text. We detect the tool-call shape with a strict starts-with-``{`` + + parses-as-dict-with-name-key check; anything else is treated as + content. Llama-3 doesn't have a built-in reasoning channel, so + ``reasoning_content`` is always ``None``. + """ + ids = _strip_stop_tokens(token_ids, stop_ids) + text = _decode(tokenizer, ids).strip() + + if text.startswith("{") and text.endswith("}"): + try: + parsed = json.loads(text) + except json.JSONDecodeError: + parsed = None + if isinstance(parsed, dict) and "name" in parsed: + arguments = parsed.get("parameters", parsed.get("arguments", {})) + tool_call = { + "function": { + "name": parsed.get("name", ""), + "arguments": arguments, + } + } + return ParsedResponse( + content="", + reasoning_content=None, + tool_calls=[tool_call], + ) + + return ParsedResponse(content=text, reasoning_content=None, tool_calls=None) diff --git a/tests/test_llama_3.py b/tests/test_llama_3.py new file mode 100644 index 0000000..02c483a --- /dev/null +++ b/tests/test_llama_3.py @@ -0,0 +1,388 @@ +"""Llama-3 renderer coverage. + +Covers ``Llama3Renderer`` and the ``meta-llama/Llama-3.2-{1B,3B}-Instruct`` +entries in ``MODEL_RENDERER_MAP``. Tokenizers are loaded via the +unrestricted ``unsloth/Llama-3.2-{1B,3B}-Instruct`` mirrors (verified +byte-identical chat templates) so CI doesn't need an HF token with Meta +license access. +""" + +from __future__ import annotations + +import pytest + +from renderers import Llama3Renderer, create_renderer +from renderers.base import MODEL_RENDERER_MAP, ParsedResponse, load_tokenizer + +# Pinned date for byte-parity tests. Matches the chat template's +# strftime fallback so we don't have to override on the apply side. +_PINNED_DATE = "26 Jul 2024" + +_MODEL_PAIRS = [ + # (canonical meta-llama id used by MODEL_RENDERER_MAP, unrestricted + # mirror used to actually load the tokenizer in tests) + ("meta-llama/Llama-3.2-1B-Instruct", "unsloth/Llama-3.2-1B-Instruct"), + ("meta-llama/Llama-3.2-3B-Instruct", "unsloth/Llama-3.2-3B-Instruct"), +] + + +@pytest.fixture(scope="module", params=_MODEL_PAIRS, ids=[m for m, _ in _MODEL_PAIRS]) +def llama_pair(request): + canonical, mirror = request.param + tok = load_tokenizer(mirror) + renderer = Llama3Renderer(tok, date_string=_PINNED_DATE) + return canonical, mirror, tok, renderer + + +# --------------------------------------------------------------------------- +# MODEL_RENDERER_MAP shape +# --------------------------------------------------------------------------- + + +def test_canonical_meta_llama_paths_route_to_llama_3(): + for canonical, _ in _MODEL_PAIRS: + assert MODEL_RENDERER_MAP.get(canonical) == "llama-3", ( + f"{canonical}: expected to route to 'llama-3'" + ) + + +def test_create_renderer_via_explicit_name(llama_pair): + """The 'llama-3' string resolves to Llama3Renderer in the registry.""" + _, _, tok, _ = llama_pair + r = create_renderer(tok, renderer="llama-3") + assert isinstance(r, Llama3Renderer) + + +# --------------------------------------------------------------------------- +# Constructor contract +# --------------------------------------------------------------------------- + + +def test_default_date_matches_chat_template_strftime_fallback(llama_pair): + """Default ``date_string`` is ``"26 Jul 2024"`` so output stays + deterministic without an explicit override.""" + _, _, tok, _ = llama_pair + r = Llama3Renderer(tok) + assert r._date_string == _PINNED_DATE + + +def test_preserve_all_thinking_rejected(llama_pair): + _, _, tok, _ = llama_pair + with pytest.raises(NotImplementedError, match="reasoning_content"): + Llama3Renderer(tok, preserve_all_thinking=True) + + +def test_preserve_thinking_between_tool_calls_rejected(llama_pair): + _, _, tok, _ = llama_pair + with pytest.raises(NotImplementedError, match="reasoning_content"): + Llama3Renderer(tok, preserve_thinking_between_tool_calls=True) + + +# --------------------------------------------------------------------------- +# Byte parity vs apply_chat_template +# --------------------------------------------------------------------------- + + +def _expected(tok, messages, **kwargs): + kwargs.setdefault("add_generation_prompt", False) + kwargs.setdefault("date_string", _PINNED_DATE) + return list( + tok.apply_chat_template(messages, tokenize=True, return_dict=False, **kwargs) + ) + + +def test_parity_minimal_user(llama_pair): + _, _, tok, r = llama_pair + msgs = [{"role": "user", "content": "Hi."}] + assert r.render_ids(msgs) == _expected(tok, msgs) + + +def test_parity_system_and_user(llama_pair): + _, _, tok, r = llama_pair + msgs = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi."}, + ] + assert r.render_ids(msgs) == _expected(tok, msgs) + + +def test_parity_system_user_assistant(llama_pair): + _, _, tok, r = llama_pair + msgs = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi."}, + {"role": "assistant", "content": "Hello!"}, + ] + assert r.render_ids(msgs) == _expected(tok, msgs) + + +def test_parity_no_system_with_gen_prompt(llama_pair): + _, _, tok, r = llama_pair + msgs = [{"role": "user", "content": "Hi."}] + assert r.render_ids(msgs, add_generation_prompt=True) == _expected( + tok, msgs, add_generation_prompt=True + ) + + +def test_parity_multi_turn(llama_pair): + _, _, tok, r = llama_pair + msgs = [ + {"role": "user", "content": "A"}, + {"role": "assistant", "content": "B"}, + {"role": "user", "content": "C"}, + {"role": "assistant", "content": "D"}, + ] + assert r.render_ids(msgs) == _expected(tok, msgs) + + +def test_parity_trims_whitespace(llama_pair): + _, _, tok, r = llama_pair + msgs = [ + {"role": "user", "content": " hello "}, + {"role": "assistant", "content": "\n\nworld\n"}, + ] + assert r.render_ids(msgs) == _expected(tok, msgs) + + +def test_parity_custom_date(llama_pair): + """``date_string`` constructor override changes both sides identically.""" + _, _, tok, _ = llama_pair + r = Llama3Renderer(tok, date_string="01 Jan 2026") + msgs = [{"role": "user", "content": "Hi."}] + expected = list( + tok.apply_chat_template( + msgs, tokenize=True, return_dict=False, date_string="01 Jan 2026" + ) + ) + assert r.render_ids(msgs) == expected + + +def test_parity_tools_in_user_default(llama_pair): + _, _, tok, r = llama_pair + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + }, + }, + } + ] + msgs = [ + {"role": "system", "content": "Be terse."}, + {"role": "user", "content": "Weather?"}, + ] + assert r.render_ids(msgs, tools=tools) == _expected(tok, msgs, tools=tools) + + +def test_parity_tools_in_system_mode(llama_pair): + """When constructed with ``tools_in_user_message=False``, the renderer + matches ``apply_chat_template(... tools_in_user_message=False)``.""" + _, _, tok, _ = llama_pair + r = Llama3Renderer(tok, date_string=_PINNED_DATE, tools_in_user_message=False) + tools = [ + { + "type": "function", + "function": {"name": "get_weather", "parameters": {}}, + } + ] + msgs = [ + {"role": "system", "content": "Be terse."}, + {"role": "user", "content": "Weather?"}, + ] + expected = list( + tok.apply_chat_template( + msgs, + tokenize=True, + return_dict=False, + tools=tools, + tools_in_user_message=False, + date_string=_PINNED_DATE, + ) + ) + assert r.render_ids(msgs, tools=tools) == expected + + +def test_parity_tool_call_round_trip(llama_pair): + """Assistant tool_calls + tool response + final assistant — covers + the JSON tool-call body emission and the ``ipython`` response role.""" + _, _, tok, r = llama_pair + msgs = [ + {"role": "user", "content": "Weather?"}, + { + "role": "assistant", + "tool_calls": [ + { + "type": "function", + "function": { + "name": "get_weather", + "arguments": {"city": "NYC"}, + }, + } + ], + }, + {"role": "tool", "content": '{"temp": 72}'}, + {"role": "assistant", "content": "It's 72."}, + ] + assert r.render_ids(msgs) == _expected(tok, msgs) + + +def test_parity_tool_response_dict_content(llama_pair): + """Tool response with mapping content goes through ``tojson`` in the + template; the renderer's ``_tool_response_str`` mirrors that.""" + _, _, tok, r = llama_pair + msgs = [ + {"role": "user", "content": "x"}, + { + "role": "assistant", + "tool_calls": [{"function": {"name": "f", "arguments": {}}}], + }, + {"role": "tool", "content": {"k": "v", "n": 42}}, + {"role": "assistant", "content": "ok"}, + ] + assert r.render_ids(msgs) == _expected(tok, msgs) + + +def test_render_raises_on_multiple_tool_calls(llama_pair): + """Llama-3 chat template explicitly raises on >1 tool call per turn — + the renderer mirrors that contract.""" + _, _, _, r = llama_pair + msgs = [ + {"role": "user", "content": "x"}, + { + "role": "assistant", + "tool_calls": [ + {"function": {"name": "f", "arguments": {}}}, + {"function": {"name": "g", "arguments": {}}}, + ], + }, + ] + with pytest.raises(ValueError, match="single tool call"): + r.render_ids(msgs) + + +# --------------------------------------------------------------------------- +# parse_response +# --------------------------------------------------------------------------- + + +def _tokens_for(tok, text: str) -> list[int]: + return tok.encode(text, add_special_tokens=False) + + +def test_parse_response_plain_content(llama_pair): + _, _, tok, r = llama_pair + ids = _tokens_for(tok, "Hello, world!") + [r._eot] + out = r.parse_response(ids) + assert isinstance(out, ParsedResponse) + assert out.content == "Hello, world!" + assert out.tool_calls is None + assert out.reasoning_content is None + + +def test_parse_response_tool_call(llama_pair): + _, _, tok, r = llama_pair + body = '{"name": "get_weather", "parameters": {"city": "NYC"}}' + ids = _tokens_for(tok, body) + [r._eot] + out = r.parse_response(ids) + assert out.content == "" + assert out.tool_calls == [ + {"function": {"name": "get_weather", "arguments": {"city": "NYC"}}} + ] + + +def test_parse_response_malformed_tool_call_falls_through_to_content(llama_pair): + """A body that LOOKS like a tool call but doesn't parse should land + in content rather than dropping silently.""" + _, _, tok, r = llama_pair + body = '{"name": "x", broken' + ids = _tokens_for(tok, body) + [r._eot] + out = r.parse_response(ids) + assert out.tool_calls is None + assert "{" in out.content + + +# --------------------------------------------------------------------------- +# Bridge contract +# --------------------------------------------------------------------------- + + +def _simulate_prior_turn(r): + prior = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi."}, + ] + asst = [{"role": "assistant", "content": "Hello!"}] + + prev_prompt = r.render_ids(prior, add_generation_prompt=True) + full = r.render_ids(prior + asst, add_generation_prompt=False) + prev_completion = list(full[len(prev_prompt) :]) + + stop = set(r.get_stop_token_ids()) + last = -1 + for i in range(len(prev_completion) - 1, -1, -1): + if prev_completion[i] in stop: + last = i + break + if last >= 0: + prev_completion = prev_completion[: last + 1] + return prev_prompt, prev_completion + + +def test_bridge_extends_prev_verbatim_on_clean_stop(llama_pair): + _, _, _, r = llama_pair + prev_prompt, prev_completion = _simulate_prior_turn(r) + new_messages = [{"role": "user", "content": "What's 2+2?"}] + bridged = r.bridge_to_next_turn(prev_prompt, prev_completion, new_messages) + assert bridged is not None + prev = prev_prompt + prev_completion + assert bridged[: len(prev)] == prev + assert len(bridged) > len(prev) + + +def test_bridge_matches_fresh_render_on_clean_stop(llama_pair): + """The whole point of the bridge: it must produce the same tokens as + a fresh render of the full message list — except sampled tokens are + kept verbatim rather than re-rendered.""" + _, _, _, r = llama_pair + prior = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi."}, + ] + asst = [{"role": "assistant", "content": "Hello!"}] + new_messages = [{"role": "user", "content": "What's 2+2?"}] + + prev_prompt, prev_completion = _simulate_prior_turn(r) + bridged = r.bridge_to_next_turn(prev_prompt, prev_completion, new_messages) + fresh = r.render_ids(prior + asst + new_messages, add_generation_prompt=True) + assert bridged == fresh + + +def test_bridge_rejects_assistant_in_extension(llama_pair): + _, _, _, r = llama_pair + prev_prompt, prev_completion = _simulate_prior_turn(r) + bridged = r.bridge_to_next_turn( + prev_prompt, + prev_completion, + [{"role": "assistant", "content": "forbidden"}], + ) + assert bridged is None + + +def test_bridge_synthesises_close_on_truncation(llama_pair): + _, _, _, r = llama_pair + prev_prompt, prev_completion = _simulate_prior_turn(r) + trunc = prev_completion[:-1] + if not trunc: + pytest.skip("simulated prior had no completion tokens to truncate") + bridged = r.bridge_to_next_turn( + prev_prompt, trunc, [{"role": "user", "content": "ping"}] + ) + assert bridged is not None + base = prev_prompt + trunc + assert bridged[: len(base)] == base + assert len(bridged) > len(base)