From 2b21800dcefbf461157b97a1374ddcea71dcb73e Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 11 May 2026 11:37:21 +0800 Subject: [PATCH 01/15] fix slgang disk sync --- trinity/common/models/sglang_model.py | 2 +- trinity/manager/synchronizer.py | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/trinity/common/models/sglang_model.py b/trinity/common/models/sglang_model.py index 8c22086eb81..a37748eaa66 100644 --- a/trinity/common/models/sglang_model.py +++ b/trinity/common/models/sglang_model.py @@ -493,7 +493,7 @@ async def sync_model( ) self.model_version = model_version elif method == SyncMethod.CHECKPOINT: - model_path = await self.synchronizer.get_latest_model_path.remote() + model_path = await self.synchronizer.get_latest_model_path.remote(use_huggingface=True) if model_path is not None: await self.api_client.update_weights_from_disk( model_path=model_path, diff --git a/trinity/manager/synchronizer.py b/trinity/manager/synchronizer.py index 3c72364674c..d6ed8f7793f 100644 --- a/trinity/manager/synchronizer.py +++ b/trinity/manager/synchronizer.py @@ -356,14 +356,19 @@ async def get_latest_model_version(self) -> int: async with self._ready_condition: return self.model_version - async def get_latest_model_path(self) -> Optional[str]: + async def get_latest_model_path(self, use_huggingface: bool = False) -> Optional[str]: """ Get the latest model path available in the synchronizer. + Args: + use_huggingface: Whether to return the Hugging Face model path. + Returns: The current model path. """ async with self._ready_condition: + if self.model_path and use_huggingface: + return self.model_path + "/huggingface" return self.model_path async def ready_to_nccl_sync(self, module: str, trainer_step: int) -> Union[int, None]: From 38b60811924cf7fac8f1e1e21168c5ed09a1951d Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 11 May 2026 14:12:02 +0800 Subject: [PATCH 02/15] move args into sglang patch --- trinity/common/models/sglang_model.py | 54 +++++++------------ .../common/models/sglang_patch/api_patch.py | 38 +++++++++---- 2 files changed, 48 insertions(+), 44 deletions(-) diff --git a/trinity/common/models/sglang_model.py b/trinity/common/models/sglang_model.py index a37748eaa66..04fbf56dea3 100644 --- a/trinity/common/models/sglang_model.py +++ b/trinity/common/models/sglang_model.py @@ -4,7 +4,7 @@ import os import traceback from logging import Logger -from typing import TYPE_CHECKING, Any, List, Literal, Optional, Sequence, Tuple +from typing import Any, List, Literal, Optional, Sequence, Tuple import httpx import torch @@ -16,8 +16,7 @@ from trinity.common.models.model import BaseInferenceModel from trinity.manager.synchronizer import Synchronizer -if TYPE_CHECKING: - from sglang.srt.server_args import ServerArgs +SGLANG_API_KEY = "EMPTY" # SGLang API server does not actually check the API key, so we can use a dummy value here. class SGLangClient: @@ -234,7 +233,6 @@ def __init__( self.synchronizer = None self.state_dict_meta: List[Tuple[str, str, Tuple]] = [] self.model_version = 0 - self.server_args: Optional[ServerArgs] = None self._prepared = False self._has_weight_update_group = False self.async_lock = asyncio.Lock() @@ -388,30 +386,6 @@ async def convert_messages_to_experience( "SGLangRolloutModel does not support convert_messages_to_experience." ) - def _build_server_args(self, host: str, port: int): - from sglang.srt.server_args import ServerArgs - - server_args_kwargs = { - "model_path": self.config.model_path, - "host": host, - "port": port, - "tp_size": self.config.tensor_parallel_size, - "dtype": self.config.dtype, - "mem_fraction_static": self.config.gpu_memory_utilization, - "served_model_name": self.config.name or self.config.model_path, - "trust_remote_code": self.config.trust_remote_code, - "context_length": self.config.max_model_len, - "enable_multimodal": self.config.enable_multimodal, - "skip_server_warmup": True, - "disable_piecewise_cuda_graph": True, - "api_key": "EMPTY", - "device": "cuda", - } - # if self.config.chat_template: - # server_args_kwargs["chat_template"] = self.config.chat_template - # TODO: fill in nnodes and node_rank for distributed setups - return ServerArgs(**server_args_kwargs) - def _get_api_server_exit_reason(self) -> Optional[str]: if self.api_server is None or not self.api_server.done(): return None @@ -422,7 +396,7 @@ def _get_api_server_exit_reason(self) -> Optional[str]: async def _wait_until_server_ready(self, server_url: str) -> None: max_retries = 100 - assert self.server_args is not None and self.api_client is not None + assert self.api_client is not None for _ in range(max_retries): reason = self._get_api_server_exit_reason() if reason is not None: @@ -445,17 +419,27 @@ async def run_api_server(self) -> bool: if self.api_server_host is None or self.api_server_port is None: self.api_server_host, self.api_server_port = self.get_available_address() - self.server_args = self._build_server_args( + self.api_server = get_api_server( host=self.api_server_host, port=self.api_server_port, + model_path=self.config.model_path, + tensor_parallel_size=self.config.tensor_parallel_size, + dtype=self.config.dtype, + served_model_name=self.config.name or self.config.model_path, + mem_fraction_static=self.config.gpu_memory_utilization, + trust_remote_code=self.config.trust_remote_code, + context_length=self.config.max_model_len, + enable_multimodal=self.config.enable_multimodal, + api_key=SGLANG_API_KEY, + logger=self.logger, ) - self.api_server = get_api_server(self.server_args, logger=self.logger) + server_url = f"http://{self.api_server_host}:{self.api_server_port}" self.api_client = SGLangClient( - server_url=f"http://{self.api_server_host}:{self.api_server_port}", - api_key=self.server_args.api_key, + server_url=server_url, + api_key=SGLANG_API_KEY, logger=self.logger, ) - await self._wait_until_server_ready(self.server_args.url()) + await self._wait_until_server_ready(server_url) return True async def shutdown(self) -> None: @@ -493,7 +477,7 @@ async def sync_model( ) self.model_version = model_version elif method == SyncMethod.CHECKPOINT: - model_path = await self.synchronizer.get_latest_model_path.remote(use_huggingface=True) + model_path = await self.synchronizer.get_latest_model_path.remote() if model_path is not None: await self.api_client.update_weights_from_disk( model_path=model_path, diff --git a/trinity/common/models/sglang_patch/api_patch.py b/trinity/common/models/sglang_patch/api_patch.py index 87050893e2c..f01b249b83b 100644 --- a/trinity/common/models/sglang_patch/api_patch.py +++ b/trinity/common/models/sglang_patch/api_patch.py @@ -120,17 +120,37 @@ def _cleanup_task(task: "asyncio.Task[None]") -> None: def get_api_server( - server_args: ServerArgs, + host: str, + port: int, + model_path: Optional[str], + tensor_parallel_size: int, + dtype: str, + served_model_name: Optional[str], + mem_fraction_static: float, + trust_remote_code: bool, + context_length: Optional[int], + enable_multimodal: bool, + api_key: str, logger: Logger, ) -> "asyncio.Task[None]": - if server_args.enable_http2: - raise NotImplementedError("Embedded SGLang server does not support HTTP/2 yet.") - if server_args.tokenizer_worker_num != 1: - raise NotImplementedError( - "Embedded SGLang server currently supports tokenizer_worker_num == 1 only." - ) - if server_args.enable_ssl_refresh: - raise NotImplementedError("Embedded SGLang server does not support SSL refresh yet.") + # TODO: fill in nnodes and node_rank for distributed setups + # TODO: fix chat template + server_args = ServerArgs( + host=host, + port=port, + model_path=model_path, + tp_size=tensor_parallel_size, + dtype=dtype, + served_model_name=served_model_name, + mem_fraction_static=mem_fraction_static, + trust_remote_code=trust_remote_code, + context_length=context_length, + enable_multimodal=enable_multimodal, + skip_server_warmup=True, + disable_piecewise_cuda_graph=True, + api_key=api_key, + device="cuda", + ) ( tokenizer_manager, From fc2db143cad791f1a41bc7ed32ffe61447711a9c Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 11 May 2026 18:23:21 +0800 Subject: [PATCH 03/15] update sglang tests --- tests/common/sglang_test.py | 272 ++++++++++++++++++++++++++++++++++++ 1 file changed, 272 insertions(+) create mode 100644 tests/common/sglang_test.py diff --git a/tests/common/sglang_test.py b/tests/common/sglang_test.py new file mode 100644 index 00000000000..da1f54f7fa2 --- /dev/null +++ b/tests/common/sglang_test.py @@ -0,0 +1,272 @@ +import asyncio + +from parameterized import parameterized_class +from transformers import AutoTokenizer + +from tests.tools import CHAT_TEMPLATE, RayUnittestBaseAsync, get_model_path, get_template_config +from trinity.common.models import create_explorer_models +from trinity.common.models.model import ModelWrapper + + +async def prepare_engines(engines, auxiliary_engines): + prepare_model_refs = [] + for engine in engines: + prepare_model_refs.append(engine.prepare.remote()) + for engines in auxiliary_engines: + for engine in engines: + prepare_model_refs.append(engine.prepare.remote()) + await asyncio.gather(*prepare_model_refs) + + +def assert_experience_tokens_match_text(test_case, tokenizer, exp, prompt_contents, response_text): + full_text = tokenizer.decode(exp.tokens.tolist(), skip_special_tokens=False) + prompt_text = tokenizer.decode(exp.tokens[: exp.prompt_length].tolist(), skip_special_tokens=False) + decoded_response_text = tokenizer.decode( + exp.tokens[exp.prompt_length :].tolist(), skip_special_tokens=False + ) + + for prompt_content in prompt_contents: + test_case.assertIn(prompt_content, full_text) + test_case.assertIn(prompt_content, prompt_text) + test_case.assertIn(response_text, full_text) + test_case.assertIn(response_text, decoded_response_text) + + +@parameterized_class( + ( + "tensor_parallel_size", + "engine_num", + "enable_history", + ), + [ + (1, 1, True), + (1, 2, False), + (2, 1, True), + ], +) +class TestSGLangOpenAIAPI(RayUnittestBaseAsync): + def setUp(self): + self.config = get_template_config() + self.config.mode = "explore" + self.config.model.model_path = get_model_path() + self.config.explorer.rollout_model.engine_type = "sglang" + self.config.explorer.rollout_model.engine_num = self.engine_num + self.config.explorer.rollout_model.tensor_parallel_size = self.tensor_parallel_size + self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE + self.config.explorer.rollout_model.enable_openai_api = True + self.config.explorer.rollout_model.base_port = 13000 + self.config.check_and_update() + + self.engines, self.auxiliary_engines = create_explorer_models(self.config) + self.model_wrapper = ModelWrapper(self.engines[0], enable_history=self.enable_history) + self.tokenizer = AutoTokenizer.from_pretrained( + self.config.model.model_path, + trust_remote_code=self.config.explorer.rollout_model.trust_remote_code, + ) + + def _assert_experience_matches_text(self, exp, prompt_contents, response_text): + self.assertGreater(exp.prompt_length, 0) + self.assertGreater(len(exp.tokens), exp.prompt_length) + assert_experience_tokens_match_text(self, self.tokenizer, exp, prompt_contents, response_text) + + def _assert_history_matches_responses(self, expected_count, prompt_contents, response_texts): + if not self.enable_history: + self.assertEqual(len(self.model_wrapper.history), 0) + return [] + + exps = self.model_wrapper.extract_experience_from_history() + self.assertEqual(len(exps), expected_count) + for exp, response_text in zip(exps, response_texts): + self.assertEqual(exp.response_text, response_text) + self._assert_experience_matches_text(exp, prompt_contents, response_text) + return exps + + def _get_tool_call_case(self): + tool_messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": "Use the weather tool result to answer what the weather is in Boston.", + }, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_weather_boston", + "type": "function", + "function": { + "name": "get_current_weather", + "arguments": '{"location":"Boston","unit":"fahrenheit"}', + }, + } + ], + }, + { + "role": "tool", + "tool_call_id": "call_weather_boston", + "content": "The weather in Boston is 72 F.", + }, + ] + tool_prompt_contents = [ + tool_messages[0]["content"], + tool_messages[1]["content"], + tool_messages[3]["content"], + ] + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string"}, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + }, + } + ] + return tool_messages, tool_prompt_contents, tools + + async def _collect_response_texts(self, response): + response_texts = [] + for choice in response.choices: + self.assertIsNotNone(choice.message.content) + self.assertGreater(len(choice.message.content), 0) + response_texts.append(choice.message.content) + return response_texts + + async def _collect_stream_contents(self, stream_response, n): + contents = ["" for _ in range(n)] + async for chunk in stream_response: + for choice in chunk.choices: + if choice.delta.content is not None: + contents[choice.index] += choice.delta.content + return contents + + async def test_chat_completions(self): + await prepare_engines(self.engines, self.auxiliary_engines) + await self.model_wrapper.prepare() + + self.assertEqual(self.model_wrapper.model_path, self.config.model.model_path) + self.assertIsNotNone(self.model_wrapper.api_address) + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Write one short sentence about Boston."}, + ] + prompt_contents = [message["content"] for message in messages] + + openai_client = self.model_wrapper.get_openai_async_client() + response = await openai_client.chat.completions.create( + model=openai_client.model_path, + messages=messages, + n=2, + temperature=0.7, + max_tokens=32, + ) + + self.assertEqual(len(response.choices), 2) + response_texts = await self._collect_response_texts(response) + self._assert_history_matches_responses(2, prompt_contents, response_texts) + + tool_messages, tool_prompt_contents, tools = self._get_tool_call_case() + tool_response = await openai_client.chat.completions.create( + model=openai_client.model_path, + messages=tool_messages, + tools=tools, + tool_choice="none", + temperature=0.7, + max_tokens=32, + ) + + self.assertEqual(len(tool_response.choices), 1) + tool_response_texts = await self._collect_response_texts(tool_response) + self._assert_history_matches_responses(1, tool_prompt_contents, tool_response_texts) + + stream_response = await openai_client.chat.completions.create( + model=openai_client.model_path, + messages=messages, + n=2, + stream=True, + temperature=0.7, + max_tokens=32, + ) + stream_contents = await self._collect_stream_contents(stream_response, 2) + + self.assertEqual(len(stream_contents), 2) + for content in stream_contents: + self.assertGreater(len(content), 0) + self._assert_history_matches_responses(2, prompt_contents, stream_contents) + + stream_tool_response = await openai_client.chat.completions.create( + model=openai_client.model_path, + messages=tool_messages, + tools=tools, + tool_choice="none", + n=1, + stream=True, + temperature=0.7, + max_tokens=32, + ) + stream_tool_contents = await self._collect_stream_contents(stream_tool_response, 1) + + self.assertEqual(len(stream_tool_contents), 1) + self.assertGreater(len(stream_tool_contents[0]), 0) + self._assert_history_matches_responses(1, tool_prompt_contents, stream_tool_contents) + + chat_exps = await self.model_wrapper.chat_async( + messages, + n=2, + temperature=0.7, + max_tokens=32, + ) + + self.assertEqual(len(chat_exps), 2) + for exp in chat_exps: + self.assertGreater(len(exp.response_text), 0) + self.assertGreater(exp.prompt_length, 0) + self.assertGreater(len(exp.tokens), exp.prompt_length) + + if self.enable_history: + chat_history = self.model_wrapper.extract_experience_from_history() + self.assertEqual(len(chat_history), 2) + for exp, recorded_exp in zip(chat_exps, chat_history): + self.assertEqual(recorded_exp.response_text, exp.response_text) + self.assertEqual(recorded_exp.prompt_length, exp.prompt_length) + self._assert_experience_matches_text( + recorded_exp, prompt_contents, exp.response_text + ) + else: + self.assertEqual(len(self.model_wrapper.history), 0) + + generate_prompt = "Write one short sentence about Boston." + generate_exps = await self.model_wrapper.generate_async( + [generate_prompt], + n=2, + temperature=0.7, + max_tokens=32, + ) + + self.assertEqual(len(generate_exps), 2) + for exp in generate_exps: + self.assertEqual(exp.prompt_text, generate_prompt) + self.assertGreater(len(exp.response_text), 0) + self.assertGreater(exp.prompt_length, 0) + self.assertGreater(len(exp.tokens), exp.prompt_length) + + if self.enable_history: + generate_history = self.model_wrapper.extract_experience_from_history() + self.assertEqual(len(generate_history), 2) + for exp, recorded_exp in zip(generate_exps, generate_history): + self.assertEqual(recorded_exp.response_text, exp.response_text) + self.assertEqual(recorded_exp.prompt_text, exp.prompt_text) + self._assert_experience_matches_text( + recorded_exp, [generate_prompt], exp.response_text + ) + else: + self.assertEqual(len(self.model_wrapper.history), 0) From 01e809a381e204f9301acd23c20d6f8e0232a7e8 Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 11 May 2026 18:24:54 +0800 Subject: [PATCH 04/15] fix pre-commit --- tests/common/sglang_test.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/tests/common/sglang_test.py b/tests/common/sglang_test.py index da1f54f7fa2..54317d53078 100644 --- a/tests/common/sglang_test.py +++ b/tests/common/sglang_test.py @@ -3,7 +3,12 @@ from parameterized import parameterized_class from transformers import AutoTokenizer -from tests.tools import CHAT_TEMPLATE, RayUnittestBaseAsync, get_model_path, get_template_config +from tests.tools import ( + CHAT_TEMPLATE, + RayUnittestBaseAsync, + get_model_path, + get_template_config, +) from trinity.common.models import create_explorer_models from trinity.common.models.model import ModelWrapper @@ -20,7 +25,9 @@ async def prepare_engines(engines, auxiliary_engines): def assert_experience_tokens_match_text(test_case, tokenizer, exp, prompt_contents, response_text): full_text = tokenizer.decode(exp.tokens.tolist(), skip_special_tokens=False) - prompt_text = tokenizer.decode(exp.tokens[: exp.prompt_length].tolist(), skip_special_tokens=False) + prompt_text = tokenizer.decode( + exp.tokens[: exp.prompt_length].tolist(), skip_special_tokens=False + ) decoded_response_text = tokenizer.decode( exp.tokens[exp.prompt_length :].tolist(), skip_special_tokens=False ) @@ -67,7 +74,9 @@ def setUp(self): def _assert_experience_matches_text(self, exp, prompt_contents, response_text): self.assertGreater(exp.prompt_length, 0) self.assertGreater(len(exp.tokens), exp.prompt_length) - assert_experience_tokens_match_text(self, self.tokenizer, exp, prompt_contents, response_text) + assert_experience_tokens_match_text( + self, self.tokenizer, exp, prompt_contents, response_text + ) def _assert_history_matches_responses(self, expected_count, prompt_contents, response_texts): if not self.enable_history: From a0f8ac5f48560857dd2617113bca8561f1a63276 Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 11 May 2026 22:51:22 +0800 Subject: [PATCH 05/15] fix sglang weight sync --- perf/scripts/explorer/perf_workflow.py | 2 ++ trinity/common/config.py | 5 +++++ trinity/common/models/sglang_model.py | 1 + 3 files changed, 8 insertions(+) diff --git a/perf/scripts/explorer/perf_workflow.py b/perf/scripts/explorer/perf_workflow.py index 4b705617875..d1e5f466eb7 100644 --- a/perf/scripts/explorer/perf_workflow.py +++ b/perf/scripts/explorer/perf_workflow.py @@ -77,4 +77,6 @@ async def run_async(self) -> List[Experience]: else 0.0 ), } + for exp in exps: + exp.reward = 1.0 return exps diff --git a/trinity/common/config.py b/trinity/common/config.py index b6729bcd662..8c7257de197 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -573,6 +573,11 @@ class InferenceModelConfig: # For external API-based engine external_model_config: ExternalModelConfig = field(default_factory=ExternalModelConfig) + # for multi-node setup + nnode: int = 1 + # ! DO NOT SET + node_rank: int = 0 + # ! DO NOT SET bundle_indices: str = "" engine_id: int = 0 diff --git a/trinity/common/models/sglang_model.py b/trinity/common/models/sglang_model.py index 04fbf56dea3..a0a4c502ce7 100644 --- a/trinity/common/models/sglang_model.py +++ b/trinity/common/models/sglang_model.py @@ -226,6 +226,7 @@ def __init__( self.config.enable_openai_api = True os.environ["SGLANG_GRPC_PORT"] = "12345" # a dummy port not actually used os.environ["SGLANG_ENABLE_GRPC"] = "0" + os.environ.setdefault("NCCL_SHM_DISABLE", "1") self.api_server_host: Optional[str] = None self.api_server_port: Optional[int] = None self.api_server: Optional[asyncio.Task[None]] = None From 82261d2eb1e378508db02562f6c0ef046e211d60 Mon Sep 17 00:00:00 2001 From: pxc Date: Tue, 12 May 2026 11:57:28 +0800 Subject: [PATCH 06/15] fix dtype --- trinity/common/models/sglang_model.py | 8 ++++++-- trinity/common/models/vllm_worker.py | 2 +- trinity/trainer/verl/fsdp_workers.py | 6 ++++-- trinity/trainer/verl/megatron_workers.py | 4 +++- 4 files changed, 14 insertions(+), 6 deletions(-) diff --git a/trinity/common/models/sglang_model.py b/trinity/common/models/sglang_model.py index a0a4c502ce7..24ae03b6723 100644 --- a/trinity/common/models/sglang_model.py +++ b/trinity/common/models/sglang_model.py @@ -226,7 +226,12 @@ def __init__( self.config.enable_openai_api = True os.environ["SGLANG_GRPC_PORT"] = "12345" # a dummy port not actually used os.environ["SGLANG_ENABLE_GRPC"] = "0" - os.environ.setdefault("NCCL_SHM_DISABLE", "1") + os.environ[ + "NCCL_P2P_DISABLE" + ] = "1" # disable NCCL P2P to avoid potential issues in certain environments + os.environ[ + "NCCL_SHM_DISABLE" + ] = "1" # disable NCCL SHM to avoid potential issues in certain environments self.api_server_host: Optional[str] = None self.api_server_port: Optional[int] = None self.api_server: Optional[asyncio.Task[None]] = None @@ -382,7 +387,6 @@ async def convert_messages_to_experience( tools=None, temperature: Optional[float] = None, ) -> Experience: - del messages, tools, temperature raise NotImplementedError( "SGLangRolloutModel does not support convert_messages_to_experience." ) diff --git a/trinity/common/models/vllm_worker.py b/trinity/common/models/vllm_worker.py index 527690a64ed..a38424166e8 100644 --- a/trinity/common/models/vllm_worker.py +++ b/trinity/common/models/vllm_worker.py @@ -84,7 +84,7 @@ def update_weight(self): weight = state_dict[name] weight = weight.to(self.device) else: - dtype = getattr(torch, dtype_str.split(".")[-1]) + dtype = getattr(torch, dtype_str) weight = torch.empty(shape, dtype=dtype, device=self.device) torch.distributed.broadcast(weight, 0, group=self._model_update_group) weight = weight.type(self.model_config.dtype) diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index 1620bcf5987..54440cbb6e9 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -799,13 +799,15 @@ def setup_weight_sync_group(self): else name ) self.state_dict_meta.append( - (realname, str(param.dtype), tuple(param.shape)) + (realname, str(param.dtype).split(".")[-1], tuple(param.shape)) ) param = None torch.cuda.empty_cache() else: # fsdp2 for name, param in model.named_parameters(): - self.state_dict_meta.append((name, str(param.dtype), tuple(param.shape))) + self.state_dict_meta.append( + (name, str(param.dtype).split(".")[-1], tuple(param.shape)) + ) if torch.distributed.get_rank() == 0: import ray diff --git a/trinity/trainer/verl/megatron_workers.py b/trinity/trainer/verl/megatron_workers.py index 26d6b470f13..2a7b285f826 100644 --- a/trinity/trainer/verl/megatron_workers.py +++ b/trinity/trainer/verl/megatron_workers.py @@ -752,7 +752,9 @@ def setup_weight_sync_group(self): if self._is_offload_param: load_megatron_model_to_gpu(self.actor_module) for name, weight in self._get_tensor_generator(): - self.state_dict_meta.append((name, str(weight.dtype), tuple(weight.shape))) + self.state_dict_meta.append( + (name, str(weight.dtype).split(".")[-1], tuple(weight.shape)) + ) del weight if self._is_offload_param: offload_megatron_model_to_cpu(self.actor_module) From 6fd503a071ca94e0d44103d30e9a93615e98a01d Mon Sep 17 00:00:00 2001 From: pxc Date: Tue, 12 May 2026 17:33:29 +0800 Subject: [PATCH 07/15] add patch --- trinity/common/models/sglang_patch/__init__.py | 2 +- .../models/sglang_patch/{api_patch.py => server_patch.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename trinity/common/models/sglang_patch/{api_patch.py => server_patch.py} (100%) diff --git a/trinity/common/models/sglang_patch/__init__.py b/trinity/common/models/sglang_patch/__init__.py index 7f2d6fdb07f..27b754acdc6 100644 --- a/trinity/common/models/sglang_patch/__init__.py +++ b/trinity/common/models/sglang_patch/__init__.py @@ -1,4 +1,4 @@ -from trinity.common.models.sglang_patch.api_patch import get_api_server +from trinity.common.models.sglang_patch.server_patch import get_api_server __all__ = [ "get_api_server", diff --git a/trinity/common/models/sglang_patch/api_patch.py b/trinity/common/models/sglang_patch/server_patch.py similarity index 100% rename from trinity/common/models/sglang_patch/api_patch.py rename to trinity/common/models/sglang_patch/server_patch.py From 700eaef50824cc8e95ce77b8108f1f1176c3c869 Mon Sep 17 00:00:00 2001 From: pxc Date: Tue, 12 May 2026 19:18:59 +0800 Subject: [PATCH 08/15] fix docker file --- .github/workflows/docker/docker-compose.yaml | 4 +- scripts/docker/Dockerfile.uv | 8 + tests/common/sglang_test.py | 61 ++--- tests/trainer/trainer_test.py | 10 +- .../models/sglang_patch/openai_api_patch.py | 216 ++++++++++++++++++ .../models/sglang_patch/server_patch.py | 60 +++++ 6 files changed, 322 insertions(+), 37 deletions(-) create mode 100644 trinity/common/models/sglang_patch/openai_api_patch.py diff --git a/.github/workflows/docker/docker-compose.yaml b/.github/workflows/docker/docker-compose.yaml index 544916bdd93..54747ee9662 100644 --- a/.github/workflows/docker/docker-compose.yaml +++ b/.github/workflows/docker/docker-compose.yaml @@ -1,6 +1,6 @@ services: trinity-node-1: - image: trinity-rft-unittest:20260506 + image: trinity-rft-unittest:20260512 cap_add: - SYS_PTRACE pull_policy: never @@ -34,7 +34,7 @@ services: capabilities: [gpu] trinity-node-2: - image: trinity-rft-unittest:20260506 + image: trinity-rft-unittest:20260512 cap_add: - SYS_PTRACE pull_policy: never diff --git a/scripts/docker/Dockerfile.uv b/scripts/docker/Dockerfile.uv index 740a9d796e1..ba405c0cec2 100644 --- a/scripts/docker/Dockerfile.uv +++ b/scripts/docker/Dockerfile.uv @@ -71,6 +71,14 @@ RUN . /opt/venv/bin/activate && MAX_JOBS=${BUILD_JOBS} \ --config-settings="--build-option=--cuda_ext" \ git+https://github.com/NVIDIA/apex.git +# Install SGLang +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh + +RUN . /opt/venv/bin/activate && git clone -b v0.5.11 https://github.com/sgl-project/sglang.git /tmp/sglang && \ + uv pip install /tmp/sglang/python && \ + rm -rf /tmp/sglang && \ + uv pip install transformers==5.8.0 + # Set Env variables # ENV LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:/usr/local/nvidia/lib:/usr/local/nvidia/lib64:/usr/local/cuda/lib64 diff --git a/tests/common/sglang_test.py b/tests/common/sglang_test.py index 54317d53078..0a76eaa3ff0 100644 --- a/tests/common/sglang_test.py +++ b/tests/common/sglang_test.py @@ -197,36 +197,37 @@ async def test_chat_completions(self): tool_response_texts = await self._collect_response_texts(tool_response) self._assert_history_matches_responses(1, tool_prompt_contents, tool_response_texts) - stream_response = await openai_client.chat.completions.create( - model=openai_client.model_path, - messages=messages, - n=2, - stream=True, - temperature=0.7, - max_tokens=32, - ) - stream_contents = await self._collect_stream_contents(stream_response, 2) - - self.assertEqual(len(stream_contents), 2) - for content in stream_contents: - self.assertGreater(len(content), 0) - self._assert_history_matches_responses(2, prompt_contents, stream_contents) - - stream_tool_response = await openai_client.chat.completions.create( - model=openai_client.model_path, - messages=tool_messages, - tools=tools, - tool_choice="none", - n=1, - stream=True, - temperature=0.7, - max_tokens=32, - ) - stream_tool_contents = await self._collect_stream_contents(stream_tool_response, 1) - - self.assertEqual(len(stream_tool_contents), 1) - self.assertGreater(len(stream_tool_contents[0]), 0) - self._assert_history_matches_responses(1, tool_prompt_contents, stream_tool_contents) + if not self.enable_history: + stream_response = await openai_client.chat.completions.create( + model=openai_client.model_path, + messages=messages, + n=2, + stream=True, + temperature=0.7, + max_tokens=32, + ) + stream_contents = await self._collect_stream_contents(stream_response, 2) + + self.assertEqual(len(stream_contents), 2) + for content in stream_contents: + self.assertGreater(len(content), 0) + self._assert_history_matches_responses(2, prompt_contents, stream_contents) + + stream_tool_response = await openai_client.chat.completions.create( + model=openai_client.model_path, + messages=tool_messages, + tools=tools, + tool_choice="none", + n=1, + stream=True, + temperature=0.7, + max_tokens=32, + ) + stream_tool_contents = await self._collect_stream_contents(stream_tool_response, 1) + + self.assertEqual(len(stream_tool_contents), 1) + self.assertGreater(len(stream_tool_contents[0]), 0) + self._assert_history_matches_responses(1, tool_prompt_contents, stream_tool_contents) chat_exps = await self.model_wrapper.chat_async( messages, diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 69673abca68..4547daa7792 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -239,12 +239,11 @@ def tearDown(self): @parameterized_class( - ("fsdp_strategy", "offloading"), + ("fsdp_strategy", "offloading", "engine_type"), [ - ("fsdp", False), - ("fsdp2", False), - ("fsdp", True), - ("fsdp2", True), + ("fsdp", False, "vllm"), + ("fsdp2", False, "vllm"), + ("fsdp2", True, "sglang"), ], ) class TestTrainerGSM8K(BaseTrainerCase): @@ -259,6 +258,7 @@ def test_trainer(self): } # self.config.algorithm.repeat_times = 8 # TODO: used for real testing # self.config.buffer.batch_size = 96 # TODO: used for real testing + self.config.explorer.rollout_model.engine_type = self.engine_type self.config.buffer.total_epochs = 1 self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k") self.config.trainer.trainer_strategy = self.fsdp_strategy diff --git a/trinity/common/models/sglang_patch/openai_api_patch.py b/trinity/common/models/sglang_patch/openai_api_patch.py new file mode 100644 index 00000000000..0fc9aebb877 --- /dev/null +++ b/trinity/common/models/sglang_patch/openai_api_patch.py @@ -0,0 +1,216 @@ +import time +from typing import Any, Dict, List, Optional + +from fastapi import Request +from pydantic import model_serializer +from sglang.srt.entrypoints.openai.protocol import ( + ChatCompletionRequest as OriginalChatCompletionRequest, +) +from sglang.srt.entrypoints.openai.protocol import ( + ChatCompletionResponse as OriginalChatCompletionResponse, +) +from sglang.srt.entrypoints.openai.protocol import ( + ChatCompletionResponseChoice as OriginalChatCompletionResponseChoice, +) +from sglang.srt.entrypoints.openai.protocol import ChatMessage, SglExt +from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat, logger +from sglang.srt.entrypoints.openai.usage_processor import UsageProcessor +from sglang.srt.entrypoints.openai.utils import ( + process_cached_tokens_details_from_ret, + process_hidden_states_from_ret, + process_routed_experts_from_ret, +) +from sglang.srt.managers.io_struct import GenerateReqInput +from sglang.srt.parser.reasoning_parser import ReasoningParser + + +# Add `return_token_ids` to the request +class ChatCompletionRequest(OriginalChatCompletionRequest): + return_token_ids: bool = False + + +class ChatCompletionResponseChoice(OriginalChatCompletionResponseChoice): + token_ids: Optional[List[int]] = None + + @model_serializer(mode="wrap") + def _serialize(self, handler): + data = handler(self) + if self.hidden_states is None: + data.pop("hidden_states", None) + if self.token_ids is None: + data.pop("token_ids", None) + return data + + +class ChatCompletionResponse(OriginalChatCompletionResponse): + choices: List[ChatCompletionResponseChoice] + prompt_token_ids: Optional[List[int]] = None + + @model_serializer(mode="wrap") + def _serialize(self, handler): + data = handler(self) + if self.prompt_token_ids is None: + data.pop("prompt_token_ids", None) + if self.sglext is None: + data.pop("sglext", None) + return data + + +class PatchedOpenAIServingChat(OpenAIServingChat): + """This is a patched version of OpenAIServingChat which supports return + `prompt_token_ids` and `token_ids` in non-streaming mode.""" + + async def _handle_non_streaming_request( + self, + adapted_request: GenerateReqInput, + request: ChatCompletionRequest, + raw_request: Request, + ): + assert hasattr( + request, "return_token_ids" + ), "You are using an unpatched version of OpenAIServingChat." + try: + ret = await self.tokenizer_manager.generate_request( + adapted_request, raw_request + ).__anext__() + except ValueError as e: + return self.create_error_response(str(e)) + + if not isinstance(ret, list): + ret = [ret] + + response = self._build_chat_response( + request, + adapted_request, + ret, + int(time.time()), + ) + + return response + + def _build_chat_response( + self, + request: ChatCompletionRequest, + adapted_request: GenerateReqInput, + ret: List[Dict[str, Any]], + created: int, + ): + """Build chat completion response from generation results""" + choices = [] + + # Build sglext at response level (from first ret_item, as these are per-request) + first_ret = ret[0] + routed_experts = process_routed_experts_from_ret(first_ret, request) + cached_tokens_details = process_cached_tokens_details_from_ret(first_ret, request) + response_sglext = None + if routed_experts or cached_tokens_details: + response_sglext = SglExt( + routed_experts=routed_experts, + cached_tokens_details=cached_tokens_details, + ) + + for idx, ret_item in enumerate(ret): + # Process logprobs + choice_logprobs = None + if request.logprobs: + choice_logprobs = self._process_response_logprobs(ret_item) + + # Handle hidden states + hidden_states = process_hidden_states_from_ret(ret_item, request) + + finish_reason = ret_item["meta_info"]["finish_reason"] + text = ret_item["text"] + + # Handle reasoning content + reasoning_text = None + reasoning_parser = self.reasoning_parser + if reasoning_parser and request.separate_reasoning: + is_force_reasoning = ( + self.template_manager.force_reasoning + or self._get_reasoning_from_request(request) + ) + try: + parser = ReasoningParser( + model_type=reasoning_parser, + stream_reasoning=False, + force_reasoning=is_force_reasoning, + request=request, + ) + reasoning_text, text = parser.parse_non_stream(text) + except Exception as e: + logger.error(f"Reasoning parsing error: {e}") + return self.create_error_response( + "Failed to parse reasoning content", + err_type="InternalServerError", + status_code=500, + ) + + # Handle tool calls + tool_calls = None + if request.tool_choice != "none" and request.tools and self.tool_call_parser: + history_tool_calls_cnt = self._get_history_tool_calls_cnt(request) + tool_calls, text, finish_reason = self._process_tool_calls( + text, + request.tools, + finish_reason, + request.tool_choice, + history_tool_calls_cnt, + ) + + choice_data = ChatCompletionResponseChoice( + index=idx, + message=ChatMessage( + role="assistant", + content=text if text else None, + tool_calls=tool_calls, + reasoning_content=reasoning_text if reasoning_text else None, + ), + logprobs=choice_logprobs, + token_ids=(ret_item.get("output_ids") if request.return_token_ids else None), + finish_reason=finish_reason["type"] if finish_reason else None, + matched_stop=( + finish_reason["matched"] + if finish_reason and "matched" in finish_reason + else None + ), + hidden_states=hidden_states, + ) + choices.append(choice_data) + + # Calculate usage + usage = UsageProcessor.calculate_response_usage( + ret, + n_choices=request.n, + enable_cache_report=self.tokenizer_manager.server_args.enable_cache_report, + ) + + response = ChatCompletionResponse( + id=ret[0]["meta_info"]["id"], + created=created, + model=request.model, + choices=choices, + usage=usage, + prompt_token_ids=( + ( + adapted_request.input_ids[0] + if adapted_request.input_ids and isinstance(adapted_request.input_ids[0], list) + else adapted_request.input_ids + ) + if request.return_token_ids + else None + ), + metadata={"weight_version": ret[0]["meta_info"]["weight_version"]}, + sglext=response_sglext, + ) + logger.info(f"Generated response: {response.model_dump()}") + return response + + async def _handle_streaming_request( + self, + adapted_request: GenerateReqInput, + request: ChatCompletionRequest, + raw_request: Request, + ): + if request.return_token_ids: + raise ValueError("return_token_ids is not supported in streaming mode.") + return await super()._handle_streaming_request(adapted_request, request, raw_request) diff --git a/trinity/common/models/sglang_patch/server_patch.py b/trinity/common/models/sglang_patch/server_patch.py index f01b249b83b..da2a1f78c6a 100644 --- a/trinity/common/models/sglang_patch/server_patch.py +++ b/trinity/common/models/sglang_patch/server_patch.py @@ -5,6 +5,13 @@ from typing import Callable, Dict, List, Optional import uvicorn +from fastapi.dependencies.utils import ( + _should_embed_body_fields, + get_body_field, + get_dependant, + get_flat_dependant, +) +from fastapi.routing import APIRoute, request_response from sglang.srt.entrypoints.engine import Engine from sglang.srt.entrypoints.http_server import ( _execute_server_warmup, @@ -20,6 +27,57 @@ from sglang.srt.utils import kill_process_tree from sglang.srt.utils.watchdog import SubprocessWatchdog +from trinity.common.models.sglang_patch.openai_api_patch import ( + ChatCompletionRequest as PatchedChatCompletionRequest, +) +from trinity.common.models.sglang_patch.openai_api_patch import ( + ChatCompletionResponse as PatchedChatCompletionResponse, +) +from trinity.common.models.sglang_patch.openai_api_patch import ( + ChatCompletionResponseChoice as PatchedChatCompletionResponseChoice, +) +from trinity.common.models.sglang_patch.openai_api_patch import PatchedOpenAIServingChat + + +def _refresh_chat_completion_routes() -> None: + for route in app.routes: + if not isinstance(route, APIRoute): + continue + if route.path not in {"/v1/chat/completions", "/invocations"}: + continue + + route.endpoint.__annotations__["request"] = PatchedChatCompletionRequest + route.dependant = get_dependant( + path=route.path_format, + call=route.endpoint, + scope="function", + ) + flat_dependant = get_flat_dependant(route.dependant) + embed_body_fields = _should_embed_body_fields(flat_dependant.body_params) + setattr(route, "_flat_dependant", flat_dependant) + setattr(route, "_embed_body_fields", embed_body_fields) + route.body_field = get_body_field( + flat_dependant=flat_dependant, + name=route.unique_id, + embed_body_fields=embed_body_fields, + ) + route.app = request_response(route.get_route_handler()) + + +def _apply_openai_api_monkey_patch() -> None: + import sglang.srt.entrypoints.http_server as http_server_module + import sglang.srt.entrypoints.openai.protocol as protocol_module + import sglang.srt.entrypoints.openai.serving_chat as serving_chat_module + + protocol_module.ChatCompletionRequest = PatchedChatCompletionRequest + serving_chat_module.ChatCompletionRequest = PatchedChatCompletionRequest + serving_chat_module.ChatCompletionResponse = PatchedChatCompletionResponse + serving_chat_module.ChatCompletionResponseChoice = PatchedChatCompletionResponseChoice + http_server_module.ChatCompletionRequest = PatchedChatCompletionRequest + http_server_module.OpenAIServingChat = PatchedOpenAIServingChat + + _refresh_chat_completion_routes() + def _setup_and_run_http_server( # noqa: C901 server_args: ServerArgs, @@ -133,6 +191,8 @@ def get_api_server( api_key: str, logger: Logger, ) -> "asyncio.Task[None]": + _apply_openai_api_monkey_patch() + # TODO: fill in nnodes and node_rank for distributed setups # TODO: fix chat template server_args = ServerArgs( From 646031097329fd9b32d9b8cc54d6894acf4afb5b Mon Sep 17 00:00:00 2001 From: pxc Date: Tue, 12 May 2026 19:27:00 +0800 Subject: [PATCH 09/15] add sglang to trainer test --- tests/trainer/trainer_test.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 4547daa7792..cbc10e92688 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -76,10 +76,10 @@ def setUp(self): @parameterized_class( - ("strategy",), + ("strategy", "engine_type"), [ - ("fsdp2",), - ("megatron",), + ("fsdp2", "vllm"), + ("megatron", "sglang"), ], ) class TestTrainerCountdown(BaseTrainerCase): @@ -92,6 +92,7 @@ def test_trainer(self): "original_max_position_embeddings": 16384, } self.config.model.rope_theta = 10000 + self.config.explorer.rollout_model.engine_type = self.engine_type self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") self.config.buffer.explorer_input.taskset.data_selector = DataSelectorConfig( selector_type="shuffle", seed=42 From b9e9f5cfbc27a7cd30798d89904b7a9ba3978725 Mon Sep 17 00:00:00 2001 From: pxc Date: Tue, 12 May 2026 19:31:03 +0800 Subject: [PATCH 10/15] fix dtype --- trinity/manager/synchronizer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/trinity/manager/synchronizer.py b/trinity/manager/synchronizer.py index d6ed8f7793f..8e8f135b305 100644 --- a/trinity/manager/synchronizer.py +++ b/trinity/manager/synchronizer.py @@ -291,7 +291,9 @@ async def get_state_dict_meta(self): ) update_weight_args_list = [] for name, param in self.model_state_dict.items(): - update_weight_args_list.append((name, str(param.dtype), tuple(param.shape))) + update_weight_args_list.append( + (name, str(param.dtype).split(".")[-1], tuple(param.shape)) + ) return update_weight_args_list async def setup_weight_sync_group( From c5d8d6ef266f71c30601543cd6d85eb74b92d6cc Mon Sep 17 00:00:00 2001 From: pxc Date: Tue, 12 May 2026 20:20:08 +0800 Subject: [PATCH 11/15] fix trainer tests --- trinity/common/models/__init__.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/trinity/common/models/__init__.py b/trinity/common/models/__init__.py index ea9fddf398f..2e692d74277 100644 --- a/trinity/common/models/__init__.py +++ b/trinity/common/models/__init__.py @@ -275,14 +275,14 @@ def create_sglang_explorer_models( from trinity.common.models.sglang_model import SGLangRolloutModel models = [] + engine_pg = placement_group( + [{"GPU": config.tensor_parallel_size} for _ in range(config.engine_num)], + strategy="PACK", + ) + ray.get(engine_pg.ready()) for i in range(config.engine_num): model_config = deepcopy(config) model_config.engine_id = i - engine_pg = placement_group( - [{"GPU": model_config.tensor_parallel_size}], - strategy="PACK", - ) - ray.get(engine_pg.ready()) models.append( ray.remote(SGLangRolloutModel) .options( @@ -293,7 +293,7 @@ def create_sglang_explorer_models( scheduling_strategy=PlacementGroupSchedulingStrategy( placement_group=engine_pg, placement_group_capture_child_tasks=True, - placement_group_bundle_index=0, + placement_group_bundle_index=i, ), ) .remote( From 83d525e429231787b39e4bfe60cbd02430b74657 Mon Sep 17 00:00:00 2001 From: pxc Date: Tue, 12 May 2026 20:43:16 +0800 Subject: [PATCH 12/15] fix trainer test --- tests/trainer/trainer_test.py | 4 ++-- trinity/common/models/sglang_model.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index cbc10e92688..599132cd0be 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -78,8 +78,8 @@ def setUp(self): @parameterized_class( ("strategy", "engine_type"), [ - ("fsdp2", "vllm"), - ("megatron", "sglang"), + ("fsdp2", "sglang"), + ("megatron", "vllm"), ], ) class TestTrainerCountdown(BaseTrainerCase): diff --git a/trinity/common/models/sglang_model.py b/trinity/common/models/sglang_model.py index 24ae03b6723..464cc6aa733 100644 --- a/trinity/common/models/sglang_model.py +++ b/trinity/common/models/sglang_model.py @@ -115,7 +115,7 @@ async def update_weights_from_distributed( self, state_dict_meta_list: List[Tuple[str, str, Tuple]], group_name: str, - flash_cache: bool = True, + flush_cache: bool = True, abort_all_requests: bool = True, weight_version: Optional[str] = None, timeout: float = 300, @@ -128,7 +128,7 @@ async def update_weights_from_distributed( "dtypes": dtypes, "shapes": shapes, "group_name": group_name, - "flash_cache": flash_cache, + "flush_cache": flush_cache, "abort_all_requests": abort_all_requests, "weight_version": weight_version, } @@ -482,7 +482,7 @@ async def sync_model( ) self.model_version = model_version elif method == SyncMethod.CHECKPOINT: - model_path = await self.synchronizer.get_latest_model_path.remote() + model_path = await self.synchronizer.get_latest_model_path.remote(use_huggingface=True) if model_path is not None: await self.api_client.update_weights_from_disk( model_path=model_path, From 325843f5e1626a891084ae877ecfb2a84ca31015 Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 13 May 2026 10:13:18 +0800 Subject: [PATCH 13/15] rename config --- trinity/common/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trinity/common/config.py b/trinity/common/config.py index 8c7257de197..a25c4d6b2a0 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -574,7 +574,7 @@ class InferenceModelConfig: external_model_config: ExternalModelConfig = field(default_factory=ExternalModelConfig) # for multi-node setup - nnode: int = 1 + nnodes: int = 1 # ! DO NOT SET node_rank: int = 0 From df286eba39c85df45a5840df92b46707eca60912 Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 13 May 2026 10:43:29 +0800 Subject: [PATCH 14/15] fix port allocate --- tests/common/config_test.py | 21 +++++++++++++++++++++ trinity/common/models/model.py | 11 ++++++++--- trinity/explorer/explorer.py | 8 ++++---- 3 files changed, 33 insertions(+), 7 deletions(-) diff --git a/tests/common/config_test.py b/tests/common/config_test.py index 7c9bfbe1671..3ea8426f8ba 100644 --- a/tests/common/config_test.py +++ b/tests/common/config_test.py @@ -73,6 +73,27 @@ def test_inference_model_without_base_port_uses_ephemeral_port(self): self.assertGreater(port, 0) + def test_inference_model_random_port_ignores_base_port(self): + requested_port = 9005 + model = DummyInferenceModel(InferenceModelConfig(base_port=9000, engine_id=5)) + + _, port = model.get_available_address(random_port=True) + + self.assertNotEqual(port, requested_port) + self.assertGreater(port, 0) + + def test_inference_model_random_port_can_use_port_reserved_by_api_server(self): + requested_port = 9006 + model = DummyInferenceModel(InferenceModelConfig(base_port=9000, engine_id=6)) + + with socket.socket() as occupied_socket: + occupied_socket.bind(("", requested_port)) + + _, port = model.get_available_address(random_port=True) + + self.assertNotEqual(port, requested_port) + self.assertGreater(port, 0) + def test_load_default_config(self): config = get_template_config() config.buffer.batch_size = 8 diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index 37b9ebf4e48..75e13d4c01e 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -64,10 +64,15 @@ async def sync_model( def get_model_version(self) -> int: """Get the checkpoint version.""" - def get_available_address(self) -> Tuple[str, int]: - """Get the address of the actor.""" + def get_available_address(self, random_port: bool = False) -> Tuple[str, int]: + """Get an available address on the current actor node. + + Args: + random_port: Whether to skip the configured ``base_port`` convention and + allocate an ephemeral port on the current node directly. + """ address = ray.util.get_node_ip_address() - if self.config.base_port is not None: + if not random_port and self.config.base_port is not None: configured_port = self.config.base_port + self.config.engine_id with socket.socket() as s: try: diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 631036ad8f2..34c2a28aec6 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -123,7 +123,7 @@ async def setup_model_level_weight_sync_group(self): refs = [] world_size = self.config.explorer.rollout_model.tensor_parallel_size for model in self.models: - master_address, master_port = await model.get_available_address.remote() + master_address, master_port = await model.get_available_address.remote(random_port=True) self.logger.info( f"Initialize process group for model weight synchronization, " f"master_address={master_address}, master_port={master_port}, " @@ -216,9 +216,9 @@ async def prepare(self) -> None: # In serving mode, each engine will setup its own process group await self.setup_model_level_weight_sync_group() else: - master_address, master_port = await self.models[ - 0 - ].get_available_address.remote() + master_address, master_port = await self.models[0].get_available_address.remote( + random_port=True + ) await self.setup_weight_sync_group(master_address, master_port) self.rollout_coordinator = RolloutCoordinator.get_actor( From d268ec63e1cae29e2d26fb529b312d845b4fe1b9 Mon Sep 17 00:00:00 2001 From: pxc Date: Wed, 13 May 2026 11:05:28 +0800 Subject: [PATCH 15/15] fix comments --- scripts/docker/Dockerfile.uv | 7 ++++--- trinity/common/models/sglang_model.py | 12 ++++++------ .../common/models/sglang_patch/openai_api_patch.py | 6 ++---- trinity/manager/synchronizer.py | 2 +- 4 files changed, 13 insertions(+), 14 deletions(-) diff --git a/scripts/docker/Dockerfile.uv b/scripts/docker/Dockerfile.uv index ba405c0cec2..6f237831d44 100644 --- a/scripts/docker/Dockerfile.uv +++ b/scripts/docker/Dockerfile.uv @@ -72,9 +72,10 @@ RUN . /opt/venv/bin/activate && MAX_JOBS=${BUILD_JOBS} \ git+https://github.com/NVIDIA/apex.git # Install SGLang -RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh - -RUN . /opt/venv/bin/activate && git clone -b v0.5.11 https://github.com/sgl-project/sglang.git /tmp/sglang && \ +RUN RUSTUP_HOME=/usr/local/rustup CARGO_HOME=/usr/local/cargo \ + curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y +RUN . /opt/venv/bin/activate && . /usr/local/cargo/env && \ + git clone -b v0.5.11 https://github.com/sgl-project/sglang.git /tmp/sglang && \ uv pip install /tmp/sglang/python && \ rm -rf /tmp/sglang && \ uv pip install transformers==5.8.0 diff --git a/trinity/common/models/sglang_model.py b/trinity/common/models/sglang_model.py index 464cc6aa733..a35e3caf431 100644 --- a/trinity/common/models/sglang_model.py +++ b/trinity/common/models/sglang_model.py @@ -226,12 +226,12 @@ def __init__( self.config.enable_openai_api = True os.environ["SGLANG_GRPC_PORT"] = "12345" # a dummy port not actually used os.environ["SGLANG_ENABLE_GRPC"] = "0" - os.environ[ - "NCCL_P2P_DISABLE" - ] = "1" # disable NCCL P2P to avoid potential issues in certain environments - os.environ[ - "NCCL_SHM_DISABLE" - ] = "1" # disable NCCL SHM to avoid potential issues in certain environments + os.environ.setdefault( + "NCCL_P2P_DISABLE", "1" + ) # default to disabling NCCL P2P, but preserve any explicit process configuration + os.environ.setdefault( + "NCCL_SHM_DISABLE", "1" + ) # default to disabling NCCL SHM, but preserve any explicit process configuration self.api_server_host: Optional[str] = None self.api_server_port: Optional[int] = None self.api_server: Optional[asyncio.Task[None]] = None diff --git a/trinity/common/models/sglang_patch/openai_api_patch.py b/trinity/common/models/sglang_patch/openai_api_patch.py index 0fc9aebb877..993a1522878 100644 --- a/trinity/common/models/sglang_patch/openai_api_patch.py +++ b/trinity/common/models/sglang_patch/openai_api_patch.py @@ -66,9 +66,8 @@ async def _handle_non_streaming_request( request: ChatCompletionRequest, raw_request: Request, ): - assert hasattr( - request, "return_token_ids" - ), "You are using an unpatched version of OpenAIServingChat." + if not hasattr(request, "return_token_ids"): + raise RuntimeError("You are using an unpatched version of OpenAIServingChat.") try: ret = await self.tokenizer_manager.generate_request( adapted_request, raw_request @@ -202,7 +201,6 @@ def _build_chat_response( metadata={"weight_version": ret[0]["meta_info"]["weight_version"]}, sglext=response_sglext, ) - logger.info(f"Generated response: {response.model_dump()}") return response async def _handle_streaming_request( diff --git a/trinity/manager/synchronizer.py b/trinity/manager/synchronizer.py index 8e8f135b305..6e2a3bdda23 100644 --- a/trinity/manager/synchronizer.py +++ b/trinity/manager/synchronizer.py @@ -370,7 +370,7 @@ async def get_latest_model_path(self, use_huggingface: bool = False) -> Optional """ async with self._ready_condition: if self.model_path and use_huggingface: - return self.model_path + "/huggingface" + return os.path.join(self.model_path, "huggingface") return self.model_path async def ready_to_nccl_sync(self, module: str, trainer_step: int) -> Union[int, None]: