Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tests/manager/synchronizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,15 @@ def join_process(self, process: Process, process_name: str, timeout: int = 300):

def wait_trainer_started(self, ray_namespace: str):
ray.init(ignore_reinit_error=True)
while True:
for _ in range(20):
try:
ray.get_actor("queue-exp_buffer", namespace=ray_namespace)
break
except ValueError:
print("waiting for trainer to start.")
time.sleep(5)
else:
raise RuntimeError("Trainer failed to start.")
return ray.get_actor("synchronizer", namespace=ray_namespace)

def _check_metrics(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __call__(self, exps: DataProto, **kwargs) -> Tuple[DataProto, Dict]:
Args:
exps: DataProto containing:
- old_log_probs: student's sampling logprobs [batch, seq]
- teacher_log_probs: teacher's logprobs [batch, seq]
- teacher_logprobs: teacher's logprobs [batch, seq]
- response_mask: mask for response tokens [batch, seq]

Returns:
Expand Down
1 change: 1 addition & 0 deletions trinity/common/config_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,6 +960,7 @@ def _fill_taskset_config(taskset: TasksetConfig, index: int, is_eval: bool = Fal
set_if_none(taskset.rollout_args, attr, getattr(config.model, attr))
set_if_none(taskset.rollout_args, "max_tokens", config.model.max_response_tokens)
set_if_none(taskset.format, "chat_template", config.model.custom_chat_template)
taskset.workflow_args["checkpoint_job_dir"] = config.checkpoint_job_dir

for i, taskset in enumerate(explorer_input.tasksets):
_fill_taskset_config(taskset, i)
Expand Down
2 changes: 1 addition & 1 deletion trinity/common/experience.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def serialize_many(cls, experiences: List[Experience]) -> bytes:
value = getattr(exp, field_name)
if value is None:
continue
tensor_data[f"{index}:{field_name}"] = value.detach().cpu().contiguous()
tensor_data[f"{index}:{field_name}"] = value.detach().cpu().contiguous().clone()

if exp.multi_modal_inputs is None:
item_meta["multi_modal_input_keys"] = []
Expand Down
4 changes: 2 additions & 2 deletions trinity/common/models/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ async def sync_model(
await self.async_llm.add_lora(self.get_lora_request(self.default_lora_path))
self.model_version = model_version
return model_version
await self.async_llm.reset_prefix_cache()
await self.async_llm.reset_prefix_cache(reset_running_requests=True)
await self._collective_rpc("update_weight", timeout=timeout)
self.logger.info(
f"Synchronized model to version {model_version} using method {sync_method}."
Expand Down Expand Up @@ -577,7 +577,7 @@ def get_api_server_url(self) -> Optional[str]:
return f"http://{self.api_server_host}:{self.api_server_port}"

async def reset_prefix_cache(self) -> None:
await self.async_llm.reset_prefix_cache()
await self.async_llm.reset_prefix_cache(reset_running_requests=True)

def get_model_version(self) -> int:
return self.model_version
Expand Down
76 changes: 67 additions & 9 deletions trinity/common/models/vllm_patch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import json
from logging import Logger

import vllm
Expand All @@ -25,18 +26,75 @@ def vllm_patch():
if vllm_version < parse_version("0.16.0"):
raise ImportError("Please upgrade vllm to 0.16.0 or above to use transformers>=5.0.0.")

from transformers.configuration_utils import PreTrainedConfig

original_init = PreTrainedConfig.__init__
if vllm_version < parse_version("0.19.1"):
from transformers.configuration_utils import PreTrainedConfig

original_init = PreTrainedConfig.__init__

def new_init(self, *args, **kwargs):
if "ignore_keys_at_rope_validation" in kwargs:
kwargs["ignore_keys_at_rope_validation"] = set(
kwargs["ignore_keys_at_rope_validation"]
)
original_init(self, *args, **kwargs)

PreTrainedConfig.__init__ = new_init
if parse_version("0.20.0") <= vllm_version:
# TODO: add upper bound when following PR is merged
# https://github.com/vllm-project/vllm/pull/39772/changes
from vllm.tool_parsers.qwen3coder_tool_parser import (
FunctionCall,
Qwen3CoderToolParser,
ToolCall,
find_tool_properties,
logger,
)

def new_init(self, *args, **kwargs):
if "ignore_keys_at_rope_validation" in kwargs:
kwargs["ignore_keys_at_rope_validation"] = set(
kwargs["ignore_keys_at_rope_validation"]
if getattr(Qwen3CoderToolParser, "_is_patched", None) is None:

def new_parse_xml_function_call(self, function_call_str: str) -> ToolCall | None:
# Extract function name
end_index = function_call_str.find(">")
# If there's no ">" character, this is not a valid xml function call
if end_index == -1:
return None
function_name = function_call_str[:end_index]
param_config = find_tool_properties(self.tools, function_name)
parameters = function_call_str[end_index + 1 :]
param_dict = {}
for match_text in self.tool_call_parameter_regex.findall(parameters):
idx = match_text.find(">")
# Skip malformed parameters missing the name>value separator
# (e.g. truncated output) so other valid parameters can still
# be parsed.
if idx == -1:
logger.warning(
"Skipping malformed parameter without '>' separator "
"in tool call for function '%s': %r",
function_name,
match_text,
)
continue
param_name = match_text[:idx]
param_value = str(match_text[idx + 1 :])
# Remove prefix and trailing \n
if param_value.startswith("\n"):
param_value = param_value[1:]
if param_value.endswith("\n"):
param_value = param_value[:-1]

param_dict[param_name] = self._convert_param_value(
param_value, param_name, param_config, function_name
)
return ToolCall(
type="function",
function=FunctionCall(
name=function_name, arguments=json.dumps(param_dict, ensure_ascii=False)
),
)
original_init(self, *args, **kwargs)

PreTrainedConfig.__init__ = new_init
Qwen3CoderToolParser._is_patched = True
Qwen3CoderToolParser._parse_xml_function_call = new_parse_xml_function_call


def get_vllm_version():
Expand Down
Loading
Loading