From dc2b964b7ecc76438a1ce34adf9a00b73820264d Mon Sep 17 00:00:00 2001 From: lesj0610 Date: Thu, 12 Feb 2026 01:09:14 +0900 Subject: [PATCH 01/19] feat(oai): add vLLM-style reasoning parsers and robust tool-call compatibility --- LICENSES/Apache-2.0.txt | 202 ++++++++++ THIRD_PARTY_LICENSES.md | 34 ++ backends/exllamav2/model.py | 3 + backends/exllamav3/model.py | 15 + common/config_models.py | 30 +- common/templating.py | 19 + config_sample.yml | 5 + endpoints/OAI/reasoning/__init__.py | 91 +++++ .../OAI/reasoning/abs_reasoning_parsers.py | 89 +++++ endpoints/OAI/reasoning/basic_parsers.py | 94 +++++ .../reasoning/deepseek_r1_reasoning_parser.py | 51 +++ .../reasoning/deepseek_v3_reasoning_parser.py | 58 +++ .../OAI/reasoning/ernie45_reasoning_parser.py | 98 +++++ .../OAI/reasoning/exaone4_reasoning_parser.py | 117 ++++++ .../reasoning/glm4_moe_reasoning_parser.py | 10 + .../OAI/reasoning/gptoss_reasoning_parser.py | 119 ++++++ .../OAI/reasoning/granite_reasoning_parser.py | 376 ++++++++++++++++++ .../OAI/reasoning/holo2_reasoning_parser.py | 59 +++ .../hunyuan_a13b_reasoning_parser.py | 241 +++++++++++ .../reasoning/identity_reasoning_parser.py | 36 ++ .../OAI/reasoning/kimi_k2_reasoning_parser.py | 57 +++ .../reasoning/minimax_m2_reasoning_parser.py | 73 ++++ .../OAI/reasoning/mistral_reasoning_parser.py | 73 ++++ .../OAI/reasoning/olmo3_reasoning_parser.py | 166 ++++++++ .../OAI/reasoning/qwen3_reasoning_parser.py | 28 ++ .../OAI/reasoning/seedoss_reasoning_parser.py | 16 + .../OAI/reasoning/step3_reasoning_parser.py | 67 ++++ .../OAI/reasoning/step3p5_reasoning_parser.py | 118 ++++++ endpoints/OAI/types/chat_completion.py | 9 +- endpoints/OAI/types/tools.py | 9 + endpoints/OAI/utils/chat_completion.py | 169 +++++++- endpoints/OAI/utils/tools.py | 92 ++++- 32 files changed, 2603 insertions(+), 21 deletions(-) create mode 100644 LICENSES/Apache-2.0.txt create mode 100644 THIRD_PARTY_LICENSES.md create mode 100644 endpoints/OAI/reasoning/__init__.py create mode 100644 endpoints/OAI/reasoning/abs_reasoning_parsers.py create mode 100644 endpoints/OAI/reasoning/basic_parsers.py create mode 100644 endpoints/OAI/reasoning/deepseek_r1_reasoning_parser.py create mode 100644 endpoints/OAI/reasoning/deepseek_v3_reasoning_parser.py create mode 100644 endpoints/OAI/reasoning/ernie45_reasoning_parser.py create mode 100644 endpoints/OAI/reasoning/exaone4_reasoning_parser.py create mode 100644 endpoints/OAI/reasoning/glm4_moe_reasoning_parser.py create mode 100644 endpoints/OAI/reasoning/gptoss_reasoning_parser.py create mode 100644 endpoints/OAI/reasoning/granite_reasoning_parser.py create mode 100644 endpoints/OAI/reasoning/holo2_reasoning_parser.py create mode 100644 endpoints/OAI/reasoning/hunyuan_a13b_reasoning_parser.py create mode 100644 endpoints/OAI/reasoning/identity_reasoning_parser.py create mode 100644 endpoints/OAI/reasoning/kimi_k2_reasoning_parser.py create mode 100644 endpoints/OAI/reasoning/minimax_m2_reasoning_parser.py create mode 100644 endpoints/OAI/reasoning/mistral_reasoning_parser.py create mode 100644 endpoints/OAI/reasoning/olmo3_reasoning_parser.py create mode 100644 endpoints/OAI/reasoning/qwen3_reasoning_parser.py create mode 100644 endpoints/OAI/reasoning/seedoss_reasoning_parser.py create mode 100644 endpoints/OAI/reasoning/step3_reasoning_parser.py create mode 100644 endpoints/OAI/reasoning/step3p5_reasoning_parser.py diff --git a/LICENSES/Apache-2.0.txt b/LICENSES/Apache-2.0.txt new file mode 100644 index 00000000..d6456956 --- /dev/null +++ b/LICENSES/Apache-2.0.txt @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/THIRD_PARTY_LICENSES.md b/THIRD_PARTY_LICENSES.md new file mode 100644 index 00000000..962f9d54 --- /dev/null +++ b/THIRD_PARTY_LICENSES.md @@ -0,0 +1,34 @@ +# Third-Party License Notices + +This repository is licensed under AGPL-3.0 (see `LICENSE`). + +The following files include code adapted from the vLLM project and are +licensed under Apache License 2.0: + +- `endpoints/OAI/reasoning/abs_reasoning_parsers.py` +- `endpoints/OAI/reasoning/basic_parsers.py` +- `endpoints/OAI/reasoning/deepseek_r1_reasoning_parser.py` +- `endpoints/OAI/reasoning/deepseek_v3_reasoning_parser.py` +- `endpoints/OAI/reasoning/ernie45_reasoning_parser.py` +- `endpoints/OAI/reasoning/exaone4_reasoning_parser.py` +- `endpoints/OAI/reasoning/glm4_moe_reasoning_parser.py` +- `endpoints/OAI/reasoning/gptoss_reasoning_parser.py` +- `endpoints/OAI/reasoning/granite_reasoning_parser.py` +- `endpoints/OAI/reasoning/holo2_reasoning_parser.py` +- `endpoints/OAI/reasoning/hunyuan_a13b_reasoning_parser.py` +- `endpoints/OAI/reasoning/identity_reasoning_parser.py` +- `endpoints/OAI/reasoning/kimi_k2_reasoning_parser.py` +- `endpoints/OAI/reasoning/minimax_m2_reasoning_parser.py` +- `endpoints/OAI/reasoning/mistral_reasoning_parser.py` +- `endpoints/OAI/reasoning/olmo3_reasoning_parser.py` +- `endpoints/OAI/reasoning/qwen3_reasoning_parser.py` +- `endpoints/OAI/reasoning/seedoss_reasoning_parser.py` +- `endpoints/OAI/reasoning/step3_reasoning_parser.py` +- `endpoints/OAI/reasoning/step3p5_reasoning_parser.py` +- `endpoints/OAI/reasoning/__init__.py` + +Source project: +- vLLM: https://github.com/vllm-project/vllm + +The Apache-2.0 license text is provided at: +- `LICENSES/Apache-2.0.txt` diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 6e59dbe3..d9341f10 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1426,12 +1426,15 @@ async def generate_gen( full_response += chunk chunk_tokens = result.get("token_ids") + token_ids = [] if chunk_tokens is not None: + token_ids = chunk_tokens.flatten().tolist() generated_tokens += chunk_tokens.size(dim=0) generation = { "request_id": request_id, "text": chunk, + "token_ids": token_ids, "prompt_tokens": context_len, "generated_tokens": generated_tokens, "offset": len(full_response), diff --git a/backends/exllamav3/model.py b/backends/exllamav3/model.py index 50c30450..33f5957d 100644 --- a/backends/exllamav3/model.py +++ b/backends/exllamav3/model.py @@ -1013,8 +1013,22 @@ async def generate_gen( if chunk: chunk_tokens = result.get("token_ids", self.tokenizer.encode(chunk)) full_response += chunk + token_ids: list[int] = [] + if isinstance(chunk_tokens, torch.Tensor): + token_ids = chunk_tokens.flatten().tolist() + elif isinstance(chunk_tokens, tuple): + # Some kernels may return tuple[token_ids, ...] + first = chunk_tokens[0] + if isinstance(first, torch.Tensor): + token_ids = first.flatten().tolist() + else: + token_ids = list(first) + else: + token_ids = list(chunk_tokens) if isinstance(chunk_tokens, torch.Tensor): generated_tokens += chunk_tokens.size(dim=0) + elif token_ids: + generated_tokens += len(token_ids) # Increase penalty range to generated token amount # TODO: @@ -1024,6 +1038,7 @@ async def generate_gen( generation = { "request_id": request_id, "text": chunk, + "token_ids": token_ids, "prompt_tokens": context_len, "generated_tokens": generated_tokens, "offset": len(full_response), diff --git a/common/config_models.py b/common/config_models.py index 0e71734c..cdfa603b 100644 --- a/common/config_models.py +++ b/common/config_models.py @@ -284,20 +284,28 @@ class ModelConfig(BaseConfigModel): ), ge=1, ) - prompt_template: Optional[str] = Field( - None, - description=( - "Set the prompt template for this model. (default: None)\n" + prompt_template: Optional[str] = Field( + None, + description=( + "Set the prompt template for this model. (default: None)\n" "If empty, attempts to look for the model's chat template.\n" "If a model contains multiple templates in its tokenizer_config.json,\n" "set prompt_template to the name of the template you want to use.\n" - "NOTE: Only works with chat completion message lists!" - ), - ) - vision: Optional[bool] = Field( - False, - description=( - "Enables vision support if the model supports it. (default: False)" + "NOTE: Only works with chat completion message lists!" + ), + ) + reasoning_parser: Optional[str] = Field( + None, + description=( + "Reasoning parser key used to split output into reasoning/content.\n" + "Compatible with vLLM parser naming (e.g. exaone4, deepseek_r1).\n" + "If omitted, defaults to 'basic'." + ), + ) + vision: Optional[bool] = Field( + False, + description=( + "Enables vision support if the model supports it. (default: False)" ), ) diff --git a/common/templating.py b/common/templating.py index cc0cceb1..1a65a73b 100644 --- a/common/templating.py +++ b/common/templating.py @@ -12,6 +12,7 @@ from jinja2.ext import loopcontrols from jinja2.sandbox import ImmutableSandboxedEnvironment from loguru import logger +from markupsafe import Markup from packaging import version @@ -46,6 +47,23 @@ class PromptTemplate: ) metadata: Optional[TemplateMetadata] = None + @staticmethod + def _tojson_compat(value, indent=None, ensure_ascii=True): + """ + Compatibility JSON filter for chat templates. + + Some model templates call `tojson(ensure_ascii=False)` while the + bundled Jinja filter may not accept that keyword in this environment. + """ + return Markup( + json.dumps( + value, + indent=indent, + ensure_ascii=ensure_ascii, + separators=(",", ": "), + ) + ) + async def extract_metadata(self, template_vars: dict): """ Returns deserialized template metadata from a chat template. @@ -107,6 +125,7 @@ def raise_exception(message): self.environment.globals["strftime_now"] = strftime_now self.environment.globals["raise_exception"] = raise_exception + self.environment.filters["tojson"] = self._tojson_compat return self.environment.from_string(template_str) diff --git a/config_sample.yml b/config_sample.yml index 0b65f9e8..b13a4cd9 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -153,6 +153,11 @@ model: # NOTE: Only works with chat completion message lists! prompt_template: + # Reasoning parser key for splitting hidden reasoning and final content. + # Compatible keys include: basic, exaone4, deepseek_r1, deepseek_v3. + # If omitted, TabbyAPI defaults to `basic`. + reasoning_parser: + # Enables vision support if the model supports it. (default: False) vision: false diff --git a/endpoints/OAI/reasoning/__init__.py b/endpoints/OAI/reasoning/__init__.py new file mode 100644 index 00000000..7aaa3630 --- /dev/null +++ b/endpoints/OAI/reasoning/__init__.py @@ -0,0 +1,91 @@ +from endpoints.OAI.reasoning.abs_reasoning_parsers import ( + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + DeltaMessage, + ReasoningParser, + ReasoningParserManager, +) +from endpoints.OAI.reasoning.basic_parsers import BaseThinkingReasoningParser +from endpoints.OAI.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser +from endpoints.OAI.reasoning.deepseek_v3_reasoning_parser import DeepSeekV3ReasoningParser +from endpoints.OAI.reasoning.ernie45_reasoning_parser import Ernie45ReasoningParser +from endpoints.OAI.reasoning.exaone4_reasoning_parser import Exaone4ReasoningParser +from endpoints.OAI.reasoning.glm4_moe_reasoning_parser import Glm4MoeModelReasoningParser +from endpoints.OAI.reasoning.gptoss_reasoning_parser import GptOssReasoningParser +from endpoints.OAI.reasoning.granite_reasoning_parser import GraniteReasoningParser +from endpoints.OAI.reasoning.holo2_reasoning_parser import Holo2ReasoningParser +from endpoints.OAI.reasoning.hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser +from endpoints.OAI.reasoning.identity_reasoning_parser import IdentityReasoningParser +from endpoints.OAI.reasoning.kimi_k2_reasoning_parser import KimiK2ReasoningParser +from endpoints.OAI.reasoning.minimax_m2_reasoning_parser import ( + MiniMaxM2AppendThinkReasoningParser, + MiniMaxM2ReasoningParser, +) +from endpoints.OAI.reasoning.mistral_reasoning_parser import MistralReasoningParser +from endpoints.OAI.reasoning.olmo3_reasoning_parser import Olmo3ReasoningParser +from endpoints.OAI.reasoning.qwen3_reasoning_parser import Qwen3ReasoningParser +from endpoints.OAI.reasoning.seedoss_reasoning_parser import SeedOSSReasoningParser +from endpoints.OAI.reasoning.step3_reasoning_parser import Step3ReasoningParser +from endpoints.OAI.reasoning.step3p5_reasoning_parser import Step3p5ReasoningParser + + +@ReasoningParserManager.register_module("identity") +class _IdentityParser(IdentityReasoningParser): + pass + + +@ReasoningParserManager.register_module("basic") +class _BasicParser(DeepSeekR1ReasoningParser): + pass + + +ReasoningParserManager.reasoning_parsers.update( + { + "deepseek_r1": DeepSeekR1ReasoningParser, + "deepseek_v3": DeepSeekV3ReasoningParser, + "ernie45": Ernie45ReasoningParser, + "exaone4": Exaone4ReasoningParser, + "glm45": Glm4MoeModelReasoningParser, + "openai_gptoss": GptOssReasoningParser, + "granite": GraniteReasoningParser, + "holo2": Holo2ReasoningParser, + "hunyuan_a13b": HunyuanA13BReasoningParser, + "kimi_k2": KimiK2ReasoningParser, + "minimax_m2": MiniMaxM2ReasoningParser, + "minimax_m2_append_think": MiniMaxM2AppendThinkReasoningParser, + "mistral": MistralReasoningParser, + "olmo3": Olmo3ReasoningParser, + "qwen3": Qwen3ReasoningParser, + "seed_oss": SeedOSSReasoningParser, + "step3": Step3ReasoningParser, + "step3p5": Step3p5ReasoningParser, + } +) + + +__all__ = [ + "BaseThinkingReasoningParser", + "DeltaMessage", + "DeepSeekR1ReasoningParser", + "DeepSeekV3ReasoningParser", + "Ernie45ReasoningParser", + "Exaone4ReasoningParser", + "Glm4MoeModelReasoningParser", + "GptOssReasoningParser", + "GraniteReasoningParser", + "Holo2ReasoningParser", + "HunyuanA13BReasoningParser", + "IdentityReasoningParser", + "KimiK2ReasoningParser", + "MiniMaxM2AppendThinkReasoningParser", + "MiniMaxM2ReasoningParser", + "MistralReasoningParser", + "Olmo3ReasoningParser", + "Qwen3ReasoningParser", + "ReasoningParser", + "ReasoningParserManager", + "SeedOSSReasoningParser", + "Step3ReasoningParser", + "Step3p5ReasoningParser", +] diff --git a/endpoints/OAI/reasoning/abs_reasoning_parsers.py b/endpoints/OAI/reasoning/abs_reasoning_parsers.py new file mode 100644 index 00000000..b81983d0 --- /dev/null +++ b/endpoints/OAI/reasoning/abs_reasoning_parsers.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import ABC, abstractmethod +from collections.abc import Callable, Sequence +from dataclasses import dataclass +from typing import Any + + +@dataclass +class DeltaMessage: + content: str | None = None + reasoning: str | None = None + + +class ReasoningParser(ABC): + def __init__(self, tokenizer: Any, *args, **kwargs): + self.model_tokenizer = tokenizer + + @property + def vocab(self) -> dict[str, int]: + return self.model_tokenizer.get_vocab() + + @abstractmethod + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + pass + + def is_reasoning_end_streaming( + self, input_ids: Sequence[int], delta_ids: Sequence[int] + ) -> bool: + return self.is_reasoning_end(input_ids) + + @abstractmethod + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + pass + + @abstractmethod + def extract_reasoning( + self, + model_output: str, + request: Any, + ) -> tuple[str | None, str | None]: + pass + + @abstractmethod + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + pass + + def prepare_structured_tag(self, original_tag: str | None, tool_server: Any | None): + return original_tag + + +class ReasoningParserManager: + reasoning_parsers: dict[str, type[ReasoningParser]] = {} + + @classmethod + def list_registered(cls) -> list[str]: + return sorted(cls.reasoning_parsers.keys()) + + @classmethod + def get_reasoning_parser(cls, name: str) -> type[ReasoningParser]: + parser = cls.reasoning_parsers.get(name) + if parser is None: + registered = ", ".join(cls.list_registered()) + raise KeyError( + f"Reasoning parser '{name}' not found. Available parsers: {registered}" + ) + return parser + + @classmethod + def register_module( + cls, + module_name: str, + ) -> Callable[[type[ReasoningParser]], type[ReasoningParser]]: + def _decorator(module: type[ReasoningParser]) -> type[ReasoningParser]: + cls.reasoning_parsers[module_name] = module + return module + + return _decorator diff --git a/endpoints/OAI/reasoning/basic_parsers.py b/endpoints/OAI/reasoning/basic_parsers.py new file mode 100644 index 00000000..f2dfc0c7 --- /dev/null +++ b/endpoints/OAI/reasoning/basic_parsers.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import abstractmethod +from typing import Any + +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage, ReasoningParser + + +class BaseThinkingReasoningParser(ReasoningParser): + @property + @abstractmethod + def start_token(self) -> str: + raise NotImplementedError + + @property + @abstractmethod + def end_token(self) -> str: + raise NotImplementedError + + def __init__(self, tokenizer: Any, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + self.start_token_id = self.vocab.get(self.start_token) + self.end_token_id = self.vocab.get(self.end_token) + if self.start_token_id is None or self.end_token_id is None: + raise RuntimeError( + f"{self.__class__.__name__} could not locate think tokens in tokenizer" + ) + + def is_reasoning_end(self, input_ids: list[int]) -> bool: + for token_id in reversed(input_ids): + if token_id == self.start_token_id: + return False + if token_id == self.end_token_id: + return True + return False + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + if self.end_token_id not in input_ids[:-1]: + return [] + return input_ids[input_ids.index(self.end_token_id) + 1 :] + + def extract_reasoning( + self, + model_output: str, + request: Any, + ) -> tuple[str | None, str | None]: + model_output_parts = model_output.partition(self.start_token) + model_output = ( + model_output_parts[2] if model_output_parts[1] else model_output_parts[0] + ) + + if self.end_token not in model_output: + return model_output or None, None + + reasoning, _, content = model_output.partition(self.end_token) + return reasoning or None, content or None + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: list[int], + current_token_ids: list[int], + delta_token_ids: list[int], + ) -> DeltaMessage | None: + if len(delta_token_ids) == 1 and ( + delta_token_ids[0] in [self.start_token_id, self.end_token_id] + ): + return None + + if self.start_token_id in previous_token_ids: + if self.end_token_id in delta_token_ids: + end_index = delta_text.find(self.end_token) + reasoning = delta_text[:end_index] or None + content = delta_text[end_index + len(self.end_token) :] or None + return DeltaMessage(reasoning=reasoning, content=content) + if self.end_token_id in previous_token_ids: + return DeltaMessage(content=delta_text or None) + return DeltaMessage(reasoning=delta_text or None) + + if self.start_token_id in delta_token_ids: + if self.end_token_id in delta_token_ids: + start_index = delta_text.find(self.start_token) + end_index = delta_text.find(self.end_token) + reasoning = delta_text[start_index + len(self.start_token) : end_index] + content = delta_text[end_index + len(self.end_token) :] + return DeltaMessage(reasoning=reasoning or None, content=content or None) + return DeltaMessage(reasoning=delta_text or None) + + return DeltaMessage(content=delta_text or None) diff --git a/endpoints/OAI/reasoning/deepseek_r1_reasoning_parser.py b/endpoints/OAI/reasoning/deepseek_r1_reasoning_parser.py new file mode 100644 index 00000000..3b93bb17 --- /dev/null +++ b/endpoints/OAI/reasoning/deepseek_r1_reasoning_parser.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage +from endpoints.OAI.reasoning.basic_parsers import BaseThinkingReasoningParser + + +class DeepSeekR1ReasoningParser(BaseThinkingReasoningParser): + @property + def start_token(self) -> str: + return "" + + @property + def end_token(self) -> str: + return "" + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: list[int], + current_token_ids: list[int], + delta_token_ids: list[int], + ) -> DeltaMessage | None: + ret = super().extract_reasoning_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + ) + + if ( + ret is not None + and self.start_token_id not in previous_token_ids + and self.start_token_id not in delta_token_ids + ): + if self.end_token_id in delta_token_ids: + end_index = delta_text.find(self.end_token) + reasoning = delta_text[:end_index] or None + content = delta_text[end_index + len(self.end_token) :] or None + return DeltaMessage(reasoning=reasoning, content=content) + if self.end_token_id in previous_token_ids: + return DeltaMessage(content=delta_text or None) + return DeltaMessage(reasoning=delta_text or None) + + return ret diff --git a/endpoints/OAI/reasoning/deepseek_v3_reasoning_parser.py b/endpoints/OAI/reasoning/deepseek_v3_reasoning_parser.py new file mode 100644 index 00000000..5e8deb73 --- /dev/null +++ b/endpoints/OAI/reasoning/deepseek_v3_reasoning_parser.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any + +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage, ReasoningParser +from endpoints.OAI.reasoning.deepseek_r1_reasoning_parser import ( + DeepSeekR1ReasoningParser, +) +from endpoints.OAI.reasoning.identity_reasoning_parser import IdentityReasoningParser + + +class DeepSeekV3ReasoningParser(ReasoningParser): + def __init__(self, tokenizer: Any, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + + chat_kwargs = kwargs.get("chat_template_kwargs", {}) or {} + thinking = bool(chat_kwargs.get("thinking", False)) + enable_thinking = bool(chat_kwargs.get("enable_thinking", False)) + thinking = thinking or enable_thinking + + if thinking: + self._parser = DeepSeekR1ReasoningParser(tokenizer, *args, **kwargs) + else: + self._parser = IdentityReasoningParser(tokenizer, *args, **kwargs) + + def is_reasoning_end(self, input_ids: list[int]) -> bool: + return self._parser.is_reasoning_end(input_ids) + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + return self._parser.extract_content_ids(input_ids) + + def extract_reasoning( + self, + model_output: str, + request: Any, + ) -> tuple[str | None, str | None]: + return self._parser.extract_reasoning(model_output, request) + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: list[int], + current_token_ids: list[int], + delta_token_ids: list[int], + ) -> DeltaMessage | None: + return self._parser.extract_reasoning_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + ) diff --git a/endpoints/OAI/reasoning/ernie45_reasoning_parser.py b/endpoints/OAI/reasoning/ernie45_reasoning_parser.py new file mode 100644 index 00000000..bff91166 --- /dev/null +++ b/endpoints/OAI/reasoning/ernie45_reasoning_parser.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Any + +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage +from endpoints.OAI.reasoning.basic_parsers import BaseThinkingReasoningParser + + +class Ernie45ReasoningParser(BaseThinkingReasoningParser): + response_start_token: str = "" + response_end_token: str = "" + newline_token: str = "<0x0A>" + + @property + def start_token(self) -> str: + return "" + + @property + def end_token(self) -> str: + return "" + + def __init__(self, tokenizer: Any, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + self.response_start_token_id = self.vocab.get(self.response_start_token) + self.response_end_token_id = self.vocab.get(self.response_end_token) + self.newline_token_id = self.vocab.get(self.newline_token) + self.parser_token_ids = [self.end_token_id, self.response_end_token_id] + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + if len(delta_token_ids) == 1 and ( + delta_token_ids[0] + in [ + self.start_token_id, + self.end_token_id, + self.response_start_token_id, + self.response_end_token_id, + ] + ): + return None + + if self.end_token_id in delta_token_ids: + think_end_index = delta_text.find(self.end_token) + reasoning = delta_text[:think_end_index] + content = delta_text[think_end_index + len(self.end_token) :].lstrip("\n") + response_start_idx = content.find(self.response_start_token) + response_end_idx = content.rfind(self.response_end_token) + if response_start_idx != -1: + content = content[response_start_idx + len(self.response_start_token) :] + if response_end_idx != -1: + content = content[:response_end_idx] + return DeltaMessage(reasoning=reasoning, content=content or None) + + if self.end_token_id in previous_token_ids: + content = delta_text + if self.response_start_token_id in delta_token_ids: + content = content.lstrip("\n") + response_start_idx = content.find(self.response_start_token) + content = content[response_start_idx + len(self.response_start_token) :] + response_end_idx = content.rfind(self.response_end_token) + if response_end_idx != -1: + content = content[:response_end_idx] + elif self.response_end_token_id in delta_token_ids: + response_end_idx = content.rfind(self.response_end_token) + content = content[:response_end_idx] + + if previous_token_ids and previous_token_ids[-1] in self.parser_token_ids: + if delta_token_ids and delta_token_ids[0] == self.newline_token_id: + content = content.lstrip("\n") + if len(previous_token_ids) > 1 and previous_token_ids[-2] == self.end_token_id: + if delta_token_ids and delta_token_ids[0] == self.newline_token_id: + content = content.lstrip("\n") + + return DeltaMessage(content=content or None) + + return DeltaMessage(reasoning=delta_text) + + def extract_reasoning( + self, model_output: str, request: Any + ) -> tuple[str | None, str | None]: + reasoning, content = super().extract_reasoning(model_output, request) + if content: + start_idx = content.find(self.response_start_token) + end_idx = content.rfind(self.response_end_token) + if start_idx != -1 and end_idx != -1 and start_idx < end_idx: + content = content[start_idx + len(self.response_start_token) : end_idx] + return reasoning, content or None diff --git a/endpoints/OAI/reasoning/exaone4_reasoning_parser.py b/endpoints/OAI/reasoning/exaone4_reasoning_parser.py new file mode 100644 index 00000000..b4394c13 --- /dev/null +++ b/endpoints/OAI/reasoning/exaone4_reasoning_parser.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any + +from endpoints.OAI.reasoning.abs_reasoning_parsers import ( + DeltaMessage, + ReasoningParserManager, +) +from endpoints.OAI.reasoning.deepseek_r1_reasoning_parser import ( + DeepSeekR1ReasoningParser, +) +from endpoints.OAI.reasoning.deepseek_v3_reasoning_parser import ( + DeepSeekV3ReasoningParser, +) +from endpoints.OAI.reasoning.identity_reasoning_parser import IdentityReasoningParser + + +@ReasoningParserManager.register_module("exaone4") +class Exaone4ReasoningParser(DeepSeekV3ReasoningParser): + """ + EXAONE-specific parser. + + Important model behavior: + - Uses only `enable_thinking` to toggle reasoning mode. + - Chat template may prefill `` so model output often starts directly + with reasoning body and closes with ``. + """ + + def __init__(self, tokenizer: Any, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + + chat_kwargs = kwargs.get("chat_template_kwargs", {}) or {} + enable_thinking = bool(chat_kwargs.get("enable_thinking", False)) + self._reasoning_ended = False + + if enable_thinking: + self._parser = DeepSeekR1ReasoningParser(tokenizer, *args, **kwargs) + else: + self._parser = IdentityReasoningParser(tokenizer, *args, **kwargs) + + def _thinking_enabled(self) -> bool: + return isinstance(self._parser, DeepSeekR1ReasoningParser) + + def _strip_stray_end_token(self, text: str) -> str: + if not text or not self._thinking_enabled(): + return text + + end_token = self._parser.end_token + start_token = self._parser.start_token + if start_token not in text and end_token in text: + return text.replace(end_token, "") + return text + + def extract_reasoning( + self, + model_output: str, + request: Any, + ) -> tuple[str | None, str | None]: + if not self._thinking_enabled(): + return None, model_output + + start_token = self._parser.start_token + end_token = self._parser.end_token + + if start_token in model_output: + _, _, model_output = model_output.partition(start_token) + + if end_token in model_output: + reasoning, _, content = model_output.partition(end_token) + return reasoning or None, content or None + + return model_output or None, None + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: list[int], + current_token_ids: list[int], + delta_token_ids: list[int], + ) -> DeltaMessage | None: + if not self._thinking_enabled(): + return self._parser.extract_reasoning_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + ) + + if self._reasoning_ended: + content = self._strip_stray_end_token(delta_text) + return DeltaMessage(content=content) if content else None + + start_token = self._parser.start_token + end_token = self._parser.end_token + + if end_token in delta_text: + reasoning_part, _, content_part = delta_text.partition(end_token) + if start_token in reasoning_part: + _, _, reasoning_part = reasoning_part.partition(start_token) + + self._reasoning_ended = True + return DeltaMessage( + reasoning=reasoning_part or None, + content=content_part or None, + ) + + reasoning_part = delta_text + if start_token in reasoning_part: + _, _, reasoning_part = reasoning_part.partition(start_token) + return DeltaMessage(reasoning=reasoning_part or None) if reasoning_part else None diff --git a/endpoints/OAI/reasoning/glm4_moe_reasoning_parser.py b/endpoints/OAI/reasoning/glm4_moe_reasoning_parser.py new file mode 100644 index 00000000..9368f2c3 --- /dev/null +++ b/endpoints/OAI/reasoning/glm4_moe_reasoning_parser.py @@ -0,0 +1,10 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from endpoints.OAI.reasoning.holo2_reasoning_parser import Holo2ReasoningParser + + +class Glm4MoeModelReasoningParser(Holo2ReasoningParser): + pass diff --git a/endpoints/OAI/reasoning/gptoss_reasoning_parser.py b/endpoints/OAI/reasoning/gptoss_reasoning_parser.py new file mode 100644 index 00000000..3e454bed --- /dev/null +++ b/endpoints/OAI/reasoning/gptoss_reasoning_parser.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +from collections.abc import Sequence +from typing import Any + +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage, ReasoningParser + + +NO_FUNC_REASONING_TAG = { + "type": "structural_tag", + "format": { + "type": "triggered_tags", + "tags": [ + { + "begin": "<|channel|>analysis<|message|>", + "content": {"type": "any_text"}, + "end": "<|end|>", + } + ], + "triggers": ["<|channel|>analysis"], + "stop_after_first": False, + }, +} + + +class GptOssReasoningParser(ReasoningParser): + def __init__(self, tokenizer: Any, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + + def _split_harmony(self, text: str) -> tuple[str | None, str | None]: + # Minimal harmony-compatible splitter without vLLM parser dependency. + analysis_tag = "<|channel|>analysis<|message|>" + final_tag = "<|channel|>final<|message|>" + end_tag = "<|end|>" + + a_idx = text.find(analysis_tag) + f_idx = text.find(final_tag) + if a_idx == -1 and f_idx == -1: + return None, text or None + + reasoning = None + content = None + + if a_idx != -1: + a_start = a_idx + len(analysis_tag) + a_end = text.find(end_tag, a_start) + if a_end == -1: + a_end = f_idx if f_idx != -1 else len(text) + reasoning = text[a_start:a_end] or None + + if f_idx != -1: + f_start = f_idx + len(final_tag) + f_end = text.find(end_tag, f_start) + if f_end == -1: + f_end = len(text) + content = text[f_start:f_end] or None + + return reasoning, content + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + text = self.model_tokenizer.decode(input_ids) + return "<|channel|>final<|message|>" in text + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + _, content = self._split_harmony(self.model_tokenizer.decode(input_ids)) + if content is None: + return [] + return self.model_tokenizer.encode(content) + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + prev_reasoning, prev_content = self._split_harmony(previous_text) + cur_reasoning, cur_content = self._split_harmony(current_text) + + reasoning_delta = None + content_delta = None + if cur_reasoning is not None: + prev_r = prev_reasoning or "" + reasoning_delta = ( + cur_reasoning[len(prev_r) :] + if cur_reasoning.startswith(prev_r) + else cur_reasoning + ) or None + if cur_content is not None: + prev_c = prev_content or "" + content_delta = ( + cur_content[len(prev_c) :] + if cur_content.startswith(prev_c) + else cur_content + ) or None + + if reasoning_delta is None and content_delta is None: + return None + return DeltaMessage(reasoning=reasoning_delta, content=content_delta) + + def extract_reasoning( + self, + model_output: str, + request: Any, + ) -> tuple[str | None, str | None]: + return self._split_harmony(model_output) + + def prepare_structured_tag( + self, original_tag: str | None, tool_server: Any | None + ) -> str | None: + if original_tag is not None: + return original_tag + return json.dumps(NO_FUNC_REASONING_TAG) diff --git a/endpoints/OAI/reasoning/granite_reasoning_parser.py b/endpoints/OAI/reasoning/granite_reasoning_parser.py new file mode 100644 index 00000000..c60c3fac --- /dev/null +++ b/endpoints/OAI/reasoning/granite_reasoning_parser.py @@ -0,0 +1,376 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Any + +try: + import regex as re +except ImportError: + import re + +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage, ReasoningParser + + + +class GraniteReasoningParser(ReasoningParser): + """ + Reasoning parser for IBM Granite. + + IBM granite models currently use "Here is my thought process:" + and "Here is my response:" to separate its thinking / response outputs. + """ + + def __init__(self, tokenizer: Any, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + + # NOTE: There have been some observed occurrences of quantized + # instances of the current models using "Here's" instead of "Here is", + # so to be safe, we match on both. + self.think_start_expr = r"(?:Here's|Here is) my thought process:" + self.response_start_expr = r"(?:Here's|Here is) my response:" + + self.reasoning_regex = re.compile( + rf"{self.think_start_expr}(.*?){self.response_start_expr}(.*)", re.DOTALL + ) + + self.valid_think_starts = [ + "Here's my thought process:", + "Here is my thought process:", + ] + self.valid_response_starts = ["Here's my response:", "Here is my response:"] + + # Substrings to match for sequence boundaries on raw text + self.seq_boundary_end = ":" + self.seq_boundary_start = "Here" + + # The longest any thinking / start of response message can be + self.longest_think_start = max( + len(think_start) for think_start in self.valid_think_starts + ) + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + text = self.model_tokenizer.decode(input_ids) + return any(resp in text for resp in self.valid_response_starts) + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + text = self.model_tokenizer.decode(input_ids) + _, content = self.extract_reasoning(text, None) + if not content: + return [] + return self.model_tokenizer.encode(content) + + def extract_reasoning( + self, model_output: str, request: ChatCompletionRequest + ) -> tuple[str | None, str | None]: + """Extract the reasoning content & content sections, respectively. + If the sequence doesn't match what we expect, i.e., the model generates + something else, all content is considered non-reasoning content. + + Args: + model_output (str): Output of the model to be parsed. + request (ChatCompletionRequest): Request being processed. + + Returns: + tuple[Optional[str], Optional[str]]: Tuple pair containing the + reasoning content and non-reasoning content. + """ + re_match = self.reasoning_regex.findall(model_output) + if not re_match: + return None, model_output + reasoning, response_content = re_match[0] + if not response_content: + return reasoning, None + return reasoning, response_content + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + """Extract the reasoning content / content emitted by granite models; + If the sequence doesn't match what we expect, i.e., the model generates + something else, all content is considered non-reasoning content. + + NOTE: Granite models do not use a special token to start their reasoning + and response sections; instead they have token sequences, e.g., + + Here is my thought process: Foo Here is my response: Bar + + This increases the complexity of correctly handling streams, since we + need to watch for specific sequences and correctly parse them without + dropping content that is potentially overlapping & spanning multiple + delta messages. + + Args: + previous_text (str): Previous text outside of this delta message. + current_text (str): Previous text + delta text. + delta_text (str): Text to consider and parse content from. + previous_token_ids (Sequence[int]): Token IDs of previous_text. + current_token_ids (Sequence[int]): Token IDs of current_text. + delta_token_ids (Sequence[int]): Token IDs of delta_text. + + Returns: + Union[DeltaMessage, None] + DeltaMessage with either reasoning content or content, or None. + """ + reasoning, resp_seq_len, content = self._get_content_sections(current_text) + # Either we haven't finished the start of the reasoning sequence, + # or the model is generating something unexpected. + if not reasoning: + delta_message = self._get_delta_message_with_no_reasoning_bounds( + current_text, delta_text + ) + # We have a start of reasoning message, but have not yet finished + # the start of response sequence. + elif not content: + delta_message = self._get_delta_message_with_no_response_bounds( + current_text, reasoning, delta_text + ) + # We've finished both the start of reasoning and start of response seq. + else: + # This should never happen since we matched on the response + assert resp_seq_len is not None + delta_message = self._get_delta_message_with_both_bounds( + delta_text, reasoning, content, current_text, resp_seq_len + ) + if not delta_message.content and not delta_message.reasoning: + return None + return delta_message + + #### Implementation details of stream parsing for granite models + def _is_reasoning_start_substr(self, text: str) -> bool: + """Check if a text matches one of the possible start reasoning seqs. + + Args: + text (str): Text to check for leading substr. + + Returns: + bool: True if any of the possible reasoning start seqs match. + """ + return any( + think_start.startswith(text) for think_start in self.valid_think_starts + ) + + def _is_response_start_substr(self, text: str) -> bool: + """Check if a text matches one of the possible start response seqs. + + Args: + text (str): Text to check for leading substr. + + Returns: + bool: True if any of the possible response start seqs match. + """ + return any( + response_start.startswith(text) + for response_start in self.valid_response_starts + ) + + def _get_delta_message_with_no_reasoning_bounds( + self, + current_text: str, + delta_text: str, + ) -> DeltaMessage: + """Parse the delta message when the current text has not yet completed + its start of reasoning sequence. + + Args: + current_text (str): The full previous + delta text. + delta_text (str): Text to consider and parse content from. + + Returns: + DeltaMessage: Message containing the parsed content. + """ + prev_longest_length = len(current_text) - len(delta_text) + is_substr = self._is_reasoning_start_substr(current_text) + was_substr = self._is_reasoning_start_substr(current_text[:prev_longest_length]) + + # Check if we just generated something NOT in the special token seq; + # if so, add everything that we previously skipped with this delta + # message and append everything to content in the future. + if was_substr and not is_substr: + return DeltaMessage( + reasoning=None, + content=current_text, + ) + if is_substr: + # Might still be in the special token sequence; return nothing + return DeltaMessage(reasoning=None, content=None) + # Otherwise the sequence has already been broken and we already + # corrected; just return the delta text as normal content. + return DeltaMessage(reasoning=None, content=delta_text) + + def _get_delta_message_with_no_response_bounds( + self, + current_text: str, + reasoning: str, + delta_text: str, + ) -> DeltaMessage: + """Parse the delta message when the current text has both reasoning + content with no (response) content. NOTE that we may have overlapping + tokens with the start of reasoning / start of response sequences on + either side of the delta text. + + Args: + current_text (str): The full previous + delta text. + reasoning (str): reasoning content from current_text. + delta_text (str): Text to consider and parse content from. + + Returns: + DeltaMessage: Message containing the parsed content. + """ + # If we have no reasoning content or explicitly end with the start of + # response sequence, we are in transition to the response; need to be + # careful here, since the final token (:) will match the reasoning + # content and fully parse it out; we should not pass the : back. + ends_with_start_response_seq = any( + current_text.endswith(response_start) + for response_start in self.valid_response_starts + ) + if reasoning is None or ends_with_start_response_seq: + return DeltaMessage(reasoning=None, content=None) + + # Consider previous / current text only within context of the reasoning + previous_text = reasoning[: -len(delta_text)] + current_text = reasoning + + # We need to be careful about adding unfinished response sequences; + # Find the place at which we MIGHT be starting a response sequence + prev_idx = previous_text.rfind(self.seq_boundary_start) + delta_idx = delta_text.rfind(self.seq_boundary_start) + + # Check the state of potential start of response substring matches. + prev_was_substr = ( + self._is_response_start_substr(previous_text[prev_idx:]) + if prev_idx >= 0 + else False + ) + delta_continues_substr = ( + self._is_response_start_substr(current_text[prev_idx:]) + if prev_idx >= 0 + else False + ) + delta_new_substr = ( + self._is_response_start_substr(delta_text[delta_idx:]) + if delta_idx >= 0 + else False + ) + + # Delta only contains potential continued response sequence text. + if delta_continues_substr: + return DeltaMessage(reasoning=None, content=None) + + if not prev_was_substr: + # Delta may be starting a new response seq but has other text too. + if delta_new_substr: + return DeltaMessage(reasoning=delta_text[:delta_idx], content=None) + # Normal case for most reasoning text (no potential special seqs). + return DeltaMessage(reasoning=delta_text, content=None) + # The substring that previously seemed to be a potential response + # seq wasn't one; we need to add the content to the delta message, + # and also slice off the potential response sequence + elif delta_new_substr: + reasoning = previous_text[prev_idx:] + delta_text[:delta_idx] + return DeltaMessage(reasoning=reasoning, content=None) + # No new substring yet, and we broke our old one; take the whole delta + return DeltaMessage( + reasoning=previous_text[prev_idx:] + delta_text, + content=None, + ) + + def _get_delta_message_with_both_bounds( + self, + delta_text: str, + reasoning: str, + response_content: str, + current_text: str, + response_seq_len: int, + ) -> DeltaMessage: + """Parse the delta message when the current text has both reasoning + content and normal (response) content. + + Args: + delta_text: Text to consider and parse content from. + reasoning: reasoning content from current_text. + response_content: response content from current_text. + current_text: The full previous + delta text. + response_seq_len: Len of the complete response sequence used. + + Returns: + DeltaMessage: Message containing the parsed content. + """ + # Always have content; take length to the end + delta_content = delta_text[-len(response_content) :] + reasoning_end_idx = len(delta_text) - (len(response_content) + response_seq_len) + + if reasoning_end_idx < 0: + delta_reasoning = None + else: + # Get the starting offset + start_reasoning_idx = ( + len(reasoning) + response_seq_len + len(response_content) - 1 + ) + delta_offset = len(current_text) - len(delta_text) + start_offset = start_reasoning_idx - delta_offset + if start_offset < 0: + start_offset = 0 + delta_reasoning = delta_text[start_offset:reasoning_end_idx] + + return DeltaMessage( + reasoning=delta_reasoning, + content=delta_content, + ) + + def _get_content_sections( + self, current_text: str + ) -> tuple[str | None, int | None, str | None]: + """Parse the text to extract the reasoning content / content + if we have them. + + Args: + current_text (str): The full previous + delta text. + + Returns: + tuple[Optional[str], Optional[int], Optional[str]]: Tuple of len 3 + containing the reasoning content, the length of the response seq + (if there is one) and the non-reasoning content. + """ + current_chunk_start = 0 + start_reasoning = None + parsed_content = False + delimiter_idxs = [ + idx + for idx, char in enumerate(current_text) + if char == self.seq_boundary_end + ] + + for current_chunk_end in delimiter_idxs: + current_chunk = current_text[current_chunk_start:current_chunk_end] + # Check to see if the start of reasoning seq if complete + if start_reasoning is None: + for think_start in self.valid_think_starts: + if current_chunk == think_start[:-1]: + start_reasoning = current_chunk_end + 1 + current_chunk_start = current_chunk_end + 1 + break + + # Check to see if the start of response seq if complete + elif not parsed_content: + for response_start in self.valid_response_starts: + if current_chunk[-len(response_start) + 1 :] == response_start[:-1]: + # Mark end of reasoning and start response content + # after the start of response sequence. + end_reasoning = current_chunk_end - len(response_start) + reasoning = current_text[start_reasoning:end_reasoning] + response_content = current_text[current_chunk_end + 1 :] + return reasoning, len(response_start), response_content + + if start_reasoning and not parsed_content: + return current_text[start_reasoning:], None, None + return None, None, None diff --git a/endpoints/OAI/reasoning/holo2_reasoning_parser.py b/endpoints/OAI/reasoning/holo2_reasoning_parser.py new file mode 100644 index 00000000..cdd4d356 --- /dev/null +++ b/endpoints/OAI/reasoning/holo2_reasoning_parser.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Any + +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage, ReasoningParser +from endpoints.OAI.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser +from endpoints.OAI.reasoning.identity_reasoning_parser import IdentityReasoningParser + + +class Holo2ReasoningParser(ReasoningParser): + def __init__(self, tokenizer: Any, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + chat_kwargs = kwargs.get("chat_template_kwargs", {}) or {} + thinking = bool(chat_kwargs.get("thinking", True)) + enable_thinking = bool(chat_kwargs.get("enable_thinking", True)) + thinking = thinking and enable_thinking + self._parser = ( + DeepSeekR1ReasoningParser(tokenizer, *args, **kwargs) + if thinking + else IdentityReasoningParser(tokenizer, *args, **kwargs) + ) + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + return self._parser.is_reasoning_end(input_ids) + + def is_reasoning_end_streaming( + self, input_ids: Sequence[int], delta_ids: Sequence[int] + ) -> bool: + return self._parser.is_reasoning_end_streaming(input_ids, delta_ids) + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + return self._parser.extract_content_ids(input_ids) + + def extract_reasoning( + self, model_output: str, request: Any + ) -> tuple[str | None, str | None]: + return self._parser.extract_reasoning(model_output, request) + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + return self._parser.extract_reasoning_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + ) diff --git a/endpoints/OAI/reasoning/hunyuan_a13b_reasoning_parser.py b/endpoints/OAI/reasoning/hunyuan_a13b_reasoning_parser.py new file mode 100644 index 00000000..b3ae3c72 --- /dev/null +++ b/endpoints/OAI/reasoning/hunyuan_a13b_reasoning_parser.py @@ -0,0 +1,241 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Any + +try: + import regex as re +except ImportError: + import re + +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage, ReasoningParser + + + +class HunyuanA13BReasoningParser(ReasoningParser): + """ + Reasoning parser for Hunyuan A13B Model + + HunyuanReasoningParser + + This class implements a reasoning parser specifically designed + for the Hunyuan A13B Model. It is responsible for parsing and + extracting structured reasoning and answer segments from model + outputs that follow a specific pattern. + + Key Features: + - For non-stream output , Recognizes and extracts reasoning ("think") + and answer ("answer") sections from text using regular expressions. + - For stream process, it requires a token id sequences to change the + reasoning state and other state so it maintains internal state to + manage parsing across multiple token. + + + think start: "\n": [14023, 771, 397] + think ends: "\n\n\n": [198, 524, 27963, 397, 27, 9399, 397] + response ends: "\n": [524, 9399, 29] + """ + + def __init__(self, tokenizer: Any, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + self.think_start_expr = r"\n" + self.think_end_expr = r"\n\n" + + self.response_start_expr = r"\n\n\n" + self.response_end_expr = r"\n" + + self.full_match_reasoning_regex = re.compile( + rf"(?:{self.think_start_expr}(.*?){self.response_start_expr})?(.*?){self.response_end_expr}", + re.DOTALL, + ) + + self.half_match_reasoning_regex = re.compile( + rf"{self.think_start_expr}(.*?){self.response_start_expr}(.*)", re.DOTALL + ) + + self.think_start_ids = [14023, 771, 397] + self.think_start_ids_fast = [14023, 771, 1363] + self.response_start_ids = [198, 524, 27963, 397, 27, 9399, 397] + self.response_start_ids_fast = [524, 27963, 397, 27, 9399, 397] + self.response_end_ids = [198, 524, 9399, 29] + self.fast_think_ids = [14023, 771, 1363, 524, 27963, 397, 27, 9399, 397] + + # when state change, send out all the buffered text in last state + self.buffered_text = [] + self.buffered_ids = [] + + self.current_state = "reasoning" + self.all_states = ["reasoning", "response"] + + self.current_state = "idle" + self.expected_sequence = self.think_start_ids + # this sequence only for the think start, it has two way to start. + self.expected_sequence_side = self.think_start_ids_fast + self.sequence_index = 0 + self.token_buffer = [] + self.text_buffer = "" + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + return self.current_state == "response" + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + # for hunyuan streaming reason parsing, the stream parse + # will call first, and the same token will be called in + # is_reasoning_end and extract_content_ids + # this id is not part of content, so just return [] here. + return [] + + def extract_reasoning( + self, model_output: str, request: ChatCompletionRequest + ) -> tuple[str | None, str | None]: + """Extract the reasoning content & content sections, respectively. + If the sequence doesn't match what we expect, i.e., the model generates + something else, all content is considered non-reasoning content. + + Args: + model_output (str): Output of the model to be parsed. + request (ChatCompletionRequest): Request being processed. + + Returns: + tuple[Optional[str], Optional[str]]: Tuple pair containing the + reasoning content and non-reasoning content. + """ + + re_match = self.full_match_reasoning_regex.findall(model_output) + if re_match: + reasoning, response_content = re_match[0] + if len(reasoning) == 0: + reasoning = None + if len(response_content) == 0: + response_content = None + return reasoning, response_content + + fallback_regex = self.half_match_reasoning_regex + fallback_match = fallback_regex.findall(model_output) + if fallback_match: + reasoning, response_content = fallback_match[0] + + if response_content.endswith(self.response_end_expr): + response_content = response_content[: -len(self.response_end_expr)] + + if len(reasoning) == 0: + reasoning = None + if len(response_content) == 0: + response_content = None + + return reasoning, response_content + + return None, model_output + + def _is_strict_increasing_subsequence( + self, subsequence: Sequence[int], sequence: Sequence[int] + ) -> bool: + if not subsequence: + return False + + sub_idx = 0 + for num in sequence: + if sub_idx < len(subsequence) and num == subsequence[sub_idx]: + sub_idx += 1 + return sub_idx == len(subsequence) + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + """Extract content using token ID sequence state machine""" + # Define sequences + think_start_sequence = self.think_start_ids + response_start_sequence = self.response_start_ids + response_end_sequence = self.response_end_ids + + if not delta_token_ids: + return None + + # Process each token in the delta + token = delta_token_ids[-1] + + def check_token_with_sequence(token): + if self.current_state == "idle" or self.current_state == "think": + return ( + token == self.expected_sequence[self.sequence_index] + or token == self.expected_sequence_side[self.sequence_index] + ) + else: + return token == self.expected_sequence[self.sequence_index] + + def check_last_token(token): + if self.current_state == "idle" or self.current_state == "think": + # only return true if it's judge using a side sequence. + if ( + self.sequence_index - 1 < len(self.expected_sequence_side) + and token == self.expected_sequence_side[self.sequence_index - 1] + ): + return self.sequence_index == len(self.expected_sequence_side) + else: + return self.sequence_index == len(self.expected_sequence) + else: + return self.sequence_index == len(self.expected_sequence) + + # Check if token matches expected sequence + token_in_state_seq = check_token_with_sequence(token) + + if token_in_state_seq: + # Store matching token + self.token_buffer.append(token) + self.text_buffer += delta_text + self.sequence_index += 1 + ## state change from idle->think->response->idle + + # Check if sequence fully matched + if check_last_token(token): + # State transition + if self.current_state == "idle": + self.current_state = "think" + self.expected_sequence = response_start_sequence + self.expected_sequence_side = self.response_start_ids_fast + elif self.current_state == "think": + self.current_state = "response" + self.expected_sequence = response_end_sequence + elif self.current_state == "response": + self.current_state = "idle" + self.expected_sequence = think_start_sequence + self.expected_sequence_side = self.think_start_ids_fast + + # Reset matching state + self.sequence_index = 0 + self.token_buffer = [] + self.text_buffer = "" + # Do not send content for state transition texts. + else: + # Sequence broken - handle buffered content + if self.token_buffer and len(self.token_buffer) > 0: + # Send buffered tokens + buffered_content = self.text_buffer + delta_text + # Reset matching state + self.sequence_index = 0 + self.token_buffer = [] + self.text_buffer = "" + + # Return content based on current state + if self.current_state == "think": + return DeltaMessage(reasoning=buffered_content, content=None) + else: + return DeltaMessage(reasoning=None, content=buffered_content) + else: + # No buffered content, send normally + if self.current_state == "think": + return DeltaMessage(reasoning=delta_text, content=None) + else: + return DeltaMessage(reasoning=None, content=delta_text) + + # If no content to send in this delta + return None diff --git a/endpoints/OAI/reasoning/identity_reasoning_parser.py b/endpoints/OAI/reasoning/identity_reasoning_parser.py new file mode 100644 index 00000000..52aa4052 --- /dev/null +++ b/endpoints/OAI/reasoning/identity_reasoning_parser.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any + +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage, ReasoningParser + + +class IdentityReasoningParser(ReasoningParser): + def is_reasoning_end(self, input_ids: list[int]) -> bool: + return True + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + return input_ids + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: list[int], + current_token_ids: list[int], + delta_token_ids: list[int], + ) -> DeltaMessage | None: + if not delta_text: + return None + return DeltaMessage(content=delta_text) + + def extract_reasoning( + self, + model_output: str, + request: Any, + ) -> tuple[str | None, str | None]: + return None, model_output diff --git a/endpoints/OAI/reasoning/kimi_k2_reasoning_parser.py b/endpoints/OAI/reasoning/kimi_k2_reasoning_parser.py new file mode 100644 index 00000000..664d5163 --- /dev/null +++ b/endpoints/OAI/reasoning/kimi_k2_reasoning_parser.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Any + +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage, ReasoningParser +from endpoints.OAI.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser +from endpoints.OAI.reasoning.identity_reasoning_parser import IdentityReasoningParser + + +class KimiK2ReasoningParser(ReasoningParser): + def __init__(self, tokenizer: Any, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + chat_kwargs = kwargs.get("chat_template_kwargs", {}) or {} + thinking = bool(chat_kwargs.get("thinking", True)) + self._parser = ( + DeepSeekR1ReasoningParser(tokenizer, *args, **kwargs) + if thinking + else IdentityReasoningParser(tokenizer, *args, **kwargs) + ) + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + return self._parser.is_reasoning_end(input_ids) + + def is_reasoning_end_streaming( + self, input_ids: Sequence[int], delta_ids: Sequence[int] + ) -> bool: + return self._parser.is_reasoning_end_streaming(input_ids, delta_ids) + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + return self._parser.extract_content_ids(input_ids) + + def extract_reasoning( + self, model_output: str, request: Any + ) -> tuple[str | None, str | None]: + return self._parser.extract_reasoning(model_output, request) + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + return self._parser.extract_reasoning_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + ) diff --git a/endpoints/OAI/reasoning/minimax_m2_reasoning_parser.py b/endpoints/OAI/reasoning/minimax_m2_reasoning_parser.py new file mode 100644 index 00000000..ca036d05 --- /dev/null +++ b/endpoints/OAI/reasoning/minimax_m2_reasoning_parser.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Any + +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage, ReasoningParser +from endpoints.OAI.reasoning.basic_parsers import BaseThinkingReasoningParser + + +class MiniMaxM2ReasoningParser(BaseThinkingReasoningParser): + @property + def start_token(self) -> str: + return "" + + @property + def end_token(self) -> str: + return "" + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + if len(delta_token_ids) == 1 and delta_token_ids[0] == self.end_token_id: + return None + + if self.end_token_id in previous_token_ids: + return DeltaMessage(content=delta_text) + + if self.end_token_id in delta_token_ids: + end_index = delta_text.find(self.end_token) + reasoning = delta_text[:end_index] + content = delta_text[end_index + len(self.end_token) :] + return DeltaMessage(reasoning=reasoning or None, content=content or None) + + return DeltaMessage(reasoning=delta_text) + + +class MiniMaxM2AppendThinkReasoningParser(ReasoningParser): + def __init__(self, tokenizer: Any, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + self.end_token_id = self.vocab.get("") + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + return any(input_id == self.end_token_id for input_id in reversed(input_ids)) + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + return input_ids + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + if len(previous_token_ids) == 0: + delta_text = "" + delta_text + return DeltaMessage(content=delta_text) + + def extract_reasoning( + self, model_output: str, request: Any + ) -> tuple[str | None, str | None]: + return None, "" + model_output diff --git a/endpoints/OAI/reasoning/mistral_reasoning_parser.py b/endpoints/OAI/reasoning/mistral_reasoning_parser.py new file mode 100644 index 00000000..99436dad --- /dev/null +++ b/endpoints/OAI/reasoning/mistral_reasoning_parser.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Any + +from endpoints.OAI.reasoning.basic_parsers import BaseThinkingReasoningParser + + +class MistralReasoningParser(BaseThinkingReasoningParser): + @property + def start_token(self) -> str: + return "[THINK]" + + @property + def end_token(self) -> str: + return "[/THINK]" + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + has_eot = False + for token_id in reversed(input_ids): + if token_id == self.start_token_id: + return has_eot + if token_id == self.end_token_id: + has_eot = True + return False + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + has_bot = False + has_eot = False + bot_idx = -1 + eot_idx = -1 + for i, token_id in enumerate(input_ids): + if token_id == self.start_token_id and not has_bot: + has_bot = True + bot_idx = i + elif token_id == self.end_token_id: + has_eot = True + eot_idx = i + break + + if has_bot and not has_eot: + return input_ids[:bot_idx] + if not has_bot and not has_eot: + return input_ids + if has_bot and has_eot: + return input_ids[:bot_idx] + input_ids[eot_idx + 1 :] + return input_ids[:eot_idx] + input_ids[eot_idx + 1 :] + + def extract_reasoning( + self, model_output: str, request: Any + ) -> tuple[str | None, str | None]: + if not model_output: + return None, "" + + prefix, bot, post_bot = model_output.partition(self.start_token) + has_bot = bool(bot) + has_valid_eot = has_bot and self.end_token in post_bot + + if has_bot and has_valid_eot: + reasoning, _, post_eot = post_bot.partition(self.end_token) + content = prefix + post_eot + return reasoning or None, content or None + if has_bot: + return post_bot or None, prefix or None + + if self.end_token in prefix: + pre_eot, _, post_eot = prefix.partition(self.end_token) + return None, (pre_eot + post_eot) or None + + return None, prefix diff --git a/endpoints/OAI/reasoning/olmo3_reasoning_parser.py b/endpoints/OAI/reasoning/olmo3_reasoning_parser.py new file mode 100644 index 00000000..6e5f0948 --- /dev/null +++ b/endpoints/OAI/reasoning/olmo3_reasoning_parser.py @@ -0,0 +1,166 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import dataclasses as dt +import enum +from collections.abc import Sequence +from typing import Any + +try: + import regex as re +except ImportError: # pragma: no cover + import re + +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage, ReasoningParser + + +class Olmo3ReasoningState(enum.Enum): + REASONING = 1 + CONTENT = 2 + + +@dt.dataclass(frozen=True) +class Indices: + start: int + end: int + + def __len__(self): + return self.end - self.start + + +def string_overlap(a: str, b: str) -> tuple[Indices | None, Indices | None]: + a, b, swap = (a, b, False) if len(a) < len(b) else (b, a, True) + + if a in b: + ind_a = Indices(0, len(a)) + ind_b = Indices(b.index(a), b.index(a) + len(a)) + return (ind_b, ind_a) if swap else (ind_a, ind_b) + + for i in range(len(a) - 1, 0, -1): + if a[-i:] == b[:i]: + ind_a = Indices(len(a) - i, len(a)) + ind_b = Indices(0, i) + return (ind_b, ind_a) if swap else (ind_a, ind_b) + + for i in range(len(a) - 1, 0, -1): + if b[-i:] == a[:i]: + ind_a = Indices(0, i) + ind_b = Indices(len(b) - i, len(b)) + return (ind_b, ind_a) if swap else (ind_a, ind_b) + + return None, None + + +@dt.dataclass +class Olmo3ReasoningBuffer: + think_start: str = "" + think_end: str = "" + buffer: str = "" + state: Olmo3ReasoningState = Olmo3ReasoningState.REASONING + + def process_buffer(self) -> DeltaMessage | None: + start_think_idx = self.buffer.find(self.think_start) + if start_think_idx >= 0: + self.state = Olmo3ReasoningState.REASONING + pretext, self.buffer = ( + self.buffer[:start_think_idx], + self.buffer[start_think_idx + len(self.think_start) :], + ) + if start_think_idx > 0: + return DeltaMessage(content=pretext) + + end_think_idx = self.buffer.rfind(self.think_end) + if end_think_idx >= 0: + self.state = Olmo3ReasoningState.CONTENT + pretext, self.buffer = ( + self.buffer[:end_think_idx], + self.buffer[end_think_idx + len(self.think_end) :], + ) + if end_think_idx > 0: + return DeltaMessage(reasoning=pretext) + + if self.state == Olmo3ReasoningState.REASONING: + text_buffer, self.buffer = self.buffer, "" + return DeltaMessage(reasoning=text_buffer) + + if self.state == Olmo3ReasoningState.CONTENT: + text_buffer, self.buffer = self.buffer, "" + return DeltaMessage(content=text_buffer) + + return None + + def add_text(self, delta_text: str) -> DeltaMessage | None: + self.buffer += delta_text + delta_message: DeltaMessage | None = None + + _, overlap_think_start = string_overlap(delta_text, self.think_start) + _, overlap_think_end = string_overlap(delta_text, self.think_end) + + partial_overlap_start = overlap_think_start is not None and len( + overlap_think_start + ) < len(self.think_start) + partial_overlap_end = overlap_think_end is not None and len(overlap_think_end) < len( + self.think_end + ) + + if partial_overlap_start and self.think_start in self.buffer and not partial_overlap_end: + delta_message = self.process_buffer() + elif partial_overlap_end and self.think_end in self.buffer: + delta_message = self.process_buffer() + elif partial_overlap_start or partial_overlap_end: + return None + else: + delta_message = self.process_buffer() + + return delta_message + + +class Olmo3ReasoningParser(ReasoningParser): + def __init__(self, tokenizer: Any, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + + self.think_start = r"" + self.think_end = r"" + + reasoning_expr = ( + rf"^(?:{self.think_start})?(?P.*?)" + rf"{self.think_end}(?P.*)$" + ) + self.reasoning_regex = re.compile(reasoning_expr, re.DOTALL) + self.buffer = Olmo3ReasoningBuffer( + think_start=self.think_start, think_end=self.think_end + ) + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + text = self.model_tokenizer.decode(input_ids) + return self.think_end in text + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + return [] + + def extract_reasoning( + self, + model_output: str, + request: Any, + ) -> tuple[str | None, str | None]: + re_match = self.reasoning_regex.match(model_output) + if re_match: + reasoning = re_match.group("reasoning") or None + content = re_match.group("content") or None + return reasoning, content + return None, model_output + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + delta_message = self.buffer.add_text(delta_text) + if delta_message is None and self.buffer.think_end in self.buffer.buffer: + delta_message = self.buffer.process_buffer() + return delta_message diff --git a/endpoints/OAI/reasoning/qwen3_reasoning_parser.py b/endpoints/OAI/reasoning/qwen3_reasoning_parser.py new file mode 100644 index 00000000..7e6a941b --- /dev/null +++ b/endpoints/OAI/reasoning/qwen3_reasoning_parser.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any + +from endpoints.OAI.reasoning.basic_parsers import BaseThinkingReasoningParser + + +class Qwen3ReasoningParser(BaseThinkingReasoningParser): + @property + def start_token(self) -> str: + return "" + + @property + def end_token(self) -> str: + return "" + + def extract_reasoning( + self, model_output: str, request: Any + ) -> tuple[str | None, str | None]: + if self.start_token not in model_output or self.end_token not in model_output: + return None, model_output + + _, _, tail = model_output.partition(self.start_token) + reasoning, _, content = tail.partition(self.end_token) + return reasoning or None, content or None diff --git a/endpoints/OAI/reasoning/seedoss_reasoning_parser.py b/endpoints/OAI/reasoning/seedoss_reasoning_parser.py new file mode 100644 index 00000000..6d7a964c --- /dev/null +++ b/endpoints/OAI/reasoning/seedoss_reasoning_parser.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from endpoints.OAI.reasoning.basic_parsers import BaseThinkingReasoningParser + + +class SeedOSSReasoningParser(BaseThinkingReasoningParser): + @property + def start_token(self) -> str: + return "" + + @property + def end_token(self) -> str: + return "" diff --git a/endpoints/OAI/reasoning/step3_reasoning_parser.py b/endpoints/OAI/reasoning/step3_reasoning_parser.py new file mode 100644 index 00000000..bedaf620 --- /dev/null +++ b/endpoints/OAI/reasoning/step3_reasoning_parser.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Any + +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage, ReasoningParser + + +class Step3ReasoningParser(ReasoningParser): + def __init__(self, tokenizer: Any, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + self.think_end_token = "" + self.think_end_token_id = self.vocab.get(self.think_end_token) + if self.think_end_token_id is None: + raise RuntimeError( + "Step3 reasoning parser could not locate think end token in tokenizer" + ) + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + if len(delta_token_ids) == 1 and delta_token_ids[0] == self.think_end_token_id: + return None + + if self.think_end_token_id in delta_token_ids: + end_index = delta_text.find(self.think_end_token) + reasoning = delta_text[:end_index] + content = delta_text[end_index + len(self.think_end_token) :] + return DeltaMessage(reasoning=reasoning, content=content or None) + + if self.think_end_token_id in previous_token_ids: + return DeltaMessage(content=delta_text) + + return DeltaMessage(reasoning=delta_text) + + def extract_reasoning( + self, model_output: str, request: Any + ) -> tuple[str | None, str | None]: + if self.think_end_token not in model_output: + return model_output or None, None + + end_index = model_output.find(self.think_end_token) + reasoning = model_output[:end_index] + content = model_output[end_index + len(self.think_end_token) :] + return reasoning or None, content or None + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + return self.think_end_token_id in input_ids + + def is_reasoning_end_streaming( + self, input_ids: Sequence[int], delta_ids: Sequence[int] + ) -> bool: + return self.think_end_token_id in delta_ids + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + if self.think_end_token_id not in input_ids[:-1]: + return [] + return input_ids[input_ids.index(self.think_end_token_id) + 1 :] diff --git a/endpoints/OAI/reasoning/step3p5_reasoning_parser.py b/endpoints/OAI/reasoning/step3p5_reasoning_parser.py new file mode 100644 index 00000000..9d56282e --- /dev/null +++ b/endpoints/OAI/reasoning/step3p5_reasoning_parser.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Any + +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage +from endpoints.OAI.reasoning.basic_parsers import BaseThinkingReasoningParser + + +class Step3p5ReasoningParser(BaseThinkingReasoningParser): + @property + def start_token(self) -> str: + return "" + + @property + def end_token(self) -> str: + return "" + + def __init__(self, tokenizer: Any, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + self._pending_reasoning_newline = False + self.end_offset = 1 + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + if self.end_token_id in input_ids and self.end_offset > 0: + self.end_offset -= 1 + return False + return self.end_offset < 1 + + def is_reasoning_end_streaming( + self, input_ids: Sequence[int], delta_ids: Sequence[int] + ) -> bool: + if self.end_token_id in input_ids and self.end_offset > 0: + self.end_offset -= 1 + return False + return self.end_offset < 1 + + def extract_reasoning( + self, model_output: str, request: Any + ) -> tuple[str | None, str | None]: + reasoning, content = super().extract_reasoning(model_output, request) + if reasoning is not None: + reasoning = reasoning.removesuffix("\n") + if content is not None: + content = content.removeprefix("\n") + return reasoning or None, content or None + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + if previous_text.endswith(self.end_token) and delta_text: + if delta_text == "\n": + return None + if delta_text.startswith("\n"): + remaining = delta_text.removeprefix("\n") + return DeltaMessage(content=remaining) if remaining else None + + ret = super().extract_reasoning_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + ) + if ret is None: + return None + + if ( + self.start_token_id not in previous_token_ids + and self.start_token_id not in delta_token_ids + ): + if self.end_token_id in delta_token_ids: + end_index = delta_text.find(self.end_token) + reasoning = delta_text[:end_index] + content = delta_text[end_index + len(self.end_token) :] + ret = DeltaMessage(reasoning=reasoning, content=content or None) + elif self.end_token_id in previous_token_ids: + ret = DeltaMessage(content=delta_text) + else: + ret = DeltaMessage(reasoning=delta_text) + + reasoning_to_output = ret.reasoning + content_to_output = ret.content + + if reasoning_to_output is not None: + if self._pending_reasoning_newline: + reasoning_to_output = "\n" + reasoning_to_output + self._pending_reasoning_newline = False + + if reasoning_to_output.endswith("\n"): + reasoning_to_output = reasoning_to_output.removesuffix("\n") + if self.end_token in delta_text: + self._pending_reasoning_newline = False + else: + self._pending_reasoning_newline = True + + if content_to_output is not None: + self.end_offset -= 1 + self._pending_reasoning_newline = False + if self.end_token in delta_text and content_to_output.startswith("\n"): + content_to_output = content_to_output.removeprefix("\n") + + reasoning_to_output = reasoning_to_output or None + content_to_output = content_to_output or None + if reasoning_to_output is None and content_to_output is None: + return None + + return DeltaMessage(reasoning=reasoning_to_output, content=content_to_output) diff --git a/endpoints/OAI/types/chat_completion.py b/endpoints/OAI/types/chat_completion.py index 52523149..5d8913d4 100644 --- a/endpoints/OAI/types/chat_completion.py +++ b/endpoints/OAI/types/chat_completion.py @@ -4,7 +4,7 @@ from uuid import uuid4 from endpoints.OAI.types.common import UsageStats, CommonCompletionRequest -from endpoints.OAI.types.tools import ToolSpec, ToolCall +from endpoints.OAI.types.tools import NamedToolChoice, ToolSpec, ToolCall class ChatCompletionLogprob(BaseModel): @@ -30,6 +30,8 @@ class ChatCompletionMessagePart(BaseModel): class ChatCompletionMessage(BaseModel): role: str = "user" content: Optional[Union[str, List[ChatCompletionMessagePart]]] = None + reasoning: Optional[str] = None + reasoning_content: Optional[str] = None tool_calls: Optional[List[ToolCall]] = None tool_call_id: Optional[str] = None @@ -65,12 +67,17 @@ class ChatCompletionRequest(CommonCompletionRequest): ) response_prefix: Optional[str] = None model: Optional[str] = None + include_reasoning: Optional[bool] = True # tools is follows the format OAI schema, functions is more flexible # both are available in the chat template. tools: Optional[List[ToolSpec]] = None functions: Optional[List[Dict]] = None + tool_choice: Optional[ + Union[Literal["none", "auto", "required"], NamedToolChoice] + ] = None + parallel_tool_calls: Optional[bool] = True # Chat completions requests do not have a BOS token preference. Backend # respects the tokenization config for the individual model. diff --git a/endpoints/OAI/types/tools.py b/endpoints/OAI/types/tools.py index b5b9611f..33b0496a 100644 --- a/endpoints/OAI/types/tools.py +++ b/endpoints/OAI/types/tools.py @@ -33,3 +33,12 @@ class ToolCall(BaseModel): id: str = Field(default_factory=lambda: str(uuid4()).replace("-", "")[:9]) function: Tool type: Literal["function"] = "function" + + +class NamedToolFunction(BaseModel): + name: str + + +class NamedToolChoice(BaseModel): + function: NamedToolFunction + type: Literal["function"] = "function" diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index b559bb2b..944621d8 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -3,6 +3,7 @@ import asyncio import pathlib from asyncio import CancelledError +from dataclasses import dataclass from typing import List, Optional from fastapi import HTTPException, Request from jinja2 import TemplateError @@ -16,7 +17,9 @@ handle_request_error, request_disconnect_loop, ) +from common.tabby_config import config from common.utils import unwrap +from endpoints.OAI.reasoning import ReasoningParserManager from endpoints.OAI.types.chat_completion import ( ChatCompletionLogprobs, ChatCompletionLogprob, @@ -32,6 +35,59 @@ from endpoints.OAI.utils.tools import ToolCallProcessor, TOOL_CALL_SCHEMA +@dataclass +class _StreamReasoningState: + text: str = "" + token_ids: List[int] = None + + def __post_init__(self): + if self.token_ids is None: + self.token_ids = [] + + +class _TokenizerAdapter: + """Expose a tiny tokenizer interface required by reasoning parsers.""" + + def __init__(self): + self._vocab = None + + def get_vocab(self) -> dict[str, int]: + if self._vocab is not None: + return self._vocab + + pieces = model.container.tokenizer.get_id_to_piece_list(True) + vocab: dict[str, int] = {} + for token_id, piece in enumerate(pieces): + if piece not in vocab: + vocab[piece] = token_id + self._vocab = vocab + return vocab + + +def _token_ids_from_generation(generation: dict) -> List[int]: + token_ids = generation.get("token_ids") + if token_ids is None: + return [] + if isinstance(token_ids, list): + return token_ids + # Covers tensor-like objects (PyTorch) + if hasattr(token_ids, "flatten"): + return token_ids.flatten().tolist() + return list(token_ids) + + +def _build_reasoning_parser(data: ChatCompletionRequest): + parser_key = unwrap(config.model.reasoning_parser, "basic") or "basic" + try: + parser_cls = ReasoningParserManager.get_reasoning_parser(parser_key) + except KeyError as exc: + raise HTTPException(400, str(exc)) from exc + + chat_template_kwargs = unwrap(data.template_vars, {}) + tokenizer_adapter = _TokenizerAdapter() + return parser_cls(tokenizer_adapter, chat_template_kwargs=chat_template_kwargs) + + def _create_response( request_id: str, generations: List[dict], model_name: Optional[str] ): @@ -39,13 +95,20 @@ def _create_response( choices = [] for index, generation in enumerate(generations): + reasoning = generation.get("reasoning") + reasoning_content = generation.get("reasoning_content") + message = ChatCompletionMessage( - role="assistant", content=unwrap(generation.get("text"), "") + role="assistant", + content=generation.get("text"), + reasoning=reasoning, + reasoning_content=reasoning_content, ) tool_calls = generation["tool_calls"] if tool_calls: message.tool_calls = ToolCallProcessor.from_json(tool_calls) + message.content = None logprob_response = None @@ -152,8 +215,13 @@ def _create_stream_chunk( choices.append(choice) else: + reasoning = generation.get("reasoning") + reasoning_content = generation.get("reasoning_content") message = ChatCompletionMessage( - role="assistant", content=unwrap(generation.get("text"), "") + role="assistant", + content=generation.get("text"), + reasoning=reasoning, + reasoning_content=reasoning_content, ) logprob_response = None @@ -320,6 +388,8 @@ async def stream_generate_chat_completion( gen_tasks: List[asyncio.Task] = [] tool_start = model.container.prompt_template.metadata.tool_start disconnect_task = asyncio.create_task(request_disconnect_loop(request)) + reasoning_parser = _build_reasoning_parser(data) + reasoning_states = [_StreamReasoningState() for _ in range(0, data.n)] try: logger.info(f"Received chat completion streaming request {request.state.id}") @@ -372,6 +442,41 @@ async def stream_generate_chat_completion( if isinstance(generation, Exception): raise generation + if "text" in generation and generation.get("finish_reason") is None: + idx = generation["index"] + state = reasoning_states[idx] + + delta_text = unwrap(generation.get("text"), "") + delta_token_ids = _token_ids_from_generation(generation) + + current_text = state.text + delta_text + current_token_ids = state.token_ids + delta_token_ids + + delta_message = reasoning_parser.extract_reasoning_streaming( + state.text, + current_text, + delta_text, + state.token_ids, + current_token_ids, + delta_token_ids, + ) + + state.text = current_text + state.token_ids = current_token_ids + + if delta_message is None: + continue + + generation["text"] = delta_message.content + if data.include_reasoning: + generation["reasoning"] = delta_message.reasoning + generation["reasoning_content"] = delta_message.reasoning + else: + generation["reasoning"] = None + generation["reasoning_content"] = None + if generation["text"] is None: + continue + response = _create_stream_chunk( request.state.id, generation, model_path.name ) @@ -436,12 +541,30 @@ async def generate_chat_completion( generations = await asyncio.gather(*gen_tasks) - # Check all the generations and see if a tool call is required - if tool_start: + # Run tool pass when template exposes tool_start OR caller explicitly + # requires a tool call (required / named function). + force_tool_pass = data.tool_choice == "required" or hasattr( + data.tool_choice, "function" + ) + if tool_start or force_tool_pass: generations = await generate_tool_calls( prompt, embeddings, data, generations, request ) + reasoning_parser = _build_reasoning_parser(data) + for generation in generations: + reasoning, content = reasoning_parser.extract_reasoning( + unwrap(generation.get("text"), ""), + data, + ) + + if not data.include_reasoning: + reasoning = None + + generation["reasoning"] = reasoning + generation["reasoning_content"] = reasoning + generation["text"] = content + response = _create_response(request.state.id, generations, model_path.name) logger.info(f"Finished chat completion request {request.state.id}") @@ -467,6 +590,10 @@ async def generate_tool_calls( ): gen_tasks: List[asyncio.Task] = [] tool_start = model.container.prompt_template.metadata.tool_start + tool_choice = data.tool_choice + + if tool_choice == "none": + return generations # Tracks which generations asked for a tool call tool_idx: List[int] = [] @@ -475,16 +602,31 @@ async def generate_tool_calls( tool_data = data.model_copy(deep=True) tool_data.json_schema = TOOL_CALL_SCHEMA + named_tool_name = None + if hasattr(tool_choice, "function") and tool_choice: + named_tool_name = tool_choice.function.name + for idx, gen in enumerate(generations): - if gen["stop_str"] != tool_start: + stop_str = gen.get("stop_str") + should_generate = stop_str == tool_start + + # If tool calls are required, run tool schema pass even when stop token + # was not emitted, using the existing completion as precursor. + if tool_choice == "required" and not should_generate: + should_generate = True + if named_tool_name and not should_generate: + should_generate = True + + if not should_generate: continue logger.info(f"Detected tool call in chat completion request {request.state.id}") # Append the existing generation text if present + prompt_with_precursor = prompt precursor_text = gen.get("full_text") if precursor_text: - prompt = prompt + precursor_text + prompt_with_precursor = prompt + precursor_text gen_request_id = gen.get("request_id") tool_request_id = f"{gen_request_id}-tool" @@ -493,7 +635,7 @@ async def generate_tool_calls( asyncio.create_task( model.container.generate( tool_request_id, - prompt, + prompt_with_precursor, tool_data, mm_embeddings=embeddings, ) @@ -507,6 +649,17 @@ async def generate_tool_calls( # Map tool calls to their appropriate generation for gen_idx, tool_call in zip(tool_idx, tool_calls, strict=True): - generations[gen_idx]["tool_calls"] = tool_call["text"] + tool_calls_json = tool_call["text"] + if named_tool_name: + try: + tool_calls_json = ToolCallProcessor.filter_by_name( + tool_calls_json, named_tool_name + ) + except Exception: + logger.warning( + "Could not filter tool calls by name '%s'", named_tool_name + ) + + generations[gen_idx]["tool_calls"] = tool_calls_json return generations diff --git a/endpoints/OAI/utils/tools.py b/endpoints/OAI/utils/tools.py index c1ebdedf..76a81841 100644 --- a/endpoints/OAI/utils/tools.py +++ b/endpoints/OAI/utils/tools.py @@ -1,4 +1,5 @@ import json +import re from loguru import logger from typing import List @@ -29,11 +30,90 @@ class ToolCallProcessor: + @staticmethod + def _normalize_tool_calls(raw) -> list[dict]: + """ + Normalize model-emitted tool call payloads into OAI-like objects. + + Accepted forms: + - [{"type":"function","function":{"name":...,"arguments":{...}}}] + - [{"name":...,"arguments":{...}}] + - {"name":...,"arguments":{...}} + """ + if isinstance(raw, dict): + raw = [raw] + if not isinstance(raw, list): + raise json.JSONDecodeError("tool_calls payload is not list/dict", str(raw), 0) + + normalized: list[dict] = [] + for item in raw: + if not isinstance(item, dict): + continue + + if "function" in item and isinstance(item["function"], dict): + fn = item["function"] + name = fn.get("name") + arguments = fn.get("arguments", {}) + else: + name = item.get("name") + arguments = item.get("arguments", {}) + + if name is None: + continue + if isinstance(arguments, str): + try: + arguments = json.loads(arguments) + except json.JSONDecodeError: + arguments = {"input": arguments} + + normalized.append( + { + "type": "function", + "function": { + "name": name, + "arguments": arguments if isinstance(arguments, dict) else {}, + }, + } + ) + return normalized + + @staticmethod + def _safe_loads(payload: str): + """Best-effort JSON parse for model-emitted tool payloads.""" + try: + return ToolCallProcessor._normalize_tool_calls(json.loads(payload)) + except json.JSONDecodeError: + # Common model artifacts: fenced code blocks or leading text. + cleaned = payload.strip() + cleaned = re.sub(r"^```(?:json)?\s*", "", cleaned) + cleaned = re.sub(r"\s*```$", "", cleaned) + + # EXAONE-like tool tags: {...} + tag_matches = re.findall(r"\s*(\{.*?\})\s*", cleaned, re.DOTALL) + if tag_matches: + parsed = [] + for candidate in tag_matches: + parsed.append(json.loads(candidate)) + return ToolCallProcessor._normalize_tool_calls(parsed) + + start = cleaned.find("[") + end = cleaned.rfind("]") + if start != -1 and end != -1 and end > start: + cleaned = cleaned[start : end + 1] + return ToolCallProcessor._normalize_tool_calls(json.loads(cleaned)) + + # Fallback: single JSON object + obj_start = cleaned.find("{") + obj_end = cleaned.rfind("}") + if obj_start != -1 and obj_end != -1 and obj_end > obj_start: + cleaned = cleaned[obj_start : obj_end + 1] + return ToolCallProcessor._normalize_tool_calls(json.loads(cleaned)) + @staticmethod def from_json(tool_calls_str: str) -> List[ToolCall]: """Postprocess tool call JSON to a parseable class""" - tool_calls = json.loads(tool_calls_str) + tool_calls = ToolCallProcessor._safe_loads(tool_calls_str) for tool_call in tool_calls: tool_call["function"]["arguments"] = json.dumps( tool_call["function"]["arguments"] @@ -83,3 +163,13 @@ def to_json(tool_calls: List[ToolCall]) -> str: # Serialize the dumped array return json.dumps(dumped_tool_calls, indent=2) + + @staticmethod + def filter_by_name(tool_calls_str: str, function_name: str) -> str: + tool_calls = ToolCallProcessor._safe_loads(tool_calls_str) + filtered = [ + item + for item in tool_calls + if item.get("function", {}).get("name") == function_name + ] + return json.dumps(filtered) From 2d58b3f5492400795cd9d30828631655d52dffc8 Mon Sep 17 00:00:00 2001 From: lesj0610 Date: Thu, 12 Feb 2026 09:43:05 +0900 Subject: [PATCH 02/19] fix(oai): map top-level thinking flags into chat template kwargs --- endpoints/OAI/types/chat_completion.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/endpoints/OAI/types/chat_completion.py b/endpoints/OAI/types/chat_completion.py index 5d8913d4..a5f45ee1 100644 --- a/endpoints/OAI/types/chat_completion.py +++ b/endpoints/OAI/types/chat_completion.py @@ -1,4 +1,4 @@ -from pydantic import AliasChoices, BaseModel, Field, field_validator +from pydantic import AliasChoices, BaseModel, Field, field_validator, model_validator from time import time from typing import Literal, Union, List, Optional, Dict from uuid import uuid4 @@ -65,6 +65,8 @@ class ChatCompletionRequest(CommonCompletionRequest): validation_alias=AliasChoices("template_vars", "chat_template_kwargs"), description="Aliases: chat_template_kwargs", ) + enable_thinking: Optional[bool] = None + thinking: Optional[bool] = None response_prefix: Optional[str] = None model: Optional[str] = None include_reasoning: Optional[bool] = True @@ -88,6 +90,20 @@ def force_bos_token(cls, v): """Always disable add_bos_token with chat completions.""" return None + @model_validator(mode="after") + def apply_thinking_aliases(self): + """Support clients that send thinking flags at the top-level.""" + template_vars = dict(self.template_vars or {}) + + if self.enable_thinking is not None and "enable_thinking" not in template_vars: + template_vars["enable_thinking"] = self.enable_thinking + + if self.thinking is not None and "thinking" not in template_vars: + template_vars["thinking"] = self.thinking + + self.template_vars = template_vars + return self + class ChatCompletionResponse(BaseModel): id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}") From 355bac9c58315934056d00b4bbbf18ec2bf35821 Mon Sep 17 00:00:00 2001 From: lesj0610 Date: Wed, 18 Feb 2026 22:58:36 +0900 Subject: [PATCH 03/19] feat(oai): add vllm-compatible parser options and reasoning/tool parsing --- THIRD_PARTY_LICENSES.md | 3 + backends/exllamav3/model.py | 24 +- common/config_models.py | 27 +- common/templating.py | 28 +- config_sample.yml | 14 + docs/02.-Server-options.md | 46 +-- docs/04.-Chat-Completions.md | 9 +- docs/10.-Tool-Calling.md | 13 +- endpoints/OAI/reasoning/__init__.py | 12 +- endpoints/OAI/types/chat_completion.py | 4 +- endpoints/OAI/types/tools.py | 17 +- endpoints/OAI/utils/chat_completion.py | 483 ++++++++++++++++++++----- endpoints/OAI/utils/completion.py | 28 +- endpoints/OAI/utils/parser_options.py | 60 +++ endpoints/OAI/utils/tools.py | 473 +++++++++++++++++++++--- templates/tool_calls/qwen3_coder.jinja | 125 +++++++ tests/parser_options_test.py | 27 ++ tests/tool_parser_test.py | 82 +++++ 18 files changed, 1272 insertions(+), 203 deletions(-) create mode 100644 endpoints/OAI/utils/parser_options.py create mode 100644 templates/tool_calls/qwen3_coder.jinja create mode 100644 tests/parser_options_test.py create mode 100644 tests/tool_parser_test.py diff --git a/THIRD_PARTY_LICENSES.md b/THIRD_PARTY_LICENSES.md index 962f9d54..f732bb73 100644 --- a/THIRD_PARTY_LICENSES.md +++ b/THIRD_PARTY_LICENSES.md @@ -26,6 +26,9 @@ licensed under Apache License 2.0: - `endpoints/OAI/reasoning/step3_reasoning_parser.py` - `endpoints/OAI/reasoning/step3p5_reasoning_parser.py` - `endpoints/OAI/reasoning/__init__.py` +- `endpoints/OAI/utils/parser_options.py` +- `endpoints/OAI/utils/tools.py` +- `templates/tool_calls/qwen3_coder.jinja` Source project: - vLLM: https://github.com/vllm-project/vllm diff --git a/backends/exllamav3/model.py b/backends/exllamav3/model.py index 33f5957d..9780f940 100644 --- a/backends/exllamav3/model.py +++ b/backends/exllamav3/model.py @@ -996,6 +996,7 @@ async def generate_gen( max_rq_tokens=self.max_rq_tokens, filters=grammar_handler.filters, ) + self.active_job_ids[request_id] = job generated_tokens = 0 full_response = "" @@ -1013,22 +1014,21 @@ async def generate_gen( if chunk: chunk_tokens = result.get("token_ids", self.tokenizer.encode(chunk)) full_response += chunk - token_ids: list[int] = [] + + # Extract token IDs as a plain list for downstream consumers if isinstance(chunk_tokens, torch.Tensor): - token_ids = chunk_tokens.flatten().tolist() + token_id_list = chunk_tokens.flatten().tolist() + generated_tokens += chunk_tokens.size(dim=0) elif isinstance(chunk_tokens, tuple): - # Some kernels may return tuple[token_ids, ...] first = chunk_tokens[0] if isinstance(first, torch.Tensor): - token_ids = first.flatten().tolist() + token_id_list = first.flatten().tolist() else: - token_ids = list(first) + token_id_list = list(first) + generated_tokens += len(token_id_list) else: - token_ids = list(chunk_tokens) - if isinstance(chunk_tokens, torch.Tensor): - generated_tokens += chunk_tokens.size(dim=0) - elif token_ids: - generated_tokens += len(token_ids) + token_id_list = list(chunk_tokens) + generated_tokens += len(token_id_list) # Increase penalty range to generated token amount # TODO: @@ -1038,7 +1038,7 @@ async def generate_gen( generation = { "request_id": request_id, "text": chunk, - "token_ids": token_ids, + "token_ids": token_id_list, "prompt_tokens": context_len, "generated_tokens": generated_tokens, "offset": len(full_response), @@ -1059,8 +1059,6 @@ async def generate_gen( yield finish_chunk break - # Assign the active job to the request ID - self.active_job_ids[request_id] = job except asyncio.CancelledError: await job.cancel() diff --git a/common/config_models.py b/common/config_models.py index cdfa603b..5d89b58a 100644 --- a/common/config_models.py +++ b/common/config_models.py @@ -288,7 +288,7 @@ class ModelConfig(BaseConfigModel): None, description=( "Set the prompt template for this model. (default: None)\n" - "If empty, attempts to look for the model's chat template.\n" + "If empty, attempts to look for the model's chat template.\n" "If a model contains multiple templates in its tokenizer_config.json,\n" "set prompt_template to the name of the template you want to use.\n" "NOTE: Only works with chat completion message lists!" @@ -302,11 +302,34 @@ class ModelConfig(BaseConfigModel): "If omitted, defaults to 'basic'." ), ) + enable_auto_tool_choice: Optional[bool] = Field( + False, + description=( + "Enable auto tool choice for chat completions (default: False).\n" + "Equivalent to vLLM's --enable-auto-tool-choice.\n" + "Requires tool_call_parser to be set." + ), + ) + tool_call_parser: Optional[str] = Field( + None, + description=( + "Tool parser key for model-generated tool call output.\n" + "Equivalent to vLLM's --tool-call-parser.\n" + "Example values: qwen3_coder, qwen3_xml, mistral, hermes, openai." + ), + ) + exclude_tools_when_tool_choice_none: Optional[bool] = Field( + False, + description=( + "Exclude tool definitions from prompt when tool_choice='none'.\n" + "Equivalent to vLLM's --exclude-tools-when-tool-choice-none." + ), + ) vision: Optional[bool] = Field( False, description=( "Enables vision support if the model supports it. (default: False)" - ), + ), ) _metadata: Metadata = PrivateAttr(Metadata()) diff --git a/common/templating.py b/common/templating.py index 1a65a73b..dda06d85 100644 --- a/common/templating.py +++ b/common/templating.py @@ -25,12 +25,17 @@ class TemplateLoadError(Exception): pass +VALID_TOOL_CALL_FORMATS = {"json", "xml", "auto"} + + @dataclass class TemplateMetadata: """Represents the parsed metadata from a template.""" stop_strings: List[str] = field(default_factory=list) tool_start: Optional[str] = None + tool_end: Optional[str] = None + tool_call_format: str = "json" class PromptTemplate: @@ -49,11 +54,10 @@ class PromptTemplate: @staticmethod def _tojson_compat(value, indent=None, ensure_ascii=True): - """ - Compatibility JSON filter for chat templates. + """Compatibility JSON filter for chat templates. - Some model templates call `tojson(ensure_ascii=False)` while the - bundled Jinja filter may not accept that keyword in this environment. + Some model templates call ``tojson(ensure_ascii=False)`` while the + bundled Jinja filter may not accept that keyword in sandboxed mode. """ return Markup( json.dumps( @@ -94,6 +98,22 @@ async def extract_metadata(self, template_vars: dict): if isinstance(template_module.tool_start, str): template_metadata.tool_start = template_module.tool_start + if hasattr(template_module, "tool_end"): + if isinstance(template_module.tool_end, str): + template_metadata.tool_end = template_module.tool_end + + if hasattr(template_module, "tool_call_format"): + fmt = template_module.tool_call_format + if isinstance(fmt, str) and fmt in VALID_TOOL_CALL_FORMATS: + template_metadata.tool_call_format = fmt + logger.debug(f"Template tool_call_format: {fmt}") + else: + logger.warning( + f"Invalid tool_call_format '{fmt}' in template, " + f"defaulting to 'json'. " + f"Valid values: {VALID_TOOL_CALL_FORMATS}" + ) + self.metadata = template_metadata return template_metadata diff --git a/config_sample.yml b/config_sample.yml index b13a4cd9..f6dfab66 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -158,6 +158,20 @@ model: # If omitted, TabbyAPI defaults to `basic`. reasoning_parser: + # Enable automatic tool selection (default: False). + # Equivalent to vLLM --enable-auto-tool-choice. + # Requires tool_call_parser to be set. + enable_auto_tool_choice: false + + # Tool parser key for model-generated tool call text. + # Equivalent to vLLM --tool-call-parser. + # Common values: qwen3_coder, qwen3_xml, mistral, hermes, openai. + tool_call_parser: + + # Exclude tool definitions from prompt when tool_choice='none'. + # Equivalent to vLLM --exclude-tools-when-tool-choice-none. + exclude_tools_when_tool_choice_none: false + # Enables vision support if the model supports it. (default: False) vision: false diff --git a/docs/02.-Server-options.md b/docs/02.-Server-options.md index 98cee556..4da7766b 100644 --- a/docs/02.-Server-options.md +++ b/docs/02.-Server-options.md @@ -53,27 +53,31 @@ Note: These are experimental flags that may be removed at any point. Note: Most of the options here will only apply on initial model load/startup (ephemeral). They will not persist unless you add the option name to `use_as_default`. -| Config Option | Type (Default) | Description | -| --------------------- | -------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -| model_dir | String ("models") | Directory to look for models.

Note: Persisted across subsequent load requests | -| inline_model_loading | Bool (False) | Enables ability to switch models using the `model` argument in a generation request. More info in [Usage](https://github.com/theroyallab/tabbyAPI/wiki/03.-Usage#inline-loading) | -| use_dummy_models | Bool (False) | Send a dummy OAI model card when calling the `/v1/models` endpoint. Used for clients which enforce specific OAI models.

Note: Persisted across subsequent load requests | -| dummy_model_names | List[String] (["gpt-3.5-turbo"]) | List of dummy names to send on model endpoint requests | -| model_name | String (None) | Folder name of a model to load. The below parameters will not apply unless this is filled out. | -| use_as_default | List[String] ([]) | Keys to use by default when loading models. For example, putting `cache_mode` in this array will make every model load with that value unless specified by the API request.

Note: Also applies to the `draft` sub-block | -| max_seq_len | Float (None) | Maximum sequence length of the model. Uses the value from config.json if not specified here. Also called the max context length. | -| tensor_parallel | Bool (False) | Enables tensor parallelism. Automatically falls back to autosplit if GPU split isn't provided.

Note: `gpu_split_auto` is ignored when this is enabled. | -| gpu_split_auto | Bool (True) | Automatically split the model across multiple GPUs. Manual GPU split isn't used if this is enabled. | -| autosplit_reserve | List[Int] ([96]) | Amount of empty VRAM to reserve when loading with autosplit.

Represented as an array of MB per GPU used. | -| gpu_split | List[Float] ([]) | Float array of GBs to split a model between GPUs. | -| rope_scale | Float (1.0) | Adjustment for rope scale (or compress_pos_emb)

Note: If the model has YaRN support, this option will not apply. | -| rope_alpha | Float (None) | Adjustment for rope alpha. Leave blank to automatically calculate based on the max_seq_len.

Note: If the model has YaRN support, this option will not apply. | -| cache_mode | String ("FP16") | Cache mode for the model.

Options: FP16, Q8, Q6, Q4 | -| cache_size | Int (max_seq_len) | Size of the K/V cache

Note: If using CFG, the cache size should be 2 * max_seq_len. | -| chunk_size | Int (2048) | Amount of tokens per chunk with ingestion. A lower value reduces VRAM usage at the cost of ingestion speed. | -| max_batch_size | Int (None) | The absolute maximum amount of prompts to process at one time. This value is automatically adjusted based on cache size. | -| prompt_template | String (None) | Name of a jinja2 chat template to apply for this model. Must be located in the `templates` directory. | -| vision | Bool (False) | Enable vision support for the provided model (if it exists). | +| Config Option | Type (Default) | Description | +| ------------------------------------ | -------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| model_dir | String ("models") | Directory to look for models.

Note: Persisted across subsequent load requests | +| inline_model_loading | Bool (False) | Enables ability to switch models using the `model` argument in a generation request. More info in [Usage](https://github.com/theroyallab/tabbyAPI/wiki/03.-Usage#inline-loading) | +| use_dummy_models | Bool (False) | Send a dummy OAI model card when calling the `/v1/models` endpoint. Used for clients which enforce specific OAI models.

Note: Persisted across subsequent load requests | +| dummy_model_names | List[String] (["gpt-3.5-turbo"]) | List of dummy names to send on model endpoint requests | +| model_name | String (None) | Folder name of a model to load. The below parameters will not apply unless this is filled out. | +| use_as_default | List[String] ([]) | Keys to use by default when loading models. For example, putting `cache_mode` in this array will make every model load with that value unless specified by the API request.

Note: Also applies to the `draft` sub-block | +| max_seq_len | Float (None) | Maximum sequence length of the model. Uses the value from config.json if not specified here. Also called the max context length. | +| tensor_parallel | Bool (False) | Enables tensor parallelism. Automatically falls back to autosplit if GPU split isn't provided.

Note: `gpu_split_auto` is ignored when this is enabled. | +| gpu_split_auto | Bool (True) | Automatically split the model across multiple GPUs. Manual GPU split isn't used if this is enabled. | +| autosplit_reserve | List[Int] ([96]) | Amount of empty VRAM to reserve when loading with autosplit.

Represented as an array of MB per GPU used. | +| gpu_split | List[Float] ([]) | Float array of GBs to split a model between GPUs. | +| rope_scale | Float (1.0) | Adjustment for rope scale (or compress_pos_emb)

Note: If the model has YaRN support, this option will not apply. | +| rope_alpha | Float (None) | Adjustment for rope alpha. Leave blank to automatically calculate based on the max_seq_len.

Note: If the model has YaRN support, this option will not apply. | +| cache_mode | String ("FP16") | Cache mode for the model.

Options: FP16, Q8, Q6, Q4 | +| cache_size | Int (max_seq_len) | Size of the K/V cache

Note: If using CFG, the cache size should be 2 * max_seq_len. | +| chunk_size | Int (2048) | Amount of tokens per chunk with ingestion. A lower value reduces VRAM usage at the cost of ingestion speed. | +| max_batch_size | Int (None) | The absolute maximum amount of prompts to process at one time. This value is automatically adjusted based on cache size. | +| prompt_template | String (None) | Name of a jinja2 chat template to apply for this model. Must be located in the `templates` directory. | +| reasoning_parser | String (None) | Reasoning parser key used to split reasoning and final answer text (vLLM-compatible names, default parser behavior is `basic`). | +| enable_auto_tool_choice | Bool (False) | Enables vLLM-style automatic tool choice handling. Equivalent to `--enable-auto-tool-choice` and requires `tool_call_parser`. | +| tool_call_parser | String (None) | vLLM-compatible tool parser key used to parse model-emitted tool calls. Equivalent to `--tool-call-parser`. | +| exclude_tools_when_tool_choice_none | Bool (False) | Excludes tool definitions from the prompt when `tool_choice` is `"none"`. Equivalent to `--exclude-tools-when-tool-choice-none`. | +| vision | Bool (False) | Enable vision support for the provided model (if it exists). | ### Draft Model Options diff --git a/docs/04.-Chat-Completions.md b/docs/04.-Chat-Completions.md index 647ee92d..96044599 100644 --- a/docs/04.-Chat-Completions.md +++ b/docs/04.-Chat-Completions.md @@ -31,4 +31,11 @@ Now let's pass the custom var in the following template: I'm going to say {{ test_var }} ``` -Running render on this template will now result in: `I'm going to say hello!` \ No newline at end of file +Running render on this template will now result in: `I'm going to say hello!` + +### Reasoning controls + +TabbyAPI supports reasoning parser output separation with vLLM-compatible parser keys via `model.reasoning_parser` in `config.yml`. + +- `include_reasoning` request field: include or suppress reasoning output in responses +- `enable_thinking` / `thinking` request fields: accepted as top-level aliases and forwarded to template vars (`template_vars.enable_thinking`, `template_vars.thinking`) diff --git a/docs/10.-Tool-Calling.md b/docs/10.-Tool-Calling.md index 83e379a5..84ffa90f 100644 --- a/docs/10.-Tool-Calling.md +++ b/docs/10.-Tool-Calling.md @@ -12,11 +12,22 @@ TabbyAPI's tool calling implementation aligns with the [OpenAI Standard](https:/ TabbyAPI's tool implementation supports: - Tool calling when streaming - Calling multiple tools per turn +- `tool_choice` values: `none`, `auto`, `required`, and named function choice +- vLLM-style parser selection via `model.tool_call_parser` Current limitations: -- No support for `tool_choice` parameter (always assumed to be auto) - `strict` parameter not yet supported (OAI format ensured, but dtype and argument name choices not yet enforced) +### vLLM-compatible options + +The following model config options are available to align behavior with vLLM: + +- `enable_auto_tool_choice`: equivalent to `--enable-auto-tool-choice` +- `tool_call_parser`: equivalent to `--tool-call-parser` +- `exclude_tools_when_tool_choice_none`: equivalent to `--exclude-tools-when-tool-choice-none` + +`tool_choice="auto"` requires both `enable_auto_tool_choice: true` and `tool_call_parser` to be set. + ## Model Support TabbyAPI exposes controls within the `prompt_template` to accommodate models specifically tuned for tool calling and those that aren't. By default, TabbyAPI includes `chatml_with_headers_tool_calling.jinja`, a generic template built to support the Llama 3.1 family and other models following the ChatML (with headers) format. diff --git a/endpoints/OAI/reasoning/__init__.py b/endpoints/OAI/reasoning/__init__.py index 7aaa3630..d7df6ee1 100644 --- a/endpoints/OAI/reasoning/__init__.py +++ b/endpoints/OAI/reasoning/__init__.py @@ -1,7 +1,7 @@ -from endpoints.OAI.reasoning.abs_reasoning_parsers import ( - # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from endpoints.OAI.reasoning.abs_reasoning_parsers import ( DeltaMessage, ReasoningParser, ReasoningParserManager, @@ -11,11 +11,15 @@ from endpoints.OAI.reasoning.deepseek_v3_reasoning_parser import DeepSeekV3ReasoningParser from endpoints.OAI.reasoning.ernie45_reasoning_parser import Ernie45ReasoningParser from endpoints.OAI.reasoning.exaone4_reasoning_parser import Exaone4ReasoningParser -from endpoints.OAI.reasoning.glm4_moe_reasoning_parser import Glm4MoeModelReasoningParser +from endpoints.OAI.reasoning.glm4_moe_reasoning_parser import ( + Glm4MoeModelReasoningParser, +) from endpoints.OAI.reasoning.gptoss_reasoning_parser import GptOssReasoningParser from endpoints.OAI.reasoning.granite_reasoning_parser import GraniteReasoningParser from endpoints.OAI.reasoning.holo2_reasoning_parser import Holo2ReasoningParser -from endpoints.OAI.reasoning.hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser +from endpoints.OAI.reasoning.hunyuan_a13b_reasoning_parser import ( + HunyuanA13BReasoningParser, +) from endpoints.OAI.reasoning.identity_reasoning_parser import IdentityReasoningParser from endpoints.OAI.reasoning.kimi_k2_reasoning_parser import KimiK2ReasoningParser from endpoints.OAI.reasoning.minimax_m2_reasoning_parser import ( diff --git a/endpoints/OAI/types/chat_completion.py b/endpoints/OAI/types/chat_completion.py index a5f45ee1..61112d4b 100644 --- a/endpoints/OAI/types/chat_completion.py +++ b/endpoints/OAI/types/chat_completion.py @@ -51,7 +51,7 @@ class ChatCompletionStreamChoice(BaseModel): # Index is 0 since we aren't using multiple choices index: int = 0 finish_reason: Optional[str] = None - delta: Union[ChatCompletionMessage, dict] = {} + delta: Union[ChatCompletionMessage, dict] = Field(default_factory=dict) logprobs: Optional[ChatCompletionLogprobs] = None @@ -61,7 +61,7 @@ class ChatCompletionRequest(CommonCompletionRequest): prompt_template: Optional[str] = None add_generation_prompt: Optional[bool] = True template_vars: Optional[dict] = Field( - default={}, + default_factory=dict, validation_alias=AliasChoices("template_vars", "chat_template_kwargs"), description="Aliases: chat_template_kwargs", ) diff --git a/endpoints/OAI/types/tools.py b/endpoints/OAI/types/tools.py index 33b0496a..1e572663 100644 --- a/endpoints/OAI/types/tools.py +++ b/endpoints/OAI/types/tools.py @@ -1,5 +1,5 @@ from pydantic import BaseModel, Field -from typing import Dict, Literal +from typing import Dict, Literal, Optional from uuid import uuid4 @@ -28,17 +28,28 @@ class Tool(BaseModel): class ToolCall(BaseModel): - """Represents an OAI tool description.""" + """Represents an OAI tool call. + + The ``index`` field is optional so it can be omitted in non-streaming + responses (where OpenAI does not include it) via ``exclude_none=True``, + while being set explicitly for streaming deltas where it is required + by strict validators like the Vercel AI SDK. + """ - id: str = Field(default_factory=lambda: str(uuid4()).replace("-", "")[:9]) + id: str = Field(default_factory=lambda: f"call_{uuid4().hex[:24]}") function: Tool type: Literal["function"] = "function" + index: Optional[int] = None class NamedToolFunction(BaseModel): + """Represents a named function reference for tool_choice.""" + name: str class NamedToolChoice(BaseModel): + """Represents a named tool choice (forces a specific function call).""" + function: NamedToolFunction type: Literal["function"] = "function" diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 944621d8..9e081090 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -1,9 +1,10 @@ """Chat completion utilities for OAI server.""" import asyncio +import json import pathlib from asyncio import CancelledError -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import List, Optional from fastapi import HTTPException, Request from jinja2 import TemplateError @@ -31,22 +32,23 @@ ChatCompletionStreamChoice, ) from endpoints.OAI.types.common import UsageStats +from endpoints.OAI.types.tools import NamedToolChoice, ToolCall from endpoints.OAI.utils.completion import _parse_gen_request_id, _stream_collector +from endpoints.OAI.utils.parser_options import ( + list_tool_call_parsers, + resolve_tool_call_format, +) from endpoints.OAI.utils.tools import ToolCallProcessor, TOOL_CALL_SCHEMA @dataclass class _StreamReasoningState: text: str = "" - token_ids: List[int] = None - - def __post_init__(self): - if self.token_ids is None: - self.token_ids = [] + token_ids: List[int] = field(default_factory=list) class _TokenizerAdapter: - """Expose a tiny tokenizer interface required by reasoning parsers.""" + """Expose the minimal tokenizer interface required by reasoning parsers.""" def __init__(self): self._vocab = None @@ -55,7 +57,12 @@ def get_vocab(self) -> dict[str, int]: if self._vocab is not None: return self._vocab - pieces = model.container.tokenizer.get_id_to_piece_list(True) + tokenizer = model.container.tokenizer + if hasattr(tokenizer, "get_vocab"): + self._vocab = tokenizer.get_vocab() + return self._vocab + + pieces = tokenizer.get_id_to_piece_list(True) vocab: dict[str, int] = {} for token_id, piece in enumerate(pieces): if piece not in vocab: @@ -70,26 +77,99 @@ def _token_ids_from_generation(generation: dict) -> List[int]: return [] if isinstance(token_ids, list): return token_ids - # Covers tensor-like objects (PyTorch) if hasattr(token_ids, "flatten"): return token_ids.flatten().tolist() return list(token_ids) -def _build_reasoning_parser(data: ChatCompletionRequest): +def _build_reasoning_parser(request_data: ChatCompletionRequest): parser_key = unwrap(config.model.reasoning_parser, "basic") or "basic" try: parser_cls = ReasoningParserManager.get_reasoning_parser(parser_key) except KeyError as exc: raise HTTPException(400, str(exc)) from exc - chat_template_kwargs = unwrap(data.template_vars, {}) - tokenizer_adapter = _TokenizerAdapter() - return parser_cls(tokenizer_adapter, chat_template_kwargs=chat_template_kwargs) + template_kwargs = unwrap(request_data.template_vars, {}) + try: + return parser_cls(_TokenizerAdapter(), chat_template_kwargs=template_kwargs) + except RuntimeError as exc: + # Keep compatibility for models that do not expose thinking tags. + if parser_key == "basic": + logger.warning( + "Reasoning parser 'basic' could not initialize ({}). " + "Falling back to identity parser.", + str(exc), + ) + identity_cls = ReasoningParserManager.get_reasoning_parser("identity") + return identity_cls(_TokenizerAdapter(), chat_template_kwargs=template_kwargs) + raise HTTPException(400, str(exc)) from exc + + +def _validate_and_get_tool_call_format( + request_data: ChatCompletionRequest, default_format: str +) -> str: + tool_choice = request_data.tool_choice + parser_key = config.model.tool_call_parser + enable_auto = bool(config.model.enable_auto_tool_choice) + parser_names = list_tool_call_parsers() + + if parser_key and parser_key not in parser_names: + parsers_str = ", ".join(sorted(parser_names)) + raise HTTPException( + 400, + f"invalid tool call parser: {parser_key} (choose from {{{parsers_str}}})", + ) + + if tool_choice == "auto" and (not enable_auto or not parser_key): + raise HTTPException( + 400, + '"auto" tool choice requires --enable-auto-tool-choice and ' + "--tool-call-parser to be set", + ) + + if tool_choice not in (None, "none", "auto") and parser_key is None: + raise HTTPException( + 400, + f'tool_choice="{tool_choice}" requires --tool-call-parser to be set', + ) + + if ( + tool_choice == "none" + and config.model.exclude_tools_when_tool_choice_none + and request_data.tools + ): + request_data.tools = None + + resolved_format = resolve_tool_call_format(parser_key, default_format) + if not resolved_format: + raise HTTPException( + 400, + f"Could not resolve format for tool_call_parser={parser_key}", + ) + return resolved_format + + +def _serialize_stream_chunk(chunk) -> str: + """Serialize a streaming chunk with OpenAI-compatible field handling. + + Uses exclude_none=True to strip irrelevant null fields (tool_calls, + tool_call_id, logprobs, usage) while ensuring finish_reason is always + present on each choice (as null when not set), matching OpenAI's + observed streaming behavior. + """ + d = chunk.model_dump(exclude_none=True) + for choice in d.get("choices", []): + if "finish_reason" not in choice: + choice["finish_reason"] = None + return json.dumps(d, ensure_ascii=False) def _create_response( - request_id: str, generations: List[dict], model_name: Optional[str] + request_id: str, + generations: List[dict], + model_name: Optional[str], + tool_call_format: str = "json", + tool_choice=None, ): """Create a chat completion response from the provided text.""" @@ -97,7 +177,6 @@ def _create_response( for index, generation in enumerate(generations): reasoning = generation.get("reasoning") reasoning_content = generation.get("reasoning_content") - message = ChatCompletionMessage( role="assistant", content=generation.get("text"), @@ -105,10 +184,40 @@ def _create_response( reasoning_content=reasoning_content, ) - tool_calls = generation["tool_calls"] - if tool_calls: - message.tool_calls = ToolCallProcessor.from_json(tool_calls) - message.content = None + tool_calls_raw = generation.get("tool_calls") + if tool_calls_raw: + parsed = ToolCallProcessor.parse(tool_calls_raw, format=tool_call_format) + if parsed and isinstance(tool_choice, NamedToolChoice): + parsed = ToolCallProcessor.filter_by_name( + parsed, tool_choice.function.name + ) + if parsed: + message.tool_calls = parsed + message.content = None + else: + logger.warning( + "Tool call text present but parsing returned no results " + f"(format={tool_call_format})" + ) + + # Fallback: detect bare XML tool calls in content that were not + # caught by the two-pass system (model never emitted tool_start) + if ( + tool_call_format in ("xml", "auto") + and not message.tool_calls + and message.content + and " List[ChatCompletionStreamChunk]: + """Build the OpenAI-standard streaming sequence for tool calls. + + Emits two chunks: + 1. Tool-call chunk: role="assistant", complete tool_calls with + index/id/type/name/arguments (all data in one chunk). + 2. Finish chunk: empty delta, finish_reason="tool_calls". + + Complete arguments are sent in a single chunk rather than streamed + incrementally, which is valid per OpenAI's spec (clients concatenate + argument strings across deltas) and maximizes compatibility with + clients that may not implement multi-chunk tool-call assembly. + + The tool_calls are placed directly into a ChatCompletionMessage + (not a raw dict) so Pydantic validates them as ToolCall objects + with the index field preserved (ToolCall declares index as Optional[int]). + """ + chunk_id = f"chatcmpl-{request_id}" + + # Set index on each tool call for streaming + for idx, tc in enumerate(tool_calls): + tc.index = idx + + # Chunk 1: Complete tool call data + tool_call_message = ChatCompletionMessage( + role="assistant", + tool_calls=tool_calls, + ) + tool_chunk = ChatCompletionStreamChunk( + id=chunk_id, + choices=[ + ChatCompletionStreamChoice( + index=choice_index, + delta=tool_call_message, + finish_reason=None, + ) + ], + model=model_name, + ) + + # Chunk 2: Finish signal + # Use model_construct to prevent Pydantic's smart Union from + # coercing the empty dict {} into ChatCompletionMessage(role="user") + finish_choice = ChatCompletionStreamChoice.model_construct( + index=choice_index, + delta={}, + finish_reason="tool_calls", + logprobs=None, + ) + finish_chunk = ChatCompletionStreamChunk( + id=chunk_id, + choices=[finish_choice], + model=model_name, + ) + + return [tool_chunk, finish_chunk] + + async def _append_template_metadata(data: ChatCompletionRequest, template_vars: dict): """Adding metadata is a one-time process.""" @@ -305,6 +472,24 @@ async def format_messages_with_template( message_dicts.append(message.model_dump(exclude_none=True)) + # Pre-template: convert tool_call arguments from JSON strings to dicts. + # OpenAI-compatible clients (Kilo, Roo, etc.) send arguments as JSON + # strings per the OAI spec, but Qwen3-Coder's template calls + # .items() on arguments which requires a dict/mapping. + for msg in message_dicts: + if msg.get("tool_calls"): + for tc in msg["tool_calls"]: + func = tc.get("function", {}) + args = func.get("arguments") + if isinstance(args, str): + try: + func["arguments"] = json.loads(args) + except (json.JSONDecodeError, ValueError): + logger.warning( + "Failed to parse tool_call arguments JSON " + "string to dict, keeping as string" + ) + # Get all special tokens special_tokens_dict = model.container.get_special_tokens() @@ -387,12 +572,19 @@ async def stream_generate_chat_completion( gen_queue = asyncio.Queue() gen_tasks: List[asyncio.Task] = [] tool_start = model.container.prompt_template.metadata.tool_start + default_tool_call_format = model.container.prompt_template.metadata.tool_call_format disconnect_task = asyncio.create_task(request_disconnect_loop(request)) - reasoning_parser = _build_reasoning_parser(data) - reasoning_states = [_StreamReasoningState() for _ in range(0, data.n)] try: logger.info(f"Received chat completion streaming request {request.state.id}") + tool_call_format = _validate_and_get_tool_call_format( + data, default_tool_call_format + ) + reasoning_parser = _build_reasoning_parser(data) + reasoning_states = [_StreamReasoningState() for _ in range(0, data.n)] + force_tool_pass = data.tool_choice == "required" or isinstance( + data.tool_choice, NamedToolChoice + ) for idx in range(0, data.n): task_gen_params = data.model_copy(deep=True) @@ -412,32 +604,26 @@ async def stream_generate_chat_completion( gen_tasks.append(gen_task) - # Text accumulation for tool calls - current_generation_text = "" - # Consumer loop while True: + # Fast path: items already queued — no task overhead + if not gen_queue.empty(): + generation = gen_queue.get_nowait() + else: + # Slow path: queue empty — race get against disconnect + get_task = asyncio.create_task(gen_queue.get()) + done, _ = await asyncio.wait( + [get_task, disconnect_task], + return_when=asyncio.FIRST_COMPLETED, + ) + if disconnect_task in done: + get_task.cancel() + raise CancelledError() + generation = get_task.result() + if disconnect_task.done(): raise CancelledError() - generation = await gen_queue.get() - - # Handle options if a tool model is present - if tool_start: - if "stop_str" in generation: - generations = await generate_tool_calls( - prompt, - embeddings, - data, - [generation], - request, - ) - - # Only one generation present in this case - generation = generations[0] - elif "text" in generation: - current_generation_text += generation["text"] - # Stream collector will push an exception to the queue if it fails if isinstance(generation, Exception): raise generation @@ -477,10 +663,71 @@ async def stream_generate_chat_completion( if generation["text"] is None: continue + # Handle options if a tool model is present + if (tool_start or force_tool_pass) and data.tool_choice != "none": + if "stop_str" in generation: + generations = await generate_tool_calls( + prompt, + embeddings, + data, + [generation], + request, + tool_call_format=tool_call_format, + ) + + # Only one generation present in this case + generation = generations[0] + + # Emit proper three-phase tool-call streaming sequence + if "tool_calls" in generation: + tool_calls_raw = generation["tool_calls"] + parsed = ToolCallProcessor.parse( + tool_calls_raw, format=tool_call_format + ) + if parsed and isinstance(data.tool_choice, NamedToolChoice): + parsed = ToolCallProcessor.filter_by_name( + parsed, data.tool_choice.function.name + ) + if parsed: + for tc_chunk in _build_tool_call_chunks( + parsed, + request.state.id, + model_path.name, + choice_index=generation.get("index", 0), + ): + yield _serialize_stream_chunk(tc_chunk) + + # Handle completion and usage after tool calls + if ( + all(task.done() for task in gen_tasks) + and gen_queue.empty() + ): + if ( + data.stream_options + and data.stream_options.include_usage + ): + usage_chunk = _create_stream_chunk( + request.state.id, + generation, + model_path.name, + is_usage_chunk=True, + ) + yield _serialize_stream_chunk(usage_chunk) + + logger.info( + "Finished chat completion streaming " + f"request {request.state.id}" + ) + yield "[DONE]" + break + continue + response = _create_stream_chunk( - request.state.id, generation, model_path.name + request.state.id, + generation, + model_path.name, ) - yield response.model_dump_json() + yield _serialize_stream_chunk(response) # Check if all tasks are completed if all(task.done() for task in gen_tasks) and gen_queue.empty(): @@ -492,7 +739,7 @@ async def stream_generate_chat_completion( model_path.name, is_usage_chunk=True, ) - yield usage_chunk.model_dump_json() + yield _serialize_stream_chunk(usage_chunk) logger.info( f"Finished chat completion streaming request {request.state.id}" @@ -503,13 +750,16 @@ async def stream_generate_chat_completion( except CancelledError: # Get out if the request gets disconnected - if not abort_event.is_set(): - abort_event.set() - handle_request_disconnect("Chat completion generation cancelled by user.") + handle_request_disconnect("Chat completion generation cancelled by user.") + except HTTPException as exc: + yield get_generator_error(str(exc.detail)) except Exception: yield get_generator_error( "Chat completion aborted. Please check the server console." ) + finally: + abort_event.set() + disconnect_task.cancel() async def generate_chat_completion( @@ -521,6 +771,10 @@ async def generate_chat_completion( ): gen_tasks: List[asyncio.Task] = [] tool_start = model.container.prompt_template.metadata.tool_start + default_tool_call_format = model.container.prompt_template.metadata.tool_call_format + tool_call_format = _validate_and_get_tool_call_format( + data, default_tool_call_format + ) try: logger.info(f"Received chat completion request {request.state.id}") @@ -541,14 +795,18 @@ async def generate_chat_completion( generations = await asyncio.gather(*gen_tasks) - # Run tool pass when template exposes tool_start OR caller explicitly - # requires a tool call (required / named function). - force_tool_pass = data.tool_choice == "required" or hasattr( - data.tool_choice, "function" + # Check all the generations and see if a tool call is required + force_tool_pass = data.tool_choice == "required" or isinstance( + data.tool_choice, NamedToolChoice ) if tool_start or force_tool_pass: generations = await generate_tool_calls( - prompt, embeddings, data, generations, request + prompt, + embeddings, + data, + generations, + request, + tool_call_format=tool_call_format, ) reasoning_parser = _build_reasoning_parser(data) @@ -565,11 +823,19 @@ async def generate_chat_completion( generation["reasoning_content"] = reasoning generation["text"] = content - response = _create_response(request.state.id, generations, model_path.name) + response = _create_response( + request.state.id, + generations, + model_path.name, + tool_call_format=tool_call_format, + tool_choice=data.tool_choice, + ) logger.info(f"Finished chat completion request {request.state.id}") return response + except HTTPException: + raise except Exception as exc: error_message = handle_request_error( f"Chat completion {request.state.id} aborted. " @@ -585,11 +851,17 @@ async def generate_tool_calls( prompt: str, embeddings: MultimodalEmbeddingWrapper, data: ChatCompletionRequest, - generations: List[str], + generations: List[dict], request: Request, + tool_call_format: Optional[str] = None, ): gen_tasks: List[asyncio.Task] = [] tool_start = model.container.prompt_template.metadata.tool_start + if tool_call_format is None: + default_tool_call_format = model.container.prompt_template.metadata.tool_call_format + tool_call_format = _validate_and_get_tool_call_format( + data, default_tool_call_format + ) tool_choice = data.tool_choice if tool_choice == "none": @@ -600,33 +872,61 @@ async def generate_tool_calls( # Copy to make sure the parent JSON schema doesn't get modified tool_data = data.model_copy(deep=True) - tool_data.json_schema = TOOL_CALL_SCHEMA - named_tool_name = None - if hasattr(tool_choice, "function") and tool_choice: - named_tool_name = tool_choice.function.name + if tool_call_format in ("xml", "auto"): + # XML / auto mode: let the model generate its natural output + # without JSON schema constraint + logger.debug( + f"generate_tool_calls: Using '{tool_call_format}' mode " + f"(no JSON schema constraint)" + ) + + # Remove tool_start from stop strings so the model can emit + # multiple sequential blocks without stopping early + if ( + tool_start + and isinstance(tool_data.stop, list) + and tool_start in tool_data.stop + ): + tool_data.stop = [s for s in tool_data.stop if s != tool_start] + logger.debug( + f"generate_tool_calls: Removed '{tool_start}' from " + f"second-pass stop strings" + ) + else: + # JSON mode: constrained generation (existing behavior) + tool_data.json_schema = TOOL_CALL_SCHEMA for idx, gen in enumerate(generations): stop_str = gen.get("stop_str") should_generate = stop_str == tool_start - # If tool calls are required, run tool schema pass even when stop token - # was not emitted, using the existing completion as precursor. - if tool_choice == "required" and not should_generate: - should_generate = True - if named_tool_name and not should_generate: + # Force tool generation if tool_choice requires it + if not should_generate and ( + tool_choice == "required" or isinstance(tool_choice, NamedToolChoice) + ): should_generate = True if not should_generate: continue - logger.info(f"Detected tool call in chat completion request {request.state.id}") + logger.info( + f"Detected tool call in chat completion request " + f"{request.state.id} (format={tool_call_format})" + ) - # Append the existing generation text if present - prompt_with_precursor = prompt + # Build per-generation prompt (avoid mutating shared prompt) + tool_prompt = prompt precursor_text = gen.get("full_text") if precursor_text: - prompt_with_precursor = prompt + precursor_text + tool_prompt = tool_prompt + precursor_text + + # For XML/auto mode: append tool_start back to prompt. + # The stop string was consumed by the first pass and not included + # in full_text, but the model expects to continue after . + # Include a trailing newline to match the canonical template format. + if tool_call_format in ("xml", "auto") and tool_start: + tool_prompt = tool_prompt + tool_start + "\n" gen_request_id = gen.get("request_id") tool_request_id = f"{gen_request_id}-tool" @@ -635,7 +935,7 @@ async def generate_tool_calls( asyncio.create_task( model.container.generate( tool_request_id, - prompt_with_precursor, + tool_prompt, tool_data, mm_embeddings=embeddings, ) @@ -649,17 +949,12 @@ async def generate_tool_calls( # Map tool calls to their appropriate generation for gen_idx, tool_call in zip(tool_idx, tool_calls, strict=True): - tool_calls_json = tool_call["text"] - if named_tool_name: - try: - tool_calls_json = ToolCallProcessor.filter_by_name( - tool_calls_json, named_tool_name - ) - except Exception: - logger.warning( - "Could not filter tool calls by name '%s'", named_tool_name - ) + raw_text = tool_call["text"] + + if tool_call_format in ("xml", "auto") and tool_start: + # Prepend tool_start to reconstruct complete XML for parser + raw_text = tool_start + "\n" + raw_text - generations[gen_idx]["tool_calls"] = tool_calls_json + generations[gen_idx]["tool_calls"] = raw_text return generations diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index f66d381d..c11a25bf 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -225,11 +225,24 @@ async def stream_generate_completion( # Consumer loop while True: + # Fast path: items already queued — no task overhead + if not gen_queue.empty(): + generation = gen_queue.get_nowait() + else: + # Slow path: queue empty — race get against disconnect + get_task = asyncio.create_task(gen_queue.get()) + done, _ = await asyncio.wait( + [get_task, disconnect_task], + return_when=asyncio.FIRST_COMPLETED, + ) + if disconnect_task in done: + get_task.cancel() + raise CancelledError() + generation = get_task.result() + if disconnect_task.done(): raise CancelledError() - generation = await gen_queue.get() - # Stream collector will push an exception to the queue if it fails if isinstance(generation, Exception): raise generation @@ -245,15 +258,16 @@ async def stream_generate_completion( except CancelledError: # Get out if the request gets disconnected - if not abort_event.is_set(): - abort_event.set() - handle_request_disconnect( - f"Completion generation {request.state.id} cancelled by user." - ) + handle_request_disconnect( + f"Completion generation {request.state.id} cancelled by user." + ) except Exception: yield get_generator_error( f"Completion {request.state.id} aborted. Please check the server console." ) + finally: + abort_event.set() + disconnect_task.cancel() async def generate_completion( diff --git a/endpoints/OAI/utils/parser_options.py b/endpoints/OAI/utils/parser_options.py new file mode 100644 index 00000000..99fe5ad3 --- /dev/null +++ b/endpoints/OAI/utils/parser_options.py @@ -0,0 +1,60 @@ +"""Parser option helpers for vLLM-compatible chat settings.""" + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Dict, Set + + +# Mirrors vLLM parser keys to keep CLI/config ergonomics familiar. +# Source of truth: vllm/tool_parsers/__init__.py::_TOOL_PARSERS_TO_REGISTER +# Format is the parsing mode supported by ToolCallProcessor. +TOOL_CALL_PARSER_FORMATS: Dict[str, str] = { + "deepseek_v3": "json", + "deepseek_v31": "json", + "deepseek_v32": "json", + "ernie45": "json", + "glm45": "json", + "glm47": "json", + "granite-20b-fc": "json", + "granite": "json", + "hermes": "json", + "hunyuan_a13b": "json", + "internlm": "json", + "jamba": "json", + "kimi_k2": "json", + "llama3_json": "json", + "llama4_json": "json", + "llama4_pythonic": "json", + "longcat": "json", + "minimax_m2": "json", + "minimax": "json", + "mistral": "json", + "olmo3": "json", + "openai": "json", + "phi4_mini_json": "json", + "pythonic": "json", + "qwen3_coder": "xml", + "qwen3_xml": "xml", + "seed_oss": "json", + "step3": "json", + "step3p5": "json", + "xlam": "json", + "gigachat3": "json", + "functiongemma": "json", + # Convenience alias for mixed/inferred content + "auto": "auto", +} + + +def list_tool_call_parsers() -> Set[str]: + return set(TOOL_CALL_PARSER_FORMATS.keys()) + + +def resolve_tool_call_format( + tool_call_parser: str | None, fallback_format: str +) -> str: + """Resolve effective parser format from configured parser key.""" + if not tool_call_parser: + return fallback_format + return TOOL_CALL_PARSER_FORMATS.get(tool_call_parser, "") diff --git a/endpoints/OAI/utils/tools.py b/endpoints/OAI/utils/tools.py index 76a81841..aa09b49e 100644 --- a/endpoints/OAI/utils/tools.py +++ b/endpoints/OAI/utils/tools.py @@ -1,9 +1,14 @@ +"""Tool call processing utilities for OAI server.""" + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + import json import re from loguru import logger -from typing import List +from typing import Any, List, Tuple -from endpoints.OAI.types.tools import ToolCall +from endpoints.OAI.types.tools import ToolCall, Tool TOOL_CALL_SCHEMA = { @@ -28,12 +33,107 @@ }, } +# --------------------------------------------------------------------------- +# XML parsing regex patterns +# Derived from vLLM's Qwen3CoderToolParser and the official Qwen parser. +# These handle both complete and partially-closed tags. +# --------------------------------------------------------------------------- + +# Matches complete ... blocks +TOOL_CALL_BLOCK_RE = re.compile( + r"(.*?)", + re.DOTALL, +) + +# Matches BODY blocks +FUNCTION_RE = re.compile( + r"(.*?)", + re.DOTALL, +) + +# Matches VALUE +# Terminates on: , next , or +PARAMETER_RE = re.compile( + r"(.*?)" + r"(?:|(?=)|(?=))", + re.DOTALL, +) + +# Think block patterns +THINK_BLOCK_RE = re.compile(r".*?\s*", re.DOTALL) +THINK_UNCLOSED_RE = re.compile(r"(?!.*).*$", re.DOTALL) + +# Markdown code fence patterns +CODE_FENCE_RE = re.compile(r"^```(?:json)?\s*", re.MULTILINE) +CODE_FENCE_END_RE = re.compile(r"\s*```\s*$", re.MULTILINE) + + +def _strip_think_blocks(text: str) -> str: + """Strip ... blocks from text. + + Handles both complete and unclosed blocks (quantization can cause + the model to never close a think tag). + """ + original = text + + # Complete blocks first + text = THINK_BLOCK_RE.sub("", text) + + # Unclosed block (think started but never closed — strip to end) + text = THINK_UNCLOSED_RE.sub("", text) + + if text != original: + if THINK_UNCLOSED_RE.search(original): + logger.warning( + "XML Parser: Stripped unclosed block " + "(possible quantization degradation)" + ) + else: + logger.debug("XML Parser: Stripped block(s) from output") + + return text + + +def _coerce_param_value(raw: str) -> Any: + """Coerce a raw parameter value string to the appropriate Python type. + + Strategy (safe, no eval()): + 1. Strip leading/trailing newlines (official template emits \\n + after opening tag and before closing tag). + 2. Try json.loads — handles objects, arrays, numbers, bools, null. + 3. Fall back to plain string. + """ + # Strip template-inserted newlines around values + if raw.startswith("\n"): + raw = raw[1:] + if raw.endswith("\n"): + raw = raw[:-1] + + stripped = raw.strip() + + # Empty string + if not stripped: + return "" + + # Try JSON parse (handles objects, arrays, numbers, booleans, null) + try: + return json.loads(stripped) + except (json.JSONDecodeError, ValueError): + pass + + # Fall back to string — never eval() + return stripped + class ToolCallProcessor: + + # ------------------------------------------------------------------ + # JSON normalization helpers + # ------------------------------------------------------------------ + @staticmethod - def _normalize_tool_calls(raw) -> list[dict]: - """ - Normalize model-emitted tool call payloads into OAI-like objects. + def _normalize_tool_calls(raw) -> list: + """Normalize model-emitted tool call payloads into OAI-like objects. Accepted forms: - [{"type":"function","function":{"name":...,"arguments":{...}}}] @@ -43,9 +143,9 @@ def _normalize_tool_calls(raw) -> list[dict]: if isinstance(raw, dict): raw = [raw] if not isinstance(raw, list): - raise json.JSONDecodeError("tool_calls payload is not list/dict", str(raw), 0) + raise ValueError("tool_calls payload is not list/dict") - normalized: list[dict] = [] + normalized: list = [] for item in raw: if not isinstance(item, dict): continue @@ -60,6 +160,7 @@ def _normalize_tool_calls(raw) -> list[dict]: if name is None: continue + if isinstance(arguments, str): try: arguments = json.loads(arguments) @@ -78,53 +179,334 @@ def _normalize_tool_calls(raw) -> list[dict]: return normalized @staticmethod - def _safe_loads(payload: str): - """Best-effort JSON parse for model-emitted tool payloads.""" + def _safe_json_loads(payload: str) -> list: + """Best-effort JSON parse for model-emitted tool payloads. + + Handles: clean JSON, markdown-fenced JSON, JSON substrings in + surrounding text, flat {name, arguments} dicts, and single objects. + """ + # Direct parse try: return ToolCallProcessor._normalize_tool_calls(json.loads(payload)) - except json.JSONDecodeError: - # Common model artifacts: fenced code blocks or leading text. - cleaned = payload.strip() - cleaned = re.sub(r"^```(?:json)?\s*", "", cleaned) - cleaned = re.sub(r"\s*```$", "", cleaned) - - # EXAONE-like tool tags: {...} - tag_matches = re.findall(r"\s*(\{.*?\})\s*", cleaned, re.DOTALL) - if tag_matches: - parsed = [] - for candidate in tag_matches: - parsed.append(json.loads(candidate)) - return ToolCallProcessor._normalize_tool_calls(parsed) - - start = cleaned.find("[") - end = cleaned.rfind("]") - if start != -1 and end != -1 and end > start: - cleaned = cleaned[start : end + 1] - return ToolCallProcessor._normalize_tool_calls(json.loads(cleaned)) - - # Fallback: single JSON object - obj_start = cleaned.find("{") - obj_end = cleaned.rfind("}") - if obj_start != -1 and obj_end != -1 and obj_end > obj_start: - cleaned = cleaned[obj_start : obj_end + 1] + except (json.JSONDecodeError, ValueError): + pass + + # Clean up common model artifacts (markdown fences, whitespace) + cleaned = payload.strip() + cleaned = CODE_FENCE_RE.sub("", cleaned) + cleaned = CODE_FENCE_END_RE.sub("", cleaned) + cleaned = cleaned.strip() + + # Try cleaned + try: return ToolCallProcessor._normalize_tool_calls(json.loads(cleaned)) + except (json.JSONDecodeError, ValueError): + pass + + # Find JSON array substring + start = cleaned.find("[") + end = cleaned.rfind("]") + if start != -1 and end != -1 and end > start: + try: + return ToolCallProcessor._normalize_tool_calls( + json.loads(cleaned[start : end + 1]) + ) + except (json.JSONDecodeError, ValueError): + pass + + # Find JSON object substring + obj_start = cleaned.find("{") + obj_end = cleaned.rfind("}") + if obj_start != -1 and obj_end != -1 and obj_end > obj_start: + try: + return ToolCallProcessor._normalize_tool_calls( + json.loads(cleaned[obj_start : obj_end + 1]) + ) + except (json.JSONDecodeError, ValueError): + pass + + raise json.JSONDecodeError( + "Could not extract valid JSON from payload", payload, 0 + ) + + # ------------------------------------------------------------------ + # JSON parsing + # ------------------------------------------------------------------ @staticmethod def from_json(tool_calls_str: str) -> List[ToolCall]: - """Postprocess tool call JSON to a parseable class""" + """Postprocess tool call JSON to a parseable class. + + Handles clean JSON arrays, markdown-fenced output, flat dicts, + and other common model output variations via _safe_json_loads. + """ + logger.debug(f"JSON Parser: Parsing tool calls ({len(tool_calls_str)} chars)") - tool_calls = ToolCallProcessor._safe_loads(tool_calls_str) + tool_calls = ToolCallProcessor._safe_json_loads(tool_calls_str) for tool_call in tool_calls: tool_call["function"]["arguments"] = json.dumps( tool_call["function"]["arguments"] ) - return [ToolCall(**tool_call) for tool_call in tool_calls] + result = [ToolCall(**tool_call) for tool_call in tool_calls] + logger.debug(f"JSON Parser: Successfully parsed {len(result)} tool call(s)") + return result + + # ------------------------------------------------------------------ + # XML parsing (Qwen3-Coder / GLM-4.5 style) + # ------------------------------------------------------------------ @staticmethod - def dump(tool_calls: List[ToolCall]) -> List[dict]: + def from_xml(raw_text: str) -> List[ToolCall]: + """Parse Qwen3-Coder XML-format tool calls into ToolCall objects. + + Handles: + - Wrapped: ... + - Bare: ... (missing wrapper) + - Multiple sequential tool call blocks + - blocks (stripped) + - Multi-line parameter values + - Missing closing tags """ - Convert ToolCall objects to a list of dictionaries. + logger.debug(f"XML Parser: Parsing tool calls ({len(raw_text)} chars)") + + # Stage 1: Strip think blocks + text = _strip_think_blocks(raw_text) + + # Stage 2: Check for incomplete XML at end (generation cutoff) + stripped_end = text.rstrip() + if stripped_end.endswith(("<", "]*$", "", text) + + # Stage 3: Extract function blocks + # First, find all wrapped ... blocks + wrapped_positions = [ + (m.start(), m.end()) for m in TOOL_CALL_BLOCK_RE.finditer(text) + ] + + # Collect function blocks from inside wrapped regions + function_blocks = [] + for match in TOOL_CALL_BLOCK_RE.finditer(text): + inner = match.group(1) + for func_match in FUNCTION_RE.finditer(inner): + function_blocks.append((func_match.group(1), func_match.group(2))) + + # Find bare blocks NOT inside any wrapped region + for func_match in FUNCTION_RE.finditer(text): + pos = func_match.start() + is_wrapped = any(start <= pos < end for start, end in wrapped_positions) + if not is_wrapped: + logger.debug( + "XML Parser: Found bare block without " + " wrapper" + ) + function_blocks.append((func_match.group(1), func_match.group(2))) + + if not function_blocks: + logger.warning("XML Parser: No blocks found") + return [] + + # Stage 4: Parse each function block into a ToolCall + tool_calls = [] + for func_name_raw, func_body in function_blocks: + func_name = func_name_raw.strip() + + # Extract parameters + params = {} + for param_match in PARAMETER_RE.finditer(func_body): + key = param_match.group(1).strip() + value_raw = param_match.group(2) + value = _coerce_param_value(value_raw) + params[key] = value + + arguments_json = json.dumps(params, ensure_ascii=False) + + tool_call = ToolCall( + function=Tool(name=func_name, arguments=arguments_json) + ) + tool_calls.append(tool_call) + + logger.debug(f"XML Parser: Successfully parsed {len(tool_calls)} tool call(s)") + return tool_calls + + # ------------------------------------------------------------------ + # Auto-detect parsing (JSON → JSON-in-tool_call → XML) + # ------------------------------------------------------------------ + + @staticmethod + def from_auto(raw_text: str) -> List[ToolCall]: + """Auto-detect format and parse. + + Tries in order: + 1. Pure JSON (standard TabbyAPI / Llama) + 2. JSON inside wrappers (Qwen3-Instruct style) + 3. XML with tags (Qwen3-Coder style) + """ + logger.debug("Auto Parser: Attempting format auto-detection") + + # Attempt 1: Pure JSON array + try: + result = ToolCallProcessor.from_json(raw_text) + logger.debug("Auto Parser: Detected JSON format") + return result + except (json.JSONDecodeError, ValueError, KeyError) as e: + logger.debug(f"Auto Parser: Not JSON ({e}), trying next format") + + # Attempt 2: JSON inside wrappers (Qwen3-Instruct) + try: + all_tool_calls = [] + for match in TOOL_CALL_BLOCK_RE.finditer(raw_text): + inner = match.group(1).strip() + if inner.startswith("{") or inner.startswith("["): + parsed = json.loads(inner) + if isinstance(parsed, dict): + parsed = [parsed] + if isinstance(parsed, list): + for tc in parsed: + name = tc.get("name", "") + arguments = tc.get("arguments", {}) + if isinstance(arguments, dict): + arguments = json.dumps(arguments) + elif not isinstance(arguments, str): + arguments = json.dumps(arguments) + all_tool_calls.append( + ToolCall(function=Tool(name=name, arguments=arguments)) + ) + if all_tool_calls: + logger.debug( + "Auto Parser: Detected JSON-inside-tool_call " + f"format ({len(all_tool_calls)} call(s))" + ) + return all_tool_calls + except (json.JSONDecodeError, ValueError, KeyError) as e: + logger.debug(f"Auto Parser: Not JSON-in-tool_call ({e}), trying XML") + + # Attempt 3: XML format (Qwen3-Coder style) + result = ToolCallProcessor.from_xml(raw_text) + if result: + logger.debug("Auto Parser: Detected XML format") + else: + logger.warning("Auto Parser: All format detection attempts failed") + return result + + # ------------------------------------------------------------------ + # Dispatcher + # ------------------------------------------------------------------ + + @staticmethod + def parse(tool_calls_str: str, format: str = "json") -> List[ToolCall]: + """Dispatch tool call parsing to the appropriate format handler. + + Args: + tool_calls_str: Raw tool call text from model generation. + format: One of ``"json"``, ``"xml"``, ``"auto"``. + + Returns: + List of parsed ToolCall objects. Empty list on parse failure + (never raises). + """ + try: + if format == "xml": + return ToolCallProcessor.from_xml(tool_calls_str) + elif format == "auto": + return ToolCallProcessor.from_auto(tool_calls_str) + else: + return ToolCallProcessor.from_json(tool_calls_str) + except Exception as e: + logger.error( + f"ToolCallProcessor.parse: Failed to parse tool calls " + f"(format={format}): {e}" + ) + return [] + + # ------------------------------------------------------------------ + # Filtering + # ------------------------------------------------------------------ + + @staticmethod + def filter_by_name( + tool_calls: List[ToolCall], function_name: str + ) -> List[ToolCall]: + """Filter parsed tool calls to only those matching a function name.""" + filtered = [tc for tc in tool_calls if tc.function.name == function_name] + if not filtered: + logger.warning( + f"filter_by_name: No tool calls matched '{function_name}' " + f"(had {len(tool_calls)} call(s))" + ) + return filtered + + # ------------------------------------------------------------------ + # Content / tool-call separation + # ------------------------------------------------------------------ + + @staticmethod + def extract_content_and_tools( + raw_text: str, + ) -> Tuple[str, List[ToolCall]]: + """Separate plain text content from XML tool call blocks. + + Used when the model mixes reasoning text with tool calls, e.g.: + ``"I'll help with that: ...`` + + Returns: + Tuple of (remaining_content, tool_calls). + """ + text = _strip_think_blocks(raw_text) + + # Collect all XML regions to exclude from content + xml_regions = [] + + # Wrapped tool call blocks + for match in TOOL_CALL_BLOCK_RE.finditer(text): + xml_regions.append((match.start(), match.end())) + + # Bare function blocks not inside wrappers + for match in FUNCTION_RE.finditer(text): + pos = match.start() + is_wrapped = any(start <= pos < end for start, end in xml_regions) + if not is_wrapped: + xml_regions.append((match.start(), match.end())) + + # Sort and extract content (everything outside XML regions) + xml_regions.sort() + content_parts = [] + last_end = 0 + for start, end in xml_regions: + if start > last_end: + part = text[last_end:start].strip() + if part: + content_parts.append(part) + last_end = end + if last_end < len(text): + part = text[last_end:].strip() + if part: + content_parts.append(part) + + content = " ".join(content_parts).strip() + + # Parse tool calls from the full text + tool_calls = ToolCallProcessor.from_xml(text) + + logger.debug( + f"extract_content_and_tools: Found {len(tool_calls)} tool " + f"call(s), content={'yes' if content else 'no'} " + f"({len(content)} chars)" + ) + + return content, tool_calls + + # ------------------------------------------------------------------ + # Serialisation helpers (unchanged from original) + # ------------------------------------------------------------------ + + @staticmethod + def dump(tool_calls: List[ToolCall]) -> List[dict]: + """Convert ToolCall objects to a list of dictionaries. Args: tool_calls (List[ToolCall]): List of ToolCall objects to convert @@ -145,8 +527,7 @@ def dump(tool_calls: List[ToolCall]) -> List[dict]: @staticmethod def to_json(tool_calls: List[ToolCall]) -> str: - """ - Convert ToolCall objects to JSON string representation. + """Convert ToolCall objects to JSON string representation. Args: tool_calls (List[ToolCall]): List of ToolCall objects to convert @@ -163,13 +544,3 @@ def to_json(tool_calls: List[ToolCall]) -> str: # Serialize the dumped array return json.dumps(dumped_tool_calls, indent=2) - - @staticmethod - def filter_by_name(tool_calls_str: str, function_name: str) -> str: - tool_calls = ToolCallProcessor._safe_loads(tool_calls_str) - filtered = [ - item - for item in tool_calls - if item.get("function", {}).get("name") == function_name - ] - return json.dumps(filtered) diff --git a/templates/tool_calls/qwen3_coder.jinja b/templates/tool_calls/qwen3_coder.jinja new file mode 100644 index 00000000..0df78172 --- /dev/null +++ b/templates/tool_calls/qwen3_coder.jinja @@ -0,0 +1,125 @@ +{# SPDX-License-Identifier: Apache-2.0 #} +{# SPDX-FileCopyrightText: Copyright contributors to the vLLM project #} +{# TabbyAPI Metadata #} +{%- set tool_call_format = "xml" -%} +{%- set tool_start = "" -%} +{%- set tool_end = "" -%} +{%- set stop_strings = ["<|im_start|>", "<|im_end|>"] -%} + +{% macro render_extra_keys(json_dict, handled_keys) %} + {%- if json_dict is mapping %} + {%- for json_key in json_dict if json_key not in handled_keys %} + {%- if json_dict[json_key] is string %} + {{-'\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | string) ~ '' }} + {%- else %} + {{- '\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | tojson | safe) ~ '' }} + {%- endif %} + {%- endfor %} + {%- endif %} +{%- endmacro %} + +{%- if messages[0]["role"] == "system" %} + {%- set system_message = messages[0]["content"] %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set loop_messages = messages %} +{%- endif %} + +{%- if not tools is defined %} + {%- set tools = [] %} +{%- endif %} + +{%- if system_message is defined %} + {{- "<|im_start|>system\n" + system_message }} +{%- else %} + {%- if tools is iterable and tools | length > 0 %} + {{- "<|im_start|>system\nYou are Qwen, a helpful AI assistant that can interact with a computer to solve tasks." }} + {%- endif %} +{%- endif %} +{%- if tools is iterable and tools | length > 0 %} + {{- "\n\n# Tools\n\nYou have access to the following functions:\n\n" }} + {{- "" }} + {%- for tool in tools %} + {%- if tool.function is defined %} + {%- set tool = tool.function %} + {%- endif %} + {{- "\n\n" ~ tool.name ~ "" }} + {%- if tool.description is defined %} + {{- '\n' ~ (tool.description | trim) ~ '' }} + {%- endif %} + {{- '\n' }} + {%- if tool.parameters is defined and tool.parameters is mapping and tool.parameters.properties is defined and tool.parameters.properties is mapping %} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {{- '\n' }} + {{- '\n' ~ param_name ~ '' }} + {%- if param_fields.type is defined %} + {{- '\n' ~ (param_fields.type | string) ~ '' }} + {%- endif %} + {%- if param_fields.description is defined %} + {{- '\n' ~ (param_fields.description | trim) ~ '' }} + {%- endif %} + {%- set handled_keys = ['name', 'type', 'description'] %} + {{- render_extra_keys(param_fields, handled_keys) }} + {{- '\n' }} + {%- endfor %} + {%- endif %} + {%- set handled_keys = ['type', 'properties'] %} + {{- render_extra_keys(tool.parameters, handled_keys) }} + {{- '\n' }} + {%- set handled_keys = ['type', 'name', 'description', 'parameters'] %} + {{- render_extra_keys(tool, handled_keys) }} + {{- '\n' }} + {%- endfor %} + {{- "\n" }} + {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n\n\n\nvalue_1\n\n\nThis is the value for the second parameter\nthat can span\nmultiple lines\n\n\n\n\n\nReminder:\n- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n' }} +{%- endif %} +{%- if system_message is defined %} + {{- '<|im_end|>\n' }} +{%- else %} + {%- if tools is iterable and tools | length > 0 %} + {{- '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- for message in loop_messages %} + {%- if message.role == "assistant" and message.tool_calls is defined and message.tool_calls is iterable and message.tool_calls | length > 0 %} + {{- '<|im_start|>' + message.role }} + {%- if message.content is defined and message.content is string and message.content | trim | length > 0 %} + {{- '\n' + message.content | trim + '\n' }} + {%- endif %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n\n\n' }} + {%- if tool_call.arguments is defined %} + {%- for args_name, args_value in tool_call.arguments|items %} + {{- '\n' }} + {%- set args_value = args_value if args_value is string else args_value | tojson | safe %} + {{- args_value }} + {{- '\n\n' }} + {%- endfor %} + {%- endif %} + {{- '\n' }} + {%- endfor %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "user" or message.role == "system" or message.role == "assistant" %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {%- elif message.role == "tool" %} + {%- if loop.previtem and loop.previtem.role != "tool" %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- message.content }} + {{- '\n' }} + {%- if not loop.last and loop.nextitem.role != "tool" %} + {{- '<|im_end|>\n' }} + {%- elif loop.last %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} +{%- endif %} diff --git a/tests/parser_options_test.py b/tests/parser_options_test.py new file mode 100644 index 00000000..250e0dba --- /dev/null +++ b/tests/parser_options_test.py @@ -0,0 +1,27 @@ +"""Tests for vLLM-compatible parser option mapping.""" + +from endpoints.OAI.utils.parser_options import ( + list_tool_call_parsers, + resolve_tool_call_format, +) + + +def test_parser_key_registry_contains_core_vllm_keys(): + parser_keys = list_tool_call_parsers() + + assert "openai" in parser_keys + assert "qwen3_coder" in parser_keys + assert "qwen3_xml" in parser_keys + assert "mistral" in parser_keys + assert "deepseek_v3" in parser_keys + + +def test_resolve_tool_call_format_uses_vllm_mapping(): + assert resolve_tool_call_format("openai", "json") == "json" + assert resolve_tool_call_format("qwen3_coder", "json") == "xml" + assert resolve_tool_call_format("auto", "json") == "auto" + + +def test_resolve_tool_call_format_falls_back_and_rejects_unknown(): + assert resolve_tool_call_format(None, "json") == "json" + assert resolve_tool_call_format("unknown_parser", "json") == "" diff --git a/tests/tool_parser_test.py b/tests/tool_parser_test.py new file mode 100644 index 00000000..7a200fdc --- /dev/null +++ b/tests/tool_parser_test.py @@ -0,0 +1,82 @@ +"""Tests for tool call parsing helpers.""" + +import json + +from endpoints.OAI.utils.tools import ToolCallProcessor + + +def _arguments_dict(tool_call): + return json.loads(tool_call.function.arguments) + + +def test_from_json_handles_markdown_fences_and_flat_shape(): + payload = """```json +[{"name": "get_weather", "arguments": {"city": "Seoul"}}] +```""" + + parsed = ToolCallProcessor.from_json(payload) + + assert len(parsed) == 1 + assert parsed[0].function.name == "get_weather" + assert _arguments_dict(parsed[0]) == {"city": "Seoul"} + + +def test_from_xml_parses_qwen3_coder_style_blocks(): + payload = ( + "internal reasoning" + "" + "\nSeoul\n" + "\n3\n" + "" + ) + + parsed = ToolCallProcessor.from_xml(payload) + + assert len(parsed) == 1 + assert parsed[0].function.name == "get_weather" + assert _arguments_dict(parsed[0]) == {"city": "Seoul", "days": 3} + + +def test_from_auto_parses_json_inside_tool_call_wrapper(): + payload = ( + "" + '{"name": "search", "arguments": {"query": "tabbyapi"}}' + "" + ) + + parsed = ToolCallProcessor.from_auto(payload) + + assert len(parsed) == 1 + assert parsed[0].function.name == "search" + assert _arguments_dict(parsed[0]) == {"query": "tabbyapi"} + + +def test_extract_content_and_tools_splits_content_from_xml_calls(): + payload = ( + "I will call a tool now. " + "\ntabby\n" + "" + " Done." + ) + + content, parsed = ToolCallProcessor.extract_content_and_tools(payload) + + assert "I will call a tool now." in content + assert "Done." in content + assert len(parsed) == 1 + assert parsed[0].function.name == "search" + + +def test_filter_by_name_keeps_only_requested_function(): + payload = ( + "[" + '{"name": "a", "arguments": {}},' + '{"name": "b", "arguments": {}}' + "]" + ) + parsed = ToolCallProcessor.from_json(payload) + + filtered = ToolCallProcessor.filter_by_name(parsed, "b") + + assert len(filtered) == 1 + assert filtered[0].function.name == "b" From ec790802ce29b11788e4da7806537466d4417aaf Mon Sep 17 00:00:00 2001 From: lesj0610 Date: Wed, 18 Feb 2026 23:32:23 +0900 Subject: [PATCH 04/19] fix(reasoning): make exaone4 parser independent and add tests --- .../OAI/reasoning/exaone4_reasoning_parser.py | 166 ++++++++++-------- tests/exaone4_reasoning_parser_test.py | 114 ++++++++++++ 2 files changed, 205 insertions(+), 75 deletions(-) create mode 100644 tests/exaone4_reasoning_parser_test.py diff --git a/endpoints/OAI/reasoning/exaone4_reasoning_parser.py b/endpoints/OAI/reasoning/exaone4_reasoning_parser.py index b4394c13..6df47840 100644 --- a/endpoints/OAI/reasoning/exaone4_reasoning_parser.py +++ b/endpoints/OAI/reasoning/exaone4_reasoning_parser.py @@ -3,76 +3,83 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Sequence from typing import Any from endpoints.OAI.reasoning.abs_reasoning_parsers import ( DeltaMessage, + ReasoningParser, ReasoningParserManager, ) -from endpoints.OAI.reasoning.deepseek_r1_reasoning_parser import ( - DeepSeekR1ReasoningParser, -) -from endpoints.OAI.reasoning.deepseek_v3_reasoning_parser import ( - DeepSeekV3ReasoningParser, -) -from endpoints.OAI.reasoning.identity_reasoning_parser import IdentityReasoningParser @ReasoningParserManager.register_module("exaone4") -class Exaone4ReasoningParser(DeepSeekV3ReasoningParser): +class Exaone4ReasoningParser(ReasoningParser): """ - EXAONE-specific parser. + Reasoning parser for EXAONE 4.x models. - Important model behavior: - - Uses only `enable_thinking` to toggle reasoning mode. - - Chat template may prefill `` so model output often starts directly - with reasoning body and closes with ``. + Behavior notes: + - EXAONE uses `enable_thinking` (not `thinking`) to control reasoning mode. + - Templates may prefill ``, so streamed/output text can start directly + with reasoning text and close at ``. """ + start_token = "" + end_token = "" + def __init__(self, tokenizer: Any, *args, **kwargs): super().__init__(tokenizer, *args, **kwargs) chat_kwargs = kwargs.get("chat_template_kwargs", {}) or {} - enable_thinking = bool(chat_kwargs.get("enable_thinking", False)) - self._reasoning_ended = False - - if enable_thinking: - self._parser = DeepSeekR1ReasoningParser(tokenizer, *args, **kwargs) - else: - self._parser = IdentityReasoningParser(tokenizer, *args, **kwargs) - - def _thinking_enabled(self) -> bool: - return isinstance(self._parser, DeepSeekR1ReasoningParser) - - def _strip_stray_end_token(self, text: str) -> str: - if not text or not self._thinking_enabled(): - return text - - end_token = self._parser.end_token - start_token = self._parser.start_token - if start_token not in text and end_token in text: - return text.replace(end_token, "") - return text + self.thinking_enabled = bool(chat_kwargs.get("enable_thinking", False)) + self.start_token_id = self.vocab.get(self.start_token) + self.end_token_id = self.vocab.get(self.end_token) + + def _contains_token( + self, token_id: int | None, token_text: str, token_ids: Sequence[int], text: str + ) -> bool: + if token_id is not None and token_id in token_ids: + return True + return token_text in text if text else False + + def _strip_reasoning_tokens(self, text: str) -> str: + if not text: + return "" + return text.replace(self.start_token, "").replace(self.end_token, "") + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + if not self.thinking_enabled: + return True + if self.end_token_id is None: + return False + return any(token_id == self.end_token_id for token_id in reversed(input_ids)) + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + if not self.thinking_enabled: + return input_ids + if self.end_token_id is None or self.end_token_id not in input_ids[:-1]: + return [] + return input_ids[input_ids.index(self.end_token_id) + 1 :] def extract_reasoning( self, model_output: str, request: Any, ) -> tuple[str | None, str | None]: - if not self._thinking_enabled(): - return None, model_output - - start_token = self._parser.start_token - end_token = self._parser.end_token + if not self.thinking_enabled: + content = self._strip_reasoning_tokens(model_output) + return None, content or None - if start_token in model_output: - _, _, model_output = model_output.partition(start_token) + if self.start_token in model_output: + _, _, model_output = model_output.partition(self.start_token) - if end_token in model_output: - reasoning, _, content = model_output.partition(end_token) + if self.end_token in model_output: + reasoning, _, content = model_output.partition(self.end_token) + content = self._strip_reasoning_tokens(content) return reasoning or None, content or None - return model_output or None, None + reasoning = model_output.replace(self.start_token, "") + return reasoning or None, None def extract_reasoning_streaming( self, @@ -83,35 +90,44 @@ def extract_reasoning_streaming( current_token_ids: list[int], delta_token_ids: list[int], ) -> DeltaMessage | None: - if not self._thinking_enabled(): - return self._parser.extract_reasoning_streaming( - previous_text, - current_text, - delta_text, - previous_token_ids, - current_token_ids, - delta_token_ids, - ) - - if self._reasoning_ended: - content = self._strip_stray_end_token(delta_text) - return DeltaMessage(content=content) if content else None - - start_token = self._parser.start_token - end_token = self._parser.end_token - - if end_token in delta_text: - reasoning_part, _, content_part = delta_text.partition(end_token) - if start_token in reasoning_part: - _, _, reasoning_part = reasoning_part.partition(start_token) - - self._reasoning_ended = True - return DeltaMessage( - reasoning=reasoning_part or None, - content=content_part or None, - ) - - reasoning_part = delta_text - if start_token in reasoning_part: - _, _, reasoning_part = reasoning_part.partition(start_token) - return DeltaMessage(reasoning=reasoning_part or None) if reasoning_part else None + if not delta_text and not delta_token_ids: + return None + + if not self.thinking_enabled: + content = self._strip_reasoning_tokens(delta_text) + return DeltaMessage(content=content or None) if content else None + + end_in_prev = self._contains_token( + self.end_token_id, self.end_token, previous_token_ids, previous_text + ) + end_in_delta = self._contains_token( + self.end_token_id, self.end_token, delta_token_ids, delta_text + ) + + if len(delta_token_ids) == 1 and ( + (self.start_token_id is not None and delta_token_ids[0] == self.start_token_id) + or (self.end_token_id is not None and delta_token_ids[0] == self.end_token_id) + ): + return None + + if end_in_prev: + content = self._strip_reasoning_tokens(delta_text) + return DeltaMessage(content=content or None) if content else None + + if end_in_delta: + reasoning_part, _, content_part = delta_text.partition(self.end_token) + if self.start_token in reasoning_part: + _, _, reasoning_part = reasoning_part.partition(self.start_token) + + reasoning = reasoning_part or None + content = self._strip_reasoning_tokens(content_part) or None + if reasoning is None and content is None: + return None + return DeltaMessage(reasoning=reasoning, content=content) + + # EXAONE can omit start token in stream when template prefills . + # While thinking mode is enabled, treat pre-end chunks as reasoning. + reasoning_part = delta_text.replace(self.start_token, "") + + reasoning = reasoning_part or None + return DeltaMessage(reasoning=reasoning) if reasoning else None diff --git a/tests/exaone4_reasoning_parser_test.py b/tests/exaone4_reasoning_parser_test.py new file mode 100644 index 00000000..d04cf6f1 --- /dev/null +++ b/tests/exaone4_reasoning_parser_test.py @@ -0,0 +1,114 @@ +"""Tests for Exaone4 reasoning parser behavior.""" + +from endpoints.OAI.reasoning.exaone4_reasoning_parser import Exaone4ReasoningParser + + +class _FakeTokenizer: + def get_vocab(self): + return { + "": 101, + "": 102, + } + + +def _parser(enable_thinking: bool) -> Exaone4ReasoningParser: + return Exaone4ReasoningParser( + _FakeTokenizer(), + chat_template_kwargs={"enable_thinking": enable_thinking}, + ) + + +def test_non_thinking_mode_emits_content_only(): + parser = _parser(enable_thinking=False) + + reasoning, content = parser.extract_reasoning("hello", request=None) + assert reasoning is None + assert content == "hello" + + reasoning, content = parser.extract_reasoning("hello", request=None) + assert reasoning is None + assert content == "hello" + + delta = parser.extract_reasoning_streaming( + previous_text="", + current_text="hello", + delta_text="hello", + previous_token_ids=[], + current_token_ids=[1], + delta_token_ids=[1], + ) + assert delta is not None + assert delta.reasoning is None + assert delta.content == "hello" + + +def test_thinking_mode_extract_reasoning_and_content_non_streaming(): + parser = _parser(enable_thinking=True) + + reasoning, content = parser.extract_reasoning( + "reasonanswer", request=None + ) + assert reasoning == "reason" + assert content == "answer" + + reasoning, content = parser.extract_reasoning("reason
answer", request=None) + assert reasoning == "reason" + assert content == "answer" + + +def test_thinking_mode_without_end_token_is_reasoning_only(): + parser = _parser(enable_thinking=True) + + reasoning, content = parser.extract_reasoning("reasoning only", request=None) + assert reasoning == "reasoning only" + assert content is None + + +def test_thinking_streaming_prefill_flow_without_start_token(): + parser = _parser(enable_thinking=True) + + first = parser.extract_reasoning_streaming( + previous_text="", + current_text="reason ", + delta_text="reason ", + previous_token_ids=[], + current_token_ids=[11], + delta_token_ids=[11], + ) + assert first is not None + assert first.reasoning == "reason " + assert first.content is None + + second = parser.extract_reasoning_streaming( + previous_text="reason ", + current_text="reason morefinal", + delta_text="morefinal", + previous_token_ids=[11], + current_token_ids=[11, 12, 102, 13], + delta_token_ids=[12, 102, 13], + ) + assert second is not None + assert second.reasoning == "more" + assert second.content == "final" + + third = parser.extract_reasoning_streaming( + previous_text="reason morefinal", + current_text="reason morefinal!", + delta_text="!", + previous_token_ids=[11, 12, 102, 13], + current_token_ids=[11, 12, 102, 13, 14], + delta_token_ids=[14], + ) + assert third is not None + assert third.reasoning is None + assert third.content == "!" + + +def test_thinking_mode_content_ids_and_end_detection(): + parser = _parser(enable_thinking=True) + + assert parser.is_reasoning_end([1, 2, 102]) is True + assert parser.is_reasoning_end([1, 2, 3]) is False + + assert parser.extract_content_ids([10, 101, 20, 102, 30, 31]) == [30, 31] + assert parser.extract_content_ids([10, 101, 20]) == [] From a51138ab33fdc9a375db3e396459861649a4322b Mon Sep 17 00:00:00 2001 From: lesj0610 Date: Thu, 19 Feb 2026 00:42:53 +0900 Subject: [PATCH 05/19] feat(oai): add vLLM-style tool parser registry and native parser flows --- common/config_models.py | 4 +- config_sample.yml | 5 +- docs/10.-Tool-Calling.md | 8 + endpoints/OAI/utils/chat_completion.py | 38 ++- endpoints/OAI/utils/parser_options.py | 45 +++- endpoints/OAI/utils/tools.py | 333 ++++++++++++++++++++++++- tests/parser_options_test.py | 12 + tests/tool_parser_test.py | 129 ++++++++++ 8 files changed, 555 insertions(+), 19 deletions(-) diff --git a/common/config_models.py b/common/config_models.py index 5d89b58a..c13434d5 100644 --- a/common/config_models.py +++ b/common/config_models.py @@ -315,7 +315,9 @@ class ModelConfig(BaseConfigModel): description=( "Tool parser key for model-generated tool call output.\n" "Equivalent to vLLM's --tool-call-parser.\n" - "Example values: qwen3_coder, qwen3_xml, mistral, hermes, openai." + "Built-in parser keys include: hermes, llama/llama3_json/llama4_json,\n" + "openai, pythonic, qwen3_coder, qwen3_xml,\n" + "deepseek_v3, deepseek_v31, deepseek_v32." ), ) exclude_tools_when_tool_choice_none: Optional[bool] = Field( diff --git a/config_sample.yml b/config_sample.yml index f6dfab66..bd03afb2 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -165,7 +165,10 @@ model: # Tool parser key for model-generated tool call text. # Equivalent to vLLM --tool-call-parser. - # Common values: qwen3_coder, qwen3_xml, mistral, hermes, openai. + # Built-in values include: + # hermes, llama (alias of llama3_json), llama3_json, llama4_json, + # openai, pythonic, qwen3_coder, qwen3_xml, + # deepseek_v3, deepseek_v31, deepseek_v32. tool_call_parser: # Exclude tool definitions from prompt when tool_choice='none'. diff --git a/docs/10.-Tool-Calling.md b/docs/10.-Tool-Calling.md index 84ffa90f..aaed88cf 100644 --- a/docs/10.-Tool-Calling.md +++ b/docs/10.-Tool-Calling.md @@ -28,6 +28,14 @@ The following model config options are available to align behavior with vLLM: `tool_choice="auto"` requires both `enable_auto_tool_choice: true` and `tool_call_parser` to be set. +Supported parser keys include: +- `hermes` +- `llama` (alias of `llama3_json`), `llama3_json`, `llama4_json` +- `openai` +- `pythonic` +- `qwen3_coder`, `qwen3_xml` +- `deepseek_v3`, `deepseek_v31`, `deepseek_v32` + ## Model Support TabbyAPI exposes controls within the `prompt_template` to accommodate models specifically tuned for tool calling and those that aren't. By default, TabbyAPI includes `chatml_with_headers_tool_calling.jinja`, a generic template built to support the Llama 3.1 family and other models following the ChatML (with headers) format. diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 9e081090..a3f302ef 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -36,6 +36,7 @@ from endpoints.OAI.utils.completion import _parse_gen_request_id, _stream_collector from endpoints.OAI.utils.parser_options import ( list_tool_call_parsers, + parser_uses_native_tool_generation, resolve_tool_call_format, ) from endpoints.OAI.utils.tools import ToolCallProcessor, TOOL_CALL_SCHEMA @@ -174,6 +175,7 @@ def _create_response( """Create a chat completion response from the provided text.""" choices = [] + parser_key = config.model.tool_call_parser for index, generation in enumerate(generations): reasoning = generation.get("reasoning") reasoning_content = generation.get("reasoning_content") @@ -186,7 +188,11 @@ def _create_response( tool_calls_raw = generation.get("tool_calls") if tool_calls_raw: - parsed = ToolCallProcessor.parse(tool_calls_raw, format=tool_call_format) + parsed = ToolCallProcessor.parse( + tool_calls_raw, + format=tool_call_format, + parser_key=parser_key, + ) if parsed and isinstance(tool_choice, NamedToolChoice): parsed = ToolCallProcessor.filter_by_name( parsed, tool_choice.function.name @@ -682,7 +688,9 @@ async def stream_generate_chat_completion( if "tool_calls" in generation: tool_calls_raw = generation["tool_calls"] parsed = ToolCallProcessor.parse( - tool_calls_raw, format=tool_call_format + tool_calls_raw, + format=tool_call_format, + parser_key=config.model.tool_call_parser, ) if parsed and isinstance(data.tool_choice, NamedToolChoice): parsed = ToolCallProcessor.filter_by_name( @@ -863,6 +871,10 @@ async def generate_tool_calls( data, default_tool_call_format ) tool_choice = data.tool_choice + parser_key = config.model.tool_call_parser + use_native_generation = parser_uses_native_tool_generation( + parser_key, tool_call_format + ) if tool_choice == "none": return generations @@ -873,12 +885,14 @@ async def generate_tool_calls( # Copy to make sure the parent JSON schema doesn't get modified tool_data = data.model_copy(deep=True) - if tool_call_format in ("xml", "auto"): - # XML / auto mode: let the model generate its natural output - # without JSON schema constraint + if use_native_generation: + # Native syntax mode: let the model generate its natural tool-call + # representation without JSON schema constraint. logger.debug( - f"generate_tool_calls: Using '{tool_call_format}' mode " - f"(no JSON schema constraint)" + "generate_tool_calls: Using parser '{}' in native mode " + "(format={}, no JSON schema constraint)", + parser_key or "template-default", + tool_call_format, ) # Remove tool_start from stop strings so the model can emit @@ -921,11 +935,11 @@ async def generate_tool_calls( if precursor_text: tool_prompt = tool_prompt + precursor_text - # For XML/auto mode: append tool_start back to prompt. + # For native generation mode: append tool_start back to prompt. # The stop string was consumed by the first pass and not included - # in full_text, but the model expects to continue after . + # in full_text, but the model expects to continue after tool_start. # Include a trailing newline to match the canonical template format. - if tool_call_format in ("xml", "auto") and tool_start: + if use_native_generation and tool_start: tool_prompt = tool_prompt + tool_start + "\n" gen_request_id = gen.get("request_id") @@ -951,8 +965,8 @@ async def generate_tool_calls( for gen_idx, tool_call in zip(tool_idx, tool_calls, strict=True): raw_text = tool_call["text"] - if tool_call_format in ("xml", "auto") and tool_start: - # Prepend tool_start to reconstruct complete XML for parser + if use_native_generation and tool_start: + # Prepend tool_start to reconstruct complete native payload. raw_text = tool_start + "\n" + raw_text generations[gen_idx]["tool_calls"] = raw_text diff --git a/endpoints/OAI/utils/parser_options.py b/endpoints/OAI/utils/parser_options.py index 99fe5ad3..e1da446c 100644 --- a/endpoints/OAI/utils/parser_options.py +++ b/endpoints/OAI/utils/parser_options.py @@ -8,7 +8,7 @@ # Mirrors vLLM parser keys to keep CLI/config ergonomics familiar. # Source of truth: vllm/tool_parsers/__init__.py::_TOOL_PARSERS_TO_REGISTER -# Format is the parsing mode supported by ToolCallProcessor. +# Format is the fallback parsing mode supported by ToolCallProcessor. TOOL_CALL_PARSER_FORMATS: Dict[str, str] = { "deepseek_v3": "json", "deepseek_v31": "json", @@ -46,9 +46,35 @@ "auto": "auto", } +# Compatibility aliases accepted by this server. +# Keys are user-facing parser names, values are canonical parser keys. +TOOL_CALL_PARSER_ALIASES: Dict[str, str] = { + "llama": "llama3_json", +} + +# Parsers that should generate tool calls in their native syntax on tool pass +# (no JSON schema constraint). Most JSON-style parsers should stay constrained. +NATIVE_TOOL_GENERATION_PARSERS: Set[str] = { + "auto", + "deepseek_v3", + "deepseek_v31", + "deepseek_v32", + "llama4_pythonic", + "pythonic", + "qwen3_coder", + "qwen3_xml", +} + + +def resolve_tool_call_parser_key(tool_call_parser: str | None) -> str | None: + """Normalize a user parser key to its canonical key.""" + if not tool_call_parser: + return None + return TOOL_CALL_PARSER_ALIASES.get(tool_call_parser, tool_call_parser) + def list_tool_call_parsers() -> Set[str]: - return set(TOOL_CALL_PARSER_FORMATS.keys()) + return set(TOOL_CALL_PARSER_FORMATS.keys()).union(TOOL_CALL_PARSER_ALIASES.keys()) def resolve_tool_call_format( @@ -57,4 +83,17 @@ def resolve_tool_call_format( """Resolve effective parser format from configured parser key.""" if not tool_call_parser: return fallback_format - return TOOL_CALL_PARSER_FORMATS.get(tool_call_parser, "") + parser_key = resolve_tool_call_parser_key(tool_call_parser) + return TOOL_CALL_PARSER_FORMATS.get(parser_key, "") + + +def parser_uses_native_tool_generation( + tool_call_parser: str | None, fallback_format: str +) -> bool: + """Whether tool pass should use native model format (unconstrained).""" + if not tool_call_parser: + return fallback_format in ("xml", "auto") + parser_key = resolve_tool_call_parser_key(tool_call_parser) + if parser_key in NATIVE_TOOL_GENERATION_PARSERS: + return True + return resolve_tool_call_format(parser_key, fallback_format) in ("xml", "auto") diff --git a/endpoints/OAI/utils/tools.py b/endpoints/OAI/utils/tools.py index aa09b49e..bf6e2a0d 100644 --- a/endpoints/OAI/utils/tools.py +++ b/endpoints/OAI/utils/tools.py @@ -3,12 +3,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import ast import json import re from loguru import logger -from typing import Any, List, Tuple +from typing import Any, Callable, Dict, List, Tuple from endpoints.OAI.types.tools import ToolCall, Tool +from endpoints.OAI.utils.parser_options import resolve_tool_call_parser_key TOOL_CALL_SCHEMA = { @@ -67,6 +69,24 @@ CODE_FENCE_RE = re.compile(r"^```(?:json)?\s*", re.MULTILINE) CODE_FENCE_END_RE = re.compile(r"\s*```\s*$", re.MULTILINE) +# DeepSeek family patterns +DEEPSEEK_V31_CALL_RE = re.compile( + r"<|tool▁call▁begin|>(?P.*?)<|tool▁sep|>(?P.*?)<|tool▁call▁end|>", + re.DOTALL, +) +DEEPSEEK_V3_CALL_RE = re.compile( + r"<|tool▁call▁begin|>(?P.*?)<|tool▁sep|>(?P.*?)\n```json\n(?P.*?)\n```(?:\s*)<|tool▁call▁end|>", # noqa: E501 + re.DOTALL, +) +DEEPSEEK_V32_INVOKE_RE = re.compile( + r'<|DSML|invoke\s+name="(?P[^"]+)"\s*>(?P.*?)', + re.DOTALL, +) +DEEPSEEK_V32_PARAM_RE = re.compile( + r'<|DSML|parameter\s+name="(?P[^"]+)"(?:\s+string="(?Ptrue|false)")?\s*>(?P.*?)', # noqa: E501 + re.DOTALL, +) + def _strip_think_blocks(text: str) -> str: """Strip ... blocks from text. @@ -126,6 +146,7 @@ def _coerce_param_value(raw: str) -> Any: class ToolCallProcessor: + _PARSER_DISPATCHER: Dict[str, Callable[[str], List[ToolCall]]] = {} # ------------------------------------------------------------------ # JSON normalization helpers @@ -229,10 +250,279 @@ def _safe_json_loads(payload: str) -> list: "Could not extract valid JSON from payload", payload, 0 ) + @staticmethod + def _build_tool_calls_from_normalized(raw: Any) -> List[ToolCall]: + """Normalize dict/list payload and build ToolCall models.""" + normalized = ToolCallProcessor._normalize_tool_calls(raw) + for tool_call in normalized: + tool_call["function"]["arguments"] = json.dumps( + tool_call["function"]["arguments"], ensure_ascii=False + ) + return [ToolCall(**tool_call) for tool_call in normalized] + + @staticmethod + def _decode_json_sequence(text: str) -> List[Any]: + """Decode multiple JSON values from a single string.""" + decoder = json.JSONDecoder() + values: List[Any] = [] + idx = 0 + while idx < len(text): + while idx < len(text) and text[idx] in " \t\r\n,;": + idx += 1 + if idx >= len(text): + break + if text.startswith("<|python_tag|>", idx): + idx += len("<|python_tag|>") + continue + try: + value, end = decoder.raw_decode(text[idx:]) + except json.JSONDecodeError: + break + values.append(value) + idx += end + return values + + @staticmethod + def _coerce_argument_payload(arguments_raw: str) -> str: + """Normalize raw argument payload to a JSON string where possible.""" + payload = arguments_raw.strip() + if not payload: + return "{}" + try: + return json.dumps(json.loads(payload), ensure_ascii=False) + except (json.JSONDecodeError, ValueError, TypeError): + return payload + + @staticmethod + def _ast_to_literal(node: ast.AST) -> Any: + """Safely convert AST literal nodes to Python primitives.""" + if isinstance(node, ast.Constant): + return node.value + if isinstance(node, ast.List): + return [ToolCallProcessor._ast_to_literal(item) for item in node.elts] + if isinstance(node, ast.Tuple): + return [ToolCallProcessor._ast_to_literal(item) for item in node.elts] + if isinstance(node, ast.Dict): + result = {} + for key, value in zip(node.keys, node.values): + literal_key = ToolCallProcessor._ast_to_literal(key) # type: ignore[arg-type] + if not isinstance(literal_key, str): + raise ValueError("pythonic parser requires string dict keys") + result[literal_key] = ToolCallProcessor._ast_to_literal(value) + return result + if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub): + return -ToolCallProcessor._ast_to_literal(node.operand) + raise ValueError(f"unsupported pythonic AST node: {type(node).__name__}") + # ------------------------------------------------------------------ # JSON parsing # ------------------------------------------------------------------ + @staticmethod + def from_hermes(raw_text: str) -> List[ToolCall]: + """Parse Hermes-style JSON tool calls (often wrapped in ).""" + text = _strip_think_blocks(raw_text) + wrapped_calls = [] + for match in TOOL_CALL_BLOCK_RE.finditer(text): + inner = match.group(1).strip() + if not inner: + continue + try: + parsed = json.loads(inner) + except (json.JSONDecodeError, ValueError): + continue + wrapped_calls.extend(ToolCallProcessor._build_tool_calls_from_normalized(parsed)) + + if wrapped_calls: + return wrapped_calls + + return ToolCallProcessor.from_json(text) + + @staticmethod + def from_llama(raw_text: str) -> List[ToolCall]: + """Parse Llama JSON tool calls (single/multiple JSON objects).""" + text = _strip_think_blocks(raw_text).strip() + if text.startswith("<|python_tag|>"): + text = text[len("<|python_tag|>") :].lstrip() + + try: + parsed = ToolCallProcessor.from_json(text) + if parsed: + return parsed + except (json.JSONDecodeError, ValueError, KeyError): + pass + + decoded = ToolCallProcessor._decode_json_sequence(text) + if not decoded: + return [] + + flattened = [] + for item in decoded: + if isinstance(item, list): + flattened.extend(item) + else: + flattened.append(item) + + return ToolCallProcessor._build_tool_calls_from_normalized(flattened) + + @staticmethod + def from_openai(raw_text: str) -> List[ToolCall]: + """Best-effort parser for OpenAI/Harmony-style text payloads.""" + text = _strip_think_blocks(raw_text).strip() + try: + parsed = ToolCallProcessor.from_json(text) + if parsed: + return parsed + except (json.JSONDecodeError, ValueError, KeyError): + pass + + decoded = ToolCallProcessor._decode_json_sequence(text) + tool_calls: List[ToolCall] = [] + normalized_items = [] + for value in decoded: + candidates = value if isinstance(value, list) else [value] + for item in candidates: + if not isinstance(item, dict): + continue + + nested = item.get("tool_calls") + if nested: + try: + tool_calls.extend( + ToolCallProcessor._build_tool_calls_from_normalized(nested) + ) + except (ValueError, KeyError, TypeError): + pass + + recipient = item.get("recipient") + content = item.get("content") + if isinstance(recipient, str) and recipient.startswith("functions."): + fn_name = recipient.split("functions.", 1)[1] + if isinstance(content, str): + payload = ToolCallProcessor._coerce_argument_payload(content) + elif content is None: + payload = "{}" + else: + payload = json.dumps(content, ensure_ascii=False) + tool_calls.append( + ToolCall(function=Tool(name=fn_name, arguments=payload)) + ) + continue + + if "name" in item: + normalized_items.append(item) + + if normalized_items: + tool_calls.extend( + ToolCallProcessor._build_tool_calls_from_normalized(normalized_items) + ) + + return tool_calls + + @staticmethod + def from_pythonic(raw_text: str) -> List[ToolCall]: + """Parse Pythonic list-of-calls tool syntax.""" + text = _strip_think_blocks(raw_text).strip() + if text.startswith("<|python_tag|>"): + text = text[len("<|python_tag|>") :].lstrip() + if not text: + return [] + + if not text.startswith("[") and re.match(r"^[A-Za-z_]\w*\s*\(", text): + text = f"[{text}]" + + expression = ast.parse(text, mode="eval").body + call_nodes = expression.elts if isinstance(expression, ast.List) else [expression] + + tool_calls = [] + for node in call_nodes: + if not isinstance(node, ast.Call) or not isinstance(node.func, ast.Name): + continue + args_dict: Dict[str, Any] = {} + if node.args: + args_dict["_args"] = [ + ToolCallProcessor._ast_to_literal(argument) + for argument in node.args + ] + for keyword in node.keywords: + if keyword.arg is None: + continue + args_dict[keyword.arg] = ToolCallProcessor._ast_to_literal(keyword.value) + + tool_calls.append( + ToolCall( + function=Tool( + name=node.func.id, + arguments=json.dumps(args_dict, ensure_ascii=False), + ) + ) + ) + + return tool_calls + + @staticmethod + def from_deepseek_v31(raw_text: str) -> List[ToolCall]: + """Parse DeepSeek v3.1 tool call syntax.""" + tool_calls = [] + for match in DEEPSEEK_V31_CALL_RE.finditer(raw_text): + name = match.group("name").strip() + if not name: + continue + arguments = ToolCallProcessor._coerce_argument_payload(match.group("args")) + tool_calls.append(ToolCall(function=Tool(name=name, arguments=arguments))) + return tool_calls + + @staticmethod + def from_deepseek_v3(raw_text: str) -> List[ToolCall]: + """Parse DeepSeek v3 tool call syntax.""" + tool_calls = [] + for match in DEEPSEEK_V3_CALL_RE.finditer(raw_text): + name = match.group("name").strip() + if not name: + continue + arguments = ToolCallProcessor._coerce_argument_payload(match.group("args")) + tool_calls.append(ToolCall(function=Tool(name=name, arguments=arguments))) + + if tool_calls: + return tool_calls + + return ToolCallProcessor.from_deepseek_v31(raw_text) + + @staticmethod + def from_deepseek_v32(raw_text: str) -> List[ToolCall]: + """Parse DeepSeek v3.2 DSML tool call syntax.""" + tool_calls = [] + for invoke in DEEPSEEK_V32_INVOKE_RE.finditer(raw_text): + function_name = invoke.group("name").strip() + if not function_name: + continue + + params: Dict[str, Any] = {} + body = invoke.group("body") + for param in DEEPSEEK_V32_PARAM_RE.finditer(body): + key = param.group("name").strip() + value_raw = param.group("value") + is_string = param.group("string") == "true" + if is_string: + value = value_raw.strip("\n") + else: + value = _coerce_param_value(value_raw) + params[key] = value + + tool_calls.append( + ToolCall( + function=Tool( + name=function_name, + arguments=json.dumps(params, ensure_ascii=False), + ) + ) + ) + + if tool_calls: + return tool_calls + + return ToolCallProcessor.from_deepseek_v31(raw_text) + @staticmethod def from_json(tool_calls_str: str) -> List[ToolCall]: """Postprocess tool call JSON to a parseable class. @@ -398,18 +688,57 @@ def from_auto(raw_text: str) -> List[ToolCall]: # ------------------------------------------------------------------ @staticmethod - def parse(tool_calls_str: str, format: str = "json") -> List[ToolCall]: + def _parser_dispatcher() -> Dict[str, Callable[[str], List[ToolCall]]]: + """Registry for parser-key-specific handlers.""" + if not ToolCallProcessor._PARSER_DISPATCHER: + ToolCallProcessor._PARSER_DISPATCHER = { + "deepseek_v3": ToolCallProcessor.from_deepseek_v3, + "deepseek_v31": ToolCallProcessor.from_deepseek_v31, + "deepseek_v32": ToolCallProcessor.from_deepseek_v32, + "hermes": ToolCallProcessor.from_hermes, + "llama": ToolCallProcessor.from_llama, + "llama3_json": ToolCallProcessor.from_llama, + "llama4_json": ToolCallProcessor.from_llama, + "openai": ToolCallProcessor.from_openai, + "pythonic": ToolCallProcessor.from_pythonic, + "qwen3_coder": ToolCallProcessor.from_xml, + "qwen3_xml": ToolCallProcessor.from_xml, + } + return ToolCallProcessor._PARSER_DISPATCHER + + @staticmethod + def parse( + tool_calls_str: str, format: str = "json", parser_key: str | None = None + ) -> List[ToolCall]: """Dispatch tool call parsing to the appropriate format handler. Args: tool_calls_str: Raw tool call text from model generation. format: One of ``"json"``, ``"xml"``, ``"auto"``. + parser_key: Optional vLLM-compatible parser key. Returns: List of parsed ToolCall objects. Empty list on parse failure (never raises). """ try: + if parser_key: + canonical_key = resolve_tool_call_parser_key(parser_key) or parser_key + parser = ToolCallProcessor._parser_dispatcher().get(canonical_key) + if parser: + try: + parsed = parser(tool_calls_str) + except Exception as exc: + logger.warning( + "Parser '{}' failed: {}. Falling back to format '{}'.", + canonical_key, + str(exc), + format, + ) + else: + if parsed: + return parsed + if format == "xml": return ToolCallProcessor.from_xml(tool_calls_str) elif format == "auto": diff --git a/tests/parser_options_test.py b/tests/parser_options_test.py index 250e0dba..0d328bc0 100644 --- a/tests/parser_options_test.py +++ b/tests/parser_options_test.py @@ -2,6 +2,8 @@ from endpoints.OAI.utils.parser_options import ( list_tool_call_parsers, + parser_uses_native_tool_generation, + resolve_tool_call_parser_key, resolve_tool_call_format, ) @@ -14,14 +16,24 @@ def test_parser_key_registry_contains_core_vllm_keys(): assert "qwen3_xml" in parser_keys assert "mistral" in parser_keys assert "deepseek_v3" in parser_keys + assert "llama" in parser_keys def test_resolve_tool_call_format_uses_vllm_mapping(): assert resolve_tool_call_format("openai", "json") == "json" assert resolve_tool_call_format("qwen3_coder", "json") == "xml" assert resolve_tool_call_format("auto", "json") == "auto" + assert resolve_tool_call_format("llama", "json") == "json" + assert resolve_tool_call_parser_key("llama") == "llama3_json" def test_resolve_tool_call_format_falls_back_and_rejects_unknown(): assert resolve_tool_call_format(None, "json") == "json" assert resolve_tool_call_format("unknown_parser", "json") == "" + + +def test_native_generation_flags_cover_native_syntax_parsers(): + assert parser_uses_native_tool_generation("qwen3_coder", "json") is True + assert parser_uses_native_tool_generation("deepseek_v31", "json") is True + assert parser_uses_native_tool_generation("pythonic", "json") is True + assert parser_uses_native_tool_generation("hermes", "json") is False diff --git a/tests/tool_parser_test.py b/tests/tool_parser_test.py index 7a200fdc..6d8b9372 100644 --- a/tests/tool_parser_test.py +++ b/tests/tool_parser_test.py @@ -80,3 +80,132 @@ def test_filter_by_name_keeps_only_requested_function(): assert len(filtered) == 1 assert filtered[0].function.name == "b" + + +def test_parse_with_hermes_parser_handles_wrapped_json(): + payload = ( + "" + '{"name":"weather","arguments":{"city":"Seoul"}}' + "" + ) + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="hermes") + + assert len(parsed) == 1 + assert parsed[0].function.name == "weather" + assert _arguments_dict(parsed[0]) == {"city": "Seoul"} + + +def test_parse_with_llama_parser_handles_sequential_json(): + payload = ( + "<|python_tag|>" + '{"name":"a","arguments":{"x":1}};' + '{"name":"b","arguments":{"y":2}}' + ) + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="llama") + + assert len(parsed) == 2 + assert parsed[0].function.name == "a" + assert _arguments_dict(parsed[0]) == {"x": 1} + assert parsed[1].function.name == "b" + assert _arguments_dict(parsed[1]) == {"y": 2} + + +def test_parse_with_pythonic_parser_extracts_function_calls(): + payload = "[get_weather(city='San Francisco', days=3)]" + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="pythonic") + + assert len(parsed) == 1 + assert parsed[0].function.name == "get_weather" + assert _arguments_dict(parsed[0]) == {"city": "San Francisco", "days": 3} + + +def test_parse_with_deepseek_v31_parser(): + payload = ( + "<|tool▁calls▁begin|>" + '<|tool▁call▁begin|>foo<|tool▁sep|>{"x":1}<|tool▁call▁end|>' + "<|tool▁calls▁end|>" + ) + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="deepseek_v31") + + assert len(parsed) == 1 + assert parsed[0].function.name == "foo" + assert _arguments_dict(parsed[0]) == {"x": 1} + + +def test_parse_with_deepseek_v3_parser(): + payload = ( + "<|tool▁calls▁begin|>" + "<|tool▁call▁begin|>function<|tool▁sep|>lookup\n" + "```json\n" + '{"q":"tabbyapi"}' + "\n```\n" + "<|tool▁call▁end|>" + "<|tool▁calls▁end|>" + ) + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="deepseek_v3") + + assert len(parsed) == 1 + assert parsed[0].function.name == "lookup" + assert _arguments_dict(parsed[0]) == {"q": "tabbyapi"} + + +def test_parse_with_deepseek_v32_parser(): + payload = ( + "<|DSML|function_calls>" + '<|DSML|invoke name="get_weather">' + '<|DSML|parameter name="location" string="true">Seoul' + '<|DSML|parameter name="days" string="false">3' + "" + "" + ) + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="deepseek_v32") + + assert len(parsed) == 1 + assert parsed[0].function.name == "get_weather" + assert _arguments_dict(parsed[0]) == {"location": "Seoul", "days": 3} + + +def test_parse_with_openai_parser_handles_functions_recipient(): + payload = ( + '[{"recipient":"functions.get_weather","content":"{\\"city\\":\\"Seoul\\"}"}]' + ) + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="openai") + + assert len(parsed) == 1 + assert parsed[0].function.name == "get_weather" + assert _arguments_dict(parsed[0]) == {"city": "Seoul"} + + +def test_parser_key_dispatch_overrides_format_for_qwen3_xml(): + payload = ( + "" + "\ntabby\n" + "" + ) + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="qwen3_xml") + + assert len(parsed) == 1 + assert parsed[0].function.name == "search" + assert _arguments_dict(parsed[0]) == {"q": "tabby"} + + +def test_parser_failure_falls_back_to_format_parser(): + payload = ( + "" + "\n42\n" + "" + ) + + parsed = ToolCallProcessor.parse(payload, format="xml", parser_key="openai") + + assert len(parsed) == 1 + assert parsed[0].function.name == "lookup" + assert _arguments_dict(parsed[0]) == {"id": 42} From fbc65d808ffa0ede3a8b3e1f0b8d4541fde40c43 Mon Sep 17 00:00:00 2001 From: lesj0610 Date: Thu, 19 Feb 2026 01:29:01 +0900 Subject: [PATCH 06/19] Fix exaone4 streaming reasoning split across think/tool boundaries --- .../OAI/reasoning/exaone4_reasoning_parser.py | 142 +++++++++++++----- tests/exaone4_reasoning_parser_test.py | 89 +++++++++++ 2 files changed, 196 insertions(+), 35 deletions(-) diff --git a/endpoints/OAI/reasoning/exaone4_reasoning_parser.py b/endpoints/OAI/reasoning/exaone4_reasoning_parser.py index 6df47840..5e2dd43f 100644 --- a/endpoints/OAI/reasoning/exaone4_reasoning_parser.py +++ b/endpoints/OAI/reasoning/exaone4_reasoning_parser.py @@ -26,6 +26,18 @@ class Exaone4ReasoningParser(ReasoningParser): start_token = "" end_token = "" + # Tool-call starts supported by ToolCallProcessor parser families. + # We use these as fallback reasoning boundaries when a model emits + # tool syntax without closing . + tool_start_markers = ( + "", + "", + "<|tool▁call▁begin|>", + "<|DSML|function_calls>", + "<|DSML|invoke", + "<|python_tag|>", + ) def __init__(self, tokenizer: Any, *args, **kwargs): super().__init__(tokenizer, *args, **kwargs) @@ -35,18 +47,89 @@ def __init__(self, tokenizer: Any, *args, **kwargs): self.start_token_id = self.vocab.get(self.start_token) self.end_token_id = self.vocab.get(self.end_token) - def _contains_token( - self, token_id: int | None, token_text: str, token_ids: Sequence[int], text: str - ) -> bool: - if token_id is not None and token_id in token_ids: - return True - return token_text in text if text else False - def _strip_reasoning_tokens(self, text: str) -> str: if not text: return "" return text.replace(self.start_token, "").replace(self.end_token, "") + def _trailing_overlap_len(self, text: str, token: str) -> int: + """Longest suffix overlap of text with token prefix.""" + max_len = min(len(text), len(token) - 1) + for size in range(max_len, 0, -1): + if text.endswith(token[:size]): + return size + return 0 + + def _find_first_marker(self, text: str, markers: Sequence[str]) -> tuple[int, str] | None: + first_idx = -1 + first_marker = "" + for marker in markers: + idx = text.find(marker) + if idx == -1: + continue + if first_idx == -1 or idx < first_idx: + first_idx = idx + first_marker = marker + if first_idx == -1: + return None + return first_idx, first_marker + + def _max_trailing_overlap_len(self, text: str, markers: Sequence[str]) -> int: + overlap = 0 + for marker in markers: + overlap = max(overlap, self._trailing_overlap_len(text, marker)) + return overlap + + def _split_reasoning_content_streaming( + self, text: str + ) -> tuple[str | None, str | None]: + """Split text into reasoning/content for streaming-safe diffing. + + Important: when end token is not yet complete, withhold a trailing + overlap with `` or tool-call prefixes to avoid leaking + partial control-tag bytes into reasoning output. This prevents + boundary-split regressions such as `answer` and + `{...}`. + """ + if not self.thinking_enabled: + content = self._strip_reasoning_tokens(text) + return None, content or None + + body = text + if self.start_token in body: + _, _, body = body.partition(self.start_token) + + if self.end_token in body: + reasoning, _, content = body.partition(self.end_token) + return reasoning or None, self._strip_reasoning_tokens(content) or None + + marker_match = self._find_first_marker(body, self.tool_start_markers) + if marker_match is not None: + marker_index, _ = marker_match + reasoning = body[:marker_index] + content = body[marker_index:] + return reasoning or None, self._strip_reasoning_tokens(content) or None + + reasoning = body.replace(self.start_token, "") + overlap = max( + self._trailing_overlap_len(reasoning, self.end_token), + self._max_trailing_overlap_len(reasoning, self.tool_start_markers), + ) + if overlap: + reasoning = reasoning[:-overlap] + return reasoning or None, None + + def _delta_from_previous(self, previous: str | None, current: str | None) -> str | None: + if current is None: + return None + previous_text = previous or "" + if current.startswith(previous_text): + delta = current[len(previous_text) :] + else: + # Fallback for recovery paths where prefix alignment breaks. + delta = current + return delta or None + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: if not self.thinking_enabled: return True @@ -94,15 +177,16 @@ def extract_reasoning_streaming( return None if not self.thinking_enabled: - content = self._strip_reasoning_tokens(delta_text) - return DeltaMessage(content=content or None) if content else None - - end_in_prev = self._contains_token( - self.end_token_id, self.end_token, previous_token_ids, previous_text - ) - end_in_delta = self._contains_token( - self.end_token_id, self.end_token, delta_token_ids, delta_text - ) + prev_reasoning, prev_content = self._split_reasoning_content_streaming( + previous_text + ) + cur_reasoning, cur_content = self._split_reasoning_content_streaming( + current_text + ) + content_delta = self._delta_from_previous(prev_content, cur_content) + if content_delta is None: + return None + return DeltaMessage(content=content_delta) if len(delta_token_ids) == 1 and ( (self.start_token_id is not None and delta_token_ids[0] == self.start_token_id) @@ -110,24 +194,12 @@ def extract_reasoning_streaming( ): return None - if end_in_prev: - content = self._strip_reasoning_tokens(delta_text) - return DeltaMessage(content=content or None) if content else None - - if end_in_delta: - reasoning_part, _, content_part = delta_text.partition(self.end_token) - if self.start_token in reasoning_part: - _, _, reasoning_part = reasoning_part.partition(self.start_token) + prev_reasoning, prev_content = self._split_reasoning_content_streaming(previous_text) + cur_reasoning, cur_content = self._split_reasoning_content_streaming(current_text) - reasoning = reasoning_part or None - content = self._strip_reasoning_tokens(content_part) or None - if reasoning is None and content is None: - return None - return DeltaMessage(reasoning=reasoning, content=content) + reasoning_delta = self._delta_from_previous(prev_reasoning, cur_reasoning) + content_delta = self._delta_from_previous(prev_content, cur_content) - # EXAONE can omit start token in stream when template prefills . - # While thinking mode is enabled, treat pre-end chunks as reasoning. - reasoning_part = delta_text.replace(self.start_token, "") - - reasoning = reasoning_part or None - return DeltaMessage(reasoning=reasoning) if reasoning else None + if reasoning_delta is None and content_delta is None: + return None + return DeltaMessage(reasoning=reasoning_delta, content=content_delta) diff --git a/tests/exaone4_reasoning_parser_test.py b/tests/exaone4_reasoning_parser_test.py index d04cf6f1..135b37cb 100644 --- a/tests/exaone4_reasoning_parser_test.py +++ b/tests/exaone4_reasoning_parser_test.py @@ -104,6 +104,95 @@ def test_thinking_streaming_prefill_flow_without_start_token(): assert third.content == "!" +def test_thinking_streaming_handles_split_end_token_boundary(): + parser = _parser(enable_thinking=True) + + first = parser.extract_reasoning_streaming( + previous_text="", + current_text="analysis {"name":"lookup","arguments":{}}' + + +def test_thinking_streaming_handles_split_deepseek_tool_boundary_without_end_token(): + parser = _parser(enable_thinking=True) + + first = parser.extract_reasoning_streaming( + previous_text="", + current_text="analysis <|tool▁call▁b", + delta_text="analysis <|tool▁call▁b", + previous_token_ids=[], + current_token_ids=[11], + delta_token_ids=[11], + ) + assert first is not None + assert first.reasoning == "analysis " + assert first.content is None + + second = parser.extract_reasoning_streaming( + previous_text="analysis <|tool▁call▁b", + current_text=( + "analysis <|tool▁call▁begin|>lookup<|tool▁sep|>{\"q\":\"tabby\"}" + "<|tool▁call▁end|>" + ), + delta_text='egin|>lookup<|tool▁sep|>{"q":"tabby"}<|tool▁call▁end|>', + previous_token_ids=[11], + current_token_ids=[11, 12], + delta_token_ids=[12], + ) + assert second is not None + assert second.reasoning is None + assert second.content == ( + '<|tool▁call▁begin|>lookup<|tool▁sep|>{"q":"tabby"}<|tool▁call▁end|>' + ) + + def test_thinking_mode_content_ids_and_end_detection(): parser = _parser(enable_thinking=True) From d26bbcaa994f66dc99709cfd88b2dd29e424b56c Mon Sep 17 00:00:00 2001 From: devnen Date: Sat, 14 Feb 2026 14:26:57 +0100 Subject: [PATCH 07/19] Full tool-calling support: XML parsing, streaming compliance, Pydantic fix, inference abort fix --- common/templating.py | 5 - endpoints/OAI/types/tools.py | 13 - endpoints/OAI/utils/chat_completion.py | 277 ++++--------- endpoints/OAI/utils/tools.py | 550 +++---------------------- templates/tool_calls/qwen3_coder.jinja | 6 +- 5 files changed, 134 insertions(+), 717 deletions(-) diff --git a/common/templating.py b/common/templating.py index dda06d85..e353a30a 100644 --- a/common/templating.py +++ b/common/templating.py @@ -34,7 +34,6 @@ class TemplateMetadata: stop_strings: List[str] = field(default_factory=list) tool_start: Optional[str] = None - tool_end: Optional[str] = None tool_call_format: str = "json" @@ -98,10 +97,6 @@ async def extract_metadata(self, template_vars: dict): if isinstance(template_module.tool_start, str): template_metadata.tool_start = template_module.tool_start - if hasattr(template_module, "tool_end"): - if isinstance(template_module.tool_end, str): - template_metadata.tool_end = template_module.tool_end - if hasattr(template_module, "tool_call_format"): fmt = template_module.tool_call_format if isinstance(fmt, str) and fmt in VALID_TOOL_CALL_FORMATS: diff --git a/endpoints/OAI/types/tools.py b/endpoints/OAI/types/tools.py index 1e572663..1428fa99 100644 --- a/endpoints/OAI/types/tools.py +++ b/endpoints/OAI/types/tools.py @@ -40,16 +40,3 @@ class ToolCall(BaseModel): function: Tool type: Literal["function"] = "function" index: Optional[int] = None - - -class NamedToolFunction(BaseModel): - """Represents a named function reference for tool_choice.""" - - name: str - - -class NamedToolChoice(BaseModel): - """Represents a named tool choice (forces a specific function call).""" - - function: NamedToolFunction - type: Literal["function"] = "function" diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index a3f302ef..8a3593a3 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -32,7 +32,7 @@ ChatCompletionStreamChoice, ) from endpoints.OAI.types.common import UsageStats -from endpoints.OAI.types.tools import NamedToolChoice, ToolCall +from endpoints.OAI.types.tools import ToolCall from endpoints.OAI.utils.completion import _parse_gen_request_id, _stream_collector from endpoints.OAI.utils.parser_options import ( list_tool_call_parsers, @@ -42,114 +42,6 @@ from endpoints.OAI.utils.tools import ToolCallProcessor, TOOL_CALL_SCHEMA -@dataclass -class _StreamReasoningState: - text: str = "" - token_ids: List[int] = field(default_factory=list) - - -class _TokenizerAdapter: - """Expose the minimal tokenizer interface required by reasoning parsers.""" - - def __init__(self): - self._vocab = None - - def get_vocab(self) -> dict[str, int]: - if self._vocab is not None: - return self._vocab - - tokenizer = model.container.tokenizer - if hasattr(tokenizer, "get_vocab"): - self._vocab = tokenizer.get_vocab() - return self._vocab - - pieces = tokenizer.get_id_to_piece_list(True) - vocab: dict[str, int] = {} - for token_id, piece in enumerate(pieces): - if piece not in vocab: - vocab[piece] = token_id - self._vocab = vocab - return vocab - - -def _token_ids_from_generation(generation: dict) -> List[int]: - token_ids = generation.get("token_ids") - if token_ids is None: - return [] - if isinstance(token_ids, list): - return token_ids - if hasattr(token_ids, "flatten"): - return token_ids.flatten().tolist() - return list(token_ids) - - -def _build_reasoning_parser(request_data: ChatCompletionRequest): - parser_key = unwrap(config.model.reasoning_parser, "basic") or "basic" - try: - parser_cls = ReasoningParserManager.get_reasoning_parser(parser_key) - except KeyError as exc: - raise HTTPException(400, str(exc)) from exc - - template_kwargs = unwrap(request_data.template_vars, {}) - try: - return parser_cls(_TokenizerAdapter(), chat_template_kwargs=template_kwargs) - except RuntimeError as exc: - # Keep compatibility for models that do not expose thinking tags. - if parser_key == "basic": - logger.warning( - "Reasoning parser 'basic' could not initialize ({}). " - "Falling back to identity parser.", - str(exc), - ) - identity_cls = ReasoningParserManager.get_reasoning_parser("identity") - return identity_cls(_TokenizerAdapter(), chat_template_kwargs=template_kwargs) - raise HTTPException(400, str(exc)) from exc - - -def _validate_and_get_tool_call_format( - request_data: ChatCompletionRequest, default_format: str -) -> str: - tool_choice = request_data.tool_choice - parser_key = config.model.tool_call_parser - enable_auto = bool(config.model.enable_auto_tool_choice) - parser_names = list_tool_call_parsers() - - if parser_key and parser_key not in parser_names: - parsers_str = ", ".join(sorted(parser_names)) - raise HTTPException( - 400, - f"invalid tool call parser: {parser_key} (choose from {{{parsers_str}}})", - ) - - if tool_choice == "auto" and (not enable_auto or not parser_key): - raise HTTPException( - 400, - '"auto" tool choice requires --enable-auto-tool-choice and ' - "--tool-call-parser to be set", - ) - - if tool_choice not in (None, "none", "auto") and parser_key is None: - raise HTTPException( - 400, - f'tool_choice="{tool_choice}" requires --tool-call-parser to be set', - ) - - if ( - tool_choice == "none" - and config.model.exclude_tools_when_tool_choice_none - and request_data.tools - ): - request_data.tools = None - - resolved_format = resolve_tool_call_format(parser_key, default_format) - if not resolved_format: - raise HTTPException( - 400, - f"Could not resolve format for tool_call_parser={parser_key}", - ) - return resolved_format - - def _serialize_stream_chunk(chunk) -> str: """Serialize a streaming chunk with OpenAI-compatible field handling. @@ -170,7 +62,6 @@ def _create_response( generations: List[dict], model_name: Optional[str], tool_call_format: str = "json", - tool_choice=None, ): """Create a chat completion response from the provided text.""" @@ -188,18 +79,9 @@ def _create_response( tool_calls_raw = generation.get("tool_calls") if tool_calls_raw: - parsed = ToolCallProcessor.parse( - tool_calls_raw, - format=tool_call_format, - parser_key=parser_key, - ) - if parsed and isinstance(tool_choice, NamedToolChoice): - parsed = ToolCallProcessor.filter_by_name( - parsed, tool_choice.function.name - ) + parsed = ToolCallProcessor.parse(tool_calls_raw, format=tool_call_format) if parsed: message.tool_calls = parsed - message.content = None else: logger.warning( "Tool call text present but parsing returned no results " @@ -375,7 +257,6 @@ def _build_tool_call_chunks( tool_calls: List[ToolCall], request_id: str, model_name: str, - choice_index: int, ) -> List[ChatCompletionStreamChunk]: """Build the OpenAI-standard streaming sequence for tool calls. @@ -408,7 +289,7 @@ def _build_tool_call_chunks( id=chunk_id, choices=[ ChatCompletionStreamChoice( - index=choice_index, + index=0, delta=tool_call_message, finish_reason=None, ) @@ -420,7 +301,7 @@ def _build_tool_call_chunks( # Use model_construct to prevent Pydantic's smart Union from # coercing the empty dict {} into ChatCompletionMessage(role="user") finish_choice = ChatCompletionStreamChoice.model_construct( - index=choice_index, + index=0, delta={}, finish_reason="tool_calls", logprobs=None, @@ -578,7 +459,7 @@ async def stream_generate_chat_completion( gen_queue = asyncio.Queue() gen_tasks: List[asyncio.Task] = [] tool_start = model.container.prompt_template.metadata.tool_start - default_tool_call_format = model.container.prompt_template.metadata.tool_call_format + tool_call_format = model.container.prompt_template.metadata.tool_call_format disconnect_task = asyncio.create_task(request_disconnect_loop(request)) try: @@ -630,45 +511,6 @@ async def stream_generate_chat_completion( if disconnect_task.done(): raise CancelledError() - # Stream collector will push an exception to the queue if it fails - if isinstance(generation, Exception): - raise generation - - if "text" in generation and generation.get("finish_reason") is None: - idx = generation["index"] - state = reasoning_states[idx] - - delta_text = unwrap(generation.get("text"), "") - delta_token_ids = _token_ids_from_generation(generation) - - current_text = state.text + delta_text - current_token_ids = state.token_ids + delta_token_ids - - delta_message = reasoning_parser.extract_reasoning_streaming( - state.text, - current_text, - delta_text, - state.token_ids, - current_token_ids, - delta_token_ids, - ) - - state.text = current_text - state.token_ids = current_token_ids - - if delta_message is None: - continue - - generation["text"] = delta_message.content - if data.include_reasoning: - generation["reasoning"] = delta_message.reasoning - generation["reasoning_content"] = delta_message.reasoning - else: - generation["reasoning"] = None - generation["reasoning_content"] = None - if generation["text"] is None: - continue - # Handle options if a tool model is present if (tool_start or force_tool_pass) and data.tool_choice != "none": if "stop_str" in generation: @@ -684,6 +526,48 @@ async def stream_generate_chat_completion( # Only one generation present in this case generation = generations[0] + # Emit proper three-phase tool-call streaming sequence + if "tool_calls" in generation: + tool_calls_raw = generation["tool_calls"] + parsed = ToolCallProcessor.parse( + tool_calls_raw, format=tool_call_format + ) + if parsed: + for tc_chunk in _build_tool_call_chunks( + parsed, + request.state.id, + model_path.name, + ): + yield _serialize_stream_chunk(tc_chunk) + + # Handle completion and usage after tool calls + if ( + all(task.done() for task in gen_tasks) + and gen_queue.empty() + ): + if ( + data.stream_options + and data.stream_options.include_usage + ): + usage_chunk = _create_stream_chunk( + request.state.id, + generation, + model_path.name, + is_usage_chunk=True, + ) + yield _serialize_stream_chunk(usage_chunk) + + logger.info( + "Finished chat completion streaming " + f"request {request.state.id}" + ) + yield "[DONE]" + break + continue + + elif "text" in generation: + current_generation_text += generation["text"] + # Emit proper three-phase tool-call streaming sequence if "tool_calls" in generation: tool_calls_raw = generation["tool_calls"] @@ -759,8 +643,6 @@ async def stream_generate_chat_completion( # Get out if the request gets disconnected handle_request_disconnect("Chat completion generation cancelled by user.") - except HTTPException as exc: - yield get_generator_error(str(exc.detail)) except Exception: yield get_generator_error( "Chat completion aborted. Please check the server console." @@ -779,10 +661,7 @@ async def generate_chat_completion( ): gen_tasks: List[asyncio.Task] = [] tool_start = model.container.prompt_template.metadata.tool_start - default_tool_call_format = model.container.prompt_template.metadata.tool_call_format - tool_call_format = _validate_and_get_tool_call_format( - data, default_tool_call_format - ) + tool_call_format = model.container.prompt_template.metadata.tool_call_format try: logger.info(f"Received chat completion request {request.state.id}") @@ -817,26 +696,11 @@ async def generate_chat_completion( tool_call_format=tool_call_format, ) - reasoning_parser = _build_reasoning_parser(data) - for generation in generations: - reasoning, content = reasoning_parser.extract_reasoning( - unwrap(generation.get("text"), ""), - data, - ) - - if not data.include_reasoning: - reasoning = None - - generation["reasoning"] = reasoning - generation["reasoning_content"] = reasoning - generation["text"] = content - response = _create_response( request.state.id, generations, model_path.name, tool_call_format=tool_call_format, - tool_choice=data.tool_choice, ) logger.info(f"Finished chat completion request {request.state.id}") @@ -865,19 +729,7 @@ async def generate_tool_calls( ): gen_tasks: List[asyncio.Task] = [] tool_start = model.container.prompt_template.metadata.tool_start - if tool_call_format is None: - default_tool_call_format = model.container.prompt_template.metadata.tool_call_format - tool_call_format = _validate_and_get_tool_call_format( - data, default_tool_call_format - ) - tool_choice = data.tool_choice - parser_key = config.model.tool_call_parser - use_native_generation = parser_uses_native_tool_generation( - parser_key, tool_call_format - ) - - if tool_choice == "none": - return generations + tool_call_format = model.container.prompt_template.metadata.tool_call_format # Tracks which generations asked for a tool call tool_idx: List[int] = [] @@ -885,14 +737,12 @@ async def generate_tool_calls( # Copy to make sure the parent JSON schema doesn't get modified tool_data = data.model_copy(deep=True) - if use_native_generation: - # Native syntax mode: let the model generate its natural tool-call - # representation without JSON schema constraint. + if tool_call_format in ("xml", "auto"): + # XML / auto mode: let the model generate its natural output + # without JSON schema constraint logger.debug( - "generate_tool_calls: Using parser '{}' in native mode " - "(format={}, no JSON schema constraint)", - parser_key or "template-default", - tool_call_format, + f"generate_tool_calls: Using '{tool_call_format}' mode " + f"(no JSON schema constraint)" ) # Remove tool_start from stop strings so the model can emit @@ -942,6 +792,17 @@ async def generate_tool_calls( if use_native_generation and tool_start: tool_prompt = tool_prompt + tool_start + "\n" + # For XML/auto mode: append tool_start back to prompt. + # The stop string was consumed by the first pass and not included + # in full_text, but the model expects to continue after . + # Include a trailing newline to match the canonical template format. + if tool_call_format in ("xml", "auto"): + prompt = prompt + tool_start + "\n" + logger.debug( + f"generate_tool_calls: Appended '{tool_start}\\n' " + f"to prompt for XML continuation" + ) + gen_request_id = gen.get("request_id") tool_request_id = f"{gen_request_id}-tool" @@ -965,9 +826,13 @@ async def generate_tool_calls( for gen_idx, tool_call in zip(tool_idx, tool_calls, strict=True): raw_text = tool_call["text"] - if use_native_generation and tool_start: - # Prepend tool_start to reconstruct complete native payload. + if tool_call_format in ("xml", "auto"): + # Prepend tool_start to reconstruct complete XML for parser raw_text = tool_start + "\n" + raw_text + logger.debug( + f"generate_tool_calls: Raw XML tool call output " + f"({len(raw_text)} chars): {raw_text[:500]}..." + ) generations[gen_idx]["tool_calls"] = raw_text diff --git a/endpoints/OAI/utils/tools.py b/endpoints/OAI/utils/tools.py index bf6e2a0d..f0114e11 100644 --- a/endpoints/OAI/utils/tools.py +++ b/endpoints/OAI/utils/tools.py @@ -1,16 +1,11 @@ """Tool call processing utilities for OAI server.""" -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import ast import json import re from loguru import logger -from typing import Any, Callable, Dict, List, Tuple +from typing import Any, List, Tuple from endpoints.OAI.types.tools import ToolCall, Tool -from endpoints.OAI.utils.parser_options import resolve_tool_call_parser_key TOOL_CALL_SCHEMA = { @@ -65,32 +60,10 @@ THINK_BLOCK_RE = re.compile(r".*?\s*", re.DOTALL) THINK_UNCLOSED_RE = re.compile(r"(?!.*).*$", re.DOTALL) -# Markdown code fence patterns -CODE_FENCE_RE = re.compile(r"^```(?:json)?\s*", re.MULTILINE) -CODE_FENCE_END_RE = re.compile(r"\s*```\s*$", re.MULTILINE) - -# DeepSeek family patterns -DEEPSEEK_V31_CALL_RE = re.compile( - r"<|tool▁call▁begin|>(?P.*?)<|tool▁sep|>(?P.*?)<|tool▁call▁end|>", - re.DOTALL, -) -DEEPSEEK_V3_CALL_RE = re.compile( - r"<|tool▁call▁begin|>(?P.*?)<|tool▁sep|>(?P.*?)\n```json\n(?P.*?)\n```(?:\s*)<|tool▁call▁end|>", # noqa: E501 - re.DOTALL, -) -DEEPSEEK_V32_INVOKE_RE = re.compile( - r'<|DSML|invoke\s+name="(?P[^"]+)"\s*>(?P.*?)', - re.DOTALL, -) -DEEPSEEK_V32_PARAM_RE = re.compile( - r'<|DSML|parameter\s+name="(?P[^"]+)"(?:\s+string="(?Ptrue|false)")?\s*>(?P.*?)', # noqa: E501 - re.DOTALL, -) - def _strip_think_blocks(text: str) -> str: - """Strip ... blocks from text. - + """ + Strip ... blocks from text. Handles both complete and unclosed blocks (quantization can cause the model to never close a think tag). """ @@ -115,7 +88,8 @@ def _strip_think_blocks(text: str) -> str: def _coerce_param_value(raw: str) -> Any: - """Coerce a raw parameter value string to the appropriate Python type. + """ + Coerce a raw parameter value string to the appropriate Python type. Strategy (safe, no eval()): 1. Strip leading/trailing newlines (official template emits \\n @@ -146,387 +120,20 @@ def _coerce_param_value(raw: str) -> Any: class ToolCallProcessor: - _PARSER_DISPATCHER: Dict[str, Callable[[str], List[ToolCall]]] = {} # ------------------------------------------------------------------ - # JSON normalization helpers + # JSON parsing (existing behavior, unchanged) # ------------------------------------------------------------------ @staticmethod - def _normalize_tool_calls(raw) -> list: - """Normalize model-emitted tool call payloads into OAI-like objects. - - Accepted forms: - - [{"type":"function","function":{"name":...,"arguments":{...}}}] - - [{"name":...,"arguments":{...}}] - - {"name":...,"arguments":{...}} - """ - if isinstance(raw, dict): - raw = [raw] - if not isinstance(raw, list): - raise ValueError("tool_calls payload is not list/dict") - - normalized: list = [] - for item in raw: - if not isinstance(item, dict): - continue - - if "function" in item and isinstance(item["function"], dict): - fn = item["function"] - name = fn.get("name") - arguments = fn.get("arguments", {}) - else: - name = item.get("name") - arguments = item.get("arguments", {}) - - if name is None: - continue - - if isinstance(arguments, str): - try: - arguments = json.loads(arguments) - except json.JSONDecodeError: - arguments = {"input": arguments} - - normalized.append( - { - "type": "function", - "function": { - "name": name, - "arguments": arguments if isinstance(arguments, dict) else {}, - }, - } - ) - return normalized - - @staticmethod - def _safe_json_loads(payload: str) -> list: - """Best-effort JSON parse for model-emitted tool payloads. - - Handles: clean JSON, markdown-fenced JSON, JSON substrings in - surrounding text, flat {name, arguments} dicts, and single objects. - """ - # Direct parse - try: - return ToolCallProcessor._normalize_tool_calls(json.loads(payload)) - except (json.JSONDecodeError, ValueError): - pass - - # Clean up common model artifacts (markdown fences, whitespace) - cleaned = payload.strip() - cleaned = CODE_FENCE_RE.sub("", cleaned) - cleaned = CODE_FENCE_END_RE.sub("", cleaned) - cleaned = cleaned.strip() - - # Try cleaned - try: - return ToolCallProcessor._normalize_tool_calls(json.loads(cleaned)) - except (json.JSONDecodeError, ValueError): - pass - - # Find JSON array substring - start = cleaned.find("[") - end = cleaned.rfind("]") - if start != -1 and end != -1 and end > start: - try: - return ToolCallProcessor._normalize_tool_calls( - json.loads(cleaned[start : end + 1]) - ) - except (json.JSONDecodeError, ValueError): - pass - - # Find JSON object substring - obj_start = cleaned.find("{") - obj_end = cleaned.rfind("}") - if obj_start != -1 and obj_end != -1 and obj_end > obj_start: - try: - return ToolCallProcessor._normalize_tool_calls( - json.loads(cleaned[obj_start : obj_end + 1]) - ) - except (json.JSONDecodeError, ValueError): - pass + def from_json(tool_calls_str: str) -> List[ToolCall]: + """Postprocess tool call JSON to a parseable class.""" - raise json.JSONDecodeError( - "Could not extract valid JSON from payload", payload, 0 + logger.debug( + f"JSON Parser: Parsing tool calls from JSON " + f"({len(tool_calls_str)} chars)" ) - @staticmethod - def _build_tool_calls_from_normalized(raw: Any) -> List[ToolCall]: - """Normalize dict/list payload and build ToolCall models.""" - normalized = ToolCallProcessor._normalize_tool_calls(raw) - for tool_call in normalized: - tool_call["function"]["arguments"] = json.dumps( - tool_call["function"]["arguments"], ensure_ascii=False - ) - return [ToolCall(**tool_call) for tool_call in normalized] - - @staticmethod - def _decode_json_sequence(text: str) -> List[Any]: - """Decode multiple JSON values from a single string.""" - decoder = json.JSONDecoder() - values: List[Any] = [] - idx = 0 - while idx < len(text): - while idx < len(text) and text[idx] in " \t\r\n,;": - idx += 1 - if idx >= len(text): - break - if text.startswith("<|python_tag|>", idx): - idx += len("<|python_tag|>") - continue - try: - value, end = decoder.raw_decode(text[idx:]) - except json.JSONDecodeError: - break - values.append(value) - idx += end - return values - - @staticmethod - def _coerce_argument_payload(arguments_raw: str) -> str: - """Normalize raw argument payload to a JSON string where possible.""" - payload = arguments_raw.strip() - if not payload: - return "{}" - try: - return json.dumps(json.loads(payload), ensure_ascii=False) - except (json.JSONDecodeError, ValueError, TypeError): - return payload - - @staticmethod - def _ast_to_literal(node: ast.AST) -> Any: - """Safely convert AST literal nodes to Python primitives.""" - if isinstance(node, ast.Constant): - return node.value - if isinstance(node, ast.List): - return [ToolCallProcessor._ast_to_literal(item) for item in node.elts] - if isinstance(node, ast.Tuple): - return [ToolCallProcessor._ast_to_literal(item) for item in node.elts] - if isinstance(node, ast.Dict): - result = {} - for key, value in zip(node.keys, node.values): - literal_key = ToolCallProcessor._ast_to_literal(key) # type: ignore[arg-type] - if not isinstance(literal_key, str): - raise ValueError("pythonic parser requires string dict keys") - result[literal_key] = ToolCallProcessor._ast_to_literal(value) - return result - if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub): - return -ToolCallProcessor._ast_to_literal(node.operand) - raise ValueError(f"unsupported pythonic AST node: {type(node).__name__}") - - # ------------------------------------------------------------------ - # JSON parsing - # ------------------------------------------------------------------ - - @staticmethod - def from_hermes(raw_text: str) -> List[ToolCall]: - """Parse Hermes-style JSON tool calls (often wrapped in ).""" - text = _strip_think_blocks(raw_text) - wrapped_calls = [] - for match in TOOL_CALL_BLOCK_RE.finditer(text): - inner = match.group(1).strip() - if not inner: - continue - try: - parsed = json.loads(inner) - except (json.JSONDecodeError, ValueError): - continue - wrapped_calls.extend(ToolCallProcessor._build_tool_calls_from_normalized(parsed)) - - if wrapped_calls: - return wrapped_calls - - return ToolCallProcessor.from_json(text) - - @staticmethod - def from_llama(raw_text: str) -> List[ToolCall]: - """Parse Llama JSON tool calls (single/multiple JSON objects).""" - text = _strip_think_blocks(raw_text).strip() - if text.startswith("<|python_tag|>"): - text = text[len("<|python_tag|>") :].lstrip() - - try: - parsed = ToolCallProcessor.from_json(text) - if parsed: - return parsed - except (json.JSONDecodeError, ValueError, KeyError): - pass - - decoded = ToolCallProcessor._decode_json_sequence(text) - if not decoded: - return [] - - flattened = [] - for item in decoded: - if isinstance(item, list): - flattened.extend(item) - else: - flattened.append(item) - - return ToolCallProcessor._build_tool_calls_from_normalized(flattened) - - @staticmethod - def from_openai(raw_text: str) -> List[ToolCall]: - """Best-effort parser for OpenAI/Harmony-style text payloads.""" - text = _strip_think_blocks(raw_text).strip() - try: - parsed = ToolCallProcessor.from_json(text) - if parsed: - return parsed - except (json.JSONDecodeError, ValueError, KeyError): - pass - - decoded = ToolCallProcessor._decode_json_sequence(text) - tool_calls: List[ToolCall] = [] - normalized_items = [] - for value in decoded: - candidates = value if isinstance(value, list) else [value] - for item in candidates: - if not isinstance(item, dict): - continue - - nested = item.get("tool_calls") - if nested: - try: - tool_calls.extend( - ToolCallProcessor._build_tool_calls_from_normalized(nested) - ) - except (ValueError, KeyError, TypeError): - pass - - recipient = item.get("recipient") - content = item.get("content") - if isinstance(recipient, str) and recipient.startswith("functions."): - fn_name = recipient.split("functions.", 1)[1] - if isinstance(content, str): - payload = ToolCallProcessor._coerce_argument_payload(content) - elif content is None: - payload = "{}" - else: - payload = json.dumps(content, ensure_ascii=False) - tool_calls.append( - ToolCall(function=Tool(name=fn_name, arguments=payload)) - ) - continue - - if "name" in item: - normalized_items.append(item) - - if normalized_items: - tool_calls.extend( - ToolCallProcessor._build_tool_calls_from_normalized(normalized_items) - ) - - return tool_calls - - @staticmethod - def from_pythonic(raw_text: str) -> List[ToolCall]: - """Parse Pythonic list-of-calls tool syntax.""" - text = _strip_think_blocks(raw_text).strip() - if text.startswith("<|python_tag|>"): - text = text[len("<|python_tag|>") :].lstrip() - if not text: - return [] - - if not text.startswith("[") and re.match(r"^[A-Za-z_]\w*\s*\(", text): - text = f"[{text}]" - - expression = ast.parse(text, mode="eval").body - call_nodes = expression.elts if isinstance(expression, ast.List) else [expression] - - tool_calls = [] - for node in call_nodes: - if not isinstance(node, ast.Call) or not isinstance(node.func, ast.Name): - continue - args_dict: Dict[str, Any] = {} - if node.args: - args_dict["_args"] = [ - ToolCallProcessor._ast_to_literal(argument) - for argument in node.args - ] - for keyword in node.keywords: - if keyword.arg is None: - continue - args_dict[keyword.arg] = ToolCallProcessor._ast_to_literal(keyword.value) - - tool_calls.append( - ToolCall( - function=Tool( - name=node.func.id, - arguments=json.dumps(args_dict, ensure_ascii=False), - ) - ) - ) - - return tool_calls - - @staticmethod - def from_deepseek_v31(raw_text: str) -> List[ToolCall]: - """Parse DeepSeek v3.1 tool call syntax.""" - tool_calls = [] - for match in DEEPSEEK_V31_CALL_RE.finditer(raw_text): - name = match.group("name").strip() - if not name: - continue - arguments = ToolCallProcessor._coerce_argument_payload(match.group("args")) - tool_calls.append(ToolCall(function=Tool(name=name, arguments=arguments))) - return tool_calls - - @staticmethod - def from_deepseek_v3(raw_text: str) -> List[ToolCall]: - """Parse DeepSeek v3 tool call syntax.""" - tool_calls = [] - for match in DEEPSEEK_V3_CALL_RE.finditer(raw_text): - name = match.group("name").strip() - if not name: - continue - arguments = ToolCallProcessor._coerce_argument_payload(match.group("args")) - tool_calls.append(ToolCall(function=Tool(name=name, arguments=arguments))) - - if tool_calls: - return tool_calls - - return ToolCallProcessor.from_deepseek_v31(raw_text) - - @staticmethod - def from_deepseek_v32(raw_text: str) -> List[ToolCall]: - """Parse DeepSeek v3.2 DSML tool call syntax.""" - tool_calls = [] - for invoke in DEEPSEEK_V32_INVOKE_RE.finditer(raw_text): - function_name = invoke.group("name").strip() - if not function_name: - continue - - params: Dict[str, Any] = {} - body = invoke.group("body") - for param in DEEPSEEK_V32_PARAM_RE.finditer(body): - key = param.group("name").strip() - value_raw = param.group("value") - is_string = param.group("string") == "true" - if is_string: - value = value_raw.strip("\n") - else: - value = _coerce_param_value(value_raw) - params[key] = value - - tool_calls.append( - ToolCall( - function=Tool( - name=function_name, - arguments=json.dumps(params, ensure_ascii=False), - ) - ) - ) - - if tool_calls: - return tool_calls - - return ToolCallProcessor.from_deepseek_v31(raw_text) - - @staticmethod - def from_json(tool_calls_str: str) -> List[ToolCall]: - """Postprocess tool call JSON to a parseable class. - Handles clean JSON arrays, markdown-fenced output, flat dicts, and other common model output variations via _safe_json_loads. """ @@ -548,7 +155,8 @@ def from_json(tool_calls_str: str) -> List[ToolCall]: @staticmethod def from_xml(raw_text: str) -> List[ToolCall]: - """Parse Qwen3-Coder XML-format tool calls into ToolCall objects. + """ + Parse Qwen3-Coder XML-format tool calls into ToolCall objects. Handles: - Wrapped: ... @@ -558,7 +166,10 @@ def from_xml(raw_text: str) -> List[ToolCall]: - Multi-line parameter values - Missing closing tags """ - logger.debug(f"XML Parser: Parsing tool calls ({len(raw_text)} chars)") + logger.debug( + f"XML Parser: Parsing tool calls from XML " f"({len(raw_text)} chars)" + ) + logger.debug(f"XML Parser: Raw input: {raw_text[:500]}...") # Stage 1: Strip think blocks text = _strip_think_blocks(raw_text) @@ -592,12 +203,15 @@ def from_xml(raw_text: str) -> List[ToolCall]: if not is_wrapped: logger.debug( "XML Parser: Found bare block without " - " wrapper" + " wrapper (common Qwen3-Coder behavior)" ) function_blocks.append((func_match.group(1), func_match.group(2))) if not function_blocks: - logger.warning("XML Parser: No blocks found") + logger.warning( + f"XML Parser: No blocks found in text: " + f"{text[:200]}..." + ) return [] # Stage 4: Parse each function block into a ToolCall @@ -612,6 +226,7 @@ def from_xml(raw_text: str) -> List[ToolCall]: value_raw = param_match.group(2) value = _coerce_param_value(value_raw) params[key] = value + logger.debug(f"XML Parser: param '{key}' = {repr(value)[:100]}") arguments_json = json.dumps(params, ensure_ascii=False) @@ -619,6 +234,10 @@ def from_xml(raw_text: str) -> List[ToolCall]: function=Tool(name=func_name, arguments=arguments_json) ) tool_calls.append(tool_call) + logger.debug( + f"XML Parser: Parsed tool call: {func_name}" + f"({', '.join(params.keys())})" + ) logger.debug(f"XML Parser: Successfully parsed {len(tool_calls)} tool call(s)") return tool_calls @@ -629,11 +248,12 @@ def from_xml(raw_text: str) -> List[ToolCall]: @staticmethod def from_auto(raw_text: str) -> List[ToolCall]: - """Auto-detect format and parse. + """ + Auto-detect format and parse. Tries in order: 1. Pure JSON (standard TabbyAPI / Llama) - 2. JSON inside wrappers (Qwen3-Instruct style) + 2. JSON inside wrapper (Qwen3-Instruct style) 3. XML with tags (Qwen3-Coder style) """ logger.debug("Auto Parser: Attempting format auto-detection") @@ -646,32 +266,31 @@ def from_auto(raw_text: str) -> List[ToolCall]: except (json.JSONDecodeError, ValueError, KeyError) as e: logger.debug(f"Auto Parser: Not JSON ({e}), trying next format") - # Attempt 2: JSON inside wrappers (Qwen3-Instruct) + # Attempt 2: JSON inside wrapper (Qwen3-Instruct) try: - all_tool_calls = [] for match in TOOL_CALL_BLOCK_RE.finditer(raw_text): inner = match.group(1).strip() - if inner.startswith("{") or inner.startswith("["): + if inner.startswith("{"): parsed = json.loads(inner) if isinstance(parsed, dict): parsed = [parsed] - if isinstance(parsed, list): - for tc in parsed: - name = tc.get("name", "") - arguments = tc.get("arguments", {}) - if isinstance(arguments, dict): - arguments = json.dumps(arguments) - elif not isinstance(arguments, str): - arguments = json.dumps(arguments) - all_tool_calls.append( - ToolCall(function=Tool(name=name, arguments=arguments)) - ) - if all_tool_calls: - logger.debug( - "Auto Parser: Detected JSON-inside-tool_call " - f"format ({len(all_tool_calls)} call(s))" - ) - return all_tool_calls + tool_calls = [] + for tc in parsed: + name = tc.get("name", "") + arguments = tc.get("arguments", {}) + if isinstance(arguments, dict): + arguments = json.dumps(arguments) + elif not isinstance(arguments, str): + arguments = json.dumps(arguments) + tool_calls.append( + ToolCall(function=Tool(name=name, arguments=arguments)) + ) + if tool_calls: + logger.debug( + "Auto Parser: Detected JSON-inside-tool_call " + "format (Qwen3-Instruct style)" + ) + return tool_calls except (json.JSONDecodeError, ValueError, KeyError) as e: logger.debug(f"Auto Parser: Not JSON-in-tool_call ({e}), trying XML") @@ -688,57 +307,21 @@ def from_auto(raw_text: str) -> List[ToolCall]: # ------------------------------------------------------------------ @staticmethod - def _parser_dispatcher() -> Dict[str, Callable[[str], List[ToolCall]]]: - """Registry for parser-key-specific handlers.""" - if not ToolCallProcessor._PARSER_DISPATCHER: - ToolCallProcessor._PARSER_DISPATCHER = { - "deepseek_v3": ToolCallProcessor.from_deepseek_v3, - "deepseek_v31": ToolCallProcessor.from_deepseek_v31, - "deepseek_v32": ToolCallProcessor.from_deepseek_v32, - "hermes": ToolCallProcessor.from_hermes, - "llama": ToolCallProcessor.from_llama, - "llama3_json": ToolCallProcessor.from_llama, - "llama4_json": ToolCallProcessor.from_llama, - "openai": ToolCallProcessor.from_openai, - "pythonic": ToolCallProcessor.from_pythonic, - "qwen3_coder": ToolCallProcessor.from_xml, - "qwen3_xml": ToolCallProcessor.from_xml, - } - return ToolCallProcessor._PARSER_DISPATCHER - - @staticmethod - def parse( - tool_calls_str: str, format: str = "json", parser_key: str | None = None - ) -> List[ToolCall]: - """Dispatch tool call parsing to the appropriate format handler. + def parse(tool_calls_str: str, format: str = "json") -> List[ToolCall]: + """ + Dispatch tool call parsing to the appropriate format handler. Args: tool_calls_str: Raw tool call text from model generation. format: One of ``"json"``, ``"xml"``, ``"auto"``. - parser_key: Optional vLLM-compatible parser key. Returns: List of parsed ToolCall objects. Empty list on parse failure (never raises). """ - try: - if parser_key: - canonical_key = resolve_tool_call_parser_key(parser_key) or parser_key - parser = ToolCallProcessor._parser_dispatcher().get(canonical_key) - if parser: - try: - parsed = parser(tool_calls_str) - except Exception as exc: - logger.warning( - "Parser '{}' failed: {}. Falling back to format '{}'.", - canonical_key, - str(exc), - format, - ) - else: - if parsed: - return parsed + logger.debug(f"ToolCallProcessor.parse: format={format}") + try: if format == "xml": return ToolCallProcessor.from_xml(tool_calls_str) elif format == "auto": @@ -750,24 +333,10 @@ def parse( f"ToolCallProcessor.parse: Failed to parse tool calls " f"(format={format}): {e}" ) - return [] - - # ------------------------------------------------------------------ - # Filtering - # ------------------------------------------------------------------ - - @staticmethod - def filter_by_name( - tool_calls: List[ToolCall], function_name: str - ) -> List[ToolCall]: - """Filter parsed tool calls to only those matching a function name.""" - filtered = [tc for tc in tool_calls if tc.function.name == function_name] - if not filtered: - logger.warning( - f"filter_by_name: No tool calls matched '{function_name}' " - f"(had {len(tool_calls)} call(s))" + logger.debug( + f"ToolCallProcessor.parse: Raw text was: " f"{tool_calls_str[:500]}..." ) - return filtered + return [] # ------------------------------------------------------------------ # Content / tool-call separation @@ -777,7 +346,8 @@ def filter_by_name( def extract_content_and_tools( raw_text: str, ) -> Tuple[str, List[ToolCall]]: - """Separate plain text content from XML tool call blocks. + """ + Separate plain text content from XML tool call blocks. Used when the model mixes reasoning text with tool calls, e.g.: ``"I'll help with that: ...`` @@ -785,6 +355,8 @@ def extract_content_and_tools( Returns: Tuple of (remaining_content, tool_calls). """ + logger.debug("extract_content_and_tools: Separating content and tools") + text = _strip_think_blocks(raw_text) # Collect all XML regions to exclude from content diff --git a/templates/tool_calls/qwen3_coder.jinja b/templates/tool_calls/qwen3_coder.jinja index 0df78172..15272747 100644 --- a/templates/tool_calls/qwen3_coder.jinja +++ b/templates/tool_calls/qwen3_coder.jinja @@ -1,6 +1,4 @@ -{# SPDX-License-Identifier: Apache-2.0 #} -{# SPDX-FileCopyrightText: Copyright contributors to the vLLM project #} -{# TabbyAPI Metadata #} +{# TabbyAPI Metadata #} {%- set tool_call_format = "xml" -%} {%- set tool_start = "" -%} {%- set tool_end = "" -%} @@ -122,4 +120,4 @@ {%- endfor %} {%- if add_generation_prompt %} {{- '<|im_start|>assistant\n' }} -{%- endif %} +{%- endif %} \ No newline at end of file From e5f9948488c564c435d9450ead00f90bfc9f2dd2 Mon Sep 17 00:00:00 2001 From: devnen Date: Sat, 14 Feb 2026 16:15:02 +0100 Subject: [PATCH 08/19] Broader model compatibility, tool_choice support, bug fixes and cleanup --- common/templating.py | 5 + endpoints/OAI/types/tools.py | 13 ++ endpoints/OAI/utils/chat_completion.py | 37 +++-- endpoints/OAI/utils/tools.py | 214 ++++++++++++++++++------- 4 files changed, 191 insertions(+), 78 deletions(-) diff --git a/common/templating.py b/common/templating.py index e353a30a..dda06d85 100644 --- a/common/templating.py +++ b/common/templating.py @@ -34,6 +34,7 @@ class TemplateMetadata: stop_strings: List[str] = field(default_factory=list) tool_start: Optional[str] = None + tool_end: Optional[str] = None tool_call_format: str = "json" @@ -97,6 +98,10 @@ async def extract_metadata(self, template_vars: dict): if isinstance(template_module.tool_start, str): template_metadata.tool_start = template_module.tool_start + if hasattr(template_module, "tool_end"): + if isinstance(template_module.tool_end, str): + template_metadata.tool_end = template_module.tool_end + if hasattr(template_module, "tool_call_format"): fmt = template_module.tool_call_format if isinstance(fmt, str) and fmt in VALID_TOOL_CALL_FORMATS: diff --git a/endpoints/OAI/types/tools.py b/endpoints/OAI/types/tools.py index 1428fa99..1e572663 100644 --- a/endpoints/OAI/types/tools.py +++ b/endpoints/OAI/types/tools.py @@ -40,3 +40,16 @@ class ToolCall(BaseModel): function: Tool type: Literal["function"] = "function" index: Optional[int] = None + + +class NamedToolFunction(BaseModel): + """Represents a named function reference for tool_choice.""" + + name: str + + +class NamedToolChoice(BaseModel): + """Represents a named tool choice (forces a specific function call).""" + + function: NamedToolFunction + type: Literal["function"] = "function" diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 8a3593a3..1afe1f6c 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -32,7 +32,7 @@ ChatCompletionStreamChoice, ) from endpoints.OAI.types.common import UsageStats -from endpoints.OAI.types.tools import ToolCall +from endpoints.OAI.types.tools import NamedToolChoice, ToolCall from endpoints.OAI.utils.completion import _parse_gen_request_id, _stream_collector from endpoints.OAI.utils.parser_options import ( list_tool_call_parsers, @@ -62,6 +62,7 @@ def _create_response( generations: List[dict], model_name: Optional[str], tool_call_format: str = "json", + tool_choice=None, ): """Create a chat completion response from the provided text.""" @@ -80,6 +81,10 @@ def _create_response( tool_calls_raw = generation.get("tool_calls") if tool_calls_raw: parsed = ToolCallProcessor.parse(tool_calls_raw, format=tool_call_format) + if parsed and isinstance(tool_choice, NamedToolChoice): + parsed = ToolCallProcessor.filter_by_name( + parsed, tool_choice.function.name + ) if parsed: message.tool_calls = parsed else: @@ -512,7 +517,7 @@ async def stream_generate_chat_completion( raise CancelledError() # Handle options if a tool model is present - if (tool_start or force_tool_pass) and data.tool_choice != "none": + if tool_start and data.tool_choice != "none": if "stop_str" in generation: generations = await generate_tool_calls( prompt, @@ -532,6 +537,10 @@ async def stream_generate_chat_completion( parsed = ToolCallProcessor.parse( tool_calls_raw, format=tool_call_format ) + if parsed and isinstance(data.tool_choice, NamedToolChoice): + parsed = ToolCallProcessor.filter_by_name( + parsed, data.tool_choice.function.name + ) if parsed: for tc_chunk in _build_tool_call_chunks( parsed, @@ -701,6 +710,7 @@ async def generate_chat_completion( generations, model_path.name, tool_call_format=tool_call_format, + tool_choice=data.tool_choice, ) logger.info(f"Finished chat completion request {request.state.id}") @@ -730,6 +740,10 @@ async def generate_tool_calls( gen_tasks: List[asyncio.Task] = [] tool_start = model.container.prompt_template.metadata.tool_start tool_call_format = model.container.prompt_template.metadata.tool_call_format + tool_choice = data.tool_choice + + if tool_choice == "none": + return generations # Tracks which generations asked for a tool call tool_idx: List[int] = [] @@ -785,23 +799,12 @@ async def generate_tool_calls( if precursor_text: tool_prompt = tool_prompt + precursor_text - # For native generation mode: append tool_start back to prompt. - # The stop string was consumed by the first pass and not included - # in full_text, but the model expects to continue after tool_start. - # Include a trailing newline to match the canonical template format. - if use_native_generation and tool_start: - tool_prompt = tool_prompt + tool_start + "\n" - # For XML/auto mode: append tool_start back to prompt. # The stop string was consumed by the first pass and not included # in full_text, but the model expects to continue after . # Include a trailing newline to match the canonical template format. - if tool_call_format in ("xml", "auto"): - prompt = prompt + tool_start + "\n" - logger.debug( - f"generate_tool_calls: Appended '{tool_start}\\n' " - f"to prompt for XML continuation" - ) + if tool_call_format in ("xml", "auto") and tool_start: + tool_prompt = tool_prompt + tool_start + "\n" gen_request_id = gen.get("request_id") tool_request_id = f"{gen_request_id}-tool" @@ -829,10 +832,6 @@ async def generate_tool_calls( if tool_call_format in ("xml", "auto"): # Prepend tool_start to reconstruct complete XML for parser raw_text = tool_start + "\n" + raw_text - logger.debug( - f"generate_tool_calls: Raw XML tool call output " - f"({len(raw_text)} chars): {raw_text[:500]}..." - ) generations[gen_idx]["tool_calls"] = raw_text diff --git a/endpoints/OAI/utils/tools.py b/endpoints/OAI/utils/tools.py index f0114e11..05eaf143 100644 --- a/endpoints/OAI/utils/tools.py +++ b/endpoints/OAI/utils/tools.py @@ -60,10 +60,14 @@ THINK_BLOCK_RE = re.compile(r".*?\s*", re.DOTALL) THINK_UNCLOSED_RE = re.compile(r"(?!.*).*$", re.DOTALL) +# Markdown code fence patterns +CODE_FENCE_RE = re.compile(r"^```(?:json)?\s*", re.MULTILINE) +CODE_FENCE_END_RE = re.compile(r"\s*```\s*$", re.MULTILINE) + def _strip_think_blocks(text: str) -> str: - """ - Strip ... blocks from text. + """Strip ... blocks from text. + Handles both complete and unclosed blocks (quantization can cause the model to never close a think tag). """ @@ -88,8 +92,7 @@ def _strip_think_blocks(text: str) -> str: def _coerce_param_value(raw: str) -> Any: - """ - Coerce a raw parameter value string to the appropriate Python type. + """Coerce a raw parameter value string to the appropriate Python type. Strategy (safe, no eval()): 1. Strip leading/trailing newlines (official template emits \\n @@ -122,18 +125,115 @@ def _coerce_param_value(raw: str) -> Any: class ToolCallProcessor: # ------------------------------------------------------------------ - # JSON parsing (existing behavior, unchanged) + # JSON normalization helpers # ------------------------------------------------------------------ @staticmethod - def from_json(tool_calls_str: str) -> List[ToolCall]: - """Postprocess tool call JSON to a parseable class.""" + def _normalize_tool_calls(raw) -> list: + """Normalize model-emitted tool call payloads into OAI-like objects. - logger.debug( - f"JSON Parser: Parsing tool calls from JSON " - f"({len(tool_calls_str)} chars)" + Accepted forms: + - [{"type":"function","function":{"name":...,"arguments":{...}}}] + - [{"name":...,"arguments":{...}}] + - {"name":...,"arguments":{...}} + """ + if isinstance(raw, dict): + raw = [raw] + if not isinstance(raw, list): + raise ValueError("tool_calls payload is not list/dict") + + normalized: list = [] + for item in raw: + if not isinstance(item, dict): + continue + + if "function" in item and isinstance(item["function"], dict): + fn = item["function"] + name = fn.get("name") + arguments = fn.get("arguments", {}) + else: + name = item.get("name") + arguments = item.get("arguments", {}) + + if name is None: + continue + + if isinstance(arguments, str): + try: + arguments = json.loads(arguments) + except json.JSONDecodeError: + arguments = {"input": arguments} + + normalized.append( + { + "type": "function", + "function": { + "name": name, + "arguments": arguments if isinstance(arguments, dict) else {}, + }, + } + ) + return normalized + + @staticmethod + def _safe_json_loads(payload: str) -> list: + """Best-effort JSON parse for model-emitted tool payloads. + + Handles: clean JSON, markdown-fenced JSON, JSON substrings in + surrounding text, flat {name, arguments} dicts, and single objects. + """ + # Direct parse + try: + return ToolCallProcessor._normalize_tool_calls(json.loads(payload)) + except (json.JSONDecodeError, ValueError): + pass + + # Clean up common model artifacts (markdown fences, whitespace) + cleaned = payload.strip() + cleaned = CODE_FENCE_RE.sub("", cleaned) + cleaned = CODE_FENCE_END_RE.sub("", cleaned) + cleaned = cleaned.strip() + + # Try cleaned + try: + return ToolCallProcessor._normalize_tool_calls(json.loads(cleaned)) + except (json.JSONDecodeError, ValueError): + pass + + # Find JSON array substring + start = cleaned.find("[") + end = cleaned.rfind("]") + if start != -1 and end != -1 and end > start: + try: + return ToolCallProcessor._normalize_tool_calls( + json.loads(cleaned[start : end + 1]) + ) + except (json.JSONDecodeError, ValueError): + pass + + # Find JSON object substring + obj_start = cleaned.find("{") + obj_end = cleaned.rfind("}") + if obj_start != -1 and obj_end != -1 and obj_end > obj_start: + try: + return ToolCallProcessor._normalize_tool_calls( + json.loads(cleaned[obj_start : obj_end + 1]) + ) + except (json.JSONDecodeError, ValueError): + pass + + raise json.JSONDecodeError( + "Could not extract valid JSON from payload", payload, 0 ) + # ------------------------------------------------------------------ + # JSON parsing + # ------------------------------------------------------------------ + + @staticmethod + def from_json(tool_calls_str: str) -> List[ToolCall]: + """Postprocess tool call JSON to a parseable class. + Handles clean JSON arrays, markdown-fenced output, flat dicts, and other common model output variations via _safe_json_loads. """ @@ -155,8 +255,7 @@ def from_json(tool_calls_str: str) -> List[ToolCall]: @staticmethod def from_xml(raw_text: str) -> List[ToolCall]: - """ - Parse Qwen3-Coder XML-format tool calls into ToolCall objects. + """Parse Qwen3-Coder XML-format tool calls into ToolCall objects. Handles: - Wrapped: ... @@ -166,10 +265,7 @@ def from_xml(raw_text: str) -> List[ToolCall]: - Multi-line parameter values - Missing closing tags """ - logger.debug( - f"XML Parser: Parsing tool calls from XML " f"({len(raw_text)} chars)" - ) - logger.debug(f"XML Parser: Raw input: {raw_text[:500]}...") + logger.debug(f"XML Parser: Parsing tool calls ({len(raw_text)} chars)") # Stage 1: Strip think blocks text = _strip_think_blocks(raw_text) @@ -203,15 +299,12 @@ def from_xml(raw_text: str) -> List[ToolCall]: if not is_wrapped: logger.debug( "XML Parser: Found bare block without " - " wrapper (common Qwen3-Coder behavior)" + " wrapper" ) function_blocks.append((func_match.group(1), func_match.group(2))) if not function_blocks: - logger.warning( - f"XML Parser: No blocks found in text: " - f"{text[:200]}..." - ) + logger.warning("XML Parser: No blocks found") return [] # Stage 4: Parse each function block into a ToolCall @@ -226,7 +319,6 @@ def from_xml(raw_text: str) -> List[ToolCall]: value_raw = param_match.group(2) value = _coerce_param_value(value_raw) params[key] = value - logger.debug(f"XML Parser: param '{key}' = {repr(value)[:100]}") arguments_json = json.dumps(params, ensure_ascii=False) @@ -234,10 +326,6 @@ def from_xml(raw_text: str) -> List[ToolCall]: function=Tool(name=func_name, arguments=arguments_json) ) tool_calls.append(tool_call) - logger.debug( - f"XML Parser: Parsed tool call: {func_name}" - f"({', '.join(params.keys())})" - ) logger.debug(f"XML Parser: Successfully parsed {len(tool_calls)} tool call(s)") return tool_calls @@ -248,12 +336,11 @@ def from_xml(raw_text: str) -> List[ToolCall]: @staticmethod def from_auto(raw_text: str) -> List[ToolCall]: - """ - Auto-detect format and parse. + """Auto-detect format and parse. Tries in order: 1. Pure JSON (standard TabbyAPI / Llama) - 2. JSON inside wrapper (Qwen3-Instruct style) + 2. JSON inside wrappers (Qwen3-Instruct style) 3. XML with tags (Qwen3-Coder style) """ logger.debug("Auto Parser: Attempting format auto-detection") @@ -266,31 +353,32 @@ def from_auto(raw_text: str) -> List[ToolCall]: except (json.JSONDecodeError, ValueError, KeyError) as e: logger.debug(f"Auto Parser: Not JSON ({e}), trying next format") - # Attempt 2: JSON inside wrapper (Qwen3-Instruct) + # Attempt 2: JSON inside wrappers (Qwen3-Instruct) try: + all_tool_calls = [] for match in TOOL_CALL_BLOCK_RE.finditer(raw_text): inner = match.group(1).strip() - if inner.startswith("{"): + if inner.startswith("{") or inner.startswith("["): parsed = json.loads(inner) if isinstance(parsed, dict): parsed = [parsed] - tool_calls = [] - for tc in parsed: - name = tc.get("name", "") - arguments = tc.get("arguments", {}) - if isinstance(arguments, dict): - arguments = json.dumps(arguments) - elif not isinstance(arguments, str): - arguments = json.dumps(arguments) - tool_calls.append( - ToolCall(function=Tool(name=name, arguments=arguments)) - ) - if tool_calls: - logger.debug( - "Auto Parser: Detected JSON-inside-tool_call " - "format (Qwen3-Instruct style)" - ) - return tool_calls + if isinstance(parsed, list): + for tc in parsed: + name = tc.get("name", "") + arguments = tc.get("arguments", {}) + if isinstance(arguments, dict): + arguments = json.dumps(arguments) + elif not isinstance(arguments, str): + arguments = json.dumps(arguments) + all_tool_calls.append( + ToolCall(function=Tool(name=name, arguments=arguments)) + ) + if all_tool_calls: + logger.debug( + "Auto Parser: Detected JSON-inside-tool_call " + f"format ({len(all_tool_calls)} call(s))" + ) + return all_tool_calls except (json.JSONDecodeError, ValueError, KeyError) as e: logger.debug(f"Auto Parser: Not JSON-in-tool_call ({e}), trying XML") @@ -308,8 +396,7 @@ def from_auto(raw_text: str) -> List[ToolCall]: @staticmethod def parse(tool_calls_str: str, format: str = "json") -> List[ToolCall]: - """ - Dispatch tool call parsing to the appropriate format handler. + """Dispatch tool call parsing to the appropriate format handler. Args: tool_calls_str: Raw tool call text from model generation. @@ -319,8 +406,6 @@ def parse(tool_calls_str: str, format: str = "json") -> List[ToolCall]: List of parsed ToolCall objects. Empty list on parse failure (never raises). """ - logger.debug(f"ToolCallProcessor.parse: format={format}") - try: if format == "xml": return ToolCallProcessor.from_xml(tool_calls_str) @@ -333,11 +418,25 @@ def parse(tool_calls_str: str, format: str = "json") -> List[ToolCall]: f"ToolCallProcessor.parse: Failed to parse tool calls " f"(format={format}): {e}" ) - logger.debug( - f"ToolCallProcessor.parse: Raw text was: " f"{tool_calls_str[:500]}..." - ) return [] + # ------------------------------------------------------------------ + # Filtering + # ------------------------------------------------------------------ + + @staticmethod + def filter_by_name( + tool_calls: List[ToolCall], function_name: str + ) -> List[ToolCall]: + """Filter parsed tool calls to only those matching a function name.""" + filtered = [tc for tc in tool_calls if tc.function.name == function_name] + if not filtered: + logger.warning( + f"filter_by_name: No tool calls matched '{function_name}' " + f"(had {len(tool_calls)} call(s))" + ) + return filtered + # ------------------------------------------------------------------ # Content / tool-call separation # ------------------------------------------------------------------ @@ -346,8 +445,7 @@ def parse(tool_calls_str: str, format: str = "json") -> List[ToolCall]: def extract_content_and_tools( raw_text: str, ) -> Tuple[str, List[ToolCall]]: - """ - Separate plain text content from XML tool call blocks. + """Separate plain text content from XML tool call blocks. Used when the model mixes reasoning text with tool calls, e.g.: ``"I'll help with that: ...`` @@ -355,8 +453,6 @@ def extract_content_and_tools( Returns: Tuple of (remaining_content, tool_calls). """ - logger.debug("extract_content_and_tools: Separating content and tools") - text = _strip_think_blocks(raw_text) # Collect all XML regions to exclude from content From f1f488bac0cf35574b70cd560735de1991f479a5 Mon Sep 17 00:00:00 2001 From: lesj0610 Date: Tue, 24 Feb 2026 21:40:49 +0900 Subject: [PATCH 09/19] fix(oai): preserve parser_key parser dispatch after stacking PRs --- endpoints/OAI/utils/tools.py | 336 ++++++++++++++++++++++++++++++++++- 1 file changed, 334 insertions(+), 2 deletions(-) diff --git a/endpoints/OAI/utils/tools.py b/endpoints/OAI/utils/tools.py index 05eaf143..bf6e2a0d 100644 --- a/endpoints/OAI/utils/tools.py +++ b/endpoints/OAI/utils/tools.py @@ -1,11 +1,16 @@ """Tool call processing utilities for OAI server.""" +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import ast import json import re from loguru import logger -from typing import Any, List, Tuple +from typing import Any, Callable, Dict, List, Tuple from endpoints.OAI.types.tools import ToolCall, Tool +from endpoints.OAI.utils.parser_options import resolve_tool_call_parser_key TOOL_CALL_SCHEMA = { @@ -64,6 +69,24 @@ CODE_FENCE_RE = re.compile(r"^```(?:json)?\s*", re.MULTILINE) CODE_FENCE_END_RE = re.compile(r"\s*```\s*$", re.MULTILINE) +# DeepSeek family patterns +DEEPSEEK_V31_CALL_RE = re.compile( + r"<|tool▁call▁begin|>(?P.*?)<|tool▁sep|>(?P.*?)<|tool▁call▁end|>", + re.DOTALL, +) +DEEPSEEK_V3_CALL_RE = re.compile( + r"<|tool▁call▁begin|>(?P.*?)<|tool▁sep|>(?P.*?)\n```json\n(?P.*?)\n```(?:\s*)<|tool▁call▁end|>", # noqa: E501 + re.DOTALL, +) +DEEPSEEK_V32_INVOKE_RE = re.compile( + r'<|DSML|invoke\s+name="(?P[^"]+)"\s*>(?P.*?)', + re.DOTALL, +) +DEEPSEEK_V32_PARAM_RE = re.compile( + r'<|DSML|parameter\s+name="(?P[^"]+)"(?:\s+string="(?Ptrue|false)")?\s*>(?P.*?)', # noqa: E501 + re.DOTALL, +) + def _strip_think_blocks(text: str) -> str: """Strip ... blocks from text. @@ -123,6 +146,7 @@ def _coerce_param_value(raw: str) -> Any: class ToolCallProcessor: + _PARSER_DISPATCHER: Dict[str, Callable[[str], List[ToolCall]]] = {} # ------------------------------------------------------------------ # JSON normalization helpers @@ -226,10 +250,279 @@ def _safe_json_loads(payload: str) -> list: "Could not extract valid JSON from payload", payload, 0 ) + @staticmethod + def _build_tool_calls_from_normalized(raw: Any) -> List[ToolCall]: + """Normalize dict/list payload and build ToolCall models.""" + normalized = ToolCallProcessor._normalize_tool_calls(raw) + for tool_call in normalized: + tool_call["function"]["arguments"] = json.dumps( + tool_call["function"]["arguments"], ensure_ascii=False + ) + return [ToolCall(**tool_call) for tool_call in normalized] + + @staticmethod + def _decode_json_sequence(text: str) -> List[Any]: + """Decode multiple JSON values from a single string.""" + decoder = json.JSONDecoder() + values: List[Any] = [] + idx = 0 + while idx < len(text): + while idx < len(text) and text[idx] in " \t\r\n,;": + idx += 1 + if idx >= len(text): + break + if text.startswith("<|python_tag|>", idx): + idx += len("<|python_tag|>") + continue + try: + value, end = decoder.raw_decode(text[idx:]) + except json.JSONDecodeError: + break + values.append(value) + idx += end + return values + + @staticmethod + def _coerce_argument_payload(arguments_raw: str) -> str: + """Normalize raw argument payload to a JSON string where possible.""" + payload = arguments_raw.strip() + if not payload: + return "{}" + try: + return json.dumps(json.loads(payload), ensure_ascii=False) + except (json.JSONDecodeError, ValueError, TypeError): + return payload + + @staticmethod + def _ast_to_literal(node: ast.AST) -> Any: + """Safely convert AST literal nodes to Python primitives.""" + if isinstance(node, ast.Constant): + return node.value + if isinstance(node, ast.List): + return [ToolCallProcessor._ast_to_literal(item) for item in node.elts] + if isinstance(node, ast.Tuple): + return [ToolCallProcessor._ast_to_literal(item) for item in node.elts] + if isinstance(node, ast.Dict): + result = {} + for key, value in zip(node.keys, node.values): + literal_key = ToolCallProcessor._ast_to_literal(key) # type: ignore[arg-type] + if not isinstance(literal_key, str): + raise ValueError("pythonic parser requires string dict keys") + result[literal_key] = ToolCallProcessor._ast_to_literal(value) + return result + if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub): + return -ToolCallProcessor._ast_to_literal(node.operand) + raise ValueError(f"unsupported pythonic AST node: {type(node).__name__}") + # ------------------------------------------------------------------ # JSON parsing # ------------------------------------------------------------------ + @staticmethod + def from_hermes(raw_text: str) -> List[ToolCall]: + """Parse Hermes-style JSON tool calls (often wrapped in ).""" + text = _strip_think_blocks(raw_text) + wrapped_calls = [] + for match in TOOL_CALL_BLOCK_RE.finditer(text): + inner = match.group(1).strip() + if not inner: + continue + try: + parsed = json.loads(inner) + except (json.JSONDecodeError, ValueError): + continue + wrapped_calls.extend(ToolCallProcessor._build_tool_calls_from_normalized(parsed)) + + if wrapped_calls: + return wrapped_calls + + return ToolCallProcessor.from_json(text) + + @staticmethod + def from_llama(raw_text: str) -> List[ToolCall]: + """Parse Llama JSON tool calls (single/multiple JSON objects).""" + text = _strip_think_blocks(raw_text).strip() + if text.startswith("<|python_tag|>"): + text = text[len("<|python_tag|>") :].lstrip() + + try: + parsed = ToolCallProcessor.from_json(text) + if parsed: + return parsed + except (json.JSONDecodeError, ValueError, KeyError): + pass + + decoded = ToolCallProcessor._decode_json_sequence(text) + if not decoded: + return [] + + flattened = [] + for item in decoded: + if isinstance(item, list): + flattened.extend(item) + else: + flattened.append(item) + + return ToolCallProcessor._build_tool_calls_from_normalized(flattened) + + @staticmethod + def from_openai(raw_text: str) -> List[ToolCall]: + """Best-effort parser for OpenAI/Harmony-style text payloads.""" + text = _strip_think_blocks(raw_text).strip() + try: + parsed = ToolCallProcessor.from_json(text) + if parsed: + return parsed + except (json.JSONDecodeError, ValueError, KeyError): + pass + + decoded = ToolCallProcessor._decode_json_sequence(text) + tool_calls: List[ToolCall] = [] + normalized_items = [] + for value in decoded: + candidates = value if isinstance(value, list) else [value] + for item in candidates: + if not isinstance(item, dict): + continue + + nested = item.get("tool_calls") + if nested: + try: + tool_calls.extend( + ToolCallProcessor._build_tool_calls_from_normalized(nested) + ) + except (ValueError, KeyError, TypeError): + pass + + recipient = item.get("recipient") + content = item.get("content") + if isinstance(recipient, str) and recipient.startswith("functions."): + fn_name = recipient.split("functions.", 1)[1] + if isinstance(content, str): + payload = ToolCallProcessor._coerce_argument_payload(content) + elif content is None: + payload = "{}" + else: + payload = json.dumps(content, ensure_ascii=False) + tool_calls.append( + ToolCall(function=Tool(name=fn_name, arguments=payload)) + ) + continue + + if "name" in item: + normalized_items.append(item) + + if normalized_items: + tool_calls.extend( + ToolCallProcessor._build_tool_calls_from_normalized(normalized_items) + ) + + return tool_calls + + @staticmethod + def from_pythonic(raw_text: str) -> List[ToolCall]: + """Parse Pythonic list-of-calls tool syntax.""" + text = _strip_think_blocks(raw_text).strip() + if text.startswith("<|python_tag|>"): + text = text[len("<|python_tag|>") :].lstrip() + if not text: + return [] + + if not text.startswith("[") and re.match(r"^[A-Za-z_]\w*\s*\(", text): + text = f"[{text}]" + + expression = ast.parse(text, mode="eval").body + call_nodes = expression.elts if isinstance(expression, ast.List) else [expression] + + tool_calls = [] + for node in call_nodes: + if not isinstance(node, ast.Call) or not isinstance(node.func, ast.Name): + continue + args_dict: Dict[str, Any] = {} + if node.args: + args_dict["_args"] = [ + ToolCallProcessor._ast_to_literal(argument) + for argument in node.args + ] + for keyword in node.keywords: + if keyword.arg is None: + continue + args_dict[keyword.arg] = ToolCallProcessor._ast_to_literal(keyword.value) + + tool_calls.append( + ToolCall( + function=Tool( + name=node.func.id, + arguments=json.dumps(args_dict, ensure_ascii=False), + ) + ) + ) + + return tool_calls + + @staticmethod + def from_deepseek_v31(raw_text: str) -> List[ToolCall]: + """Parse DeepSeek v3.1 tool call syntax.""" + tool_calls = [] + for match in DEEPSEEK_V31_CALL_RE.finditer(raw_text): + name = match.group("name").strip() + if not name: + continue + arguments = ToolCallProcessor._coerce_argument_payload(match.group("args")) + tool_calls.append(ToolCall(function=Tool(name=name, arguments=arguments))) + return tool_calls + + @staticmethod + def from_deepseek_v3(raw_text: str) -> List[ToolCall]: + """Parse DeepSeek v3 tool call syntax.""" + tool_calls = [] + for match in DEEPSEEK_V3_CALL_RE.finditer(raw_text): + name = match.group("name").strip() + if not name: + continue + arguments = ToolCallProcessor._coerce_argument_payload(match.group("args")) + tool_calls.append(ToolCall(function=Tool(name=name, arguments=arguments))) + + if tool_calls: + return tool_calls + + return ToolCallProcessor.from_deepseek_v31(raw_text) + + @staticmethod + def from_deepseek_v32(raw_text: str) -> List[ToolCall]: + """Parse DeepSeek v3.2 DSML tool call syntax.""" + tool_calls = [] + for invoke in DEEPSEEK_V32_INVOKE_RE.finditer(raw_text): + function_name = invoke.group("name").strip() + if not function_name: + continue + + params: Dict[str, Any] = {} + body = invoke.group("body") + for param in DEEPSEEK_V32_PARAM_RE.finditer(body): + key = param.group("name").strip() + value_raw = param.group("value") + is_string = param.group("string") == "true" + if is_string: + value = value_raw.strip("\n") + else: + value = _coerce_param_value(value_raw) + params[key] = value + + tool_calls.append( + ToolCall( + function=Tool( + name=function_name, + arguments=json.dumps(params, ensure_ascii=False), + ) + ) + ) + + if tool_calls: + return tool_calls + + return ToolCallProcessor.from_deepseek_v31(raw_text) + @staticmethod def from_json(tool_calls_str: str) -> List[ToolCall]: """Postprocess tool call JSON to a parseable class. @@ -395,18 +688,57 @@ def from_auto(raw_text: str) -> List[ToolCall]: # ------------------------------------------------------------------ @staticmethod - def parse(tool_calls_str: str, format: str = "json") -> List[ToolCall]: + def _parser_dispatcher() -> Dict[str, Callable[[str], List[ToolCall]]]: + """Registry for parser-key-specific handlers.""" + if not ToolCallProcessor._PARSER_DISPATCHER: + ToolCallProcessor._PARSER_DISPATCHER = { + "deepseek_v3": ToolCallProcessor.from_deepseek_v3, + "deepseek_v31": ToolCallProcessor.from_deepseek_v31, + "deepseek_v32": ToolCallProcessor.from_deepseek_v32, + "hermes": ToolCallProcessor.from_hermes, + "llama": ToolCallProcessor.from_llama, + "llama3_json": ToolCallProcessor.from_llama, + "llama4_json": ToolCallProcessor.from_llama, + "openai": ToolCallProcessor.from_openai, + "pythonic": ToolCallProcessor.from_pythonic, + "qwen3_coder": ToolCallProcessor.from_xml, + "qwen3_xml": ToolCallProcessor.from_xml, + } + return ToolCallProcessor._PARSER_DISPATCHER + + @staticmethod + def parse( + tool_calls_str: str, format: str = "json", parser_key: str | None = None + ) -> List[ToolCall]: """Dispatch tool call parsing to the appropriate format handler. Args: tool_calls_str: Raw tool call text from model generation. format: One of ``"json"``, ``"xml"``, ``"auto"``. + parser_key: Optional vLLM-compatible parser key. Returns: List of parsed ToolCall objects. Empty list on parse failure (never raises). """ try: + if parser_key: + canonical_key = resolve_tool_call_parser_key(parser_key) or parser_key + parser = ToolCallProcessor._parser_dispatcher().get(canonical_key) + if parser: + try: + parsed = parser(tool_calls_str) + except Exception as exc: + logger.warning( + "Parser '{}' failed: {}. Falling back to format '{}'.", + canonical_key, + str(exc), + format, + ) + else: + if parsed: + return parsed + if format == "xml": return ToolCallProcessor.from_xml(tool_calls_str) elif format == "auto": From c86a6ddfadf7213895109ecae0a0cc48cd342012 Mon Sep 17 00:00:00 2001 From: lesj0610 Date: Tue, 24 Feb 2026 22:22:35 +0900 Subject: [PATCH 10/19] feat(oai): add mistral tool-call parser compatibility --- docs/10.-Tool-Calling.md | 1 + endpoints/OAI/utils/parser_options.py | 1 + endpoints/OAI/utils/tools.py | 115 ++++++++++++++++++++++++++ tests/parser_options_test.py | 1 + tests/tool_parser_test.py | 41 +++++++++ 5 files changed, 159 insertions(+) diff --git a/docs/10.-Tool-Calling.md b/docs/10.-Tool-Calling.md index aaed88cf..cbb772f6 100644 --- a/docs/10.-Tool-Calling.md +++ b/docs/10.-Tool-Calling.md @@ -31,6 +31,7 @@ The following model config options are available to align behavior with vLLM: Supported parser keys include: - `hermes` - `llama` (alias of `llama3_json`), `llama3_json`, `llama4_json` +- `mistral` - `openai` - `pythonic` - `qwen3_coder`, `qwen3_xml` diff --git a/endpoints/OAI/utils/parser_options.py b/endpoints/OAI/utils/parser_options.py index e1da446c..84340f88 100644 --- a/endpoints/OAI/utils/parser_options.py +++ b/endpoints/OAI/utils/parser_options.py @@ -60,6 +60,7 @@ "deepseek_v31", "deepseek_v32", "llama4_pythonic", + "mistral", "pythonic", "qwen3_coder", "qwen3_xml", diff --git a/endpoints/OAI/utils/tools.py b/endpoints/OAI/utils/tools.py index bf6e2a0d..9758e4ce 100644 --- a/endpoints/OAI/utils/tools.py +++ b/endpoints/OAI/utils/tools.py @@ -6,6 +6,8 @@ import ast import json import re +from random import choices +from string import ascii_letters, digits from loguru import logger from typing import Any, Callable, Dict, List, Tuple @@ -87,6 +89,9 @@ re.DOTALL, ) +MISTRAL_TOOL_START = "[TOOL_CALLS]" +MISTRAL_ID_ALPHANUMERIC = ascii_letters + digits + def _strip_think_blocks(text: str) -> str: """Strip ... blocks from text. @@ -293,6 +298,64 @@ def _coerce_argument_payload(arguments_raw: str) -> str: except (json.JSONDecodeError, ValueError, TypeError): return payload + @staticmethod + def _normalize_mistral_tool_call_id(raw_id: Any) -> str: + """Normalize tool call IDs to Mistral's 9-char alphanumeric format.""" + if isinstance(raw_id, str): + candidate = re.sub(r"[^A-Za-z0-9]", "", raw_id) + if len(candidate) >= 9: + return candidate[-9:] + return "".join(choices(MISTRAL_ID_ALPHANUMERIC, k=9)) + + @staticmethod + def _build_mistral_tool_call(name: str, arguments: Any, raw_id: Any = None) -> ToolCall: + if isinstance(arguments, str): + payload = ToolCallProcessor._coerce_argument_payload(arguments) + else: + payload = json.dumps(arguments, ensure_ascii=False) + return ToolCall( + id=ToolCallProcessor._normalize_mistral_tool_call_id(raw_id), + function=Tool(name=name, arguments=payload), + ) + + @staticmethod + def _parse_mistral_json_tool_calls(payload: str) -> List[ToolCall]: + """Parse JSON-style Mistral tool calls following [TOOL_CALLS].""" + decoded = ToolCallProcessor._decode_json_sequence(payload) + if not decoded: + try: + decoded = [json.loads(payload)] + except (json.JSONDecodeError, ValueError): + return [] + + tool_calls: List[ToolCall] = [] + for item in decoded: + candidates = item if isinstance(item, list) else [item] + for candidate in candidates: + if not isinstance(candidate, dict): + continue + + if "function" in candidate and isinstance(candidate["function"], dict): + fn = candidate["function"] + name = fn.get("name") + arguments = fn.get("arguments", {}) + tool_id = candidate.get("id") + else: + name = candidate.get("name") + arguments = candidate.get("arguments", {}) + tool_id = candidate.get("id") + + if not isinstance(name, str) or not name: + continue + + tool_calls.append( + ToolCallProcessor._build_mistral_tool_call( + name=name, arguments=arguments, raw_id=tool_id + ) + ) + + return tool_calls + @staticmethod def _ast_to_literal(node: ast.AST) -> Any: """Safely convert AST literal nodes to Python primitives.""" @@ -523,6 +586,57 @@ def from_deepseek_v32(raw_text: str) -> List[ToolCall]: return ToolCallProcessor.from_deepseek_v31(raw_text) + @staticmethod + def from_mistral(raw_text: str) -> List[ToolCall]: + """Parse Mistral [TOOL_CALLS] payloads for both tokenizer formats.""" + text = raw_text.strip() + + # Non-Mistral outputs should remain compatible with existing JSON logic. + if MISTRAL_TOOL_START not in text: + return ToolCallProcessor.from_json(text) + + split_payloads = [ + chunk.strip() + for chunk in text.split(MISTRAL_TOOL_START)[1:] + if chunk.strip() + ] + if not split_payloads: + return [] + + # pre-v11 format: [TOOL_CALLS] [{"name": "...", "arguments": {...}}] + if len(split_payloads) == 1: + json_calls = ToolCallProcessor._parse_mistral_json_tool_calls( + split_payloads[0] + ) + if json_calls: + return json_calls + + # v11+ format: [TOOL_CALLS]name{...}[TOOL_CALLS]name{...} + tool_calls: List[ToolCall] = [] + for payload in split_payloads: + start = payload.find("{") + end = payload.rfind("}") + if start == -1 or end < start: + continue + + function_name = payload[:start].strip() + if not function_name: + continue + + arguments = payload[start : end + 1] + tool_calls.append( + ToolCallProcessor._build_mistral_tool_call( + name=function_name, + arguments=arguments, + ) + ) + + if tool_calls: + return tool_calls + + # Final fallback for malformed payloads. + return ToolCallProcessor.from_json(split_payloads[-1]) + @staticmethod def from_json(tool_calls_str: str) -> List[ToolCall]: """Postprocess tool call JSON to a parseable class. @@ -699,6 +813,7 @@ def _parser_dispatcher() -> Dict[str, Callable[[str], List[ToolCall]]]: "llama": ToolCallProcessor.from_llama, "llama3_json": ToolCallProcessor.from_llama, "llama4_json": ToolCallProcessor.from_llama, + "mistral": ToolCallProcessor.from_mistral, "openai": ToolCallProcessor.from_openai, "pythonic": ToolCallProcessor.from_pythonic, "qwen3_coder": ToolCallProcessor.from_xml, diff --git a/tests/parser_options_test.py b/tests/parser_options_test.py index 0d328bc0..4f9648cc 100644 --- a/tests/parser_options_test.py +++ b/tests/parser_options_test.py @@ -36,4 +36,5 @@ def test_native_generation_flags_cover_native_syntax_parsers(): assert parser_uses_native_tool_generation("qwen3_coder", "json") is True assert parser_uses_native_tool_generation("deepseek_v31", "json") is True assert parser_uses_native_tool_generation("pythonic", "json") is True + assert parser_uses_native_tool_generation("mistral", "json") is True assert parser_uses_native_tool_generation("hermes", "json") is False diff --git a/tests/tool_parser_test.py b/tests/tool_parser_test.py index 6d8b9372..8b93a726 100644 --- a/tests/tool_parser_test.py +++ b/tests/tool_parser_test.py @@ -183,6 +183,47 @@ def test_parse_with_openai_parser_handles_functions_recipient(): assert _arguments_dict(parsed[0]) == {"city": "Seoul"} +def test_parse_with_mistral_parser_handles_pre_v11_json(): + payload = ( + '[TOOL_CALLS] [{"name":"get_weather","arguments":{"city":"Seoul","days":2}}]' + ) + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="mistral") + + assert len(parsed) == 1 + assert parsed[0].function.name == "get_weather" + assert _arguments_dict(parsed[0]) == {"city": "Seoul", "days": 2} + assert parsed[0].id.isalnum() + assert len(parsed[0].id) == 9 + + +def test_parse_with_mistral_parser_handles_v11_style_segments(): + payload = ( + '[TOOL_CALLS]search{"q":"tabbyapi"}' + '[TOOL_CALLS]lookup{"id":42}' + ) + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="mistral") + + assert len(parsed) == 2 + assert parsed[0].function.name == "search" + assert _arguments_dict(parsed[0]) == {"q": "tabbyapi"} + assert parsed[1].function.name == "lookup" + assert _arguments_dict(parsed[1]) == {"id": 42} + assert parsed[1].id.isalnum() + assert len(parsed[1].id) == 9 + + +def test_parse_with_mistral_parser_falls_back_to_standard_json(): + payload = '[{"name":"lookup","arguments":{"id":42}}]' + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="mistral") + + assert len(parsed) == 1 + assert parsed[0].function.name == "lookup" + assert _arguments_dict(parsed[0]) == {"id": 42} + + def test_parser_key_dispatch_overrides_format_for_qwen3_xml(): payload = ( "" From c0f93874b415949182cee6e769c4f6d69278d771 Mon Sep 17 00:00:00 2001 From: lesj0610 Date: Tue, 24 Feb 2026 22:22:41 +0900 Subject: [PATCH 11/19] feat(config): add tokenizer_mode and mistral-safe tool ID handling --- backends/exllamav2/model.py | 3 + backends/exllamav3/model.py | 33 +++ common/config_models.py | 16 +- config_sample.yml | 8 +- docs/02.-Server-options.md | 1 + endpoints/OAI/utils/chat_completion.py | 334 ++++++++++++++++++++----- endpoints/core/types/model.py | 5 + 7 files changed, 333 insertions(+), 67 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index d9341f10..ce24b65a 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -83,6 +83,7 @@ class ExllamaV2Container(BaseModelContainer): cache_mode: str = "FP16" draft_cache_mode: str = "FP16" max_batch_size: Optional[int] = None + tokenizer_mode: str = "auto" # GPU split vars gpu_split: List[float] = [] @@ -120,6 +121,7 @@ async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs self.model_dir = model_directory self.config.model_dir = str(model_directory.resolve()) self.hf_model = hf_model + self.tokenizer_mode = str(unwrap(kwargs.get("tokenizer_mode"), "auto")).lower() # Make the max seq len 4096 before preparing the config # This is a better default than 2048 @@ -440,6 +442,7 @@ def model_info(self): max_batch_size=self.max_batch_size, cache_mode=self.cache_mode, chunk_size=self.config.max_input_len, + tokenizer_mode=self.tokenizer_mode, use_vision=self.use_vision, draft=draft_model_card, ) diff --git a/backends/exllamav3/model.py b/backends/exllamav3/model.py index 9780f940..b2a3cc05 100644 --- a/backends/exllamav3/model.py +++ b/backends/exllamav3/model.py @@ -45,6 +45,17 @@ from endpoints.core.types.model import ModelCard, ModelCardParameters +_SUPPORTED_TOKENIZER_MODES = {"auto", "hf", "mistral"} + + +def _supports_mistral_tokenizer_mode(model_directory: pathlib.Path) -> bool: + return ( + (model_directory / "tekken.json").exists() + or (model_directory / "tokenizer.model").exists() + or any(model_directory.glob("tokenizer.model.v*")) + ) + + class ExllamaV3Container(BaseModelContainer): """Abstract base class for model containers.""" @@ -88,6 +99,7 @@ class ExllamaV3Container(BaseModelContainer): chunk_size: int = 2048 max_rq_tokens: Optional[int] = 2048 max_batch_size: Optional[int] = None + tokenizer_mode: str = "auto" # Required methods @classmethod @@ -110,6 +122,26 @@ async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs self.model_dir = model_directory self.hf_model = hf_model + requested_tokenizer_mode = str(unwrap(kwargs.get("tokenizer_mode"), "auto")).lower() + if requested_tokenizer_mode not in _SUPPORTED_TOKENIZER_MODES: + logger.warning( + "Unknown tokenizer_mode '{}' requested. Falling back to 'auto'.", + requested_tokenizer_mode, + ) + requested_tokenizer_mode = "auto" + + if requested_tokenizer_mode == "mistral": + if _supports_mistral_tokenizer_mode(model_directory): + logger.info("Using tokenizer_mode='mistral' compatibility path.") + else: + logger.warning( + "tokenizer_mode='mistral' requested but model does not appear " + "to use Mistral tokenizer assets. Falling back to default mode." + ) + requested_tokenizer_mode = "auto" + + self.tokenizer_mode = requested_tokenizer_mode + self.config = Config.from_directory(str(model_directory.resolve())) self.model = Model.from_config(self.config) self.tokenizer = Tokenizer.from_config(self.config) @@ -369,6 +401,7 @@ def model_info(self) -> ModelCard: max_batch_size=self.max_batch_size, cache_mode=self.cache_mode, chunk_size=self.chunk_size, + tokenizer_mode=self.tokenizer_mode, use_vision=self.use_vision, ) diff --git a/common/config_models.py b/common/config_models.py index c13434d5..66aa89ff 100644 --- a/common/config_models.py +++ b/common/config_models.py @@ -289,11 +289,21 @@ class ModelConfig(BaseConfigModel): description=( "Set the prompt template for this model. (default: None)\n" "If empty, attempts to look for the model's chat template.\n" - "If a model contains multiple templates in its tokenizer_config.json,\n" - "set prompt_template to the name of the template you want to use.\n" + "If a model contains multiple templates in its tokenizer_config.json,\n" + "set prompt_template to the name of the template you want to use.\n" "NOTE: Only works with chat completion message lists!" ), ) + tokenizer_mode: Optional[str] = Field( + "auto", + description=( + "Tokenizer compatibility mode for chat formatting.\n" + "Compatible values: auto, hf, mistral.\n" + "mistral applies Mistral-specific message normalization " + "(tool-call ID handling) and falls back to default behavior " + "for non-Mistral models." + ), + ) reasoning_parser: Optional[str] = Field( None, description=( @@ -316,7 +326,7 @@ class ModelConfig(BaseConfigModel): "Tool parser key for model-generated tool call output.\n" "Equivalent to vLLM's --tool-call-parser.\n" "Built-in parser keys include: hermes, llama/llama3_json/llama4_json,\n" - "openai, pythonic, qwen3_coder, qwen3_xml,\n" + "mistral, openai, pythonic, qwen3_coder, qwen3_xml,\n" "deepseek_v3, deepseek_v31, deepseek_v32." ), ) diff --git a/config_sample.yml b/config_sample.yml index bd03afb2..6f53a78a 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -153,6 +153,12 @@ model: # NOTE: Only works with chat completion message lists! prompt_template: + # Tokenizer compatibility mode for chat formatting (default: auto). + # Values: auto, hf, mistral. + # mistral applies Mistral-style tool-call ID normalization and + # falls back to default behavior for non-Mistral models. + tokenizer_mode: auto + # Reasoning parser key for splitting hidden reasoning and final content. # Compatible keys include: basic, exaone4, deepseek_r1, deepseek_v3. # If omitted, TabbyAPI defaults to `basic`. @@ -166,7 +172,7 @@ model: # Tool parser key for model-generated tool call text. # Equivalent to vLLM --tool-call-parser. # Built-in values include: - # hermes, llama (alias of llama3_json), llama3_json, llama4_json, + # hermes, llama (alias of llama3_json), llama3_json, llama4_json, mistral, # openai, pythonic, qwen3_coder, qwen3_xml, # deepseek_v3, deepseek_v31, deepseek_v32. tool_call_parser: diff --git a/docs/02.-Server-options.md b/docs/02.-Server-options.md index 4da7766b..2efc9c3a 100644 --- a/docs/02.-Server-options.md +++ b/docs/02.-Server-options.md @@ -73,6 +73,7 @@ Note: Most of the options here will only apply on initial model load/startup (ep | chunk_size | Int (2048) | Amount of tokens per chunk with ingestion. A lower value reduces VRAM usage at the cost of ingestion speed. | | max_batch_size | Int (None) | The absolute maximum amount of prompts to process at one time. This value is automatically adjusted based on cache size. | | prompt_template | String (None) | Name of a jinja2 chat template to apply for this model. Must be located in the `templates` directory. | +| tokenizer_mode | String ("auto") | Tokenizer compatibility mode for chat formatting. Supported values are `auto`, `hf`, `mistral`; `mistral` applies Mistral-specific tool-call ID handling and falls back to default behavior on non-Mistral models. | | reasoning_parser | String (None) | Reasoning parser key used to split reasoning and final answer text (vLLM-compatible names, default parser behavior is `basic`). | | enable_auto_tool_choice | Bool (False) | Enables vLLM-style automatic tool choice handling. Equivalent to `--enable-auto-tool-choice` and requires `tool_call_parser`. | | tool_call_parser | String (None) | vLLM-compatible tool parser key used to parse model-emitted tool calls. Equivalent to `--tool-call-parser`. | diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 1afe1f6c..beb20879 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -42,6 +42,181 @@ from endpoints.OAI.utils.tools import ToolCallProcessor, TOOL_CALL_SCHEMA +_SUPPORTED_TOKENIZER_MODES = {"auto", "hf", "mistral"} + + +@dataclass +class _StreamReasoningState: + text: str = "" + token_ids: List[int] = field(default_factory=list) + + +class _TokenizerAdapter: + """Expose the minimal tokenizer interface required by reasoning parsers.""" + + def __init__(self): + self._vocab = None + + def get_vocab(self) -> dict[str, int]: + if self._vocab is not None: + return self._vocab + + tokenizer = model.container.tokenizer + if hasattr(tokenizer, "get_vocab"): + self._vocab = tokenizer.get_vocab() + return self._vocab + + pieces = tokenizer.get_id_to_piece_list(True) + vocab: dict[str, int] = {} + for token_id, piece in enumerate(pieces): + if piece not in vocab: + vocab[piece] = token_id + self._vocab = vocab + return vocab + + +def _token_ids_from_generation(generation: dict) -> List[int]: + token_ids = generation.get("token_ids") + if token_ids is None: + return [] + if isinstance(token_ids, list): + return token_ids + if hasattr(token_ids, "flatten"): + return token_ids.flatten().tolist() + return list(token_ids) + + +def _get_tokenizer_mode() -> str: + container_mode = getattr(model.container, "tokenizer_mode", None) + tokenizer_mode = str( + unwrap(container_mode, unwrap(config.model.tokenizer_mode, "auto")) or "auto" + ).lower() + if tokenizer_mode not in _SUPPORTED_TOKENIZER_MODES: + logger.warning( + "Unknown tokenizer_mode '{}' configured; falling back to 'auto'.", + tokenizer_mode, + ) + return "auto" + return tokenizer_mode + + +def _truncate_mistral_tool_ids(message_dicts: List[dict]) -> None: + """Match vLLM mistral_common behavior by truncating tool IDs to 9 chars.""" + for message in message_dicts: + role = message.get("role") + + if role == "assistant": + for tool_call in message.get("tool_calls", []): + tool_id = tool_call.get("id") + if isinstance(tool_id, str) and len(tool_id) > 9: + tool_call["id"] = tool_id[-9:] + + if role in {"tool", "tool_results"}: + tool_call_id = message.get("tool_call_id") + if isinstance(tool_call_id, str) and len(tool_call_id) > 9: + message["tool_call_id"] = tool_call_id[-9:] + + +def _sanitize_mistral_tool_choice_ids(request_data: ChatCompletionRequest) -> None: + """Normalize tool_choice IDs so mistral templates don't fail validation.""" + for message in request_data.messages: + if message.role == "assistant" and message.tool_calls: + for tool_call in message.tool_calls: + if isinstance(tool_call.id, str) and len(tool_call.id) > 9: + tool_call.id = tool_call.id[-9:] + elif message.role in {"tool", "tool_results"} and message.tool_call_id: + if len(message.tool_call_id) > 9: + message.tool_call_id = message.tool_call_id[-9:] + + +def _build_parser_instance(parser_key: str, template_kwargs: dict): + parser_cls = ReasoningParserManager.get_reasoning_parser(parser_key) + return parser_cls(_TokenizerAdapter(), chat_template_kwargs=template_kwargs) + + +def _build_reasoning_parser(request_data: ChatCompletionRequest): + parser_key = unwrap(config.model.reasoning_parser, "basic") or "basic" + + template_kwargs = unwrap(request_data.template_vars, {}) + + try: + return _build_parser_instance(parser_key, template_kwargs) + except KeyError as exc: + raise HTTPException(400, str(exc)) from exc + except RuntimeError as exc: + # Keep compatibility for models that do not expose thinking tags. + if parser_key == "basic": + logger.warning( + "Reasoning parser 'basic' could not initialize ({}). " + "Falling back to identity parser.", + str(exc), + ) + identity_cls = ReasoningParserManager.get_reasoning_parser("identity") + return identity_cls(_TokenizerAdapter(), chat_template_kwargs=template_kwargs) + + # Mistral parser only applies when think tokens exist. If a + # non-Mistral model is selected, fall back to basic behavior. + if parser_key == "mistral": + logger.warning( + "Reasoning parser 'mistral' could not initialize ({}). " + "Falling back to 'basic'.", + str(exc), + ) + try: + return _build_parser_instance("basic", template_kwargs) + except Exception: + identity_cls = ReasoningParserManager.get_reasoning_parser("identity") + return identity_cls( + _TokenizerAdapter(), chat_template_kwargs=template_kwargs + ) + + raise HTTPException(400, str(exc)) from exc + + +def _validate_and_get_tool_call_format( + request_data: ChatCompletionRequest, default_format: str +) -> str: + tool_choice = request_data.tool_choice + parser_key = config.model.tool_call_parser + enable_auto = bool(config.model.enable_auto_tool_choice) + parser_names = list_tool_call_parsers() + + if parser_key and parser_key not in parser_names: + parsers_str = ", ".join(sorted(parser_names)) + raise HTTPException( + 400, + f"invalid tool call parser: {parser_key} (choose from {{{parsers_str}}})", + ) + + if tool_choice == "auto" and (not enable_auto or not parser_key): + raise HTTPException( + 400, + '"auto" tool choice requires --enable-auto-tool-choice and ' + "--tool-call-parser to be set", + ) + + if tool_choice not in (None, "none", "auto") and parser_key is None: + raise HTTPException( + 400, + f'tool_choice="{tool_choice}" requires --tool-call-parser to be set', + ) + + if ( + tool_choice == "none" + and config.model.exclude_tools_when_tool_choice_none + and request_data.tools + ): + request_data.tools = None + + resolved_format = resolve_tool_call_format(parser_key, default_format) + if not resolved_format: + raise HTTPException( + 400, + f"Could not resolve format for tool_call_parser={parser_key}", + ) + return resolved_format + + def _serialize_stream_chunk(chunk) -> str: """Serialize a streaming chunk with OpenAI-compatible field handling. @@ -80,13 +255,18 @@ def _create_response( tool_calls_raw = generation.get("tool_calls") if tool_calls_raw: - parsed = ToolCallProcessor.parse(tool_calls_raw, format=tool_call_format) + parsed = ToolCallProcessor.parse( + tool_calls_raw, + format=tool_call_format, + parser_key=parser_key, + ) if parsed and isinstance(tool_choice, NamedToolChoice): parsed = ToolCallProcessor.filter_by_name( parsed, tool_choice.function.name ) if parsed: message.tool_calls = parsed + message.content = None else: logger.warning( "Tool call text present but parsing returned no results " @@ -262,6 +442,7 @@ def _build_tool_call_chunks( tool_calls: List[ToolCall], request_id: str, model_name: str, + choice_index: int, ) -> List[ChatCompletionStreamChunk]: """Build the OpenAI-standard streaming sequence for tool calls. @@ -294,7 +475,7 @@ def _build_tool_call_chunks( id=chunk_id, choices=[ ChatCompletionStreamChoice( - index=0, + index=choice_index, delta=tool_call_message, finish_reason=None, ) @@ -306,7 +487,7 @@ def _build_tool_call_chunks( # Use model_construct to prevent Pydantic's smart Union from # coercing the empty dict {} into ChatCompletionMessage(role="user") finish_choice = ChatCompletionStreamChoice.model_construct( - index=0, + index=choice_index, delta={}, finish_reason="tool_calls", logprobs=None, @@ -364,6 +545,9 @@ async def format_messages_with_template( message_dicts.append(message.model_dump(exclude_none=True)) + if _get_tokenizer_mode() == "mistral": + _truncate_mistral_tool_ids(message_dicts) + # Pre-template: convert tool_call arguments from JSON strings to dicts. # OpenAI-compatible clients (Kilo, Roo, etc.) send arguments as JSON # strings per the OAI spec, but Qwen3-Coder's template calls @@ -399,6 +583,8 @@ async def apply_chat_template(data: ChatCompletionRequest): # Locally store tools dict tools = data.model_dump()["tools"] + if _get_tokenizer_mode() == "mistral": + _sanitize_mistral_tool_choice_ids(data) try: data.template_vars.update( @@ -464,7 +650,7 @@ async def stream_generate_chat_completion( gen_queue = asyncio.Queue() gen_tasks: List[asyncio.Task] = [] tool_start = model.container.prompt_template.metadata.tool_start - tool_call_format = model.container.prompt_template.metadata.tool_call_format + default_tool_call_format = model.container.prompt_template.metadata.tool_call_format disconnect_task = asyncio.create_task(request_disconnect_loop(request)) try: @@ -516,8 +702,47 @@ async def stream_generate_chat_completion( if disconnect_task.done(): raise CancelledError() + # Stream collector will push an exception to the queue if it fails + if isinstance(generation, Exception): + raise generation + + if "text" in generation and generation.get("finish_reason") is None: + idx = generation["index"] + state = reasoning_states[idx] + + delta_text = unwrap(generation.get("text"), "") + delta_token_ids = _token_ids_from_generation(generation) + + current_text = state.text + delta_text + current_token_ids = state.token_ids + delta_token_ids + + delta_message = reasoning_parser.extract_reasoning_streaming( + state.text, + current_text, + delta_text, + state.token_ids, + current_token_ids, + delta_token_ids, + ) + + state.text = current_text + state.token_ids = current_token_ids + + if delta_message is None: + continue + + generation["text"] = delta_message.content + if data.include_reasoning: + generation["reasoning"] = delta_message.reasoning + generation["reasoning_content"] = delta_message.reasoning + else: + generation["reasoning"] = None + generation["reasoning_content"] = None + if generation["text"] is None: + continue + # Handle options if a tool model is present - if tool_start and data.tool_choice != "none": + if (tool_start or force_tool_pass) and data.tool_choice != "none": if "stop_str" in generation: generations = await generate_tool_calls( prompt, @@ -531,52 +756,6 @@ async def stream_generate_chat_completion( # Only one generation present in this case generation = generations[0] - # Emit proper three-phase tool-call streaming sequence - if "tool_calls" in generation: - tool_calls_raw = generation["tool_calls"] - parsed = ToolCallProcessor.parse( - tool_calls_raw, format=tool_call_format - ) - if parsed and isinstance(data.tool_choice, NamedToolChoice): - parsed = ToolCallProcessor.filter_by_name( - parsed, data.tool_choice.function.name - ) - if parsed: - for tc_chunk in _build_tool_call_chunks( - parsed, - request.state.id, - model_path.name, - ): - yield _serialize_stream_chunk(tc_chunk) - - # Handle completion and usage after tool calls - if ( - all(task.done() for task in gen_tasks) - and gen_queue.empty() - ): - if ( - data.stream_options - and data.stream_options.include_usage - ): - usage_chunk = _create_stream_chunk( - request.state.id, - generation, - model_path.name, - is_usage_chunk=True, - ) - yield _serialize_stream_chunk(usage_chunk) - - logger.info( - "Finished chat completion streaming " - f"request {request.state.id}" - ) - yield "[DONE]" - break - continue - - elif "text" in generation: - current_generation_text += generation["text"] - # Emit proper three-phase tool-call streaming sequence if "tool_calls" in generation: tool_calls_raw = generation["tool_calls"] @@ -652,6 +831,8 @@ async def stream_generate_chat_completion( # Get out if the request gets disconnected handle_request_disconnect("Chat completion generation cancelled by user.") + except HTTPException as exc: + yield get_generator_error(str(exc.detail)) except Exception: yield get_generator_error( "Chat completion aborted. Please check the server console." @@ -670,7 +851,10 @@ async def generate_chat_completion( ): gen_tasks: List[asyncio.Task] = [] tool_start = model.container.prompt_template.metadata.tool_start - tool_call_format = model.container.prompt_template.metadata.tool_call_format + default_tool_call_format = model.container.prompt_template.metadata.tool_call_format + tool_call_format = _validate_and_get_tool_call_format( + data, default_tool_call_format + ) try: logger.info(f"Received chat completion request {request.state.id}") @@ -705,6 +889,20 @@ async def generate_chat_completion( tool_call_format=tool_call_format, ) + reasoning_parser = _build_reasoning_parser(data) + for generation in generations: + reasoning, content = reasoning_parser.extract_reasoning( + unwrap(generation.get("text"), ""), + data, + ) + + if not data.include_reasoning: + reasoning = None + + generation["reasoning"] = reasoning + generation["reasoning_content"] = reasoning + generation["text"] = content + response = _create_response( request.state.id, generations, @@ -739,8 +937,16 @@ async def generate_tool_calls( ): gen_tasks: List[asyncio.Task] = [] tool_start = model.container.prompt_template.metadata.tool_start - tool_call_format = model.container.prompt_template.metadata.tool_call_format + if tool_call_format is None: + default_tool_call_format = model.container.prompt_template.metadata.tool_call_format + tool_call_format = _validate_and_get_tool_call_format( + data, default_tool_call_format + ) tool_choice = data.tool_choice + parser_key = config.model.tool_call_parser + use_native_generation = parser_uses_native_tool_generation( + parser_key, tool_call_format + ) if tool_choice == "none": return generations @@ -751,12 +957,14 @@ async def generate_tool_calls( # Copy to make sure the parent JSON schema doesn't get modified tool_data = data.model_copy(deep=True) - if tool_call_format in ("xml", "auto"): - # XML / auto mode: let the model generate its natural output - # without JSON schema constraint + if use_native_generation: + # Native syntax mode: let the model generate its natural tool-call + # representation without JSON schema constraint. logger.debug( - f"generate_tool_calls: Using '{tool_call_format}' mode " - f"(no JSON schema constraint)" + "generate_tool_calls: Using parser '{}' in native mode " + "(format={}, no JSON schema constraint)", + parser_key or "template-default", + tool_call_format, ) # Remove tool_start from stop strings so the model can emit @@ -799,11 +1007,11 @@ async def generate_tool_calls( if precursor_text: tool_prompt = tool_prompt + precursor_text - # For XML/auto mode: append tool_start back to prompt. + # For native generation mode: append tool_start back to prompt. # The stop string was consumed by the first pass and not included - # in full_text, but the model expects to continue after . + # in full_text, but the model expects to continue after tool_start. # Include a trailing newline to match the canonical template format. - if tool_call_format in ("xml", "auto") and tool_start: + if use_native_generation and tool_start: tool_prompt = tool_prompt + tool_start + "\n" gen_request_id = gen.get("request_id") @@ -829,8 +1037,8 @@ async def generate_tool_calls( for gen_idx, tool_call in zip(tool_idx, tool_calls, strict=True): raw_text = tool_call["text"] - if tool_call_format in ("xml", "auto"): - # Prepend tool_start to reconstruct complete XML for parser + if use_native_generation and tool_start: + # Prepend tool_start to reconstruct complete native payload. raw_text = tool_start + "\n" + raw_text generations[gen_idx]["tool_calls"] = raw_text diff --git a/endpoints/core/types/model.py b/endpoints/core/types/model.py index 84229294..f5f0e061 100644 --- a/endpoints/core/types/model.py +++ b/endpoints/core/types/model.py @@ -20,6 +20,7 @@ class ModelCardParameters(BaseModel): rope_alpha: Optional[float] = 1.0 max_batch_size: Optional[int] = 1 chunk_size: Optional[int] = 2048 + tokenizer_mode: Optional[str] = "auto" prompt_template: Optional[str] = None prompt_template_content: Optional[str] = None use_vision: Optional[bool] = False @@ -111,6 +112,10 @@ class ModelLoadRequest(BaseModel): chunk_size: Optional[int] = None output_chunking: Optional[bool] = True prompt_template: Optional[str] = None + tokenizer_mode: Optional[str] = Field( + description="Tokenizer compatibility mode (auto, hf, mistral)", + default=None, + ) vision: Optional[bool] = None # Non-config arguments From e8a76208587d17740eafabdbe6a4ac690f13723f Mon Sep 17 00:00:00 2001 From: lesj0610 Date: Tue, 24 Feb 2026 22:22:49 +0900 Subject: [PATCH 12/19] test(reasoning): align mistral parser token handling with mistral_common --- .../OAI/reasoning/mistral_reasoning_parser.py | 14 +++- tests/mistral_reasoning_parser_test.py | 70 +++++++++++++++++++ 2 files changed, 82 insertions(+), 2 deletions(-) create mode 100644 tests/mistral_reasoning_parser_test.py diff --git a/endpoints/OAI/reasoning/mistral_reasoning_parser.py b/endpoints/OAI/reasoning/mistral_reasoning_parser.py index 99436dad..bda3fabf 100644 --- a/endpoints/OAI/reasoning/mistral_reasoning_parser.py +++ b/endpoints/OAI/reasoning/mistral_reasoning_parser.py @@ -12,11 +12,21 @@ class MistralReasoningParser(BaseThinkingReasoningParser): @property def start_token(self) -> str: - return "[THINK]" + try: + from mistral_common.tokens.tokenizers.base import SpecialTokens + + return SpecialTokens.begin_think + except Exception: + return "[THINK]" @property def end_token(self) -> str: - return "[/THINK]" + try: + from mistral_common.tokens.tokenizers.base import SpecialTokens + + return SpecialTokens.end_think + except Exception: + return "[/THINK]" def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: has_eot = False diff --git a/tests/mistral_reasoning_parser_test.py b/tests/mistral_reasoning_parser_test.py new file mode 100644 index 00000000..d598c018 --- /dev/null +++ b/tests/mistral_reasoning_parser_test.py @@ -0,0 +1,70 @@ +"""Tests for Mistral reasoning parser parity with vLLM behavior.""" + +from endpoints.OAI.reasoning.mistral_reasoning_parser import MistralReasoningParser + + +class _FakeTokenizer: + def get_vocab(self): + return { + "[THINK]": 301, + "[/THINK]": 302, + } + + +def _parser() -> MistralReasoningParser: + return MistralReasoningParser(_FakeTokenizer()) + + +def test_extract_reasoning_with_valid_think_section(): + parser = _parser() + + reasoning, content = parser.extract_reasoning( + "[THINK]This is a reasoning section[/THINK]This is the rest", + request=None, + ) + + assert reasoning == "This is a reasoning section" + assert content == "This is the rest" + + +def test_extract_reasoning_with_invalid_end_token_only(): + parser = _parser() + + reasoning, content = parser.extract_reasoning( + "This is a reasoning section[/THINK]This is the rest", + request=None, + ) + + assert reasoning is None + assert content == "This is a reasoning sectionThis is the rest" + + +def test_extract_reasoning_with_begin_token_only(): + parser = _parser() + + reasoning, content = parser.extract_reasoning( + "[THINK]This is a reasoning section", + request=None, + ) + + assert reasoning == "This is a reasoning section" + assert content is None + + +def test_extract_reasoning_without_think_tokens(): + parser = _parser() + + reasoning, content = parser.extract_reasoning("This is content", request=None) + + assert reasoning is None + assert content == "This is content" + + +def test_is_reasoning_end_and_extract_content_ids(): + parser = _parser() + + assert parser.is_reasoning_end([1, parser.start_token_id, parser.end_token_id]) is True + assert parser.is_reasoning_end([1, 2, 3]) is False + + assert parser.extract_content_ids([7, parser.start_token_id, 9, parser.end_token_id, 10]) == [7, 10] + assert parser.extract_content_ids([7, parser.start_token_id, 9]) == [7] From c74cc176d2ba3d289bdd378de232ef35430f050f Mon Sep 17 00:00:00 2001 From: lesj0610 Date: Tue, 24 Feb 2026 23:18:43 +0900 Subject: [PATCH 13/19] test(model): skip exllamav2 model tests when dependency is missing --- tests/model_test.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/model_test.py b/tests/model_test.py index 662ccbc5..d01258aa 100644 --- a/tests/model_test.py +++ b/tests/model_test.py @@ -1,5 +1,8 @@ """Test the model container.""" +import pytest + +pytest.importorskip("exllamav2") from backends.exllamav2.model import ModelContainer From 88014e77475de92bf729a3cc0b8f3763f0905757 Mon Sep 17 00:00:00 2001 From: lesj0610 Date: Wed, 25 Feb 2026 20:59:30 +0900 Subject: [PATCH 14/19] Align ExLlama tokenizer modes with vLLM and enforce flashinfer path --- backends/exllamav2/model.py | 31 ++++++- backends/exllamav3/model.py | 54 ++++++------ common/config_models.py | 11 ++- common/hardware.py | 28 ++++++- common/optional_dependencies.py | 16 ++-- common/tokenizer_modes.py | 110 +++++++++++++++++++++++++ config_sample.yml | 8 +- docs/01.-Getting-Started.md | 2 +- docs/02.-Server-options.md | 3 +- endpoints/OAI/utils/chat_completion.py | 18 ++-- endpoints/core/types/model.py | 10 ++- pyproject.toml | 33 +++----- tests/mistral_tokenizer_mode_test.py | 76 +++++++++++++++++ tests/wheel_test.py | 29 +++++-- 14 files changed, 349 insertions(+), 80 deletions(-) create mode 100644 common/tokenizer_modes.py create mode 100644 tests/mistral_tokenizer_mode_test.py diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index ce24b65a..b313ac89 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -46,7 +46,12 @@ from common.multimodal import MultimodalEmbeddingWrapper from common.optional_dependencies import check_package_version from common.sampling import BaseSamplerRequest +from common.tabby_config import config from common.templating import PromptTemplate, find_prompt_template +from common.tokenizer_modes import ( + normalize_tokenizer_mode, + should_enable_mistral_tokenizer_mode, +) from common.transformers_utils import HFModel from common.utils import calculate_rope_alpha, coalesce, unwrap from endpoints.core.types.model import ModelCard, ModelCardParameters @@ -84,6 +89,7 @@ class ExllamaV2Container(BaseModelContainer): draft_cache_mode: str = "FP16" max_batch_size: Optional[int] = None tokenizer_mode: str = "auto" + mistral_tokenizer_models: List[str] = [] # GPU split vars gpu_split: List[float] = [] @@ -121,7 +127,29 @@ async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs self.model_dir = model_directory self.config.model_dir = str(model_directory.resolve()) self.hf_model = hf_model - self.tokenizer_mode = str(unwrap(kwargs.get("tokenizer_mode"), "auto")).lower() + self.tokenizer_mode, mode_message = normalize_tokenizer_mode( + coalesce(kwargs.get("tokenizer_mode"), config.model.tokenizer_mode, "auto") + ) + if mode_message: + logger.warning(mode_message) + + if self.tokenizer_mode == "mistral": + mistral_tokenizer_models = coalesce( + kwargs.get("mistral_tokenizer_models"), + config.model.mistral_tokenizer_models, + [], + ) + self.mistral_tokenizer_models = list(mistral_tokenizer_models) + if should_enable_mistral_tokenizer_mode( + model_directory, mistral_tokenizer_models + ): + logger.info("Using tokenizer_mode='mistral' compatibility path.") + else: + logger.warning( + "tokenizer_mode='mistral' requested but model does not appear " + "to support mistral tokenizer mode. Falling back to default mode." + ) + self.tokenizer_mode = "auto" # Make the max seq len 4096 before preparing the config # This is a better default than 2048 @@ -443,6 +471,7 @@ def model_info(self): cache_mode=self.cache_mode, chunk_size=self.config.max_input_len, tokenizer_mode=self.tokenizer_mode, + mistral_tokenizer_models=self.mistral_tokenizer_models, use_vision=self.use_vision, draft=draft_model_card, ) diff --git a/backends/exllamav3/model.py b/backends/exllamav3/model.py index b2a3cc05..3a3e797b 100644 --- a/backends/exllamav3/model.py +++ b/backends/exllamav3/model.py @@ -34,28 +34,22 @@ log_metrics, log_prompt, ) -from common.hardware import hardware_supports_flash_attn +from common.hardware import hardware_supports_flashinfer from common.health import HealthManager from common.multimodal import MultimodalEmbeddingWrapper from common.optional_dependencies import check_package_version from common.sampling import BaseSamplerRequest +from common.tabby_config import config from common.templating import PromptTemplate, find_prompt_template +from common.tokenizer_modes import ( + normalize_tokenizer_mode, + should_enable_mistral_tokenizer_mode, +) from common.transformers_utils import HFModel from common.utils import coalesce, unwrap from endpoints.core.types.model import ModelCard, ModelCardParameters -_SUPPORTED_TOKENIZER_MODES = {"auto", "hf", "mistral"} - - -def _supports_mistral_tokenizer_mode(model_directory: pathlib.Path) -> bool: - return ( - (model_directory / "tekken.json").exists() - or (model_directory / "tokenizer.model").exists() - or any(model_directory.glob("tokenizer.model.v*")) - ) - - class ExllamaV3Container(BaseModelContainer): """Abstract base class for model containers.""" @@ -100,6 +94,7 @@ class ExllamaV3Container(BaseModelContainer): max_rq_tokens: Optional[int] = 2048 max_batch_size: Optional[int] = None tokenizer_mode: str = "auto" + mistral_tokenizer_models: List[str] = [] # Required methods @classmethod @@ -119,19 +114,26 @@ async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs # Make sure ExllamaV3 is up to date check_package_version("exllamav3", "0.0.7") + check_package_version("flashinfer-python", "0.6.3") self.model_dir = model_directory self.hf_model = hf_model - requested_tokenizer_mode = str(unwrap(kwargs.get("tokenizer_mode"), "auto")).lower() - if requested_tokenizer_mode not in _SUPPORTED_TOKENIZER_MODES: - logger.warning( - "Unknown tokenizer_mode '{}' requested. Falling back to 'auto'.", - requested_tokenizer_mode, - ) - requested_tokenizer_mode = "auto" + requested_tokenizer_mode, mode_message = normalize_tokenizer_mode( + coalesce(kwargs.get("tokenizer_mode"), config.model.tokenizer_mode, "auto") + ) + if mode_message: + logger.warning(mode_message) + mistral_tokenizer_models = coalesce( + kwargs.get("mistral_tokenizer_models"), + config.model.mistral_tokenizer_models, + [], + ) + self.mistral_tokenizer_models = list(mistral_tokenizer_models) if requested_tokenizer_mode == "mistral": - if _supports_mistral_tokenizer_mode(model_directory): + if should_enable_mistral_tokenizer_mode( + model_directory, mistral_tokenizer_models + ): logger.info("Using tokenizer_mode='mistral' compatibility path.") else: logger.warning( @@ -243,12 +245,12 @@ async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs value / 1024 for value in autosplit_reserve_megabytes ] - if not hardware_supports_flash_attn(gpu_device_list): + if not hardware_supports_flashinfer(gpu_device_list): gpu_unsupported_message = ( "Unable to run ExllamaV3 because an unsupported GPU is " "found in this configuration. \n" - "All GPUs must be ampere " - "(30 series) or newer. AMD GPUs are not supported." + "All GPUs must be Ampere " + "(30 series) or newer for flashinfer. AMD GPUs are not supported." ) logger.warning(gpu_unsupported_message) @@ -402,6 +404,7 @@ def model_info(self) -> ModelCard: cache_mode=self.cache_mode, chunk_size=self.chunk_size, tokenizer_mode=self.tokenizer_mode, + mistral_tokenizer_models=self.mistral_tokenizer_models, use_vision=self.use_vision, ) @@ -465,8 +468,10 @@ async def load_gen(self, progress_callback=None, **kwargs): Progress updates """ + acquired_lock = False try: await self.load_lock.acquire() + acquired_lock = True # Wait for existing generation jobs to finish await self.wait_for_jobs(kwargs.get("skip_wait")) @@ -487,7 +492,8 @@ async def load_gen(self, progress_callback=None, **kwargs): self.loaded = True logger.info("Model successfully loaded.") finally: - self.load_lock.release() + if acquired_lock and self.load_lock.locked(): + self.load_lock.release() async with self.load_condition: self.load_condition.notify_all() diff --git a/common/config_models.py b/common/config_models.py index 66aa89ff..c61801b8 100644 --- a/common/config_models.py +++ b/common/config_models.py @@ -298,12 +298,21 @@ class ModelConfig(BaseConfigModel): "auto", description=( "Tokenizer compatibility mode for chat formatting.\n" - "Compatible values: auto, hf, mistral.\n" + "Compatible values: auto, hf, slow, mistral, deepseek_v32.\n" + "slow is normalized to hf for ExLlama backends.\n" "mistral applies Mistral-specific message normalization " "(tool-call ID handling) and falls back to default behavior " "for non-Mistral models." ), ) + mistral_tokenizer_models: Optional[List[str]] = Field( + default_factory=list, + description=( + "Optional allowlist for tokenizer_mode='mistral'.\n" + "When set, only listed model names/paths will use mistral mode.\n" + "If empty, Tabby auto-detects Mistral-family models." + ), + ) reasoning_parser: Optional[str] = Field( None, description=( diff --git a/common/hardware.py b/common/hardware.py index 10723c5e..2e3ac25b 100644 --- a/common/hardware.py +++ b/common/hardware.py @@ -1,6 +1,13 @@ import torch +def _min_compute_capability(gpu_device_list: list[int]) -> int: + return min( + torch.cuda.get_device_capability(device = device_idx)[0] + for device_idx in gpu_device_list + ) + + def hardware_supports_flash_attn(gpu_device_list: list[int]): """ Check whether all GPUs in list support FA2 @@ -9,10 +16,23 @@ def hardware_supports_flash_attn(gpu_device_list: list[int]): AMD is also unsupported until ROCm updates its FA2 fork """ - min_compute_capability = min( - torch.cuda.get_device_capability(device=device_idx)[0] - for device_idx in gpu_device_list - ) + min_compute_capability = _min_compute_capability(gpu_device_list) + + if torch.version.hip or min_compute_capability < 8: + return False + else: + return True + + +def hardware_supports_flashinfer(gpu_device_list: list[int]): + """ + Check whether all GPUs in list support flashinfer. + + Keep the same minimum as the legacy FA2 path for now: + Ampere (SM80) or newer, CUDA only. + """ + + min_compute_capability = _min_compute_capability(gpu_device_list) if torch.version.hip or min_compute_capability < 8: return False diff --git a/common/optional_dependencies.py b/common/optional_dependencies.py index 5a23a2ee..e4195bdd 100644 --- a/common/optional_dependencies.py +++ b/common/optional_dependencies.py @@ -14,12 +14,12 @@ class DependenciesModel(BaseModel): """Model of which optional dependencies are installed.""" - torch: bool - exllamav2: bool - exllamav3: bool - flash_attn: bool - infinity_emb: bool - sentence_transformers: bool + torch: bool + exllamav2: bool + exllamav3: bool + flashinfer: bool + infinity_emb: bool + sentence_transformers: bool @computed_field @property @@ -28,8 +28,8 @@ def extras(self) -> bool: @computed_field @property - def inference(self) -> bool: - return self.torch and (self.exllamav2 or (self.exllamav3 and self.flash_attn)) + def inference(self) -> bool: + return self.torch and (self.exllamav2 or (self.exllamav3 and self.flashinfer)) def is_installed(package_name: str) -> bool: diff --git a/common/tokenizer_modes.py b/common/tokenizer_modes.py new file mode 100644 index 00000000..48c0f93b --- /dev/null +++ b/common/tokenizer_modes.py @@ -0,0 +1,110 @@ +"""Helpers for tokenizer compatibility mode detection.""" + +import json +import pathlib +from typing import Iterable + +VLLM_COMPAT_TOKENIZER_MODES = { + "auto", + "hf", + "slow", + "mistral", + "deepseek_v32", +} + + +def normalize_tokenizer_mode(tokenizer_mode: str | None) -> tuple[str, str | None]: + mode = str(tokenizer_mode or "auto").lower() + if mode not in VLLM_COMPAT_TOKENIZER_MODES: + return ( + "auto", + f"Unknown tokenizer_mode '{mode}' requested. Falling back to 'auto'.", + ) + + if mode == "slow": + return ( + "hf", + "tokenizer_mode='slow' requested, but ExLlama backends do not expose " + "a distinct slow tokenizer path. Using 'hf' compatibility mode.", + ) + + return mode, None + + +def _read_model_type(model_directory: pathlib.Path) -> str: + config_path = model_directory / "config.json" + if not config_path.exists(): + return "" + + try: + with open(config_path, "r", encoding="utf-8") as config_file: + return str(json.load(config_file).get("model_type", "")).lower() + except Exception: + return "" + + +def has_mistral_tokenizer_assets(model_directory: pathlib.Path) -> bool: + return ( + (model_directory / "tekken.json").exists() + or (model_directory / "tokenizer.model").exists() + or any(model_directory.glob("tokenizer.model.v*")) + ) + + +def supports_mistral_tokenizer_mode(model_directory: pathlib.Path) -> bool: + """ + Return True when mistral tokenizer mode is safe to enable for this model. + + vLLM uses mistral-common only for Mistral-family models in auto mode. + Match that intent by requiring both: + 1. A mistral-family model type. + 2. Mistral tokenizer assets. + """ + + model_type = _read_model_type(model_directory) + is_mistral_family = model_type.startswith("mistral") or model_type.startswith( + "mixtral" + ) + + return is_mistral_family and has_mistral_tokenizer_assets(model_directory) + + +def _matches_allowlist( + model_directory: pathlib.Path, mistral_tokenizer_models: Iterable[str] +) -> bool: + model_name = model_directory.name.lower() + model_path = model_directory.as_posix().lower().rstrip("/") + + for entry in mistral_tokenizer_models: + normalized = str(entry).strip().lower().strip("/") + if not normalized: + continue + + if model_name == normalized: + return True + + if model_path.endswith(normalized): + return True + + return False + + +def should_enable_mistral_tokenizer_mode( + model_directory: pathlib.Path, + mistral_tokenizer_models: Iterable[str] | None = None, +) -> bool: + """ + Decide whether mistral tokenizer mode should be enabled. + + If an explicit allowlist is configured, only listed Mistral-family models + can use mistral mode. If no allowlist is provided, fallback to auto + detection (mistral-family model + tokenizer assets). + """ + + allowlist = list(mistral_tokenizer_models or []) + if allowlist: + return _matches_allowlist(model_directory, allowlist) and ( + supports_mistral_tokenizer_mode(model_directory) + ) + + return supports_mistral_tokenizer_mode(model_directory) diff --git a/config_sample.yml b/config_sample.yml index 6f53a78a..e7d75424 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -154,11 +154,17 @@ model: prompt_template: # Tokenizer compatibility mode for chat formatting (default: auto). - # Values: auto, hf, mistral. + # Values: auto, hf, slow, mistral, deepseek_v32. + # Note: ExLlama backends treat "slow" as "hf". # mistral applies Mistral-style tool-call ID normalization and # falls back to default behavior for non-Mistral models. tokenizer_mode: auto + # Optional allowlist for tokenizer_mode='mistral'. + # If set, only listed model names/paths can use mistral mode. + # Leave empty to keep auto-detection behavior. + mistral_tokenizer_models: [] + # Reasoning parser key for splitting hidden reasoning and final content. # Compatible keys include: basic, exaone4, deepseek_r1, deepseek_v3. # If omitted, TabbyAPI defaults to `basic`. diff --git a/docs/01.-Getting-Started.md b/docs/01.-Getting-Started.md index 330116c6..513883c8 100644 --- a/docs/01.-Getting-Started.md +++ b/docs/01.-Getting-Started.md @@ -101,7 +101,7 @@ These scripts exit after running their respective tasks. To start TabbyAPI, run 1. `pip install -U .[cu12]` = CUDA 12.x 2. `pip install -U .[amd]` = ROCm 6.0 -If you don't want to update dependencies that come from wheels (torch, exllamav2, and flash attention 2), use `pip install .` or pass the `--nowheel` flag when invoking the start scripts. +If you don't want to update dependencies that come from wheels (torch, exllamav2/exllamav3, and flashinfer), use `pip install .` or pass the `--nowheel` flag when invoking the start scripts. ### Update Exllamav2 diff --git a/docs/02.-Server-options.md b/docs/02.-Server-options.md index 2efc9c3a..9ca6683c 100644 --- a/docs/02.-Server-options.md +++ b/docs/02.-Server-options.md @@ -73,7 +73,8 @@ Note: Most of the options here will only apply on initial model load/startup (ep | chunk_size | Int (2048) | Amount of tokens per chunk with ingestion. A lower value reduces VRAM usage at the cost of ingestion speed. | | max_batch_size | Int (None) | The absolute maximum amount of prompts to process at one time. This value is automatically adjusted based on cache size. | | prompt_template | String (None) | Name of a jinja2 chat template to apply for this model. Must be located in the `templates` directory. | -| tokenizer_mode | String ("auto") | Tokenizer compatibility mode for chat formatting. Supported values are `auto`, `hf`, `mistral`; `mistral` applies Mistral-specific tool-call ID handling and falls back to default behavior on non-Mistral models. | +| tokenizer_mode | String ("auto") | Tokenizer compatibility mode for chat formatting. Supported values are `auto`, `hf`, `slow`, `mistral`, `deepseek_v32`; `slow` maps to `hf` on ExLlama backends, and `mistral` applies Mistral-specific tool-call ID handling with fallback on non-Mistral models. | +| mistral_tokenizer_models | List[String] ([]) | Optional allowlist for `tokenizer_mode: mistral`. If set, only listed model names/paths use mistral mode. Leave empty to keep auto-detection behavior. | | reasoning_parser | String (None) | Reasoning parser key used to split reasoning and final answer text (vLLM-compatible names, default parser behavior is `basic`). | | enable_auto_tool_choice | Bool (False) | Enables vLLM-style automatic tool choice handling. Equivalent to `--enable-auto-tool-choice` and requires `tool_call_parser`. | | tool_call_parser | String (None) | vLLM-compatible tool parser key used to parse model-emitted tool calls. Equivalent to `--tool-call-parser`. | diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index beb20879..62eb899a 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -19,6 +19,7 @@ request_disconnect_loop, ) from common.tabby_config import config +from common.tokenizer_modes import normalize_tokenizer_mode from common.utils import unwrap from endpoints.OAI.reasoning import ReasoningParserManager from endpoints.OAI.types.chat_completion import ( @@ -42,9 +43,6 @@ from endpoints.OAI.utils.tools import ToolCallProcessor, TOOL_CALL_SCHEMA -_SUPPORTED_TOKENIZER_MODES = {"auto", "hf", "mistral"} - - @dataclass class _StreamReasoningState: text: str = "" @@ -88,15 +86,11 @@ def _token_ids_from_generation(generation: dict) -> List[int]: def _get_tokenizer_mode() -> str: container_mode = getattr(model.container, "tokenizer_mode", None) - tokenizer_mode = str( - unwrap(container_mode, unwrap(config.model.tokenizer_mode, "auto")) or "auto" - ).lower() - if tokenizer_mode not in _SUPPORTED_TOKENIZER_MODES: - logger.warning( - "Unknown tokenizer_mode '{}' configured; falling back to 'auto'.", - tokenizer_mode, - ) - return "auto" + tokenizer_mode, mode_message = normalize_tokenizer_mode( + unwrap(container_mode, unwrap(config.model.tokenizer_mode, "auto")) + ) + if mode_message: + logger.warning(mode_message) return tokenizer_mode diff --git a/endpoints/core/types/model.py b/endpoints/core/types/model.py index f5f0e061..d090fce8 100644 --- a/endpoints/core/types/model.py +++ b/endpoints/core/types/model.py @@ -21,6 +21,7 @@ class ModelCardParameters(BaseModel): max_batch_size: Optional[int] = 1 chunk_size: Optional[int] = 2048 tokenizer_mode: Optional[str] = "auto" + mistral_tokenizer_models: Optional[List[str]] = Field(default_factory=list) prompt_template: Optional[str] = None prompt_template_content: Optional[str] = None use_vision: Optional[bool] = False @@ -113,9 +114,16 @@ class ModelLoadRequest(BaseModel): output_chunking: Optional[bool] = True prompt_template: Optional[str] = None tokenizer_mode: Optional[str] = Field( - description="Tokenizer compatibility mode (auto, hf, mistral)", + description="Tokenizer compatibility mode (auto, hf, slow, mistral, deepseek_v32)", default=None, ) + mistral_tokenizer_models: Optional[List[str]] = Field( + default_factory=list, + description=( + "Optional allowlist for tokenizer_mode='mistral'. " + "Only listed Mistral-family models can use mistral mode." + ), + ) vision: Optional[bool] = None # Non-config arguments diff --git a/pyproject.toml b/pyproject.toml index 7a04cce9..12d38b00 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,26 +78,19 @@ cu12 = [ "exllamav2 @ https://github.com/turboderp-org/exllamav2/releases/download/v0.3.2/exllamav2-0.3.2+cu128.torch2.9.0-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Exl3 - "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.21/exllamav3-0.0.21+cu128.torch2.9.0-cp313-cp313-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.13'", - "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.21/exllamav3-0.0.21+cu128.torch2.9.0-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", - "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.21/exllamav3-0.0.21+cu128.torch2.9.0-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", - "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.21/exllamav3-0.0.21+cu128.torch2.9.0-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", - "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.21/exllamav3-0.0.21+cu128.torch2.9.0-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.13'", - "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.21/exllamav3-0.0.21+cu128.torch2.9.0-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", - "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.21/exllamav3-0.0.21+cu128.torch2.9.0-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", - "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.21/exllamav3-0.0.21+cu128.torch2.9.0-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", - - # Windows FA2 from https://github.com/kingbri1/flash-attention/releases - "flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu128torch2.9.0cxx11abiFALSE-cp313-cp313-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.13'", - "flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu128torch2.9.0cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", - "flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu128torch2.9.0cxx11abiFALSE-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", - "flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu128torch2.9.0cxx11abiFALSE-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", - - # Linux FA2 from https://github.com/kingbri1/flash-attention/releases - "flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu128torch2.9.0cxx11abiFALSE-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.13'", - "flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu128torch2.9.0cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", - "flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu128torch2.9.0cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", - "flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu128torch2.9.0cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.22/exllamav3-0.0.22+cu128.torch2.9.0-cp313-cp313-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.13'", + "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.22/exllamav3-0.0.22+cu128.torch2.9.0-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", + "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.22/exllamav3-0.0.22+cu128.torch2.9.0-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", + "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.22/exllamav3-0.0.22+cu128.torch2.9.0-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", + "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.22/exllamav3-0.0.22+cu128.torch2.9.0-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.13'", + "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.22/exllamav3-0.0.22+cu128.torch2.9.0-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", + "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.22/exllamav3-0.0.22+cu128.torch2.9.0-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.22/exllamav3-0.0.22+cu128.torch2.9.0-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + + # FlashInfer backend for ExLlamaV3 (CUDA 12.8 wheel set) + "flashinfer-python @ https://github.com/flashinfer-ai/flashinfer/releases/download/v0.6.3/flashinfer_python-0.6.3-py3-none-any.whl ; platform_system == 'Linux' and platform_machine == 'x86_64'", + "flashinfer-jit-cache @ https://github.com/flashinfer-ai/flashinfer/releases/download/v0.6.3/flashinfer_jit_cache-0.6.3+cu128-cp39-abi3-manylinux_2_28_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64'", + ] amd = [ # Torch triton for ROCm diff --git a/tests/mistral_tokenizer_mode_test.py b/tests/mistral_tokenizer_mode_test.py new file mode 100644 index 00000000..1d422051 --- /dev/null +++ b/tests/mistral_tokenizer_mode_test.py @@ -0,0 +1,76 @@ +"""Tests for mistral tokenizer mode auto-detection.""" + +import json + +from common.tokenizer_modes import ( + normalize_tokenizer_mode, + should_enable_mistral_tokenizer_mode, + supports_mistral_tokenizer_mode, +) + + +def _write_config(directory, model_type: str) -> None: + with open(directory / "config.json", "w", encoding="utf-8") as config_file: + json.dump({"model_type": model_type}, config_file) + + +def test_supports_mistral_with_tekken(tmp_path): + _write_config(tmp_path, "mistral3") + (tmp_path / "tekken.json").write_text("{}", encoding="utf-8") + + assert supports_mistral_tokenizer_mode(tmp_path) is True + + +def test_supports_mistral_with_sentencepiece_variant(tmp_path): + _write_config(tmp_path, "mistral") + (tmp_path / "tokenizer.model.v3").write_text("dummy", encoding="utf-8") + + assert supports_mistral_tokenizer_mode(tmp_path) is True + + +def test_rejects_non_mistral_with_sentencepiece_tokenizer(tmp_path): + _write_config(tmp_path, "gemma2") + (tmp_path / "tokenizer.model").write_text("dummy", encoding="utf-8") + + assert supports_mistral_tokenizer_mode(tmp_path) is False + + +def test_allowlist_enables_listed_mistral_model(tmp_path): + _write_config(tmp_path, "mistral3") + (tmp_path / "tekken.json").write_text("{}", encoding="utf-8") + + assert should_enable_mistral_tokenizer_mode( + tmp_path, [tmp_path.name, "other-model"] + ) + + +def test_allowlist_disables_non_mistral_even_if_listed(tmp_path): + _write_config(tmp_path, "gemma2") + (tmp_path / "tokenizer.model").write_text("dummy", encoding="utf-8") + + assert not should_enable_mistral_tokenizer_mode(tmp_path, [tmp_path.name]) + + +def test_allowlist_disables_unlisted_mistral_model(tmp_path): + _write_config(tmp_path, "mistral3") + (tmp_path / "tekken.json").write_text("{}", encoding="utf-8") + + assert not should_enable_mistral_tokenizer_mode(tmp_path, ["another-model"]) + + +def test_normalize_tokenizer_mode_accepts_deepseek_v32(): + normalized, message = normalize_tokenizer_mode("deepseek_v32") + assert normalized == "deepseek_v32" + assert message is None + + +def test_normalize_tokenizer_mode_maps_slow_to_hf(): + normalized, message = normalize_tokenizer_mode("slow") + assert normalized == "hf" + assert message is not None + + +def test_normalize_tokenizer_mode_unknown_falls_back_to_auto(): + normalized, message = normalize_tokenizer_mode("unknown_mode") + assert normalized == "auto" + assert message is not None diff --git a/tests/wheel_test.py b/tests/wheel_test.py index 733c6188..1300d100 100644 --- a/tests/wheel_test.py +++ b/tests/wheel_test.py @@ -2,23 +2,40 @@ from importlib.metadata import version from importlib.util import find_spec +from packaging import version as package_version successful_packages = [] errored_packages = [] if find_spec("flash_attn") is not None: - print(f"Flash attention on version {version('flash_attn')} successfully imported") - successful_packages.append("flash_attn") + torch_version = version("torch").split("+")[0] if find_spec("torch") else "0" + if package_version.parse(torch_version) >= package_version.parse("2.10.0"): + print( + f"Flash attention 2 detected with torch {torch_version}. " + "This combination is unsupported for the flashinfer migration." + ) + errored_packages.append("flash_attn") + else: + print( + "Flash attention 2 is installed. " + "ExLlamaV3 now uses flashinfer." + ) + +if find_spec("flashinfer") is not None: + print( + "FlashInfer on version " + f"{version('flashinfer-python')} successfully imported" + ) + successful_packages.append("flashinfer") else: - print("Flash attention 2 is not found in your environment.") - errored_packages.append("flash_attn") + print("FlashInfer is not found in your environment.") + errored_packages.append("flashinfer") if find_spec("exllamav2") is not None: print(f"Exllamav2 on version {version('exllamav2')} successfully imported") successful_packages.append("exllamav2") else: - print("Exllamav2 is not found in your environment.") - errored_packages.append("exllamav2") + print("Exllamav2 is not found in your environment (optional).") if find_spec("torch") is not None: print(f"Torch on version {version('torch')} successfully imported") From e52fde2ac98900380ee6ed323577fbdef05e0241 Mon Sep 17 00:00:00 2001 From: lesj0610 Date: Fri, 27 Feb 2026 00:03:47 +0900 Subject: [PATCH 15/19] fix: surface and avoid exllamav3 startup lock deadlock --- common/model.py | 66 ++++++++++++++++++++++++++++++++++++-------- common/multimodal.py | 26 ++++++++--------- 2 files changed, 65 insertions(+), 27 deletions(-) diff --git a/common/model.py b/common/model.py index 4ac4861a..5afafb06 100644 --- a/common/model.py +++ b/common/model.py @@ -10,7 +10,7 @@ from fastapi import HTTPException from loguru import logger from ruamel.yaml import YAML -from typing import Dict, Optional +from typing import Dict, Optional, Type from backends.base_model_container import BaseModelContainer from common.logger import get_loading_progress_bar @@ -25,24 +25,60 @@ embeddings_container = None -_BACKEND_REGISTRY: Dict[str, BaseModelContainer] = {} +_BACKEND_REGISTRY: Dict[str, Type[BaseModelContainer]] = {} +_BACKEND_REGISTRY_INITIALIZED = False +_INFINITY_CONTAINER_CLASS = None -if dependencies.exllamav2: - from backends.exllamav2.model import ExllamaV2Container - _BACKEND_REGISTRY["exllamav2"] = ExllamaV2Container +def _log_exllamav3_lock_hint(): + """Warn about stale torch extension lock files before ExLlama import.""" + lock_paths = sorted( + pathlib.Path.home().glob(".cache/torch_extensions/*/exllamav3_ext/lock") + ) + if not lock_paths: + return + + sample_paths = ", ".join(str(path) for path in lock_paths[:3]) + remaining = len(lock_paths) - 3 + if remaining > 0: + sample_paths += f", ... (+{remaining} more)" + + logger.warning( + "Detected torch extension lock file(s): " + f"{sample_paths}. Startup may appear frozen while waiting on this lock. " + "If no build process (ninja/nvcc) is active, remove the lock file and retry." + ) + +def _ensure_backend_registry(): + """Initialize backend registry lazily to avoid heavy imports at startup.""" + global _BACKEND_REGISTRY_INITIALIZED + global _INFINITY_CONTAINER_CLASS -if dependencies.exllamav3: - from backends.exllamav3.model import ExllamaV3Container + if _BACKEND_REGISTRY_INITIALIZED: + return - _BACKEND_REGISTRY["exllamav3"] = ExllamaV3Container + if dependencies.exllamav2: + from backends.exllamav2.model import ExllamaV2Container + _BACKEND_REGISTRY["exllamav2"] = ExllamaV2Container + + if dependencies.exllamav3: + logger.info( + "Initializing exllamav3 backend. " + "First run or source changes may trigger extension build." + ) + _log_exllamav3_lock_hint() + from backends.exllamav3.model import ExllamaV3Container -if dependencies.extras: - from backends.infinity.model import InfinityContainer + _BACKEND_REGISTRY["exllamav3"] = ExllamaV3Container - embeddings_container: Optional[InfinityContainer] = None + if dependencies.extras: + from backends.infinity.model import InfinityContainer + + _INFINITY_CONTAINER_CLASS = InfinityContainer + + _BACKEND_REGISTRY_INITIALIZED = True class ModelType(Enum): @@ -159,6 +195,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs): # Fetch the extra HF configuration options hf_model = await HFModel.from_directory(model_path) + _ensure_backend_registry() # Override the max sequence length based on user max_seq_len = kwargs.get("max_seq_len") @@ -255,6 +292,8 @@ async def load_embedding_model(model_path: pathlib.Path, **kwargs): "pip install -U .[extras]" ) + _ensure_backend_registry() + # Check if the model is already loaded if embeddings_container and embeddings_container.engine: loaded_model_name = embeddings_container.model_dir.name @@ -270,7 +309,10 @@ async def load_embedding_model(model_path: pathlib.Path, **kwargs): # Reset to prepare for a new container embeddings_container = None - new_embeddings_container = InfinityContainer(model_path) + if _INFINITY_CONTAINER_CLASS is None: + raise ImportError("Infinity backend is unavailable in the current environment.") + + new_embeddings_container = _INFINITY_CONTAINER_CLASS(model_path) await new_embeddings_container.load(**kwargs) embeddings_container = new_embeddings_container diff --git a/common/multimodal.py b/common/multimodal.py index b92386f3..ee4a1d30 100644 --- a/common/multimodal.py +++ b/common/multimodal.py @@ -1,5 +1,3 @@ -from backends.exllamav2.vision import get_image_embedding_exl2 -from backends.exllamav3.vision import get_image_embedding_exl3 from common import model from loguru import logger from pydantic import BaseModel, Field @@ -7,11 +5,6 @@ from common.optional_dependencies import dependencies -if dependencies.exllamav2: - from exllamav2 import ExLlamaV2VisionTower -if dependencies.exllamav3: - from exllamav3 import Model - class MultimodalEmbeddingWrapper(BaseModel): """Common multimodal embedding wrapper""" @@ -23,21 +16,24 @@ class MultimodalEmbeddingWrapper(BaseModel): async def add(self, url: str): # Determine the type of vision embedding to use if not self.type: - if dependencies.exllamav2 and isinstance( - model.container.vision_model, ExLlamaV2VisionTower - ): - self.type = "ExLlamaV2MMEmbedding" - elif dependencies.exllamav3 and isinstance( - model.container.vision_model, Model - ): - self.type = "MMEmbedding" + container = model.container + if container and getattr(container, "vision_model", None): + container_name = container.__class__.__name__ + if dependencies.exllamav2 and container_name == "ExllamaV2Container": + self.type = "ExLlamaV2MMEmbedding" + elif dependencies.exllamav3 and container_name == "ExllamaV3Container": + self.type = "MMEmbedding" # Create the embedding if self.type == "ExLlamaV2MMEmbedding": + from backends.exllamav2.vision import get_image_embedding_exl2 + embedding = await get_image_embedding_exl2(url) self.content.append(embedding) self.text_alias.append(embedding.text_alias) elif self.type == "MMEmbedding": + from backends.exllamav3.vision import get_image_embedding_exl3 + embedding = await get_image_embedding_exl3(url) self.content.append(embedding) self.text_alias.append(embedding.text_alias) From 824c6c0697fc876a270ce9d936a328b0cbc63990 Mon Sep 17 00:00:00 2001 From: lesj0610 Date: Fri, 27 Feb 2026 01:05:10 +0900 Subject: [PATCH 16/19] Align qwen3 reasoning + expand tool parser parity coverage --- .../OAI/reasoning/qwen3_reasoning_parser.py | 84 ++++++- endpoints/OAI/utils/tools.py | 214 +++++++++++++++++- tests/parser_options_test.py | 57 +++++ tests/qwen3_reasoning_parser_test.py | 114 ++++++++++ tests/reasoning_parser_registry_test.py | 38 ++++ tests/tool_parser_test.py | 139 ++++++++++++ 6 files changed, 639 insertions(+), 7 deletions(-) create mode 100644 tests/qwen3_reasoning_parser_test.py create mode 100644 tests/reasoning_parser_registry_test.py diff --git a/endpoints/OAI/reasoning/qwen3_reasoning_parser.py b/endpoints/OAI/reasoning/qwen3_reasoning_parser.py index 7e6a941b..8d6c3be9 100644 --- a/endpoints/OAI/reasoning/qwen3_reasoning_parser.py +++ b/endpoints/OAI/reasoning/qwen3_reasoning_parser.py @@ -3,12 +3,34 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Sequence from typing import Any +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage from endpoints.OAI.reasoning.basic_parsers import BaseThinkingReasoningParser class Qwen3ReasoningParser(BaseThinkingReasoningParser): + """ + Reasoning parser for the Qwen3/Qwen3.5 model family. + + Qwen3.5 chat templates prefill `` in the prompt, so streaming output + usually begins with reasoning text and only later emits ``. + This parser mirrors vLLM behavior for that path while also honoring + `enable_thinking=False` by routing all deltas as normal content. + """ + + def __init__(self, tokenizer: Any, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + + chat_kwargs = kwargs.get("chat_template_kwargs", {}) or {} + enable_thinking = chat_kwargs.get("enable_thinking") + if enable_thinking is None: + enable_thinking = chat_kwargs.get("thinking") + + # Keep default behavior when no explicit switch is provided. + self.thinking_enabled = True if enable_thinking is None else bool(enable_thinking) + @property def start_token(self) -> str: return "" @@ -17,12 +39,68 @@ def start_token(self) -> str: def end_token(self) -> str: return "" + def _strip_reasoning_tags(self, text: str) -> str: + if not text: + return "" + return text.replace(self.start_token, "").replace(self.end_token, "") + def extract_reasoning( self, model_output: str, request: Any ) -> tuple[str | None, str | None]: - if self.start_token not in model_output or self.end_token not in model_output: + if not self.thinking_enabled: + content = self._strip_reasoning_tags(model_output) + return None, content or None + + # Qwen3.5 templates prefill in the prompt. + # If appears in output (legacy templates), strip it. + model_output_parts = model_output.partition(self.start_token) + model_output = ( + model_output_parts[2] if model_output_parts[1] else model_output_parts[0] + ) + + if self.end_token not in model_output: return None, model_output - _, _, tail = model_output.partition(self.start_token) - reasoning, _, content = tail.partition(self.end_token) + reasoning, _, content = model_output.partition(self.end_token) return reasoning or None, content or None + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + if not self.thinking_enabled: + cleaned = self._strip_reasoning_tags(delta_text) + if not cleaned: + return None + return DeltaMessage(content=cleaned) + + # Handle old templates where model may generate . + if self.start_token_id in delta_token_ids: + start_idx = delta_text.find(self.start_token) + if start_idx >= 0: + delta_text = delta_text[start_idx + len(self.start_token) :] + + if self.end_token_id in delta_token_ids: + end_index = delta_text.find(self.end_token) + if end_index >= 0: + reasoning = delta_text[:end_index] + content = delta_text[end_index + len(self.end_token) :] + if not reasoning and not content: + return None + return DeltaMessage( + reasoning=reasoning if reasoning else None, + content=content if content else None, + ) + # End token id is present but text was already stripped by backend. + return None + + if not delta_text: + return None + if self.end_token_id in previous_token_ids: + return DeltaMessage(content=delta_text) + return DeltaMessage(reasoning=delta_text) diff --git a/endpoints/OAI/utils/tools.py b/endpoints/OAI/utils/tools.py index 9758e4ce..3b1b4981 100644 --- a/endpoints/OAI/utils/tools.py +++ b/endpoints/OAI/utils/tools.py @@ -49,9 +49,11 @@ re.DOTALL, ) -# Matches BODY blocks +# Matches BODY blocks. +# Supports complete and partially-closed function sections to keep parity +# with vLLM behavior on generation cutoffs. FUNCTION_RE = re.compile( - r"(.*?)", + r"(.*?)|(.*)$", re.DOTALL, ) @@ -71,6 +73,35 @@ CODE_FENCE_RE = re.compile(r"^```(?:json)?\s*", re.MULTILINE) CODE_FENCE_END_RE = re.compile(r"\s*```\s*$", re.MULTILINE) +# Jamba / MiniMax tagged JSON blocks +TOOL_CALLS_TAG_RE = re.compile(r"(.*?)", re.DOTALL) + +# GLM-4.5 style: function_name\n{...} +GLM45_CALL_RE = re.compile( + r"\s*(?P[^\n<]+?)\s*\n(?P.*?)", + re.DOTALL, +) + +# MiniMax-M2 XML-like syntax +MINIMAX_M2_CALL_RE = re.compile( + r"(.*?)", + re.DOTALL, +) +MINIMAX_M2_INVOKE_RE = re.compile( + r".*?)>(?P.*?)", + re.DOTALL, +) +MINIMAX_M2_PARAM_RE = re.compile( + r".*?)>(?P.*?)", + re.DOTALL, +) + +# Seed-OSS tags +SEED_THINK_BLOCK_RE = re.compile(r".*?\s*", re.DOTALL) +SEED_THINK_UNCLOSED_RE = re.compile(r"(?!.*).*$", re.DOTALL) +SEED_TOOL_CALL_START = "" +SEED_TOOL_CALL_END = "" + # DeepSeek family patterns DEEPSEEK_V31_CALL_RE = re.compile( r"<|tool▁call▁begin|>(?P.*?)<|tool▁sep|>(?P.*?)<|tool▁call▁end|>", @@ -146,17 +177,32 @@ def _coerce_param_value(raw: str) -> Any: except (json.JSONDecodeError, ValueError): pass + # Handle Python-like literals often emitted by coder models, + # e.g. {'k': 'v'} for object parameters. + try: + return ast.literal_eval(stripped) + except (ValueError, SyntaxError): + pass + # Fall back to string — never eval() return stripped class ToolCallProcessor: _PARSER_DISPATCHER: Dict[str, Callable[[str], List[ToolCall]]] = {} + _MISSING_PARSER_WARNED: set[str] = set() # ------------------------------------------------------------------ # JSON normalization helpers # ------------------------------------------------------------------ + @staticmethod + def _strip_quotes(value: str) -> str: + value = value.strip() + if len(value) >= 2 and value[0] == value[-1] and value[0] in {"'", '"'}: + return value[1:-1] + return value + @staticmethod def _normalize_tool_calls(raw) -> list: """Normalize model-emitted tool call payloads into OAI-like objects. @@ -356,6 +402,44 @@ def _parse_mistral_json_tool_calls(payload: str) -> List[ToolCall]: return tool_calls + @staticmethod + def _parse_tagged_json_payload(payload: str) -> List[ToolCall]: + payload = payload.strip() + if not payload: + return [] + + # Prefer full JSON parse first (array/object). + try: + return ToolCallProcessor._build_tool_calls_from_normalized(json.loads(payload)) + except (json.JSONDecodeError, ValueError, TypeError, KeyError): + pass + + # Fallback: decode a sequence of JSON values. + decoded = ToolCallProcessor._decode_json_sequence(payload) + if decoded: + flattened = [] + for item in decoded: + if isinstance(item, list): + flattened.extend(item) + else: + flattened.append(item) + return ToolCallProcessor._build_tool_calls_from_normalized(flattened) + + # Fallback: line-delimited JSON objects. + lines = [line.strip().rstrip(",") for line in payload.splitlines() if line.strip()] + parsed_lines = [] + for line in lines: + if not line.startswith("{"): + continue + try: + parsed_lines.append(json.loads(line)) + except (json.JSONDecodeError, ValueError): + continue + if parsed_lines: + return ToolCallProcessor._build_tool_calls_from_normalized(parsed_lines) + + return [] + @staticmethod def _ast_to_literal(node: ast.AST) -> Any: """Safely convert AST literal nodes to Python primitives.""" @@ -637,6 +721,95 @@ def from_mistral(raw_text: str) -> List[ToolCall]: # Final fallback for malformed payloads. return ToolCallProcessor.from_json(split_payloads[-1]) + @staticmethod + def from_tagged_tool_calls(raw_text: str) -> List[ToolCall]: + """Parse ... tagged payloads (Jamba/MiniMax).""" + text = _strip_think_blocks(raw_text) + matches = TOOL_CALLS_TAG_RE.findall(text) + if not matches: + return [] + + tool_calls: List[ToolCall] = [] + for payload in matches: + tool_calls.extend(ToolCallProcessor._parse_tagged_json_payload(payload)) + return tool_calls + + @staticmethod + def from_glm45(raw_text: str) -> List[ToolCall]: + """Parse GLM-4.5 style name\\n{args} payloads.""" + text = _strip_think_blocks(raw_text) + tool_calls: List[ToolCall] = [] + for match in GLM45_CALL_RE.finditer(text): + name = match.group("name").strip() + if not name: + continue + args = ToolCallProcessor._coerce_argument_payload(match.group("args")) + tool_calls.append( + ToolCall( + function=Tool( + name=name, + arguments=args, + ) + ) + ) + return tool_calls + + @staticmethod + def from_minimax_m2(raw_text: str) -> List[ToolCall]: + """Parse MiniMax-M2 XML-like tool call payloads.""" + text = _strip_think_blocks(raw_text) + tool_calls: List[ToolCall] = [] + + for call in MINIMAX_M2_CALL_RE.finditer(text): + call_body = call.group(1) + for invoke in MINIMAX_M2_INVOKE_RE.finditer(call_body): + fn_name = ToolCallProcessor._strip_quotes(invoke.group("name")) + if not fn_name: + continue + + params: Dict[str, Any] = {} + invoke_body = invoke.group("body") + for param in MINIMAX_M2_PARAM_RE.finditer(invoke_body): + key = ToolCallProcessor._strip_quotes(param.group("name")) + if not key: + continue + value = _coerce_param_value(param.group("value")) + params[key] = value + + tool_calls.append( + ToolCall( + function=Tool( + name=fn_name, + arguments=json.dumps(params, ensure_ascii=False), + ) + ) + ) + + return tool_calls + + @staticmethod + def from_seed_oss(raw_text: str) -> List[ToolCall]: + """Parse Seed-OSS XML-style tool calls by adapting to Qwen3 XML format.""" + text = SEED_THINK_BLOCK_RE.sub("", raw_text) + text = SEED_THINK_UNCLOSED_RE.sub("", text) + text = text.replace(SEED_TOOL_CALL_START, "") + text = text.replace(SEED_TOOL_CALL_END, "") + return ToolCallProcessor.from_xml(text) + + @staticmethod + def from_olmo3(raw_text: str) -> List[ToolCall]: + """Parse OLMo3 pythonic tool calls, optionally wrapped by .""" + text = _strip_think_blocks(raw_text).strip() + wrapped = re.search(r"(.*?)", text, re.DOTALL) + if wrapped: + text = wrapped.group(1).strip() + + lines = [line.strip() for line in text.splitlines() if line.strip()] + if len(lines) > 1 and all(re.match(r"^[A-Za-z_]\w*\s*\(", line) for line in lines): + text = "[" + ", ".join(lines) + "]" + + return ToolCallProcessor.from_pythonic(text) + @staticmethod def from_json(tool_calls_str: str) -> List[ToolCall]: """Postprocess tool call JSON to a parseable class. @@ -697,7 +870,9 @@ def from_xml(raw_text: str) -> List[ToolCall]: for match in TOOL_CALL_BLOCK_RE.finditer(text): inner = match.group(1) for func_match in FUNCTION_RE.finditer(inner): - function_blocks.append((func_match.group(1), func_match.group(2))) + name = func_match.group(1) if func_match.group(1) is not None else func_match.group(3) + body = func_match.group(2) if func_match.group(2) is not None else func_match.group(4) + function_blocks.append((name, body)) # Find bare blocks NOT inside any wrapped region for func_match in FUNCTION_RE.finditer(text): @@ -708,7 +883,9 @@ def from_xml(raw_text: str) -> List[ToolCall]: "XML Parser: Found bare block without " " wrapper" ) - function_blocks.append((func_match.group(1), func_match.group(2))) + name = func_match.group(1) if func_match.group(1) is not None else func_match.group(3) + body = func_match.group(2) if func_match.group(2) is not None else func_match.group(4) + function_blocks.append((name, body)) if not function_blocks: logger.warning("XML Parser: No blocks found") @@ -809,15 +986,36 @@ def _parser_dispatcher() -> Dict[str, Callable[[str], List[ToolCall]]]: "deepseek_v3": ToolCallProcessor.from_deepseek_v3, "deepseek_v31": ToolCallProcessor.from_deepseek_v31, "deepseek_v32": ToolCallProcessor.from_deepseek_v32, + "ernie45": ToolCallProcessor.from_hermes, + "functiongemma": ToolCallProcessor.from_auto, + "gigachat3": ToolCallProcessor.from_auto, + "glm45": ToolCallProcessor.from_glm45, + "glm47": ToolCallProcessor.from_glm45, + "granite": ToolCallProcessor.from_json, + "granite-20b-fc": ToolCallProcessor.from_auto, "hermes": ToolCallProcessor.from_hermes, + "hunyuan_a13b": ToolCallProcessor.from_auto, + "internlm": ToolCallProcessor.from_auto, + "jamba": ToolCallProcessor.from_tagged_tool_calls, + "kimi_k2": ToolCallProcessor.from_auto, "llama": ToolCallProcessor.from_llama, "llama3_json": ToolCallProcessor.from_llama, "llama4_json": ToolCallProcessor.from_llama, + "llama4_pythonic": ToolCallProcessor.from_pythonic, + "longcat": ToolCallProcessor.from_hermes, + "minimax": ToolCallProcessor.from_tagged_tool_calls, + "minimax_m2": ToolCallProcessor.from_minimax_m2, "mistral": ToolCallProcessor.from_mistral, + "olmo3": ToolCallProcessor.from_olmo3, "openai": ToolCallProcessor.from_openai, + "phi4_mini_json": ToolCallProcessor.from_json, "pythonic": ToolCallProcessor.from_pythonic, "qwen3_coder": ToolCallProcessor.from_xml, "qwen3_xml": ToolCallProcessor.from_xml, + "seed_oss": ToolCallProcessor.from_seed_oss, + "step3": ToolCallProcessor.from_auto, + "step3p5": ToolCallProcessor.from_xml, + "xlam": ToolCallProcessor.from_auto, } return ToolCallProcessor._PARSER_DISPATCHER @@ -853,6 +1051,14 @@ def parse( else: if parsed: return parsed + elif canonical_key not in ToolCallProcessor._MISSING_PARSER_WARNED: + ToolCallProcessor._MISSING_PARSER_WARNED.add(canonical_key) + logger.warning( + "No dedicated tool parser handler for key '{}'; " + "falling back to format parser '{}'.", + canonical_key, + format, + ) if format == "xml": return ToolCallProcessor.from_xml(tool_calls_str) diff --git a/tests/parser_options_test.py b/tests/parser_options_test.py index 4f9648cc..4cc26e11 100644 --- a/tests/parser_options_test.py +++ b/tests/parser_options_test.py @@ -1,11 +1,49 @@ """Tests for vLLM-compatible parser option mapping.""" from endpoints.OAI.utils.parser_options import ( + TOOL_CALL_PARSER_FORMATS, list_tool_call_parsers, parser_uses_native_tool_generation, resolve_tool_call_parser_key, resolve_tool_call_format, ) +from endpoints.OAI.utils.tools import ToolCallProcessor + + +VLLM_CANONICAL_TOOL_PARSERS = { + "deepseek_v3", + "deepseek_v31", + "deepseek_v32", + "ernie45", + "glm45", + "glm47", + "granite", + "granite-20b-fc", + "hermes", + "hunyuan_a13b", + "internlm", + "jamba", + "kimi_k2", + "llama3_json", + "llama4_json", + "llama4_pythonic", + "longcat", + "minimax", + "minimax_m2", + "mistral", + "olmo3", + "openai", + "phi4_mini_json", + "pythonic", + "qwen3_coder", + "qwen3_xml", + "seed_oss", + "step3", + "step3p5", + "xlam", + "gigachat3", + "functiongemma", +} def test_parser_key_registry_contains_core_vllm_keys(): @@ -19,6 +57,25 @@ def test_parser_key_registry_contains_core_vllm_keys(): assert "llama" in parser_keys +def test_parser_key_registry_matches_vllm_set_plus_local_aliases(): + parser_keys = list_tool_call_parsers() + + # Canonical set should match current vLLM registry. + assert VLLM_CANONICAL_TOOL_PARSERS.issubset(parser_keys) + assert set(TOOL_CALL_PARSER_FORMATS.keys()) - {"auto"} == VLLM_CANONICAL_TOOL_PARSERS + + # Local compatibility alias. + assert "llama" in parser_keys + + +def test_every_configured_canonical_parser_has_dispatch_handler(): + dispatcher = ToolCallProcessor._parser_dispatcher() + canonical = set(TOOL_CALL_PARSER_FORMATS.keys()) - {"auto"} + + missing = sorted(canonical - set(dispatcher.keys())) + assert missing == [] + + def test_resolve_tool_call_format_uses_vllm_mapping(): assert resolve_tool_call_format("openai", "json") == "json" assert resolve_tool_call_format("qwen3_coder", "json") == "xml" diff --git a/tests/qwen3_reasoning_parser_test.py b/tests/qwen3_reasoning_parser_test.py new file mode 100644 index 00000000..f27996c3 --- /dev/null +++ b/tests/qwen3_reasoning_parser_test.py @@ -0,0 +1,114 @@ +"""Tests for Qwen3 reasoning parser parity with modern Qwen3.5 behavior.""" + +from endpoints.OAI.reasoning.qwen3_reasoning_parser import Qwen3ReasoningParser + + +class _FakeTokenizer: + def get_vocab(self): + return { + "": 101, + "": 102, + } + + +def _parser(enable_thinking=None) -> Qwen3ReasoningParser: + kwargs = {} + if enable_thinking is not None: + kwargs["chat_template_kwargs"] = {"enable_thinking": enable_thinking} + return Qwen3ReasoningParser(_FakeTokenizer(), **kwargs) + + +def test_non_stream_extract_thinking_mode_with_prefilled_start_token(): + parser = _parser(enable_thinking=True) + + reasoning, content = parser.extract_reasoning("reasoninganswer", request=None) + assert reasoning == "reasoning" + assert content == "answer" + + +def test_non_stream_extract_without_end_token_treated_as_content(): + parser = _parser(enable_thinking=True) + + reasoning, content = parser.extract_reasoning("reasoning only", request=None) + assert reasoning is None + assert content == "reasoning only" + + +def test_non_stream_extract_non_thinking_mode_content_only(): + parser = _parser(enable_thinking=False) + + reasoning, content = parser.extract_reasoning("hiddenvisible", request=None) + assert reasoning is None + assert content == "hiddenvisible" + + +def test_streaming_prefilled_think_mode_splits_reasoning_and_content(): + parser = _parser(enable_thinking=True) + + first = parser.extract_reasoning_streaming( + previous_text="", + current_text="analysis ", + delta_text="analysis ", + previous_token_ids=[], + current_token_ids=[11], + delta_token_ids=[11], + ) + assert first is not None + assert first.reasoning == "analysis " + assert first.content is None + + second = parser.extract_reasoning_streaming( + previous_text="analysis ", + current_text="analysis stepfinal ", + delta_text="stepfinal ", + previous_token_ids=[11], + current_token_ids=[11, 12, 102, 13], + delta_token_ids=[12, 102, 13], + ) + assert second is not None + assert second.reasoning == "step" + assert second.content == "final " + + third = parser.extract_reasoning_streaming( + previous_text="analysis stepfinal ", + current_text="analysis stepfinal answer", + delta_text="answer", + previous_token_ids=[11, 12, 102, 13], + current_token_ids=[11, 12, 102, 13, 14], + delta_token_ids=[14], + ) + assert third is not None + assert third.reasoning is None + assert third.content == "answer" + + +def test_streaming_non_thinking_mode_emits_content_only(): + parser = _parser(enable_thinking=False) + + delta = parser.extract_reasoning_streaming( + previous_text="", + current_text="plain output", + delta_text="plain output", + previous_token_ids=[], + current_token_ids=[11, 12], + delta_token_ids=[11, 12], + ) + assert delta is not None + assert delta.reasoning is None + assert delta.content == "plain output" + + +def test_streaming_strips_generated_start_token_when_present(): + parser = _parser(enable_thinking=True) + + delta = parser.extract_reasoning_streaming( + previous_text="", + current_text="reason", + delta_text="reason", + previous_token_ids=[], + current_token_ids=[101, 11], + delta_token_ids=[101, 11], + ) + assert delta is not None + assert delta.reasoning == "reason" + assert delta.content is None diff --git a/tests/reasoning_parser_registry_test.py b/tests/reasoning_parser_registry_test.py new file mode 100644 index 00000000..ccf1d46e --- /dev/null +++ b/tests/reasoning_parser_registry_test.py @@ -0,0 +1,38 @@ +"""Tests for reasoning parser registry parity with vLLM.""" + +from endpoints.OAI.reasoning import ReasoningParserManager + + +VLLM_CANONICAL_REASONING_PARSERS = { + "deepseek_r1", + "deepseek_v3", + "ernie45", + "glm45", + "openai_gptoss", + "granite", + "holo2", + "hunyuan_a13b", + "kimi_k2", + "minimax_m2", + "minimax_m2_append_think", + "mistral", + "olmo3", + "qwen3", + "seed_oss", + "step3", + "step3p5", +} + + +def test_reasoning_registry_contains_all_vllm_canonical_parsers(): + registered = set(ReasoningParserManager.list_registered()) + missing = sorted(VLLM_CANONICAL_REASONING_PARSERS - registered) + assert missing == [] + + +def test_reasoning_registry_allows_local_extensions(): + registered = set(ReasoningParserManager.list_registered()) + # Local compatibility/default parsers that may not exist in vLLM. + assert "identity" in registered + assert "basic" in registered + assert "exaone4" in registered diff --git a/tests/tool_parser_test.py b/tests/tool_parser_test.py index 8b93a726..7c57d5e1 100644 --- a/tests/tool_parser_test.py +++ b/tests/tool_parser_test.py @@ -37,6 +37,36 @@ def test_from_xml_parses_qwen3_coder_style_blocks(): assert _arguments_dict(parsed[0]) == {"city": "Seoul", "days": 3} +def test_from_xml_supports_single_quote_object_parameter(): + payload = ( + "" + "\n{'key': 'value'}\n" + "" + ) + + parsed = ToolCallProcessor.from_xml(payload) + + assert len(parsed) == 1 + assert parsed[0].function.name == "test_types" + assert _arguments_dict(parsed[0]) == {"obj_param": {"key": "value"}} + + +def test_from_xml_parses_incomplete_function_block_at_generation_cutoff(): + payload = ( + "I'll call a tool. " + "" + "\nSeoul\n" + "\n3\n" + # Missing on purpose + ) + + parsed = ToolCallProcessor.from_xml(payload) + + assert len(parsed) == 1 + assert parsed[0].function.name == "get_weather" + assert _arguments_dict(parsed[0]) == {"city": "Seoul", "days": 3} + + def test_from_auto_parses_json_inside_tool_call_wrapper(): payload = ( "" @@ -171,6 +201,115 @@ def test_parse_with_deepseek_v32_parser(): assert _arguments_dict(parsed[0]) == {"location": "Seoul", "days": 3} +def test_parse_with_ernie45_parser_handles_tool_call_json(): + payload = ( + "" + '{"name":"get_weather","arguments":{"city":"Seoul"}}' + "" + ) + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="ernie45") + + assert len(parsed) == 1 + assert parsed[0].function.name == "get_weather" + assert _arguments_dict(parsed[0]) == {"city": "Seoul"} + + +def test_parse_with_jamba_parser_handles_tool_calls_tag_array(): + payload = ( + "" + '[{"name":"get_weather","arguments":{"city":"Seoul","days":2}}]' + "" + ) + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="jamba") + + assert len(parsed) == 1 + assert parsed[0].function.name == "get_weather" + assert _arguments_dict(parsed[0]) == {"city": "Seoul", "days": 2} + + +def test_parse_with_minimax_parser_handles_line_delimited_json(): + payload = ( + "\n" + '{"name":"foo","arguments":{"x":1}}\n' + '{"name":"bar","arguments":{"y":2}}\n' + "" + ) + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="minimax") + + assert len(parsed) == 2 + assert parsed[0].function.name == "foo" + assert _arguments_dict(parsed[0]) == {"x": 1} + assert parsed[1].function.name == "bar" + assert _arguments_dict(parsed[1]) == {"y": 2} + + +def test_parse_with_glm45_parser_handles_name_and_json_body(): + payload = 'lookup\n{"id":42,"q":"tabbyapi"}' + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="glm45") + + assert len(parsed) == 1 + assert parsed[0].function.name == "lookup" + assert _arguments_dict(parsed[0]) == {"id": 42, "q": "tabbyapi"} + + +def test_parse_with_minimax_m2_parser_handles_invoke_parameters(): + payload = ( + '' + '42' + 'tabbyapi' + "" + ) + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="minimax_m2") + + assert len(parsed) == 1 + assert parsed[0].function.name == "lookup" + assert _arguments_dict(parsed[0]) == {"id": 42, "query": "tabbyapi"} + + +def test_parse_with_seed_oss_parser_handles_seed_xml(): + payload = ( + "" + "\nSeoul\n" + "\n2\n" + "" + ) + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="seed_oss") + + assert len(parsed) == 1 + assert parsed[0].function.name == "get_weather" + assert _arguments_dict(parsed[0]) == {"city": "Seoul", "days": 2} + + +def test_parse_with_olmo3_parser_handles_function_calls_wrapper(): + payload = "\nget_weather(city='Seoul', days=2)\n" + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="olmo3") + + assert len(parsed) == 1 + assert parsed[0].function.name == "get_weather" + assert _arguments_dict(parsed[0]) == {"city": "Seoul", "days": 2} + + +def test_parse_with_step3p5_parser_handles_qwen3_xml_shape(): + payload = ( + "" + "\ntabbyapi\n" + "" + ) + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="step3p5") + + assert len(parsed) == 1 + assert parsed[0].function.name == "search" + assert _arguments_dict(parsed[0]) == {"q": "tabbyapi"} + + def test_parse_with_openai_parser_handles_functions_recipient(): payload = ( '[{"recipient":"functions.get_weather","content":"{\\"city\\":\\"Seoul\\"}"}]' From 39845d2f7ef6e259a47dab1b404bab378a052da9 Mon Sep 17 00:00:00 2001 From: lesj0610 Date: Mon, 2 Mar 2026 12:08:34 +0900 Subject: [PATCH 17/19] feat(chat): add DeepSeek-VL2 built-in serializer --- backends/exllamav3/model.py | 13 +- common/image_util.py | 11 +- endpoints/OAI/router.py | 3 +- endpoints/OAI/utils/chat_completion.py | 145 ++++++++++++++++++++- tests/deepseek_vl2_chat_serializer_test.py | 114 ++++++++++++++++ 5 files changed, 274 insertions(+), 12 deletions(-) create mode 100644 tests/deepseek_vl2_chat_serializer_test.py diff --git a/backends/exllamav3/model.py b/backends/exllamav3/model.py index 3a3e797b..86d5f63a 100644 --- a/backends/exllamav3/model.py +++ b/backends/exllamav3/model.py @@ -301,10 +301,15 @@ async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs f'Using template "{self.prompt_template.name}" for chat completions.' ) else: - logger.warning( - "Chat completions are disabled because a prompt " - "template wasn't provided or auto-detected." - ) + if self.config.architecture == "DeepseekVLV2ForCausalLM": + logger.info( + "Using built-in DeepSeek-VL2 chat serializer for chat completions." + ) + else: + logger.warning( + "Chat completions are disabled because a prompt " + "template wasn't provided or auto-detected." + ) return self diff --git a/common/image_util.py b/common/image_util.py index 9790cfe9..ab5dfd50 100644 --- a/common/image_util.py +++ b/common/image_util.py @@ -49,4 +49,13 @@ async def get_image(url: str) -> Image: raise HTTPException(400, error_message) - return Image.open(io.BytesIO(bytes_image)) + try: + image = Image.open(io.BytesIO(bytes_image)) + image.load() + return image + except Exception as e: + error_message = handle_request_error( + "Failed to read or decode image data stream.", + exc_info=False, + ).error.message + raise HTTPException(400, error_message) diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 8f4e7a4e..d03f61ec 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -16,6 +16,7 @@ from endpoints.OAI.types.embedding import EmbeddingsRequest, EmbeddingsResponse from endpoints.OAI.utils.chat_completion import ( apply_chat_template, + chat_completions_available, generate_chat_completion, stream_generate_chat_completion, ) @@ -113,7 +114,7 @@ async def chat_completion_request( else: await check_model_container() - if model.container.prompt_template is None: + if not chat_completions_available(): error_message = handle_request_error( "Chat completions are disabled because a prompt template is not set.", exc_info=False, diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 62eb899a..8beb8f8e 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -49,6 +49,9 @@ class _StreamReasoningState: token_ids: List[int] = field(default_factory=list) +DEEPSEEK_VL2_ARCH = "DeepseekVLV2ForCausalLM" + + class _TokenizerAdapter: """Expose the minimal tokenizer interface required by reasoning parsers.""" @@ -94,6 +97,24 @@ def _get_tokenizer_mode() -> str: return tokenizer_mode +def _uses_builtin_chat_serializer() -> bool: + container = model.container + cfg = getattr(container, "config", None) + return bool(cfg and getattr(cfg, "architecture", None) == DEEPSEEK_VL2_ARCH) + + +def chat_completions_available() -> bool: + return bool(getattr(model.container, "prompt_template", None)) or _uses_builtin_chat_serializer() + + +def _get_template_tooling_defaults() -> tuple[Optional[str], str]: + prompt_template = getattr(model.container, "prompt_template", None) + metadata = getattr(prompt_template, "metadata", None) + if metadata is None: + return None, "json" + return metadata.tool_start, metadata.tool_call_format + + def _truncate_mistral_tool_ids(message_dicts: List[dict]) -> None: """Match vLLM mistral_common behavior by truncating tool IDs to 9 chars.""" for message in message_dicts: @@ -498,6 +519,9 @@ def _build_tool_call_chunks( async def _append_template_metadata(data: ChatCompletionRequest, template_vars: dict): """Adding metadata is a one-time process.""" + if model.container.prompt_template is None: + return + template_metadata = await model.container.prompt_template.extract_metadata( template_vars ) @@ -569,6 +593,91 @@ async def format_messages_with_template( return prompt, mm_embeddings, template_vars +async def format_messages_with_builtin_serializer( + messages: List[ChatCompletionMessage], +): + """Serialize messages for models that define their own chat protocol.""" + + if not _uses_builtin_chat_serializer(): + raise HTTPException(400, "No built-in chat serializer is available.") + + mm_embeddings = MultimodalEmbeddingWrapper() if model.container.use_vision else None + pending_system: List[str] = [] + segments: List[str] = [] + last_non_system_role: Optional[str] = None + + for message in messages: + content = message.content + if isinstance(content, list): + concatenated_content = "" + previous_part_type: Optional[str] = None + for part in content: + if part.type == "text" and part.text: + if previous_part_type == "image_url" and concatenated_content: + concatenated_content += "\n" + concatenated_content += part.text + elif part.type == "image_url" and mm_embeddings: + if ( + previous_part_type == "text" + and concatenated_content + and not concatenated_content.endswith("\n") + ): + concatenated_content += "\n" + elif previous_part_type == "image_url" and concatenated_content: + concatenated_content += "\n" + await mm_embeddings.add(part.image_url.url) + concatenated_content += mm_embeddings.text_alias[-1] + previous_part_type = part.type + content = concatenated_content + + normalized_content = content or "" + role = (message.role or "user").lower() + + if role == "system": + if normalized_content: + pending_system.append(normalized_content) + continue + + if role == "user": + if pending_system: + normalized_content = "\n\n".join( + pending_system + ([normalized_content] if normalized_content else []) + ) + pending_system.clear() + + segments.append(f"<|User|>: {normalized_content}") + last_non_system_role = "user" + continue + + if role == "assistant": + segments.append(f"<|Assistant|>: {normalized_content}") + last_non_system_role = "assistant" + continue + + if role in {"tool", "tool_results"}: + tool_payload = normalized_content + if message.tool_call_id: + prefix = f"[tool_call_id={message.tool_call_id}]" + tool_payload = ( + f"{prefix}\n{tool_payload}" if tool_payload else prefix + ) + segments.append(f"<|User|>: {tool_payload}") + last_non_system_role = "user" + continue + + raise HTTPException( + 400, + f"Unsupported role '{message.role}' for built-in DeepSeek-VL2 chat serialization.", + ) + + if pending_system: + segments.append(f"<|User|>: {'\\n\\n'.join(pending_system)}") + last_non_system_role = "user" + + prompt = "\n\n".join(segments) + return prompt, mm_embeddings, {"last_non_system_role": last_non_system_role} + + async def apply_chat_template(data: ChatCompletionRequest): """ Compile the prompt and get any additional stop strings from the template. @@ -580,6 +689,33 @@ async def apply_chat_template(data: ChatCompletionRequest): if _get_tokenizer_mode() == "mistral": _sanitize_mistral_tool_choice_ids(data) + if _uses_builtin_chat_serializer(): + if data.tools or data.functions or data.tool_choice not in (None, "none"): + raise HTTPException( + 400, + "Tool calling is not supported for the built-in DeepSeek-VL2 chat serializer.", + ) + + prompt, mm_embeddings, serializer_state = ( + await format_messages_with_builtin_serializer(data.messages) + ) + + last_non_system_role = serializer_state.get("last_non_system_role") + + if data.add_generation_prompt and last_non_system_role != "assistant": + prompt = f"{prompt}\n\n<|Assistant|>: " if prompt else "<|Assistant|>: " + + if data.response_prefix: + if data.add_generation_prompt: + prompt += data.response_prefix + else: + logger.warning( + "Could not add response prefix because " + "add_generation_prompt is False" + ) + + return prompt, mm_embeddings + try: data.template_vars.update( { @@ -643,8 +779,7 @@ async def stream_generate_chat_completion( abort_event = asyncio.Event() gen_queue = asyncio.Queue() gen_tasks: List[asyncio.Task] = [] - tool_start = model.container.prompt_template.metadata.tool_start - default_tool_call_format = model.container.prompt_template.metadata.tool_call_format + tool_start, default_tool_call_format = _get_template_tooling_defaults() disconnect_task = asyncio.create_task(request_disconnect_loop(request)) try: @@ -844,8 +979,7 @@ async def generate_chat_completion( model_path: pathlib.Path, ): gen_tasks: List[asyncio.Task] = [] - tool_start = model.container.prompt_template.metadata.tool_start - default_tool_call_format = model.container.prompt_template.metadata.tool_call_format + tool_start, default_tool_call_format = _get_template_tooling_defaults() tool_call_format = _validate_and_get_tool_call_format( data, default_tool_call_format ) @@ -930,9 +1064,8 @@ async def generate_tool_calls( tool_call_format: Optional[str] = None, ): gen_tasks: List[asyncio.Task] = [] - tool_start = model.container.prompt_template.metadata.tool_start + tool_start, default_tool_call_format = _get_template_tooling_defaults() if tool_call_format is None: - default_tool_call_format = model.container.prompt_template.metadata.tool_call_format tool_call_format = _validate_and_get_tool_call_format( data, default_tool_call_format ) diff --git a/tests/deepseek_vl2_chat_serializer_test.py b/tests/deepseek_vl2_chat_serializer_test.py new file mode 100644 index 00000000..3cdbf732 --- /dev/null +++ b/tests/deepseek_vl2_chat_serializer_test.py @@ -0,0 +1,114 @@ +"""Regression tests for DeepSeek-VL2 built-in chat serialization.""" + +import asyncio +import sys +from pathlib import Path +from types import SimpleNamespace + +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +from endpoints.OAI.utils import chat_completion as cc + + +class _FakeMultimodalEmbeddingWrapper: + """Minimal multimodal stub that emits stable text aliases.""" + + def __init__(self): + self.text_alias = [] + self.urls = [] + + async def add(self, url: str): + self.urls.append(url) + self.text_alias.append(f"") + + +def _message(role: str, content, tool_call_id=None): + return SimpleNamespace(role=role, content=content, tool_call_id=tool_call_id) + + +def _text_part(text: str): + return SimpleNamespace(type="text", text=text, image_url=None) + + +def _image_part(url: str): + return SimpleNamespace( + type="image_url", + text=None, + image_url=SimpleNamespace(url=url), + ) + + +def _install_deepseek_vl2_serializer(monkeypatch): + container = SimpleNamespace( + config=SimpleNamespace(architecture=cc.DEEPSEEK_VL2_ARCH), + use_vision=True, + ) + monkeypatch.setattr(cc.model, "container", container, raising=False) + monkeypatch.setattr( + cc, + "MultimodalEmbeddingWrapper", + _FakeMultimodalEmbeddingWrapper, + ) + + +def test_builtin_serializer_supports_official_multi_image_interleaving(monkeypatch): + _install_deepseek_vl2_serializer(monkeypatch) + + messages = [ + _message( + "user", + [ + _text_part("This is image_1: "), + _image_part("image://1"), + _text_part("This is image_2: "), + _image_part("image://2"), + _text_part("This is image_3: "), + _image_part("image://3"), + _text_part(" Can you tell me what are in the images?"), + ], + ), + _message("assistant", ""), + ] + + prompt, mm_embeddings, serializer_state = asyncio.run( + cc.format_messages_with_builtin_serializer(messages) + ) + + assert prompt == ( + "<|User|>: This is image_1: \n\n" + "This is image_2: \n\n" + "This is image_3: \n\n" + " Can you tell me what are in the images?\n\n" + "<|Assistant|>: " + ) + assert mm_embeddings is not None + assert mm_embeddings.urls == ["image://1", "image://2", "image://3"] + assert mm_embeddings.text_alias == ["", "", ""] + assert serializer_state["last_non_system_role"] == "assistant" + + +def test_builtin_serializer_preserves_grounding_markup(monkeypatch): + _install_deepseek_vl2_serializer(monkeypatch) + + grounding_text = ( + "<|ref|>The red square<|/ref|> is the target. " + "<|grounding|><|det|>[0.1,0.2,0.9,0.9]<|/det|><|/grounding|>" + ) + messages = [ + _message( + "user", + [ + _image_part("image://1"), + _text_part(grounding_text), + ], + ), + ] + + prompt, mm_embeddings, serializer_state = asyncio.run( + cc.format_messages_with_builtin_serializer(messages) + ) + + assert prompt == f"<|User|>: \n{grounding_text}" + assert mm_embeddings is not None + assert mm_embeddings.urls == ["image://1"] + assert serializer_state["last_non_system_role"] == "user" From ae72f8deb0c13270b9e61ff879812368198848a5 Mon Sep 17 00:00:00 2001 From: lesj0610 Date: Mon, 2 Mar 2026 13:34:11 +0900 Subject: [PATCH 18/19] fix(reasoning): restore qwen3-next parser parity --- .../OAI/reasoning/qwen3_reasoning_parser.py | 26 +++++++++++++++++-- tests/qwen3_reasoning_parser_test.py | 24 +++++++++++++++++ 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/endpoints/OAI/reasoning/qwen3_reasoning_parser.py b/endpoints/OAI/reasoning/qwen3_reasoning_parser.py index 8d6c3be9..0b7cb417 100644 --- a/endpoints/OAI/reasoning/qwen3_reasoning_parser.py +++ b/endpoints/OAI/reasoning/qwen3_reasoning_parser.py @@ -28,8 +28,12 @@ def __init__(self, tokenizer: Any, *args, **kwargs): if enable_thinking is None: enable_thinking = chat_kwargs.get("thinking") - # Keep default behavior when no explicit switch is provided. - self.thinking_enabled = True if enable_thinking is None else bool(enable_thinking) + # Only force "prefilled " behavior when the template explicitly + # exposes a thinking switch. Templates like Qwen3-Next's tokenizer + # config do not, and should fall back to normal tag-based parsing. + self.thinking_enabled = ( + None if enable_thinking is None else bool(enable_thinking) + ) @property def start_token(self) -> str: @@ -47,6 +51,14 @@ def _strip_reasoning_tags(self, text: str) -> str: def extract_reasoning( self, model_output: str, request: Any ) -> tuple[str | None, str | None]: + if self.thinking_enabled is None: + if self.start_token not in model_output or self.end_token not in model_output: + return None, model_output or None + + _, _, tail = model_output.partition(self.start_token) + reasoning, _, content = tail.partition(self.end_token) + return reasoning or None, content or None + if not self.thinking_enabled: content = self._strip_reasoning_tags(model_output) return None, content or None @@ -73,6 +85,16 @@ def extract_reasoning_streaming( current_token_ids: Sequence[int], delta_token_ids: Sequence[int], ) -> DeltaMessage | None: + if self.thinking_enabled is None: + return super().extract_reasoning_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + ) + if not self.thinking_enabled: cleaned = self._strip_reasoning_tags(delta_text) if not cleaned: diff --git a/tests/qwen3_reasoning_parser_test.py b/tests/qwen3_reasoning_parser_test.py index f27996c3..53f89f07 100644 --- a/tests/qwen3_reasoning_parser_test.py +++ b/tests/qwen3_reasoning_parser_test.py @@ -42,6 +42,14 @@ def test_non_stream_extract_non_thinking_mode_content_only(): assert content == "hiddenvisible" +def test_non_stream_without_explicit_thinking_switch_treats_plain_text_as_content(): + parser = _parser() + + reasoning, content = parser.extract_reasoning("OK", request=None) + assert reasoning is None + assert content == "OK" + + def test_streaming_prefilled_think_mode_splits_reasoning_and_content(): parser = _parser(enable_thinking=True) @@ -98,6 +106,22 @@ def test_streaming_non_thinking_mode_emits_content_only(): assert delta.content == "plain output" +def test_streaming_without_explicit_thinking_switch_emits_content_only(): + parser = _parser() + + delta = parser.extract_reasoning_streaming( + previous_text="", + current_text="OK", + delta_text="OK", + previous_token_ids=[], + current_token_ids=[11], + delta_token_ids=[11], + ) + assert delta is not None + assert delta.reasoning is None + assert delta.content == "OK" + + def test_streaming_strips_generated_start_token_when_present(): parser = _parser(enable_thinking=True) From 19eb1c86da3189293142a803b46c59d8e94baf1c Mon Sep 17 00:00:00 2001 From: lesj0610 Date: Mon, 2 Mar 2026 16:45:04 +0900 Subject: [PATCH 19/19] feat(config): add exllamav3 attention backend policy --- backends/exllamav3/model.py | 75 ++++++++++++++++++++++++++++----- common/config_models.py | 34 +++++++++------ common/optional_dependencies.py | 15 ++++--- config_sample.yml | 6 +++ endpoints/core/types/model.py | 11 ++++- 5 files changed, 112 insertions(+), 29 deletions(-) diff --git a/backends/exllamav3/model.py b/backends/exllamav3/model.py index 86d5f63a..a804fc50 100644 --- a/backends/exllamav3/model.py +++ b/backends/exllamav3/model.py @@ -20,6 +20,7 @@ Model, Tokenizer, ) +from exllamav3.modules.attn import has_flash_attn_backend, has_flashinfer_backend from exllamav3.cache import CacheLayer_quant from backends.exllamav3.grammar import ExLlamaV3Grammar from loguru import logger @@ -34,7 +35,7 @@ log_metrics, log_prompt, ) -from common.hardware import hardware_supports_flashinfer +from common.hardware import hardware_supports_flash_attn, hardware_supports_flashinfer from common.health import HealthManager from common.multimodal import MultimodalEmbeddingWrapper from common.optional_dependencies import check_package_version @@ -95,6 +96,8 @@ class ExllamaV3Container(BaseModelContainer): max_batch_size: Optional[int] = None tokenizer_mode: str = "auto" mistral_tokenizer_models: List[str] = [] + attention_backend: str = "auto" + resolved_attention_backend: Optional[str] = None # Required methods @classmethod @@ -114,10 +117,17 @@ async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs # Make sure ExllamaV3 is up to date check_package_version("exllamav3", "0.0.7") - check_package_version("flashinfer-python", "0.6.3") self.model_dir = model_directory self.hf_model = hf_model + requested_attention_backend = unwrap(kwargs.get("attention_backend"), "auto") + if requested_attention_backend not in ("auto", "flash_attn", "flashinfer"): + raise ValueError( + "Invalid attention_backend " + f"'{requested_attention_backend}'. " + "Expected one of: auto, flash_attn, flashinfer." + ) + self.attention_backend = requested_attention_backend requested_tokenizer_mode, mode_message = normalize_tokenizer_mode( coalesce(kwargs.get("tokenizer_mode"), config.model.tokenizer_mode, "auto") ) @@ -245,17 +255,59 @@ async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs value / 1024 for value in autosplit_reserve_megabytes ] - if not hardware_supports_flashinfer(gpu_device_list): - gpu_unsupported_message = ( - "Unable to run ExllamaV3 because an unsupported GPU is " - "found in this configuration. \n" - "All GPUs must be Ampere " - "(30 series) or newer for flashinfer. AMD GPUs are not supported." + flash_attn_available = ( + has_flash_attn_backend() and hardware_supports_flash_attn(gpu_device_list) + ) + flashinfer_available = ( + has_flashinfer_backend() and hardware_supports_flashinfer(gpu_device_list) + ) + + def _unsupported_backend_message(backend_name: str) -> str: + package_name = ( + "flash_attn" if backend_name == "flash_attn" else "flashinfer-python" + ) + return ( + f"Unable to use the requested ExllamaV3 attention backend " + f"'{backend_name}'.\n" + f"The required package ({package_name}) is missing or unsupported " + "on the selected GPUs. All GPUs must be Ampere (30 series) or " + "newer, CUDA only." ) - logger.warning(gpu_unsupported_message) + if self.attention_backend == "flash_attn": + if not flash_attn_available: + message = _unsupported_backend_message("flash_attn") + logger.warning(message) + raise RuntimeError(message) + self.resolved_attention_backend = "flash_attn" + elif self.attention_backend == "flashinfer": + if not flashinfer_available: + message = _unsupported_backend_message("flashinfer") + logger.warning(message) + raise RuntimeError(message) + check_package_version("flashinfer-python", "0.6.3") + self.resolved_attention_backend = "flashinfer" + else: + if flash_attn_available: + self.resolved_attention_backend = "flash_attn" + elif flashinfer_available: + check_package_version("flashinfer-python", "0.6.3") + self.resolved_attention_backend = "flashinfer" + else: + message = ( + "Unable to run ExllamaV3 because no supported cache-capable " + "attention backend is available.\n" + "Install flash_attn or flashinfer-python, and use Ampere-class " + "CUDA GPUs or newer." + ) + logger.warning(message) + raise RuntimeError(message) - raise RuntimeError(gpu_unsupported_message) + logger.info( + "Attention backend policy: {} (resolved: {})", + self.attention_backend, + self.resolved_attention_backend, + ) # Store the max_seq_len arg user_max_seq_len = kwargs.get("max_seq_len") @@ -410,6 +462,8 @@ def model_info(self) -> ModelCard: chunk_size=self.chunk_size, tokenizer_mode=self.tokenizer_mode, mistral_tokenizer_models=self.mistral_tokenizer_models, + attention_backend=self.attention_backend, + resolved_attention_backend=self.resolved_attention_backend, use_vision=self.use_vision, ) @@ -560,6 +614,7 @@ async def create_generator(self): tokenizer=self.tokenizer, max_batch_size=self.max_batch_size, max_chunk_size=self.chunk_size, + attn_mode=self.resolved_attention_backend or self.attention_backend, ) # Update the state of the container var diff --git a/common/config_models.py b/common/config_models.py index c61801b8..6265f4e6 100644 --- a/common/config_models.py +++ b/common/config_models.py @@ -9,8 +9,9 @@ from typing import List, Literal, Optional, Union -CACHE_SIZES = Literal["FP16", "Q8", "Q6", "Q4"] -CACHE_TYPE = Union[CACHE_SIZES, constr(pattern=r"^[2-8]\s*,\s*[2-8]$")] +CACHE_SIZES = Literal["FP16", "Q8", "Q6", "Q4"] +CACHE_TYPE = Union[CACHE_SIZES, constr(pattern=r"^[2-8]\s*,\s*[2-8]$")] +ATTENTION_BACKENDS = Literal["auto", "flash_attn", "flashinfer"] class Metadata(BaseModel): @@ -165,16 +166,25 @@ class ModelConfig(BaseConfigModel): "Example: ['max_seq_len', 'cache_mode']." ), ) - backend: Optional[str] = Field( - None, - description=( - "Backend to use for this model (auto-detect if not specified)\n" - "Options: exllamav2, exllamav3" - ), - ) - max_seq_len: Optional[int] = Field( - None, - description=( + backend: Optional[str] = Field( + None, + description=( + "Backend to use for this model (auto-detect if not specified)\n" + "Options: exllamav2, exllamav3" + ), + ) + attention_backend: Optional[ATTENTION_BACKENDS] = Field( + "auto", + description=( + "Attention backend policy for exllamav3 (default: auto).\n" + "Options: auto, flash_attn, flashinfer.\n" + "This chooses the cache-capable attention backend at model init time.\n" + "SDPA remains an internal fallback for non-cache paths and unsupported cases." + ), + ) + max_seq_len: Optional[int] = Field( + None, + description=( "Max sequence length (default: 4096).\n" "Set to -1 to fetch from the model's config.json" ), diff --git a/common/optional_dependencies.py b/common/optional_dependencies.py index e4195bdd..811bc67a 100644 --- a/common/optional_dependencies.py +++ b/common/optional_dependencies.py @@ -11,12 +11,13 @@ __all__ = ["dependencies"] -class DependenciesModel(BaseModel): - """Model of which optional dependencies are installed.""" - +class DependenciesModel(BaseModel): + """Model of which optional dependencies are installed.""" + torch: bool exllamav2: bool exllamav3: bool + flash_attn: bool flashinfer: bool infinity_emb: bool sentence_transformers: bool @@ -26,10 +27,12 @@ class DependenciesModel(BaseModel): def extras(self) -> bool: return self.infinity_emb and self.sentence_transformers - @computed_field - @property + @computed_field + @property def inference(self) -> bool: - return self.torch and (self.exllamav2 or (self.exllamav3 and self.flashinfer)) + return self.torch and ( + self.exllamav2 or (self.exllamav3 and (self.flash_attn or self.flashinfer)) + ) def is_installed(package_name: str) -> bool: diff --git a/config_sample.yml b/config_sample.yml index e7d75424..bbb0188d 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -78,6 +78,12 @@ model: # Options: exllamav2, exllamav3 backend: + # Attention backend policy for exllamav3 (default: auto). + # Options: auto, flash_attn, flashinfer + # Picks the cache-capable attention backend once at model init time. + # SDPA remains an internal fallback for non-cache paths and unsupported cases. + attention_backend: auto + # Max sequence length (default: min(max_position_embeddings, cache_size)). # Set to -1 to fetch from the model's config.json max_seq_len: diff --git a/endpoints/core/types/model.py b/endpoints/core/types/model.py index d090fce8..6f4e66a6 100644 --- a/endpoints/core/types/model.py +++ b/endpoints/core/types/model.py @@ -4,7 +4,7 @@ from time import time from typing import List, Literal, Optional, Union -from common.config_models import LoggingConfig +from common.config_models import ATTENTION_BACKENDS, LoggingConfig from common.tabby_config import config @@ -22,6 +22,8 @@ class ModelCardParameters(BaseModel): chunk_size: Optional[int] = 2048 tokenizer_mode: Optional[str] = "auto" mistral_tokenizer_models: Optional[List[str]] = Field(default_factory=list) + attention_backend: Optional[str] = "auto" + resolved_attention_backend: Optional[str] = None prompt_template: Optional[str] = None prompt_template_content: Optional[str] = None use_vision: Optional[bool] = False @@ -81,6 +83,13 @@ class ModelLoadRequest(BaseModel): description="Backend to use", default=None, ) + attention_backend: Optional[ATTENTION_BACKENDS] = Field( + description=( + "Attention backend policy for exllamav3 " + "(auto, flash_attn, flashinfer)" + ), + default=None, + ) max_seq_len: Optional[int] = Field( description="Leave this blank to use the model's base sequence length", default=None,