diff --git a/LICENSES/Apache-2.0.txt b/LICENSES/Apache-2.0.txt new file mode 100644 index 00000000..d6456956 --- /dev/null +++ b/LICENSES/Apache-2.0.txt @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/THIRD_PARTY_LICENSES.md b/THIRD_PARTY_LICENSES.md new file mode 100644 index 00000000..f732bb73 --- /dev/null +++ b/THIRD_PARTY_LICENSES.md @@ -0,0 +1,37 @@ +# Third-Party License Notices + +This repository is licensed under AGPL-3.0 (see `LICENSE`). + +The following files include code adapted from the vLLM project and are +licensed under Apache License 2.0: + +- `endpoints/OAI/reasoning/abs_reasoning_parsers.py` +- `endpoints/OAI/reasoning/basic_parsers.py` +- `endpoints/OAI/reasoning/deepseek_r1_reasoning_parser.py` +- `endpoints/OAI/reasoning/deepseek_v3_reasoning_parser.py` +- `endpoints/OAI/reasoning/ernie45_reasoning_parser.py` +- `endpoints/OAI/reasoning/exaone4_reasoning_parser.py` +- `endpoints/OAI/reasoning/glm4_moe_reasoning_parser.py` +- `endpoints/OAI/reasoning/gptoss_reasoning_parser.py` +- `endpoints/OAI/reasoning/granite_reasoning_parser.py` +- `endpoints/OAI/reasoning/holo2_reasoning_parser.py` +- `endpoints/OAI/reasoning/hunyuan_a13b_reasoning_parser.py` +- `endpoints/OAI/reasoning/identity_reasoning_parser.py` +- `endpoints/OAI/reasoning/kimi_k2_reasoning_parser.py` +- `endpoints/OAI/reasoning/minimax_m2_reasoning_parser.py` +- `endpoints/OAI/reasoning/mistral_reasoning_parser.py` +- `endpoints/OAI/reasoning/olmo3_reasoning_parser.py` +- `endpoints/OAI/reasoning/qwen3_reasoning_parser.py` +- `endpoints/OAI/reasoning/seedoss_reasoning_parser.py` +- `endpoints/OAI/reasoning/step3_reasoning_parser.py` +- `endpoints/OAI/reasoning/step3p5_reasoning_parser.py` +- `endpoints/OAI/reasoning/__init__.py` +- `endpoints/OAI/utils/parser_options.py` +- `endpoints/OAI/utils/tools.py` +- `templates/tool_calls/qwen3_coder.jinja` + +Source project: +- vLLM: https://github.com/vllm-project/vllm + +The Apache-2.0 license text is provided at: +- `LICENSES/Apache-2.0.txt` diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 6e59dbe3..b313ac89 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -46,7 +46,12 @@ from common.multimodal import MultimodalEmbeddingWrapper from common.optional_dependencies import check_package_version from common.sampling import BaseSamplerRequest +from common.tabby_config import config from common.templating import PromptTemplate, find_prompt_template +from common.tokenizer_modes import ( + normalize_tokenizer_mode, + should_enable_mistral_tokenizer_mode, +) from common.transformers_utils import HFModel from common.utils import calculate_rope_alpha, coalesce, unwrap from endpoints.core.types.model import ModelCard, ModelCardParameters @@ -83,6 +88,8 @@ class ExllamaV2Container(BaseModelContainer): cache_mode: str = "FP16" draft_cache_mode: str = "FP16" max_batch_size: Optional[int] = None + tokenizer_mode: str = "auto" + mistral_tokenizer_models: List[str] = [] # GPU split vars gpu_split: List[float] = [] @@ -120,6 +127,29 @@ async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs self.model_dir = model_directory self.config.model_dir = str(model_directory.resolve()) self.hf_model = hf_model + self.tokenizer_mode, mode_message = normalize_tokenizer_mode( + coalesce(kwargs.get("tokenizer_mode"), config.model.tokenizer_mode, "auto") + ) + if mode_message: + logger.warning(mode_message) + + if self.tokenizer_mode == "mistral": + mistral_tokenizer_models = coalesce( + kwargs.get("mistral_tokenizer_models"), + config.model.mistral_tokenizer_models, + [], + ) + self.mistral_tokenizer_models = list(mistral_tokenizer_models) + if should_enable_mistral_tokenizer_mode( + model_directory, mistral_tokenizer_models + ): + logger.info("Using tokenizer_mode='mistral' compatibility path.") + else: + logger.warning( + "tokenizer_mode='mistral' requested but model does not appear " + "to support mistral tokenizer mode. Falling back to default mode." + ) + self.tokenizer_mode = "auto" # Make the max seq len 4096 before preparing the config # This is a better default than 2048 @@ -440,6 +470,8 @@ def model_info(self): max_batch_size=self.max_batch_size, cache_mode=self.cache_mode, chunk_size=self.config.max_input_len, + tokenizer_mode=self.tokenizer_mode, + mistral_tokenizer_models=self.mistral_tokenizer_models, use_vision=self.use_vision, draft=draft_model_card, ) @@ -1426,12 +1458,15 @@ async def generate_gen( full_response += chunk chunk_tokens = result.get("token_ids") + token_ids = [] if chunk_tokens is not None: + token_ids = chunk_tokens.flatten().tolist() generated_tokens += chunk_tokens.size(dim=0) generation = { "request_id": request_id, "text": chunk, + "token_ids": token_ids, "prompt_tokens": context_len, "generated_tokens": generated_tokens, "offset": len(full_response), diff --git a/backends/exllamav3/model.py b/backends/exllamav3/model.py index 50c30450..a804fc50 100644 --- a/backends/exllamav3/model.py +++ b/backends/exllamav3/model.py @@ -20,6 +20,7 @@ Model, Tokenizer, ) +from exllamav3.modules.attn import has_flash_attn_backend, has_flashinfer_backend from exllamav3.cache import CacheLayer_quant from backends.exllamav3.grammar import ExLlamaV3Grammar from loguru import logger @@ -34,12 +35,17 @@ log_metrics, log_prompt, ) -from common.hardware import hardware_supports_flash_attn +from common.hardware import hardware_supports_flash_attn, hardware_supports_flashinfer from common.health import HealthManager from common.multimodal import MultimodalEmbeddingWrapper from common.optional_dependencies import check_package_version from common.sampling import BaseSamplerRequest +from common.tabby_config import config from common.templating import PromptTemplate, find_prompt_template +from common.tokenizer_modes import ( + normalize_tokenizer_mode, + should_enable_mistral_tokenizer_mode, +) from common.transformers_utils import HFModel from common.utils import coalesce, unwrap from endpoints.core.types.model import ModelCard, ModelCardParameters @@ -88,6 +94,10 @@ class ExllamaV3Container(BaseModelContainer): chunk_size: int = 2048 max_rq_tokens: Optional[int] = 2048 max_batch_size: Optional[int] = None + tokenizer_mode: str = "auto" + mistral_tokenizer_models: List[str] = [] + attention_backend: str = "auto" + resolved_attention_backend: Optional[str] = None # Required methods @classmethod @@ -110,6 +120,40 @@ async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs self.model_dir = model_directory self.hf_model = hf_model + requested_attention_backend = unwrap(kwargs.get("attention_backend"), "auto") + if requested_attention_backend not in ("auto", "flash_attn", "flashinfer"): + raise ValueError( + "Invalid attention_backend " + f"'{requested_attention_backend}'. " + "Expected one of: auto, flash_attn, flashinfer." + ) + self.attention_backend = requested_attention_backend + requested_tokenizer_mode, mode_message = normalize_tokenizer_mode( + coalesce(kwargs.get("tokenizer_mode"), config.model.tokenizer_mode, "auto") + ) + if mode_message: + logger.warning(mode_message) + + mistral_tokenizer_models = coalesce( + kwargs.get("mistral_tokenizer_models"), + config.model.mistral_tokenizer_models, + [], + ) + self.mistral_tokenizer_models = list(mistral_tokenizer_models) + if requested_tokenizer_mode == "mistral": + if should_enable_mistral_tokenizer_mode( + model_directory, mistral_tokenizer_models + ): + logger.info("Using tokenizer_mode='mistral' compatibility path.") + else: + logger.warning( + "tokenizer_mode='mistral' requested but model does not appear " + "to use Mistral tokenizer assets. Falling back to default mode." + ) + requested_tokenizer_mode = "auto" + + self.tokenizer_mode = requested_tokenizer_mode + self.config = Config.from_directory(str(model_directory.resolve())) self.model = Model.from_config(self.config) self.tokenizer = Tokenizer.from_config(self.config) @@ -211,17 +255,59 @@ async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs value / 1024 for value in autosplit_reserve_megabytes ] - if not hardware_supports_flash_attn(gpu_device_list): - gpu_unsupported_message = ( - "Unable to run ExllamaV3 because an unsupported GPU is " - "found in this configuration. \n" - "All GPUs must be ampere " - "(30 series) or newer. AMD GPUs are not supported." + flash_attn_available = ( + has_flash_attn_backend() and hardware_supports_flash_attn(gpu_device_list) + ) + flashinfer_available = ( + has_flashinfer_backend() and hardware_supports_flashinfer(gpu_device_list) + ) + + def _unsupported_backend_message(backend_name: str) -> str: + package_name = ( + "flash_attn" if backend_name == "flash_attn" else "flashinfer-python" + ) + return ( + f"Unable to use the requested ExllamaV3 attention backend " + f"'{backend_name}'.\n" + f"The required package ({package_name}) is missing or unsupported " + "on the selected GPUs. All GPUs must be Ampere (30 series) or " + "newer, CUDA only." ) - logger.warning(gpu_unsupported_message) + if self.attention_backend == "flash_attn": + if not flash_attn_available: + message = _unsupported_backend_message("flash_attn") + logger.warning(message) + raise RuntimeError(message) + self.resolved_attention_backend = "flash_attn" + elif self.attention_backend == "flashinfer": + if not flashinfer_available: + message = _unsupported_backend_message("flashinfer") + logger.warning(message) + raise RuntimeError(message) + check_package_version("flashinfer-python", "0.6.3") + self.resolved_attention_backend = "flashinfer" + else: + if flash_attn_available: + self.resolved_attention_backend = "flash_attn" + elif flashinfer_available: + check_package_version("flashinfer-python", "0.6.3") + self.resolved_attention_backend = "flashinfer" + else: + message = ( + "Unable to run ExllamaV3 because no supported cache-capable " + "attention backend is available.\n" + "Install flash_attn or flashinfer-python, and use Ampere-class " + "CUDA GPUs or newer." + ) + logger.warning(message) + raise RuntimeError(message) - raise RuntimeError(gpu_unsupported_message) + logger.info( + "Attention backend policy: {} (resolved: {})", + self.attention_backend, + self.resolved_attention_backend, + ) # Store the max_seq_len arg user_max_seq_len = kwargs.get("max_seq_len") @@ -267,10 +353,15 @@ async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs f'Using template "{self.prompt_template.name}" for chat completions.' ) else: - logger.warning( - "Chat completions are disabled because a prompt " - "template wasn't provided or auto-detected." - ) + if self.config.architecture == "DeepseekVLV2ForCausalLM": + logger.info( + "Using built-in DeepSeek-VL2 chat serializer for chat completions." + ) + else: + logger.warning( + "Chat completions are disabled because a prompt " + "template wasn't provided or auto-detected." + ) return self @@ -369,6 +460,10 @@ def model_info(self) -> ModelCard: max_batch_size=self.max_batch_size, cache_mode=self.cache_mode, chunk_size=self.chunk_size, + tokenizer_mode=self.tokenizer_mode, + mistral_tokenizer_models=self.mistral_tokenizer_models, + attention_backend=self.attention_backend, + resolved_attention_backend=self.resolved_attention_backend, use_vision=self.use_vision, ) @@ -432,8 +527,10 @@ async def load_gen(self, progress_callback=None, **kwargs): Progress updates """ + acquired_lock = False try: await self.load_lock.acquire() + acquired_lock = True # Wait for existing generation jobs to finish await self.wait_for_jobs(kwargs.get("skip_wait")) @@ -454,7 +551,8 @@ async def load_gen(self, progress_callback=None, **kwargs): self.loaded = True logger.info("Model successfully loaded.") finally: - self.load_lock.release() + if acquired_lock and self.load_lock.locked(): + self.load_lock.release() async with self.load_condition: self.load_condition.notify_all() @@ -516,6 +614,7 @@ async def create_generator(self): tokenizer=self.tokenizer, max_batch_size=self.max_batch_size, max_chunk_size=self.chunk_size, + attn_mode=self.resolved_attention_backend or self.attention_backend, ) # Update the state of the container var @@ -996,6 +1095,7 @@ async def generate_gen( max_rq_tokens=self.max_rq_tokens, filters=grammar_handler.filters, ) + self.active_job_ids[request_id] = job generated_tokens = 0 full_response = "" @@ -1013,8 +1113,21 @@ async def generate_gen( if chunk: chunk_tokens = result.get("token_ids", self.tokenizer.encode(chunk)) full_response += chunk + + # Extract token IDs as a plain list for downstream consumers if isinstance(chunk_tokens, torch.Tensor): + token_id_list = chunk_tokens.flatten().tolist() generated_tokens += chunk_tokens.size(dim=0) + elif isinstance(chunk_tokens, tuple): + first = chunk_tokens[0] + if isinstance(first, torch.Tensor): + token_id_list = first.flatten().tolist() + else: + token_id_list = list(first) + generated_tokens += len(token_id_list) + else: + token_id_list = list(chunk_tokens) + generated_tokens += len(token_id_list) # Increase penalty range to generated token amount # TODO: @@ -1024,6 +1137,7 @@ async def generate_gen( generation = { "request_id": request_id, "text": chunk, + "token_ids": token_id_list, "prompt_tokens": context_len, "generated_tokens": generated_tokens, "offset": len(full_response), @@ -1044,8 +1158,6 @@ async def generate_gen( yield finish_chunk break - # Assign the active job to the request ID - self.active_job_ids[request_id] = job except asyncio.CancelledError: await job.cancel() diff --git a/common/config_models.py b/common/config_models.py index 0e71734c..6265f4e6 100644 --- a/common/config_models.py +++ b/common/config_models.py @@ -9,8 +9,9 @@ from typing import List, Literal, Optional, Union -CACHE_SIZES = Literal["FP16", "Q8", "Q6", "Q4"] -CACHE_TYPE = Union[CACHE_SIZES, constr(pattern=r"^[2-8]\s*,\s*[2-8]$")] +CACHE_SIZES = Literal["FP16", "Q8", "Q6", "Q4"] +CACHE_TYPE = Union[CACHE_SIZES, constr(pattern=r"^[2-8]\s*,\s*[2-8]$")] +ATTENTION_BACKENDS = Literal["auto", "flash_attn", "flashinfer"] class Metadata(BaseModel): @@ -165,16 +166,25 @@ class ModelConfig(BaseConfigModel): "Example: ['max_seq_len', 'cache_mode']." ), ) - backend: Optional[str] = Field( - None, - description=( - "Backend to use for this model (auto-detect if not specified)\n" - "Options: exllamav2, exllamav3" - ), - ) - max_seq_len: Optional[int] = Field( - None, - description=( + backend: Optional[str] = Field( + None, + description=( + "Backend to use for this model (auto-detect if not specified)\n" + "Options: exllamav2, exllamav3" + ), + ) + attention_backend: Optional[ATTENTION_BACKENDS] = Field( + "auto", + description=( + "Attention backend policy for exllamav3 (default: auto).\n" + "Options: auto, flash_attn, flashinfer.\n" + "This chooses the cache-capable attention backend at model init time.\n" + "SDPA remains an internal fallback for non-cache paths and unsupported cases." + ), + ) + max_seq_len: Optional[int] = Field( + None, + description=( "Max sequence length (default: 4096).\n" "Set to -1 to fetch from the model's config.json" ), @@ -284,21 +294,73 @@ class ModelConfig(BaseConfigModel): ), ge=1, ) - prompt_template: Optional[str] = Field( - None, - description=( - "Set the prompt template for this model. (default: None)\n" - "If empty, attempts to look for the model's chat template.\n" - "If a model contains multiple templates in its tokenizer_config.json,\n" - "set prompt_template to the name of the template you want to use.\n" - "NOTE: Only works with chat completion message lists!" - ), - ) - vision: Optional[bool] = Field( - False, - description=( - "Enables vision support if the model supports it. (default: False)" - ), + prompt_template: Optional[str] = Field( + None, + description=( + "Set the prompt template for this model. (default: None)\n" + "If empty, attempts to look for the model's chat template.\n" + "If a model contains multiple templates in its tokenizer_config.json,\n" + "set prompt_template to the name of the template you want to use.\n" + "NOTE: Only works with chat completion message lists!" + ), + ) + tokenizer_mode: Optional[str] = Field( + "auto", + description=( + "Tokenizer compatibility mode for chat formatting.\n" + "Compatible values: auto, hf, slow, mistral, deepseek_v32.\n" + "slow is normalized to hf for ExLlama backends.\n" + "mistral applies Mistral-specific message normalization " + "(tool-call ID handling) and falls back to default behavior " + "for non-Mistral models." + ), + ) + mistral_tokenizer_models: Optional[List[str]] = Field( + default_factory=list, + description=( + "Optional allowlist for tokenizer_mode='mistral'.\n" + "When set, only listed model names/paths will use mistral mode.\n" + "If empty, Tabby auto-detects Mistral-family models." + ), + ) + reasoning_parser: Optional[str] = Field( + None, + description=( + "Reasoning parser key used to split output into reasoning/content.\n" + "Compatible with vLLM parser naming (e.g. exaone4, deepseek_r1).\n" + "If omitted, defaults to 'basic'." + ), + ) + enable_auto_tool_choice: Optional[bool] = Field( + False, + description=( + "Enable auto tool choice for chat completions (default: False).\n" + "Equivalent to vLLM's --enable-auto-tool-choice.\n" + "Requires tool_call_parser to be set." + ), + ) + tool_call_parser: Optional[str] = Field( + None, + description=( + "Tool parser key for model-generated tool call output.\n" + "Equivalent to vLLM's --tool-call-parser.\n" + "Built-in parser keys include: hermes, llama/llama3_json/llama4_json,\n" + "mistral, openai, pythonic, qwen3_coder, qwen3_xml,\n" + "deepseek_v3, deepseek_v31, deepseek_v32." + ), + ) + exclude_tools_when_tool_choice_none: Optional[bool] = Field( + False, + description=( + "Exclude tool definitions from prompt when tool_choice='none'.\n" + "Equivalent to vLLM's --exclude-tools-when-tool-choice-none." + ), + ) + vision: Optional[bool] = Field( + False, + description=( + "Enables vision support if the model supports it. (default: False)" + ), ) _metadata: Metadata = PrivateAttr(Metadata()) diff --git a/common/hardware.py b/common/hardware.py index 10723c5e..2e3ac25b 100644 --- a/common/hardware.py +++ b/common/hardware.py @@ -1,6 +1,13 @@ import torch +def _min_compute_capability(gpu_device_list: list[int]) -> int: + return min( + torch.cuda.get_device_capability(device = device_idx)[0] + for device_idx in gpu_device_list + ) + + def hardware_supports_flash_attn(gpu_device_list: list[int]): """ Check whether all GPUs in list support FA2 @@ -9,10 +16,23 @@ def hardware_supports_flash_attn(gpu_device_list: list[int]): AMD is also unsupported until ROCm updates its FA2 fork """ - min_compute_capability = min( - torch.cuda.get_device_capability(device=device_idx)[0] - for device_idx in gpu_device_list - ) + min_compute_capability = _min_compute_capability(gpu_device_list) + + if torch.version.hip or min_compute_capability < 8: + return False + else: + return True + + +def hardware_supports_flashinfer(gpu_device_list: list[int]): + """ + Check whether all GPUs in list support flashinfer. + + Keep the same minimum as the legacy FA2 path for now: + Ampere (SM80) or newer, CUDA only. + """ + + min_compute_capability = _min_compute_capability(gpu_device_list) if torch.version.hip or min_compute_capability < 8: return False diff --git a/common/image_util.py b/common/image_util.py index 9790cfe9..ab5dfd50 100644 --- a/common/image_util.py +++ b/common/image_util.py @@ -49,4 +49,13 @@ async def get_image(url: str) -> Image: raise HTTPException(400, error_message) - return Image.open(io.BytesIO(bytes_image)) + try: + image = Image.open(io.BytesIO(bytes_image)) + image.load() + return image + except Exception as e: + error_message = handle_request_error( + "Failed to read or decode image data stream.", + exc_info=False, + ).error.message + raise HTTPException(400, error_message) diff --git a/common/model.py b/common/model.py index 4ac4861a..5afafb06 100644 --- a/common/model.py +++ b/common/model.py @@ -10,7 +10,7 @@ from fastapi import HTTPException from loguru import logger from ruamel.yaml import YAML -from typing import Dict, Optional +from typing import Dict, Optional, Type from backends.base_model_container import BaseModelContainer from common.logger import get_loading_progress_bar @@ -25,24 +25,60 @@ embeddings_container = None -_BACKEND_REGISTRY: Dict[str, BaseModelContainer] = {} +_BACKEND_REGISTRY: Dict[str, Type[BaseModelContainer]] = {} +_BACKEND_REGISTRY_INITIALIZED = False +_INFINITY_CONTAINER_CLASS = None -if dependencies.exllamav2: - from backends.exllamav2.model import ExllamaV2Container - _BACKEND_REGISTRY["exllamav2"] = ExllamaV2Container +def _log_exllamav3_lock_hint(): + """Warn about stale torch extension lock files before ExLlama import.""" + lock_paths = sorted( + pathlib.Path.home().glob(".cache/torch_extensions/*/exllamav3_ext/lock") + ) + if not lock_paths: + return + + sample_paths = ", ".join(str(path) for path in lock_paths[:3]) + remaining = len(lock_paths) - 3 + if remaining > 0: + sample_paths += f", ... (+{remaining} more)" + + logger.warning( + "Detected torch extension lock file(s): " + f"{sample_paths}. Startup may appear frozen while waiting on this lock. " + "If no build process (ninja/nvcc) is active, remove the lock file and retry." + ) + +def _ensure_backend_registry(): + """Initialize backend registry lazily to avoid heavy imports at startup.""" + global _BACKEND_REGISTRY_INITIALIZED + global _INFINITY_CONTAINER_CLASS -if dependencies.exllamav3: - from backends.exllamav3.model import ExllamaV3Container + if _BACKEND_REGISTRY_INITIALIZED: + return - _BACKEND_REGISTRY["exllamav3"] = ExllamaV3Container + if dependencies.exllamav2: + from backends.exllamav2.model import ExllamaV2Container + _BACKEND_REGISTRY["exllamav2"] = ExllamaV2Container + + if dependencies.exllamav3: + logger.info( + "Initializing exllamav3 backend. " + "First run or source changes may trigger extension build." + ) + _log_exllamav3_lock_hint() + from backends.exllamav3.model import ExllamaV3Container -if dependencies.extras: - from backends.infinity.model import InfinityContainer + _BACKEND_REGISTRY["exllamav3"] = ExllamaV3Container - embeddings_container: Optional[InfinityContainer] = None + if dependencies.extras: + from backends.infinity.model import InfinityContainer + + _INFINITY_CONTAINER_CLASS = InfinityContainer + + _BACKEND_REGISTRY_INITIALIZED = True class ModelType(Enum): @@ -159,6 +195,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs): # Fetch the extra HF configuration options hf_model = await HFModel.from_directory(model_path) + _ensure_backend_registry() # Override the max sequence length based on user max_seq_len = kwargs.get("max_seq_len") @@ -255,6 +292,8 @@ async def load_embedding_model(model_path: pathlib.Path, **kwargs): "pip install -U .[extras]" ) + _ensure_backend_registry() + # Check if the model is already loaded if embeddings_container and embeddings_container.engine: loaded_model_name = embeddings_container.model_dir.name @@ -270,7 +309,10 @@ async def load_embedding_model(model_path: pathlib.Path, **kwargs): # Reset to prepare for a new container embeddings_container = None - new_embeddings_container = InfinityContainer(model_path) + if _INFINITY_CONTAINER_CLASS is None: + raise ImportError("Infinity backend is unavailable in the current environment.") + + new_embeddings_container = _INFINITY_CONTAINER_CLASS(model_path) await new_embeddings_container.load(**kwargs) embeddings_container = new_embeddings_container diff --git a/common/multimodal.py b/common/multimodal.py index b92386f3..ee4a1d30 100644 --- a/common/multimodal.py +++ b/common/multimodal.py @@ -1,5 +1,3 @@ -from backends.exllamav2.vision import get_image_embedding_exl2 -from backends.exllamav3.vision import get_image_embedding_exl3 from common import model from loguru import logger from pydantic import BaseModel, Field @@ -7,11 +5,6 @@ from common.optional_dependencies import dependencies -if dependencies.exllamav2: - from exllamav2 import ExLlamaV2VisionTower -if dependencies.exllamav3: - from exllamav3 import Model - class MultimodalEmbeddingWrapper(BaseModel): """Common multimodal embedding wrapper""" @@ -23,21 +16,24 @@ class MultimodalEmbeddingWrapper(BaseModel): async def add(self, url: str): # Determine the type of vision embedding to use if not self.type: - if dependencies.exllamav2 and isinstance( - model.container.vision_model, ExLlamaV2VisionTower - ): - self.type = "ExLlamaV2MMEmbedding" - elif dependencies.exllamav3 and isinstance( - model.container.vision_model, Model - ): - self.type = "MMEmbedding" + container = model.container + if container and getattr(container, "vision_model", None): + container_name = container.__class__.__name__ + if dependencies.exllamav2 and container_name == "ExllamaV2Container": + self.type = "ExLlamaV2MMEmbedding" + elif dependencies.exllamav3 and container_name == "ExllamaV3Container": + self.type = "MMEmbedding" # Create the embedding if self.type == "ExLlamaV2MMEmbedding": + from backends.exllamav2.vision import get_image_embedding_exl2 + embedding = await get_image_embedding_exl2(url) self.content.append(embedding) self.text_alias.append(embedding.text_alias) elif self.type == "MMEmbedding": + from backends.exllamav3.vision import get_image_embedding_exl3 + embedding = await get_image_embedding_exl3(url) self.content.append(embedding) self.text_alias.append(embedding.text_alias) diff --git a/common/optional_dependencies.py b/common/optional_dependencies.py index 5a23a2ee..811bc67a 100644 --- a/common/optional_dependencies.py +++ b/common/optional_dependencies.py @@ -11,25 +11,28 @@ __all__ = ["dependencies"] -class DependenciesModel(BaseModel): - """Model of which optional dependencies are installed.""" - - torch: bool - exllamav2: bool - exllamav3: bool - flash_attn: bool - infinity_emb: bool - sentence_transformers: bool +class DependenciesModel(BaseModel): + """Model of which optional dependencies are installed.""" + + torch: bool + exllamav2: bool + exllamav3: bool + flash_attn: bool + flashinfer: bool + infinity_emb: bool + sentence_transformers: bool @computed_field @property def extras(self) -> bool: return self.infinity_emb and self.sentence_transformers - @computed_field - @property - def inference(self) -> bool: - return self.torch and (self.exllamav2 or (self.exllamav3 and self.flash_attn)) + @computed_field + @property + def inference(self) -> bool: + return self.torch and ( + self.exllamav2 or (self.exllamav3 and (self.flash_attn or self.flashinfer)) + ) def is_installed(package_name: str) -> bool: diff --git a/common/templating.py b/common/templating.py index cc0cceb1..dda06d85 100644 --- a/common/templating.py +++ b/common/templating.py @@ -12,6 +12,7 @@ from jinja2.ext import loopcontrols from jinja2.sandbox import ImmutableSandboxedEnvironment from loguru import logger +from markupsafe import Markup from packaging import version @@ -24,12 +25,17 @@ class TemplateLoadError(Exception): pass +VALID_TOOL_CALL_FORMATS = {"json", "xml", "auto"} + + @dataclass class TemplateMetadata: """Represents the parsed metadata from a template.""" stop_strings: List[str] = field(default_factory=list) tool_start: Optional[str] = None + tool_end: Optional[str] = None + tool_call_format: str = "json" class PromptTemplate: @@ -46,6 +52,22 @@ class PromptTemplate: ) metadata: Optional[TemplateMetadata] = None + @staticmethod + def _tojson_compat(value, indent=None, ensure_ascii=True): + """Compatibility JSON filter for chat templates. + + Some model templates call ``tojson(ensure_ascii=False)`` while the + bundled Jinja filter may not accept that keyword in sandboxed mode. + """ + return Markup( + json.dumps( + value, + indent=indent, + ensure_ascii=ensure_ascii, + separators=(",", ": "), + ) + ) + async def extract_metadata(self, template_vars: dict): """ Returns deserialized template metadata from a chat template. @@ -76,6 +98,22 @@ async def extract_metadata(self, template_vars: dict): if isinstance(template_module.tool_start, str): template_metadata.tool_start = template_module.tool_start + if hasattr(template_module, "tool_end"): + if isinstance(template_module.tool_end, str): + template_metadata.tool_end = template_module.tool_end + + if hasattr(template_module, "tool_call_format"): + fmt = template_module.tool_call_format + if isinstance(fmt, str) and fmt in VALID_TOOL_CALL_FORMATS: + template_metadata.tool_call_format = fmt + logger.debug(f"Template tool_call_format: {fmt}") + else: + logger.warning( + f"Invalid tool_call_format '{fmt}' in template, " + f"defaulting to 'json'. " + f"Valid values: {VALID_TOOL_CALL_FORMATS}" + ) + self.metadata = template_metadata return template_metadata @@ -107,6 +145,7 @@ def raise_exception(message): self.environment.globals["strftime_now"] = strftime_now self.environment.globals["raise_exception"] = raise_exception + self.environment.filters["tojson"] = self._tojson_compat return self.environment.from_string(template_str) diff --git a/common/tokenizer_modes.py b/common/tokenizer_modes.py new file mode 100644 index 00000000..48c0f93b --- /dev/null +++ b/common/tokenizer_modes.py @@ -0,0 +1,110 @@ +"""Helpers for tokenizer compatibility mode detection.""" + +import json +import pathlib +from typing import Iterable + +VLLM_COMPAT_TOKENIZER_MODES = { + "auto", + "hf", + "slow", + "mistral", + "deepseek_v32", +} + + +def normalize_tokenizer_mode(tokenizer_mode: str | None) -> tuple[str, str | None]: + mode = str(tokenizer_mode or "auto").lower() + if mode not in VLLM_COMPAT_TOKENIZER_MODES: + return ( + "auto", + f"Unknown tokenizer_mode '{mode}' requested. Falling back to 'auto'.", + ) + + if mode == "slow": + return ( + "hf", + "tokenizer_mode='slow' requested, but ExLlama backends do not expose " + "a distinct slow tokenizer path. Using 'hf' compatibility mode.", + ) + + return mode, None + + +def _read_model_type(model_directory: pathlib.Path) -> str: + config_path = model_directory / "config.json" + if not config_path.exists(): + return "" + + try: + with open(config_path, "r", encoding="utf-8") as config_file: + return str(json.load(config_file).get("model_type", "")).lower() + except Exception: + return "" + + +def has_mistral_tokenizer_assets(model_directory: pathlib.Path) -> bool: + return ( + (model_directory / "tekken.json").exists() + or (model_directory / "tokenizer.model").exists() + or any(model_directory.glob("tokenizer.model.v*")) + ) + + +def supports_mistral_tokenizer_mode(model_directory: pathlib.Path) -> bool: + """ + Return True when mistral tokenizer mode is safe to enable for this model. + + vLLM uses mistral-common only for Mistral-family models in auto mode. + Match that intent by requiring both: + 1. A mistral-family model type. + 2. Mistral tokenizer assets. + """ + + model_type = _read_model_type(model_directory) + is_mistral_family = model_type.startswith("mistral") or model_type.startswith( + "mixtral" + ) + + return is_mistral_family and has_mistral_tokenizer_assets(model_directory) + + +def _matches_allowlist( + model_directory: pathlib.Path, mistral_tokenizer_models: Iterable[str] +) -> bool: + model_name = model_directory.name.lower() + model_path = model_directory.as_posix().lower().rstrip("/") + + for entry in mistral_tokenizer_models: + normalized = str(entry).strip().lower().strip("/") + if not normalized: + continue + + if model_name == normalized: + return True + + if model_path.endswith(normalized): + return True + + return False + + +def should_enable_mistral_tokenizer_mode( + model_directory: pathlib.Path, + mistral_tokenizer_models: Iterable[str] | None = None, +) -> bool: + """ + Decide whether mistral tokenizer mode should be enabled. + + If an explicit allowlist is configured, only listed Mistral-family models + can use mistral mode. If no allowlist is provided, fallback to auto + detection (mistral-family model + tokenizer assets). + """ + + allowlist = list(mistral_tokenizer_models or []) + if allowlist: + return _matches_allowlist(model_directory, allowlist) and ( + supports_mistral_tokenizer_mode(model_directory) + ) + + return supports_mistral_tokenizer_mode(model_directory) diff --git a/config_sample.yml b/config_sample.yml index 0b65f9e8..bbb0188d 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -78,6 +78,12 @@ model: # Options: exllamav2, exllamav3 backend: + # Attention backend policy for exllamav3 (default: auto). + # Options: auto, flash_attn, flashinfer + # Picks the cache-capable attention backend once at model init time. + # SDPA remains an internal fallback for non-cache paths and unsupported cases. + attention_backend: auto + # Max sequence length (default: min(max_position_embeddings, cache_size)). # Set to -1 to fetch from the model's config.json max_seq_len: @@ -153,6 +159,40 @@ model: # NOTE: Only works with chat completion message lists! prompt_template: + # Tokenizer compatibility mode for chat formatting (default: auto). + # Values: auto, hf, slow, mistral, deepseek_v32. + # Note: ExLlama backends treat "slow" as "hf". + # mistral applies Mistral-style tool-call ID normalization and + # falls back to default behavior for non-Mistral models. + tokenizer_mode: auto + + # Optional allowlist for tokenizer_mode='mistral'. + # If set, only listed model names/paths can use mistral mode. + # Leave empty to keep auto-detection behavior. + mistral_tokenizer_models: [] + + # Reasoning parser key for splitting hidden reasoning and final content. + # Compatible keys include: basic, exaone4, deepseek_r1, deepseek_v3. + # If omitted, TabbyAPI defaults to `basic`. + reasoning_parser: + + # Enable automatic tool selection (default: False). + # Equivalent to vLLM --enable-auto-tool-choice. + # Requires tool_call_parser to be set. + enable_auto_tool_choice: false + + # Tool parser key for model-generated tool call text. + # Equivalent to vLLM --tool-call-parser. + # Built-in values include: + # hermes, llama (alias of llama3_json), llama3_json, llama4_json, mistral, + # openai, pythonic, qwen3_coder, qwen3_xml, + # deepseek_v3, deepseek_v31, deepseek_v32. + tool_call_parser: + + # Exclude tool definitions from prompt when tool_choice='none'. + # Equivalent to vLLM --exclude-tools-when-tool-choice-none. + exclude_tools_when_tool_choice_none: false + # Enables vision support if the model supports it. (default: False) vision: false diff --git a/docs/01.-Getting-Started.md b/docs/01.-Getting-Started.md index 330116c6..513883c8 100644 --- a/docs/01.-Getting-Started.md +++ b/docs/01.-Getting-Started.md @@ -101,7 +101,7 @@ These scripts exit after running their respective tasks. To start TabbyAPI, run 1. `pip install -U .[cu12]` = CUDA 12.x 2. `pip install -U .[amd]` = ROCm 6.0 -If you don't want to update dependencies that come from wheels (torch, exllamav2, and flash attention 2), use `pip install .` or pass the `--nowheel` flag when invoking the start scripts. +If you don't want to update dependencies that come from wheels (torch, exllamav2/exllamav3, and flashinfer), use `pip install .` or pass the `--nowheel` flag when invoking the start scripts. ### Update Exllamav2 diff --git a/docs/02.-Server-options.md b/docs/02.-Server-options.md index 98cee556..9ca6683c 100644 --- a/docs/02.-Server-options.md +++ b/docs/02.-Server-options.md @@ -53,27 +53,33 @@ Note: These are experimental flags that may be removed at any point. Note: Most of the options here will only apply on initial model load/startup (ephemeral). They will not persist unless you add the option name to `use_as_default`. -| Config Option | Type (Default) | Description | -| --------------------- | -------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -| model_dir | String ("models") | Directory to look for models.

