From b5e6175291b8464a556cb76e45549fe8dbe21f73 Mon Sep 17 00:00:00 2001 From: tbqh Date: Fri, 21 Nov 2025 12:27:29 -0800 Subject: [PATCH 1/6] Pull latest benchmark_inference from lightning-thunder repo --- benchmarks/python/benchmark_inference.py | 541 ++++++++++-------- .../python/layers_for_inference_benchmark.py | 456 +++++++-------- 2 files changed, 488 insertions(+), 509 deletions(-) diff --git a/benchmarks/python/benchmark_inference.py b/benchmarks/python/benchmark_inference.py index 66a6d5d2c3d..68ff62c7af3 100644 --- a/benchmarks/python/benchmark_inference.py +++ b/benchmarks/python/benchmark_inference.py @@ -1,7 +1,3 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. -# All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause - """Inference benchmark focusing on throughput and latency metrics of prefill and decode phases. AutoModelForCausalLM from Hugging Face transformers is used for model implementation. @@ -13,48 +9,48 @@ - Time Between Output Tokens (TBOT) """ +# fmt: off + from __future__ import annotations from contextlib import contextmanager from dataclasses import dataclass, field -from typing import TYPE_CHECKING import argparse import json import os import statistics -import sys import time import warnings +from typing import Any +from collections.abc import Callable +from looseversion import LooseVersion import torch +import torch.distributed as dist import torch.nn as nn from torch.distributed.device_mesh import init_device_mesh -from torch.distributed.distributed_c10d import destroy_process_group -from torch.distributed.tensor.parallel import ( - parallelize_module, - RowwiseParallel, - ColwiseParallel, -) +from torch.distributed.tensor.parallel import parallelize_module, RowwiseParallel, ColwiseParallel from tqdm import tqdm -from transformers import AutoModelForCausalLM -from transformers.models.llama4 import Llama4TextConfig -from transformers.cache_utils import HybridChunkedCache +import transformers +from transformers import AutoConfig, AutoModelForCausalLM +from transformers.cache_utils import HybridChunkedCache, StaticCache from transformers.models.llama4.modeling_llama4 import Llama4TextMoe +from torch.distributed.tensor.placement_types import Shard +from torch.distributed.tensor import DTensor import thunder from thunder.dynamo.compiler import thunderfx -from thunder.dynamo.report import thunderfx_benchmark_report -from layers_for_inference_benchmark import ( - GroupedLinear, +from thunder.benchmarks.layers_for_inference_benchmark import ( + GroupedSwiGLU, Llama4MoE, - NVFP4InferenceGroupedLinear, - NVFP4InferenceLinear, + NVFP4InferenceGroupedSwiGLU, nvfuser_f16a_nvfp4weight_scaled_grouped_mm, - nvfuser_f16a_nvfp4weight_scaled_mm, + FLOAT4_E2M1_MAX, + FLOAT8_E4M3_EPS, + FLOAT8_E4M3_MAX, ) -from thunder.torch.custom_op import _register_custom_op - -if TYPE_CHECKING: - from typing import Any +from thunder.tests.distributed.test_moe import GroupedLinearColwiseParallel, GroupedLinearRowwiseParallel +from thunder.transforms.cudagraph import CUDAGraphTransform +from thunder.torch.custom_op import _register_custom_op, _register_nvfuser_translator RANK = int(os.environ.get("RANK", 0)) @@ -67,54 +63,77 @@ DEVICE = torch.device("cuda", LOCAL_RANK) torch.cuda.set_device(DEVICE) -if WORLD_SIZE > 1: +if dist.is_torchelastic_launched(): mesh = init_device_mesh("cuda", (WORLD_SIZE,), mesh_dim_names=("tp",)) +else: + mesh = None LLAMA4_MAVERICK_MODEL_ID: str = "meta-llama/Llama-4-Maverick-17B-128E" -llama_4_Maverick_17B_128E_cfg_str = r""" { - "attention_bias": false, - "attention_chunk_size": 8192, - "attention_dropout": 0.0, - "attn_scale": 0.1, - "attn_temperature_tuning": true, - "bos_token_id": 200000, - "cache_implementation": "hybrid", - "eos_token_id": [ - 200001, - 200007, - 200008 - ], - "floor_scale": 8192, - "for_llm_compressor": false, - "head_dim": 128, - "hidden_act": "silu", - "hidden_size": 5120, - "initializer_range": 0.02, - "interleave_moe_layer_step": 2, - "intermediate_size": 8192, - "intermediate_size_mlp": 16384, - "max_position_embeddings": 262144, - "model_type": "llama4_text", - "moe_layers": [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 33, 35, 37, 39, 41, 43, 45, 47], - "no_rope_layers": [1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0], - "num_attention_heads": 40, - "num_experts_per_tok": 1, - "num_hidden_layers": 48, - "num_key_value_heads": 8, - "num_local_experts": 128, - "output_router_logits": false, - "pad_token_id": 200018, - "rms_norm_eps": 1e-05, - "rope_scaling": null, - "rope_theta": 500000.0, - "router_aux_loss_coef": 0.001, - "router_jitter_noise": 0.0, - "torch_dtype": "bfloat16", - "use_cache": true, - "use_qk_norm": false, - "vocab_size": 202048 -} -""" + + +# TODO: Add mm quantization once nvfuser implements nvfp4 gemm +# Register nvfp4 custom ops with Thunder and nvFuser +def _register_nvfp4_ops(): + """Register nvfp4 custom operations with Thunder.""" + # Register f16a_nvfp4weight_scaled_grouped_mm with nvfuser translator + _nvfp4_grouped_mm_symbol = _register_custom_op(nvfuser_f16a_nvfp4weight_scaled_grouped_mm) + + def nvfp4_grouped_mm_translator( + activation, + fp4_weight, + weight_scaling_factor, + global_scale, + offsets, + blockscale_offsets, + problem_sizes, + *, + fd, + lc_to_nv_map, + ): + from nvfuser_direct import DataType + from thunder.executors.nvfuserex_impl import getnv + + nv_act = getnv(activation, fd, lc_to_nv_map) + nv_fp4_w = getnv(fp4_weight, fd, lc_to_nv_map) + nv_sf_w = getnv(weight_scaling_factor, fd, lc_to_nv_map) + nv_alpha = getnv(global_scale, fd, lc_to_nv_map) + nv_offsets = getnv(offsets, fd, lc_to_nv_map) + nv_blocksf_offsets = getnv(blockscale_offsets, fd, lc_to_nv_map) + nv_problem_sizes = getnv(problem_sizes, fd, lc_to_nv_map) + # dynamic shape support has some concretization issue + m_size = activation.shape[0] + k_size = activation.shape[1] + k_tile_size = k_size // 16 + + reshaped_mat1 = fd.ops.reshape(nv_act, [m_size, k_tile_size, 16]) + scale1 = fd.ops.abs(reshaped_mat1) + scale1 = fd.ops.max(scale1, 2) + scale1 = fd.ops.div(scale1, FLOAT4_E2M1_MAX) + scale1 = fd.ops.clamp(scale1, FLOAT8_E4M3_EPS, FLOAT8_E4M3_MAX) + + broadcast_scale1 = fd.ops.broadcast(scale1, [False, False, True]) + reshaped_scaled_mat1 = fd.ops.div(reshaped_mat1, broadcast_scale1) + reshaped_scaled_mat1 = fd.ops.clamp(reshaped_scaled_mat1, -FLOAT8_E4M3_MAX, FLOAT8_E4M3_MAX) + + scaled_mat1 = fd.ops.reshape(reshaped_scaled_mat1, [m_size, k_size]) + fp4_mat1 = fd.ops.cast(scaled_mat1, DataType.Float4_e2m1fn) + fp8_scale1 = fd.ops.cast(scale1, DataType.Float8_e4m3fn) + layout_fp8_scale1 = fd.ops.preprocess_grouped_matmul_input_sf(fp8_scale1, nv_offsets, nv_blocksf_offsets) + out = fd.ops.cutlass_nvfp4_grouped_mm( + fp4_mat1, + nv_fp4_w, + layout_fp8_scale1, + nv_sf_w, + nv_alpha, + # NOTE: we might need to call contiguous on problem_sizes + nv_problem_sizes, + nv_offsets, + nv_blocksf_offsets, + DataType.BFloat16, + ) + return out + + _register_nvfuser_translator(_nvfp4_grouped_mm_symbol, nvfp4_grouped_mm_translator) # The logic is based on https://github.com/pytorch/ao/blob/b34c1037/torchao/quantization/quant_api.py#L230 @@ -161,16 +180,18 @@ def _replace_llama4_moe(model: nn.Module) -> None: def _quantize_llama4(model: nn.Module) -> None: - """Replace linear and moe with nvfp4 inference version.""" - _replace_with_custom_fn_if_matches_filter_with_name( - model, - NVFP4InferenceLinear.from_linear, - lambda model, cur_fqn: isinstance(model, nn.Linear), - ) + """Replace linear and/or MoE with nvfp4 inference version. + + Args: + model: The model to quantize + + Note: GroupedSwiGLU is always quantized when this function is called. + """ + # Always quantize GroupedSwiGLU when this function is called _replace_with_custom_fn_if_matches_filter_with_name( model, - NVFP4InferenceGroupedLinear.from_grouped_linear, - lambda model, cur_fqn: isinstance(model, GroupedLinear), + NVFP4InferenceGroupedSwiGLU.from_grouped_swiglu, + lambda model, cur_fqn: isinstance(model, GroupedSwiGLU), ) @@ -194,12 +215,15 @@ class InferenceBenchmarkConfig: num_layers: int | None num_iterations: int warmup_iterations: int - dtensor_single_gpu: bool - enable_nvfp4: bool # Enable NVFP4 quantization + enable_nvfp4: bool # Enable NVFP4 registration and quantize GroupedSwiGLU in MoE fx_report_folder: str | None enable_nv_linear: bool mode: str disable_moe_replacement: bool + attn_implementation: str | None + profile: bool + thunder_cache: str | None + enable_thunder_cudagraph: bool @dataclass @@ -234,42 +258,90 @@ def __init__(self, config: InferenceBenchmarkConfig): self.config = config self.metrics = InferenceMetrics() + # NOTE: Model resides on meta device model = self._load_model() + assert all(p.device == torch.device("meta") for p in model.parameters()) + + # NOTE: Replacement happens before model is materialized + # otherwise, the memory usage will be increased due to + # additional parameters materialized from the replacement module + if not self.config.disable_moe_replacement: + _replace_llama4_moe(model) + assert all(p.device == torch.device("meta") for p in model.parameters()) tp_plan = { "*.layers.*.self_attn.q_proj": ColwiseParallel(use_local_output=True), "*.layers.*.self_attn.k_proj": ColwiseParallel(use_local_output=True), "*.layers.*.self_attn.v_proj": ColwiseParallel(use_local_output=True), "*.layers.*.self_attn.o_proj": RowwiseParallel(use_local_output=True), - "*.layers.*.feed_forward.gate_proj": ColwiseParallel( - use_local_output=False - ), + "*.layers.*.feed_forward.gate_proj": ColwiseParallel(use_local_output=False), "*.layers.*.feed_forward.up_proj": ColwiseParallel(use_local_output=False), "*.layers.*.feed_forward.down_proj": RowwiseParallel(use_local_output=True), - "*.layers.*.feed_forward.shared_expert.gate_proj": ColwiseParallel( - use_local_output=False - ), - "*.layers.*.feed_forward.shared_expert.up_proj": ColwiseParallel( - use_local_output=False - ), - "*.layers.*.feed_forward.shared_expert.down_proj": RowwiseParallel( - use_local_output=True - ), } - if self.config.dtensor_single_gpu or WORLD_SIZE > 1: + if not self.config.disable_moe_replacement: + tp_plan.update( + { + # Custom MoE + "*.layers.*.feed_forward.shared_experts.gate_proj": ColwiseParallel( + use_local_output=False, output_layouts=Shard(2) + ), + "*.layers.*.feed_forward.shared_experts.up_proj": ColwiseParallel( + use_local_output=False, output_layouts=Shard(2) + ), + "*.layers.*.feed_forward.shared_experts.down_proj": RowwiseParallel(), + "*.layers.*.feed_forward.routed_experts.gate_proj": GroupedLinearColwiseParallel( + use_local_output=False + ), + "*.layers.*.feed_forward.routed_experts.up_proj": GroupedLinearColwiseParallel( + use_local_output=False + ), + "*.layers.*.feed_forward.routed_experts.down_proj": GroupedLinearRowwiseParallel(), + } + ) + + else: + tp_plan.update( + { + # HF MoE + "*.layers.*.feed_forward.shared_expert.gate_proj": ColwiseParallel(use_local_output=False), + "*.layers.*.feed_forward.shared_expert.up_proj": ColwiseParallel(use_local_output=False), + "*.layers.*.feed_forward.shared_expert.down_proj": RowwiseParallel(use_local_output=True), + # TODO:Need to write ParallelStyle for HF's grouped_mm implementation. + } + ) + + if mesh: model = parallelize_module(model, mesh, tp_plan) - # Required as that doesn't understand inference mode - for p in model.parameters(): - p.requires_grad_(False) + # Sanity check + if not self.config.disable_moe_replacement: + assert type(model.model.layers[1].feed_forward.shared_experts.gate_proj.weight) == DTensor + assert type(model.model.layers[1].feed_forward.shared_experts.up_proj.weight) == DTensor + assert type(model.model.layers[1].feed_forward.shared_experts.down_proj.weight) == DTensor + assert type(model.model.layers[1].feed_forward.routed_experts.gate_proj.weight) == DTensor + assert type(model.model.layers[1].feed_forward.routed_experts.up_proj.weight) == DTensor + assert type(model.model.layers[1].feed_forward.routed_experts.down_proj.weight) == DTensor + else: + assert type(model.model.layers[1].feed_forward.shared_expert.gate_proj.weight) == DTensor + assert type(model.model.layers[1].feed_forward.shared_expert.up_proj.weight) == DTensor + assert type(model.model.layers[1].feed_forward.shared_expert.down_proj.weight) == DTensor + + # Materialize the model on the device (after Llama4MoE replacement and sharding) + model.to_empty(device=DEVICE) + assert all(p.device == DEVICE for p in model.parameters()) + + # Required as thunder doesn't understand inference mode + # And some prims like `prims._grouped_mm` don't have grad rule defined yet. + for p in model.parameters(): + p.requires_grad_(False) + + assert all(not p.requires_grad for p in model.parameters()) # `thunderfx` seems to hide the access to vocab_size somewhere so # store it here before any compiler is applied. self.vocab_size = model.vocab_size - if not self.config.disable_moe_replacement: - _replace_llama4_moe(model) if self.config.enable_nvfp4: _quantize_llama4(model) self.model = self._compile_model(model) @@ -278,13 +350,25 @@ def __init__(self, config: InferenceBenchmarkConfig): def _thunder_jit_options(self) -> dict[str, Any]: # `nv_enable_linear=True` might fail with distributed run # ref: https://github.com/NVIDIA/Fuser/issues/4507 + res = {"transforms": []} if self.config.enable_nv_linear: - return {"nv_enable_linear": True, "nv_enable_matmul": True} - return {} + res["nv_enable_linear"] = True + res["nv_enable_matmul"] = True + if self.config.mode == "thunderjit": + from thunder.recipes.hf_transformers import SDPAMaskTransform + + if not hasattr(self, "_mask_transform"): + self._mask_transform = SDPAMaskTransform() + res["transforms"].append(self._mask_transform) + res["executors"] = [self._mask_transform.get_executor(), *thunder.get_default_executors()] + if self.config.enable_thunder_cudagraph: + res["transforms"].append(CUDAGraphTransform()) + if self.config.thunder_cache is not None: + res["cache"] = self.config.thunder_cache + + return res def _compile_model(self, model): - if self.config.fx_report_folder is not None: - return model match self.config.mode: case "eager": return model @@ -300,9 +384,7 @@ def _compile_model(self, model): def _load_model(self) -> torch.nn.Module: """Load the model based on configuration""" model_id = self.config.model_name - config = Llama4TextConfig.from_dict( - json.loads(llama_4_Maverick_17B_128E_cfg_str) - ) + config = AutoConfig.from_pretrained(model_id) if hasattr(config, "text_config"): config = config.text_config @@ -311,8 +393,10 @@ def _load_model(self) -> torch.nn.Module: self.hf_config = config - with DEVICE: - model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.bfloat16) + with torch.device("meta"): + model = AutoModelForCausalLM.from_config( + config, torch_dtype=torch.bfloat16, attn_implementation=self.config.attn_implementation + ) return model @@ -321,36 +405,39 @@ def generate_batch(self) -> tuple[torch.Tensor, HybridChunkedCache]: batch_size = self.config.batch_size input_length = self.config.input_length - input_ids = torch.randint( - 0, self.vocab_size, (batch_size, input_length), device=DEVICE - ) - past_key_values = HybridChunkedCache( - self.hf_config, - input_ids.shape[0], - input_ids.shape[1] + self.config.output_length, - ) - for layer_idx in range(self.hf_config.num_hidden_layers): - # key_states.shape[1] is used to retrieve the number of key value heads, all other dimensions can be 1 and ignored - # https://github.com/huggingface/transformers/blob/9300728665aaeb0ebf4db99f9d9fbce916b4a183/src/transformers/cache_utils.py#L1822 - dummy_key_states = torch.empty( - 1, self.hf_config.num_key_value_heads // WORLD_SIZE, 1, 1, device=DEVICE + input_ids = torch.randint(0, self.vocab_size, (batch_size, input_length), device=DEVICE) + if LooseVersion(transformers.__version__) >= LooseVersion("4.55"): + # Transformers deprecated HybridChunkedCache in favour of static in 4.55.x + past_key_values = StaticCache( + config=self.hf_config, + max_batch_size=input_ids.shape[0], + max_cache_len=input_ids.shape[1] + self.config.output_length, + device=DEVICE, + dtype=torch.bfloat16, + ) + else: + past_key_values = HybridChunkedCache( + self.hf_config, input_ids.shape[0], input_ids.shape[1] + self.config.output_length ) - past_key_values.initialise_cache_layer(layer_idx, dummy_key_states) + for layer_idx in range(self.hf_config.num_hidden_layers): + # key_states.shape[1] is used to retrieve the number of key value heads, all other dimensions can be 1 and ignored + # https://github.com/huggingface/transformers/blob/9300728665aaeb0ebf4db99f9d9fbce916b4a183/src/transformers/cache_utils.py#L1822 + dummy_key_states = torch.empty(1, self.hf_config.num_key_value_heads // WORLD_SIZE, 1, 1, device=DEVICE) + past_key_values.initialise_cache_layer(layer_idx, dummy_key_states) return input_ids, past_key_values def get_next_token( - self, input_ids: torch.Tensor, past_key_values: HybridChunkedCache + self, input_ids: torch.Tensor, past_key_values: HybridChunkedCache | StaticCache ) -> torch.Tensor: - outputs = self.model(input_ids, past_key_values=past_key_values, use_cache=True) + with torch.no_grad(): + outputs = self.model(input_ids, past_key_values=past_key_values, use_cache=True) logits = outputs.logits # [B, seq_len, vocab_size] next_token_logits = logits[:, -1, :] next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) return next_token - def prefill( - self, input_ids: torch.Tensor, past_key_values: HybridChunkedCache - ) -> torch.Tensor: + def prefill(self, input_ids: torch.Tensor, past_key_values: HybridChunkedCache) -> torch.Tensor: """ Prefill phase: Process the entire input prompt at once. Returns the next token. @@ -359,25 +446,24 @@ def prefill( """ return self.get_next_token(input_ids, past_key_values) - def decode_one_token( - self, input_ids: torch.Tensor, past_key_values: HybridChunkedCache - ) -> torch.Tensor: + def decode_one_token(self, input_ids: torch.Tensor, past_key_values: HybridChunkedCache) -> torch.Tensor: """ Decode phase: Generate a single token given the current sequence. Returns the next token. """ # input_pos: [B, 1] One token at the time - assert ( - input_ids.shape[-1] == 1 - ), f"Expected shape (B, 1), but found {input_ids.shape}" + assert input_ids.shape[-1] == 1, f"Expected shape (B, 1), but found {input_ids.shape}" return self.get_next_token(input_ids, past_key_values) - @torch.inference_mode() + # TODO: Running `torchrun --nproc-per-node 2 thunder/benchmarks/benchmark_inference.py --input-length 32 --output-length 32 --mode eager --num-iterations 10` + # with inference mode results in + # [rank1]: File "/opt/pytorch/lightning-thunder/thunder/benchmarks/layers_for_inference_benchmark.py", line 358, in grouped_mm + # [rank1]: group_outs.append(group_a @ b[idx]) + # [rank1]: ~^^^^^ + # [rank1]: RuntimeError: Cannot set version_counter for inference tensor + # @torch.inference_mode() def generate( - self, - input_ids: torch.Tensor, - max_new_tokens: int, - past_key_values: HybridChunkedCache, + self, input_ids: torch.Tensor, max_new_tokens: int, past_key_values: HybridChunkedCache ) -> dict[str, Any]: """ Generate tokens using separate prefill and decode phases. @@ -406,22 +492,15 @@ def generate( } def measure_inference_step( - self, - input_ids: torch.Tensor, - past_key_values: HybridChunkedCache, - max_new_tokens: int, + self, input_ids: torch.Tensor, past_key_values: HybridChunkedCache, max_new_tokens: int ) -> dict[str, float]: """Measure a single inference step with detailed timing using separate prefill/decode""" # Generate tokens with separate prefill/decode tracking generation_result = self.generate(input_ids, max_new_tokens, past_key_values) - total_time = ( - generation_result["prefill_time_ms"] + generation_result["decode_time_ms"] - ) + total_time = generation_result["prefill_time_ms"] + generation_result["decode_time_ms"] # Extract metrics - ttft = generation_result[ - "prefill_time_ms" - ] # Time to first token is the prefill time + ttft = generation_result["prefill_time_ms"] # Time to first token is the prefill time total_decode_time = generation_result["decode_time_ms"] avg_tbot = total_decode_time / (max_new_tokens - 1) if max_new_tokens > 1 else 0 @@ -431,14 +510,10 @@ def measure_inference_step( # Calculate separate prefill and decode throughput prefill_tokens = self.config.input_length * self.config.batch_size - prefill_throughput = ( - prefill_tokens / generation_result["prefill_time_ms"] - ) * 1000 + prefill_throughput = (prefill_tokens / generation_result["prefill_time_ms"]) * 1000 decode_tokens = (self.config.output_length - 1) * self.config.batch_size - decode_throughput = ( - (decode_tokens / total_decode_time) * 1000 if total_decode_time > 0 else 0 - ) + decode_throughput = (decode_tokens / total_decode_time) * 1000 if total_decode_time > 0 else 0 return { "ttft": ttft, @@ -451,25 +526,8 @@ def measure_inference_step( "total_decode_time": total_decode_time, } - def _run_thunderfx_benchmark_report(self): - print( - f"Running thunderfx benchmark report for {self.config.model_name} to {self.config.fx_report_folder}" - ) - print(f"Batch size: {self.config.batch_size}") - print(f"Input length: {self.config.input_length}") - print(f"Output length: {self.config.output_length}") - input_ids, past_key_values = self.generate_batch() - thunderfx_benchmark_report( - self.model, - folder_path=self.config.fx_report_folder, - compare_fusion=True, - )(input_ids, past_key_values) - def run_benchmark(self) -> InferenceMetrics: """Run the full benchmark and collect metrics""" - if self.config.fx_report_folder is not None: - self._run_thunderfx_benchmark_report() - return print(f"Running inference benchmark for {self.config.model_name}") print(f"Batch size: {self.config.batch_size}") @@ -480,20 +538,28 @@ def run_benchmark(self) -> InferenceMetrics: print(f"\nWarming up with {self.config.warmup_iterations} iterations...") input_ids, past_key_values = self.generate_batch() - for _ in tqdm(range(self.config.warmup_iterations)): + for _ in tqdm(range(self.config.warmup_iterations), disable=LOCAL_RANK != 0): past_key_values.reset() - _ = self.measure_inference_step( - input_ids, past_key_values, max_new_tokens=1 - ) + # Use output_length to warm up sufficiently. Otherwise, Thunder's + # first-run latency is terribly slow due to lack of dynamic shape + # support. + _ = self.measure_inference_step(input_ids, past_key_values, self.config.output_length) print(f"\nRunning {self.config.num_iterations} benchmark iterations...") all_metrics = [] - for _ in tqdm(range(self.config.num_iterations)): + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats() + + for _ in tqdm(range(self.config.num_iterations), disable=LOCAL_RANK != 0): past_key_values.reset() - iter_metrics = self.measure_inference_step( - input_ids, past_key_values, self.config.output_length - ) + + if self.config.profile: + torch.cuda.cudart().cudaProfilerStart() + iter_metrics = self.measure_inference_step(input_ids, past_key_values, self.config.output_length) + if self.config.profile: + torch.cuda.cudart().cudaProfilerStop() + all_metrics.append(iter_metrics) # Track metrics @@ -508,6 +574,10 @@ def run_benchmark(self) -> InferenceMetrics: self.metrics.memory_used_gb = torch.cuda.memory_allocated() / 1e9 self.metrics.peak_memory_gb = torch.cuda.max_memory_allocated() / 1e9 + if self.config.fx_report_folder is not None and self.config.mode == "thunder": + self.model._backend.save_reproducer_to_folder(self.config.fx_report_folder) + return + return self.metrics def _calculate_aggregate_metrics(self, all_metrics: list[dict[str, Any]]): @@ -526,26 +596,20 @@ def _calculate_aggregate_metrics(self, all_metrics: list[dict[str, Any]]): self.metrics.time_to_first_token_ms = statistics.mean(ttfts) # TBOT - self.metrics.time_between_output_tokens_ms = statistics.mean( - [m["avg_tbot"] for m in all_metrics] - ) + self.metrics.time_between_output_tokens_ms = statistics.mean([m["avg_tbot"] for m in all_metrics]) # Total time self.metrics.total_time_ms = statistics.mean(total_times) # Prefill metrics prefill_throughputs = [m["prefill_throughput"] for m in all_metrics] - self.metrics.prefill_throughput_tokens_per_sec = statistics.mean( - prefill_throughputs - ) + self.metrics.prefill_throughput_tokens_per_sec = statistics.mean(prefill_throughputs) prefill_times = [m["prefill_time"] for m in all_metrics] self.metrics.prefill_time_ms = statistics.mean(prefill_times) # Decode metrics decode_throughputs = [m["decode_throughput"] for m in all_metrics] - self.metrics.decode_throughput_tokens_per_sec = statistics.mean( - decode_throughputs - ) + self.metrics.decode_throughput_tokens_per_sec = statistics.mean(decode_throughputs) decode_times = [m["total_decode_time"] for m in all_metrics] self.metrics.decode_time_ms = statistics.mean(decode_times) @@ -556,24 +620,14 @@ def print_results(self): print("=" * 60) print("\nThroughput Metrics:") - print( - f" Overall Throughput: {self.metrics.throughput_tokens_per_sec:.2f} tokens/sec" - ) - print( - f" Prefill Throughput: {self.metrics.prefill_throughput_tokens_per_sec:.2f} tokens/sec" - ) - print( - f" Decode Throughput: {self.metrics.decode_throughput_tokens_per_sec:.2f} tokens/sec" - ) + print(f" Overall Throughput: {self.metrics.throughput_tokens_per_sec:.2f} tokens/sec") + print(f" Prefill Throughput: {self.metrics.prefill_throughput_tokens_per_sec:.2f} tokens/sec") + print(f" Decode Throughput: {self.metrics.decode_throughput_tokens_per_sec:.2f} tokens/sec") print(f" Latency: {self.metrics.latency_ms_per_token:.2f} ms/token") print("\nLatency Breakdown:") - print( - f" Time to First Token (TTFT): {self.metrics.time_to_first_token_ms:.2f} ms" - ) - print( - f" Time Between Output Tokens (TBOT): {self.metrics.time_between_output_tokens_ms:.2f} ms" - ) + print(f" Time to First Token (TTFT): {self.metrics.time_to_first_token_ms:.2f} ms") + print(f" Time Between Output Tokens (TBOT): {self.metrics.time_between_output_tokens_ms:.2f} ms") print(f" Prefill Time: {self.metrics.prefill_time_ms:.2f} ms") print(f" Decode Time: {self.metrics.decode_time_ms:.2f} ms") print(f" Total Generation Time: {self.metrics.total_time_ms:.2f} ms") @@ -584,9 +638,7 @@ def print_results(self): if len(self.metrics.iteration_times) > 1: print("\nVariance Analysis:") - print( - f" Throughput Std Dev: {statistics.stdev(self.metrics.iteration_times):.2f} ms" - ) + print(f" Throughput Std Dev: {statistics.stdev(self.metrics.iteration_times):.2f} ms") print(f" TTFT Std Dev: {statistics.stdev(self.metrics.ttft_times):.2f} ms") def save_results(self, filename: str): @@ -620,9 +672,7 @@ def save_results(self, filename: str): print(f"\nResults saved to {filename}") -class CustomFormatter( - argparse.RawDescriptionHelpFormatter, argparse.ArgumentDefaultsHelpFormatter -): +class CustomFormatter(argparse.RawDescriptionHelpFormatter, argparse.ArgumentDefaultsHelpFormatter): pass @@ -632,11 +682,6 @@ def parse_args() -> argparse.Namespace: description="Thunder Inference Benchmark", formatter_class=CustomFormatter, epilog=""" -Standard Benchmark Scenarios: - summarization - Prefill-Heavy: 4,000 input → 1,000 output tokens - chat - Balanced: 1,000 input → 1,000 output tokens - reasoning - Decode-Heavy: 1,000 input → 4,000 output tokens - Examples: python benchmark_inference.py --input-length 2048 --output-length 512 --model-name meta-llama/Llama-4-Maverick-17B-128E --mode eager """, @@ -648,9 +693,7 @@ def parse_args() -> argparse.Namespace: default=LLAMA4_MAVERICK_MODEL_ID, help="Model to benchmark", ) - parser.add_argument( - "--batch-size", type=int, default=1, help="Batch size for inference" - ) + parser.add_argument("--batch-size", type=int, default=1, help="Batch size for inference") parser.add_argument( "--input-length", type=int, @@ -663,12 +706,8 @@ def parse_args() -> argparse.Namespace: default=128, help="Output sequence length", ) - parser.add_argument( - "--num-iterations", type=int, default=100, help="Number of benchmark iterations" - ) - parser.add_argument( - "--warmup-iterations", type=int, default=10, help="Number of warmup iterations" - ) + parser.add_argument("--num-iterations", type=int, default=100, help="Number of benchmark iterations") + parser.add_argument("--warmup-iterations", type=int, default=10, help="Number of warmup iterations") parser.add_argument( "--num-layers", default=2, @@ -696,15 +735,10 @@ def parse_args() -> argparse.Namespace: help="Specify the folder for thunderfx_benchmark_report.", ) - parser.add_argument( - "--dtensor-single-gpu", - action="store_true", - help="Use DTensor for single GPU", - ) parser.add_argument( "--enable-nvfp4", action="store_true", - help="Enable NVFP4 quantization for linear layers", + help="Enable NVFP4 quantization for MoE GroupedSwiGLU layers (has nvfuser grouped_mm support)", ) parser.add_argument( "--enable-nv-linear", @@ -712,17 +746,26 @@ def parse_args() -> argparse.Namespace: help="let nvfuser take care of linear and matmul, note that this might fail with distributed run. See: https://github.com/NVIDIA/Fuser/issues/4507", ) parser.add_argument( - "--thunder-trace", + "--profile", action="store_true", - help="Enable debug dump of thunder trace", + help="Wrap each non-warmup iteration with cudaProfilerStart() and cudaProfilerStop(). This allows us to run `nsys profile --capture-range=cudaProfilerApi --capture-range-end=repeat: ... --profile` to record only the non-warmup iterations.", ) parser.add_argument( - "--save-results", action="store_true", help="Save results to JSON file" + "--thunder-trace", + action="store_true", + help="Enable debug dump of thunder trace", ) + parser.add_argument("--save-results", action="store_true", help="Save results to JSON file") + parser.add_argument("--output-dir", type=str, default="./results", help="Directory to save results") parser.add_argument( - "--output-dir", type=str, default="./results", help="Directory to save results" + "--thunder-cache", + type=str, + default=None, + help="Cache option: no caching, same input, constant values, symbolic values. See `cache` argument of `thunder.jit` for more details.", ) + parser.add_argument("--enable-thunder-cudagraph", action="store_true", help="Pass CUDAGraphTransform to Thunder") + parser.add_argument("--attn-implementation", type=str, default=None, help="Attention implementation") args = parser.parse_args() return args @@ -733,17 +776,13 @@ def main(): if args.save_results: os.makedirs(args.output_dir, exist_ok=True) - # TODO: Override the forward with nvfuser_direct based implementation like - # https://github.com/Lightning-AI/lightning-thunder/blob/8b72715d/thunder/tests/test_torch_library_custom_op.py#L250-L266 does. - # Note that the linked code is in a draft pull request of https://github.com/Lightning-AI/lightning-thunder/pull/2481 - # so we might want to do it more clumsily by copying the code in the pull request for now. + # Register NVFP4 custom ops with nvfuser translators when enabled if args.enable_nvfp4: - sym_of_nvfp4_scaled_mm = _register_custom_op( - nvfuser_f16a_nvfp4weight_scaled_mm - ) # noqa: F841 - sym_of_nvfp4_scaled_grouped_mm = _register_custom_op( - nvfuser_f16a_nvfp4weight_scaled_grouped_mm - ) # noqa: F841 + try: + _register_nvfp4_ops() + except Exception as e: + # If registration fails (e.g., nvfuser not available), warn and continue + warnings.warn(f"Failed to register nvfp4 custom ops: {e}") config = InferenceBenchmarkConfig( model_name=args.model_name, @@ -754,19 +793,17 @@ def main(): num_iterations=args.num_iterations, warmup_iterations=args.warmup_iterations, mode=args.mode, - dtensor_single_gpu=args.dtensor_single_gpu, enable_nvfp4=args.enable_nvfp4, fx_report_folder=args.fx_report_folder, enable_nv_linear=args.enable_nv_linear, disable_moe_replacement=args.disable_moe_replacement, + attn_implementation=args.attn_implementation, + profile=args.profile, + thunder_cache=args.thunder_cache, + enable_thunder_cudagraph=args.enable_thunder_cudagraph, ) benchmark = InferenceBenchmark(config) - if args.enable_nvfp4: - msg = "NVFP4 kernels are not yet available. `--enable-nvfp4` runs only quantization but not benchmark" - warnings.warn(msg) - sys.exit(0) - benchmark.run_benchmark() benchmark.print_results() @@ -780,9 +817,7 @@ def main(): if args.save_results: timestamp = time.strftime("%Y%m%d_%H%M%S") - filename = ( - f"thunder_inference_{args.model_name.replace('/', '_')}_{timestamp}.json" - ) + filename = f"thunder_inference_{args.model_name.replace('/', '_')}_{timestamp}.json" path = os.path.join(args.output_dir, filename) benchmark.save_results(path) @@ -793,6 +828,6 @@ def main(): except Exception: raise finally: - if WORLD_SIZE > 1: + if mesh: for process_group in mesh.get_all_groups(): - destroy_process_group(process_group) + dist.destroy_process_group(process_group) diff --git a/benchmarks/python/layers_for_inference_benchmark.py b/benchmarks/python/layers_for_inference_benchmark.py index 04ee99fa4d1..7a718fc9479 100644 --- a/benchmarks/python/layers_for_inference_benchmark.py +++ b/benchmarks/python/layers_for_inference_benchmark.py @@ -9,6 +9,11 @@ # SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +# +# NOTE: `pytorch_nvfp4_quantize` and `linear_to_swizzled_128_4` are copied from NVIDIA's Fuser's test code. + +# fmt: off + from __future__ import annotations from typing import TYPE_CHECKING import math @@ -17,6 +22,8 @@ import torch import torch.nn as nn from torch.testing._internal.common_quantized import _f32_to_floatx_unpacked +from torch.distributed.tensor import DTensor +from torch.distributed.tensor.placement_types import Replicate if TYPE_CHECKING: from transformers.models.llama4.modeling_llama4 import Llama4TextMoe @@ -24,11 +31,11 @@ __all__ = [ "GroupedLinear", + "GroupedSwiGLU", "Llama4MoE", "NVFP4InferenceGroupedLinear", - "NVFP4InferenceLinear", + "NVFP4InferenceGroupedSwiGLU", "nvfuser_f16a_nvfp4weight_scaled_grouped_mm", - "nvfuser_f16a_nvfp4weight_scaled_mm", ] @@ -64,9 +71,9 @@ def to_fp4(x: torch.Tensor) -> torch.Tensor: # Ref: https://github.com/NVIDIA/Fuser/blob/d70540f9/tests/python/utils/narrow_precision.py#L125-L148 def pytorch_nvfp4_quantize(a, a_global_scale): BLOCK_SIZE = 16 - assert ( - a.size(-1) % BLOCK_SIZE == 0 - ), "The inner-most dim must be divisible by block_size; Padding is not implemented." + assert a.size(-1) % BLOCK_SIZE == 0, ( + "The inner-most dim must be divisible by block_size; Padding is not implemented." + ) assert a.is_contiguous(), "Only contiguous tensors are supported." original_shape = a.shape @@ -101,9 +108,7 @@ def linear_to_swizzled_128_4(a_sf_linear: torch.Tensor): k_tiles = (sf_k + 4 - 1) // 4 k_padded = k_tiles * 4 if mn_padded != mn or k_padded != sf_k: - a_sf_padded = torch.empty( - mn_padded, k_padded, dtype=a_sf_linear.dtype, device=a_sf_linear.device - ) + a_sf_padded = torch.empty(mn_padded, k_padded, dtype=a_sf_linear.dtype, device=a_sf_linear.device) a_sf_padded[0:mn, 0:sf_k] = a_sf_linear else: a_sf_padded = a_sf_linear @@ -118,9 +123,7 @@ def quantize_linear_weight_to_nvfp4( weight: torch.Tensor | nn.Parameter, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: """Quantize weight to nvfp4, returning (packed) e2m1 weight, e4m3 scale factor, fp32 global scale.""" - global_scale = ( - (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / weight.float().abs().amax() - ).to(torch.float32) + global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / weight.float().abs().amax()).to(torch.float32) fp4_weight, weight_scaling_factor = pytorch_nvfp4_quantize(weight, global_scale) weight_scale_interleaved = linear_to_swizzled_128_4(weight_scaling_factor) return fp4_weight, weight_scale_interleaved, global_scale @@ -163,12 +166,8 @@ def unpack_fp4_bytes(a, dtype): a = a.view(torch.uint8).flatten() upper_half_byte = (a & 0xF0) >> 4 lower_half_byte = a & 0x0F - upper_half_float = torch.tensor([e2m1_to_fp32(x) for x in upper_half_byte]).to( - a.device - ) - lower_half_float = torch.tensor([e2m1_to_fp32(x) for x in lower_half_byte]).to( - a.device - ) + upper_half_float = torch.tensor([e2m1_to_fp32(x) for x in upper_half_byte]).to(a.device) + lower_half_float = torch.tensor([e2m1_to_fp32(x) for x in lower_half_byte]).to(a.device) out = torch.stack((lower_half_float, upper_half_float), dim=-1).reshape(m, n * 2) return out @@ -186,9 +185,7 @@ def swizzled_to_linear_128_4(a_sf_swizzled: torch.Tensor, mn, k): # Ref: https://github.com/NVIDIA/Fuser/blob/d70540f9/tests/python/utils/narrow_precision.py#L85-L101 -def dequantize_to_dtype( - tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16 -): +def dequantize_to_dtype(tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16): """Dequantize the fp4 tensor back to high precision.""" # Two fp4 values are packed into one uint8. assert tensor_fp4.dtype == torch.float4_e2m1fn_x2 @@ -205,73 +202,27 @@ def dequantize_to_dtype( return out -# TODO: Update this accordingly to the progress of nvfp4 kernel implementation. -# An alternative is to use `_register_nvfuser_translator` of https://github.com/Lightning-AI/lightning-thunder/pull/2481 -# instead of updating this function itself. -@torch.library.custom_op("nvf_cutlass::f16a_nvfp4weight_scaled_mm", mutates_args=()) -def nvfuser_f16a_nvfp4weight_scaled_mm( - activation: torch.Tensor, - fp4_weight: torch.Tensor, - weight_scaling_factor: torch.Tensor, - weight_global_scale: torch.Tensor, - bias: torch.Tensor | None, -) -> torch.Tensor: - hp_weight = dequantize_to_dtype( - fp4_weight, - weight_scaling_factor, - weight_global_scale, - activation.dtype, - fp4_weight.device, - 16, - ) - return activation @ hp_weight + bias - - -@torch.library.register_fake("nvf_cutlass::f16a_nvfp4weight_scaled_mm") -def _( - activation: torch.Tensor, - fp4_weight: torch.Tensor, - weight_scaling_factor: torch.Tensor, - weight_global_scale: torch.Tensor, - bias: torch.Tensor | None, -) -> torch.Tensor: - return torch.empty( - (activation.size(0), fp4_weight.size(0)), - device=activation.device, - dtype=activation.dtype, - ) - - -# TODO: Update this accordingly to the progress of nvfp4 kernel implementation. -# An alternative is to use `_register_nvfuser_translator` of https://github.com/Lightning-AI/lightning-thunder/pull/2481 -# instead of updating this function itself. -@torch.library.custom_op( - "nvf_cutlass::f16a_nvfp4weight_scaled_grouped_mm", mutates_args=() -) +# NOTE: This custom op is registered with nvfuser translator in benchmark_inference.py +# using _register_nvfuser_translator. See benchmark_inference._register_nvfp4_ops(). +@torch.library.custom_op("nvf_cutlass::f16a_nvfp4weight_scaled_grouped_mm", mutates_args=()) def nvfuser_f16a_nvfp4weight_scaled_grouped_mm( activation: torch.Tensor, fp4_weight: torch.Tensor, weight_scaling_factor: torch.Tensor, weight_global_scale: torch.Tensor, - ab_strides: torch.Tensor, - c_strides: torch.Tensor, offsets: torch.Tensor, blockscale_offsets: torch.Tensor, problem_sizes: torch.Tensor, ) -> torch.Tensor: hp_weight = torch.empty( - (fp4_weight.size(0), fp4_weight.size(1), fp4_weight.size(2) * 2), + (fp4_weight.size(0), fp4_weight.size(1) * 2, fp4_weight.size(2)), device=activation.device, dtype=activation.dtype, ) for i in range(fp4_weight.size(0)): + # NOTE: dequantize here doesn't look right, since we have (g, k, n) hp_weight[i] = dequantize_to_dtype( - fp4_weight[i], - weight_scaling_factor[i], - weight_global_scale[i], - activation.dtype, - fp4_weight.device, - 16, + fp4_weight[i], weight_scaling_factor[i], weight_global_scale[i], activation.dtype, fp4_weight.device, 16 ) return grouped_mm(activation, hp_weight, offsets) @@ -282,98 +233,60 @@ def _( fp4_weight: torch.Tensor, weight_scaling_factor: torch.Tensor, weight_global_scale: torch.Tensor, - ab_strides: torch.Tensor, - c_strides: torch.Tensor, offsets: torch.Tensor, blockscale_offsets: torch.Tensor, problem_sizes: torch.Tensor, ) -> torch.Tensor: - return torch.empty( - (activation.size(0), fp4_weight.size(1)), - device=activation.device, - dtype=activation.dtype, - ) - - -class NVFP4InferenceLinear(nn.Module): - """NVFP4 Linear layer for Inference. - - Weight, its scaling factor, its global scale, and bias are registered as a buffer, not a parameter. - """ - - def __init__( - self, - in_features: int, - out_features: int, - *, - fp4_weight: torch.Tensor | nn.Parameter, - weight_scaling_factor: torch.Tensor | nn.Parameter, - weight_global_scale: torch.Tensor | nn.Parameter | None, - bias: torch.Tensor | nn.Parameter | None, - ) -> None: - super().__init__() - self.in_features = in_features - self.out_features = out_features - - self.register_buffer("fp4_weight", fp4_weight) - self.register_buffer("weight_scaling_factor", weight_scaling_factor) - self.register_buffer("weight_global_scale", weight_global_scale) - self.register_buffer("bias", bias) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return torch.ops.nvf_cutlass.f16a_nvfp4weight_scaled_mm( - x, - self.fp4_weight, - self.weight_scaling_factor, - self.weight_global_scale, - self.bias, + # fp4_weight shape: (groups, in_features // 2, out_features) + # Validate that activation has at least 1 dimension + if activation.ndim == 0: + raise ValueError(f"Expected activation to have at least 1 dimension, got {activation.ndim}") + + if ( + len( + { + t.device + for t in [ + activation, + fp4_weight, + weight_scaling_factor, + weight_global_scale, + offsets, + blockscale_offsets, + problem_sizes, + ] + } ) + != 1 + ): + raise ValueError("Expected all inputs to be on the same device.") - @staticmethod - def from_linear(linear: nn.Linear) -> NVFP4InferenceLinear: - weight = linear.weight - bias = linear.bias - out_features, in_features = weight.size() - ( - fp4_weight, - weight_scaling_factor, - weight_global_scale, - ) = quantize_linear_weight_to_nvfp4(weight) - return NVFP4InferenceLinear( - in_features, - out_features, - fp4_weight=fp4_weight, - weight_scaling_factor=weight_scaling_factor, - weight_global_scale=weight_global_scale, - bias=bias, - ) + # After unpacking: (groups, in_features, out_features) + # Output shape should match activation.shape[:-1] + (out_features,) + # This handles both 2D (tokens, hidden) and 3D (batch, seq_len, hidden) inputs + out_features = fp4_weight.size(2) + output_shape = activation.shape[:-1] + (out_features,) + return torch.empty(output_shape, device=activation.device, dtype=torch.bfloat16) class SwiGLU(nn.Module): - def __init__( - self, hidden_size: int, intermediate_size: int, dtype: torch.dtype, device: str - ): + def __init__(self, hidden_size: int, intermediate_size: int, dtype: torch.dtype, device: str): super().__init__() - self.gate_proj = nn.Linear( - hidden_size, intermediate_size, bias=False, dtype=dtype, device=device - ) - self.up_proj = nn.Linear( - hidden_size, intermediate_size, bias=False, dtype=dtype, device=device - ) - self.down_proj = nn.Linear( - intermediate_size, hidden_size, bias=False, dtype=dtype, device=device - ) + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False, dtype=dtype, device=device) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False, dtype=dtype, device=device) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False, dtype=dtype, device=device) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - return self.down_proj( - torch.nn.functional.silu(self.gate_proj(hidden_states)) - * self.up_proj(hidden_states) - ) + return self.down_proj(torch.nn.functional.silu(self.gate_proj(hidden_states)) * self.up_proj(hidden_states)) def _group_sizes_from_offsets(offsets: torch.Tensor) -> list[int]: group_sizes = [] prev = 0 + if isinstance(offsets, DTensor): + assert offsets.placements == (Replicate(),) + offsets = offsets.to_local() + for offset in offsets: group_sizes.append(offset - prev) prev = offset @@ -395,50 +308,40 @@ def grouped_mm(a: torch.Tensor, b: torch.Tensor, offsets: torch.Tensor) -> torch group_sizes = _group_sizes_from_offsets(offsets) group_outs = [] - for group_a, group_b in zip(a.split(group_sizes), b.unbind()): - group_outs.append(group_a @ group_b) + for idx, group_a in enumerate(a.split(group_sizes)): + group_outs.append(group_a @ b[idx]) return torch.cat(group_outs) class GroupedLinear(nn.Module): - def __init__( - self, - groups: int, - in_features: int, - out_features: int, - dtype: torch.dtype, - device: str, - ): + def __init__(self, groups: int, in_features: int, out_features: int, dtype: torch.dtype, device: str): super().__init__() - self.weight = nn.Parameter( - torch.empty(groups, in_features, out_features, dtype=dtype, device=device) - ) + self.weight = nn.Parameter(torch.empty(groups, out_features, in_features, dtype=dtype, device=device)) # Initialize the weight in the same way as nn.Linear nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) - def forward( - self, hidden_states: torch.Tensor, offsets: torch.Tensor - ) -> torch.Tensor: - return grouped_mm(hidden_states, self.weight, offsets) + def forward(self, hidden_states: torch.Tensor, offsets: torch.Tensor) -> torch.Tensor: + return grouped_mm(hidden_states, self.weight.transpose(-1, -2), offsets) @torch.inference_mode() def quantize_grouped_linear_weight_to_nvfp4( weight: torch.Tensor | nn.Parameter, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Quantize grouped linear's weight to nvfp4 Args: weight: Parameter of `GroupedLinear` of [g, n, k] - m: hidden_states.size(0) - tokens_per_expert_neg_one: Returns: - fp4_weight: [g, n, k // 2] + fp4_weight: [g, k // 2, n] scale_factors: [g, n, k // 16] global_scales: [g] - ab_strides: [g] - c_strides: [g] + + Note: + The reason we choose different layout of weight is to avoid performance + regression for bf16. See + https://github.com/Lightning-AI/lightning-thunder/pull/2659 """ assert weight.ndim == 3, "Weight must be a 3D tensor" @@ -446,23 +349,19 @@ def quantize_grouped_linear_weight_to_nvfp4( g, n, k = weight.size() with device: - ab_strides = torch.full((g,), k, dtype=torch.int32) - c_strides = torch.full((g,), n, dtype=torch.int32) - fp4_weight = torch.empty((g, n, k // 2), dtype=torch.float4_e2m1fn_x2) global_scales = torch.empty((g,), dtype=torch.float32) scale_factors = torch.empty((g, n, k // 16), dtype=torch.float8_e4m3fn) + weight = weight.contiguous() for i in range(g): cur_weight = weight[i] global_scales[i] = cur_weight.abs().amax() - cur_fp4_weight, cur_scale_factors = pytorch_nvfp4_quantize( - cur_weight, global_scales[i] - ) + cur_fp4_weight, cur_scale_factors = pytorch_nvfp4_quantize(cur_weight, global_scales[i]) fp4_weight[i] = cur_fp4_weight scale_factors[i] = linear_to_swizzled_128_4(cur_scale_factors) - return fp4_weight, scale_factors, global_scales, ab_strides, c_strides + return fp4_weight.transpose(-1, -2), scale_factors, global_scales class NVFP4InferenceGroupedLinear(nn.Module): @@ -471,91 +370,149 @@ def __init__( fp4_weight: torch.Tensor, weight_scaling_factor: torch.Tensor, weight_global_scale: torch.Tensor, - ab_strides: torch.Tensor, - c_strides: torch.Tensor, ) -> None: + super().__init__() self.register_buffer("fp4_weight", fp4_weight) self.register_buffer("weight_scaling_factor", weight_scaling_factor) self.register_buffer("weight_global_scale", weight_global_scale) - self.register_buffer("ab_strides", ab_strides) - self.register_buffer("c_strides", c_strides) - # TODO: Update this accordingly to the progress of nvfp4 kernel implementation. - def forward( - self, hidden_states: torch.Tensor, offsets: torch.Tensor - ) -> torch.Tensor: + @property + def out_features(self) -> int: + return self.fp4_weight.size(2) + + @property + def in_features(self) -> int: + return self.fp4_weight.size(1) * 2 + + @staticmethod + def compute_auxiliary_tensors( + hidden_states: torch.Tensor, + offsets: torch.Tensor, + out_features: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Compute blockscale_offsets and problem_sizes for grouped mm. + + These can be computed once and reused across multiple forward calls with the same offsets. + """ tokens_per_group = offsets[1:] - offsets[:-1] problem_sizes = torch.stack( [ tokens_per_group, - torch.full_like(tokens_per_group, hidden_states.size(0)), - torch.full_like(tokens_per_group, self.fp4_weight.size(2) * 2), + torch.full_like(tokens_per_group, out_features), + torch.full_like(tokens_per_group, hidden_states.size(1)), ], dim=1, ) - blockscale_offsets = torch.cumsum(torch.ceil(tokens_per_group, 128) * 128) + # Calculate block-scale offsets: round up to 128, then cumsum with initial 0 + rounded_tokens = ((tokens_per_group + 127) // 128) * 128 + blockscale_offsets = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=tokens_per_group.device), + torch.cumsum(rounded_tokens, 0, dtype=torch.int32), + ] + )[0:-1] + return blockscale_offsets, problem_sizes + + # TODO: Update this accordingly to the progress of nvfp4 kernel implementation. + def forward( + self, + hidden_states: torch.Tensor, + offsets: torch.Tensor, + blockscale_offsets: torch.Tensor | None = None, + problem_sizes: torch.Tensor | None = None, + ) -> torch.Tensor: + if blockscale_offsets is None or problem_sizes is None: + # Compute them if not provided (backward compatibility) + out_features = self.out_features + blockscale_offsets, problem_sizes = self.compute_auxiliary_tensors(hidden_states, offsets, out_features) return torch.ops.nvf_cutlass.f16a_nvfp4weight_scaled_grouped_mm( hidden_states, self.fp4_weight, self.weight_scaling_factor, self.weight_global_scale, - self.ab_strides, - self.c_strides, - offsets, + offsets[:-1], blockscale_offsets, problem_sizes, ) @staticmethod - def from_grouped_linear( - grouped_linear: GroupedLinear, - ) -> NVFP4InferenceGroupedLinear: + def from_grouped_linear(grouped_linear: GroupedLinear, fqn: str | None = None) -> NVFP4InferenceGroupedLinear: + """Create an NVFP4InferenceGroupedLinear from a GroupedLinear. + + Args: + grouped_linear (GroupedLinear): The source GroupedLinear. + fqn (str or None): Fully qualified name. Currently unused; reserved for future use or compatibility. + """ weight = grouped_linear.weight - ( - fp4_weight, - weight_scaling_factor, - weight_global_scale, - ab_strides, - c_strides, - ) = quantize_grouped_linear_weight_to_nvfp4(weight) + fp4_weight, weight_scaling_factor, weight_global_scale = quantize_grouped_linear_weight_to_nvfp4(weight) return NVFP4InferenceGroupedLinear( fp4_weight, weight_scaling_factor, weight_global_scale, - ab_strides=ab_strides, - c_strides=c_strides, ) class GroupedSwiGLU(nn.Module): + def __init__(self, groups: int, hidden_size: int, intermediate_size: int, dtype: torch.dtype, device: str): + super().__init__() + self.gate_proj = GroupedLinear(groups, hidden_size, intermediate_size, dtype, device) + self.up_proj = GroupedLinear(groups, hidden_size, intermediate_size, dtype, device) + self.down_proj = GroupedLinear(groups, intermediate_size, hidden_size, dtype, device) + + def forward(self, hidden_states: torch.Tensor, offsets: torch.Tensor) -> torch.Tensor: + return self.down_proj( + torch.nn.functional.silu(self.gate_proj(hidden_states, offsets)) * self.up_proj(hidden_states, offsets), + offsets, + ) + + +class NVFP4InferenceGroupedSwiGLU(nn.Module): + """NVFP4 GroupedSwiGLU that efficiently reuses auxiliary tensor computations.""" + def __init__( self, - groups: int, - hidden_size: int, - intermediate_size: int, - dtype: torch.dtype, - device: str, + gate_proj: NVFP4InferenceGroupedLinear, + up_proj: NVFP4InferenceGroupedLinear, + down_proj: NVFP4InferenceGroupedLinear, ): super().__init__() - self.gate_proj = GroupedLinear( - groups, hidden_size, intermediate_size, dtype, device - ) - self.up_proj = GroupedLinear( - groups, hidden_size, intermediate_size, dtype, device - ) - self.down_proj = GroupedLinear( - groups, intermediate_size, hidden_size, dtype, device + self.gate_proj = gate_proj + self.up_proj = up_proj + self.down_proj = down_proj + + def forward(self, hidden_states: torch.Tensor, offsets: torch.Tensor) -> torch.Tensor: + # Compute auxiliary tensors once for all three operations + intermediate_features = self.gate_proj.out_features + blockscale_offsets_gate, problem_sizes_gate = NVFP4InferenceGroupedLinear.compute_auxiliary_tensors( + hidden_states, offsets, intermediate_features ) - def forward( - self, hidden_states: torch.Tensor, offsets: torch.Tensor - ) -> torch.Tensor: - return self.down_proj( - torch.nn.functional.silu(self.gate_proj(hidden_states, offsets)) - * self.up_proj(hidden_states, offsets), - offsets, + gate_out = self.gate_proj(hidden_states, offsets, blockscale_offsets_gate, problem_sizes_gate) + up_out = self.up_proj(hidden_states, offsets, blockscale_offsets_gate, problem_sizes_gate) + + intermediate = torch.nn.functional.silu(gate_out) * up_out + + # For down_proj, we need different problem_sizes (different output features) + hidden_features = self.down_proj.out_features + blockscale_offsets_down, problem_sizes_down = NVFP4InferenceGroupedLinear.compute_auxiliary_tensors( + intermediate, offsets, hidden_features ) + return self.down_proj(intermediate, offsets, blockscale_offsets_down, problem_sizes_down) + + @staticmethod + def from_grouped_swiglu(grouped_swiglu: GroupedSwiGLU, fqn: str | None = None) -> NVFP4InferenceGroupedSwiGLU: + """Create an NVFP4InferenceGroupedSwiGLU from a GroupedSwiGLU. + + Args: + grouped_swiglu (GroupedSwiGLU): The source GroupedSwiGLU. + fqn (str or None): Fully qualified name. Currently unused; reserved for future use or compatibility. + """ + gate_proj = NVFP4InferenceGroupedLinear.from_grouped_linear(grouped_swiglu.gate_proj) + up_proj = NVFP4InferenceGroupedLinear.from_grouped_linear(grouped_swiglu.up_proj) + down_proj = NVFP4InferenceGroupedLinear.from_grouped_linear(grouped_swiglu.down_proj) + return NVFP4InferenceGroupedSwiGLU(gate_proj, up_proj, down_proj) + # Slightly modified version of `thunder.tests.test_networks.Llama4MoE` # to have the same singature as transformers' Llama4TextMoe -- in this file @@ -609,15 +566,9 @@ def from_transformers_llama4textmoe(moe: Llama4TextMoe) -> Llama4MoE: new_moe.gate.weight.data.copy_(moe.router.weight.data) # 4. Copy the shared expert weights - new_moe.shared_experts.gate_proj.weight.data.copy_( - moe.shared_expert.gate_proj.weight.data - ) - new_moe.shared_experts.up_proj.weight.data.copy_( - moe.shared_expert.up_proj.weight.data - ) - new_moe.shared_experts.down_proj.weight.data.copy_( - moe.shared_expert.down_proj.weight.data - ) + new_moe.shared_experts.gate_proj.weight.data.copy_(moe.shared_expert.gate_proj.weight.data) + new_moe.shared_experts.up_proj.weight.data.copy_(moe.shared_expert.up_proj.weight.data) + new_moe.shared_experts.down_proj.weight.data.copy_(moe.shared_expert.down_proj.weight.data) # 5. For the routed experts, we need to handle the combined gate_up_proj # to match GroupedLinear @@ -628,19 +579,17 @@ def from_transformers_llama4textmoe(moe: Llama4TextMoe) -> Llama4MoE: # Split into gate and up projections gate_proj_w, up_proj_w = moe.experts.gate_up_proj.chunk(2, dim=2) - new_moe.routed_experts.gate_proj.weight.data.copy_(gate_proj_w) - new_moe.routed_experts.up_proj.weight.data.copy_(up_proj_w) + new_moe.routed_experts.gate_proj.weight.data.copy_(gate_proj_w.transpose(-1, -2)) + new_moe.routed_experts.up_proj.weight.data.copy_(up_proj_w.transpose(-1, -2)) # Handle down_proj # HF format: (groups, intermediate_size, hidden_size) # Our format: (groups, hidden, intermediate_size) - new_moe.routed_experts.down_proj.weight.data.copy_(moe.experts.down_proj) + new_moe.routed_experts.down_proj.weight.data.copy_(moe.experts.down_proj.transpose(-1, -2)) return new_moe - def run_routed_experts( - self, hidden_states: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: + def run_routed_experts(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: batch_size, seq_len, _ = hidden_states.size() hidden_states = hidden_states.view(-1, hidden_states.size(-1)) # [s, h] @@ -659,29 +608,24 @@ def run_routed_experts( tokens_per_expert = counts.sum(0) # [n] token_ids_sorted_by_expert_id = topk_ids.view(-1).argsort() # [s] - tokens_sorted_by_expert_id = hidden_states[ - token_ids_sorted_by_expert_id - ] # [s, h] + tokens_sorted_by_expert_id = hidden_states[token_ids_sorted_by_expert_id] # [s, h] # Without `torch.int32`, we see `RuntimeError: Offsets tensor must be integer (int32) tensor, but got torch.int64.` # from PyTorch when calling _grouped_mm. - offsets = torch.cumsum(tokens_per_expert, 0, dtype=torch.int32) # [n] - outs_sorted_by_expert_id = self.routed_experts( - tokens_sorted_by_expert_id, offsets - ) # [s, h] + # Prepend 0 to offsets for correct grouping + offsets = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=tokens_per_expert.device), + torch.cumsum(tokens_per_expert, 0, dtype=torch.int32), + ] + )[:-1] # [n] + outs_sorted_by_expert_id = self.routed_experts(tokens_sorted_by_expert_id, offsets) # [s, h] - token_ids_sorted_by_expert_inverse_id = torch.argsort( - token_ids_sorted_by_expert_id - ) - outs_sorted_by_token_id = outs_sorted_by_expert_id[ - token_ids_sorted_by_expert_inverse_id - ] + token_ids_sorted_by_expert_inverse_id = torch.argsort(token_ids_sorted_by_expert_id) + outs_sorted_by_token_id = outs_sorted_by_expert_id[token_ids_sorted_by_expert_inverse_id] - return outs_sorted_by_token_id, router_logits + return outs_sorted_by_token_id.view(batch_size, seq_len, -1), router_logits.view(batch_size, seq_len, -1) def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: outs_sorted_by_token_id, router_logits = self.run_routed_experts(hidden_states) - return ( - self.shared_experts(hidden_states) + outs_sorted_by_token_id, - router_logits, - ) + return self.shared_experts(hidden_states) + outs_sorted_by_token_id, router_logits From 4c6c47eca35dc290b273c8ac1a93a98a8d7af132 Mon Sep 17 00:00:00 2001 From: tbqh Date: Fri, 21 Nov 2025 12:29:02 -0800 Subject: [PATCH 2/6] Add references to lightning-thunder --- benchmarks/python/benchmark_inference.py | 3 +++ benchmarks/python/layers_for_inference_benchmark.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/benchmarks/python/benchmark_inference.py b/benchmarks/python/benchmark_inference.py index 68ff62c7af3..38b89d2c95c 100644 --- a/benchmarks/python/benchmark_inference.py +++ b/benchmarks/python/benchmark_inference.py @@ -7,6 +7,9 @@ - Latency (ms/token) - Time to First Token (TTFT) - Time Between Output Tokens (TBOT) + +Pulled from the lightning-thunder repo. Reference: +https://github.com/Lightning-AI/lightning-thunder/blob/4d3a3c3a7481efdc6a23cdeea99c3ffd31af5e78/thunder/benchmarks/benchmark_inference.py """ # fmt: off diff --git a/benchmarks/python/layers_for_inference_benchmark.py b/benchmarks/python/layers_for_inference_benchmark.py index 7a718fc9479..25aa4a151c3 100644 --- a/benchmarks/python/layers_for_inference_benchmark.py +++ b/benchmarks/python/layers_for_inference_benchmark.py @@ -11,6 +11,9 @@ # SPDX-License-Identifier: BSD-3-Clause # # NOTE: `pytorch_nvfp4_quantize` and `linear_to_swizzled_128_4` are copied from NVIDIA's Fuser's test code. +# +# Pulled from the lightning-thunder repo. Reference: +# https://github.com/Lightning-AI/lightning-thunder/blob/4d3a3c3a7481efdc6a23cdeea99c3ffd31af5e78/thunder/benchmarks/layers_for_inference_benchmark.py # fmt: off From e18159578ca1c5bef94c2ab4c2d1d6f3b670af44 Mon Sep 17 00:00:00 2001 From: tbqh Date: Fri, 21 Nov 2025 12:33:21 -0800 Subject: [PATCH 3/6] Pull thunder PR "Use torch._grouped_mm in eager mode" https://github.com/Lightning-AI/lightning-thunder/pull/2721 --- benchmarks/python/layers_for_inference_benchmark.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/benchmarks/python/layers_for_inference_benchmark.py b/benchmarks/python/layers_for_inference_benchmark.py index 25aa4a151c3..9b96146f8fc 100644 --- a/benchmarks/python/layers_for_inference_benchmark.py +++ b/benchmarks/python/layers_for_inference_benchmark.py @@ -297,16 +297,14 @@ def _group_sizes_from_offsets(offsets: torch.Tensor) -> list[int]: if LooseVersion(torch.__version__) >= LooseVersion("2.8.0"): - # Required otherwise, there is a graph-break. + # Required -- otherwise there is a graph-break. _grouped_mm = torch.compiler.allow_in_graph(torch._grouped_mm) +else: + _grouped_mm = None -# This function should be replaced with torch._grouped_mm. However, -# torch._grouped_mm is yet to be usable because it requires offsets being -# multiples of 16. def grouped_mm(a: torch.Tensor, b: torch.Tensor, offsets: torch.Tensor) -> torch.Tensor: - if torch.compiler.is_compiling(): - # NOTE: This path also works for `thunder.jit` as it has a lookaside for `torch.compiler.is_compiling`. + if _grouped_mm is not None: return _grouped_mm(a, b, offsets) group_sizes = _group_sizes_from_offsets(offsets) From b21af254bada16f30810a95fac055cdc6293e2ac Mon Sep 17 00:00:00 2001 From: tbqh Date: Fri, 21 Nov 2025 12:35:51 -0800 Subject: [PATCH 4/6] Pull thunder PR "Remove the --profile option" https://github.com/Lightning-AI/lightning-thunder/pull/2715 --- benchmarks/python/benchmark_inference.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/benchmarks/python/benchmark_inference.py b/benchmarks/python/benchmark_inference.py index 38b89d2c95c..ec09745510d 100644 --- a/benchmarks/python/benchmark_inference.py +++ b/benchmarks/python/benchmark_inference.py @@ -224,7 +224,6 @@ class InferenceBenchmarkConfig: mode: str disable_moe_replacement: bool attn_implementation: str | None - profile: bool thunder_cache: str | None enable_thunder_cudagraph: bool @@ -557,10 +556,17 @@ def run_benchmark(self) -> InferenceMetrics: for _ in tqdm(range(self.config.num_iterations), disable=LOCAL_RANK != 0): past_key_values.reset() - if self.config.profile: + is_under_nsys = bool(os.environ.get("NSYS_PROFILING_SESSION_ID")) + # Wrap each non-warmup iteration with cudaProfilerStart() and + # cudaProfilerStop(). This allows the user to run + # ```shell + # nsys profile --capture-range=cudaProfilerApi --capture-range-end=repeat: ... + # ``` + # to record only the non-warmup iterations. + if is_under_nsys: torch.cuda.cudart().cudaProfilerStart() iter_metrics = self.measure_inference_step(input_ids, past_key_values, self.config.output_length) - if self.config.profile: + if is_under_nsys: torch.cuda.cudart().cudaProfilerStop() all_metrics.append(iter_metrics) @@ -748,11 +754,6 @@ def parse_args() -> argparse.Namespace: action="store_true", help="let nvfuser take care of linear and matmul, note that this might fail with distributed run. See: https://github.com/NVIDIA/Fuser/issues/4507", ) - parser.add_argument( - "--profile", - action="store_true", - help="Wrap each non-warmup iteration with cudaProfilerStart() and cudaProfilerStop(). This allows us to run `nsys profile --capture-range=cudaProfilerApi --capture-range-end=repeat: ... --profile` to record only the non-warmup iterations.", - ) parser.add_argument( "--thunder-trace", @@ -801,7 +802,6 @@ def main(): enable_nv_linear=args.enable_nv_linear, disable_moe_replacement=args.disable_moe_replacement, attn_implementation=args.attn_implementation, - profile=args.profile, thunder_cache=args.thunder_cache, enable_thunder_cudagraph=args.enable_thunder_cudagraph, ) From be28360677e4d7efa486cd3123911c5b74abcd79 Mon Sep 17 00:00:00 2001 From: tbqh Date: Fri, 21 Nov 2025 12:46:01 -0800 Subject: [PATCH 5/6] Add SPDX header back to file --- benchmarks/python/benchmark_inference.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/benchmarks/python/benchmark_inference.py b/benchmarks/python/benchmark_inference.py index ec09745510d..514ca0d7a3d 100644 --- a/benchmarks/python/benchmark_inference.py +++ b/benchmarks/python/benchmark_inference.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + """Inference benchmark focusing on throughput and latency metrics of prefill and decode phases. AutoModelForCausalLM from Hugging Face transformers is used for model implementation. From f5d05ae2d0c189835e2f564f817b2979b5ea5a8e Mon Sep 17 00:00:00 2001 From: tbqh Date: Fri, 21 Nov 2025 15:07:35 -0800 Subject: [PATCH 6/6] Simplify if statement --- benchmarks/python/benchmark_inference.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/benchmarks/python/benchmark_inference.py b/benchmarks/python/benchmark_inference.py index 514ca0d7a3d..d383a50920c 100644 --- a/benchmarks/python/benchmark_inference.py +++ b/benchmarks/python/benchmark_inference.py @@ -285,7 +285,18 @@ def __init__(self, config: InferenceBenchmarkConfig): "*.layers.*.feed_forward.down_proj": RowwiseParallel(use_local_output=True), } - if not self.config.disable_moe_replacement: + if self.config.disable_moe_replacement: + tp_plan.update( + { + # HF MoE + "*.layers.*.feed_forward.shared_expert.gate_proj": ColwiseParallel(use_local_output=False), + "*.layers.*.feed_forward.shared_expert.up_proj": ColwiseParallel(use_local_output=False), + "*.layers.*.feed_forward.shared_expert.down_proj": RowwiseParallel(use_local_output=True), + # TODO:Need to write ParallelStyle for HF's grouped_mm implementation. + } + ) + + else: tp_plan.update( { # Custom MoE @@ -306,17 +317,6 @@ def __init__(self, config: InferenceBenchmarkConfig): } ) - else: - tp_plan.update( - { - # HF MoE - "*.layers.*.feed_forward.shared_expert.gate_proj": ColwiseParallel(use_local_output=False), - "*.layers.*.feed_forward.shared_expert.up_proj": ColwiseParallel(use_local_output=False), - "*.layers.*.feed_forward.shared_expert.down_proj": RowwiseParallel(use_local_output=True), - # TODO:Need to write ParallelStyle for HF's grouped_mm implementation. - } - ) - if mesh: model = parallelize_module(model, mesh, tp_plan)