diff --git a/tests/manager/synchronizer_test.py b/tests/manager/synchronizer_test.py index bb45ac6093c..4efae61d3eb 100644 --- a/tests/manager/synchronizer_test.py +++ b/tests/manager/synchronizer_test.py @@ -170,13 +170,15 @@ def join_process(self, process: Process, process_name: str, timeout: int = 300): def wait_trainer_started(self, ray_namespace: str): ray.init(ignore_reinit_error=True) - while True: + for _ in range(20): try: ray.get_actor("queue-exp_buffer", namespace=ray_namespace) break except ValueError: print("waiting for trainer to start.") time.sleep(5) + else: + raise RuntimeError("Trainer failed to start.") return ray.get_actor("synchronizer", namespace=ray_namespace) def _check_metrics( diff --git a/trinity/algorithm/advantage_fn/on_policy_distill_advantage.py b/trinity/algorithm/advantage_fn/on_policy_distill_advantage.py index d2cdbf1082a..b9dc775e390 100644 --- a/trinity/algorithm/advantage_fn/on_policy_distill_advantage.py +++ b/trinity/algorithm/advantage_fn/on_policy_distill_advantage.py @@ -32,7 +32,7 @@ def __call__(self, exps: DataProto, **kwargs) -> Tuple[DataProto, Dict]: Args: exps: DataProto containing: - old_log_probs: student's sampling logprobs [batch, seq] - - teacher_log_probs: teacher's logprobs [batch, seq] + - teacher_logprobs: teacher's logprobs [batch, seq] - response_mask: mask for response tokens [batch, seq] Returns: diff --git a/trinity/common/config_validator.py b/trinity/common/config_validator.py index 729e5b2e97f..87caf737247 100644 --- a/trinity/common/config_validator.py +++ b/trinity/common/config_validator.py @@ -960,6 +960,7 @@ def _fill_taskset_config(taskset: TasksetConfig, index: int, is_eval: bool = Fal set_if_none(taskset.rollout_args, attr, getattr(config.model, attr)) set_if_none(taskset.rollout_args, "max_tokens", config.model.max_response_tokens) set_if_none(taskset.format, "chat_template", config.model.custom_chat_template) + taskset.workflow_args["checkpoint_job_dir"] = config.checkpoint_job_dir for i, taskset in enumerate(explorer_input.tasksets): _fill_taskset_config(taskset, i) diff --git a/trinity/common/experience.py b/trinity/common/experience.py index f8bb1e2722e..adeb2362f21 100644 --- a/trinity/common/experience.py +++ b/trinity/common/experience.py @@ -357,7 +357,7 @@ def serialize_many(cls, experiences: List[Experience]) -> bytes: value = getattr(exp, field_name) if value is None: continue - tensor_data[f"{index}:{field_name}"] = value.detach().cpu().contiguous() + tensor_data[f"{index}:{field_name}"] = value.detach().cpu().contiguous().clone() if exp.multi_modal_inputs is None: item_meta["multi_modal_input_keys"] = [] diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 01de7052613..d031f31525f 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -499,7 +499,7 @@ async def sync_model( await self.async_llm.add_lora(self.get_lora_request(self.default_lora_path)) self.model_version = model_version return model_version - await self.async_llm.reset_prefix_cache() + await self.async_llm.reset_prefix_cache(reset_running_requests=True) await self._collective_rpc("update_weight", timeout=timeout) self.logger.info( f"Synchronized model to version {model_version} using method {sync_method}." @@ -577,7 +577,7 @@ def get_api_server_url(self) -> Optional[str]: return f"http://{self.api_server_host}:{self.api_server_port}" async def reset_prefix_cache(self) -> None: - await self.async_llm.reset_prefix_cache() + await self.async_llm.reset_prefix_cache(reset_running_requests=True) def get_model_version(self) -> int: return self.model_version diff --git a/trinity/common/models/vllm_patch/__init__.py b/trinity/common/models/vllm_patch/__init__.py index f458ec65000..c41978b1b70 100644 --- a/trinity/common/models/vllm_patch/__init__.py +++ b/trinity/common/models/vllm_patch/__init__.py @@ -1,4 +1,5 @@ import asyncio +import json from logging import Logger import vllm @@ -25,18 +26,75 @@ def vllm_patch(): if vllm_version < parse_version("0.16.0"): raise ImportError("Please upgrade vllm to 0.16.0 or above to use transformers>=5.0.0.") - from transformers.configuration_utils import PreTrainedConfig - - original_init = PreTrainedConfig.__init__ + if vllm_version < parse_version("0.19.1"): + from transformers.configuration_utils import PreTrainedConfig + + original_init = PreTrainedConfig.__init__ + + def new_init(self, *args, **kwargs): + if "ignore_keys_at_rope_validation" in kwargs: + kwargs["ignore_keys_at_rope_validation"] = set( + kwargs["ignore_keys_at_rope_validation"] + ) + original_init(self, *args, **kwargs) + + PreTrainedConfig.__init__ = new_init + if parse_version("0.20.0") <= vllm_version: + # TODO: add upper bound when following PR is merged + # https://github.com/vllm-project/vllm/pull/39772/changes + from vllm.tool_parsers.qwen3coder_tool_parser import ( + FunctionCall, + Qwen3CoderToolParser, + ToolCall, + find_tool_properties, + logger, + ) - def new_init(self, *args, **kwargs): - if "ignore_keys_at_rope_validation" in kwargs: - kwargs["ignore_keys_at_rope_validation"] = set( - kwargs["ignore_keys_at_rope_validation"] + if getattr(Qwen3CoderToolParser, "_is_patched", None) is None: + + def new_parse_xml_function_call(self, function_call_str: str) -> ToolCall | None: + # Extract function name + end_index = function_call_str.find(">") + # If there's no ">" character, this is not a valid xml function call + if end_index == -1: + return None + function_name = function_call_str[:end_index] + param_config = find_tool_properties(self.tools, function_name) + parameters = function_call_str[end_index + 1 :] + param_dict = {} + for match_text in self.tool_call_parameter_regex.findall(parameters): + idx = match_text.find(">") + # Skip malformed parameters missing the name>value separator + # (e.g. truncated output) so other valid parameters can still + # be parsed. + if idx == -1: + logger.warning( + "Skipping malformed parameter without '>' separator " + "in tool call for function '%s': %r", + function_name, + match_text, + ) + continue + param_name = match_text[:idx] + param_value = str(match_text[idx + 1 :]) + # Remove prefix and trailing \n + if param_value.startswith("\n"): + param_value = param_value[1:] + if param_value.endswith("\n"): + param_value = param_value[:-1] + + param_dict[param_name] = self._convert_param_value( + param_value, param_name, param_config, function_name + ) + return ToolCall( + type="function", + function=FunctionCall( + name=function_name, arguments=json.dumps(param_dict, ensure_ascii=False) + ), ) - original_init(self, *args, **kwargs) - PreTrainedConfig.__init__ = new_init + Qwen3CoderToolParser._is_patched = True + Qwen3CoderToolParser._parse_xml_function_call = new_parse_xml_function_call def get_vllm_version(): diff --git a/trinity/common/patch/qwen3_5.py b/trinity/common/patch/qwen3_5.py index 4a7fef18368..828ce338160 100644 --- a/trinity/common/patch/qwen3_5.py +++ b/trinity/common/patch/qwen3_5.py @@ -6,16 +6,16 @@ import torch.distributed as dist from torch import Tensor from transformers.models.qwen3_5.modeling_qwen3_5 import ( - BaseModelOutputWithPast, + BaseModelOutputWithPooling, Cache, + F, Qwen3_5CausalLMOutputWithPast, Qwen3_5ForConditionalGeneration, Qwen3_5ModelOutputWithPast, TransformersKwargs, Unpack, - capture_outputs, - create_causal_mask, - merge_with_config_defaults, + apply_mask_to_padding_states, + can_return_tuple, ) from verl.utils.ulysses import all_gather_tensor @@ -70,109 +70,548 @@ def backward(ctx, grad_outputs: Tensor) -> Any: ) -# TODO: may optimize this function -def ulysses_gated_delta_net_forward_decorator(func): - @wraps(func) - def wrapper( - hidden_states: torch.Tensor, - **kwargs, - ): - from verl.utils.ulysses import ( - gather_outputs_and_unpad, - get_ulysses_sequence_parallel_group, - get_ulysses_sequence_parallel_world_size, - ) +_in_gate_delta_net_with_sp = False - ulysses_sp_size = get_ulysses_sequence_parallel_world_size() - if ulysses_sp_size > 1: - hidden_states = gather_outputs_and_unpad(hidden_states, gather_dim=1) - output = func(hidden_states, **kwargs) +def ulysses_gate_delta_net_decorator(net, ulysses_sp_size): + """Decorator to enable Ulysses Sequence Parallel for Qwen3.5 GateDeltaNet linear attention. - if ulysses_sp_size > 1: - group = get_ulysses_sequence_parallel_group() - output = Slice.apply(group, output, 1) + This decorator patches the GateDeltaNet module to support sequence parallelism using the Ulysses + strategy. It intercepts various operations (forward pass, projections, convolutions, and attention) + to properly scatter/gather tensors across sequence parallel ranks. + + Args: + net: The GateDeltaNet module to patch (typically a linear attention layer). + ulysses_sp_size: The sequence parallel world size. If 1, no patching is performed. + + Note: + - This function patches the module in-place and sets a `_is_patched` flag to avoid double-patching. + - The sequence parallel operations are controlled via a global `_in_gate_delta_net_with_sp` flag. + - The patching includes modifications to forward, in_proj_qkv, conv1d, torch.split, and chunk_gated_delta_rule. + """ + if getattr(net, "_is_patched", False): + return + + net._is_patched = True + + # ulysses sequence parallel setup + from verl.utils.ulysses import ( + gather_heads_scatter_seq, + gather_seq_scatter_heads, + get_ulysses_sequence_parallel_group, + ) + + if ulysses_sp_size == 1: + # no need to patch + return + + # Patch net.forward + original_net_forward = net.forward + + @wraps(original_net_forward) + def new_net_forward(*args, **kwargs): + global _in_gate_delta_net_with_sp + _in_gate_delta_net_with_sp = True + output = original_net_forward(*args, **kwargs) + _in_gate_delta_net_with_sp = False return output - return wrapper + net.forward = new_net_forward + + # Patch in_proj_qkv + original_in_proj_qkv_forward = net.in_proj_qkv.forward + + @wraps(original_in_proj_qkv_forward) + def new_in_proj_qkv_forward(input): + output = original_in_proj_qkv_forward(input) + group = get_ulysses_sequence_parallel_group() + output = gather_seq_scatter_heads(output, seq_dim=1, head_dim=2, group=group) + return output + + net.in_proj_qkv.forward = new_in_proj_qkv_forward + + # Patch conv1d layer + original_conv1d_class = net.conv1d.__class__ + original_conv1d_getattr = original_conv1d_class.__getattr__ + + @wraps(original_conv1d_getattr) + def new_conv1d_getattr(self, name): + global _in_gate_delta_net_with_sp + attr = original_conv1d_getattr(self, name) + # bias is None in Qwen3.5, so no need to patch for bias + if name == "weight" and _in_gate_delta_net_with_sp: + group = get_ulysses_sequence_parallel_group() + return Slice.apply(group, attr, 0, True) + return attr + + new_conv1d_class = type( + f"UlyssesGated{original_conv1d_class.__name__}", + (original_conv1d_class,), + {"__getattr__": new_conv1d_getattr}, + ) + net.conv1d.__class__ = new_conv1d_class + + # Patch torch.split + if not getattr(torch.split, "_is_patched_by_ulysses_gate_delta_net", False): + original_split = torch.split + + @wraps(original_split) + def new_split(tensor, split_size_or_sections, dim=0): + global _in_gate_delta_net_with_sp + if _in_gate_delta_net_with_sp and dim == -1 and len(split_size_or_sections) == 3: + tensor = gather_heads_scatter_seq(tensor, seq_dim=1, head_dim=2) + + return original_split(tensor, split_size_or_sections, dim) + + torch.split = new_split + torch.split._is_patched_by_ulysses_gate_delta_net = True + + # Patch chunk_gated_delta_rule + original_chunk_gated_delta_rule = net.chunk_gated_delta_rule + + @wraps(original_chunk_gated_delta_rule) + def new_chunk_gated_delta_rule(query, key, value, g, beta, **kwargs): + query = gather_seq_scatter_heads(query, seq_dim=1, head_dim=2) + key = gather_seq_scatter_heads(key, seq_dim=1, head_dim=2) + value = gather_seq_scatter_heads(value, seq_dim=1, head_dim=2) + g = gather_seq_scatter_heads(g, seq_dim=1, head_dim=2) + beta = gather_seq_scatter_heads(beta, seq_dim=1, head_dim=2) + output, last_recurrent_state = original_chunk_gated_delta_rule( + query, key, value, g, beta, **kwargs + ) + output = gather_heads_scatter_seq(output, seq_dim=1, head_dim=2) + return output, last_recurrent_state + + net.chunk_gated_delta_rule = new_chunk_gated_delta_rule + + +# removed when following PR is merged +# https://github.com/huggingface/transformers/pull/45034/changes +def gate_delta_net_forward( + self, + hidden_states: torch.Tensor, + cache_params: Cache | None = None, + attention_mask: torch.Tensor | None = None, + **kwargs, +): + """Forward pass for Qwen3.5 GateDeltaNet linear attention with packing support. + + This implementation of the linear attention forward pass supports packed sequences for efficient + training, following the approach referenced in the Hugging Face transformers PR #45034. + It handles both incremental (cached) and non-cached inference modes. + + Args: + hidden_states: Input hidden states of shape (batch_size, seq_len, hidden_dim). + cache_params: Optional cache parameters for incremental decoding. + attention_mask: Optional attention mask to mask out padding positions. + **kwargs: Additional keyword arguments passed to sub-components (e.g., seq_idx for packed sequences). + + Returns: + Output tensor of shape (batch_size, seq_len, hidden_dim) after linear attention computation. + """ + hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) + + # Set up dimensions for reshapes later + batch_size, seq_len, _ = hidden_states.shape + + use_precomputed_states = ( + cache_params is not None + and cache_params.has_previous_state(self.layer_idx) + and seq_len == 1 + ) + + # getting projected states from cache if it exists + if use_precomputed_states: + conv_state = cache_params.layers[self.layer_idx].conv_states + recurrent_state = cache_params.layers[self.layer_idx].recurrent_states + + mixed_qkv = self.in_proj_qkv(hidden_states) + mixed_qkv = mixed_qkv.transpose(1, 2) + + z = self.in_proj_z(hidden_states) + z = z.reshape(batch_size, seq_len, -1, self.head_v_dim) + + b = self.in_proj_b(hidden_states) + a = self.in_proj_a(hidden_states) + + if use_precomputed_states: + # 2. Convolution sequence transformation + # NOTE: the conv state is updated in `causal_conv1d_update` + mixed_qkv = self.causal_conv1d_update( + mixed_qkv, + conv_state, + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.activation, + ) + else: + if cache_params is not None: + conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)) + conv_state = cache_params.update_conv_state(conv_state, self.layer_idx) + if self.causal_conv1d_fn is not None: + seq_idx = kwargs.get("seq_idx", None) + mixed_qkv = self.causal_conv1d_fn( + x=mixed_qkv, + weight=self.conv1d.weight.squeeze(1), + bias=self.conv1d.bias, + activation=self.activation, + seq_idx=seq_idx, + ) + else: + mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len]) + + mixed_qkv = mixed_qkv.transpose(1, 2) + query, key, value = torch.split( + mixed_qkv, + [ + self.key_dim, + self.key_dim, + self.value_dim, + ], + dim=-1, + ) + + query = query.reshape(batch_size, seq_len, -1, self.head_k_dim) + key = key.reshape(batch_size, seq_len, -1, self.head_k_dim) + value = value.reshape(batch_size, seq_len, -1, self.head_v_dim) + + beta = b.sigmoid() + # If the model is loaded in fp16, without the .float() here, A might be -inf + g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) + if self.num_v_heads // self.num_k_heads > 1: + query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) + key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) + + if not use_precomputed_states: + chunk_kwargs = {} + if getattr(self.chunk_gated_delta_rule, "__module__", "").startswith("fla."): + chunk_kwargs["cu_seqlens"] = kwargs.get("cu_seqlens", None) + + core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=None, + output_final_state=cache_params is not None, + use_qk_l2norm_in_kernel=True, + **chunk_kwargs, + ) + + else: + core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=recurrent_state, + output_final_state=cache_params is not None, + use_qk_l2norm_in_kernel=True, + ) + + # Update cache + if cache_params is not None: + cache_params.update_recurrent_state(last_recurrent_state, self.layer_idx) + + # reshape input data into 2D tensor + core_attn_out = core_attn_out.reshape(-1, self.head_v_dim) + z = z.reshape(-1, self.head_v_dim) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1) + + output = self.out_proj(core_attn_out) + return output -@merge_with_config_defaults -@capture_outputs -def qwen35_text_forward( +# removed when following PR is merged +# https://github.com/huggingface/transformers/pull/45034/changes +def decoder_layer_forward( self, - input_ids: torch.LongTensor | None = None, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + **kwargs: Unpack[TransformersKwargs], +) -> torch.FloatTensor: + """Forward pass for a Qwen3.5 decoder layer supporting packed sequences. + + This function implements a full transformer decoder layer with support for packed sequences + (packing training). It combines token mixing (via linear or full attention) with a feed-forward + network, with residual connections around each sub-layer. + + Args: + hidden_states: Input hidden states of shape (batch_size, seq_len, hidden_dim). + position_embeddings: Tuple of (cos_cached, sin_cached) for rotary position embeddings. + attention_mask: Optional attention mask. + position_ids: Optional position IDs for the sequence. + past_key_values: Optional cache for incremental decoding. + **kwargs: Additional arguments including: + - layer_type: Either 'linear_attention' or 'full_attention' to determine token mixer. + - seq_idx: Sequence indices for packed sequence training. + + Returns: + Output hidden states of same shape as input (batch_size, seq_len, hidden_dim). + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Token Mixer + if self.layer_type == "linear_attention": + hidden_states = self.linear_attn( + hidden_states=hidden_states, + cache_params=past_key_values, + attention_mask=attention_mask, + **kwargs, + ) + elif self.layer_type == "full_attention": + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +def qwen35_vision_fast_pos_embed_interpolate(self, grid_thw): + """Interpolate vision position embeddings for variable resolution inputs with proper device handling. + + This function performs bilinear interpolation of position embeddings to support variable spatial + resolutions. It fixes the device handling issue that occurred during CPU offloading, ensuring all + tensors are created and operated on the same device as the input. + + Args: + grid_thw: Tensor of shape (num_images, 3) containing temporal, height, and width dimensions + for each image in the batch. + + Returns: + Interpolated position embeddings of shape (total_patches, embedding_dim) after merging, + where total_patches is the sum of all h*w for each image after spatial merging. + + Note: + - The function supports batch processing of multiple images with different resolutions. + - Spatial merging is applied based on config.spatial_merge_size. + - All tensors are properly placed on the same device as the input grid_thw. + """ + grid_thw_list = grid_thw.tolist() + grid_ts = [row[0] for row in grid_thw_list] + grid_hs = [row[1] for row in grid_thw_list] + grid_ws = [row[2] for row in grid_thw_list] + device = grid_thw.device # modified to ensure tensors are created on the correct device + + idx_list = [[] for _ in range(4)] + weight_list = [[] for _ in range(4)] + + for t, h, w in grid_thw_list: + h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) + w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) + + h_idxs_floor = h_idxs.int() + w_idxs_floor = w_idxs.int() + h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + + dh = h_idxs - h_idxs_floor + dw = w_idxs - w_idxs_floor + + base_h = h_idxs_floor * self.num_grid_per_side + base_h_ceil = h_idxs_ceil * self.num_grid_per_side + + indices = [ + (base_h[None].T + w_idxs_floor[None]).flatten(), + (base_h[None].T + w_idxs_ceil[None]).flatten(), + (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), + (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), + ] + + weights = [ + ((1 - dh)[None].T * (1 - dw)[None]).flatten(), + ((1 - dh)[None].T * dw[None]).flatten(), + (dh[None].T * (1 - dw)[None]).flatten(), + (dh[None].T * dw[None]).flatten(), + ] + + for i in range(4): + idx_list[i].extend(indices[i].tolist()) + weight_list[i].extend(weights[i].tolist()) + + idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device) + weight_tensor = torch.tensor(weight_list, dtype=self.pos_embed.weight.dtype, device=device) + pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None] + patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] + + patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) + + patch_pos_embeds_permute = [] + merge_size = self.config.spatial_merge_size + for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): + pos_embed = pos_embed.repeat(t, 1) + pos_embed = ( + pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) + .permute(0, 1, 3, 2, 4, 5) + .flatten(0, 4) + ) + patch_pos_embeds_permute.append(pos_embed) + patch_pos_embeds = torch.cat(patch_pos_embeds_permute) + return patch_pos_embeds + + +@can_return_tuple +def qwen35_model_forward( + self, + input_ids: torch.LongTensor = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, - use_cache: bool | None = None, - cache_position: torch.LongTensor | None = None, + pixel_values: torch.Tensor | None = None, + pixel_values_videos: torch.FloatTensor | None = None, + image_grid_thw: torch.LongTensor | None = None, + video_grid_thw: torch.LongTensor | None = None, + mm_token_type_ids: torch.IntTensor | None = None, **kwargs: Unpack[TransformersKwargs], -) -> BaseModelOutputWithPast: +) -> tuple | Qwen3_5ModelOutputWithPast: + """Qwen3.5 model forward pass with multimodal support and gradient synchronization across ranks. + + This forward function handles multimodal training (images and/or videos) across multiple GPU ranks + with proper synchronization. When a rank doesn't have image/video inputs but other ranks do (common in + distributed training with different data samples), it creates dummy images/videos to maintain consistency + and avoid hanging in collective operations. + + Args: + input_ids: Token IDs of shape (batch_size, seq_len). + attention_mask: Attention mask for padding tokens. + position_ids: Position IDs for embeddings. + past_key_values: Cached key-values for incremental decoding. + inputs_embeds: Pre-computed input embeddings (alternative to input_ids). + pixel_values: Image pixel values of shape (num_images, channels, height, width). + pixel_values_videos: Video pixel values of shape (num_videos, frames, channels, height, width). + image_grid_thw: Grid dimensions (temporal, height, width) for images. + video_grid_thw: Grid dimensions (temporal, height, width) for videos. + mm_token_type_ids: Token type IDs to distinguish image, video, and text tokens. + **kwargs: Additional arguments. + + Returns: + Qwen3_5ModelOutputWithPast containing language model outputs with rope_deltas for position embeddings. + + Note: + - Dummy images/videos are created with shape based on spatial_merge_size when needed for gradient synchronization. + - Uses distributed communication (dist.all_reduce) to synchronize multimodal input availability across ranks. + """ + r""" + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.get_input_embeddings()(input_ids) - if use_cache and past_key_values is None: - from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5DynamicCache + vision_config = self.config.vision_config + pixel_values_dim = ( + vision_config.in_channels + * vision_config.temporal_patch_size + * (vision_config.patch_size**2) + ) + merge_size = vision_config.spatial_merge_size - past_key_values = Qwen3_5DynamicCache(config=self.config) + device = inputs_embeds.device + has_mm_local = torch.tensor( + [int(pixel_values is not None), int(pixel_values_videos is not None)], device=device + ) + has_mm_global = has_mm_local.clone() + if dist.is_initialized(): + dist.all_reduce(has_mm_global) + has_mm_global = has_mm_global > 0 + + # check images + if has_mm_global[0].item(): + if not has_mm_local[0].item(): + pixel_values = torch.zeros( + (merge_size * merge_size, pixel_values_dim), dtype=torch.float32, device=device + ) + image_grid_thw = torch.ones((1, 3), dtype=torch.int64, device=device) + image_grid_thw[:, 1:] = merge_size - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + image_outputs: BaseModelOutputWithPooling = self.get_image_features( + pixel_values, image_grid_thw, return_dict=True ) + image_embeds = image_outputs.pooler_output + image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) - # mrope: the hard coded `3` is for temporal, height and width. - if position_ids is None: - position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) - elif position_ids.ndim == 2: - position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) - - if position_ids.ndim == 3 and position_ids.shape[0] == 4: - text_position_ids = position_ids[0] - position_ids = position_ids[1:] - else: - text_position_ids = position_ids[0] - - causal_mask = create_causal_mask( - config=self.config, - inputs_embeds=inputs_embeds, - attention_mask=attention_mask, - cache_position=cache_position, - past_key_values=past_key_values, - position_ids=text_position_ids, - ) - linear_attn_mask = self._update_linear_attn_mask(attention_mask, cache_position) - - hidden_states = inputs_embeds - position_embeddings = self.rotary_emb(hidden_states, position_ids) + if has_mm_local[0].item(): + image_mask, _ = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + else: # patched for backward + inputs_embeds[0] = inputs_embeds[0] + image_embeds[0] * 0.0 + + # check videos + if has_mm_global[1].item(): + if not has_mm_local[1].item(): + pixel_values_videos = torch.zeros( + (merge_size * merge_size, pixel_values_dim), dtype=torch.float32, device=device + ) + video_grid_thw = torch.ones((1, 3), dtype=torch.int64, device=device) + video_grid_thw[:, 1:] = merge_size - for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): - layer_mask = ( - linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask + video_outputs: BaseModelOutputWithPooling = self.get_video_features( + pixel_values_videos, video_grid_thw, return_dict=True ) + video_embeds = video_outputs.pooler_output + video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) - hidden_states = decoder_layer( - hidden_states, - position_embeddings=position_embeddings, - attention_mask=layer_mask, - position_ids=text_position_ids, + if has_mm_local[1].item(): + _, video_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + else: # patched for backward + inputs_embeds[0] = inputs_embeds[0] + video_embeds[0] * 0.0 + + if position_ids is None: + position_ids = self.compute_3d_position_ids( + input_ids=input_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, past_key_values=past_key_values, - use_cache=use_cache, - cache_position=cache_position, - **kwargs, + mm_token_type_ids=mm_token_type_ids, ) - hidden_states = self.norm(hidden_states) + outputs = self.language_model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + **kwargs, + ) return Qwen3_5ModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=past_key_values, + **outputs, + rope_deltas=self.rope_deltas, ) @@ -189,6 +628,31 @@ def forward_with_torch_backend( temperature: float = 1.0, **kwargs, ) -> tuple | Qwen3_5CausalLMOutputForPPO: + """Compute log probabilities and entropy for reinforcement learning using PyTorch backend. + + This function computes per-token log probabilities and entropy from the language model's hidden + states using a fused PyTorch-based linear projection. It's designed for PPO and other RL algorithms + that require per-token probability distributions over the vocabulary. + + Args: + input_ids: Token IDs of shape (batch_size, seq_len). + labels: Optional labels for loss computation. If None, input_ids are rolled to compute shifted targets. + temperature: Temperature scaling for softmax (default: 1.0). Used to control probability distribution sharpness. + **kwargs: Additional arguments passed to the model (e.g., attention_mask). + + Returns: + Qwen3_5CausalLMOutputForPPO containing: + - log_probs: Log probabilities of shape (batch_size, seq_len) + - entropy: Entropy values of shape (batch_size, seq_len) + - hidden_states: Hidden states from the model forward pass + + Raises: + RuntimeError: If neither labels nor input_ids is provided. + + Note: + - Uses FusedLinearForPPO for efficient torch-based computation. + - The log probability target is computed by rolling labels (or input_ids) by -1 to create next-token prediction targets. + """ from verl.utils.experimental.torch_functional import FusedLinearForPPO outputs = self.model(input_ids=input_ids, **kwargs) @@ -225,6 +689,32 @@ def forward_with_triton_backend( temperature: float = 1.0, **kwargs, ) -> tuple | Qwen3_5CausalLMOutputForPPO: + """Compute log probabilities and entropy for reinforcement learning using Triton kernel backend. + + This function computes per-token log probabilities and entropy from the language model's hidden + states using an optimized Triton kernel (linear_cross_entropy). It provides better performance + compared to the PyTorch backend for large vocabularies, suitable for PPO and other RL algorithms. + + Args: + input_ids: Token IDs of shape (batch_size, seq_len). + labels: Optional labels for loss computation. If None, input_ids are rolled to compute shifted targets. + temperature: Temperature scaling for softmax (default: 1.0). Used to control probability distribution sharpness. + **kwargs: Additional arguments passed to the model (e.g., attention_mask). + + Returns: + Qwen3_5CausalLMOutputForPPO containing: + - log_probs: Log probabilities of shape (batch_size, seq_len) + - entropy: Entropy values of shape (batch_size, seq_len) + - hidden_states: Hidden states from the model forward pass + + Raises: + RuntimeError: If neither labels nor input_ids is provided. + + Note: + - Uses the linear_cross_entropy Triton kernel from verl for highly optimized computation. + - The log probability target is computed by rolling labels (or input_ids) by -1 to create next-token prediction targets. + - Generally faster than forward_with_torch_backend for large vocabulary sizes. + """ from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy outputs = self.model(input_ids=input_ids, **kwargs) diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 40b4efe80a8..631036ad8f2 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -288,7 +288,8 @@ async def explore_step(self) -> bool: await self.shutdown() return False self.explore_step_num += 1 - assert self.rollout_coordinator is not None, "Rollout coordinator must be prepared first." + if self.rollout_coordinator is None: + return False await self.rollout_coordinator.submit_batch.remote( batch_id=self.explore_step_num, tasks=tasks, @@ -422,7 +423,9 @@ async def _finish_steps(self, start_step: int, end_step: int, model_version: int self.monitor.log(metric, step=end_step) async def _finish_explore_step(self, step: int, model_version: int) -> None: - assert self.rollout_coordinator is not None, "Rollout coordinator must be prepared first." + if self.rollout_coordinator is None: + return + metric = {"rollout/model_version": model_version} with Timer(metric, "explorer/time/wait_explore_step"): result = await self.rollout_coordinator.finalize_train_batch.remote(step) diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index 05a7fbaebab..565d02d89c6 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -257,7 +257,7 @@ async def run_with_retry( self.logger.error(status.message) except asyncio.TimeoutError: run_task_ref = None - last_exception_msg = f"Timeout when running task of batch {task.batch_id} at runner {self.runner_id} at attempt {attempt + 1}: {task.task}" + last_exception_msg = f"Timeout ({timeout} s) when running task of batch {task.batch_id} at runner {self.runner_id} at attempt {attempt + 1}: {task.task}" self.logger.error(last_exception_msg) status = Status( completed_runs=0, diff --git a/trinity/trainer/verl/dp_actor.py b/trinity/trainer/verl/dp_actor.py index c13d25a0bea..254b9cc4d76 100644 --- a/trinity/trainer/verl/dp_actor.py +++ b/trinity/trainer/verl/dp_actor.py @@ -18,16 +18,26 @@ Modified from https://github.com/volcengine/verl/blob/v0.7.1/verl/workers/actor/dp_actor.py """ -import logging -import os - import torch +import verl.utils.torch_functional as verl_F from torch import nn from verl import DataProto +from verl.utils.attention_utils import ( + index_first_axis, + pad_input, + rearrange, + unpad_input, +) from verl.utils.debug import GPUMemoryLogger from verl.utils.device import get_device_id from verl.utils.py_functional import append_to_dict from verl.utils.seqlen_balancing import prepare_dynamic_batch +from verl.utils.torch_functional import logprobs_from_logits +from verl.utils.ulysses import ( + gather_outputs_and_unpad, + ulysses_pad, + ulysses_pad_and_slice_inputs, +) from verl.workers.actor.dp_actor import DataParallelPPOActor as DPActor from trinity.algorithm import ENTROPY_LOSS_FN, KL_FN, POLICY_LOSS_FN @@ -35,11 +45,46 @@ from trinity.algorithm.kl_fn.kl_fn import DummyKLFn from trinity.algorithm.utils import prefix_metrics from trinity.common.config import AlgorithmConfig +from trinity.utils.log import get_logger __all__ = ["DataParallelPPOActor"] -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) +logger = get_logger(in_ray_actor=True) + + +def get_seq_idx(cu_seqlens: torch.Tensor, total_nnz: int) -> torch.Tensor: + """ + Build `seq_idx` from `cu_seqlens`, mapping each packed position to its + original sequence id. + + Args: + cu_seqlens: Shape (batch + 1,). Cumulative sequence lengths from + `unpad_input`. + For example, [0, 3, 7, 10] means sequence 0 has length 3, + sequence 1 has length 4, and sequence 2 has length 3. + total_nnz: Total number of packed tokens, i.e. `cu_seqlens[-1]`. + + Returns: + Shape (total_nnz,), where each position is the original sequence id + (0-indexed). For example, [0, 0, 0, 1, 1, 1, 1, 2, 2, 2]. + """ + device = cu_seqlens.device + batch_size = cu_seqlens.shape[0] - 1 + seq_idx = torch.zeros(total_nnz, dtype=torch.int32, device=device) + + # Use cu_seqlens differences: place 1 at each sequence start index, then + # apply cumsum to recover sequence ids. + # Example: cu_seqlens = [0, 3, 7, 10] + # Set 1 at indices [3, 7], then cumsum -> [0,0,0,1,1,1,1,2,2,2] + seq_idx.scatter_( + dim=0, + # Start index of each new sequence (exclude the last endpoint). + index=cu_seqlens[1:-1].long(), + src=torch.ones(batch_size - 1, dtype=torch.int32, device=device), + ) + seq_idx = seq_idx.cumsum(dim=0, dtype=torch.int32) + + return seq_idx class DataParallelPPOActor(DPActor): @@ -62,6 +107,335 @@ def set_algorithm(self, algorithm_config: AlgorithmConfig): **algorithm_config.entropy_loss_fn_args ) + def _forward_micro_batch( # noqa: C901 + self, + micro_batch: dict[str, torch.Tensor], + temperature: float, + calculate_entropy: bool = False, + ) -> dict[str, torch.Tensor]: + """ + Returns: + dict[str, torch.Tensor]: + log_probs: (bs, response_len) + if calculate_entropy is True: + entropys: (bs, response_len) + if calculate_sum_pi_squared is False: + sum_pi_squared: (bs, response_len) + """ + calculate_sum_pi_squared = self.config.get("calculate_sum_pi_squared", False) + sum_pi_squared_checkpointing = self.config.get("sum_pi_squared_checkpointing", False) + # PrefixGrouper path for shared-prefix optimization + if self.use_prefix_grouper: + can_use_pg = ( + not self.use_remove_padding + and not self.use_ulysses_sp + and not self.use_fused_kernels + and not self.use_dynamic_bsz + ) + if can_use_pg and "response_mask" in micro_batch and "uid" in micro_batch: + from verl.trainer.ppo.prefix_grouper_utils import ( + forward_micro_batch_with_prefix_grouper, + ) + + return forward_micro_batch_with_prefix_grouper( + micro_batch=micro_batch, + model=self.actor_module, + temperature=temperature, + calculate_entropy=calculate_entropy, + device_name=self.device_name, + param_dtype=self.param_dtype, + use_chunking_entropy=self.config.get( + "entropy_from_logits_with_chunking", False + ), + ) + + response_length = micro_batch["responses"].size(-1) + multi_modal_inputs = {} + if "multi_modal_inputs" in micro_batch.keys(): + from verl.utils.model import extract_multi_modal_inputs + + multi_modal_inputs = extract_multi_modal_inputs(micro_batch["multi_modal_inputs"]) + + with torch.autocast(device_type=self.device_name, dtype=self.param_dtype): + input_ids = micro_batch["input_ids"] + batch_size, seqlen = input_ids.shape + attention_mask = micro_batch["attention_mask"] + position_ids = micro_batch["position_ids"] + entropy = None + if position_ids.dim() == 3: # qwen2vl mrope + position_ids = position_ids.transpose(0, 1) # (bsz, 4, seqlen) -> (4, bsz, seqlen) + + if self.use_remove_padding: + input_ids_rmpad, indices, cu_seqlens, *_ = unpad_input( + input_ids.unsqueeze(-1), attention_mask + ) # input_ids_rmpad (total_nnz, ...) + input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) + + seq_idx = get_seq_idx( + cu_seqlens=cu_seqlens, + total_nnz=cu_seqlens[-1].item(), + ) + + # unpad the position_ids to align the rotary + if position_ids.dim() == 3: + position_ids_rmpad = ( + index_first_axis( + rearrange(position_ids, "c b s ... -> (b s) c ..."), indices + ) + .transpose(0, 1) + .unsqueeze(1) + ) # (4, bsz, seqlen) -> (4, 1, bsz * seqlen) + else: + position_ids_rmpad = index_first_axis( + rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices + ).transpose(0, 1) + + is_mask_all_zero = attention_mask.sum() == 0 + if is_mask_all_zero: + input_ids_rmpad = torch.zeros( + (1, self.ulysses_sequence_parallel_size), + device=input_ids.device, + dtype=input_ids.dtype, + ) + if position_ids.dim() == 3: + position_ids_rmpad = torch.zeros( + (position_ids.shape[0], 1, self.ulysses_sequence_parallel_size), + device=position_ids.device, + dtype=position_ids.dtype, + ) + else: + position_ids_rmpad = torch.zeros( + (1, self.ulysses_sequence_parallel_size), + device=position_ids.device, + dtype=position_ids.dtype, + ) + + if "image_bound" in multi_modal_inputs: + from verl.utils.dataset.vision_utils import ( + process_multi_modal_inputs_for_minicpmo, + ) + + multi_modal_inputs = process_multi_modal_inputs_for_minicpmo( + input_ids, attention_mask, position_ids, cu_seqlens, multi_modal_inputs + ) + + # for compute the log_prob + input_ids_rmpad_rolled = torch.roll( + input_ids_rmpad, shifts=-1, dims=1 + ) # (1, total_nnz) + + # pad and slice the inputs if sp > 1 + if self.use_ulysses_sp: + is_vlm_model = hasattr( + getattr(self.actor_module, "module", self.actor_module).config, + "vision_config", + ) + if is_vlm_model: + # vlm model's inputs will be sliced after embedding + input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad( + input_ids_rmpad, + position_ids_rmpad=position_ids_rmpad, + sp_size=self.ulysses_sequence_parallel_size, + ) + else: + ( + input_ids_rmpad, + position_ids_rmpad, + pad_size, + ) = ulysses_pad_and_slice_inputs( + input_ids_rmpad, + position_ids_rmpad=position_ids_rmpad, + sp_size=self.ulysses_sequence_parallel_size, + ) + + if pad_size > 0: + seq_idx = torch.cat( + [ + seq_idx, + torch.full_like(seq_idx[:pad_size], fill_value=seq_idx[-1].item()), + ], + dim=0, + ) + cu_seqlens[-1] += pad_size + + input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs( + input_ids_rmpad_rolled, + position_ids_rmpad=None, + sp_size=self.ulysses_sequence_parallel_size, + ) + + input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze( + 0 + ) # ((total_nnz / sp) + pad) + + # only pass input_ids and position_ids to enable flash_attn_varlen + extra_args = { + "seq_idx": seq_idx.unsqueeze(0).to(torch.int32), + "cu_seqlens": cu_seqlens, + } + if self.use_fused_kernels: + extra_args["temperature"] = temperature + extra_args["return_dict"] = True + + output = self.actor_module( + input_ids=input_ids_rmpad, + attention_mask=None, + position_ids=position_ids_rmpad, + **multi_modal_inputs, + use_cache=False, + **extra_args, + ) # prevent model thinks we are generating + + if self.use_fused_kernels: + log_probs = output.log_probs.squeeze(0) # (total_nnz,) + entropy_rmpad = output.entropy.squeeze(0) # (total_nnz,) + + else: + logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size) + logits_rmpad.div_(temperature) + + # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen) + inplace_backward = True + if calculate_entropy: + inplace_backward = False + log_probs = logprobs_from_logits( + logits=logits_rmpad, + labels=input_ids_rmpad_rolled, + inplace_backward=inplace_backward, + ) + + # compute entropy + if calculate_entropy: + # ((total_nnz / sp) + pad) + entropy_rmpad = ( + self.compute_entropy_from_logits(logits_rmpad) + if not self.config.entropy_checkpointing + else torch.utils.checkpoint.checkpoint( + self.compute_entropy_from_logits, logits_rmpad + ) + ) + + # Compute sum_pi_squared if requested (for optimal_token_baseline) + if calculate_sum_pi_squared: + sum_pi_squared_rmpad = ( + self.calculate_sum_pi_squared_from_logits(logits_rmpad) + if not sum_pi_squared_checkpointing + else torch.utils.checkpoint.checkpoint( + self.calculate_sum_pi_squared_from_logits, logits_rmpad + ) + ) + + # gather log_prob if sp > 1 + if self.use_ulysses_sp: + # gather and unpad for the ulysses sp + log_probs = gather_outputs_and_unpad( + log_probs, + gather_dim=0, + unpad_dim=0, + padding_size=pad_size, + ) + if calculate_entropy: + entropy_rmpad = gather_outputs_and_unpad( + entropy_rmpad, + gather_dim=0, + unpad_dim=0, + padding_size=pad_size, + ) + if calculate_sum_pi_squared: + sum_pi_squared_rmpad = gather_outputs_and_unpad( + sum_pi_squared_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size + ) + + if is_mask_all_zero: + log_probs = log_probs[:0] + if calculate_entropy: + entropy_rmpad = entropy_rmpad[:0] + + # pad back to (bsz, seqlen) + if calculate_entropy: + full_entropy = pad_input( + hidden_states=entropy_rmpad.unsqueeze(-1), + indices=indices, + batch=batch_size, + seqlen=seqlen, + ) + if calculate_sum_pi_squared: + full_sum_pi_squared = pad_input( + hidden_states=sum_pi_squared_rmpad.unsqueeze(-1), + indices=indices, + batch=batch_size, + seqlen=seqlen, + ) + full_log_probs = pad_input( + hidden_states=log_probs.unsqueeze(-1), + indices=indices, + batch=batch_size, + seqlen=seqlen, + ) + + # only return response part: + if calculate_entropy: + entropy = full_entropy.squeeze(-1)[ + :, -response_length - 1 : -1 + ] # (bsz, response_length) + if calculate_sum_pi_squared: + # (bsz, response_length) + sum_pi_squared = full_sum_pi_squared.squeeze(-1)[:, -response_length - 1 : -1] + log_probs = full_log_probs.squeeze(-1)[ + :, -response_length - 1 : -1 + ] # (bsz, response_length) + + else: # not using rmpad and no ulysses sp + extra_args = {} + if self.use_fused_kernels: + extra_args["temperature"] = temperature + extra_args["return_dict"] = True + + output = self.actor_module( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + **multi_modal_inputs, + use_cache=False, + **extra_args, + ) # prevent model thinks we are generating + + if self.use_fused_kernels: + log_probs = output.log_probs[:, -response_length - 1 : -1] + entropy = output.entropy[:, -response_length - 1 : -1] # (bsz, response_length) + + else: + logits = output.logits + + logits.div_(temperature) + logits = logits[ + :, -response_length - 1 : -1, : + ] # (bsz, response_length, vocab_size) + log_probs = logprobs_from_logits(logits, micro_batch["responses"]) + if calculate_entropy: + if not self.config.entropy_checkpointing: + entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length) + else: + entropy = torch.utils.checkpoint.checkpoint( + verl_F.entropy_from_logits, logits + ) + # Compute sum_pi_squared if requested (for optimal_token_baseline) + if calculate_sum_pi_squared: + sum_pi_squared = ( + self.calculate_sum_pi_squared_from_logits(logits) + if not sum_pi_squared_checkpointing + else torch.utils.checkpoint.checkpoint( + self.calculate_sum_pi_squared_from_logits, logits + ) + ) + + outputs = {"log_probs": log_probs} + if calculate_entropy: + outputs["entropys"] = entropy + if calculate_sum_pi_squared: + outputs["sum_pi_squared"] = sum_pi_squared + return outputs + @GPUMemoryLogger(role="dp actor", logger=logger) def update_policy(self, data: DataProto): # noqa: C901 # make sure we are in training mode @@ -122,7 +496,12 @@ def update_policy(self, data: DataProto): # noqa: C901 # calculate the total number of response tokens in the minibatch mini_batch_token_num = torch.sum( mini_batch.batch["response_mask"].to(get_device_id()) - ).item() + ) + torch.distributed.all_reduce( + mini_batch_token_num, op=torch.distributed.ReduceOp.SUM + ) + if mini_batch_token_num == 0: + mini_batch_token_num += 1e-6 # to avoid division by zero self.actor_optimizer.zero_grad() @@ -211,8 +590,12 @@ def update_policy(self, data: DataProto): # noqa: C901 else: # EXPERIMENTAL: fix for token-mean loss aggregation # scale microbatch loss according to the number of tokens (rather than sequences) - loss_scale = torch.sum(response_mask).item() / (mini_batch_token_num + 1e-6) - + cur_token_num = torch.sum(response_mask.to(get_device_id())) + loss_scale = ( + cur_token_num + / mini_batch_token_num + * torch.distributed.get_world_size() + ) loss = policy_loss * loss_scale micro_batch_metrics["actor/final_loss"] = loss.detach().item() if "actor/kl_loss" in micro_batch_metrics: @@ -225,6 +608,10 @@ def update_policy(self, data: DataProto): # noqa: C901 else: loss.backward() + micro_batch_metrics = { + key: (value.detach().item() if isinstance(value, torch.Tensor) else value) + for key, value in micro_batch_metrics.items() + } append_to_dict(metrics, micro_batch_metrics) grad_norm = self._optimizer_step() diff --git a/trinity/trainer/verl/fsdp_checkpoint_manager.py b/trinity/trainer/verl/fsdp_checkpoint_manager.py index 9c93ade5350..7fd20ca650a 100644 --- a/trinity/trainer/verl/fsdp_checkpoint_manager.py +++ b/trinity/trainer/verl/fsdp_checkpoint_manager.py @@ -27,7 +27,6 @@ import torch from accelerate import init_empty_weights from torch.distributed.fsdp import ( - FullStateDictConfig, ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictType, @@ -112,15 +111,7 @@ def upload_state_dict(self, global_step: int): global_step (int): The current training step number. """ assert self.synchronizer is not None - state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) - with get_fsdp_state_ctx(self.model, StateDictType.FULL_STATE_DICT, state_dict_config, None): - state_dict = self.model.state_dict() - state_dict = { - key: (value.full_tensor() if hasattr(value, "full_tensor") else value) - .detach() - .to("cpu") - for key, value in state_dict.items() - } + state_dict = get_fsdp_full_state_dict(self.model, offload_to_cpu=True, rank0_only=True) self._upload_state_dict(state_dict, global_step) def _save_with_thread( diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index 1620bcf5987..344c6c55ee8 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -16,6 +16,7 @@ Modified from https://github.com/volcengine/verl/blob/v0.7.1/verl/workers/fsdp_workers.py """ +import builtins import datetime import json import os @@ -136,7 +137,11 @@ def __init__(self, config: DictConfig, role: str, **kwargs): timeout=datetime.timedelta(seconds=self.config.get("nccl_timeout", 600)), init_method=os.environ.get("DIST_INIT_METHOD", None), ) + + # setup logger self.logger = get_logger(f"{role}_{self.rank}", in_ray_actor=True) + # redirect built-in print to logger to capture logs + builtins.print = lambda *args, **kwargs: self.logger.info(" ".join(map(str, args))) # build device mesh for FSDP world_size = torch.distributed.get_world_size() @@ -366,11 +371,6 @@ def _build_model_optimizer( # noqa: C901 actor_model_config = AutoConfig.from_pretrained( local_path, trust_remote_code=trust_remote_code, attn_implementation=attn_implementation ) - # TODO: VL models use VisionAttention, which directly uses flash_attention in transformers>=4.53 - # which will be patched by _ulysses_flash_attention_forward, but errorly misses position_ids - # Maybe support Ulysses in VisionAttention in the future and remove this patch - if self.ulysses_sequence_parallel_size > 1 and hasattr(actor_model_config, "vision_config"): - actor_model_config.vision_config._attn_implementation = "eager" # patch for qwen2.5-vl: when using flash_attention_3, set vision tower to use flash_attention_2 # because the vision tower does not support flash_attention_3 @@ -405,21 +405,37 @@ def _build_model_optimizer( # noqa: C901 if self.rank == 0: self.logger.info(f"Model config after override: {actor_model_config}") - # NOTE(fix me): tie_word_embedding causes meta_tensor init to hang - init_context = get_init_weight_context_manager(use_meta_tensor=False, mesh=self.device_mesh) + major_capability, _ = torch.cuda.get_device_capability(0) + use_meta = ( + ( + self.rank != 0 + if self.device_mesh is None + else self.device_mesh.get_coordinate()[-1] != 0 + ) + if self.config.actor.strategy == "fsdp2" and major_capability >= 9 + else False + ) - with init_context(), warnings.catch_warnings(): + init_context = torch.device("meta") if use_meta else torch.device("cpu") + + with init_context, warnings.catch_warnings(): warnings.simplefilter("ignore") actor_module_class = get_hf_auto_model_class(actor_model_config) - - actor_module = actor_module_class.from_pretrained( - pretrained_model_name_or_path=local_path, - torch_dtype=torch_dtype, + loading_kwargs = dict( + dtype=torch_dtype, config=actor_model_config, trust_remote_code=trust_remote_code, attn_implementation=attn_implementation, ) + if use_meta: + actor_module = actor_module_class.from_config(**loading_kwargs) + else: + actor_module = actor_module_class.from_pretrained( + pretrained_model_name_or_path=local_path, + **loading_kwargs, + ) + # Apply Liger kernel to the model if use_liger is set to True if use_liger: from liger_kernel.transformers.monkey_patch import ( @@ -1207,7 +1223,11 @@ def __init__(self, config: FSDPCriticConfig): init_method=os.environ.get("DIST_INIT_METHOD", None), ) + # Setup logger self.logger = get_logger(f"critic_{self.rank}", in_ray_actor=True) + # redirect built-in print to logger to capture logs + builtins.print = lambda *args, **kwargs: self.logger.info(" ".join(map(str, args))) + self.config: FSDPCriticConfig = config # build device mesh for Ulysses Sequence Parallel diff --git a/trinity/trainer/verl/monkey_patch.py b/trinity/trainer/verl/monkey_patch.py index 52151b1f47d..126d8077e9c 100644 --- a/trinity/trainer/verl/monkey_patch.py +++ b/trinity/trainer/verl/monkey_patch.py @@ -210,12 +210,21 @@ def apply_monkey_patch( # noqa: C901 from verl.models.transformers.monkey_patch import ( _ulysses_flash_attention_forward, apply_prefix_grouper_patch, - patch_vlm_for_ulysses_input_slicing, + ) + from verl.models.transformers.monkey_patch import ( + patch_vlm_for_ulysses_input_slicing as verl_patch_vlm_for_ulysses_input_slicing, ) from verl.utils.import_utils import is_trl_available from verl.utils.transformers_compat import is_transformers_version_in_range - logger = get_logger(__name__) + logger = get_logger(__name__, in_ray_actor=True) + + def patch_vlm_for_ulysses_input_slicing(model_class: type): + if getattr(model_class, "_patch_vlm_for_ulysses_input_slicing", False): + return + + verl_patch_vlm_for_ulysses_input_slicing(model_class) + model_class._patch_vlm_for_ulysses_input_slicing = True # Apply TiledMLP monkey patch for memory-efficient MLP computation if use_tiled_mlp: @@ -309,49 +318,61 @@ def state_dict(self, *args, **kwargs): patch_vlm_for_ulysses_input_slicing(Qwen3VLMoeTextModel) elif model.config.model_type in ["qwen3_5", "qwen3_5_moe"]: - from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5TextModel + from transformers.models.qwen3_5.modeling_qwen3_5 import ( + Qwen3_5DecoderLayer, + Qwen3_5GatedDeltaNet, + Qwen3_5Model, + Qwen3_5TextModel, + Qwen3_5VisionModel, + ) from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import ( + Qwen3_5MoeDecoderLayer, + Qwen3_5MoeGatedDeltaNet, + Qwen3_5MoeModel, Qwen3_5MoeTextModel, + Qwen3_5MoeVisionModel, + ) + + from trinity.common.patch.qwen3_5 import ( + decoder_layer_forward, + gate_delta_net_forward, + qwen35_model_forward, + qwen35_vision_fast_pos_embed_interpolate, ) - # Step 1: bug fix in transformers==5.2.0 - # see https://github.com/huggingface/transformers/pull/44382 - if "Qwen3_5TextDecoderLayer" in model._no_split_modules: - model._no_split_modules.remove("Qwen3_5TextDecoderLayer") - model.model._no_split_modules.remove("Qwen3_5TextDecoderLayer") - if "Qwen3_5MoeTextDecoderLayer" in model._no_split_modules: - model._no_split_modules.remove("Qwen3_5MoeTextDecoderLayer") - model.model._no_split_modules.remove("Qwen3_5MoeTextDecoderLayer") + Qwen3_5DecoderLayer.forward = decoder_layer_forward + Qwen3_5MoeDecoderLayer.forward = decoder_layer_forward + Qwen3_5GatedDeltaNet.forward = gate_delta_net_forward + Qwen3_5MoeGatedDeltaNet.forward = gate_delta_net_forward - # see https://github.com/huggingface/transformers/pull/44399 - if is_transformers_version_in_range(max_version="5.3.0"): - from trinity.common.patch.qwen3_5 import qwen35_text_forward + Qwen3_5VisionModel.fast_pos_embed_interpolate = qwen35_vision_fast_pos_embed_interpolate + Qwen3_5MoeVisionModel.fast_pos_embed_interpolate = qwen35_vision_fast_pos_embed_interpolate - Qwen3_5TextModel.forward = qwen35_text_forward - Qwen3_5MoeTextModel.forward = qwen35_text_forward + Qwen3_5Model.forward = qwen35_model_forward + Qwen3_5MoeModel.forward = qwen35_model_forward # Step 2: patch input for multimodal sequence parallelism if ulysses_sp_size > 1: patch_vlm_for_ulysses_input_slicing(Qwen3_5TextModel) patch_vlm_for_ulysses_input_slicing(Qwen3_5MoeTextModel) - from trinity.common.patch.qwen3_5 import ( - ulysses_gated_delta_net_forward_decorator, - ) + from trinity.common.patch.qwen3_5 import ulysses_gate_delta_net_decorator for layer in model.model.language_model.layers: if layer.layer_type == "linear_attention": - layer.linear_attn.forward = ulysses_gated_delta_net_forward_decorator( - layer.linear_attn.forward - ) + ulysses_gate_delta_net_decorator(layer.linear_attn, ulysses_sp_size) # Step 3: patch verl.utils.flops_counter - from verl.utils.flops_counter import ESTIMATE_FUNC, _estimate_qwen3_vl_flops + from verl.utils.flops_counter import ( + ESTIMATE_FUNC, + _estimate_qwen3_vl_flops, + _estimate_qwen3_vl_moe_flops, + ) ESTIMATE_FUNC.update( { "qwen3_5": _estimate_qwen3_vl_flops, - "qwen3_5_moe": _estimate_qwen3_vl_flops, + "qwen3_5_moe": _estimate_qwen3_vl_moe_flops, } ) diff --git a/trinity/trainer/verl/verl_config.py b/trinity/trainer/verl/verl_config.py index 9bdcaa8f831..4c548a4440c 100644 --- a/trinity/trainer/verl/verl_config.py +++ b/trinity/trainer/verl/verl_config.py @@ -173,6 +173,7 @@ class Actor: policy_loss: PolicyLossConfig = field(default_factory=PolicyLossConfig) profiler: dict = field(default_factory=dict) router_replay: RouterReplayConfig = field(default_factory=RouterReplayConfig) + freeze_vision_tower: bool = False # do not set loss_agg_mode: str = "token-mean" loss_scale_factor: Optional[float] = None @@ -250,6 +251,7 @@ class CriticModel: enable_gradient_checkpointing: bool = True use_remove_padding: bool = True fsdp_config: FSDPConfig = field(default_factory=FSDPConfig) + freeze_vision_tower: bool = False # rope configs rope_scaling: Optional[dict] = None diff --git a/trinity/utils/log.py b/trinity/utils/log.py index 369f334b6e6..d06e174e19e 100644 --- a/trinity/utils/log.py +++ b/trinity/utils/log.py @@ -90,6 +90,7 @@ def get_logger( # File handler (rotating file log) log_dir = os.environ.get(LOG_DIR_ENV_VAR) assert name is not None, "Logger name must be set when logging from a Ray actor" + # If LOG_DIR_ENV_VAR is not set, file logging is disabled if log_dir: if os.environ.get(LOG_NODE_IP_ENV_VAR, "0") != "0": # organize logs by node IP @@ -104,8 +105,7 @@ def get_logger( file_handler.setLevel(resolved_level) file_handler.setFormatter(formatter) logger.addHandler(file_handler) - _ray_logger_ctx.set(logger) - _ray_logger = logger - # If LOG_DIR_ENV_VAR is not set, file logging is disabled + _ray_logger_ctx.set(logger) + _ray_logger = logger return logger