Note: Persisted across subsequent load requests | -| inline_model_loading | Bool (False) | Enables ability to switch models using the `model` argument in a generation request. More info in [Usage](https://github.com/theroyallab/tabbyAPI/wiki/03.-Usage#inline-loading) | -| use_dummy_models | Bool (False) | Send a dummy OAI model card when calling the `/v1/models` endpoint. Used for clients which enforce specific OAI models.

Note: Persisted across subsequent load requests | -| dummy_model_names | List[String] (["gpt-3.5-turbo"]) | List of dummy names to send on model endpoint requests | -| model_name | String (None) | Folder name of a model to load. The below parameters will not apply unless this is filled out. | -| use_as_default | List[String] ([]) | Keys to use by default when loading models. For example, putting `cache_mode` in this array will make every model load with that value unless specified by the API request.

Note: Also applies to the `draft` sub-block | -| max_seq_len | Float (None) | Maximum sequence length of the model. Uses the value from config.json if not specified here. Also called the max context length. | -| tensor_parallel | Bool (False) | Enables tensor parallelism. Automatically falls back to autosplit if GPU split isn't provided.

Note: `gpu_split_auto` is ignored when this is enabled. | -| gpu_split_auto | Bool (True) | Automatically split the model across multiple GPUs. Manual GPU split isn't used if this is enabled. | -| autosplit_reserve | List[Int] ([96]) | Amount of empty VRAM to reserve when loading with autosplit.

Represented as an array of MB per GPU used. | -| gpu_split | List[Float] ([]) | Float array of GBs to split a model between GPUs. | -| rope_scale | Float (1.0) | Adjustment for rope scale (or compress_pos_emb)

Note: If the model has YaRN support, this option will not apply. | -| rope_alpha | Float (None) | Adjustment for rope alpha. Leave blank to automatically calculate based on the max_seq_len.

Note: If the model has YaRN support, this option will not apply. | -| cache_mode | String ("FP16") | Cache mode for the model.

Options: FP16, Q8, Q6, Q4 | -| cache_size | Int (max_seq_len) | Size of the K/V cache

Note: If using CFG, the cache size should be 2 * max_seq_len. | -| chunk_size | Int (2048) | Amount of tokens per chunk with ingestion. A lower value reduces VRAM usage at the cost of ingestion speed. | -| max_batch_size | Int (None) | The absolute maximum amount of prompts to process at one time. This value is automatically adjusted based on cache size. | -| prompt_template | String (None) | Name of a jinja2 chat template to apply for this model. Must be located in the `templates` directory. | -| vision | Bool (False) | Enable vision support for the provided model (if it exists). | +| Config Option | Type (Default) | Description | +| ------------------------------------ | -------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| model_dir | String ("models") | Directory to look for models.

Note: Persisted across subsequent load requests | +| inline_model_loading | Bool (False) | Enables ability to switch models using the `model` argument in a generation request. More info in [Usage](https://github.com/theroyallab/tabbyAPI/wiki/03.-Usage#inline-loading) | +| use_dummy_models | Bool (False) | Send a dummy OAI model card when calling the `/v1/models` endpoint. Used for clients which enforce specific OAI models.

Note: Persisted across subsequent load requests | +| dummy_model_names | List[String] (["gpt-3.5-turbo"]) | List of dummy names to send on model endpoint requests | +| model_name | String (None) | Folder name of a model to load. The below parameters will not apply unless this is filled out. | +| use_as_default | List[String] ([]) | Keys to use by default when loading models. For example, putting `cache_mode` in this array will make every model load with that value unless specified by the API request.

Note: Also applies to the `draft` sub-block | +| max_seq_len | Float (None) | Maximum sequence length of the model. Uses the value from config.json if not specified here. Also called the max context length. | +| tensor_parallel | Bool (False) | Enables tensor parallelism. Automatically falls back to autosplit if GPU split isn't provided.

Note: `gpu_split_auto` is ignored when this is enabled. | +| gpu_split_auto | Bool (True) | Automatically split the model across multiple GPUs. Manual GPU split isn't used if this is enabled. | +| autosplit_reserve | List[Int] ([96]) | Amount of empty VRAM to reserve when loading with autosplit.

Represented as an array of MB per GPU used. | +| gpu_split | List[Float] ([]) | Float array of GBs to split a model between GPUs. | +| rope_scale | Float (1.0) | Adjustment for rope scale (or compress_pos_emb)

Note: If the model has YaRN support, this option will not apply. | +| rope_alpha | Float (None) | Adjustment for rope alpha. Leave blank to automatically calculate based on the max_seq_len.

Note: If the model has YaRN support, this option will not apply. | +| cache_mode | String ("FP16") | Cache mode for the model.

Options: FP16, Q8, Q6, Q4 | +| cache_size | Int (max_seq_len) | Size of the K/V cache

Note: If using CFG, the cache size should be 2 * max_seq_len. | +| chunk_size | Int (2048) | Amount of tokens per chunk with ingestion. A lower value reduces VRAM usage at the cost of ingestion speed. | +| max_batch_size | Int (None) | The absolute maximum amount of prompts to process at one time. This value is automatically adjusted based on cache size. | +| prompt_template | String (None) | Name of a jinja2 chat template to apply for this model. Must be located in the `templates` directory. | +| tokenizer_mode | String ("auto") | Tokenizer compatibility mode for chat formatting. Supported values are `auto`, `hf`, `slow`, `mistral`, `deepseek_v32`; `slow` maps to `hf` on ExLlama backends, and `mistral` applies Mistral-specific tool-call ID handling with fallback on non-Mistral models. | +| mistral_tokenizer_models | List[String] ([]) | Optional allowlist for `tokenizer_mode: mistral`. If set, only listed model names/paths use mistral mode. Leave empty to keep auto-detection behavior. | +| reasoning_parser | String (None) | Reasoning parser key used to split reasoning and final answer text (vLLM-compatible names, default parser behavior is `basic`). | +| enable_auto_tool_choice | Bool (False) | Enables vLLM-style automatic tool choice handling. Equivalent to `--enable-auto-tool-choice` and requires `tool_call_parser`. | +| tool_call_parser | String (None) | vLLM-compatible tool parser key used to parse model-emitted tool calls. Equivalent to `--tool-call-parser`. | +| exclude_tools_when_tool_choice_none | Bool (False) | Excludes tool definitions from the prompt when `tool_choice` is `"none"`. Equivalent to `--exclude-tools-when-tool-choice-none`. | +| vision | Bool (False) | Enable vision support for the provided model (if it exists). | ### Draft Model Options diff --git a/docs/04.-Chat-Completions.md b/docs/04.-Chat-Completions.md index 647ee92d..96044599 100644 --- a/docs/04.-Chat-Completions.md +++ b/docs/04.-Chat-Completions.md @@ -31,4 +31,11 @@ Now let's pass the custom var in the following template: I'm going to say {{ test_var }} ``` -Running render on this template will now result in: `I'm going to say hello!` \ No newline at end of file +Running render on this template will now result in: `I'm going to say hello!` + +### Reasoning controls + +TabbyAPI supports reasoning parser output separation with vLLM-compatible parser keys via `model.reasoning_parser` in `config.yml`. + +- `include_reasoning` request field: include or suppress reasoning output in responses +- `enable_thinking` / `thinking` request fields: accepted as top-level aliases and forwarded to template vars (`template_vars.enable_thinking`, `template_vars.thinking`) diff --git a/docs/10.-Tool-Calling.md b/docs/10.-Tool-Calling.md index 83e379a5..cbb772f6 100644 --- a/docs/10.-Tool-Calling.md +++ b/docs/10.-Tool-Calling.md @@ -12,11 +12,31 @@ TabbyAPI's tool calling implementation aligns with the [OpenAI Standard](https:/ TabbyAPI's tool implementation supports: - Tool calling when streaming - Calling multiple tools per turn +- `tool_choice` values: `none`, `auto`, `required`, and named function choice +- vLLM-style parser selection via `model.tool_call_parser` Current limitations: -- No support for `tool_choice` parameter (always assumed to be auto) - `strict` parameter not yet supported (OAI format ensured, but dtype and argument name choices not yet enforced) +### vLLM-compatible options + +The following model config options are available to align behavior with vLLM: + +- `enable_auto_tool_choice`: equivalent to `--enable-auto-tool-choice` +- `tool_call_parser`: equivalent to `--tool-call-parser` +- `exclude_tools_when_tool_choice_none`: equivalent to `--exclude-tools-when-tool-choice-none` + +`tool_choice="auto"` requires both `enable_auto_tool_choice: true` and `tool_call_parser` to be set. + +Supported parser keys include: +- `hermes` +- `llama` (alias of `llama3_json`), `llama3_json`, `llama4_json` +- `mistral` +- `openai` +- `pythonic` +- `qwen3_coder`, `qwen3_xml` +- `deepseek_v3`, `deepseek_v31`, `deepseek_v32` + ## Model Support TabbyAPI exposes controls within the `prompt_template` to accommodate models specifically tuned for tool calling and those that aren't. By default, TabbyAPI includes `chatml_with_headers_tool_calling.jinja`, a generic template built to support the Llama 3.1 family and other models following the ChatML (with headers) format. diff --git a/endpoints/OAI/reasoning/__init__.py b/endpoints/OAI/reasoning/__init__.py new file mode 100644 index 00000000..d7df6ee1 --- /dev/null +++ b/endpoints/OAI/reasoning/__init__.py @@ -0,0 +1,95 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from endpoints.OAI.reasoning.abs_reasoning_parsers import ( + DeltaMessage, + ReasoningParser, + ReasoningParserManager, +) +from endpoints.OAI.reasoning.basic_parsers import BaseThinkingReasoningParser +from endpoints.OAI.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser +from endpoints.OAI.reasoning.deepseek_v3_reasoning_parser import DeepSeekV3ReasoningParser +from endpoints.OAI.reasoning.ernie45_reasoning_parser import Ernie45ReasoningParser +from endpoints.OAI.reasoning.exaone4_reasoning_parser import Exaone4ReasoningParser +from endpoints.OAI.reasoning.glm4_moe_reasoning_parser import ( + Glm4MoeModelReasoningParser, +) +from endpoints.OAI.reasoning.gptoss_reasoning_parser import GptOssReasoningParser +from endpoints.OAI.reasoning.granite_reasoning_parser import GraniteReasoningParser +from endpoints.OAI.reasoning.holo2_reasoning_parser import Holo2ReasoningParser +from endpoints.OAI.reasoning.hunyuan_a13b_reasoning_parser import ( + HunyuanA13BReasoningParser, +) +from endpoints.OAI.reasoning.identity_reasoning_parser import IdentityReasoningParser +from endpoints.OAI.reasoning.kimi_k2_reasoning_parser import KimiK2ReasoningParser +from endpoints.OAI.reasoning.minimax_m2_reasoning_parser import ( + MiniMaxM2AppendThinkReasoningParser, + MiniMaxM2ReasoningParser, +) +from endpoints.OAI.reasoning.mistral_reasoning_parser import MistralReasoningParser +from endpoints.OAI.reasoning.olmo3_reasoning_parser import Olmo3ReasoningParser +from endpoints.OAI.reasoning.qwen3_reasoning_parser import Qwen3ReasoningParser +from endpoints.OAI.reasoning.seedoss_reasoning_parser import SeedOSSReasoningParser +from endpoints.OAI.reasoning.step3_reasoning_parser import Step3ReasoningParser +from endpoints.OAI.reasoning.step3p5_reasoning_parser import Step3p5ReasoningParser + + +@ReasoningParserManager.register_module("identity") +class _IdentityParser(IdentityReasoningParser): + pass + + +@ReasoningParserManager.register_module("basic") +class _BasicParser(DeepSeekR1ReasoningParser): + pass + + +ReasoningParserManager.reasoning_parsers.update( + { + "deepseek_r1": DeepSeekR1ReasoningParser, + "deepseek_v3": DeepSeekV3ReasoningParser, + "ernie45": Ernie45ReasoningParser, + "exaone4": Exaone4ReasoningParser, + "glm45": Glm4MoeModelReasoningParser, + "openai_gptoss": GptOssReasoningParser, + "granite": GraniteReasoningParser, + "holo2": Holo2ReasoningParser, + "hunyuan_a13b": HunyuanA13BReasoningParser, + "kimi_k2": KimiK2ReasoningParser, + "minimax_m2": MiniMaxM2ReasoningParser, + "minimax_m2_append_think": MiniMaxM2AppendThinkReasoningParser, + "mistral": MistralReasoningParser, + "olmo3": Olmo3ReasoningParser, + "qwen3": Qwen3ReasoningParser, + "seed_oss": SeedOSSReasoningParser, + "step3": Step3ReasoningParser, + "step3p5": Step3p5ReasoningParser, + } +) + + +__all__ = [ + "BaseThinkingReasoningParser", + "DeltaMessage", + "DeepSeekR1ReasoningParser", + "DeepSeekV3ReasoningParser", + "Ernie45ReasoningParser", + "Exaone4ReasoningParser", + "Glm4MoeModelReasoningParser", + "GptOssReasoningParser", + "GraniteReasoningParser", + "Holo2ReasoningParser", + "HunyuanA13BReasoningParser", + "IdentityReasoningParser", + "KimiK2ReasoningParser", + "MiniMaxM2AppendThinkReasoningParser", + "MiniMaxM2ReasoningParser", + "MistralReasoningParser", + "Olmo3ReasoningParser", + "Qwen3ReasoningParser", + "ReasoningParser", + "ReasoningParserManager", + "SeedOSSReasoningParser", + "Step3ReasoningParser", + "Step3p5ReasoningParser", +] diff --git a/endpoints/OAI/reasoning/abs_reasoning_parsers.py b/endpoints/OAI/reasoning/abs_reasoning_parsers.py new file mode 100644 index 00000000..b81983d0 --- /dev/null +++ b/endpoints/OAI/reasoning/abs_reasoning_parsers.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import ABC, abstractmethod +from collections.abc import Callable, Sequence +from dataclasses import dataclass +from typing import Any + + +@dataclass +class DeltaMessage: + content: str | None = None + reasoning: str | None = None + + +class ReasoningParser(ABC): + def __init__(self, tokenizer: Any, *args, **kwargs): + self.model_tokenizer = tokenizer + + @property + def vocab(self) -> dict[str, int]: + return self.model_tokenizer.get_vocab() + + @abstractmethod + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + pass + + def is_reasoning_end_streaming( + self, input_ids: Sequence[int], delta_ids: Sequence[int] + ) -> bool: + return self.is_reasoning_end(input_ids) + + @abstractmethod + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + pass + + @abstractmethod + def extract_reasoning( + self, + model_output: str, + request: Any, + ) -> tuple[str | None, str | None]: + pass + + @abstractmethod + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + pass + + def prepare_structured_tag(self, original_tag: str | None, tool_server: Any | None): + return original_tag + + +class ReasoningParserManager: + reasoning_parsers: dict[str, type[ReasoningParser]] = {} + + @classmethod + def list_registered(cls) -> list[str]: + return sorted(cls.reasoning_parsers.keys()) + + @classmethod + def get_reasoning_parser(cls, name: str) -> type[ReasoningParser]: + parser = cls.reasoning_parsers.get(name) + if parser is None: + registered = ", ".join(cls.list_registered()) + raise KeyError( + f"Reasoning parser '{name}' not found. Available parsers: {registered}" + ) + return parser + + @classmethod + def register_module( + cls, + module_name: str, + ) -> Callable[[type[ReasoningParser]], type[ReasoningParser]]: + def _decorator(module: type[ReasoningParser]) -> type[ReasoningParser]: + cls.reasoning_parsers[module_name] = module + return module + + return _decorator diff --git a/endpoints/OAI/reasoning/basic_parsers.py b/endpoints/OAI/reasoning/basic_parsers.py new file mode 100644 index 00000000..f2dfc0c7 --- /dev/null +++ b/endpoints/OAI/reasoning/basic_parsers.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import abstractmethod +from typing import Any + +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage, ReasoningParser + + +class BaseThinkingReasoningParser(ReasoningParser): + @property + @abstractmethod + def start_token(self) -> str: + raise NotImplementedError + + @property + @abstractmethod + def end_token(self) -> str: + raise NotImplementedError + + def __init__(self, tokenizer: Any, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + self.start_token_id = self.vocab.get(self.start_token) + self.end_token_id = self.vocab.get(self.end_token) + if self.start_token_id is None or self.end_token_id is None: + raise RuntimeError( + f"{self.__class__.__name__} could not locate think tokens in tokenizer" + ) + + def is_reasoning_end(self, input_ids: list[int]) -> bool: + for token_id in reversed(input_ids): + if token_id == self.start_token_id: + return False + if token_id == self.end_token_id: + return True + return False + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + if self.end_token_id not in input_ids[:-1]: + return [] + return input_ids[input_ids.index(self.end_token_id) + 1 :] + + def extract_reasoning( + self, + model_output: str, + request: Any, + ) -> tuple[str | None, str | None]: + model_output_parts = model_output.partition(self.start_token) + model_output = ( + model_output_parts[2] if model_output_parts[1] else model_output_parts[0] + ) + + if self.end_token not in model_output: + return model_output or None, None + + reasoning, _, content = model_output.partition(self.end_token) + return reasoning or None, content or None + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: list[int], + current_token_ids: list[int], + delta_token_ids: list[int], + ) -> DeltaMessage | None: + if len(delta_token_ids) == 1 and ( + delta_token_ids[0] in [self.start_token_id, self.end_token_id] + ): + return None + + if self.start_token_id in previous_token_ids: + if self.end_token_id in delta_token_ids: + end_index = delta_text.find(self.end_token) + reasoning = delta_text[:end_index] or None + content = delta_text[end_index + len(self.end_token) :] or None + return DeltaMessage(reasoning=reasoning, content=content) + if self.end_token_id in previous_token_ids: + return DeltaMessage(content=delta_text or None) + return DeltaMessage(reasoning=delta_text or None) + + if self.start_token_id in delta_token_ids: + if self.end_token_id in delta_token_ids: + start_index = delta_text.find(self.start_token) + end_index = delta_text.find(self.end_token) + reasoning = delta_text[start_index + len(self.start_token) : end_index] + content = delta_text[end_index + len(self.end_token) :] + return DeltaMessage(reasoning=reasoning or None, content=content or None) + return DeltaMessage(reasoning=delta_text or None) + + return DeltaMessage(content=delta_text or None) diff --git a/endpoints/OAI/reasoning/deepseek_r1_reasoning_parser.py b/endpoints/OAI/reasoning/deepseek_r1_reasoning_parser.py new file mode 100644 index 00000000..3b93bb17 --- /dev/null +++ b/endpoints/OAI/reasoning/deepseek_r1_reasoning_parser.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage +from endpoints.OAI.reasoning.basic_parsers import BaseThinkingReasoningParser + + +class DeepSeekR1ReasoningParser(BaseThinkingReasoningParser): + @property + def start_token(self) -> str: + return "" + + @property + def end_token(self) -> str: + return "" + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: list[int], + current_token_ids: list[int], + delta_token_ids: list[int], + ) -> DeltaMessage | None: + ret = super().extract_reasoning_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + ) + + if ( + ret is not None + and self.start_token_id not in previous_token_ids + and self.start_token_id not in delta_token_ids + ): + if self.end_token_id in delta_token_ids: + end_index = delta_text.find(self.end_token) + reasoning = delta_text[:end_index] or None + content = delta_text[end_index + len(self.end_token) :] or None + return DeltaMessage(reasoning=reasoning, content=content) + if self.end_token_id in previous_token_ids: + return DeltaMessage(content=delta_text or None) + return DeltaMessage(reasoning=delta_text or None) + + return ret diff --git a/endpoints/OAI/reasoning/deepseek_v3_reasoning_parser.py b/endpoints/OAI/reasoning/deepseek_v3_reasoning_parser.py new file mode 100644 index 00000000..5e8deb73 --- /dev/null +++ b/endpoints/OAI/reasoning/deepseek_v3_reasoning_parser.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any + +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage, ReasoningParser +from endpoints.OAI.reasoning.deepseek_r1_reasoning_parser import ( + DeepSeekR1ReasoningParser, +) +from endpoints.OAI.reasoning.identity_reasoning_parser import IdentityReasoningParser + + +class DeepSeekV3ReasoningParser(ReasoningParser): + def __init__(self, tokenizer: Any, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + + chat_kwargs = kwargs.get("chat_template_kwargs", {}) or {} + thinking = bool(chat_kwargs.get("thinking", False)) + enable_thinking = bool(chat_kwargs.get("enable_thinking", False)) + thinking = thinking or enable_thinking + + if thinking: + self._parser = DeepSeekR1ReasoningParser(tokenizer, *args, **kwargs) + else: + self._parser = IdentityReasoningParser(tokenizer, *args, **kwargs) + + def is_reasoning_end(self, input_ids: list[int]) -> bool: + return self._parser.is_reasoning_end(input_ids) + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + return self._parser.extract_content_ids(input_ids) + + def extract_reasoning( + self, + model_output: str, + request: Any, + ) -> tuple[str | None, str | None]: + return self._parser.extract_reasoning(model_output, request) + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: list[int], + current_token_ids: list[int], + delta_token_ids: list[int], + ) -> DeltaMessage | None: + return self._parser.extract_reasoning_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + ) diff --git a/endpoints/OAI/reasoning/ernie45_reasoning_parser.py b/endpoints/OAI/reasoning/ernie45_reasoning_parser.py new file mode 100644 index 00000000..bff91166 --- /dev/null +++ b/endpoints/OAI/reasoning/ernie45_reasoning_parser.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Any + +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage +from endpoints.OAI.reasoning.basic_parsers import BaseThinkingReasoningParser + + +class Ernie45ReasoningParser(BaseThinkingReasoningParser): + response_start_token: str = "" + response_end_token: str = "" + newline_token: str = "<0x0A>" + + @property + def start_token(self) -> str: + return "" + + @property + def end_token(self) -> str: + return "" + + def __init__(self, tokenizer: Any, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + self.response_start_token_id = self.vocab.get(self.response_start_token) + self.response_end_token_id = self.vocab.get(self.response_end_token) + self.newline_token_id = self.vocab.get(self.newline_token) + self.parser_token_ids = [self.end_token_id, self.response_end_token_id] + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + if len(delta_token_ids) == 1 and ( + delta_token_ids[0] + in [ + self.start_token_id, + self.end_token_id, + self.response_start_token_id, + self.response_end_token_id, + ] + ): + return None + + if self.end_token_id in delta_token_ids: + think_end_index = delta_text.find(self.end_token) + reasoning = delta_text[:think_end_index] + content = delta_text[think_end_index + len(self.end_token) :].lstrip("\n") + response_start_idx = content.find(self.response_start_token) + response_end_idx = content.rfind(self.response_end_token) + if response_start_idx != -1: + content = content[response_start_idx + len(self.response_start_token) :] + if response_end_idx != -1: + content = content[:response_end_idx] + return DeltaMessage(reasoning=reasoning, content=content or None) + + if self.end_token_id in previous_token_ids: + content = delta_text + if self.response_start_token_id in delta_token_ids: + content = content.lstrip("\n") + response_start_idx = content.find(self.response_start_token) + content = content[response_start_idx + len(self.response_start_token) :] + response_end_idx = content.rfind(self.response_end_token) + if response_end_idx != -1: + content = content[:response_end_idx] + elif self.response_end_token_id in delta_token_ids: + response_end_idx = content.rfind(self.response_end_token) + content = content[:response_end_idx] + + if previous_token_ids and previous_token_ids[-1] in self.parser_token_ids: + if delta_token_ids and delta_token_ids[0] == self.newline_token_id: + content = content.lstrip("\n") + if len(previous_token_ids) > 1 and previous_token_ids[-2] == self.end_token_id: + if delta_token_ids and delta_token_ids[0] == self.newline_token_id: + content = content.lstrip("\n") + + return DeltaMessage(content=content or None) + + return DeltaMessage(reasoning=delta_text) + + def extract_reasoning( + self, model_output: str, request: Any + ) -> tuple[str | None, str | None]: + reasoning, content = super().extract_reasoning(model_output, request) + if content: + start_idx = content.find(self.response_start_token) + end_idx = content.rfind(self.response_end_token) + if start_idx != -1 and end_idx != -1 and start_idx < end_idx: + content = content[start_idx + len(self.response_start_token) : end_idx] + return reasoning, content or None diff --git a/endpoints/OAI/reasoning/exaone4_reasoning_parser.py b/endpoints/OAI/reasoning/exaone4_reasoning_parser.py new file mode 100644 index 00000000..5e2dd43f --- /dev/null +++ b/endpoints/OAI/reasoning/exaone4_reasoning_parser.py @@ -0,0 +1,205 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Any + +from endpoints.OAI.reasoning.abs_reasoning_parsers import ( + DeltaMessage, + ReasoningParser, + ReasoningParserManager, +) + + +@ReasoningParserManager.register_module("exaone4") +class Exaone4ReasoningParser(ReasoningParser): + """ + Reasoning parser for EXAONE 4.x models. + + Behavior notes: + - EXAONE uses `enable_thinking` (not `thinking`) to control reasoning mode. + - Templates may prefill ``, so streamed/output text can start directly + with reasoning text and close at ``. + """ + + start_token = "" + end_token = "" + # Tool-call starts supported by ToolCallProcessor parser families. + # We use these as fallback reasoning boundaries when a model emits + # tool syntax without closing . + tool_start_markers = ( + "", + "", + "<|tool▁call▁begin|>", + "<|DSML|function_calls>", + "<|DSML|invoke", + "<|python_tag|>", + ) + + def __init__(self, tokenizer: Any, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + + chat_kwargs = kwargs.get("chat_template_kwargs", {}) or {} + self.thinking_enabled = bool(chat_kwargs.get("enable_thinking", False)) + self.start_token_id = self.vocab.get(self.start_token) + self.end_token_id = self.vocab.get(self.end_token) + + def _strip_reasoning_tokens(self, text: str) -> str: + if not text: + return "" + return text.replace(self.start_token, "").replace(self.end_token, "") + + def _trailing_overlap_len(self, text: str, token: str) -> int: + """Longest suffix overlap of text with token prefix.""" + max_len = min(len(text), len(token) - 1) + for size in range(max_len, 0, -1): + if text.endswith(token[:size]): + return size + return 0 + + def _find_first_marker(self, text: str, markers: Sequence[str]) -> tuple[int, str] | None: + first_idx = -1 + first_marker = "" + for marker in markers: + idx = text.find(marker) + if idx == -1: + continue + if first_idx == -1 or idx < first_idx: + first_idx = idx + first_marker = marker + if first_idx == -1: + return None + return first_idx, first_marker + + def _max_trailing_overlap_len(self, text: str, markers: Sequence[str]) -> int: + overlap = 0 + for marker in markers: + overlap = max(overlap, self._trailing_overlap_len(text, marker)) + return overlap + + def _split_reasoning_content_streaming( + self, text: str + ) -> tuple[str | None, str | None]: + """Split text into reasoning/content for streaming-safe diffing. + + Important: when end token is not yet complete, withhold a trailing + overlap with `` or tool-call prefixes to avoid leaking + partial control-tag bytes into reasoning output. This prevents + boundary-split regressions such as `answer` and + `{...}`. + """ + if not self.thinking_enabled: + content = self._strip_reasoning_tokens(text) + return None, content or None + + body = text + if self.start_token in body: + _, _, body = body.partition(self.start_token) + + if self.end_token in body: + reasoning, _, content = body.partition(self.end_token) + return reasoning or None, self._strip_reasoning_tokens(content) or None + + marker_match = self._find_first_marker(body, self.tool_start_markers) + if marker_match is not None: + marker_index, _ = marker_match + reasoning = body[:marker_index] + content = body[marker_index:] + return reasoning or None, self._strip_reasoning_tokens(content) or None + + reasoning = body.replace(self.start_token, "") + overlap = max( + self._trailing_overlap_len(reasoning, self.end_token), + self._max_trailing_overlap_len(reasoning, self.tool_start_markers), + ) + if overlap: + reasoning = reasoning[:-overlap] + return reasoning or None, None + + def _delta_from_previous(self, previous: str | None, current: str | None) -> str | None: + if current is None: + return None + previous_text = previous or "" + if current.startswith(previous_text): + delta = current[len(previous_text) :] + else: + # Fallback for recovery paths where prefix alignment breaks. + delta = current + return delta or None + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + if not self.thinking_enabled: + return True + if self.end_token_id is None: + return False + return any(token_id == self.end_token_id for token_id in reversed(input_ids)) + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + if not self.thinking_enabled: + return input_ids + if self.end_token_id is None or self.end_token_id not in input_ids[:-1]: + return [] + return input_ids[input_ids.index(self.end_token_id) + 1 :] + + def extract_reasoning( + self, + model_output: str, + request: Any, + ) -> tuple[str | None, str | None]: + if not self.thinking_enabled: + content = self._strip_reasoning_tokens(model_output) + return None, content or None + + if self.start_token in model_output: + _, _, model_output = model_output.partition(self.start_token) + + if self.end_token in model_output: + reasoning, _, content = model_output.partition(self.end_token) + content = self._strip_reasoning_tokens(content) + return reasoning or None, content or None + + reasoning = model_output.replace(self.start_token, "") + return reasoning or None, None + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: list[int], + current_token_ids: list[int], + delta_token_ids: list[int], + ) -> DeltaMessage | None: + if not delta_text and not delta_token_ids: + return None + + if not self.thinking_enabled: + prev_reasoning, prev_content = self._split_reasoning_content_streaming( + previous_text + ) + cur_reasoning, cur_content = self._split_reasoning_content_streaming( + current_text + ) + content_delta = self._delta_from_previous(prev_content, cur_content) + if content_delta is None: + return None + return DeltaMessage(content=content_delta) + + if len(delta_token_ids) == 1 and ( + (self.start_token_id is not None and delta_token_ids[0] == self.start_token_id) + or (self.end_token_id is not None and delta_token_ids[0] == self.end_token_id) + ): + return None + + prev_reasoning, prev_content = self._split_reasoning_content_streaming(previous_text) + cur_reasoning, cur_content = self._split_reasoning_content_streaming(current_text) + + reasoning_delta = self._delta_from_previous(prev_reasoning, cur_reasoning) + content_delta = self._delta_from_previous(prev_content, cur_content) + + if reasoning_delta is None and content_delta is None: + return None + return DeltaMessage(reasoning=reasoning_delta, content=content_delta) diff --git a/endpoints/OAI/reasoning/glm4_moe_reasoning_parser.py b/endpoints/OAI/reasoning/glm4_moe_reasoning_parser.py new file mode 100644 index 00000000..9368f2c3 --- /dev/null +++ b/endpoints/OAI/reasoning/glm4_moe_reasoning_parser.py @@ -0,0 +1,10 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from endpoints.OAI.reasoning.holo2_reasoning_parser import Holo2ReasoningParser + + +class Glm4MoeModelReasoningParser(Holo2ReasoningParser): + pass diff --git a/endpoints/OAI/reasoning/gptoss_reasoning_parser.py b/endpoints/OAI/reasoning/gptoss_reasoning_parser.py new file mode 100644 index 00000000..3e454bed --- /dev/null +++ b/endpoints/OAI/reasoning/gptoss_reasoning_parser.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +from collections.abc import Sequence +from typing import Any + +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage, ReasoningParser + + +NO_FUNC_REASONING_TAG = { + "type": "structural_tag", + "format": { + "type": "triggered_tags", + "tags": [ + { + "begin": "<|channel|>analysis<|message|>", + "content": {"type": "any_text"}, + "end": "<|end|>", + } + ], + "triggers": ["<|channel|>analysis"], + "stop_after_first": False, + }, +} + + +class GptOssReasoningParser(ReasoningParser): + def __init__(self, tokenizer: Any, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + + def _split_harmony(self, text: str) -> tuple[str | None, str | None]: + # Minimal harmony-compatible splitter without vLLM parser dependency. + analysis_tag = "<|channel|>analysis<|message|>" + final_tag = "<|channel|>final<|message|>" + end_tag = "<|end|>" + + a_idx = text.find(analysis_tag) + f_idx = text.find(final_tag) + if a_idx == -1 and f_idx == -1: + return None, text or None + + reasoning = None + content = None + + if a_idx != -1: + a_start = a_idx + len(analysis_tag) + a_end = text.find(end_tag, a_start) + if a_end == -1: + a_end = f_idx if f_idx != -1 else len(text) + reasoning = text[a_start:a_end] or None + + if f_idx != -1: + f_start = f_idx + len(final_tag) + f_end = text.find(end_tag, f_start) + if f_end == -1: + f_end = len(text) + content = text[f_start:f_end] or None + + return reasoning, content + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + text = self.model_tokenizer.decode(input_ids) + return "<|channel|>final<|message|>" in text + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + _, content = self._split_harmony(self.model_tokenizer.decode(input_ids)) + if content is None: + return [] + return self.model_tokenizer.encode(content) + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + prev_reasoning, prev_content = self._split_harmony(previous_text) + cur_reasoning, cur_content = self._split_harmony(current_text) + + reasoning_delta = None + content_delta = None + if cur_reasoning is not None: + prev_r = prev_reasoning or "" + reasoning_delta = ( + cur_reasoning[len(prev_r) :] + if cur_reasoning.startswith(prev_r) + else cur_reasoning + ) or None + if cur_content is not None: + prev_c = prev_content or "" + content_delta = ( + cur_content[len(prev_c) :] + if cur_content.startswith(prev_c) + else cur_content + ) or None + + if reasoning_delta is None and content_delta is None: + return None + return DeltaMessage(reasoning=reasoning_delta, content=content_delta) + + def extract_reasoning( + self, + model_output: str, + request: Any, + ) -> tuple[str | None, str | None]: + return self._split_harmony(model_output) + + def prepare_structured_tag( + self, original_tag: str | None, tool_server: Any | None + ) -> str | None: + if original_tag is not None: + return original_tag + return json.dumps(NO_FUNC_REASONING_TAG) diff --git a/endpoints/OAI/reasoning/granite_reasoning_parser.py b/endpoints/OAI/reasoning/granite_reasoning_parser.py new file mode 100644 index 00000000..c60c3fac --- /dev/null +++ b/endpoints/OAI/reasoning/granite_reasoning_parser.py @@ -0,0 +1,376 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Any + +try: + import regex as re +except ImportError: + import re + +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage, ReasoningParser + + + +class GraniteReasoningParser(ReasoningParser): + """ + Reasoning parser for IBM Granite. + + IBM granite models currently use "Here is my thought process:" + and "Here is my response:" to separate its thinking / response outputs. + """ + + def __init__(self, tokenizer: Any, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + + # NOTE: There have been some observed occurrences of quantized + # instances of the current models using "Here's" instead of "Here is", + # so to be safe, we match on both. + self.think_start_expr = r"(?:Here's|Here is) my thought process:" + self.response_start_expr = r"(?:Here's|Here is) my response:" + + self.reasoning_regex = re.compile( + rf"{self.think_start_expr}(.*?){self.response_start_expr}(.*)", re.DOTALL + ) + + self.valid_think_starts = [ + "Here's my thought process:", + "Here is my thought process:", + ] + self.valid_response_starts = ["Here's my response:", "Here is my response:"] + + # Substrings to match for sequence boundaries on raw text + self.seq_boundary_end = ":" + self.seq_boundary_start = "Here" + + # The longest any thinking / start of response message can be + self.longest_think_start = max( + len(think_start) for think_start in self.valid_think_starts + ) + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + text = self.model_tokenizer.decode(input_ids) + return any(resp in text for resp in self.valid_response_starts) + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + text = self.model_tokenizer.decode(input_ids) + _, content = self.extract_reasoning(text, None) + if not content: + return [] + return self.model_tokenizer.encode(content) + + def extract_reasoning( + self, model_output: str, request: ChatCompletionRequest + ) -> tuple[str | None, str | None]: + """Extract the reasoning content & content sections, respectively. + If the sequence doesn't match what we expect, i.e., the model generates + something else, all content is considered non-reasoning content. + + Args: + model_output (str): Output of the model to be parsed. + request (ChatCompletionRequest): Request being processed. + + Returns: + tuple[Optional[str], Optional[str]]: Tuple pair containing the + reasoning content and non-reasoning content. + """ + re_match = self.reasoning_regex.findall(model_output) + if not re_match: + return None, model_output + reasoning, response_content = re_match[0] + if not response_content: + return reasoning, None + return reasoning, response_content + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + """Extract the reasoning content / content emitted by granite models; + If the sequence doesn't match what we expect, i.e., the model generates + something else, all content is considered non-reasoning content. + + NOTE: Granite models do not use a special token to start their reasoning + and response sections; instead they have token sequences, e.g., + + Here is my thought process: Foo Here is my response: Bar + + This increases the complexity of correctly handling streams, since we + need to watch for specific sequences and correctly parse them without + dropping content that is potentially overlapping & spanning multiple + delta messages. + + Args: + previous_text (str): Previous text outside of this delta message. + current_text (str): Previous text + delta text. + delta_text (str): Text to consider and parse content from. + previous_token_ids (Sequence[int]): Token IDs of previous_text. + current_token_ids (Sequence[int]): Token IDs of current_text. + delta_token_ids (Sequence[int]): Token IDs of delta_text. + + Returns: + Union[DeltaMessage, None] + DeltaMessage with either reasoning content or content, or None. + """ + reasoning, resp_seq_len, content = self._get_content_sections(current_text) + # Either we haven't finished the start of the reasoning sequence, + # or the model is generating something unexpected. + if not reasoning: + delta_message = self._get_delta_message_with_no_reasoning_bounds( + current_text, delta_text + ) + # We have a start of reasoning message, but have not yet finished + # the start of response sequence. + elif not content: + delta_message = self._get_delta_message_with_no_response_bounds( + current_text, reasoning, delta_text + ) + # We've finished both the start of reasoning and start of response seq. + else: + # This should never happen since we matched on the response + assert resp_seq_len is not None + delta_message = self._get_delta_message_with_both_bounds( + delta_text, reasoning, content, current_text, resp_seq_len + ) + if not delta_message.content and not delta_message.reasoning: + return None + return delta_message + + #### Implementation details of stream parsing for granite models + def _is_reasoning_start_substr(self, text: str) -> bool: + """Check if a text matches one of the possible start reasoning seqs. + + Args: + text (str): Text to check for leading substr. + + Returns: + bool: True if any of the possible reasoning start seqs match. + """ + return any( + think_start.startswith(text) for think_start in self.valid_think_starts + ) + + def _is_response_start_substr(self, text: str) -> bool: + """Check if a text matches one of the possible start response seqs. + + Args: + text (str): Text to check for leading substr. + + Returns: + bool: True if any of the possible response start seqs match. + """ + return any( + response_start.startswith(text) + for response_start in self.valid_response_starts + ) + + def _get_delta_message_with_no_reasoning_bounds( + self, + current_text: str, + delta_text: str, + ) -> DeltaMessage: + """Parse the delta message when the current text has not yet completed + its start of reasoning sequence. + + Args: + current_text (str): The full previous + delta text. + delta_text (str): Text to consider and parse content from. + + Returns: + DeltaMessage: Message containing the parsed content. + """ + prev_longest_length = len(current_text) - len(delta_text) + is_substr = self._is_reasoning_start_substr(current_text) + was_substr = self._is_reasoning_start_substr(current_text[:prev_longest_length]) + + # Check if we just generated something NOT in the special token seq; + # if so, add everything that we previously skipped with this delta + # message and append everything to content in the future. + if was_substr and not is_substr: + return DeltaMessage( + reasoning=None, + content=current_text, + ) + if is_substr: + # Might still be in the special token sequence; return nothing + return DeltaMessage(reasoning=None, content=None) + # Otherwise the sequence has already been broken and we already + # corrected; just return the delta text as normal content. + return DeltaMessage(reasoning=None, content=delta_text) + + def _get_delta_message_with_no_response_bounds( + self, + current_text: str, + reasoning: str, + delta_text: str, + ) -> DeltaMessage: + """Parse the delta message when the current text has both reasoning + content with no (response) content. NOTE that we may have overlapping + tokens with the start of reasoning / start of response sequences on + either side of the delta text. + + Args: + current_text (str): The full previous + delta text. + reasoning (str): reasoning content from current_text. + delta_text (str): Text to consider and parse content from. + + Returns: + DeltaMessage: Message containing the parsed content. + """ + # If we have no reasoning content or explicitly end with the start of + # response sequence, we are in transition to the response; need to be + # careful here, since the final token (:) will match the reasoning + # content and fully parse it out; we should not pass the : back. + ends_with_start_response_seq = any( + current_text.endswith(response_start) + for response_start in self.valid_response_starts + ) + if reasoning is None or ends_with_start_response_seq: + return DeltaMessage(reasoning=None, content=None) + + # Consider previous / current text only within context of the reasoning + previous_text = reasoning[: -len(delta_text)] + current_text = reasoning + + # We need to be careful about adding unfinished response sequences; + # Find the place at which we MIGHT be starting a response sequence + prev_idx = previous_text.rfind(self.seq_boundary_start) + delta_idx = delta_text.rfind(self.seq_boundary_start) + + # Check the state of potential start of response substring matches. + prev_was_substr = ( + self._is_response_start_substr(previous_text[prev_idx:]) + if prev_idx >= 0 + else False + ) + delta_continues_substr = ( + self._is_response_start_substr(current_text[prev_idx:]) + if prev_idx >= 0 + else False + ) + delta_new_substr = ( + self._is_response_start_substr(delta_text[delta_idx:]) + if delta_idx >= 0 + else False + ) + + # Delta only contains potential continued response sequence text. + if delta_continues_substr: + return DeltaMessage(reasoning=None, content=None) + + if not prev_was_substr: + # Delta may be starting a new response seq but has other text too. + if delta_new_substr: + return DeltaMessage(reasoning=delta_text[:delta_idx], content=None) + # Normal case for most reasoning text (no potential special seqs). + return DeltaMessage(reasoning=delta_text, content=None) + # The substring that previously seemed to be a potential response + # seq wasn't one; we need to add the content to the delta message, + # and also slice off the potential response sequence + elif delta_new_substr: + reasoning = previous_text[prev_idx:] + delta_text[:delta_idx] + return DeltaMessage(reasoning=reasoning, content=None) + # No new substring yet, and we broke our old one; take the whole delta + return DeltaMessage( + reasoning=previous_text[prev_idx:] + delta_text, + content=None, + ) + + def _get_delta_message_with_both_bounds( + self, + delta_text: str, + reasoning: str, + response_content: str, + current_text: str, + response_seq_len: int, + ) -> DeltaMessage: + """Parse the delta message when the current text has both reasoning + content and normal (response) content. + + Args: + delta_text: Text to consider and parse content from. + reasoning: reasoning content from current_text. + response_content: response content from current_text. + current_text: The full previous + delta text. + response_seq_len: Len of the complete response sequence used. + + Returns: + DeltaMessage: Message containing the parsed content. + """ + # Always have content; take length to the end + delta_content = delta_text[-len(response_content) :] + reasoning_end_idx = len(delta_text) - (len(response_content) + response_seq_len) + + if reasoning_end_idx < 0: + delta_reasoning = None + else: + # Get the starting offset + start_reasoning_idx = ( + len(reasoning) + response_seq_len + len(response_content) - 1 + ) + delta_offset = len(current_text) - len(delta_text) + start_offset = start_reasoning_idx - delta_offset + if start_offset < 0: + start_offset = 0 + delta_reasoning = delta_text[start_offset:reasoning_end_idx] + + return DeltaMessage( + reasoning=delta_reasoning, + content=delta_content, + ) + + def _get_content_sections( + self, current_text: str + ) -> tuple[str | None, int | None, str | None]: + """Parse the text to extract the reasoning content / content + if we have them. + + Args: + current_text (str): The full previous + delta text. + + Returns: + tuple[Optional[str], Optional[int], Optional[str]]: Tuple of len 3 + containing the reasoning content, the length of the response seq + (if there is one) and the non-reasoning content. + """ + current_chunk_start = 0 + start_reasoning = None + parsed_content = False + delimiter_idxs = [ + idx + for idx, char in enumerate(current_text) + if char == self.seq_boundary_end + ] + + for current_chunk_end in delimiter_idxs: + current_chunk = current_text[current_chunk_start:current_chunk_end] + # Check to see if the start of reasoning seq if complete + if start_reasoning is None: + for think_start in self.valid_think_starts: + if current_chunk == think_start[:-1]: + start_reasoning = current_chunk_end + 1 + current_chunk_start = current_chunk_end + 1 + break + + # Check to see if the start of response seq if complete + elif not parsed_content: + for response_start in self.valid_response_starts: + if current_chunk[-len(response_start) + 1 :] == response_start[:-1]: + # Mark end of reasoning and start response content + # after the start of response sequence. + end_reasoning = current_chunk_end - len(response_start) + reasoning = current_text[start_reasoning:end_reasoning] + response_content = current_text[current_chunk_end + 1 :] + return reasoning, len(response_start), response_content + + if start_reasoning and not parsed_content: + return current_text[start_reasoning:], None, None + return None, None, None diff --git a/endpoints/OAI/reasoning/holo2_reasoning_parser.py b/endpoints/OAI/reasoning/holo2_reasoning_parser.py new file mode 100644 index 00000000..cdd4d356 --- /dev/null +++ b/endpoints/OAI/reasoning/holo2_reasoning_parser.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Any + +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage, ReasoningParser +from endpoints.OAI.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser +from endpoints.OAI.reasoning.identity_reasoning_parser import IdentityReasoningParser + + +class Holo2ReasoningParser(ReasoningParser): + def __init__(self, tokenizer: Any, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + chat_kwargs = kwargs.get("chat_template_kwargs", {}) or {} + thinking = bool(chat_kwargs.get("thinking", True)) + enable_thinking = bool(chat_kwargs.get("enable_thinking", True)) + thinking = thinking and enable_thinking + self._parser = ( + DeepSeekR1ReasoningParser(tokenizer, *args, **kwargs) + if thinking + else IdentityReasoningParser(tokenizer, *args, **kwargs) + ) + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + return self._parser.is_reasoning_end(input_ids) + + def is_reasoning_end_streaming( + self, input_ids: Sequence[int], delta_ids: Sequence[int] + ) -> bool: + return self._parser.is_reasoning_end_streaming(input_ids, delta_ids) + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + return self._parser.extract_content_ids(input_ids) + + def extract_reasoning( + self, model_output: str, request: Any + ) -> tuple[str | None, str | None]: + return self._parser.extract_reasoning(model_output, request) + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + return self._parser.extract_reasoning_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + ) diff --git a/endpoints/OAI/reasoning/hunyuan_a13b_reasoning_parser.py b/endpoints/OAI/reasoning/hunyuan_a13b_reasoning_parser.py new file mode 100644 index 00000000..b3ae3c72 --- /dev/null +++ b/endpoints/OAI/reasoning/hunyuan_a13b_reasoning_parser.py @@ -0,0 +1,241 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Any + +try: + import regex as re +except ImportError: + import re + +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage, ReasoningParser + + + +class HunyuanA13BReasoningParser(ReasoningParser): + """ + Reasoning parser for Hunyuan A13B Model + + HunyuanReasoningParser + + This class implements a reasoning parser specifically designed + for the Hunyuan A13B Model. It is responsible for parsing and + extracting structured reasoning and answer segments from model + outputs that follow a specific pattern. + + Key Features: + - For non-stream output , Recognizes and extracts reasoning ("think") + and answer ("answer") sections from text using regular expressions. + - For stream process, it requires a token id sequences to change the + reasoning state and other state so it maintains internal state to + manage parsing across multiple token. + + + think start: "\n": [14023, 771, 397] + think ends: "\n\n\n": [198, 524, 27963, 397, 27, 9399, 397] + response ends: "\n": [524, 9399, 29] + """ + + def __init__(self, tokenizer: Any, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + self.think_start_expr = r"\n" + self.think_end_expr = r"\n\n" + + self.response_start_expr = r"\n\n\n" + self.response_end_expr = r"\n" + + self.full_match_reasoning_regex = re.compile( + rf"(?:{self.think_start_expr}(.*?){self.response_start_expr})?(.*?){self.response_end_expr}", + re.DOTALL, + ) + + self.half_match_reasoning_regex = re.compile( + rf"{self.think_start_expr}(.*?){self.response_start_expr}(.*)", re.DOTALL + ) + + self.think_start_ids = [14023, 771, 397] + self.think_start_ids_fast = [14023, 771, 1363] + self.response_start_ids = [198, 524, 27963, 397, 27, 9399, 397] + self.response_start_ids_fast = [524, 27963, 397, 27, 9399, 397] + self.response_end_ids = [198, 524, 9399, 29] + self.fast_think_ids = [14023, 771, 1363, 524, 27963, 397, 27, 9399, 397] + + # when state change, send out all the buffered text in last state + self.buffered_text = [] + self.buffered_ids = [] + + self.current_state = "reasoning" + self.all_states = ["reasoning", "response"] + + self.current_state = "idle" + self.expected_sequence = self.think_start_ids + # this sequence only for the think start, it has two way to start. + self.expected_sequence_side = self.think_start_ids_fast + self.sequence_index = 0 + self.token_buffer = [] + self.text_buffer = "" + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + return self.current_state == "response" + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + # for hunyuan streaming reason parsing, the stream parse + # will call first, and the same token will be called in + # is_reasoning_end and extract_content_ids + # this id is not part of content, so just return [] here. + return [] + + def extract_reasoning( + self, model_output: str, request: ChatCompletionRequest + ) -> tuple[str | None, str | None]: + """Extract the reasoning content & content sections, respectively. + If the sequence doesn't match what we expect, i.e., the model generates + something else, all content is considered non-reasoning content. + + Args: + model_output (str): Output of the model to be parsed. + request (ChatCompletionRequest): Request being processed. + + Returns: + tuple[Optional[str], Optional[str]]: Tuple pair containing the + reasoning content and non-reasoning content. + """ + + re_match = self.full_match_reasoning_regex.findall(model_output) + if re_match: + reasoning, response_content = re_match[0] + if len(reasoning) == 0: + reasoning = None + if len(response_content) == 0: + response_content = None + return reasoning, response_content + + fallback_regex = self.half_match_reasoning_regex + fallback_match = fallback_regex.findall(model_output) + if fallback_match: + reasoning, response_content = fallback_match[0] + + if response_content.endswith(self.response_end_expr): + response_content = response_content[: -len(self.response_end_expr)] + + if len(reasoning) == 0: + reasoning = None + if len(response_content) == 0: + response_content = None + + return reasoning, response_content + + return None, model_output + + def _is_strict_increasing_subsequence( + self, subsequence: Sequence[int], sequence: Sequence[int] + ) -> bool: + if not subsequence: + return False + + sub_idx = 0 + for num in sequence: + if sub_idx < len(subsequence) and num == subsequence[sub_idx]: + sub_idx += 1 + return sub_idx == len(subsequence) + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + """Extract content using token ID sequence state machine""" + # Define sequences + think_start_sequence = self.think_start_ids + response_start_sequence = self.response_start_ids + response_end_sequence = self.response_end_ids + + if not delta_token_ids: + return None + + # Process each token in the delta + token = delta_token_ids[-1] + + def check_token_with_sequence(token): + if self.current_state == "idle" or self.current_state == "think": + return ( + token == self.expected_sequence[self.sequence_index] + or token == self.expected_sequence_side[self.sequence_index] + ) + else: + return token == self.expected_sequence[self.sequence_index] + + def check_last_token(token): + if self.current_state == "idle" or self.current_state == "think": + # only return true if it's judge using a side sequence. + if ( + self.sequence_index - 1 < len(self.expected_sequence_side) + and token == self.expected_sequence_side[self.sequence_index - 1] + ): + return self.sequence_index == len(self.expected_sequence_side) + else: + return self.sequence_index == len(self.expected_sequence) + else: + return self.sequence_index == len(self.expected_sequence) + + # Check if token matches expected sequence + token_in_state_seq = check_token_with_sequence(token) + + if token_in_state_seq: + # Store matching token + self.token_buffer.append(token) + self.text_buffer += delta_text + self.sequence_index += 1 + ## state change from idle->think->response->idle + + # Check if sequence fully matched + if check_last_token(token): + # State transition + if self.current_state == "idle": + self.current_state = "think" + self.expected_sequence = response_start_sequence + self.expected_sequence_side = self.response_start_ids_fast + elif self.current_state == "think": + self.current_state = "response" + self.expected_sequence = response_end_sequence + elif self.current_state == "response": + self.current_state = "idle" + self.expected_sequence = think_start_sequence + self.expected_sequence_side = self.think_start_ids_fast + + # Reset matching state + self.sequence_index = 0 + self.token_buffer = [] + self.text_buffer = "" + # Do not send content for state transition texts. + else: + # Sequence broken - handle buffered content + if self.token_buffer and len(self.token_buffer) > 0: + # Send buffered tokens + buffered_content = self.text_buffer + delta_text + # Reset matching state + self.sequence_index = 0 + self.token_buffer = [] + self.text_buffer = "" + + # Return content based on current state + if self.current_state == "think": + return DeltaMessage(reasoning=buffered_content, content=None) + else: + return DeltaMessage(reasoning=None, content=buffered_content) + else: + # No buffered content, send normally + if self.current_state == "think": + return DeltaMessage(reasoning=delta_text, content=None) + else: + return DeltaMessage(reasoning=None, content=delta_text) + + # If no content to send in this delta + return None diff --git a/endpoints/OAI/reasoning/identity_reasoning_parser.py b/endpoints/OAI/reasoning/identity_reasoning_parser.py new file mode 100644 index 00000000..52aa4052 --- /dev/null +++ b/endpoints/OAI/reasoning/identity_reasoning_parser.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any + +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage, ReasoningParser + + +class IdentityReasoningParser(ReasoningParser): + def is_reasoning_end(self, input_ids: list[int]) -> bool: + return True + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + return input_ids + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: list[int], + current_token_ids: list[int], + delta_token_ids: list[int], + ) -> DeltaMessage | None: + if not delta_text: + return None + return DeltaMessage(content=delta_text) + + def extract_reasoning( + self, + model_output: str, + request: Any, + ) -> tuple[str | None, str | None]: + return None, model_output diff --git a/endpoints/OAI/reasoning/kimi_k2_reasoning_parser.py b/endpoints/OAI/reasoning/kimi_k2_reasoning_parser.py new file mode 100644 index 00000000..664d5163 --- /dev/null +++ b/endpoints/OAI/reasoning/kimi_k2_reasoning_parser.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Any + +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage, ReasoningParser +from endpoints.OAI.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser +from endpoints.OAI.reasoning.identity_reasoning_parser import IdentityReasoningParser + + +class KimiK2ReasoningParser(ReasoningParser): + def __init__(self, tokenizer: Any, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + chat_kwargs = kwargs.get("chat_template_kwargs", {}) or {} + thinking = bool(chat_kwargs.get("thinking", True)) + self._parser = ( + DeepSeekR1ReasoningParser(tokenizer, *args, **kwargs) + if thinking + else IdentityReasoningParser(tokenizer, *args, **kwargs) + ) + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + return self._parser.is_reasoning_end(input_ids) + + def is_reasoning_end_streaming( + self, input_ids: Sequence[int], delta_ids: Sequence[int] + ) -> bool: + return self._parser.is_reasoning_end_streaming(input_ids, delta_ids) + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + return self._parser.extract_content_ids(input_ids) + + def extract_reasoning( + self, model_output: str, request: Any + ) -> tuple[str | None, str | None]: + return self._parser.extract_reasoning(model_output, request) + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + return self._parser.extract_reasoning_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + ) diff --git a/endpoints/OAI/reasoning/minimax_m2_reasoning_parser.py b/endpoints/OAI/reasoning/minimax_m2_reasoning_parser.py new file mode 100644 index 00000000..ca036d05 --- /dev/null +++ b/endpoints/OAI/reasoning/minimax_m2_reasoning_parser.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Any + +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage, ReasoningParser +from endpoints.OAI.reasoning.basic_parsers import BaseThinkingReasoningParser + + +class MiniMaxM2ReasoningParser(BaseThinkingReasoningParser): + @property + def start_token(self) -> str: + return "" + + @property + def end_token(self) -> str: + return "" + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + if len(delta_token_ids) == 1 and delta_token_ids[0] == self.end_token_id: + return None + + if self.end_token_id in previous_token_ids: + return DeltaMessage(content=delta_text) + + if self.end_token_id in delta_token_ids: + end_index = delta_text.find(self.end_token) + reasoning = delta_text[:end_index] + content = delta_text[end_index + len(self.end_token) :] + return DeltaMessage(reasoning=reasoning or None, content=content or None) + + return DeltaMessage(reasoning=delta_text) + + +class MiniMaxM2AppendThinkReasoningParser(ReasoningParser): + def __init__(self, tokenizer: Any, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + self.end_token_id = self.vocab.get("") + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + return any(input_id == self.end_token_id for input_id in reversed(input_ids)) + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + return input_ids + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + if len(previous_token_ids) == 0: + delta_text = "" + delta_text + return DeltaMessage(content=delta_text) + + def extract_reasoning( + self, model_output: str, request: Any + ) -> tuple[str | None, str | None]: + return None, "" + model_output diff --git a/endpoints/OAI/reasoning/mistral_reasoning_parser.py b/endpoints/OAI/reasoning/mistral_reasoning_parser.py new file mode 100644 index 00000000..bda3fabf --- /dev/null +++ b/endpoints/OAI/reasoning/mistral_reasoning_parser.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Any + +from endpoints.OAI.reasoning.basic_parsers import BaseThinkingReasoningParser + + +class MistralReasoningParser(BaseThinkingReasoningParser): + @property + def start_token(self) -> str: + try: + from mistral_common.tokens.tokenizers.base import SpecialTokens + + return SpecialTokens.begin_think + except Exception: + return "[THINK]" + + @property + def end_token(self) -> str: + try: + from mistral_common.tokens.tokenizers.base import SpecialTokens + + return SpecialTokens.end_think + except Exception: + return "[/THINK]" + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + has_eot = False + for token_id in reversed(input_ids): + if token_id == self.start_token_id: + return has_eot + if token_id == self.end_token_id: + has_eot = True + return False + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + has_bot = False + has_eot = False + bot_idx = -1 + eot_idx = -1 + for i, token_id in enumerate(input_ids): + if token_id == self.start_token_id and not has_bot: + has_bot = True + bot_idx = i + elif token_id == self.end_token_id: + has_eot = True + eot_idx = i + break + + if has_bot and not has_eot: + return input_ids[:bot_idx] + if not has_bot and not has_eot: + return input_ids + if has_bot and has_eot: + return input_ids[:bot_idx] + input_ids[eot_idx + 1 :] + return input_ids[:eot_idx] + input_ids[eot_idx + 1 :] + + def extract_reasoning( + self, model_output: str, request: Any + ) -> tuple[str | None, str | None]: + if not model_output: + return None, "" + + prefix, bot, post_bot = model_output.partition(self.start_token) + has_bot = bool(bot) + has_valid_eot = has_bot and self.end_token in post_bot + + if has_bot and has_valid_eot: + reasoning, _, post_eot = post_bot.partition(self.end_token) + content = prefix + post_eot + return reasoning or None, content or None + if has_bot: + return post_bot or None, prefix or None + + if self.end_token in prefix: + pre_eot, _, post_eot = prefix.partition(self.end_token) + return None, (pre_eot + post_eot) or None + + return None, prefix diff --git a/endpoints/OAI/reasoning/olmo3_reasoning_parser.py b/endpoints/OAI/reasoning/olmo3_reasoning_parser.py new file mode 100644 index 00000000..6e5f0948 --- /dev/null +++ b/endpoints/OAI/reasoning/olmo3_reasoning_parser.py @@ -0,0 +1,166 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import dataclasses as dt +import enum +from collections.abc import Sequence +from typing import Any + +try: + import regex as re +except ImportError: # pragma: no cover + import re + +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage, ReasoningParser + + +class Olmo3ReasoningState(enum.Enum): + REASONING = 1 + CONTENT = 2 + + +@dt.dataclass(frozen=True) +class Indices: + start: int + end: int + + def __len__(self): + return self.end - self.start + + +def string_overlap(a: str, b: str) -> tuple[Indices | None, Indices | None]: + a, b, swap = (a, b, False) if len(a) < len(b) else (b, a, True) + + if a in b: + ind_a = Indices(0, len(a)) + ind_b = Indices(b.index(a), b.index(a) + len(a)) + return (ind_b, ind_a) if swap else (ind_a, ind_b) + + for i in range(len(a) - 1, 0, -1): + if a[-i:] == b[:i]: + ind_a = Indices(len(a) - i, len(a)) + ind_b = Indices(0, i) + return (ind_b, ind_a) if swap else (ind_a, ind_b) + + for i in range(len(a) - 1, 0, -1): + if b[-i:] == a[:i]: + ind_a = Indices(0, i) + ind_b = Indices(len(b) - i, len(b)) + return (ind_b, ind_a) if swap else (ind_a, ind_b) + + return None, None + + +@dt.dataclass +class Olmo3ReasoningBuffer: + think_start: str = "" + think_end: str = "" + buffer: str = "" + state: Olmo3ReasoningState = Olmo3ReasoningState.REASONING + + def process_buffer(self) -> DeltaMessage | None: + start_think_idx = self.buffer.find(self.think_start) + if start_think_idx >= 0: + self.state = Olmo3ReasoningState.REASONING + pretext, self.buffer = ( + self.buffer[:start_think_idx], + self.buffer[start_think_idx + len(self.think_start) :], + ) + if start_think_idx > 0: + return DeltaMessage(content=pretext) + + end_think_idx = self.buffer.rfind(self.think_end) + if end_think_idx >= 0: + self.state = Olmo3ReasoningState.CONTENT + pretext, self.buffer = ( + self.buffer[:end_think_idx], + self.buffer[end_think_idx + len(self.think_end) :], + ) + if end_think_idx > 0: + return DeltaMessage(reasoning=pretext) + + if self.state == Olmo3ReasoningState.REASONING: + text_buffer, self.buffer = self.buffer, "" + return DeltaMessage(reasoning=text_buffer) + + if self.state == Olmo3ReasoningState.CONTENT: + text_buffer, self.buffer = self.buffer, "" + return DeltaMessage(content=text_buffer) + + return None + + def add_text(self, delta_text: str) -> DeltaMessage | None: + self.buffer += delta_text + delta_message: DeltaMessage | None = None + + _, overlap_think_start = string_overlap(delta_text, self.think_start) + _, overlap_think_end = string_overlap(delta_text, self.think_end) + + partial_overlap_start = overlap_think_start is not None and len( + overlap_think_start + ) < len(self.think_start) + partial_overlap_end = overlap_think_end is not None and len(overlap_think_end) < len( + self.think_end + ) + + if partial_overlap_start and self.think_start in self.buffer and not partial_overlap_end: + delta_message = self.process_buffer() + elif partial_overlap_end and self.think_end in self.buffer: + delta_message = self.process_buffer() + elif partial_overlap_start or partial_overlap_end: + return None + else: + delta_message = self.process_buffer() + + return delta_message + + +class Olmo3ReasoningParser(ReasoningParser): + def __init__(self, tokenizer: Any, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + + self.think_start = r"" + self.think_end = r"" + + reasoning_expr = ( + rf"^(?:{self.think_start})?(?P.*?)" + rf"{self.think_end}(?P.*)$" + ) + self.reasoning_regex = re.compile(reasoning_expr, re.DOTALL) + self.buffer = Olmo3ReasoningBuffer( + think_start=self.think_start, think_end=self.think_end + ) + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + text = self.model_tokenizer.decode(input_ids) + return self.think_end in text + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + return [] + + def extract_reasoning( + self, + model_output: str, + request: Any, + ) -> tuple[str | None, str | None]: + re_match = self.reasoning_regex.match(model_output) + if re_match: + reasoning = re_match.group("reasoning") or None + content = re_match.group("content") or None + return reasoning, content + return None, model_output + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + delta_message = self.buffer.add_text(delta_text) + if delta_message is None and self.buffer.think_end in self.buffer.buffer: + delta_message = self.buffer.process_buffer() + return delta_message diff --git a/endpoints/OAI/reasoning/qwen3_reasoning_parser.py b/endpoints/OAI/reasoning/qwen3_reasoning_parser.py new file mode 100644 index 00000000..0b7cb417 --- /dev/null +++ b/endpoints/OAI/reasoning/qwen3_reasoning_parser.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Any + +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage +from endpoints.OAI.reasoning.basic_parsers import BaseThinkingReasoningParser + + +class Qwen3ReasoningParser(BaseThinkingReasoningParser): + """ + Reasoning parser for the Qwen3/Qwen3.5 model family. + + Qwen3.5 chat templates prefill `` in the prompt, so streaming output + usually begins with reasoning text and only later emits ``. + This parser mirrors vLLM behavior for that path while also honoring + `enable_thinking=False` by routing all deltas as normal content. + """ + + def __init__(self, tokenizer: Any, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + + chat_kwargs = kwargs.get("chat_template_kwargs", {}) or {} + enable_thinking = chat_kwargs.get("enable_thinking") + if enable_thinking is None: + enable_thinking = chat_kwargs.get("thinking") + + # Only force "prefilled " behavior when the template explicitly + # exposes a thinking switch. Templates like Qwen3-Next's tokenizer + # config do not, and should fall back to normal tag-based parsing. + self.thinking_enabled = ( + None if enable_thinking is None else bool(enable_thinking) + ) + + @property + def start_token(self) -> str: + return "" + + @property + def end_token(self) -> str: + return "" + + def _strip_reasoning_tags(self, text: str) -> str: + if not text: + return "" + return text.replace(self.start_token, "").replace(self.end_token, "") + + def extract_reasoning( + self, model_output: str, request: Any + ) -> tuple[str | None, str | None]: + if self.thinking_enabled is None: + if self.start_token not in model_output or self.end_token not in model_output: + return None, model_output or None + + _, _, tail = model_output.partition(self.start_token) + reasoning, _, content = tail.partition(self.end_token) + return reasoning or None, content or None + + if not self.thinking_enabled: + content = self._strip_reasoning_tags(model_output) + return None, content or None + + # Qwen3.5 templates prefill in the prompt. + # If appears in output (legacy templates), strip it. + model_output_parts = model_output.partition(self.start_token) + model_output = ( + model_output_parts[2] if model_output_parts[1] else model_output_parts[0] + ) + + if self.end_token not in model_output: + return None, model_output + + reasoning, _, content = model_output.partition(self.end_token) + return reasoning or None, content or None + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + if self.thinking_enabled is None: + return super().extract_reasoning_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + ) + + if not self.thinking_enabled: + cleaned = self._strip_reasoning_tags(delta_text) + if not cleaned: + return None + return DeltaMessage(content=cleaned) + + # Handle old templates where model may generate . + if self.start_token_id in delta_token_ids: + start_idx = delta_text.find(self.start_token) + if start_idx >= 0: + delta_text = delta_text[start_idx + len(self.start_token) :] + + if self.end_token_id in delta_token_ids: + end_index = delta_text.find(self.end_token) + if end_index >= 0: + reasoning = delta_text[:end_index] + content = delta_text[end_index + len(self.end_token) :] + if not reasoning and not content: + return None + return DeltaMessage( + reasoning=reasoning if reasoning else None, + content=content if content else None, + ) + # End token id is present but text was already stripped by backend. + return None + + if not delta_text: + return None + if self.end_token_id in previous_token_ids: + return DeltaMessage(content=delta_text) + return DeltaMessage(reasoning=delta_text) diff --git a/endpoints/OAI/reasoning/seedoss_reasoning_parser.py b/endpoints/OAI/reasoning/seedoss_reasoning_parser.py new file mode 100644 index 00000000..6d7a964c --- /dev/null +++ b/endpoints/OAI/reasoning/seedoss_reasoning_parser.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from endpoints.OAI.reasoning.basic_parsers import BaseThinkingReasoningParser + + +class SeedOSSReasoningParser(BaseThinkingReasoningParser): + @property + def start_token(self) -> str: + return "" + + @property + def end_token(self) -> str: + return "" diff --git a/endpoints/OAI/reasoning/step3_reasoning_parser.py b/endpoints/OAI/reasoning/step3_reasoning_parser.py new file mode 100644 index 00000000..bedaf620 --- /dev/null +++ b/endpoints/OAI/reasoning/step3_reasoning_parser.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Any + +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage, ReasoningParser + + +class Step3ReasoningParser(ReasoningParser): + def __init__(self, tokenizer: Any, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + self.think_end_token = "" + self.think_end_token_id = self.vocab.get(self.think_end_token) + if self.think_end_token_id is None: + raise RuntimeError( + "Step3 reasoning parser could not locate think end token in tokenizer" + ) + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + if len(delta_token_ids) == 1 and delta_token_ids[0] == self.think_end_token_id: + return None + + if self.think_end_token_id in delta_token_ids: + end_index = delta_text.find(self.think_end_token) + reasoning = delta_text[:end_index] + content = delta_text[end_index + len(self.think_end_token) :] + return DeltaMessage(reasoning=reasoning, content=content or None) + + if self.think_end_token_id in previous_token_ids: + return DeltaMessage(content=delta_text) + + return DeltaMessage(reasoning=delta_text) + + def extract_reasoning( + self, model_output: str, request: Any + ) -> tuple[str | None, str | None]: + if self.think_end_token not in model_output: + return model_output or None, None + + end_index = model_output.find(self.think_end_token) + reasoning = model_output[:end_index] + content = model_output[end_index + len(self.think_end_token) :] + return reasoning or None, content or None + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + return self.think_end_token_id in input_ids + + def is_reasoning_end_streaming( + self, input_ids: Sequence[int], delta_ids: Sequence[int] + ) -> bool: + return self.think_end_token_id in delta_ids + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + if self.think_end_token_id not in input_ids[:-1]: + return [] + return input_ids[input_ids.index(self.think_end_token_id) + 1 :] diff --git a/endpoints/OAI/reasoning/step3p5_reasoning_parser.py b/endpoints/OAI/reasoning/step3p5_reasoning_parser.py new file mode 100644 index 00000000..9d56282e --- /dev/null +++ b/endpoints/OAI/reasoning/step3p5_reasoning_parser.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Any + +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage +from endpoints.OAI.reasoning.basic_parsers import BaseThinkingReasoningParser + + +class Step3p5ReasoningParser(BaseThinkingReasoningParser): + @property + def start_token(self) -> str: + return "" + + @property + def end_token(self) -> str: + return "" + + def __init__(self, tokenizer: Any, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + self._pending_reasoning_newline = False + self.end_offset = 1 + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + if self.end_token_id in input_ids and self.end_offset > 0: + self.end_offset -= 1 + return False + return self.end_offset < 1 + + def is_reasoning_end_streaming( + self, input_ids: Sequence[int], delta_ids: Sequence[int] + ) -> bool: + if self.end_token_id in input_ids and self.end_offset > 0: + self.end_offset -= 1 + return False + return self.end_offset < 1 + + def extract_reasoning( + self, model_output: str, request: Any + ) -> tuple[str | None, str | None]: + reasoning, content = super().extract_reasoning(model_output, request) + if reasoning is not None: + reasoning = reasoning.removesuffix("\n") + if content is not None: + content = content.removeprefix("\n") + return reasoning or None, content or None + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + if previous_text.endswith(self.end_token) and delta_text: + if delta_text == "\n": + return None + if delta_text.startswith("\n"): + remaining = delta_text.removeprefix("\n") + return DeltaMessage(content=remaining) if remaining else None + + ret = super().extract_reasoning_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + ) + if ret is None: + return None + + if ( + self.start_token_id not in previous_token_ids + and self.start_token_id not in delta_token_ids + ): + if self.end_token_id in delta_token_ids: + end_index = delta_text.find(self.end_token) + reasoning = delta_text[:end_index] + content = delta_text[end_index + len(self.end_token) :] + ret = DeltaMessage(reasoning=reasoning, content=content or None) + elif self.end_token_id in previous_token_ids: + ret = DeltaMessage(content=delta_text) + else: + ret = DeltaMessage(reasoning=delta_text) + + reasoning_to_output = ret.reasoning + content_to_output = ret.content + + if reasoning_to_output is not None: + if self._pending_reasoning_newline: + reasoning_to_output = "\n" + reasoning_to_output + self._pending_reasoning_newline = False + + if reasoning_to_output.endswith("\n"): + reasoning_to_output = reasoning_to_output.removesuffix("\n") + if self.end_token in delta_text: + self._pending_reasoning_newline = False + else: + self._pending_reasoning_newline = True + + if content_to_output is not None: + self.end_offset -= 1 + self._pending_reasoning_newline = False + if self.end_token in delta_text and content_to_output.startswith("\n"): + content_to_output = content_to_output.removeprefix("\n") + + reasoning_to_output = reasoning_to_output or None + content_to_output = content_to_output or None + if reasoning_to_output is None and content_to_output is None: + return None + + return DeltaMessage(reasoning=reasoning_to_output, content=content_to_output) diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 8f4e7a4e..d03f61ec 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -16,6 +16,7 @@ from endpoints.OAI.types.embedding import EmbeddingsRequest, EmbeddingsResponse from endpoints.OAI.utils.chat_completion import ( apply_chat_template, + chat_completions_available, generate_chat_completion, stream_generate_chat_completion, ) @@ -113,7 +114,7 @@ async def chat_completion_request( else: await check_model_container() - if model.container.prompt_template is None: + if not chat_completions_available(): error_message = handle_request_error( "Chat completions are disabled because a prompt template is not set.", exc_info=False, diff --git a/endpoints/OAI/types/chat_completion.py b/endpoints/OAI/types/chat_completion.py index 52523149..61112d4b 100644 --- a/endpoints/OAI/types/chat_completion.py +++ b/endpoints/OAI/types/chat_completion.py @@ -1,10 +1,10 @@ -from pydantic import AliasChoices, BaseModel, Field, field_validator +from pydantic import AliasChoices, BaseModel, Field, field_validator, model_validator from time import time from typing import Literal, Union, List, Optional, Dict from uuid import uuid4 from endpoints.OAI.types.common import UsageStats, CommonCompletionRequest -from endpoints.OAI.types.tools import ToolSpec, ToolCall +from endpoints.OAI.types.tools import NamedToolChoice, ToolSpec, ToolCall class ChatCompletionLogprob(BaseModel): @@ -30,6 +30,8 @@ class ChatCompletionMessagePart(BaseModel): class ChatCompletionMessage(BaseModel): role: str = "user" content: Optional[Union[str, List[ChatCompletionMessagePart]]] = None + reasoning: Optional[str] = None + reasoning_content: Optional[str] = None tool_calls: Optional[List[ToolCall]] = None tool_call_id: Optional[str] = None @@ -49,7 +51,7 @@ class ChatCompletionStreamChoice(BaseModel): # Index is 0 since we aren't using multiple choices index: int = 0 finish_reason: Optional[str] = None - delta: Union[ChatCompletionMessage, dict] = {} + delta: Union[ChatCompletionMessage, dict] = Field(default_factory=dict) logprobs: Optional[ChatCompletionLogprobs] = None @@ -59,18 +61,25 @@ class ChatCompletionRequest(CommonCompletionRequest): prompt_template: Optional[str] = None add_generation_prompt: Optional[bool] = True template_vars: Optional[dict] = Field( - default={}, + default_factory=dict, validation_alias=AliasChoices("template_vars", "chat_template_kwargs"), description="Aliases: chat_template_kwargs", ) + enable_thinking: Optional[bool] = None + thinking: Optional[bool] = None response_prefix: Optional[str] = None model: Optional[str] = None + include_reasoning: Optional[bool] = True # tools is follows the format OAI schema, functions is more flexible # both are available in the chat template. tools: Optional[List[ToolSpec]] = None functions: Optional[List[Dict]] = None + tool_choice: Optional[ + Union[Literal["none", "auto", "required"], NamedToolChoice] + ] = None + parallel_tool_calls: Optional[bool] = True # Chat completions requests do not have a BOS token preference. Backend # respects the tokenization config for the individual model. @@ -81,6 +90,20 @@ def force_bos_token(cls, v): """Always disable add_bos_token with chat completions.""" return None + @model_validator(mode="after") + def apply_thinking_aliases(self): + """Support clients that send thinking flags at the top-level.""" + template_vars = dict(self.template_vars or {}) + + if self.enable_thinking is not None and "enable_thinking" not in template_vars: + template_vars["enable_thinking"] = self.enable_thinking + + if self.thinking is not None and "thinking" not in template_vars: + template_vars["thinking"] = self.thinking + + self.template_vars = template_vars + return self + class ChatCompletionResponse(BaseModel): id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}") diff --git a/endpoints/OAI/types/tools.py b/endpoints/OAI/types/tools.py index b5b9611f..1e572663 100644 --- a/endpoints/OAI/types/tools.py +++ b/endpoints/OAI/types/tools.py @@ -1,5 +1,5 @@ from pydantic import BaseModel, Field -from typing import Dict, Literal +from typing import Dict, Literal, Optional from uuid import uuid4 @@ -28,8 +28,28 @@ class Tool(BaseModel): class ToolCall(BaseModel): - """Represents an OAI tool description.""" + """Represents an OAI tool call. + + The ``index`` field is optional so it can be omitted in non-streaming + responses (where OpenAI does not include it) via ``exclude_none=True``, + while being set explicitly for streaming deltas where it is required + by strict validators like the Vercel AI SDK. + """ - id: str = Field(default_factory=lambda: str(uuid4()).replace("-", "")[:9]) + id: str = Field(default_factory=lambda: f"call_{uuid4().hex[:24]}") function: Tool type: Literal["function"] = "function" + index: Optional[int] = None + + +class NamedToolFunction(BaseModel): + """Represents a named function reference for tool_choice.""" + + name: str + + +class NamedToolChoice(BaseModel): + """Represents a named tool choice (forces a specific function call).""" + + function: NamedToolFunction + type: Literal["function"] = "function" diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index b559bb2b..8beb8f8e 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -1,8 +1,10 @@ """Chat completion utilities for OAI server.""" import asyncio +import json import pathlib from asyncio import CancelledError +from dataclasses import dataclass, field from typing import List, Optional from fastapi import HTTPException, Request from jinja2 import TemplateError @@ -16,7 +18,10 @@ handle_request_error, request_disconnect_loop, ) +from common.tabby_config import config +from common.tokenizer_modes import normalize_tokenizer_mode from common.utils import unwrap +from endpoints.OAI.reasoning import ReasoningParserManager from endpoints.OAI.types.chat_completion import ( ChatCompletionLogprobs, ChatCompletionLogprob, @@ -28,24 +33,279 @@ ChatCompletionStreamChoice, ) from endpoints.OAI.types.common import UsageStats +from endpoints.OAI.types.tools import NamedToolChoice, ToolCall from endpoints.OAI.utils.completion import _parse_gen_request_id, _stream_collector +from endpoints.OAI.utils.parser_options import ( + list_tool_call_parsers, + parser_uses_native_tool_generation, + resolve_tool_call_format, +) from endpoints.OAI.utils.tools import ToolCallProcessor, TOOL_CALL_SCHEMA +@dataclass +class _StreamReasoningState: + text: str = "" + token_ids: List[int] = field(default_factory=list) + + +DEEPSEEK_VL2_ARCH = "DeepseekVLV2ForCausalLM" + + +class _TokenizerAdapter: + """Expose the minimal tokenizer interface required by reasoning parsers.""" + + def __init__(self): + self._vocab = None + + def get_vocab(self) -> dict[str, int]: + if self._vocab is not None: + return self._vocab + + tokenizer = model.container.tokenizer + if hasattr(tokenizer, "get_vocab"): + self._vocab = tokenizer.get_vocab() + return self._vocab + + pieces = tokenizer.get_id_to_piece_list(True) + vocab: dict[str, int] = {} + for token_id, piece in enumerate(pieces): + if piece not in vocab: + vocab[piece] = token_id + self._vocab = vocab + return vocab + + +def _token_ids_from_generation(generation: dict) -> List[int]: + token_ids = generation.get("token_ids") + if token_ids is None: + return [] + if isinstance(token_ids, list): + return token_ids + if hasattr(token_ids, "flatten"): + return token_ids.flatten().tolist() + return list(token_ids) + + +def _get_tokenizer_mode() -> str: + container_mode = getattr(model.container, "tokenizer_mode", None) + tokenizer_mode, mode_message = normalize_tokenizer_mode( + unwrap(container_mode, unwrap(config.model.tokenizer_mode, "auto")) + ) + if mode_message: + logger.warning(mode_message) + return tokenizer_mode + + +def _uses_builtin_chat_serializer() -> bool: + container = model.container + cfg = getattr(container, "config", None) + return bool(cfg and getattr(cfg, "architecture", None) == DEEPSEEK_VL2_ARCH) + + +def chat_completions_available() -> bool: + return bool(getattr(model.container, "prompt_template", None)) or _uses_builtin_chat_serializer() + + +def _get_template_tooling_defaults() -> tuple[Optional[str], str]: + prompt_template = getattr(model.container, "prompt_template", None) + metadata = getattr(prompt_template, "metadata", None) + if metadata is None: + return None, "json" + return metadata.tool_start, metadata.tool_call_format + + +def _truncate_mistral_tool_ids(message_dicts: List[dict]) -> None: + """Match vLLM mistral_common behavior by truncating tool IDs to 9 chars.""" + for message in message_dicts: + role = message.get("role") + + if role == "assistant": + for tool_call in message.get("tool_calls", []): + tool_id = tool_call.get("id") + if isinstance(tool_id, str) and len(tool_id) > 9: + tool_call["id"] = tool_id[-9:] + + if role in {"tool", "tool_results"}: + tool_call_id = message.get("tool_call_id") + if isinstance(tool_call_id, str) and len(tool_call_id) > 9: + message["tool_call_id"] = tool_call_id[-9:] + + +def _sanitize_mistral_tool_choice_ids(request_data: ChatCompletionRequest) -> None: + """Normalize tool_choice IDs so mistral templates don't fail validation.""" + for message in request_data.messages: + if message.role == "assistant" and message.tool_calls: + for tool_call in message.tool_calls: + if isinstance(tool_call.id, str) and len(tool_call.id) > 9: + tool_call.id = tool_call.id[-9:] + elif message.role in {"tool", "tool_results"} and message.tool_call_id: + if len(message.tool_call_id) > 9: + message.tool_call_id = message.tool_call_id[-9:] + + +def _build_parser_instance(parser_key: str, template_kwargs: dict): + parser_cls = ReasoningParserManager.get_reasoning_parser(parser_key) + return parser_cls(_TokenizerAdapter(), chat_template_kwargs=template_kwargs) + + +def _build_reasoning_parser(request_data: ChatCompletionRequest): + parser_key = unwrap(config.model.reasoning_parser, "basic") or "basic" + + template_kwargs = unwrap(request_data.template_vars, {}) + + try: + return _build_parser_instance(parser_key, template_kwargs) + except KeyError as exc: + raise HTTPException(400, str(exc)) from exc + except RuntimeError as exc: + # Keep compatibility for models that do not expose thinking tags. + if parser_key == "basic": + logger.warning( + "Reasoning parser 'basic' could not initialize ({}). " + "Falling back to identity parser.", + str(exc), + ) + identity_cls = ReasoningParserManager.get_reasoning_parser("identity") + return identity_cls(_TokenizerAdapter(), chat_template_kwargs=template_kwargs) + + # Mistral parser only applies when think tokens exist. If a + # non-Mistral model is selected, fall back to basic behavior. + if parser_key == "mistral": + logger.warning( + "Reasoning parser 'mistral' could not initialize ({}). " + "Falling back to 'basic'.", + str(exc), + ) + try: + return _build_parser_instance("basic", template_kwargs) + except Exception: + identity_cls = ReasoningParserManager.get_reasoning_parser("identity") + return identity_cls( + _TokenizerAdapter(), chat_template_kwargs=template_kwargs + ) + + raise HTTPException(400, str(exc)) from exc + + +def _validate_and_get_tool_call_format( + request_data: ChatCompletionRequest, default_format: str +) -> str: + tool_choice = request_data.tool_choice + parser_key = config.model.tool_call_parser + enable_auto = bool(config.model.enable_auto_tool_choice) + parser_names = list_tool_call_parsers() + + if parser_key and parser_key not in parser_names: + parsers_str = ", ".join(sorted(parser_names)) + raise HTTPException( + 400, + f"invalid tool call parser: {parser_key} (choose from {{{parsers_str}}})", + ) + + if tool_choice == "auto" and (not enable_auto or not parser_key): + raise HTTPException( + 400, + '"auto" tool choice requires --enable-auto-tool-choice and ' + "--tool-call-parser to be set", + ) + + if tool_choice not in (None, "none", "auto") and parser_key is None: + raise HTTPException( + 400, + f'tool_choice="{tool_choice}" requires --tool-call-parser to be set', + ) + + if ( + tool_choice == "none" + and config.model.exclude_tools_when_tool_choice_none + and request_data.tools + ): + request_data.tools = None + + resolved_format = resolve_tool_call_format(parser_key, default_format) + if not resolved_format: + raise HTTPException( + 400, + f"Could not resolve format for tool_call_parser={parser_key}", + ) + return resolved_format + + +def _serialize_stream_chunk(chunk) -> str: + """Serialize a streaming chunk with OpenAI-compatible field handling. + + Uses exclude_none=True to strip irrelevant null fields (tool_calls, + tool_call_id, logprobs, usage) while ensuring finish_reason is always + present on each choice (as null when not set), matching OpenAI's + observed streaming behavior. + """ + d = chunk.model_dump(exclude_none=True) + for choice in d.get("choices", []): + if "finish_reason" not in choice: + choice["finish_reason"] = None + return json.dumps(d, ensure_ascii=False) + + def _create_response( - request_id: str, generations: List[dict], model_name: Optional[str] + request_id: str, + generations: List[dict], + model_name: Optional[str], + tool_call_format: str = "json", + tool_choice=None, ): """Create a chat completion response from the provided text.""" choices = [] + parser_key = config.model.tool_call_parser for index, generation in enumerate(generations): + reasoning = generation.get("reasoning") + reasoning_content = generation.get("reasoning_content") message = ChatCompletionMessage( - role="assistant", content=unwrap(generation.get("text"), "") + role="assistant", + content=generation.get("text"), + reasoning=reasoning, + reasoning_content=reasoning_content, ) - tool_calls = generation["tool_calls"] - if tool_calls: - message.tool_calls = ToolCallProcessor.from_json(tool_calls) + tool_calls_raw = generation.get("tool_calls") + if tool_calls_raw: + parsed = ToolCallProcessor.parse( + tool_calls_raw, + format=tool_call_format, + parser_key=parser_key, + ) + if parsed and isinstance(tool_choice, NamedToolChoice): + parsed = ToolCallProcessor.filter_by_name( + parsed, tool_choice.function.name + ) + if parsed: + message.tool_calls = parsed + message.content = None + else: + logger.warning( + "Tool call text present but parsing returned no results " + f"(format={tool_call_format})" + ) + + # Fallback: detect bare XML tool calls in content that were not + # caught by the two-pass system (model never emitted tool_start) + if ( + tool_call_format in ("xml", "auto") + and not message.tool_calls + and message.content + and " List[ChatCompletionStreamChunk]: + """Build the OpenAI-standard streaming sequence for tool calls. + + Emits two chunks: + 1. Tool-call chunk: role="assistant", complete tool_calls with + index/id/type/name/arguments (all data in one chunk). + 2. Finish chunk: empty delta, finish_reason="tool_calls". + + Complete arguments are sent in a single chunk rather than streamed + incrementally, which is valid per OpenAI's spec (clients concatenate + argument strings across deltas) and maximizes compatibility with + clients that may not implement multi-chunk tool-call assembly. + + The tool_calls are placed directly into a ChatCompletionMessage + (not a raw dict) so Pydantic validates them as ToolCall objects + with the index field preserved (ToolCall declares index as Optional[int]). + """ + chunk_id = f"chatcmpl-{request_id}" + + # Set index on each tool call for streaming + for idx, tc in enumerate(tool_calls): + tc.index = idx + + # Chunk 1: Complete tool call data + tool_call_message = ChatCompletionMessage( + role="assistant", + tool_calls=tool_calls, + ) + tool_chunk = ChatCompletionStreamChunk( + id=chunk_id, + choices=[ + ChatCompletionStreamChoice( + index=choice_index, + delta=tool_call_message, + finish_reason=None, + ) + ], + model=model_name, + ) + + # Chunk 2: Finish signal + # Use model_construct to prevent Pydantic's smart Union from + # coercing the empty dict {} into ChatCompletionMessage(role="user") + finish_choice = ChatCompletionStreamChoice.model_construct( + index=choice_index, + delta={}, + finish_reason="tool_calls", + logprobs=None, + ) + finish_chunk = ChatCompletionStreamChunk( + id=chunk_id, + choices=[finish_choice], + model=model_name, + ) + + return [tool_chunk, finish_chunk] + + async def _append_template_metadata(data: ChatCompletionRequest, template_vars: dict): """Adding metadata is a one-time process.""" + if model.container.prompt_template is None: + return + template_metadata = await model.container.prompt_template.extract_metadata( template_vars ) @@ -237,6 +563,27 @@ async def format_messages_with_template( message_dicts.append(message.model_dump(exclude_none=True)) + if _get_tokenizer_mode() == "mistral": + _truncate_mistral_tool_ids(message_dicts) + + # Pre-template: convert tool_call arguments from JSON strings to dicts. + # OpenAI-compatible clients (Kilo, Roo, etc.) send arguments as JSON + # strings per the OAI spec, but Qwen3-Coder's template calls + # .items() on arguments which requires a dict/mapping. + for msg in message_dicts: + if msg.get("tool_calls"): + for tc in msg["tool_calls"]: + func = tc.get("function", {}) + args = func.get("arguments") + if isinstance(args, str): + try: + func["arguments"] = json.loads(args) + except (json.JSONDecodeError, ValueError): + logger.warning( + "Failed to parse tool_call arguments JSON " + "string to dict, keeping as string" + ) + # Get all special tokens special_tokens_dict = model.container.get_special_tokens() @@ -246,6 +593,91 @@ async def format_messages_with_template( return prompt, mm_embeddings, template_vars +async def format_messages_with_builtin_serializer( + messages: List[ChatCompletionMessage], +): + """Serialize messages for models that define their own chat protocol.""" + + if not _uses_builtin_chat_serializer(): + raise HTTPException(400, "No built-in chat serializer is available.") + + mm_embeddings = MultimodalEmbeddingWrapper() if model.container.use_vision else None + pending_system: List[str] = [] + segments: List[str] = [] + last_non_system_role: Optional[str] = None + + for message in messages: + content = message.content + if isinstance(content, list): + concatenated_content = "" + previous_part_type: Optional[str] = None + for part in content: + if part.type == "text" and part.text: + if previous_part_type == "image_url" and concatenated_content: + concatenated_content += "\n" + concatenated_content += part.text + elif part.type == "image_url" and mm_embeddings: + if ( + previous_part_type == "text" + and concatenated_content + and not concatenated_content.endswith("\n") + ): + concatenated_content += "\n" + elif previous_part_type == "image_url" and concatenated_content: + concatenated_content += "\n" + await mm_embeddings.add(part.image_url.url) + concatenated_content += mm_embeddings.text_alias[-1] + previous_part_type = part.type + content = concatenated_content + + normalized_content = content or "" + role = (message.role or "user").lower() + + if role == "system": + if normalized_content: + pending_system.append(normalized_content) + continue + + if role == "user": + if pending_system: + normalized_content = "\n\n".join( + pending_system + ([normalized_content] if normalized_content else []) + ) + pending_system.clear() + + segments.append(f"<|User|>: {normalized_content}") + last_non_system_role = "user" + continue + + if role == "assistant": + segments.append(f"<|Assistant|>: {normalized_content}") + last_non_system_role = "assistant" + continue + + if role in {"tool", "tool_results"}: + tool_payload = normalized_content + if message.tool_call_id: + prefix = f"[tool_call_id={message.tool_call_id}]" + tool_payload = ( + f"{prefix}\n{tool_payload}" if tool_payload else prefix + ) + segments.append(f"<|User|>: {tool_payload}") + last_non_system_role = "user" + continue + + raise HTTPException( + 400, + f"Unsupported role '{message.role}' for built-in DeepSeek-VL2 chat serialization.", + ) + + if pending_system: + segments.append(f"<|User|>: {'\\n\\n'.join(pending_system)}") + last_non_system_role = "user" + + prompt = "\n\n".join(segments) + return prompt, mm_embeddings, {"last_non_system_role": last_non_system_role} + + async def apply_chat_template(data: ChatCompletionRequest): """ Compile the prompt and get any additional stop strings from the template. @@ -254,6 +686,35 @@ async def apply_chat_template(data: ChatCompletionRequest): # Locally store tools dict tools = data.model_dump()["tools"] + if _get_tokenizer_mode() == "mistral": + _sanitize_mistral_tool_choice_ids(data) + + if _uses_builtin_chat_serializer(): + if data.tools or data.functions or data.tool_choice not in (None, "none"): + raise HTTPException( + 400, + "Tool calling is not supported for the built-in DeepSeek-VL2 chat serializer.", + ) + + prompt, mm_embeddings, serializer_state = ( + await format_messages_with_builtin_serializer(data.messages) + ) + + last_non_system_role = serializer_state.get("last_non_system_role") + + if data.add_generation_prompt and last_non_system_role != "assistant": + prompt = f"{prompt}\n\n<|Assistant|>: " if prompt else "<|Assistant|>: " + + if data.response_prefix: + if data.add_generation_prompt: + prompt += data.response_prefix + else: + logger.warning( + "Could not add response prefix because " + "add_generation_prompt is False" + ) + + return prompt, mm_embeddings try: data.template_vars.update( @@ -318,11 +779,19 @@ async def stream_generate_chat_completion( abort_event = asyncio.Event() gen_queue = asyncio.Queue() gen_tasks: List[asyncio.Task] = [] - tool_start = model.container.prompt_template.metadata.tool_start + tool_start, default_tool_call_format = _get_template_tooling_defaults() disconnect_task = asyncio.create_task(request_disconnect_loop(request)) try: logger.info(f"Received chat completion streaming request {request.state.id}") + tool_call_format = _validate_and_get_tool_call_format( + data, default_tool_call_format + ) + reasoning_parser = _build_reasoning_parser(data) + reasoning_states = [_StreamReasoningState() for _ in range(0, data.n)] + force_tool_pass = data.tool_choice == "required" or isinstance( + data.tool_choice, NamedToolChoice + ) for idx in range(0, data.n): task_gen_params = data.model_copy(deep=True) @@ -342,18 +811,67 @@ async def stream_generate_chat_completion( gen_tasks.append(gen_task) - # Text accumulation for tool calls - current_generation_text = "" - # Consumer loop while True: + # Fast path: items already queued — no task overhead + if not gen_queue.empty(): + generation = gen_queue.get_nowait() + else: + # Slow path: queue empty — race get against disconnect + get_task = asyncio.create_task(gen_queue.get()) + done, _ = await asyncio.wait( + [get_task, disconnect_task], + return_when=asyncio.FIRST_COMPLETED, + ) + if disconnect_task in done: + get_task.cancel() + raise CancelledError() + generation = get_task.result() + if disconnect_task.done(): raise CancelledError() - generation = await gen_queue.get() + # Stream collector will push an exception to the queue if it fails + if isinstance(generation, Exception): + raise generation + + if "text" in generation and generation.get("finish_reason") is None: + idx = generation["index"] + state = reasoning_states[idx] + + delta_text = unwrap(generation.get("text"), "") + delta_token_ids = _token_ids_from_generation(generation) + + current_text = state.text + delta_text + current_token_ids = state.token_ids + delta_token_ids + + delta_message = reasoning_parser.extract_reasoning_streaming( + state.text, + current_text, + delta_text, + state.token_ids, + current_token_ids, + delta_token_ids, + ) + + state.text = current_text + state.token_ids = current_token_ids + + if delta_message is None: + continue + + generation["text"] = delta_message.content + if data.include_reasoning: + generation["reasoning"] = delta_message.reasoning + generation["reasoning_content"] = delta_message.reasoning + else: + generation["reasoning"] = None + generation["reasoning_content"] = None + if generation["text"] is None: + continue # Handle options if a tool model is present - if tool_start: + if (tool_start or force_tool_pass) and data.tool_choice != "none": if "stop_str" in generation: generations = await generate_tool_calls( prompt, @@ -361,21 +879,64 @@ async def stream_generate_chat_completion( data, [generation], request, + tool_call_format=tool_call_format, ) # Only one generation present in this case generation = generations[0] - elif "text" in generation: - current_generation_text += generation["text"] - # Stream collector will push an exception to the queue if it fails - if isinstance(generation, Exception): - raise generation + # Emit proper three-phase tool-call streaming sequence + if "tool_calls" in generation: + tool_calls_raw = generation["tool_calls"] + parsed = ToolCallProcessor.parse( + tool_calls_raw, + format=tool_call_format, + parser_key=config.model.tool_call_parser, + ) + if parsed and isinstance(data.tool_choice, NamedToolChoice): + parsed = ToolCallProcessor.filter_by_name( + parsed, data.tool_choice.function.name + ) + if parsed: + for tc_chunk in _build_tool_call_chunks( + parsed, + request.state.id, + model_path.name, + choice_index=generation.get("index", 0), + ): + yield _serialize_stream_chunk(tc_chunk) + + # Handle completion and usage after tool calls + if ( + all(task.done() for task in gen_tasks) + and gen_queue.empty() + ): + if ( + data.stream_options + and data.stream_options.include_usage + ): + usage_chunk = _create_stream_chunk( + request.state.id, + generation, + model_path.name, + is_usage_chunk=True, + ) + yield _serialize_stream_chunk(usage_chunk) + + logger.info( + "Finished chat completion streaming " + f"request {request.state.id}" + ) + yield "[DONE]" + break + continue response = _create_stream_chunk( - request.state.id, generation, model_path.name + request.state.id, + generation, + model_path.name, ) - yield response.model_dump_json() + yield _serialize_stream_chunk(response) # Check if all tasks are completed if all(task.done() for task in gen_tasks) and gen_queue.empty(): @@ -387,7 +948,7 @@ async def stream_generate_chat_completion( model_path.name, is_usage_chunk=True, ) - yield usage_chunk.model_dump_json() + yield _serialize_stream_chunk(usage_chunk) logger.info( f"Finished chat completion streaming request {request.state.id}" @@ -398,13 +959,16 @@ async def stream_generate_chat_completion( except CancelledError: # Get out if the request gets disconnected - if not abort_event.is_set(): - abort_event.set() - handle_request_disconnect("Chat completion generation cancelled by user.") + handle_request_disconnect("Chat completion generation cancelled by user.") + except HTTPException as exc: + yield get_generator_error(str(exc.detail)) except Exception: yield get_generator_error( "Chat completion aborted. Please check the server console." ) + finally: + abort_event.set() + disconnect_task.cancel() async def generate_chat_completion( @@ -415,7 +979,10 @@ async def generate_chat_completion( model_path: pathlib.Path, ): gen_tasks: List[asyncio.Task] = [] - tool_start = model.container.prompt_template.metadata.tool_start + tool_start, default_tool_call_format = _get_template_tooling_defaults() + tool_call_format = _validate_and_get_tool_call_format( + data, default_tool_call_format + ) try: logger.info(f"Received chat completion request {request.state.id}") @@ -437,16 +1004,46 @@ async def generate_chat_completion( generations = await asyncio.gather(*gen_tasks) # Check all the generations and see if a tool call is required - if tool_start: + force_tool_pass = data.tool_choice == "required" or isinstance( + data.tool_choice, NamedToolChoice + ) + if tool_start or force_tool_pass: generations = await generate_tool_calls( - prompt, embeddings, data, generations, request + prompt, + embeddings, + data, + generations, + request, + tool_call_format=tool_call_format, + ) + + reasoning_parser = _build_reasoning_parser(data) + for generation in generations: + reasoning, content = reasoning_parser.extract_reasoning( + unwrap(generation.get("text"), ""), + data, ) - response = _create_response(request.state.id, generations, model_path.name) + if not data.include_reasoning: + reasoning = None + + generation["reasoning"] = reasoning + generation["reasoning_content"] = reasoning + generation["text"] = content + + response = _create_response( + request.state.id, + generations, + model_path.name, + tool_call_format=tool_call_format, + tool_choice=data.tool_choice, + ) logger.info(f"Finished chat completion request {request.state.id}") return response + except HTTPException: + raise except Exception as exc: error_message = handle_request_error( f"Chat completion {request.state.id} aborted. " @@ -462,29 +1059,87 @@ async def generate_tool_calls( prompt: str, embeddings: MultimodalEmbeddingWrapper, data: ChatCompletionRequest, - generations: List[str], + generations: List[dict], request: Request, + tool_call_format: Optional[str] = None, ): gen_tasks: List[asyncio.Task] = [] - tool_start = model.container.prompt_template.metadata.tool_start + tool_start, default_tool_call_format = _get_template_tooling_defaults() + if tool_call_format is None: + tool_call_format = _validate_and_get_tool_call_format( + data, default_tool_call_format + ) + tool_choice = data.tool_choice + parser_key = config.model.tool_call_parser + use_native_generation = parser_uses_native_tool_generation( + parser_key, tool_call_format + ) + + if tool_choice == "none": + return generations # Tracks which generations asked for a tool call tool_idx: List[int] = [] # Copy to make sure the parent JSON schema doesn't get modified tool_data = data.model_copy(deep=True) - tool_data.json_schema = TOOL_CALL_SCHEMA + + if use_native_generation: + # Native syntax mode: let the model generate its natural tool-call + # representation without JSON schema constraint. + logger.debug( + "generate_tool_calls: Using parser '{}' in native mode " + "(format={}, no JSON schema constraint)", + parser_key or "template-default", + tool_call_format, + ) + + # Remove tool_start from stop strings so the model can emit + # multiple sequential blocks without stopping early + if ( + tool_start + and isinstance(tool_data.stop, list) + and tool_start in tool_data.stop + ): + tool_data.stop = [s for s in tool_data.stop if s != tool_start] + logger.debug( + f"generate_tool_calls: Removed '{tool_start}' from " + f"second-pass stop strings" + ) + else: + # JSON mode: constrained generation (existing behavior) + tool_data.json_schema = TOOL_CALL_SCHEMA for idx, gen in enumerate(generations): - if gen["stop_str"] != tool_start: + stop_str = gen.get("stop_str") + should_generate = stop_str == tool_start + + # Force tool generation if tool_choice requires it + if not should_generate and ( + tool_choice == "required" or isinstance(tool_choice, NamedToolChoice) + ): + should_generate = True + + if not should_generate: continue - logger.info(f"Detected tool call in chat completion request {request.state.id}") + logger.info( + f"Detected tool call in chat completion request " + f"{request.state.id} (format={tool_call_format})" + ) - # Append the existing generation text if present + # Build per-generation prompt (avoid mutating shared prompt) + tool_prompt = prompt precursor_text = gen.get("full_text") if precursor_text: - prompt = prompt + precursor_text + tool_prompt = tool_prompt + precursor_text + + # For native generation mode: append tool_start back to prompt. + # The stop string was consumed by the first pass and not included + # in full_text, but the model expects to continue after tool_start. + # Include a trailing newline to match the canonical template format. + if use_native_generation and tool_start: + tool_prompt = tool_prompt + tool_start + "\n" gen_request_id = gen.get("request_id") tool_request_id = f"{gen_request_id}-tool" @@ -493,7 +1148,7 @@ async def generate_tool_calls( asyncio.create_task( model.container.generate( tool_request_id, - prompt, + tool_prompt, tool_data, mm_embeddings=embeddings, ) @@ -507,6 +1162,12 @@ async def generate_tool_calls( # Map tool calls to their appropriate generation for gen_idx, tool_call in zip(tool_idx, tool_calls, strict=True): - generations[gen_idx]["tool_calls"] = tool_call["text"] + raw_text = tool_call["text"] + + if use_native_generation and tool_start: + # Prepend tool_start to reconstruct complete native payload. + raw_text = tool_start + "\n" + raw_text + + generations[gen_idx]["tool_calls"] = raw_text return generations diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index f66d381d..c11a25bf 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -225,11 +225,24 @@ async def stream_generate_completion( # Consumer loop while True: + # Fast path: items already queued — no task overhead + if not gen_queue.empty(): + generation = gen_queue.get_nowait() + else: + # Slow path: queue empty — race get against disconnect + get_task = asyncio.create_task(gen_queue.get()) + done, _ = await asyncio.wait( + [get_task, disconnect_task], + return_when=asyncio.FIRST_COMPLETED, + ) + if disconnect_task in done: + get_task.cancel() + raise CancelledError() + generation = get_task.result() + if disconnect_task.done(): raise CancelledError() - generation = await gen_queue.get() - # Stream collector will push an exception to the queue if it fails if isinstance(generation, Exception): raise generation @@ -245,15 +258,16 @@ async def stream_generate_completion( except CancelledError: # Get out if the request gets disconnected - if not abort_event.is_set(): - abort_event.set() - handle_request_disconnect( - f"Completion generation {request.state.id} cancelled by user." - ) + handle_request_disconnect( + f"Completion generation {request.state.id} cancelled by user." + ) except Exception: yield get_generator_error( f"Completion {request.state.id} aborted. Please check the server console." ) + finally: + abort_event.set() + disconnect_task.cancel() async def generate_completion( diff --git a/endpoints/OAI/utils/parser_options.py b/endpoints/OAI/utils/parser_options.py new file mode 100644 index 00000000..84340f88 --- /dev/null +++ b/endpoints/OAI/utils/parser_options.py @@ -0,0 +1,100 @@ +"""Parser option helpers for vLLM-compatible chat settings.""" + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Dict, Set + + +# Mirrors vLLM parser keys to keep CLI/config ergonomics familiar. +# Source of truth: vllm/tool_parsers/__init__.py::_TOOL_PARSERS_TO_REGISTER +# Format is the fallback parsing mode supported by ToolCallProcessor. +TOOL_CALL_PARSER_FORMATS: Dict[str, str] = { + "deepseek_v3": "json", + "deepseek_v31": "json", + "deepseek_v32": "json", + "ernie45": "json", + "glm45": "json", + "glm47": "json", + "granite-20b-fc": "json", + "granite": "json", + "hermes": "json", + "hunyuan_a13b": "json", + "internlm": "json", + "jamba": "json", + "kimi_k2": "json", + "llama3_json": "json", + "llama4_json": "json", + "llama4_pythonic": "json", + "longcat": "json", + "minimax_m2": "json", + "minimax": "json", + "mistral": "json", + "olmo3": "json", + "openai": "json", + "phi4_mini_json": "json", + "pythonic": "json", + "qwen3_coder": "xml", + "qwen3_xml": "xml", + "seed_oss": "json", + "step3": "json", + "step3p5": "json", + "xlam": "json", + "gigachat3": "json", + "functiongemma": "json", + # Convenience alias for mixed/inferred content + "auto": "auto", +} + +# Compatibility aliases accepted by this server. +# Keys are user-facing parser names, values are canonical parser keys. +TOOL_CALL_PARSER_ALIASES: Dict[str, str] = { + "llama": "llama3_json", +} + +# Parsers that should generate tool calls in their native syntax on tool pass +# (no JSON schema constraint). Most JSON-style parsers should stay constrained. +NATIVE_TOOL_GENERATION_PARSERS: Set[str] = { + "auto", + "deepseek_v3", + "deepseek_v31", + "deepseek_v32", + "llama4_pythonic", + "mistral", + "pythonic", + "qwen3_coder", + "qwen3_xml", +} + + +def resolve_tool_call_parser_key(tool_call_parser: str | None) -> str | None: + """Normalize a user parser key to its canonical key.""" + if not tool_call_parser: + return None + return TOOL_CALL_PARSER_ALIASES.get(tool_call_parser, tool_call_parser) + + +def list_tool_call_parsers() -> Set[str]: + return set(TOOL_CALL_PARSER_FORMATS.keys()).union(TOOL_CALL_PARSER_ALIASES.keys()) + + +def resolve_tool_call_format( + tool_call_parser: str | None, fallback_format: str +) -> str: + """Resolve effective parser format from configured parser key.""" + if not tool_call_parser: + return fallback_format + parser_key = resolve_tool_call_parser_key(tool_call_parser) + return TOOL_CALL_PARSER_FORMATS.get(parser_key, "") + + +def parser_uses_native_tool_generation( + tool_call_parser: str | None, fallback_format: str +) -> bool: + """Whether tool pass should use native model format (unconstrained).""" + if not tool_call_parser: + return fallback_format in ("xml", "auto") + parser_key = resolve_tool_call_parser_key(tool_call_parser) + if parser_key in NATIVE_TOOL_GENERATION_PARSERS: + return True + return resolve_tool_call_format(parser_key, fallback_format) in ("xml", "auto") diff --git a/endpoints/OAI/utils/tools.py b/endpoints/OAI/utils/tools.py index c1ebdedf..3b1b4981 100644 --- a/endpoints/OAI/utils/tools.py +++ b/endpoints/OAI/utils/tools.py @@ -1,8 +1,18 @@ +"""Tool call processing utilities for OAI server.""" + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import ast import json +import re +from random import choices +from string import ascii_letters, digits from loguru import logger -from typing import List +from typing import Any, Callable, Dict, List, Tuple -from endpoints.OAI.types.tools import ToolCall +from endpoints.OAI.types.tools import ToolCall, Tool +from endpoints.OAI.utils.parser_options import resolve_tool_call_parser_key TOOL_CALL_SCHEMA = { @@ -27,24 +37,1126 @@ }, } +# --------------------------------------------------------------------------- +# XML parsing regex patterns +# Derived from vLLM's Qwen3CoderToolParser and the official Qwen parser. +# These handle both complete and partially-closed tags. +# --------------------------------------------------------------------------- + +# Matches complete ... blocks +TOOL_CALL_BLOCK_RE = re.compile( + r"(.*?)", + re.DOTALL, +) + +# Matches BODY blocks. +# Supports complete and partially-closed function sections to keep parity +# with vLLM behavior on generation cutoffs. +FUNCTION_RE = re.compile( + r"(.*?)|(.*)$", + re.DOTALL, +) + +# Matches VALUE +# Terminates on: , next , or +PARAMETER_RE = re.compile( + r"(.*?)" + r"(?:|(?=)|(?=))", + re.DOTALL, +) + +# Think block patterns +THINK_BLOCK_RE = re.compile(r".*?\s*", re.DOTALL) +THINK_UNCLOSED_RE = re.compile(r"(?!.*).*$", re.DOTALL) + +# Markdown code fence patterns +CODE_FENCE_RE = re.compile(r"^```(?:json)?\s*", re.MULTILINE) +CODE_FENCE_END_RE = re.compile(r"\s*```\s*$", re.MULTILINE) + +# Jamba / MiniMax tagged JSON blocks +TOOL_CALLS_TAG_RE = re.compile(r"(.*?)", re.DOTALL) + +# GLM-4.5 style: function_name\n{...} +GLM45_CALL_RE = re.compile( + r"\s*(?P[^\n<]+?)\s*\n(?P.*?)", + re.DOTALL, +) + +# MiniMax-M2 XML-like syntax +MINIMAX_M2_CALL_RE = re.compile( + r"(.*?)", + re.DOTALL, +) +MINIMAX_M2_INVOKE_RE = re.compile( + r".*?)>(?P.*?)", + re.DOTALL, +) +MINIMAX_M2_PARAM_RE = re.compile( + r".*?)>(?P.*?)", + re.DOTALL, +) + +# Seed-OSS tags +SEED_THINK_BLOCK_RE = re.compile(r".*?\s*", re.DOTALL) +SEED_THINK_UNCLOSED_RE = re.compile(r"(?!.*).*$", re.DOTALL) +SEED_TOOL_CALL_START = "" +SEED_TOOL_CALL_END = "" + +# DeepSeek family patterns +DEEPSEEK_V31_CALL_RE = re.compile( + r"<|tool▁call▁begin|>(?P.*?)<|tool▁sep|>(?P.*?)<|tool▁call▁end|>", + re.DOTALL, +) +DEEPSEEK_V3_CALL_RE = re.compile( + r"<|tool▁call▁begin|>(?P.*?)<|tool▁sep|>(?P.*?)\n```json\n(?P.*?)\n```(?:\s*)<|tool▁call▁end|>", # noqa: E501 + re.DOTALL, +) +DEEPSEEK_V32_INVOKE_RE = re.compile( + r'<|DSML|invoke\s+name="(?P[^"]+)"\s*>(?P.*?)', + re.DOTALL, +) +DEEPSEEK_V32_PARAM_RE = re.compile( + r'<|DSML|parameter\s+name="(?P[^"]+)"(?:\s+string="(?Ptrue|false)")?\s*>(?P.*?)', # noqa: E501 + re.DOTALL, +) + +MISTRAL_TOOL_START = "[TOOL_CALLS]" +MISTRAL_ID_ALPHANUMERIC = ascii_letters + digits + + +def _strip_think_blocks(text: str) -> str: + """Strip ... blocks from text. + + Handles both complete and unclosed blocks (quantization can cause + the model to never close a think tag). + """ + original = text + + # Complete blocks first + text = THINK_BLOCK_RE.sub("", text) + + # Unclosed block (think started but never closed — strip to end) + text = THINK_UNCLOSED_RE.sub("", text) + + if text != original: + if THINK_UNCLOSED_RE.search(original): + logger.warning( + "XML Parser: Stripped unclosed block " + "(possible quantization degradation)" + ) + else: + logger.debug("XML Parser: Stripped block(s) from output") + + return text + + +def _coerce_param_value(raw: str) -> Any: + """Coerce a raw parameter value string to the appropriate Python type. + + Strategy (safe, no eval()): + 1. Strip leading/trailing newlines (official template emits \\n + after opening tag and before closing tag). + 2. Try json.loads — handles objects, arrays, numbers, bools, null. + 3. Fall back to plain string. + """ + # Strip template-inserted newlines around values + if raw.startswith("\n"): + raw = raw[1:] + if raw.endswith("\n"): + raw = raw[:-1] + + stripped = raw.strip() + + # Empty string + if not stripped: + return "" + + # Try JSON parse (handles objects, arrays, numbers, booleans, null) + try: + return json.loads(stripped) + except (json.JSONDecodeError, ValueError): + pass + + # Handle Python-like literals often emitted by coder models, + # e.g. {'k': 'v'} for object parameters. + try: + return ast.literal_eval(stripped) + except (ValueError, SyntaxError): + pass + + # Fall back to string — never eval() + return stripped + class ToolCallProcessor: + _PARSER_DISPATCHER: Dict[str, Callable[[str], List[ToolCall]]] = {} + _MISSING_PARSER_WARNED: set[str] = set() + + # ------------------------------------------------------------------ + # JSON normalization helpers + # ------------------------------------------------------------------ + + @staticmethod + def _strip_quotes(value: str) -> str: + value = value.strip() + if len(value) >= 2 and value[0] == value[-1] and value[0] in {"'", '"'}: + return value[1:-1] + return value + + @staticmethod + def _normalize_tool_calls(raw) -> list: + """Normalize model-emitted tool call payloads into OAI-like objects. + + Accepted forms: + - [{"type":"function","function":{"name":...,"arguments":{...}}}] + - [{"name":...,"arguments":{...}}] + - {"name":...,"arguments":{...}} + """ + if isinstance(raw, dict): + raw = [raw] + if not isinstance(raw, list): + raise ValueError("tool_calls payload is not list/dict") + + normalized: list = [] + for item in raw: + if not isinstance(item, dict): + continue + + if "function" in item and isinstance(item["function"], dict): + fn = item["function"] + name = fn.get("name") + arguments = fn.get("arguments", {}) + else: + name = item.get("name") + arguments = item.get("arguments", {}) + + if name is None: + continue + + if isinstance(arguments, str): + try: + arguments = json.loads(arguments) + except json.JSONDecodeError: + arguments = {"input": arguments} + + normalized.append( + { + "type": "function", + "function": { + "name": name, + "arguments": arguments if isinstance(arguments, dict) else {}, + }, + } + ) + return normalized + + @staticmethod + def _safe_json_loads(payload: str) -> list: + """Best-effort JSON parse for model-emitted tool payloads. + + Handles: clean JSON, markdown-fenced JSON, JSON substrings in + surrounding text, flat {name, arguments} dicts, and single objects. + """ + # Direct parse + try: + return ToolCallProcessor._normalize_tool_calls(json.loads(payload)) + except (json.JSONDecodeError, ValueError): + pass + + # Clean up common model artifacts (markdown fences, whitespace) + cleaned = payload.strip() + cleaned = CODE_FENCE_RE.sub("", cleaned) + cleaned = CODE_FENCE_END_RE.sub("", cleaned) + cleaned = cleaned.strip() + + # Try cleaned + try: + return ToolCallProcessor._normalize_tool_calls(json.loads(cleaned)) + except (json.JSONDecodeError, ValueError): + pass + + # Find JSON array substring + start = cleaned.find("[") + end = cleaned.rfind("]") + if start != -1 and end != -1 and end > start: + try: + return ToolCallProcessor._normalize_tool_calls( + json.loads(cleaned[start : end + 1]) + ) + except (json.JSONDecodeError, ValueError): + pass + + # Find JSON object substring + obj_start = cleaned.find("{") + obj_end = cleaned.rfind("}") + if obj_start != -1 and obj_end != -1 and obj_end > obj_start: + try: + return ToolCallProcessor._normalize_tool_calls( + json.loads(cleaned[obj_start : obj_end + 1]) + ) + except (json.JSONDecodeError, ValueError): + pass + + raise json.JSONDecodeError( + "Could not extract valid JSON from payload", payload, 0 + ) + + @staticmethod + def _build_tool_calls_from_normalized(raw: Any) -> List[ToolCall]: + """Normalize dict/list payload and build ToolCall models.""" + normalized = ToolCallProcessor._normalize_tool_calls(raw) + for tool_call in normalized: + tool_call["function"]["arguments"] = json.dumps( + tool_call["function"]["arguments"], ensure_ascii=False + ) + return [ToolCall(**tool_call) for tool_call in normalized] + + @staticmethod + def _decode_json_sequence(text: str) -> List[Any]: + """Decode multiple JSON values from a single string.""" + decoder = json.JSONDecoder() + values: List[Any] = [] + idx = 0 + while idx < len(text): + while idx < len(text) and text[idx] in " \t\r\n,;": + idx += 1 + if idx >= len(text): + break + if text.startswith("<|python_tag|>", idx): + idx += len("<|python_tag|>") + continue + try: + value, end = decoder.raw_decode(text[idx:]) + except json.JSONDecodeError: + break + values.append(value) + idx += end + return values + + @staticmethod + def _coerce_argument_payload(arguments_raw: str) -> str: + """Normalize raw argument payload to a JSON string where possible.""" + payload = arguments_raw.strip() + if not payload: + return "{}" + try: + return json.dumps(json.loads(payload), ensure_ascii=False) + except (json.JSONDecodeError, ValueError, TypeError): + return payload + + @staticmethod + def _normalize_mistral_tool_call_id(raw_id: Any) -> str: + """Normalize tool call IDs to Mistral's 9-char alphanumeric format.""" + if isinstance(raw_id, str): + candidate = re.sub(r"[^A-Za-z0-9]", "", raw_id) + if len(candidate) >= 9: + return candidate[-9:] + return "".join(choices(MISTRAL_ID_ALPHANUMERIC, k=9)) + + @staticmethod + def _build_mistral_tool_call(name: str, arguments: Any, raw_id: Any = None) -> ToolCall: + if isinstance(arguments, str): + payload = ToolCallProcessor._coerce_argument_payload(arguments) + else: + payload = json.dumps(arguments, ensure_ascii=False) + return ToolCall( + id=ToolCallProcessor._normalize_mistral_tool_call_id(raw_id), + function=Tool(name=name, arguments=payload), + ) + + @staticmethod + def _parse_mistral_json_tool_calls(payload: str) -> List[ToolCall]: + """Parse JSON-style Mistral tool calls following [TOOL_CALLS].""" + decoded = ToolCallProcessor._decode_json_sequence(payload) + if not decoded: + try: + decoded = [json.loads(payload)] + except (json.JSONDecodeError, ValueError): + return [] + + tool_calls: List[ToolCall] = [] + for item in decoded: + candidates = item if isinstance(item, list) else [item] + for candidate in candidates: + if not isinstance(candidate, dict): + continue + + if "function" in candidate and isinstance(candidate["function"], dict): + fn = candidate["function"] + name = fn.get("name") + arguments = fn.get("arguments", {}) + tool_id = candidate.get("id") + else: + name = candidate.get("name") + arguments = candidate.get("arguments", {}) + tool_id = candidate.get("id") + + if not isinstance(name, str) or not name: + continue + + tool_calls.append( + ToolCallProcessor._build_mistral_tool_call( + name=name, arguments=arguments, raw_id=tool_id + ) + ) + + return tool_calls + + @staticmethod + def _parse_tagged_json_payload(payload: str) -> List[ToolCall]: + payload = payload.strip() + if not payload: + return [] + + # Prefer full JSON parse first (array/object). + try: + return ToolCallProcessor._build_tool_calls_from_normalized(json.loads(payload)) + except (json.JSONDecodeError, ValueError, TypeError, KeyError): + pass + + # Fallback: decode a sequence of JSON values. + decoded = ToolCallProcessor._decode_json_sequence(payload) + if decoded: + flattened = [] + for item in decoded: + if isinstance(item, list): + flattened.extend(item) + else: + flattened.append(item) + return ToolCallProcessor._build_tool_calls_from_normalized(flattened) + + # Fallback: line-delimited JSON objects. + lines = [line.strip().rstrip(",") for line in payload.splitlines() if line.strip()] + parsed_lines = [] + for line in lines: + if not line.startswith("{"): + continue + try: + parsed_lines.append(json.loads(line)) + except (json.JSONDecodeError, ValueError): + continue + if parsed_lines: + return ToolCallProcessor._build_tool_calls_from_normalized(parsed_lines) + + return [] + + @staticmethod + def _ast_to_literal(node: ast.AST) -> Any: + """Safely convert AST literal nodes to Python primitives.""" + if isinstance(node, ast.Constant): + return node.value + if isinstance(node, ast.List): + return [ToolCallProcessor._ast_to_literal(item) for item in node.elts] + if isinstance(node, ast.Tuple): + return [ToolCallProcessor._ast_to_literal(item) for item in node.elts] + if isinstance(node, ast.Dict): + result = {} + for key, value in zip(node.keys, node.values): + literal_key = ToolCallProcessor._ast_to_literal(key) # type: ignore[arg-type] + if not isinstance(literal_key, str): + raise ValueError("pythonic parser requires string dict keys") + result[literal_key] = ToolCallProcessor._ast_to_literal(value) + return result + if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub): + return -ToolCallProcessor._ast_to_literal(node.operand) + raise ValueError(f"unsupported pythonic AST node: {type(node).__name__}") + + # ------------------------------------------------------------------ + # JSON parsing + # ------------------------------------------------------------------ + + @staticmethod + def from_hermes(raw_text: str) -> List[ToolCall]: + """Parse Hermes-style JSON tool calls (often wrapped in ).""" + text = _strip_think_blocks(raw_text) + wrapped_calls = [] + for match in TOOL_CALL_BLOCK_RE.finditer(text): + inner = match.group(1).strip() + if not inner: + continue + try: + parsed = json.loads(inner) + except (json.JSONDecodeError, ValueError): + continue + wrapped_calls.extend(ToolCallProcessor._build_tool_calls_from_normalized(parsed)) + + if wrapped_calls: + return wrapped_calls + + return ToolCallProcessor.from_json(text) + + @staticmethod + def from_llama(raw_text: str) -> List[ToolCall]: + """Parse Llama JSON tool calls (single/multiple JSON objects).""" + text = _strip_think_blocks(raw_text).strip() + if text.startswith("<|python_tag|>"): + text = text[len("<|python_tag|>") :].lstrip() + + try: + parsed = ToolCallProcessor.from_json(text) + if parsed: + return parsed + except (json.JSONDecodeError, ValueError, KeyError): + pass + + decoded = ToolCallProcessor._decode_json_sequence(text) + if not decoded: + return [] + + flattened = [] + for item in decoded: + if isinstance(item, list): + flattened.extend(item) + else: + flattened.append(item) + + return ToolCallProcessor._build_tool_calls_from_normalized(flattened) + + @staticmethod + def from_openai(raw_text: str) -> List[ToolCall]: + """Best-effort parser for OpenAI/Harmony-style text payloads.""" + text = _strip_think_blocks(raw_text).strip() + try: + parsed = ToolCallProcessor.from_json(text) + if parsed: + return parsed + except (json.JSONDecodeError, ValueError, KeyError): + pass + + decoded = ToolCallProcessor._decode_json_sequence(text) + tool_calls: List[ToolCall] = [] + normalized_items = [] + for value in decoded: + candidates = value if isinstance(value, list) else [value] + for item in candidates: + if not isinstance(item, dict): + continue + + nested = item.get("tool_calls") + if nested: + try: + tool_calls.extend( + ToolCallProcessor._build_tool_calls_from_normalized(nested) + ) + except (ValueError, KeyError, TypeError): + pass + + recipient = item.get("recipient") + content = item.get("content") + if isinstance(recipient, str) and recipient.startswith("functions."): + fn_name = recipient.split("functions.", 1)[1] + if isinstance(content, str): + payload = ToolCallProcessor._coerce_argument_payload(content) + elif content is None: + payload = "{}" + else: + payload = json.dumps(content, ensure_ascii=False) + tool_calls.append( + ToolCall(function=Tool(name=fn_name, arguments=payload)) + ) + continue + + if "name" in item: + normalized_items.append(item) + + if normalized_items: + tool_calls.extend( + ToolCallProcessor._build_tool_calls_from_normalized(normalized_items) + ) + + return tool_calls + + @staticmethod + def from_pythonic(raw_text: str) -> List[ToolCall]: + """Parse Pythonic list-of-calls tool syntax.""" + text = _strip_think_blocks(raw_text).strip() + if text.startswith("<|python_tag|>"): + text = text[len("<|python_tag|>") :].lstrip() + if not text: + return [] + + if not text.startswith("[") and re.match(r"^[A-Za-z_]\w*\s*\(", text): + text = f"[{text}]" + + expression = ast.parse(text, mode="eval").body + call_nodes = expression.elts if isinstance(expression, ast.List) else [expression] + + tool_calls = [] + for node in call_nodes: + if not isinstance(node, ast.Call) or not isinstance(node.func, ast.Name): + continue + args_dict: Dict[str, Any] = {} + if node.args: + args_dict["_args"] = [ + ToolCallProcessor._ast_to_literal(argument) + for argument in node.args + ] + for keyword in node.keywords: + if keyword.arg is None: + continue + args_dict[keyword.arg] = ToolCallProcessor._ast_to_literal(keyword.value) + + tool_calls.append( + ToolCall( + function=Tool( + name=node.func.id, + arguments=json.dumps(args_dict, ensure_ascii=False), + ) + ) + ) + + return tool_calls + + @staticmethod + def from_deepseek_v31(raw_text: str) -> List[ToolCall]: + """Parse DeepSeek v3.1 tool call syntax.""" + tool_calls = [] + for match in DEEPSEEK_V31_CALL_RE.finditer(raw_text): + name = match.group("name").strip() + if not name: + continue + arguments = ToolCallProcessor._coerce_argument_payload(match.group("args")) + tool_calls.append(ToolCall(function=Tool(name=name, arguments=arguments))) + return tool_calls + + @staticmethod + def from_deepseek_v3(raw_text: str) -> List[ToolCall]: + """Parse DeepSeek v3 tool call syntax.""" + tool_calls = [] + for match in DEEPSEEK_V3_CALL_RE.finditer(raw_text): + name = match.group("name").strip() + if not name: + continue + arguments = ToolCallProcessor._coerce_argument_payload(match.group("args")) + tool_calls.append(ToolCall(function=Tool(name=name, arguments=arguments))) + + if tool_calls: + return tool_calls + + return ToolCallProcessor.from_deepseek_v31(raw_text) + + @staticmethod + def from_deepseek_v32(raw_text: str) -> List[ToolCall]: + """Parse DeepSeek v3.2 DSML tool call syntax.""" + tool_calls = [] + for invoke in DEEPSEEK_V32_INVOKE_RE.finditer(raw_text): + function_name = invoke.group("name").strip() + if not function_name: + continue + + params: Dict[str, Any] = {} + body = invoke.group("body") + for param in DEEPSEEK_V32_PARAM_RE.finditer(body): + key = param.group("name").strip() + value_raw = param.group("value") + is_string = param.group("string") == "true" + if is_string: + value = value_raw.strip("\n") + else: + value = _coerce_param_value(value_raw) + params[key] = value + + tool_calls.append( + ToolCall( + function=Tool( + name=function_name, + arguments=json.dumps(params, ensure_ascii=False), + ) + ) + ) + + if tool_calls: + return tool_calls + + return ToolCallProcessor.from_deepseek_v31(raw_text) + + @staticmethod + def from_mistral(raw_text: str) -> List[ToolCall]: + """Parse Mistral [TOOL_CALLS] payloads for both tokenizer formats.""" + text = raw_text.strip() + + # Non-Mistral outputs should remain compatible with existing JSON logic. + if MISTRAL_TOOL_START not in text: + return ToolCallProcessor.from_json(text) + + split_payloads = [ + chunk.strip() + for chunk in text.split(MISTRAL_TOOL_START)[1:] + if chunk.strip() + ] + if not split_payloads: + return [] + + # pre-v11 format: [TOOL_CALLS] [{"name": "...", "arguments": {...}}] + if len(split_payloads) == 1: + json_calls = ToolCallProcessor._parse_mistral_json_tool_calls( + split_payloads[0] + ) + if json_calls: + return json_calls + + # v11+ format: [TOOL_CALLS]name{...}[TOOL_CALLS]name{...} + tool_calls: List[ToolCall] = [] + for payload in split_payloads: + start = payload.find("{") + end = payload.rfind("}") + if start == -1 or end < start: + continue + + function_name = payload[:start].strip() + if not function_name: + continue + + arguments = payload[start : end + 1] + tool_calls.append( + ToolCallProcessor._build_mistral_tool_call( + name=function_name, + arguments=arguments, + ) + ) + + if tool_calls: + return tool_calls + + # Final fallback for malformed payloads. + return ToolCallProcessor.from_json(split_payloads[-1]) + + @staticmethod + def from_tagged_tool_calls(raw_text: str) -> List[ToolCall]: + """Parse ... tagged payloads (Jamba/MiniMax).""" + text = _strip_think_blocks(raw_text) + matches = TOOL_CALLS_TAG_RE.findall(text) + if not matches: + return [] + + tool_calls: List[ToolCall] = [] + for payload in matches: + tool_calls.extend(ToolCallProcessor._parse_tagged_json_payload(payload)) + return tool_calls + + @staticmethod + def from_glm45(raw_text: str) -> List[ToolCall]: + """Parse GLM-4.5 style name\\n{args} payloads.""" + text = _strip_think_blocks(raw_text) + tool_calls: List[ToolCall] = [] + for match in GLM45_CALL_RE.finditer(text): + name = match.group("name").strip() + if not name: + continue + args = ToolCallProcessor._coerce_argument_payload(match.group("args")) + tool_calls.append( + ToolCall( + function=Tool( + name=name, + arguments=args, + ) + ) + ) + return tool_calls + + @staticmethod + def from_minimax_m2(raw_text: str) -> List[ToolCall]: + """Parse MiniMax-M2 XML-like tool call payloads.""" + text = _strip_think_blocks(raw_text) + tool_calls: List[ToolCall] = [] + + for call in MINIMAX_M2_CALL_RE.finditer(text): + call_body = call.group(1) + for invoke in MINIMAX_M2_INVOKE_RE.finditer(call_body): + fn_name = ToolCallProcessor._strip_quotes(invoke.group("name")) + if not fn_name: + continue + + params: Dict[str, Any] = {} + invoke_body = invoke.group("body") + for param in MINIMAX_M2_PARAM_RE.finditer(invoke_body): + key = ToolCallProcessor._strip_quotes(param.group("name")) + if not key: + continue + value = _coerce_param_value(param.group("value")) + params[key] = value + + tool_calls.append( + ToolCall( + function=Tool( + name=fn_name, + arguments=json.dumps(params, ensure_ascii=False), + ) + ) + ) + + return tool_calls + + @staticmethod + def from_seed_oss(raw_text: str) -> List[ToolCall]: + """Parse Seed-OSS XML-style tool calls by adapting to Qwen3 XML format.""" + text = SEED_THINK_BLOCK_RE.sub("", raw_text) + text = SEED_THINK_UNCLOSED_RE.sub("", text) + text = text.replace(SEED_TOOL_CALL_START, "") + text = text.replace(SEED_TOOL_CALL_END, "") + return ToolCallProcessor.from_xml(text) + + @staticmethod + def from_olmo3(raw_text: str) -> List[ToolCall]: + """Parse OLMo3 pythonic tool calls, optionally wrapped by .""" + text = _strip_think_blocks(raw_text).strip() + wrapped = re.search(r"(.*?)", text, re.DOTALL) + if wrapped: + text = wrapped.group(1).strip() + + lines = [line.strip() for line in text.splitlines() if line.strip()] + if len(lines) > 1 and all(re.match(r"^[A-Za-z_]\w*\s*\(", line) for line in lines): + text = "[" + ", ".join(lines) + "]" + + return ToolCallProcessor.from_pythonic(text) + @staticmethod def from_json(tool_calls_str: str) -> List[ToolCall]: - """Postprocess tool call JSON to a parseable class""" + """Postprocess tool call JSON to a parseable class. - tool_calls = json.loads(tool_calls_str) + Handles clean JSON arrays, markdown-fenced output, flat dicts, + and other common model output variations via _safe_json_loads. + """ + logger.debug(f"JSON Parser: Parsing tool calls ({len(tool_calls_str)} chars)") + + tool_calls = ToolCallProcessor._safe_json_loads(tool_calls_str) for tool_call in tool_calls: tool_call["function"]["arguments"] = json.dumps( tool_call["function"]["arguments"] ) - return [ToolCall(**tool_call) for tool_call in tool_calls] + result = [ToolCall(**tool_call) for tool_call in tool_calls] + logger.debug(f"JSON Parser: Successfully parsed {len(result)} tool call(s)") + return result + + # ------------------------------------------------------------------ + # XML parsing (Qwen3-Coder / GLM-4.5 style) + # ------------------------------------------------------------------ @staticmethod - def dump(tool_calls: List[ToolCall]) -> List[dict]: + def from_xml(raw_text: str) -> List[ToolCall]: + """Parse Qwen3-Coder XML-format tool calls into ToolCall objects. + + Handles: + - Wrapped: ... + - Bare: ... (missing wrapper) + - Multiple sequential tool call blocks + - blocks (stripped) + - Multi-line parameter values + - Missing closing tags + """ + logger.debug(f"XML Parser: Parsing tool calls ({len(raw_text)} chars)") + + # Stage 1: Strip think blocks + text = _strip_think_blocks(raw_text) + + # Stage 2: Check for incomplete XML at end (generation cutoff) + stripped_end = text.rstrip() + if stripped_end.endswith(("<", "]*$", "", text) + + # Stage 3: Extract function blocks + # First, find all wrapped ... blocks + wrapped_positions = [ + (m.start(), m.end()) for m in TOOL_CALL_BLOCK_RE.finditer(text) + ] + + # Collect function blocks from inside wrapped regions + function_blocks = [] + for match in TOOL_CALL_BLOCK_RE.finditer(text): + inner = match.group(1) + for func_match in FUNCTION_RE.finditer(inner): + name = func_match.group(1) if func_match.group(1) is not None else func_match.group(3) + body = func_match.group(2) if func_match.group(2) is not None else func_match.group(4) + function_blocks.append((name, body)) + + # Find bare blocks NOT inside any wrapped region + for func_match in FUNCTION_RE.finditer(text): + pos = func_match.start() + is_wrapped = any(start <= pos < end for start, end in wrapped_positions) + if not is_wrapped: + logger.debug( + "XML Parser: Found bare block without " + " wrapper" + ) + name = func_match.group(1) if func_match.group(1) is not None else func_match.group(3) + body = func_match.group(2) if func_match.group(2) is not None else func_match.group(4) + function_blocks.append((name, body)) + + if not function_blocks: + logger.warning("XML Parser: No blocks found") + return [] + + # Stage 4: Parse each function block into a ToolCall + tool_calls = [] + for func_name_raw, func_body in function_blocks: + func_name = func_name_raw.strip() + + # Extract parameters + params = {} + for param_match in PARAMETER_RE.finditer(func_body): + key = param_match.group(1).strip() + value_raw = param_match.group(2) + value = _coerce_param_value(value_raw) + params[key] = value + + arguments_json = json.dumps(params, ensure_ascii=False) + + tool_call = ToolCall( + function=Tool(name=func_name, arguments=arguments_json) + ) + tool_calls.append(tool_call) + + logger.debug(f"XML Parser: Successfully parsed {len(tool_calls)} tool call(s)") + return tool_calls + + # ------------------------------------------------------------------ + # Auto-detect parsing (JSON → JSON-in-tool_call → XML) + # ------------------------------------------------------------------ + + @staticmethod + def from_auto(raw_text: str) -> List[ToolCall]: + """Auto-detect format and parse. + + Tries in order: + 1. Pure JSON (standard TabbyAPI / Llama) + 2. JSON inside wrappers (Qwen3-Instruct style) + 3. XML with tags (Qwen3-Coder style) """ - Convert ToolCall objects to a list of dictionaries. + logger.debug("Auto Parser: Attempting format auto-detection") + + # Attempt 1: Pure JSON array + try: + result = ToolCallProcessor.from_json(raw_text) + logger.debug("Auto Parser: Detected JSON format") + return result + except (json.JSONDecodeError, ValueError, KeyError) as e: + logger.debug(f"Auto Parser: Not JSON ({e}), trying next format") + + # Attempt 2: JSON inside wrappers (Qwen3-Instruct) + try: + all_tool_calls = [] + for match in TOOL_CALL_BLOCK_RE.finditer(raw_text): + inner = match.group(1).strip() + if inner.startswith("{") or inner.startswith("["): + parsed = json.loads(inner) + if isinstance(parsed, dict): + parsed = [parsed] + if isinstance(parsed, list): + for tc in parsed: + name = tc.get("name", "") + arguments = tc.get("arguments", {}) + if isinstance(arguments, dict): + arguments = json.dumps(arguments) + elif not isinstance(arguments, str): + arguments = json.dumps(arguments) + all_tool_calls.append( + ToolCall(function=Tool(name=name, arguments=arguments)) + ) + if all_tool_calls: + logger.debug( + "Auto Parser: Detected JSON-inside-tool_call " + f"format ({len(all_tool_calls)} call(s))" + ) + return all_tool_calls + except (json.JSONDecodeError, ValueError, KeyError) as e: + logger.debug(f"Auto Parser: Not JSON-in-tool_call ({e}), trying XML") + + # Attempt 3: XML format (Qwen3-Coder style) + result = ToolCallProcessor.from_xml(raw_text) + if result: + logger.debug("Auto Parser: Detected XML format") + else: + logger.warning("Auto Parser: All format detection attempts failed") + return result + + # ------------------------------------------------------------------ + # Dispatcher + # ------------------------------------------------------------------ + + @staticmethod + def _parser_dispatcher() -> Dict[str, Callable[[str], List[ToolCall]]]: + """Registry for parser-key-specific handlers.""" + if not ToolCallProcessor._PARSER_DISPATCHER: + ToolCallProcessor._PARSER_DISPATCHER = { + "deepseek_v3": ToolCallProcessor.from_deepseek_v3, + "deepseek_v31": ToolCallProcessor.from_deepseek_v31, + "deepseek_v32": ToolCallProcessor.from_deepseek_v32, + "ernie45": ToolCallProcessor.from_hermes, + "functiongemma": ToolCallProcessor.from_auto, + "gigachat3": ToolCallProcessor.from_auto, + "glm45": ToolCallProcessor.from_glm45, + "glm47": ToolCallProcessor.from_glm45, + "granite": ToolCallProcessor.from_json, + "granite-20b-fc": ToolCallProcessor.from_auto, + "hermes": ToolCallProcessor.from_hermes, + "hunyuan_a13b": ToolCallProcessor.from_auto, + "internlm": ToolCallProcessor.from_auto, + "jamba": ToolCallProcessor.from_tagged_tool_calls, + "kimi_k2": ToolCallProcessor.from_auto, + "llama": ToolCallProcessor.from_llama, + "llama3_json": ToolCallProcessor.from_llama, + "llama4_json": ToolCallProcessor.from_llama, + "llama4_pythonic": ToolCallProcessor.from_pythonic, + "longcat": ToolCallProcessor.from_hermes, + "minimax": ToolCallProcessor.from_tagged_tool_calls, + "minimax_m2": ToolCallProcessor.from_minimax_m2, + "mistral": ToolCallProcessor.from_mistral, + "olmo3": ToolCallProcessor.from_olmo3, + "openai": ToolCallProcessor.from_openai, + "phi4_mini_json": ToolCallProcessor.from_json, + "pythonic": ToolCallProcessor.from_pythonic, + "qwen3_coder": ToolCallProcessor.from_xml, + "qwen3_xml": ToolCallProcessor.from_xml, + "seed_oss": ToolCallProcessor.from_seed_oss, + "step3": ToolCallProcessor.from_auto, + "step3p5": ToolCallProcessor.from_xml, + "xlam": ToolCallProcessor.from_auto, + } + return ToolCallProcessor._PARSER_DISPATCHER + + @staticmethod + def parse( + tool_calls_str: str, format: str = "json", parser_key: str | None = None + ) -> List[ToolCall]: + """Dispatch tool call parsing to the appropriate format handler. + + Args: + tool_calls_str: Raw tool call text from model generation. + format: One of ``"json"``, ``"xml"``, ``"auto"``. + parser_key: Optional vLLM-compatible parser key. + + Returns: + List of parsed ToolCall objects. Empty list on parse failure + (never raises). + """ + try: + if parser_key: + canonical_key = resolve_tool_call_parser_key(parser_key) or parser_key + parser = ToolCallProcessor._parser_dispatcher().get(canonical_key) + if parser: + try: + parsed = parser(tool_calls_str) + except Exception as exc: + logger.warning( + "Parser '{}' failed: {}. Falling back to format '{}'.", + canonical_key, + str(exc), + format, + ) + else: + if parsed: + return parsed + elif canonical_key not in ToolCallProcessor._MISSING_PARSER_WARNED: + ToolCallProcessor._MISSING_PARSER_WARNED.add(canonical_key) + logger.warning( + "No dedicated tool parser handler for key '{}'; " + "falling back to format parser '{}'.", + canonical_key, + format, + ) + + if format == "xml": + return ToolCallProcessor.from_xml(tool_calls_str) + elif format == "auto": + return ToolCallProcessor.from_auto(tool_calls_str) + else: + return ToolCallProcessor.from_json(tool_calls_str) + except Exception as e: + logger.error( + f"ToolCallProcessor.parse: Failed to parse tool calls " + f"(format={format}): {e}" + ) + return [] + + # ------------------------------------------------------------------ + # Filtering + # ------------------------------------------------------------------ + + @staticmethod + def filter_by_name( + tool_calls: List[ToolCall], function_name: str + ) -> List[ToolCall]: + """Filter parsed tool calls to only those matching a function name.""" + filtered = [tc for tc in tool_calls if tc.function.name == function_name] + if not filtered: + logger.warning( + f"filter_by_name: No tool calls matched '{function_name}' " + f"(had {len(tool_calls)} call(s))" + ) + return filtered + + # ------------------------------------------------------------------ + # Content / tool-call separation + # ------------------------------------------------------------------ + + @staticmethod + def extract_content_and_tools( + raw_text: str, + ) -> Tuple[str, List[ToolCall]]: + """Separate plain text content from XML tool call blocks. + + Used when the model mixes reasoning text with tool calls, e.g.: + ``"I'll help with that: ...`` + + Returns: + Tuple of (remaining_content, tool_calls). + """ + text = _strip_think_blocks(raw_text) + + # Collect all XML regions to exclude from content + xml_regions = [] + + # Wrapped tool call blocks + for match in TOOL_CALL_BLOCK_RE.finditer(text): + xml_regions.append((match.start(), match.end())) + + # Bare function blocks not inside wrappers + for match in FUNCTION_RE.finditer(text): + pos = match.start() + is_wrapped = any(start <= pos < end for start, end in xml_regions) + if not is_wrapped: + xml_regions.append((match.start(), match.end())) + + # Sort and extract content (everything outside XML regions) + xml_regions.sort() + content_parts = [] + last_end = 0 + for start, end in xml_regions: + if start > last_end: + part = text[last_end:start].strip() + if part: + content_parts.append(part) + last_end = end + if last_end < len(text): + part = text[last_end:].strip() + if part: + content_parts.append(part) + + content = " ".join(content_parts).strip() + + # Parse tool calls from the full text + tool_calls = ToolCallProcessor.from_xml(text) + + logger.debug( + f"extract_content_and_tools: Found {len(tool_calls)} tool " + f"call(s), content={'yes' if content else 'no'} " + f"({len(content)} chars)" + ) + + return content, tool_calls + + # ------------------------------------------------------------------ + # Serialisation helpers (unchanged from original) + # ------------------------------------------------------------------ + + @staticmethod + def dump(tool_calls: List[ToolCall]) -> List[dict]: + """Convert ToolCall objects to a list of dictionaries. Args: tool_calls (List[ToolCall]): List of ToolCall objects to convert @@ -65,8 +1177,7 @@ def dump(tool_calls: List[ToolCall]) -> List[dict]: @staticmethod def to_json(tool_calls: List[ToolCall]) -> str: - """ - Convert ToolCall objects to JSON string representation. + """Convert ToolCall objects to JSON string representation. Args: tool_calls (List[ToolCall]): List of ToolCall objects to convert diff --git a/endpoints/core/types/model.py b/endpoints/core/types/model.py index 84229294..6f4e66a6 100644 --- a/endpoints/core/types/model.py +++ b/endpoints/core/types/model.py @@ -4,7 +4,7 @@ from time import time from typing import List, Literal, Optional, Union -from common.config_models import LoggingConfig +from common.config_models import ATTENTION_BACKENDS, LoggingConfig from common.tabby_config import config @@ -20,6 +20,10 @@ class ModelCardParameters(BaseModel): rope_alpha: Optional[float] = 1.0 max_batch_size: Optional[int] = 1 chunk_size: Optional[int] = 2048 + tokenizer_mode: Optional[str] = "auto" + mistral_tokenizer_models: Optional[List[str]] = Field(default_factory=list) + attention_backend: Optional[str] = "auto" + resolved_attention_backend: Optional[str] = None prompt_template: Optional[str] = None prompt_template_content: Optional[str] = None use_vision: Optional[bool] = False @@ -79,6 +83,13 @@ class ModelLoadRequest(BaseModel): description="Backend to use", default=None, ) + attention_backend: Optional[ATTENTION_BACKENDS] = Field( + description=( + "Attention backend policy for exllamav3 " + "(auto, flash_attn, flashinfer)" + ), + default=None, + ) max_seq_len: Optional[int] = Field( description="Leave this blank to use the model's base sequence length", default=None, @@ -111,6 +122,17 @@ class ModelLoadRequest(BaseModel): chunk_size: Optional[int] = None output_chunking: Optional[bool] = True prompt_template: Optional[str] = None + tokenizer_mode: Optional[str] = Field( + description="Tokenizer compatibility mode (auto, hf, slow, mistral, deepseek_v32)", + default=None, + ) + mistral_tokenizer_models: Optional[List[str]] = Field( + default_factory=list, + description=( + "Optional allowlist for tokenizer_mode='mistral'. " + "Only listed Mistral-family models can use mistral mode." + ), + ) vision: Optional[bool] = None # Non-config arguments diff --git a/pyproject.toml b/pyproject.toml index 7a04cce9..12d38b00 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,26 +78,19 @@ cu12 = [ "exllamav2 @ https://github.com/turboderp-org/exllamav2/releases/download/v0.3.2/exllamav2-0.3.2+cu128.torch2.9.0-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Exl3 - "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.21/exllamav3-0.0.21+cu128.torch2.9.0-cp313-cp313-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.13'", - "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.21/exllamav3-0.0.21+cu128.torch2.9.0-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", - "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.21/exllamav3-0.0.21+cu128.torch2.9.0-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", - "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.21/exllamav3-0.0.21+cu128.torch2.9.0-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", - "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.21/exllamav3-0.0.21+cu128.torch2.9.0-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.13'", - "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.21/exllamav3-0.0.21+cu128.torch2.9.0-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", - "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.21/exllamav3-0.0.21+cu128.torch2.9.0-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", - "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.21/exllamav3-0.0.21+cu128.torch2.9.0-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", - - # Windows FA2 from https://github.com/kingbri1/flash-attention/releases - "flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu128torch2.9.0cxx11abiFALSE-cp313-cp313-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.13'", - "flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu128torch2.9.0cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", - "flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu128torch2.9.0cxx11abiFALSE-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", - "flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu128torch2.9.0cxx11abiFALSE-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", - - # Linux FA2 from https://github.com/kingbri1/flash-attention/releases - "flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu128torch2.9.0cxx11abiFALSE-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.13'", - "flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu128torch2.9.0cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", - "flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu128torch2.9.0cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", - "flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu128torch2.9.0cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.22/exllamav3-0.0.22+cu128.torch2.9.0-cp313-cp313-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.13'", + "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.22/exllamav3-0.0.22+cu128.torch2.9.0-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", + "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.22/exllamav3-0.0.22+cu128.torch2.9.0-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", + "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.22/exllamav3-0.0.22+cu128.torch2.9.0-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", + "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.22/exllamav3-0.0.22+cu128.torch2.9.0-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.13'", + "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.22/exllamav3-0.0.22+cu128.torch2.9.0-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", + "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.22/exllamav3-0.0.22+cu128.torch2.9.0-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.22/exllamav3-0.0.22+cu128.torch2.9.0-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + + # FlashInfer backend for ExLlamaV3 (CUDA 12.8 wheel set) + "flashinfer-python @ https://github.com/flashinfer-ai/flashinfer/releases/download/v0.6.3/flashinfer_python-0.6.3-py3-none-any.whl ; platform_system == 'Linux' and platform_machine == 'x86_64'", + "flashinfer-jit-cache @ https://github.com/flashinfer-ai/flashinfer/releases/download/v0.6.3/flashinfer_jit_cache-0.6.3+cu128-cp39-abi3-manylinux_2_28_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64'", + ] amd = [ # Torch triton for ROCm diff --git a/templates/tool_calls/qwen3_coder.jinja b/templates/tool_calls/qwen3_coder.jinja new file mode 100644 index 00000000..15272747 --- /dev/null +++ b/templates/tool_calls/qwen3_coder.jinja @@ -0,0 +1,123 @@ +{# TabbyAPI Metadata #} +{%- set tool_call_format = "xml" -%} +{%- set tool_start = "" -%} +{%- set tool_end = "" -%} +{%- set stop_strings = ["<|im_start|>", "<|im_end|>"] -%} + +{% macro render_extra_keys(json_dict, handled_keys) %} + {%- if json_dict is mapping %} + {%- for json_key in json_dict if json_key not in handled_keys %} + {%- if json_dict[json_key] is string %} + {{-'\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | string) ~ '' }} + {%- else %} + {{- '\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | tojson | safe) ~ '' }} + {%- endif %} + {%- endfor %} + {%- endif %} +{%- endmacro %} + +{%- if messages[0]["role"] == "system" %} + {%- set system_message = messages[0]["content"] %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set loop_messages = messages %} +{%- endif %} + +{%- if not tools is defined %} + {%- set tools = [] %} +{%- endif %} + +{%- if system_message is defined %} + {{- "<|im_start|>system\n" + system_message }} +{%- else %} + {%- if tools is iterable and tools | length > 0 %} + {{- "<|im_start|>system\nYou are Qwen, a helpful AI assistant that can interact with a computer to solve tasks." }} + {%- endif %} +{%- endif %} +{%- if tools is iterable and tools | length > 0 %} + {{- "\n\n# Tools\n\nYou have access to the following functions:\n\n" }} + {{- "" }} + {%- for tool in tools %} + {%- if tool.function is defined %} + {%- set tool = tool.function %} + {%- endif %} + {{- "\n\n" ~ tool.name ~ "" }} + {%- if tool.description is defined %} + {{- '\n' ~ (tool.description | trim) ~ '' }} + {%- endif %} + {{- '\n' }} + {%- if tool.parameters is defined and tool.parameters is mapping and tool.parameters.properties is defined and tool.parameters.properties is mapping %} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {{- '\n' }} + {{- '\n' ~ param_name ~ '' }} + {%- if param_fields.type is defined %} + {{- '\n' ~ (param_fields.type | string) ~ '' }} + {%- endif %} + {%- if param_fields.description is defined %} + {{- '\n' ~ (param_fields.description | trim) ~ '' }} + {%- endif %} + {%- set handled_keys = ['name', 'type', 'description'] %} + {{- render_extra_keys(param_fields, handled_keys) }} + {{- '\n' }} + {%- endfor %} + {%- endif %} + {%- set handled_keys = ['type', 'properties'] %} + {{- render_extra_keys(tool.parameters, handled_keys) }} + {{- '\n' }} + {%- set handled_keys = ['type', 'name', 'description', 'parameters'] %} + {{- render_extra_keys(tool, handled_keys) }} + {{- '\n' }} + {%- endfor %} + {{- "\n" }} + {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n\n\n\nvalue_1\n\n\nThis is the value for the second parameter\nthat can span\nmultiple lines\n\n\n\n\n\nReminder:\n- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n' }} +{%- endif %} +{%- if system_message is defined %} + {{- '<|im_end|>\n' }} +{%- else %} + {%- if tools is iterable and tools | length > 0 %} + {{- '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- for message in loop_messages %} + {%- if message.role == "assistant" and message.tool_calls is defined and message.tool_calls is iterable and message.tool_calls | length > 0 %} + {{- '<|im_start|>' + message.role }} + {%- if message.content is defined and message.content is string and message.content | trim | length > 0 %} + {{- '\n' + message.content | trim + '\n' }} + {%- endif %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n\n\n' }} + {%- if tool_call.arguments is defined %} + {%- for args_name, args_value in tool_call.arguments|items %} + {{- '\n' }} + {%- set args_value = args_value if args_value is string else args_value | tojson | safe %} + {{- args_value }} + {{- '\n\n' }} + {%- endfor %} + {%- endif %} + {{- '\n' }} + {%- endfor %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "user" or message.role == "system" or message.role == "assistant" %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {%- elif message.role == "tool" %} + {%- if loop.previtem and loop.previtem.role != "tool" %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- message.content }} + {{- '\n' }} + {%- if not loop.last and loop.nextitem.role != "tool" %} + {{- '<|im_end|>\n' }} + {%- elif loop.last %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} +{%- endif %} \ No newline at end of file diff --git a/tests/deepseek_vl2_chat_serializer_test.py b/tests/deepseek_vl2_chat_serializer_test.py new file mode 100644 index 00000000..3cdbf732 --- /dev/null +++ b/tests/deepseek_vl2_chat_serializer_test.py @@ -0,0 +1,114 @@ +"""Regression tests for DeepSeek-VL2 built-in chat serialization.""" + +import asyncio +import sys +from pathlib import Path +from types import SimpleNamespace + +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +from endpoints.OAI.utils import chat_completion as cc + + +class _FakeMultimodalEmbeddingWrapper: + """Minimal multimodal stub that emits stable text aliases.""" + + def __init__(self): + self.text_alias = [] + self.urls = [] + + async def add(self, url: str): + self.urls.append(url) + self.text_alias.append(f"") + + +def _message(role: str, content, tool_call_id=None): + return SimpleNamespace(role=role, content=content, tool_call_id=tool_call_id) + + +def _text_part(text: str): + return SimpleNamespace(type="text", text=text, image_url=None) + + +def _image_part(url: str): + return SimpleNamespace( + type="image_url", + text=None, + image_url=SimpleNamespace(url=url), + ) + + +def _install_deepseek_vl2_serializer(monkeypatch): + container = SimpleNamespace( + config=SimpleNamespace(architecture=cc.DEEPSEEK_VL2_ARCH), + use_vision=True, + ) + monkeypatch.setattr(cc.model, "container", container, raising=False) + monkeypatch.setattr( + cc, + "MultimodalEmbeddingWrapper", + _FakeMultimodalEmbeddingWrapper, + ) + + +def test_builtin_serializer_supports_official_multi_image_interleaving(monkeypatch): + _install_deepseek_vl2_serializer(monkeypatch) + + messages = [ + _message( + "user", + [ + _text_part("This is image_1: "), + _image_part("image://1"), + _text_part("This is image_2: "), + _image_part("image://2"), + _text_part("This is image_3: "), + _image_part("image://3"), + _text_part(" Can you tell me what are in the images?"), + ], + ), + _message("assistant", ""), + ] + + prompt, mm_embeddings, serializer_state = asyncio.run( + cc.format_messages_with_builtin_serializer(messages) + ) + + assert prompt == ( + "<|User|>: This is image_1: \n\n" + "This is image_2: \n\n" + "This is image_3: \n\n" + " Can you tell me what are in the images?\n\n" + "<|Assistant|>: " + ) + assert mm_embeddings is not None + assert mm_embeddings.urls == ["image://1", "image://2", "image://3"] + assert mm_embeddings.text_alias == ["", "", ""] + assert serializer_state["last_non_system_role"] == "assistant" + + +def test_builtin_serializer_preserves_grounding_markup(monkeypatch): + _install_deepseek_vl2_serializer(monkeypatch) + + grounding_text = ( + "<|ref|>The red square<|/ref|> is the target. " + "<|grounding|><|det|>[0.1,0.2,0.9,0.9]<|/det|><|/grounding|>" + ) + messages = [ + _message( + "user", + [ + _image_part("image://1"), + _text_part(grounding_text), + ], + ), + ] + + prompt, mm_embeddings, serializer_state = asyncio.run( + cc.format_messages_with_builtin_serializer(messages) + ) + + assert prompt == f"<|User|>: \n{grounding_text}" + assert mm_embeddings is not None + assert mm_embeddings.urls == ["image://1"] + assert serializer_state["last_non_system_role"] == "user" diff --git a/tests/exaone4_reasoning_parser_test.py b/tests/exaone4_reasoning_parser_test.py new file mode 100644 index 00000000..135b37cb --- /dev/null +++ b/tests/exaone4_reasoning_parser_test.py @@ -0,0 +1,203 @@ +"""Tests for Exaone4 reasoning parser behavior.""" + +from endpoints.OAI.reasoning.exaone4_reasoning_parser import Exaone4ReasoningParser + + +class _FakeTokenizer: + def get_vocab(self): + return { + "": 101, + "": 102, + } + + +def _parser(enable_thinking: bool) -> Exaone4ReasoningParser: + return Exaone4ReasoningParser( + _FakeTokenizer(), + chat_template_kwargs={"enable_thinking": enable_thinking}, + ) + + +def test_non_thinking_mode_emits_content_only(): + parser = _parser(enable_thinking=False) + + reasoning, content = parser.extract_reasoning("hello", request=None) + assert reasoning is None + assert content == "hello" + + reasoning, content = parser.extract_reasoning("hello", request=None) + assert reasoning is None + assert content == "hello" + + delta = parser.extract_reasoning_streaming( + previous_text="", + current_text="hello", + delta_text="hello", + previous_token_ids=[], + current_token_ids=[1], + delta_token_ids=[1], + ) + assert delta is not None + assert delta.reasoning is None + assert delta.content == "hello" + + +def test_thinking_mode_extract_reasoning_and_content_non_streaming(): + parser = _parser(enable_thinking=True) + + reasoning, content = parser.extract_reasoning( + "reasonanswer", request=None + ) + assert reasoning == "reason" + assert content == "answer" + + reasoning, content = parser.extract_reasoning("reasonanswer", request=None) + assert reasoning == "reason" + assert content == "answer" + + +def test_thinking_mode_without_end_token_is_reasoning_only(): + parser = _parser(enable_thinking=True) + + reasoning, content = parser.extract_reasoning("reasoning only", request=None) + assert reasoning == "reasoning only" + assert content is None + + +def test_thinking_streaming_prefill_flow_without_start_token(): + parser = _parser(enable_thinking=True) + + first = parser.extract_reasoning_streaming( + previous_text="", + current_text="reason ", + delta_text="reason ", + previous_token_ids=[], + current_token_ids=[11], + delta_token_ids=[11], + ) + assert first is not None + assert first.reasoning == "reason " + assert first.content is None + + second = parser.extract_reasoning_streaming( + previous_text="reason ", + current_text="reason morefinal", + delta_text="morefinal", + previous_token_ids=[11], + current_token_ids=[11, 12, 102, 13], + delta_token_ids=[12, 102, 13], + ) + assert second is not None + assert second.reasoning == "more" + assert second.content == "final" + + third = parser.extract_reasoning_streaming( + previous_text="reason morefinal", + current_text="reason morefinal!", + delta_text="!", + previous_token_ids=[11, 12, 102, 13], + current_token_ids=[11, 12, 102, 13, 14], + delta_token_ids=[14], + ) + assert third is not None + assert third.reasoning is None + assert third.content == "!" + + +def test_thinking_streaming_handles_split_end_token_boundary(): + parser = _parser(enable_thinking=True) + + first = parser.extract_reasoning_streaming( + previous_text="", + current_text="analysis {"name":"lookup","arguments":{}}' + + +def test_thinking_streaming_handles_split_deepseek_tool_boundary_without_end_token(): + parser = _parser(enable_thinking=True) + + first = parser.extract_reasoning_streaming( + previous_text="", + current_text="analysis <|tool▁call▁b", + delta_text="analysis <|tool▁call▁b", + previous_token_ids=[], + current_token_ids=[11], + delta_token_ids=[11], + ) + assert first is not None + assert first.reasoning == "analysis " + assert first.content is None + + second = parser.extract_reasoning_streaming( + previous_text="analysis <|tool▁call▁b", + current_text=( + "analysis <|tool▁call▁begin|>lookup<|tool▁sep|>{\"q\":\"tabby\"}" + "<|tool▁call▁end|>" + ), + delta_text='egin|>lookup<|tool▁sep|>{"q":"tabby"}<|tool▁call▁end|>', + previous_token_ids=[11], + current_token_ids=[11, 12], + delta_token_ids=[12], + ) + assert second is not None + assert second.reasoning is None + assert second.content == ( + '<|tool▁call▁begin|>lookup<|tool▁sep|>{"q":"tabby"}<|tool▁call▁end|>' + ) + + +def test_thinking_mode_content_ids_and_end_detection(): + parser = _parser(enable_thinking=True) + + assert parser.is_reasoning_end([1, 2, 102]) is True + assert parser.is_reasoning_end([1, 2, 3]) is False + + assert parser.extract_content_ids([10, 101, 20, 102, 30, 31]) == [30, 31] + assert parser.extract_content_ids([10, 101, 20]) == [] diff --git a/tests/mistral_reasoning_parser_test.py b/tests/mistral_reasoning_parser_test.py new file mode 100644 index 00000000..d598c018 --- /dev/null +++ b/tests/mistral_reasoning_parser_test.py @@ -0,0 +1,70 @@ +"""Tests for Mistral reasoning parser parity with vLLM behavior.""" + +from endpoints.OAI.reasoning.mistral_reasoning_parser import MistralReasoningParser + + +class _FakeTokenizer: + def get_vocab(self): + return { + "[THINK]": 301, + "[/THINK]": 302, + } + + +def _parser() -> MistralReasoningParser: + return MistralReasoningParser(_FakeTokenizer()) + + +def test_extract_reasoning_with_valid_think_section(): + parser = _parser() + + reasoning, content = parser.extract_reasoning( + "[THINK]This is a reasoning section[/THINK]This is the rest", + request=None, + ) + + assert reasoning == "This is a reasoning section" + assert content == "This is the rest" + + +def test_extract_reasoning_with_invalid_end_token_only(): + parser = _parser() + + reasoning, content = parser.extract_reasoning( + "This is a reasoning section[/THINK]This is the rest", + request=None, + ) + + assert reasoning is None + assert content == "This is a reasoning sectionThis is the rest" + + +def test_extract_reasoning_with_begin_token_only(): + parser = _parser() + + reasoning, content = parser.extract_reasoning( + "[THINK]This is a reasoning section", + request=None, + ) + + assert reasoning == "This is a reasoning section" + assert content is None + + +def test_extract_reasoning_without_think_tokens(): + parser = _parser() + + reasoning, content = parser.extract_reasoning("This is content", request=None) + + assert reasoning is None + assert content == "This is content" + + +def test_is_reasoning_end_and_extract_content_ids(): + parser = _parser() + + assert parser.is_reasoning_end([1, parser.start_token_id, parser.end_token_id]) is True + assert parser.is_reasoning_end([1, 2, 3]) is False + + assert parser.extract_content_ids([7, parser.start_token_id, 9, parser.end_token_id, 10]) == [7, 10] + assert parser.extract_content_ids([7, parser.start_token_id, 9]) == [7] diff --git a/tests/mistral_tokenizer_mode_test.py b/tests/mistral_tokenizer_mode_test.py new file mode 100644 index 00000000..1d422051 --- /dev/null +++ b/tests/mistral_tokenizer_mode_test.py @@ -0,0 +1,76 @@ +"""Tests for mistral tokenizer mode auto-detection.""" + +import json + +from common.tokenizer_modes import ( + normalize_tokenizer_mode, + should_enable_mistral_tokenizer_mode, + supports_mistral_tokenizer_mode, +) + + +def _write_config(directory, model_type: str) -> None: + with open(directory / "config.json", "w", encoding="utf-8") as config_file: + json.dump({"model_type": model_type}, config_file) + + +def test_supports_mistral_with_tekken(tmp_path): + _write_config(tmp_path, "mistral3") + (tmp_path / "tekken.json").write_text("{}", encoding="utf-8") + + assert supports_mistral_tokenizer_mode(tmp_path) is True + + +def test_supports_mistral_with_sentencepiece_variant(tmp_path): + _write_config(tmp_path, "mistral") + (tmp_path / "tokenizer.model.v3").write_text("dummy", encoding="utf-8") + + assert supports_mistral_tokenizer_mode(tmp_path) is True + + +def test_rejects_non_mistral_with_sentencepiece_tokenizer(tmp_path): + _write_config(tmp_path, "gemma2") + (tmp_path / "tokenizer.model").write_text("dummy", encoding="utf-8") + + assert supports_mistral_tokenizer_mode(tmp_path) is False + + +def test_allowlist_enables_listed_mistral_model(tmp_path): + _write_config(tmp_path, "mistral3") + (tmp_path / "tekken.json").write_text("{}", encoding="utf-8") + + assert should_enable_mistral_tokenizer_mode( + tmp_path, [tmp_path.name, "other-model"] + ) + + +def test_allowlist_disables_non_mistral_even_if_listed(tmp_path): + _write_config(tmp_path, "gemma2") + (tmp_path / "tokenizer.model").write_text("dummy", encoding="utf-8") + + assert not should_enable_mistral_tokenizer_mode(tmp_path, [tmp_path.name]) + + +def test_allowlist_disables_unlisted_mistral_model(tmp_path): + _write_config(tmp_path, "mistral3") + (tmp_path / "tekken.json").write_text("{}", encoding="utf-8") + + assert not should_enable_mistral_tokenizer_mode(tmp_path, ["another-model"]) + + +def test_normalize_tokenizer_mode_accepts_deepseek_v32(): + normalized, message = normalize_tokenizer_mode("deepseek_v32") + assert normalized == "deepseek_v32" + assert message is None + + +def test_normalize_tokenizer_mode_maps_slow_to_hf(): + normalized, message = normalize_tokenizer_mode("slow") + assert normalized == "hf" + assert message is not None + + +def test_normalize_tokenizer_mode_unknown_falls_back_to_auto(): + normalized, message = normalize_tokenizer_mode("unknown_mode") + assert normalized == "auto" + assert message is not None diff --git a/tests/model_test.py b/tests/model_test.py index 662ccbc5..d01258aa 100644 --- a/tests/model_test.py +++ b/tests/model_test.py @@ -1,5 +1,8 @@ """Test the model container.""" +import pytest + +pytest.importorskip("exllamav2") from backends.exllamav2.model import ModelContainer diff --git a/tests/parser_options_test.py b/tests/parser_options_test.py new file mode 100644 index 00000000..4cc26e11 --- /dev/null +++ b/tests/parser_options_test.py @@ -0,0 +1,97 @@ +"""Tests for vLLM-compatible parser option mapping.""" + +from endpoints.OAI.utils.parser_options import ( + TOOL_CALL_PARSER_FORMATS, + list_tool_call_parsers, + parser_uses_native_tool_generation, + resolve_tool_call_parser_key, + resolve_tool_call_format, +) +from endpoints.OAI.utils.tools import ToolCallProcessor + + +VLLM_CANONICAL_TOOL_PARSERS = { + "deepseek_v3", + "deepseek_v31", + "deepseek_v32", + "ernie45", + "glm45", + "glm47", + "granite", + "granite-20b-fc", + "hermes", + "hunyuan_a13b", + "internlm", + "jamba", + "kimi_k2", + "llama3_json", + "llama4_json", + "llama4_pythonic", + "longcat", + "minimax", + "minimax_m2", + "mistral", + "olmo3", + "openai", + "phi4_mini_json", + "pythonic", + "qwen3_coder", + "qwen3_xml", + "seed_oss", + "step3", + "step3p5", + "xlam", + "gigachat3", + "functiongemma", +} + + +def test_parser_key_registry_contains_core_vllm_keys(): + parser_keys = list_tool_call_parsers() + + assert "openai" in parser_keys + assert "qwen3_coder" in parser_keys + assert "qwen3_xml" in parser_keys + assert "mistral" in parser_keys + assert "deepseek_v3" in parser_keys + assert "llama" in parser_keys + + +def test_parser_key_registry_matches_vllm_set_plus_local_aliases(): + parser_keys = list_tool_call_parsers() + + # Canonical set should match current vLLM registry. + assert VLLM_CANONICAL_TOOL_PARSERS.issubset(parser_keys) + assert set(TOOL_CALL_PARSER_FORMATS.keys()) - {"auto"} == VLLM_CANONICAL_TOOL_PARSERS + + # Local compatibility alias. + assert "llama" in parser_keys + + +def test_every_configured_canonical_parser_has_dispatch_handler(): + dispatcher = ToolCallProcessor._parser_dispatcher() + canonical = set(TOOL_CALL_PARSER_FORMATS.keys()) - {"auto"} + + missing = sorted(canonical - set(dispatcher.keys())) + assert missing == [] + + +def test_resolve_tool_call_format_uses_vllm_mapping(): + assert resolve_tool_call_format("openai", "json") == "json" + assert resolve_tool_call_format("qwen3_coder", "json") == "xml" + assert resolve_tool_call_format("auto", "json") == "auto" + assert resolve_tool_call_format("llama", "json") == "json" + assert resolve_tool_call_parser_key("llama") == "llama3_json" + + +def test_resolve_tool_call_format_falls_back_and_rejects_unknown(): + assert resolve_tool_call_format(None, "json") == "json" + assert resolve_tool_call_format("unknown_parser", "json") == "" + + +def test_native_generation_flags_cover_native_syntax_parsers(): + assert parser_uses_native_tool_generation("qwen3_coder", "json") is True + assert parser_uses_native_tool_generation("deepseek_v31", "json") is True + assert parser_uses_native_tool_generation("pythonic", "json") is True + assert parser_uses_native_tool_generation("mistral", "json") is True + assert parser_uses_native_tool_generation("hermes", "json") is False diff --git a/tests/qwen3_reasoning_parser_test.py b/tests/qwen3_reasoning_parser_test.py new file mode 100644 index 00000000..53f89f07 --- /dev/null +++ b/tests/qwen3_reasoning_parser_test.py @@ -0,0 +1,138 @@ +"""Tests for Qwen3 reasoning parser parity with modern Qwen3.5 behavior.""" + +from endpoints.OAI.reasoning.qwen3_reasoning_parser import Qwen3ReasoningParser + + +class _FakeTokenizer: + def get_vocab(self): + return { + "": 101, + "": 102, + } + + +def _parser(enable_thinking=None) -> Qwen3ReasoningParser: + kwargs = {} + if enable_thinking is not None: + kwargs["chat_template_kwargs"] = {"enable_thinking": enable_thinking} + return Qwen3ReasoningParser(_FakeTokenizer(), **kwargs) + + +def test_non_stream_extract_thinking_mode_with_prefilled_start_token(): + parser = _parser(enable_thinking=True) + + reasoning, content = parser.extract_reasoning("reasoninganswer", request=None) + assert reasoning == "reasoning" + assert content == "answer" + + +def test_non_stream_extract_without_end_token_treated_as_content(): + parser = _parser(enable_thinking=True) + + reasoning, content = parser.extract_reasoning("reasoning only", request=None) + assert reasoning is None + assert content == "reasoning only" + + +def test_non_stream_extract_non_thinking_mode_content_only(): + parser = _parser(enable_thinking=False) + + reasoning, content = parser.extract_reasoning("hiddenvisible", request=None) + assert reasoning is None + assert content == "hiddenvisible" + + +def test_non_stream_without_explicit_thinking_switch_treats_plain_text_as_content(): + parser = _parser() + + reasoning, content = parser.extract_reasoning("OK", request=None) + assert reasoning is None + assert content == "OK" + + +def test_streaming_prefilled_think_mode_splits_reasoning_and_content(): + parser = _parser(enable_thinking=True) + + first = parser.extract_reasoning_streaming( + previous_text="", + current_text="analysis ", + delta_text="analysis ", + previous_token_ids=[], + current_token_ids=[11], + delta_token_ids=[11], + ) + assert first is not None + assert first.reasoning == "analysis " + assert first.content is None + + second = parser.extract_reasoning_streaming( + previous_text="analysis ", + current_text="analysis stepfinal ", + delta_text="stepfinal ", + previous_token_ids=[11], + current_token_ids=[11, 12, 102, 13], + delta_token_ids=[12, 102, 13], + ) + assert second is not None + assert second.reasoning == "step" + assert second.content == "final " + + third = parser.extract_reasoning_streaming( + previous_text="analysis stepfinal ", + current_text="analysis stepfinal answer", + delta_text="answer", + previous_token_ids=[11, 12, 102, 13], + current_token_ids=[11, 12, 102, 13, 14], + delta_token_ids=[14], + ) + assert third is not None + assert third.reasoning is None + assert third.content == "answer" + + +def test_streaming_non_thinking_mode_emits_content_only(): + parser = _parser(enable_thinking=False) + + delta = parser.extract_reasoning_streaming( + previous_text="", + current_text="plain output", + delta_text="plain output", + previous_token_ids=[], + current_token_ids=[11, 12], + delta_token_ids=[11, 12], + ) + assert delta is not None + assert delta.reasoning is None + assert delta.content == "plain output" + + +def test_streaming_without_explicit_thinking_switch_emits_content_only(): + parser = _parser() + + delta = parser.extract_reasoning_streaming( + previous_text="", + current_text="OK", + delta_text="OK", + previous_token_ids=[], + current_token_ids=[11], + delta_token_ids=[11], + ) + assert delta is not None + assert delta.reasoning is None + assert delta.content == "OK" + + +def test_streaming_strips_generated_start_token_when_present(): + parser = _parser(enable_thinking=True) + + delta = parser.extract_reasoning_streaming( + previous_text="", + current_text="reason", + delta_text="reason", + previous_token_ids=[], + current_token_ids=[101, 11], + delta_token_ids=[101, 11], + ) + assert delta is not None + assert delta.reasoning == "reason" + assert delta.content is None diff --git a/tests/reasoning_parser_registry_test.py b/tests/reasoning_parser_registry_test.py new file mode 100644 index 00000000..ccf1d46e --- /dev/null +++ b/tests/reasoning_parser_registry_test.py @@ -0,0 +1,38 @@ +"""Tests for reasoning parser registry parity with vLLM.""" + +from endpoints.OAI.reasoning import ReasoningParserManager + + +VLLM_CANONICAL_REASONING_PARSERS = { + "deepseek_r1", + "deepseek_v3", + "ernie45", + "glm45", + "openai_gptoss", + "granite", + "holo2", + "hunyuan_a13b", + "kimi_k2", + "minimax_m2", + "minimax_m2_append_think", + "mistral", + "olmo3", + "qwen3", + "seed_oss", + "step3", + "step3p5", +} + + +def test_reasoning_registry_contains_all_vllm_canonical_parsers(): + registered = set(ReasoningParserManager.list_registered()) + missing = sorted(VLLM_CANONICAL_REASONING_PARSERS - registered) + assert missing == [] + + +def test_reasoning_registry_allows_local_extensions(): + registered = set(ReasoningParserManager.list_registered()) + # Local compatibility/default parsers that may not exist in vLLM. + assert "identity" in registered + assert "basic" in registered + assert "exaone4" in registered diff --git a/tests/tool_parser_test.py b/tests/tool_parser_test.py new file mode 100644 index 00000000..7c57d5e1 --- /dev/null +++ b/tests/tool_parser_test.py @@ -0,0 +1,391 @@ +"""Tests for tool call parsing helpers.""" + +import json + +from endpoints.OAI.utils.tools import ToolCallProcessor + + +def _arguments_dict(tool_call): + return json.loads(tool_call.function.arguments) + + +def test_from_json_handles_markdown_fences_and_flat_shape(): + payload = """```json +[{"name": "get_weather", "arguments": {"city": "Seoul"}}] +```""" + + parsed = ToolCallProcessor.from_json(payload) + + assert len(parsed) == 1 + assert parsed[0].function.name == "get_weather" + assert _arguments_dict(parsed[0]) == {"city": "Seoul"} + + +def test_from_xml_parses_qwen3_coder_style_blocks(): + payload = ( + "internal reasoning" + "" + "\nSeoul\n" + "\n3\n" + "" + ) + + parsed = ToolCallProcessor.from_xml(payload) + + assert len(parsed) == 1 + assert parsed[0].function.name == "get_weather" + assert _arguments_dict(parsed[0]) == {"city": "Seoul", "days": 3} + + +def test_from_xml_supports_single_quote_object_parameter(): + payload = ( + "" + "\n{'key': 'value'}\n" + "" + ) + + parsed = ToolCallProcessor.from_xml(payload) + + assert len(parsed) == 1 + assert parsed[0].function.name == "test_types" + assert _arguments_dict(parsed[0]) == {"obj_param": {"key": "value"}} + + +def test_from_xml_parses_incomplete_function_block_at_generation_cutoff(): + payload = ( + "I'll call a tool. " + "" + "\nSeoul\n" + "\n3\n" + # Missing on purpose + ) + + parsed = ToolCallProcessor.from_xml(payload) + + assert len(parsed) == 1 + assert parsed[0].function.name == "get_weather" + assert _arguments_dict(parsed[0]) == {"city": "Seoul", "days": 3} + + +def test_from_auto_parses_json_inside_tool_call_wrapper(): + payload = ( + "" + '{"name": "search", "arguments": {"query": "tabbyapi"}}' + "" + ) + + parsed = ToolCallProcessor.from_auto(payload) + + assert len(parsed) == 1 + assert parsed[0].function.name == "search" + assert _arguments_dict(parsed[0]) == {"query": "tabbyapi"} + + +def test_extract_content_and_tools_splits_content_from_xml_calls(): + payload = ( + "I will call a tool now. " + "\ntabby\n" + "" + " Done." + ) + + content, parsed = ToolCallProcessor.extract_content_and_tools(payload) + + assert "I will call a tool now." in content + assert "Done." in content + assert len(parsed) == 1 + assert parsed[0].function.name == "search" + + +def test_filter_by_name_keeps_only_requested_function(): + payload = ( + "[" + '{"name": "a", "arguments": {}},' + '{"name": "b", "arguments": {}}' + "]" + ) + parsed = ToolCallProcessor.from_json(payload) + + filtered = ToolCallProcessor.filter_by_name(parsed, "b") + + assert len(filtered) == 1 + assert filtered[0].function.name == "b" + + +def test_parse_with_hermes_parser_handles_wrapped_json(): + payload = ( + "" + '{"name":"weather","arguments":{"city":"Seoul"}}' + "" + ) + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="hermes") + + assert len(parsed) == 1 + assert parsed[0].function.name == "weather" + assert _arguments_dict(parsed[0]) == {"city": "Seoul"} + + +def test_parse_with_llama_parser_handles_sequential_json(): + payload = ( + "<|python_tag|>" + '{"name":"a","arguments":{"x":1}};' + '{"name":"b","arguments":{"y":2}}' + ) + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="llama") + + assert len(parsed) == 2 + assert parsed[0].function.name == "a" + assert _arguments_dict(parsed[0]) == {"x": 1} + assert parsed[1].function.name == "b" + assert _arguments_dict(parsed[1]) == {"y": 2} + + +def test_parse_with_pythonic_parser_extracts_function_calls(): + payload = "[get_weather(city='San Francisco', days=3)]" + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="pythonic") + + assert len(parsed) == 1 + assert parsed[0].function.name == "get_weather" + assert _arguments_dict(parsed[0]) == {"city": "San Francisco", "days": 3} + + +def test_parse_with_deepseek_v31_parser(): + payload = ( + "<|tool▁calls▁begin|>" + '<|tool▁call▁begin|>foo<|tool▁sep|>{"x":1}<|tool▁call▁end|>' + "<|tool▁calls▁end|>" + ) + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="deepseek_v31") + + assert len(parsed) == 1 + assert parsed[0].function.name == "foo" + assert _arguments_dict(parsed[0]) == {"x": 1} + + +def test_parse_with_deepseek_v3_parser(): + payload = ( + "<|tool▁calls▁begin|>" + "<|tool▁call▁begin|>function<|tool▁sep|>lookup\n" + "```json\n" + '{"q":"tabbyapi"}' + "\n```\n" + "<|tool▁call▁end|>" + "<|tool▁calls▁end|>" + ) + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="deepseek_v3") + + assert len(parsed) == 1 + assert parsed[0].function.name == "lookup" + assert _arguments_dict(parsed[0]) == {"q": "tabbyapi"} + + +def test_parse_with_deepseek_v32_parser(): + payload = ( + "<|DSML|function_calls>" + '<|DSML|invoke name="get_weather">' + '<|DSML|parameter name="location" string="true">Seoul' + '<|DSML|parameter name="days" string="false">3' + "" + "" + ) + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="deepseek_v32") + + assert len(parsed) == 1 + assert parsed[0].function.name == "get_weather" + assert _arguments_dict(parsed[0]) == {"location": "Seoul", "days": 3} + + +def test_parse_with_ernie45_parser_handles_tool_call_json(): + payload = ( + "" + '{"name":"get_weather","arguments":{"city":"Seoul"}}' + "" + ) + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="ernie45") + + assert len(parsed) == 1 + assert parsed[0].function.name == "get_weather" + assert _arguments_dict(parsed[0]) == {"city": "Seoul"} + + +def test_parse_with_jamba_parser_handles_tool_calls_tag_array(): + payload = ( + "" + '[{"name":"get_weather","arguments":{"city":"Seoul","days":2}}]' + "" + ) + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="jamba") + + assert len(parsed) == 1 + assert parsed[0].function.name == "get_weather" + assert _arguments_dict(parsed[0]) == {"city": "Seoul", "days": 2} + + +def test_parse_with_minimax_parser_handles_line_delimited_json(): + payload = ( + "\n" + '{"name":"foo","arguments":{"x":1}}\n' + '{"name":"bar","arguments":{"y":2}}\n' + "" + ) + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="minimax") + + assert len(parsed) == 2 + assert parsed[0].function.name == "foo" + assert _arguments_dict(parsed[0]) == {"x": 1} + assert parsed[1].function.name == "bar" + assert _arguments_dict(parsed[1]) == {"y": 2} + + +def test_parse_with_glm45_parser_handles_name_and_json_body(): + payload = 'lookup\n{"id":42,"q":"tabbyapi"}' + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="glm45") + + assert len(parsed) == 1 + assert parsed[0].function.name == "lookup" + assert _arguments_dict(parsed[0]) == {"id": 42, "q": "tabbyapi"} + + +def test_parse_with_minimax_m2_parser_handles_invoke_parameters(): + payload = ( + '' + '42' + 'tabbyapi' + "" + ) + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="minimax_m2") + + assert len(parsed) == 1 + assert parsed[0].function.name == "lookup" + assert _arguments_dict(parsed[0]) == {"id": 42, "query": "tabbyapi"} + + +def test_parse_with_seed_oss_parser_handles_seed_xml(): + payload = ( + "" + "\nSeoul\n" + "\n2\n" + "" + ) + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="seed_oss") + + assert len(parsed) == 1 + assert parsed[0].function.name == "get_weather" + assert _arguments_dict(parsed[0]) == {"city": "Seoul", "days": 2} + + +def test_parse_with_olmo3_parser_handles_function_calls_wrapper(): + payload = "\nget_weather(city='Seoul', days=2)\n" + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="olmo3") + + assert len(parsed) == 1 + assert parsed[0].function.name == "get_weather" + assert _arguments_dict(parsed[0]) == {"city": "Seoul", "days": 2} + + +def test_parse_with_step3p5_parser_handles_qwen3_xml_shape(): + payload = ( + "" + "\ntabbyapi\n" + "" + ) + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="step3p5") + + assert len(parsed) == 1 + assert parsed[0].function.name == "search" + assert _arguments_dict(parsed[0]) == {"q": "tabbyapi"} + + +def test_parse_with_openai_parser_handles_functions_recipient(): + payload = ( + '[{"recipient":"functions.get_weather","content":"{\\"city\\":\\"Seoul\\"}"}]' + ) + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="openai") + + assert len(parsed) == 1 + assert parsed[0].function.name == "get_weather" + assert _arguments_dict(parsed[0]) == {"city": "Seoul"} + + +def test_parse_with_mistral_parser_handles_pre_v11_json(): + payload = ( + '[TOOL_CALLS] [{"name":"get_weather","arguments":{"city":"Seoul","days":2}}]' + ) + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="mistral") + + assert len(parsed) == 1 + assert parsed[0].function.name == "get_weather" + assert _arguments_dict(parsed[0]) == {"city": "Seoul", "days": 2} + assert parsed[0].id.isalnum() + assert len(parsed[0].id) == 9 + + +def test_parse_with_mistral_parser_handles_v11_style_segments(): + payload = ( + '[TOOL_CALLS]search{"q":"tabbyapi"}' + '[TOOL_CALLS]lookup{"id":42}' + ) + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="mistral") + + assert len(parsed) == 2 + assert parsed[0].function.name == "search" + assert _arguments_dict(parsed[0]) == {"q": "tabbyapi"} + assert parsed[1].function.name == "lookup" + assert _arguments_dict(parsed[1]) == {"id": 42} + assert parsed[1].id.isalnum() + assert len(parsed[1].id) == 9 + + +def test_parse_with_mistral_parser_falls_back_to_standard_json(): + payload = '[{"name":"lookup","arguments":{"id":42}}]' + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="mistral") + + assert len(parsed) == 1 + assert parsed[0].function.name == "lookup" + assert _arguments_dict(parsed[0]) == {"id": 42} + + +def test_parser_key_dispatch_overrides_format_for_qwen3_xml(): + payload = ( + "" + "\ntabby\n" + "" + ) + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="qwen3_xml") + + assert len(parsed) == 1 + assert parsed[0].function.name == "search" + assert _arguments_dict(parsed[0]) == {"q": "tabby"} + + +def test_parser_failure_falls_back_to_format_parser(): + payload = ( + "" + "\n42\n" + "" + ) + + parsed = ToolCallProcessor.parse(payload, format="xml", parser_key="openai") + + assert len(parsed) == 1 + assert parsed[0].function.name == "lookup" + assert _arguments_dict(parsed[0]) == {"id": 42} diff --git a/tests/wheel_test.py b/tests/wheel_test.py index 733c6188..1300d100 100644 --- a/tests/wheel_test.py +++ b/tests/wheel_test.py @@ -2,23 +2,40 @@ from importlib.metadata import version from importlib.util import find_spec +from packaging import version as package_version successful_packages = [] errored_packages = [] if find_spec("flash_attn") is not None: - print(f"Flash attention on version {version('flash_attn')} successfully imported") - successful_packages.append("flash_attn") + torch_version = version("torch").split("+")[0] if find_spec("torch") else "0" + if package_version.parse(torch_version) >= package_version.parse("2.10.0"): + print( + f"Flash attention 2 detected with torch {torch_version}. " + "This combination is unsupported for the flashinfer migration." + ) + errored_packages.append("flash_attn") + else: + print( + "Flash attention 2 is installed. " + "ExLlamaV3 now uses flashinfer." + ) + +if find_spec("flashinfer") is not None: + print( + "FlashInfer on version " + f"{version('flashinfer-python')} successfully imported" + ) + successful_packages.append("flashinfer") else: - print("Flash attention 2 is not found in your environment.") - errored_packages.append("flash_attn") + print("FlashInfer is not found in your environment.") + errored_packages.append("flashinfer") if find_spec("exllamav2") is not None: print(f"Exllamav2 on version {version('exllamav2')} successfully imported") successful_packages.append("exllamav2") else: - print("Exllamav2 is not found in your environment.") - errored_packages.append("exllamav2") + print("Exllamav2 is not found in your environment (optional).") if find_spec("torch") is not None: print(f"Torch on version {version('torch')} successfully imported")