From 9b92f11f5aea7517093f748903c811564125b81b Mon Sep 17 00:00:00 2001 From: jung-min Date: Mon, 2 Mar 2026 07:58:59 +0000 Subject: [PATCH 01/31] [Frontend/template] add SDPA modules --- .../torch_openreg/openreg/__init__.py | 7 +- PyTorchSimFrontend/mlir/mlir_lowering.py | 25 +- PyTorchSimFrontend/mlir/mlir_sdpa_template.py | 664 ++++++++++++++++++ PyTorchSimFrontend/mlir/mlir_template.py | 101 ++- tests/test_sdpa.py | 84 +++ 5 files changed, 878 insertions(+), 3 deletions(-) create mode 100644 PyTorchSimFrontend/mlir/mlir_sdpa_template.py create mode 100644 tests/test_sdpa.py diff --git a/PyTorchSimDevice/torch_openreg/openreg/__init__.py b/PyTorchSimDevice/torch_openreg/openreg/__init__.py index 8d62cee3..5a0de6c3 100644 --- a/PyTorchSimDevice/torch_openreg/openreg/__init__.py +++ b/PyTorchSimDevice/torch_openreg/openreg/__init__.py @@ -24,7 +24,7 @@ class device: def __init__(self, device): self.idx = torch.accelerator._get_device_index(device, optional=True) - self.prev_idx = -1 + self.prev_idx = -1 def __enter__(self): self.prev_idx = torch_openreg._C._exchangeDevice(self.idx) @@ -64,6 +64,11 @@ def _lazy_init(): global _initialized, _tog_simulator if is_initialized(): return + + # Replace the global C++ binding with our custom dispatcher patch + from PyTorchSimFrontend.mlir.mlir_sdpa_template import patched_scaled_dot_product_attention + torch._C._nn.scaled_dot_product_attention = patched_scaled_dot_product_attention + torch_openreg._C._init() register_interface_for_device(custom_device(), ExtensionDeviceInterface) _initialized = True diff --git a/PyTorchSimFrontend/mlir/mlir_lowering.py b/PyTorchSimFrontend/mlir/mlir_lowering.py index ebf0c80e..e09dcf57 100644 --- a/PyTorchSimFrontend/mlir/mlir_lowering.py +++ b/PyTorchSimFrontend/mlir/mlir_lowering.py @@ -15,6 +15,7 @@ from PyTorchSimFrontend.mlir.mlir_conv_sb_template import MLIRConvSingleBatchTemplate from PyTorchSimFrontend.mlir.mlir_conv_sbs_template import MLIRConvSingleBatchStridedTemplate from PyTorchSimFrontend.mlir.mlir_maxpool_template import MLIRMaxPoolTemplate +from PyTorchSimFrontend.mlir.mlir_sdpa_template import MLIRFlashSDPATemplate, flash_sdpa_args from PyTorchSimFrontend import extension_config aten = torch.ops.aten @@ -38,6 +39,26 @@ def tuned_bmm(mat1, mat2, *, layout=None): return mlir_template.generate().output_node() + +def tuned_flash_sdpa( + query : TensorBox, + key : TensorBox, + value : TensorBox, + scale : float, + dropout_p : float = 0.0, + is_causal : bool = False, + return_debug_mask : bool =False) -> tuple: + + print("Enter tuned_flash_sdpa") + + N, Hq, H, L, S, E, Ev, layout, query, key, value = flash_sdpa_args(query, key, value) + mlir_template = MLIRFlashSDPATemplate([query, key, value], layout, scale) + + # _scaled_dot_product_flash_attention has to return a tuple which has 9 values + # since its backward(_scaled_dot_product_flash_attention_backward) needs that values. + # (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor rng_state, Tensor unused, Tensor debug_attn_mask) + return (mlir_template.generate().output_node(), None, None, None, None, None, None, None, None) + def conv_layout( x: TensorBox, weight: TensorBox, @@ -188,4 +209,6 @@ def custom_unsafe_index(x, indices): lowerings.update({getattr(aten._sparse_addmm, overload): sparse_addmm for overload in aten._sparse_addmm.overloads()}) lowerings.update({getattr(aten._unsafe_index, overload): custom_unsafe_index for overload in aten._unsafe_index.overloads()}) if extension_config.CONFIG_USE_TIMING_POOLING: - lowerings.update({getattr(aten.max_pool2d_with_indices, overload): custom_maxpool for overload in aten.max_pool2d_with_indices.overloads()}) # FIXME: maxpool should be implemented as a template \ No newline at end of file + lowerings.update({getattr(aten.max_pool2d_with_indices, overload): custom_maxpool for overload in aten.max_pool2d_with_indices.overloads()}) # FIXME: maxpool should be implemented as a template + +lowerings.update({getattr(aten._scaled_dot_product_flash_attention, overload): tuned_flash_sdpa for overload in aten._scaled_dot_product_flash_attention.overloads()}) \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_sdpa_template.py b/PyTorchSimFrontend/mlir/mlir_sdpa_template.py new file mode 100644 index 00000000..b3d88cc6 --- /dev/null +++ b/PyTorchSimFrontend/mlir/mlir_sdpa_template.py @@ -0,0 +1,664 @@ +import math # sqrt +import sympy + +from typing import List, Optional + +import torch +from torch import empty_strided +from torch._inductor.ir import IRNode, TensorBox, FixedLayout +from torch._inductor.virtualized import V +from torch._inductor.select_algorithm import realize_inputs +from torch.backends.cuda import flash_sdp_enabled, mem_efficient_sdp_enabled + +from PyTorchSimFrontend import extension_config +from PyTorchSimFrontend.mlir import mlir_common +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel + + +def flash_sdpa_args( + query : TensorBox, + key : TensorBox, + value : TensorBox) -> list: + """ + Arg processing for flash SDPA. + Its logic is based on: + mm_args() which is in torch._inductor.kernel.mm_common.py (142 line). + """ + + # Materialize input buffers for the codegen backend. + query, key, value = realize_inputs(query, key, value) + + # query : (n, hq, l, e) + # key : (n, h, s, e) + # value : (n, h, s, ev) + # out : (n, hq, l, ev) + # n: Batch size + # hq: query's head counts, h: key and value's head counts. + # l: target sequence lenght and s: source sequence length. + # e: embeding dimension of the query and key and ev: embeding dimension of the value. + nq, hq, l, eq = query.get_size() + nk, hk, sk, ek = key.get_size() + nk, hv, sv, ev = value.get_size() + + n = V.graph.sizevars.guard_equals(nq, nk) + n = V.graph.sizevars.guard_equals(nq, nk) + + h = V.graph.sizevars.guard_equals(hk, hv) + s = V.graph.sizevars.guard_equals(sk, sv) + e = V.graph.sizevars.guard_equals(eq, ek) + + # While there are no theoretical requirements for e == ev, + # this implementation enforces e == ev for simplicity. + # Distinct notations are still maintained to ensure future compatibility and clarity. + if e != ev: + raise NotImplementedError("Flash SDPA does not support mismatched head dimensions between query and value.") + + # Flash attention does not split tiles along the head dimension (e or ev). + # Therefore, the head dimension size must be less than or equal to the number of vlanes. + vector_lane = extension_config.vpu_num_lanes + if e > vector_lane or ev > vector_lane: + raise ValueError(f"The head dimension size must be less than or equal to the number of vlanes (e: {e}, ev: {ev}, vlanes: {vector_lane}).") + + # The aten._scaled_dot_product_flash_attention kernel does not accept an explicit enable_gqa parameter. + # Instead, the Flash SDPA implementation infers GQA usage by checking if hq != hk. + # The Flash SDPA for GQA will be implemented after implementing its native version. + if hq != h : + raise NotImplementedError("Flash SDPA for GQA is not supported yet.") + + layout = FixedLayout( + query.get_device(), + query.get_dtype(), + [n, hq, l, ev] + ) + + return [n, hq, h, l, s, e, ev, layout, query, key, value] + +def validate_sdpa_input( + query : torch.Tensor, + key : torch.Tensor, + value : torch.Tensor, + attn_mask : torch.Tensor = None, + dropout_p : float = 0.0, + is_casual : bool = False, + scale : float = None, + enable_gqa : bool = False) -> None: + """ + Validates input tensors and parameters for Scaled Dot Product Attention (SDPA). + This function's logic can be found in: + https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/attention.cpp(504 line) + https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html + """ + + # Tensor class, dtype, and device consistency + # Ensure all primary inputs are torch.Tensors + if not all(isinstance(t, torch.Tensor) for t in [query, key, value]): + raise TypeError( + f"Expected query, key and value to be Tensors, but got " + f"{type(query).__name__}, {type(key).__name__}, and {type(value).__name__}." + ) + + # Check for dtype mismatch + if query.dtype != key.dtype or query.dtype != value.dtype: + raise TypeError( + f"Expected query, key, and value to have the same dtype, " + f"but got {query.dtype}, {key.dtype}, and {value.dtype}." + ) + + # Check for device mismatch (e.g., mixing CPU and NPU) + if query.device != key.device or query.device != value.device: + raise ValueError( + f"Expected query, key, and value to be on the same device, " + f"but got {query.device}, {key.device}, and {value.device}." + ) + + # Shape and dimension validation + # SDPA typically expects 4D (B, H, S, D), but we check for at least 2D here + if any(t.dim() < 2 for t in [query, key, value]): + raise ValueError( + f"Expected query, key, and value to be at least 2D, " + f"but got Q:{query.dim()}D, K:{key.dim()}D, V:{value.dim()}D." + ) + + # Attention mask validation + if attn_mask is not None: + if not isinstance(attn_mask, torch.Tensor): + raise TypeError(f"Expected attn_mask to be a Tensor, but got {type(attn_mask).__name__}.") + + # Dtype check: floating point masks must match query dtype; bool masks are also allowed + if attn_mask.dtype.is_floating_point: + if attn_mask.dtype != query.dtype: + raise TypeError(f"Floating point attn_mask must match query dtype ({query.dtype}), but got {attn_mask.dtype}.") + elif attn_mask.dtype != torch.bool: + raise TypeError(f"attn_mask must be floating point or bool, but got {attn_mask.dtype}.") + + # Nested tensor limitation with explicit masking + if query.is_nested or key.is_nested: + raise ValueError("Nested tensors are not supported when an explicit attn_mask is set.") + + # Dropout and causal flag validation (added) + # Dropout probability must be in the range [0, 1) + if not (0.0 <= dropout_p < 1.0): + raise ValueError(f"Expected dropout_p to be in [0, 1), but got {dropout_p}.") + + # Mutual exclusivity: cannot use both explicit mask and causal flag (added) + if is_casual and attn_mask is not None: + raise ValueError("Both attn_mask and is_casual cannot be set at the same time.") + + # Scaling factor validation (added) + if scale is not None and scale <= 0.0: + raise ValueError(f"Expected scale to be a positive number, but got {scale}.") + + # GQA (Grouped Query Attention) constraints (added) + n_head_q = query.size(1) + n_head_k = key.size(1) + n_head_v = value.size(1) + + # The aten._scaled_dot_product_flash_attention kernel does not accept an explicit enable_gqa parameter. + # Instead, the Flash SDPA implementation infers GQA usage by checking if n_head_q != n_head_k. + if not enable_gqa and n_head_q != n_head_k: + raise ValueError(f"Query and Key must have the same number of heads when enable_gqa is false (Q:{n_head_q} vs K:{n_head_k}).") + + if enable_gqa: + if n_head_q == n_head_k: + raise ValueError(f"enable_gqa Query and Key ") + + if n_head_k != n_head_v: + raise ValueError(f"Key and Value must have the same number of heads (K:{n_head_k} vs V:{n_head_v}).") + + # Query heads must be an integer multiple of key heads for grouping + if n_head_q % n_head_k != 0: + raise ValueError( + f"Number of query heads ({n_head_q}) must be divisible by " + f"number of key heads ({n_head_k}) for GQA." + ) + +def convert_boolean_attn_mask(attn_mask: torch.Tensor, target_dtype: torch.dtype) -> float: + """ + Equivalent to the C++ 'convert_boolean_attn_mask' function. + Converts a boolean mask to a floating-point mask for SDPA. + """ + + if attn_mask is not None and attn_mask.dtype == torch.bool: + + new_mask = torch.zeros_like(attn_mask, dtype=target_dtype) + minus_inf = torch.finfo(target_dtype).min + new_mask.masked_fill_(attn_mask.logical_not(), minus_inf) + + return new_mask + + return attn_mask + +def calculate_scale(query: torch.Tensor, scale: float) -> float: + """ + Calculate the scaling factor based on the head dimension if scale is None + Otherwise, use the provided scale. + """ + if scale is None: + return 1.0 / math.sqrt(query.size(-1)) + else: + return scale + +def patched_scaled_dot_product_attention( + query_ : torch.Tensor, + key : torch.Tensor, + value : torch.Tensor, + dropout_p : float = 0.0, + is_casual : bool = False, + attn_mask_ : torch.Tensor = None, + scale_ : float = None, + enable_gqa : bool = None, + orig_fn = torch._C._nn.scaled_dot_product_attention) -> torch.Tensor : + """ + Custom patch for Scaled Dot Product Attention (SDPA) to intercept high-level calls. + For NPU devices, it redirects execution to specific ATen kernels based on global flags. + For all devices, it maintains parity with the original dispatcher logic found in: + https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/attention.cpp + + This function acts as a custom override that replaces the default PyTorch SDPA implementation, + invoked via 'PyTorchSim/PyTorchSimDevice/torch_openreg/openreg/__init__.py'. + """ + + # Device-specific Dispatching: redirect to specialized kernels if on NPU + if "npu" in str(query_.device): + + validate_sdpa_input(query_, key, value, attn_mask_, dropout_p, is_casual, scale_, enable_gqa) + attn_mask = convert_boolean_attn_mask(attn_mask_, query_.dtype) + + # Kernel selection logic: emulate C++ dispatcher priority + # Selection priority(can be changed): flash attention > memory efficient > math (cuDNN is not supported) + aten = torch.ops.aten + scale = calculate_scale(query_, scale_) + + if flash_sdp_enabled(): + # Skip padding query, key and value for alignment. + dispatch_kwargs = { + "dropout_p" : dropout_p, + "is_causal" : is_casual, + "return_debug_mask" : False, + "scale" : scale + } + + out_lse_softmax = aten._scaled_dot_product_flash_attention( + query_, key, value, **dispatch_kwargs + ) + + return out_lse_softmax[0] + elif mem_efficient_sdp_enabled(): + # out_and_lse = aten._scaled_dot_product_efficient_attention(...) + # return out_and_lse[0] + raise NotImplementedError("Memory efficient SDPA is not implemented yet.") + else: + dispatch_kwargs = { + "attn_mask" : attn_mask, + "dropout_p" : dropout_p, + "is_causal" : is_casual, + "dropout_mask" : None, + "scale": scale, + "enable_gqa" : enable_gqa + } + + out_lse_softmax = aten._scaled_dot_product_attention_math( + query_, + key, + value, + **dispatch_kwargs) + + return out_lse_softmax[0] + else: + # Fallback: Delegate to the original C++ Dispatcher for other devices + return orig_fn(query_, key, value) + +FLASH_SDPA_TEMPLATE = r""" +// SDPA kernel +// b = {{ b }} +// l = {{ l }} +// s = {{ s }} +// e = {{ e }} +// tile_l = {{ tile_l }} +// tile_s = {{ tile_s }} +// tile_e = {{ tile_e }} +// subtile_l = {{ subtile_l }} +// subtile_s = {{ subtile_s }} +// subtile_e = {{ subtile_e }} +{{kernel.def_global_vars()}} + +func.func @{{ KERNEL_NAME }}{{kernel.def_kernel(inputs=[query, key, value], outputs=[out], names_str="query, key, value, out", input_reorder=input_reorder)}} { + // Inputs + {{ kernel.def_sram_buffer("query", q_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("key", k_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("value", v_tile_desc, indent_size=2) }} + + // Output + {{ kernel.def_sram_buffer("out", out_tile_desc, indent_size=2) }} + + // Intermediate buffers + {{ kernel.def_sram_buffer("mul", mul_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("max", max_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("sum", sum_desc, indent_size=2) }} + + // Constants + %c0 = arith.constant 0.0 : {{ data_stype }} + %c1 = arith.constant 1.0 : {{ data_stype }} + %c_scale = arith.constant {{ scale }} : {{ data_stype }} + %c_neg_inf = arith.constant -1.0e+30 : {{ data_stype }} + + %v0_c = arith.constant dense<0.0> : vector<{{ chunk_size }}x{{ data_stype }}> + %v0_l = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(tile_l, tile_e) }}x{{ data_stype }}> + %v0_s = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(tile_s, tile_l) }}x{{ data_stype }}> + %v0_2x = arith.constant dense<0.0> : vector<2x{{ data_stype }}> + + %v_neg_inf_c = arith.constant dense<-1.0e+30> : vector<{{ chunk_size }}x{{ data_stype }}> + %v_neg_inf_2x = arith.constant dense<-1.0e+30> : vector<2x{{ data_stype }}> + + %v_scale = vector.broadcast %c_scale : {{ data_stype }} to vector<{{ tile_s }}x{{ data_stype }}> + + {{ kernel.def_local_vars(indent_size=2) }} + + affine.for %index0 = 0 to {{ b }} { + affine.for %index3 = 0 to 1 step 1 { + affine.for %index1 = 0 to {{ l }} step {{ tile_l }} { + {{ kernel.def_dma_op("MVIN", "query", q_idx, q_tile_desc, subtile_size=[1, subtile_l, subtile_e], indent_size=8) }} + + affine.vector_store %v0_l, %out_buffer[0, 0, 0] : {{ out_tile_desc.get_mlir_shape(data_stype) }}, vector<{{ kernel.get_spad_size_per_lane(tile_l, tile_e) }}x{{ data_stype }}> + affine.vector_store %v_neg_inf_2x, %max_buffer[0, 0] : {{ max_desc.get_mlir_shape(data_stype) }}, vector<2x{{ data_stype }}> + affine.vector_store %v0_2x, %sum_buffer[0, 0] : {{ sum_desc.get_mlir_shape(data_stype) }}, vector<2x{{ data_stype }}> + + %qt_buffer2D = memref.reinterpret_cast %q_buffer to offset: [0], sizes: [{{ tile_e }}, {{ tile_l }}], strides: [{{ tile_l }}, 1] : {{ q_tile_desc.get_mlir_shape(data_stype) }} to memref<{{ tile_e }}x{{ tile_l }}x{{ data_stype }}, 1> + %ot_buffer2D = memref.reinterpret_cast %out_buffer to offset: [0], sizes: [{{ tile_e }}, {{ tile_l }}], strides: [{{ tile_l }}, 1] : {{ out_tile_desc.get_mlir_shape(data_stype) }} to memref<{{ tile_e }}x{{ tile_l }}x{{ data_stype }}, 1> + + affine.for %index2 = 0 to {{ s }} step {{ tile_s }} { + {{ kernel.def_dma_op("MVIN", "key", k_idx, k_tile_desc, subtile_size=[1, subtile_s, subtile_e], indent_size=10) }} + {{ kernel.def_dma_op("MVIN", "value", v_idx, v_tile_desc, subtile_size=[1, subtile_s, subtile_e], indent_size=10) }} + + affine.vector_store %v0_s, %mul_buffer[0, 0] : {{ mul_tile_desc.get_mlir_shape(data_stype) }}, vector<{{ kernel.get_spad_size_per_lane(tile_s, tile_l) }}x{{ data_stype }}> + + %k_buffer2D = memref.reinterpret_cast %k_buffer to offset: [0], sizes: [{{ tile_s }}, {{ tile_e }}], strides: [{{ tile_e }}, 1] : {{ k_tile_desc.get_mlir_shape(data_stype) }} to memref<{{ tile_s }}x{{ tile_e }}x{{ data_stype }}, 1> + %vt_buffer2D = memref.reinterpret_cast %v_buffer to offset: [0], sizes: [{{ tile_e }}, {{ tile_s }}], strides: [{{ tile_s }}, 1] : {{ v_tile_desc.get_mlir_shape(data_stype) }} to memref<{{ tile_e }}x{{ tile_s }}x{{ data_stype }}, 1> + + + // key @ query.t and scaling. + linalg.matmul + ins(%k_buffer2D, %qt_buffer2D : memref<{{ tile_s }}x{{ tile_e }}x{{ data_stype }}, 1>, memref<{{ tile_e }}x{{ tile_l }}x{{ data_stype }}, 1>) + outs(%mul_buffer : {{ mul_tile_desc.get_mlir_shape(data_stype) }}) + + %raw_mul_vec = affine.vector_load %mul_buffer[0, 0] : {{ mul_tile_desc.get_mlir_shape(data_stype) }}, vector<{{ tile_s }}x{{ data_stype }}> + %scaled_mul_vec = arith.mulf %raw_mul_vec, %v_scale : vector<{{ tile_s }}x{{ data_stype }}> + affine.vector_store %scaled_mul_vec, %mul_buffer[0, 0] : {{ mul_tile_desc.get_mlir_shape(data_stype) }}, vector<{{ tile_s }}x{{ data_stype }}> + + + // Find new max. + %old_max = affine.vector_load %max_buffer[0,0] : {{ max_desc.get_mlir_shape(data_stype) }}, vector<2x{{ data_stype }}> + + %chunk_max_res = affine.for %index5 = 0 to {{ tile_s }} step {{ chunk_size }} iter_args(%iter_max=%v_neg_inf_c) -> (vector<{{ chunk_size }}x{{ data_stype }}>) { + %chunk_val = affine.vector_load %mul_buffer[0, %index5] : {{ mul_tile_desc.get_mlir_shape(data_stype) }}, vector<{{ chunk_size }}x{{ data_stype }}> + %local_max = arith.maximumf %chunk_val, %iter_max : vector<{{ chunk_size }}x{{ data_stype }}> + affine.yield %local_max : vector<{{ chunk_size }}x{{ data_stype }}> + } + + %max_cast = vector.shape_cast %chunk_max_res : vector<{{ chunk_size }}x{{ data_stype }}> to vector<{{ chunk_size // 2 }}x2x{{ data_stype }}> + %max_reduced_1 = vector.multi_reduction , %max_cast, %v_neg_inf_2x [0] : vector<8x2x{{ data_stype }}> to vector<2x{{ data_stype }}> + %max_shuffled = vector.shuffle %max_reduced_1, %max_reduced_1 [1, 0] : vector<2x{{ data_stype }}>, vector<2x{{ data_stype }}> + %max_reduced_2 = arith.maximumf %max_reduced_1, %max_shuffled : vector<2x{{ data_stype }}> + + %new_max = arith.maximumf %max_reduced_2, %old_max : vector<2x{{ data_stype }}> + affine.vector_store %new_max, %max_buffer[0, 0] : {{ max_desc.get_mlir_shape(data_stype) }}, vector<2x{{ data_stype }}> + + + // Compute rescale factors: exp(old_max - new_max) + %max_diff = arith.subf %old_max, %new_max : vector<2x{{ data_stype }}> + %max_diff_scalar = vector.extract %max_diff[0] : {{ data_stype }} from vector<2x{{ data_stype }}> + + %rescale_bcast_e = vector.broadcast %max_diff_scalar : {{ data_stype }} to vector<{{ tile_e }}x{{ data_stype }}> + %exp_rescale_e = math.exp %rescale_bcast_e : vector<{{ tile_e }}x{{ data_stype }}> + + %rescale_bcast_2 = vector.broadcast %max_diff_scalar : {{ data_stype }} to vector<2x{{ data_stype }}> + %exp_rescale_2 = math.exp %rescale_bcast_2 : vector<2x{{ data_stype }}> + + + // Rescale previous out and sum accumulators + %old_out = affine.vector_load %ot_buffer2D[0, 0] : memref<{{ tile_e }}x{{ tile_l }}x{{ data_stype }}, 1>, vector<{{ tile_e }}x{{ data_stype }}> + %rescaled_out = arith.mulf %exp_rescale_e, %old_out : vector<{{ tile_e }}x{{ data_stype }}> + affine.vector_store %rescaled_out, %ot_buffer2D[0, 0] : memref<{{ tile_e }}x{{ tile_l }}x{{ data_stype }}, 1>, vector<{{ tile_e }}x{{ data_stype }}> + + %old_sum = affine.vector_load %sum_buffer[0, 0] : {{ sum_desc.get_mlir_shape(data_stype) }}, vector<2x{{ data_stype }}> + %rescaled_sum = arith.mulf %old_sum, %exp_rescale_2 : vector<2x{{ data_stype }}> + + + // Shift scores and apply exp: exp(x - new_max) + %scaled_scores_reload = affine.vector_load %mul_buffer[0, 0] : {{ mul_tile_desc.get_mlir_shape(data_stype) }}, vector<{{ tile_s }}x{{ data_stype }}> + %new_max_scalar = vector.extract %new_max[0] : {{ data_stype }} from vector<2x{{ data_stype }}> + %new_max_bcast = vector.broadcast %new_max_scalar : {{ data_stype }} to vector<{{ tile_s }}x{{ data_stype }}> + + %shifted_scores = arith.subf %scaled_scores_reload, %new_max_bcast : vector<{{ tile_s }}x{{ data_stype }}> + %exp_scores = math.exp %shifted_scores : vector<{{ tile_s }}x{{ data_stype }}> + affine.vector_store %exp_scores, %mul_buffer[0, 0] : {{ mul_tile_desc.get_mlir_shape(data_stype) }}, vector<{{ tile_s }}x{{ data_stype }}> + + + // accumulate current sum + %chunk_sum_res = affine.for %index5 = 0 to {{ tile_s }} step {{ chunk_size }} iter_args(%iter_sum=%v0_c) -> (vector<{{ chunk_size }}x{{ data_stype }}>) { + %chunk_exp = affine.vector_load %mul_buffer[0, %index5] : {{ mul_tile_desc.get_mlir_shape(data_stype) }}, vector<{{ chunk_size }}x{{ data_stype }}> + %local_sum = arith.addf %chunk_exp, %iter_sum : vector<{{ chunk_size }}x{{ data_stype }}> + affine.yield %local_sum : vector<{{ chunk_size }}x{{ data_stype }}> + } + + %zero_2x = vector.broadcast %c0 : {{ data_stype }} to vector<2x{{ data_stype }}> + %sum_cast = vector.shape_cast %chunk_sum_res : vector<{{ chunk_size }}x{{ data_stype }}> to vector<{{ chunk_size // 2 }}x2x{{ data_stype }}> + %sum_reduced_1 = vector.multi_reduction , %sum_cast, %zero_2x [0] : vector<8x2x{{ data_stype }}> to vector<2x{{ data_stype }}> + %sum_shuffled = vector.shuffle %sum_reduced_1, %sum_reduced_1 [1, 0] : vector<2x{{ data_stype }}>, vector<2x{{ data_stype }}> + %sum_reduced_2 = arith.addf %sum_reduced_1, %sum_shuffled : vector<2x{{ data_stype }}> + + %new_sum = arith.addf %sum_reduced_2, %rescaled_sum : vector<2x{{ data_stype }}> + affine.vector_store %new_sum, %sum_buffer[0, 0] : {{ sum_desc.get_mlir_shape(data_stype) }}, vector<2x{{ data_stype }}> + + + // value.t @ mul + linalg.matmul + { idx_map = array } + ins(%vt_buffer2D, %mul_buffer : memref<{{ tile_e }}x{{ tile_s }}x{{ data_stype }}, 1>, {{ mul_tile_desc.get_mlir_shape(data_stype) }}) + outs(%ot_buffer2D : memref<{{ tile_e }}x{{ tile_l }}x{{ data_stype }}, 1>) + } + + // out @ row_sum^(-1) + %final_row_sum = affine.vector_load %sum_buffer[0, 0] : {{ sum_desc.get_mlir_shape(data_stype) }}, vector<2x{{ data_stype }}> + %one_2x = vector.broadcast %c1 : {{ data_stype }} to vector<2x{{ data_stype }}> + + %reciprocal_row_sum_2x = arith.divf %one_2x, %final_row_sum : vector<2x{{ data_stype }}> + %reciprocal_scalar = vector.extract %reciprocal_row_sum_2x[0] : {{ data_stype }} from vector<2x{{ data_stype }}> + %reciprocal_bcast_e = vector.broadcast %reciprocal_scalar : {{ data_stype }} to vector<{{ tile_e }}x{{ data_stype }}> + + %accumulated_out = affine.vector_load %ot_buffer2D[0, 0] : memref<{{ tile_e }}x{{ tile_l }}x{{ data_stype }}, 1>, vector<{{ tile_e }}x{{ data_stype }}> + %stable_final_out = arith.mulf %accumulated_out, %reciprocal_bcast_e : vector<{{ tile_e }}x{{ data_stype }}> + affine.vector_store %stable_final_out, %ot_buffer2D[0, 0] : memref<{{ tile_e }}x{{ tile_l }}x{{ data_stype }}, 1>, vector<{{ tile_e }}x{{ data_stype }}> + + {{ kernel.store_output(indent_size=8) }} + } { accumulation_loop=true } + } { outer_loop=true } + } { outer_loop=true } + return +} +""" + +class MLIRFlashSDPATemplate(MLIRTemplate): + def __init__(self, input_nodes, layout, scale, input_reorder=None): + super().__init__("kernel", input_nodes, layout, input_reorder) + self.scale = scale + + def render(self, + kernel: MLIRTemplateKernel, + template_buffer_node = None, + epilogue_nodes: Optional[List[IRNode]] = None, + prologue_nodes: Optional[List[IRNode]] = None, + tile_info = None, + **kwargs): + + # Except for kernel, other arguments are usually None. + query, key, value, out, q_tensor, k_tensor, v_tensor, out_tensor, b, l, s, e, ev, n_extra_node, n_prologue_node = self.extract_info(template_buffer_node, epilogue_nodes, prologue_nodes) + + if tile_info is None: + tile_l, tile_s, tile_e, subtile_l, subtile_s, subtile_e = self.select_tile(kernel, l, s, e, n_extra_node, 0, n_prologue_node)[0] + else: + tile_l, tile_s, tile_e, subtile_l, subtile_s, subtile_e = tile_info + + TOG_latency = l if tile_l > l else tile_l + kernel.loop_size = [TOG_latency, tile_s, tile_e] + + # Select template code + # Other templates will be added according to situations. + nr_reduction_nodes = [node for node in epilogue_nodes if node.is_reduction()] if epilogue_nodes is not None else [] + if nr_reduction_nodes: + raise NotImplementedError("FLASH_SDPA_REDUCTION_TEMPLATE is not implemented yet.") + elif prologue_nodes: + raise NotImplementedError("FLASH_SDPA_PROLOGUE_TEMPLATE is not implemented yet.") + else: + template = FLASH_SDPA_TEMPLATE + epilogue_dim_aliasing = {"index0":"index0", "index1":"index1", "index2": "index2", "index3": "index3"} + nr_rdim = 0 + + # Prepare tile descriptors for input and output tensors. + # Intermediate buffers (transient data) do not require DRAM settings(dram stride and dram indices) + # as they are not synchronized with external DRAM. + # DRAM and SRAM tile shapes must match. + vlane_stride = 1 + + # (n, l, s, e, ev) + loop_dim = [sympy.Symbol("index0"), sympy.Symbol("index1"), sympy.Symbol("index2"), sympy.Symbol("index3")] + + + # Hardware constraint: The tile split axis is restricted. + # To accommodate this, we compute (key @ query.t) instead of (query @ key.t). + # SRAM settings + vlane_split_axis = 1 + q_tile_size = [1, tile_l, tile_e] + q_tile_stride = [0, tile_e, 1] + q_tile_desc = mlir_common.MLIRMultiDimTile(q_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + q_tile_desc.set_tile_size_stride(q_tile_size, q_tile_stride) + q_tile_desc.set_name("q_buffer") + q_tile_desc.offset = query.get_layout().offset + # DRAM settings + q_stride = q_tensor.stride() + q_idx = [loop_dim[0]*q_stride[0], loop_dim[1]*q_stride[1], loop_dim[3]*q_stride[2]] # To keep index arguemnt order, we used index_list + + # Since we use a weight-stationary approach in the Systolic Array (SA), + # the split axis of the first operand differs from a standard linear algebra matmul. + # The first operand (key) must be split along the column axis. + # This logic aligns with the relationship between the dot product's summation direction and the hardware's accumulation direction in the SA. + # SRAM settings + vlane_split_axis = 2 + k_tile_size = [1, tile_s, tile_e] + k_tile_stride = [0, 1, tile_s] + k_tile_desc = mlir_common.MLIRMultiDimTile(k_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + k_tile_desc.set_tile_size_stride(k_tile_size, k_tile_stride) + k_tile_desc.set_name("k_buffer") + k_tile_desc.offset = key.get_layout().offset + # DRAM settings + k_stride = k_tensor.stride() + k_idx = [loop_dim[0]*k_stride[0], loop_dim[2]*k_stride[1], loop_dim[3]*k_stride[2]] + + # Since we compute mul = key @ query.t, we perform out.t = (value.t @ Softmax(mul).t).t, + # which simplifies to (value.t @ Softmax(mul)) + # SRAM settings + vlane_split_axis = 1 + v_tile_size = [1, tile_s, tile_e] + v_tile_stride = [0, tile_e, 1] + v_tile_desc = mlir_common.MLIRMultiDimTile(v_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + v_tile_desc.set_tile_size_stride(v_tile_size, v_tile_stride) + v_tile_desc.set_name("v_buffer") + v_tile_desc.offset = value.get_layout().offset + # DRAM settings + v_stride = v_tensor.stride() + v_idx = [loop_dim[0]*v_stride[0], loop_dim[2]*v_stride[1], loop_dim[3]*v_stride[2]] # To keep index arguemnt order, we used index_list + + # Output is also stored in transposed format to match the value.t @ Softmax(mul) operation. + # SRAM settings + vlane_split_axis = 1 + out_tile_size = [1, tile_l, tile_e] + out_tile_stride=[0, tile_e, 1] + out_tile_desc = mlir_common.MLIRMultiDimTile(out_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + out_tile_desc.set_tile_size_stride(out_tile_size, out_tile_stride) + out_tile_desc.set_name("out_buffer") + # DRAM settings + out_stride = out.get_layout().stride[1:] + out_idx = [loop_dim[0]*out_stride[0], loop_dim[1]*out_stride[1], loop_dim[3]*out_stride[2]] + + # Intermediate buffers + + # For mul = key @ query.t + vlane_split_axis = 1 + mul_tile_size = [tile_s, tile_l] + mul_tile_stride = [tile_l, 1] + mul_tile_desc = mlir_common.MLIRMultiDimTile(mul_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + mul_tile_desc.set_tile_size_stride(mul_tile_size, mul_tile_stride) + mul_tile_desc.set_name("mul_buffer") + #FIXME. What is the offset? -> It doesn't matter at this time. + + # For storing maximum values per row + vlane_split_axis = 0 + max_size = [tile_l, 2] + max_stride = [2, 1] + max_desc = mlir_common.MLIRMultiDimTile(max_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + max_desc.set_tile_size_stride(max_size, max_stride) + max_desc.set_name("max_buffer") + + # For storing summation per row + vlane_split_axis = 0 + sum_size = [tile_l, 2] + sum_stride = [2, 1] + sum_desc = mlir_common.MLIRMultiDimTile(sum_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + sum_desc.set_tile_size_stride(sum_size, sum_stride) + sum_desc.set_name("sum_buffer") + + # For reduction + chunk_size = 16 + + kernel.render_options = dict( + KERNEL_NAME = self.name, + kernel = kernel, + b = b, + l = l, + s = s, + e = e, # Input sizes (dram) + tile_l = tile_l, + tile_s = tile_s, + tile_e = tile_e, # Tile sizes (sram) + subtile_l = subtile_l, + subtile_s = subtile_s, + subtile_e = subtile_e, # Subtile sizes (sram) + data_stype="f32", + query = query, + key = key, + value = value, + out = out, # Inputs and output (dram) + q_idx = q_idx, + k_idx = k_idx, + v_idx = v_idx, + out_idx = out_idx, # Strides (dram) + q_tile_desc = q_tile_desc, + k_tile_desc = k_tile_desc, + v_tile_desc = v_tile_desc, + mul_tile_desc = mul_tile_desc, + out_tile_desc = out_tile_desc, # Tile descriptions (sram) + max_desc = max_desc, + sum_desc = sum_desc, # Intermediate buffer descriptions (sram) + scale = self.scale, + chunk_size = chunk_size, + input_reorder = self.input_reorder # ETC + ) + + kernel.epilogue_info = dict( + output_node = self.output_node.name, + sram_var = "out_buffer", + dram_var = "out", + dram_idx = out_idx, + dram_tile_desc = out_tile_desc, + nr_rdim = nr_rdim, + r_dim_size = 0, + dim_aliasing = epilogue_dim_aliasing + ) + + code = self._template_from_string(template).render(**kernel.render_options) + kernel.add_loop_info([kernel.render_options["l"], kernel.render_options["s"], kernel.render_options["e"]], [kernel.render_options["tile_l"], kernel.render_options["tile_s"], kernel.render_options["tile_e"]]) + return code + + def extract_info(self, template_buffer_node, epilogue_nodes, prologue_nodes): + if template_buffer_node is not None: + self.output_node = template_buffer_node + + query = self.input_nodes[0] + key = self.input_nodes[1] + value = self.input_nodes[2] + out = self.output_node + + q_tensor = empty_strided(query.layout.size, query.layout.stride) + k_tensor = empty_strided(key.layout.size, key.layout.stride) + v_tensor = empty_strided(value.layout.size, value.layout.stride) + out_tensor = empty_strided(out.layout.size, out.layout.stride) + + # Flatten batch and head dimensions (n, h) into a single dimension (b = n*h) + q_tensor = q_tensor.view([-1, q_tensor.shape[-2], q_tensor.shape[-1]]) + k_tensor = k_tensor.view([-1, k_tensor.shape[-2], k_tensor.shape[-1]]) + v_tensor = v_tensor.view([-1, v_tensor.shape[-2], v_tensor.shape[-1]]) + out_tensor = out_tensor.view([-1, out_tensor.shape[-2], out_tensor.shape[-1]]) + + b, l, s, e, ev = q_tensor.size(0), q_tensor.size(1), k_tensor.size(1), k_tensor.size(2), v_tensor.size(2) + + n_extra_node = len(epilogue_nodes) if epilogue_nodes is not None else 0 + n_prologue_node = len(prologue_nodes) if prologue_nodes is not None else 0 + + return query, key, value, out, q_tensor, k_tensor, v_tensor, out_tensor, b, l, s, e, ev, n_extra_node, n_prologue_node + + # Reuse the existing function in MLIRBMMTemplate. + def select_tile(self, kernel, l, s, e, n_extra_node, n_extra_read, n_prologue_node): + + # FIXME: Update the method for getting tile candidates once TestDmaFineGrained oass works correctly with Flash Attention. + # tile_candidates = kernel.flash_sdpa_mapping(l, s, e, n_extra_node=n_extra_node) + tile_candidates = [[kernel.vector_lane, kernel.vector_lane, e]] + + for idx, (tile_l, tile_s, tile_e) in enumerate(tile_candidates): + subtile_l = tile_l if (tile_l < kernel.vector_lane) or n_prologue_node else kernel.vector_lane + subtile_s = tile_s # if (tile_s < kernel.vector_lane) or prologue_nodes else kernel.vector_lane + subtile_e = tile_e # if (tile_e < kernel.vector_lane) or prologue_nodes else kernel.vector_lane + + tile_candidates[idx] = tile_l,tile_s,tile_e,subtile_l,subtile_s,subtile_e + + return tile_candidates diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index b864e5f2..23f5e3dc 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -387,6 +387,100 @@ def conv_single_batch_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilatio tile_candidates = sorted(tile_candidates, key=lambda x: x[0], reverse=True) tile_candidates = [v for _, v in tile_candidates] return tile_candidates + + # Flash Attention requires more SRAM compared to standard GEMM. + # Total buffers needed: query, key, value, out, mul, max, sum + # Tensor Shapes: + # query (tile_l, tile_e), key (tile_s, tile_e), value (tile_s, tile_e), mul (tile_s, tile_l), out(tile_l, tile_e) + # max, sum : (tile_l, 2) + def flash_sdpa_mapping(self, l, s, e, n_extra_node=0, n_prologue_node=0, pad_e=True, min_tile=False, is_conv=False): + tile_candidates = [] + + spad_size_per_lane = self.spad_info["spad_size"] + spad_size = spad_size_per_lane * self.vector_lane + + # Double buffering + max_spad_per_lane = spad_size_per_lane // 2 + max_spad_size = spad_size // 2 + + # Padding for utilization + minimum_tile_size = 8 + minimum_n_tile = self.num_cores if min_tile else 1 + l_pad_factor = self.vector_lane if l > self.vector_lane else minimum_tile_size + s_pad_factor = self.vector_lane if s > self.vector_lane else minimum_tile_size + + pad = lambda x, factor: ((x + factor - 1) // factor) * factor + l_padded = pad(l, l_pad_factor) + s_padded = pad(s, s_pad_factor) + + # Calculate the total number of vector-sized blocks + l_idx = l_padded // self.vector_lane + s_idx = s_padded // self.vector_lane + + # Generate candidates for the number of blocks per tile + l_tile_range = sympy.divisors(l_idx) if l > self.vector_lane else [1] + s_tile_range = sympy.divisors(s_idx) if s > self.vector_lane else [1] + + # Convert block count to actual tile size + maximize_i_j = 1 + max_used_spad_size = 0 + + # Flash Attention does not tile along the head dimension (e or ev). + tile_e = e + + for i in l_tile_range: + tile_l = i * self.vector_lane if l > self.vector_lane else l_padded + for j in s_tile_range: + tile_s = j * self.vector_lane if s > self.vector_lane else s_padded + + # Calculate used spad size + used_spad_size = ( + tile_l * tile_e * (1 + n_prologue_node) # query + + tile_s * tile_e # key + + tile_s * tile_e # value + + tile_s * tile_l # mul + + tile_l * tile_e * (1 + n_extra_node) # out + + (tile_l * 2) * 2 # max, sum + ) * self.precision + + # Calculate used spad size per lane. + query_per_lane = tile_e * (1+n_prologue_node) + key_per_lane = tile_s + value_per_lane = tile_e + mul_per_lane = tile_s + out_per_lane = tile_e * (1 + n_extra_node) + vec_per_lane = 2 * 2 + + used_spad_per_lane = ( + query_per_lane + + key_per_lane + + value_per_lane + + mul_per_lane + + out_per_lane + + vec_per_lane + ) * self.precision + + # Add the validated candidate to the list if it passes all hardware constraints. + n_tile = math.ceil(l / max(tile_l, 128)) * math.ceil(s / max(tile_s, 128)) + check_spad_size = (used_spad_size < max_spad_size and used_spad_per_lane < max_spad_per_lane) + + if (check_spad_size + and max_used_spad_size < used_spad_size # SRAM utilization + and maximize_i_j <= tile_l * tile_s # Larger tile + and n_tile >= minimum_n_tile # Pallelism + and max(tile_s, 128) // max(tile_l, 128) < 10): # Balanced Shape + max_used_spad_size = used_spad_size + maximize_i_j = tile_l * tile_s + + if check_spad_size: + tile_candidates.append((used_spad_size, (tile_l, tile_s, tile_e))) + + # Sort by used_spad_size. + # tile_candidates[0] is the best solution we have. + tile_candidates = sorted(tile_candidates, key=lambda x: x[0], reverse=True) + tile_candidates = [v for _, v in tile_candidates] + + return tile_candidates def meta_kernel(self): kernel_arg_attributes = self.kernel_arg_attributes @@ -827,7 +921,12 @@ def def_dma_op(self, dma_type, dram_var:str, index_list:list, tile_desc:mlir_com def def_sram_buffer(self, dram_name, tile_desc, id=0, indent_size=0): # Prepare code block with self: - dtype = self.named_nodes[dram_name].get_layout().dtype + try: + dtype = self.named_nodes[dram_name].get_layout().dtype + except (KeyError, AttributeError, TypeError): + import torch + dtype = torch.float32 + tile_shape = tile_desc.get_mlir_shape(mlir_common.DTYPE_TO_MLIR[dtype]) buffer_name = self.allocate_sram_buffer(dtype, dram_name, tile_desc, id, forced_name=dram_name) code = f"%{tile_desc.name} = memref.get_global @{buffer_name} : {tile_shape}" diff --git a/tests/test_sdpa.py b/tests/test_sdpa.py new file mode 100644 index 00000000..9c921eb4 --- /dev/null +++ b/tests/test_sdpa.py @@ -0,0 +1,84 @@ +import sys +import math +import torch +import inspect +from typing import List +import torch.nn.functional as F +from torch.nn.attention import SDPBackend, sdpa_kernel +from torch.fx.passes.graph_drawer import FxGraphDrawer +from torch._inductor.decomposition import decompositions + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + message = f"|{name} Test Passed|" + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + print("custom out: ", out.cpu()) + print("cpu out: ", cpu_out) + exit(1) + +def test_scaled_dot_product_attention(device, backends="flash"): + torch.manual_seed(0) + n_batch_list = [1, 4, 8, 16] + n_head_list = [1, 4, 8, 12] + n_token_list = [128, 256, 512, 1024] + head_dim_list = [32, 64, 128] + + for n_batch in n_batch_list: + for n_head in n_head_list: + for n_token in n_token_list: + for head_dim in head_dim_list: + # Inputs + query = torch.rand(n_batch, n_head, n_token, head_dim, dtype=torch.float32) + key = torch.rand(n_batch, n_head, n_token, head_dim, dtype=torch.float32) + value = torch.rand(n_batch, n_head, n_token, head_dim, dtype=torch.float32) + + query = query.to(device=device) + key = key.to(device=device) + value = value.to(device=device) + + # With NPU + if backends == "flash": + backends = [SDPBackend.FLASH_ATTENTION] + elif backends == "math": + backends = [SDPBackend.MATH] + elif backends == "memory_efficient": + backends = [SDPBackend.EFFICIENT_ATTENTION] + else: + backends = [SDPBackend.FLASH_ATTENTION, SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION] + + with sdpa_kernel(backends=backends): + opt_fn = torch.compile(dynamic=False)(F.scaled_dot_product_attention) + out = opt_fn(query, key, value) + + out = out.to(device) + + # With CPU + device = torch.device('cpu') + query = query.to(device=device) + key = key.to(device=device) + value = value.to(device=device) + cpu_out = F.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) + + name = f"SDPA(n_batch: {n_batch}, n_head: {n_head}, n_token: {n_token}, head_dim: {head_dim})" + test_result(name, out, cpu_out) + + print("All tests passed!") + +def clear_caches(): + import os + from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache + from torch._inductor.codecache import FxGraphCache + AOTAutogradCache.clear() + torch._dynamo.reset() + os.environ["TORCHINDUCTOR_CACHE"] = "0" + FxGraphCache.clear() + +if __name__ == "__main__": + clear_caches() + + device = torch.device('npu:0') + test_scaled_dot_product_attention(device, backends="flash") + \ No newline at end of file From fc247be17221f2b6aa8c52228a2e86b7315ef78d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EC=9D=B4=EC=9E=AC=EA=B7=A0?= Date: Mon, 2 Mar 2026 00:28:31 +0900 Subject: [PATCH 02/31] [Template] Add cat & sort template + Multi-output (WIP) --- .../torch_openreg/openreg/__init__.py | 49 +++ PyTorchSimFrontend/mlir/mlir_cat_template.py | 167 +++++++++++ PyTorchSimFrontend/mlir/mlir_common.py | 6 +- PyTorchSimFrontend/mlir/mlir_lowering.py | 281 +++++++++++++++++- PyTorchSimFrontend/mlir/mlir_sort_template.py | 253 ++++++++++++++++ PyTorchSimFrontend/mlir/mlir_template.py | 30 +- tests/DeepSeek/test_deepseek_v3_base.py | 170 +++++++++-- tests/test_cat.py | 89 ++++++ tests/test_sort.py | 112 +++++++ 9 files changed, 1121 insertions(+), 36 deletions(-) create mode 100644 PyTorchSimFrontend/mlir/mlir_cat_template.py create mode 100644 PyTorchSimFrontend/mlir/mlir_sort_template.py create mode 100644 tests/test_cat.py create mode 100644 tests/test_sort.py diff --git a/PyTorchSimDevice/torch_openreg/openreg/__init__.py b/PyTorchSimDevice/torch_openreg/openreg/__init__.py index f5aabc18..5603a4f7 100644 --- a/PyTorchSimDevice/torch_openreg/openreg/__init__.py +++ b/PyTorchSimDevice/torch_openreg/openreg/__init__.py @@ -256,6 +256,52 @@ def launch_model(model, *args, stream_index=0, timestamp=0, **kwargs): from .random import * # noqa: F403 from .amp import * +def _precheck_cat_out_args(args, kwargs): + tensors = args[0] if len(args) > 0 else kwargs.get("tensors") + dim = args[1] if len(args) > 1 else kwargs.get("dim", 0) + out = kwargs.get("out", args[2] if len(args) > 2 else None) + + if out is None: + return + if not isinstance(tensors, (list, tuple)) or len(tensors) == 0: + raise RuntimeError("aten::cat.out requires non-empty tensor list") + if not all(isinstance(t, torch.Tensor) for t in tensors): + raise RuntimeError("aten::cat.out tensors must be Tensor values") + if not isinstance(out, torch.Tensor): + raise RuntimeError("aten::cat.out out must be a Tensor") + + rank = tensors[0].dim() + if rank == 0: + raise RuntimeError("aten::cat.out does not support scalar inputs") + if dim < 0: + dim += rank + if dim < 0 or dim >= rank: + raise RuntimeError(f"aten::cat.out dim out of range: dim={dim}, rank={rank}") + if any(t.dim() != rank for t in tensors): + raise RuntimeError("aten::cat.out inputs must have the same rank") + if any(t.dtype != tensors[0].dtype for t in tensors): + raise RuntimeError("aten::cat.out inputs must have the same dtype") + if out.dim() != rank: + raise RuntimeError("aten::cat.out out rank mismatch") + + for d in range(rank): + if d == dim: + continue + base = tensors[0].shape[d] + if any(t.shape[d] != base for t in tensors[1:]): + raise RuntimeError( + f"aten::cat.out non-concatenated dimension mismatch at dim={d}" + ) + if out.shape[d] != base: + raise RuntimeError(f"aten::cat.out out shape mismatch at dim={d}") + + expected = sum(t.shape[dim] for t in tensors) + if out.shape[dim] != expected: + raise RuntimeError( + f"aten::cat.out out concatenated dimension mismatch at dim={dim}: " + f"expected {expected}, got {out.shape[dim]}" + ) + def eager_to_compile(op_name): """ Register an eager mode operation as a graph-based implementation using torch.compile(). @@ -267,6 +313,9 @@ def eager_to_compile(op_name): torch.npu.eager_to_compile("aten::mul.Tensor") """ def wrapper(*args, **kwargs): + if op_name == "aten::cat.out": + _precheck_cat_out_args(args, kwargs) + @torch.compile(dynamic=False) def dummy_graph(*args, **kwargs): # Convert "aten::mul.Tensor" -> torch.ops.aten.mul.Tensor diff --git a/PyTorchSimFrontend/mlir/mlir_cat_template.py b/PyTorchSimFrontend/mlir/mlir_cat_template.py new file mode 100644 index 00000000..996af1de --- /dev/null +++ b/PyTorchSimFrontend/mlir/mlir_cat_template.py @@ -0,0 +1,167 @@ +from typing import List, Optional, cast + +import sympy +from torch._inductor.ir import Buffer, IRNode +from torch._inductor.virtualized import V + +from PyTorchSimFrontend.mlir import mlir_common +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate, MLIRTemplateKernel + + +TEMPLATE = r""" +{{kernel.def_global_vars()}} + +func.func @{{ KERNEL_NAME }} {{kernel.def_kernel(inputs=[X0, X1], outputs=[Y], names_str=NAMES_STR, input_reorder=input_reorder)}} { + {{ kernel.def_sram_buffer("X0", X0_TILE_DESC, id=0, indent_size=2) }} + {{ kernel.def_sram_buffer("X1", X1_TILE_DESC, id=1, indent_size=2) }} + {{ kernel.def_sram_buffer(OUT_DVAR, Y_TILE_DESC, id=2, indent_size=2) }} + {{ kernel.def_local_vars(indent_size=2) }} + + affine.for %cat_block = 0 to 1 step 1 { +{% if DIM == 0 %} + affine.for %index0 = 0 to {{ X0_ROWS }} step 1 { + affine.for %index1 = 0 to {{ COLS }} step 1 { + {{ kernel.def_dma_op("MVIN", "X0", X0_IDX, X0_TILE_DESC, indent_size=8) }} + {{ kernel.def_dma_op("MVOUT", OUT_DVAR, Y0_IDX, X0_TILE_DESC, indent_size=8) }} + } + } + + affine.for %index2 = 0 to {{ X1_ROWS }} step 1 { + affine.for %index3 = 0 to {{ COLS }} step 1 { + {{ kernel.def_dma_op("MVIN", "X1", X1_IDX, X1_TILE_DESC, indent_size=8) }} + {{ kernel.def_dma_op("MVOUT", OUT_DVAR, Y1_IDX, X1_TILE_DESC, indent_size=8) }} + } + } +{% else %} + affine.for %index0 = 0 to {{ ROWS }} step 1 { + affine.for %index1 = 0 to {{ X0_COLS }} step 1 { + {{ kernel.def_dma_op("MVIN", "X0", X0_IDX, X0_TILE_DESC, indent_size=8) }} + {{ kernel.def_dma_op("MVOUT", OUT_DVAR, Y0_IDX, X0_TILE_DESC, indent_size=8) }} + } + affine.for %index3 = 0 to {{ X1_COLS }} step 1 { + {{ kernel.def_dma_op("MVIN", "X1", X1_IDX, X1_TILE_DESC, indent_size=8) }} + {{ kernel.def_dma_op("MVOUT", OUT_DVAR, Y1_IDX, X1_TILE_DESC, indent_size=8) }} + } + } +{% endif %} + } { outer_loop=true } + return +} +""" + + +class MLIRCatTemplate(MLIRTemplate): + def __init__(self, input_nodes, layout, dim, input_reorder=None): + super().__init__("kernel", input_nodes, layout, input_reorder) + self.dim = dim + + def render( + self, + kernel: MLIRTemplateKernel, + template_buffer_node=None, + epilogue_nodes: Optional[List[IRNode]] = None, + tile_info=None, + **kwargs, + ): + is_out_variant = template_buffer_node is not None + if is_out_variant: + self.output_node = template_buffer_node + # cat template currently emits a single output buffer and does not + # support epilogue output remapping. + + def _unwrap_node(n): + return n.node if hasattr(n, "node") else n + + x0 = _unwrap_node(self.input_nodes[0]) + x1 = _unwrap_node(self.input_nodes[1]) + y = _unwrap_node(self.output_node) + + def _as_int(v): + try: + return int(v) + except Exception: + return int(V.graph.sizevars.size_hint(v)) + + x0_rows = _as_int(x0.get_size()[0]) + x1_rows = _as_int(x1.get_size()[0]) + x0_cols = _as_int(x0.get_size()[1]) + x1_cols = _as_int(x1.get_size()[1]) + y_cols = _as_int(y.get_size()[1]) + kernel.loop_size = None + + # 2D cat template with contiguous layout. + x0_tile_desc = mlir_common.MLIRMultiDimTile([1, 1], kernel.vector_lane, vlane_split_axis=1, vlane_stride=1) + x0_tile_desc.set_tile_size_stride([1, 1], [1, 1]) + x0_tile_desc.set_name("x0_cat_tile") + x1_tile_desc = mlir_common.MLIRMultiDimTile([1, 1], kernel.vector_lane, vlane_split_axis=1, vlane_stride=1) + x1_tile_desc.set_tile_size_stride([1, 1], [1, 1]) + x1_tile_desc.set_name("x1_cat_tile") + y_tile_desc = mlir_common.MLIRMultiDimTile([1, 1], kernel.vector_lane, vlane_split_axis=1, vlane_stride=1) + y_tile_desc.set_tile_size_stride([1, 1], [1, 1]) + y_tile_desc.set_name("y_cat_tile") + + if self.dim == 0: + # Flattened offsets for dim=0 cat. + x0_idx = [sympy.Symbol("index0") * x0_cols, sympy.Symbol("index1")] + x1_idx = [sympy.Symbol("index2") * x1_cols, sympy.Symbol("index3")] + y0_idx = [sympy.Symbol("index0") * y_cols, sympy.Symbol("index1")] + y1_idx = [(sympy.Symbol("index2") + x0_rows) * y_cols, sympy.Symbol("index3")] + else: + # Flattened offsets for dim=1 cat. + x0_idx = [sympy.Symbol("index0") * x0_cols, sympy.Symbol("index1")] + x1_idx = [sympy.Symbol("index0") * x1_cols, sympy.Symbol("index3")] + y0_idx = [sympy.Symbol("index0") * y_cols, sympy.Symbol("index1")] + y1_idx = [sympy.Symbol("index0") * y_cols, sympy.Symbol("index3") + x0_cols] + + kernel.render_options = dict( + KERNEL_NAME=self.name, + kernel=kernel, + X0=x0, + X1=x1, + Y=y, + OUT_DVAR="out_ptr1" if is_out_variant else "Y", + NAMES_STR="X0, X1, out_ptr1" if is_out_variant else "X0, X1, Y", + DIM=self.dim, + X0_ROWS=x0_rows, + X1_ROWS=x1_rows, + ROWS=x0_rows, + X0_COLS=x0_cols, + X1_COLS=x1_cols, + COLS=x0_cols, + X0_TILE_DESC=x0_tile_desc, + X1_TILE_DESC=x1_tile_desc, + Y_TILE_DESC=y_tile_desc, + X0_IDX=x0_idx, + X1_IDX=x1_idx, + Y0_IDX=y0_idx, + Y1_IDX=y1_idx, + input_reorder=self.input_reorder, + ) + # Needed when epilogue fusion requests set_ranges(). + kernel.dim_aliasing = {"index0": "index0", "index1": "index1"} + + if hasattr(self.output_node, "node") and hasattr(self.output_node.node, "get_name"): + output_node_name = self.output_node.node.get_name() + elif hasattr(self.output_node, "get_name"): + output_node_name = self.output_node.get_name() + else: + output_node_name = self.output_node.name + + if hasattr(y, "get_numel"): + y_numel = y.get_numel() + elif hasattr(y, "node") and hasattr(y.node, "get_numel"): + y_numel = y.node.get_numel() + else: + y_numel = None + + kernel.epilogue_info = dict( + output_node=output_node_name, + sram_var="y_cat_tile", + dram_var=kernel.render_options["OUT_DVAR"], + dram_tile_desc=y_tile_desc, + ) + if y_numel is not None: + kernel.exception_nodes[kernel.render_options["OUT_DVAR"]] = {"numel": y_numel} + + code = self._template_from_string(TEMPLATE).render(**kernel.render_options) + return code diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index 34b185b8..256d7101 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -173,7 +173,11 @@ def get_mlir_shape(info): def mlir_argdefs(self, extra_node=dict()): buffer_types = {} for x in V.graph.buffers: - if not isinstance(x.layout, MultiOutputLayout): # FIXME: MultiOutputLayout should be handled + if isinstance(x.layout, MultiOutputLayout): + # MultiOutput kernel containers own concrete output nodes in `outputs`. + for out in getattr(x, "outputs", []): + buffer_types[out.get_name()] = [out.get_dtype(), out.get_numel(), out.get_size(), out.get_stride()] + else: buffer_types[x.get_name()] = [x.get_dtype(), x.get_numel(), x.get_size(), x.get_stride()] for name, val in V.graph.graph_inputs.items(): if isinstance(val, sympy.Expr): diff --git a/PyTorchSimFrontend/mlir/mlir_lowering.py b/PyTorchSimFrontend/mlir/mlir_lowering.py index ebf0c80e..0f28f03b 100644 --- a/PyTorchSimFrontend/mlir/mlir_lowering.py +++ b/PyTorchSimFrontend/mlir/mlir_lowering.py @@ -15,10 +15,15 @@ from PyTorchSimFrontend.mlir.mlir_conv_sb_template import MLIRConvSingleBatchTemplate from PyTorchSimFrontend.mlir.mlir_conv_sbs_template import MLIRConvSingleBatchStridedTemplate from PyTorchSimFrontend.mlir.mlir_maxpool_template import MLIRMaxPoolTemplate +from PyTorchSimFrontend.mlir.mlir_cat_template import MLIRCatTemplate +from PyTorchSimFrontend.mlir.mlir_sort_template import MLIRSortTemplate from PyTorchSimFrontend import extension_config aten = torch.ops.aten aten_spmm = MLIRExternKernelChoice(torch.sparse.mm, "custom_op::sparse_addmm") +_orig_cat_default_lowering = lowerings.get(aten.cat.default) +_orig_cat_out_lowering = lowerings.get(aten.cat.out) +_orig_sort_values_stable_lowering = lowerings.get(aten.sort.values_stable) def tuned_mm(mat1, mat2, * ,layout=None): m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout) @@ -181,11 +186,285 @@ def custom_unsafe_index(x, indices): x.realize() return index_impl(x, indices, check=False) + +def _cat_layout(tensors: Sequence[TensorBox], dim: int) -> ir.Layout: + with V.graph.fake_mode: + output = torch.ops.aten.cat( + [ir.ir_node_to_tensor(t, guard_shape=True) for t in tensors], + dim, + ) + sizes = ir.convert_shape_to_inductor(output.size()) + stride = ir.convert_shape_to_inductor(output.stride()) + return ir.FixedLayout( + tensors[0].get_device(), + tensors[0].get_dtype(), + sizes, + stride, + ) + + +def _can_use_cat_template(tensors: Sequence[TensorBox], dim: int) -> bool: + # Current template specialization: 2 inputs, rank-2, dim in {0, 1}. + if len(tensors) != 2: + return False + if not all(hasattr(t, "get_size") and hasattr(t, "get_dtype") and hasattr(t, "realize") for t in tensors): + return False + if tensors[0].get_dtype() != tensors[1].get_dtype(): + return False + rank0 = len(tensors[0].get_size()) + rank1 = len(tensors[1].get_size()) + if rank0 != 2 or rank1 != 2: + return False + if dim < 0: + dim += rank0 + if dim not in (0, 1): + return False + + if dim == 0: + cols0 = tensors[0].get_size()[1] + cols1 = tensors[1].get_size()[1] + return V.graph.sizevars.statically_known_equals(cols0, cols1) + + rows0 = tensors[0].get_size()[0] + rows1 = tensors[1].get_size()[0] + return V.graph.sizevars.statically_known_equals(rows0, rows1) + + +def _cat_fallback(reason: str, tensors: Sequence[TensorBox], dim: int): + # Non-template cases delegate to the original lowering path. + return _orig_cat_default_lowering(tensors, dim) + + +def _custom_cat_impl(tensors: Sequence[TensorBox], dim: int = 0): + if _orig_cat_default_lowering is None: + raise RuntimeError("Original aten.cat.default lowering is missing") + if len(tensors) > 0: + rank = len(tensors[0].get_size()) + if dim < 0: + dim += rank + if not _can_use_cat_template(tensors, dim): + return _cat_fallback("default-path", tensors, dim) + + for t in tensors: + t.realize() + layout = _cat_layout(tensors, dim) + mlir_template = MLIRCatTemplate(list(tensors), layout, dim=dim) + return mlir_template.generate().output_node() + + +def custom_cat_default(tensors: Sequence[TensorBox], dim: int = 0): + return _custom_cat_impl(tensors, dim) + + +def custom_cat_out(tensors: Sequence[TensorBox], dim: int = 0, out: Optional[TensorBox] = None): + if _orig_cat_out_lowering is None: + raise RuntimeError("Original aten.cat.out lowering is missing") + if out is None: + return _orig_cat_out_lowering(tensors, dim, out) + + copy_default_lowering = lowerings.get(aten.copy_.default) + slice_tensor_lowering = lowerings.get(aten.slice.Tensor) + if copy_default_lowering is None or slice_tensor_lowering is None: + raise RuntimeError("cat.out lowering requires aten.copy_.default and aten.slice.Tensor lowerings") + + # Lower cat.out as a sequence of slice+copy ops so each piece still runs + # through the existing compiled/simulated kernel path. + if len(tensors) == 0: + raise RuntimeError("cat.out requires at least one input tensor") + if not all(hasattr(t, "get_size") and hasattr(t, "get_dtype") and hasattr(t, "realize") for t in tensors): + raise RuntimeError("cat.out inputs must be tensor-like values") + rank = len(tensors[0].get_size()) + if rank == 0: + raise RuntimeError("cat.out does not support scalar inputs") + if dim < 0: + dim = dim + rank + if dim < 0 or dim >= rank: + raise RuntimeError(f"cat.out dim out of range: dim={dim}, rank={rank}") + if any(len(t.get_size()) != rank for t in tensors): + raise RuntimeError("cat.out inputs must have the same rank") + if any(t.get_dtype() != tensors[0].get_dtype() for t in tensors): + raise RuntimeError("cat.out inputs must have the same dtype") + # cat semantics: all non-cat dimensions must be equal. + for i in range(rank): + if i == dim: + continue + base = tensors[0].get_size()[i] + if any(not V.graph.sizevars.statically_known_equals(base, t.get_size()[i]) for t in tensors[1:]): + raise RuntimeError(f"cat.out non-concatenated dimension mismatch at dim={i}") + + # Output shape must match concatenated shape. + if not hasattr(out, "get_size"): + raise RuntimeError("cat.out output must be tensor-like") + out_sizes = list(out.get_size()) + if len(out_sizes) != rank: + raise RuntimeError("cat.out output rank mismatch") + for i in range(rank): + if i == dim: + continue + if not V.graph.sizevars.statically_known_equals(out_sizes[i], tensors[0].get_size()[i]): + raise RuntimeError(f"cat.out output shape mismatch at dim={i}") + expected_cat = sum(t.get_size()[dim] for t in tensors) + if not V.graph.sizevars.statically_known_equals(out_sizes[dim], expected_cat): + raise RuntimeError(f"cat.out output concatenated dimension mismatch at dim={dim}") + + if isinstance(out, TensorBox): + out.realize() + + offset = 0 + for src in tensors: + src.realize() + end = offset + src.get_size()[dim] + dst_view = slice_tensor_lowering(out, dim, offset, end, 1) + copy_default_lowering(dst_view, src) + offset = end + return out + + +def _custom_sort_values_impl( + self: TensorBox, + dim: int = -1, + descending: bool = False, + values: Optional[TensorBox] = None, + indices: Optional[TensorBox] = None, + stable: Optional[bool] = None, +): + if values is None or indices is None: + raise RuntimeError("sort.values* lowering requires both out tensors: values, indices") + + def _normalize_dim(rank: int, d: int) -> int: + return d + rank if d < 0 else d + + if not hasattr(self, "get_size"): + raise RuntimeError("sort.values* lowering requires TensorBox input") + + rank = len(self.get_size()) + norm_dim = _normalize_dim(rank, dim) + if norm_dim < 0 or norm_dim >= rank: + raise RuntimeError(f"sort.values* dim out of range: dim={dim}, rank={rank}") + if rank != 2: + raise RuntimeError(f"sort.values* lowering currently supports rank-2 only, got rank={rank}") + if norm_dim not in (0, 1): + raise RuntimeError(f"sort.values* lowering currently supports dim in {{0,1}} only, got dim={norm_dim}") + + self.realize() + if isinstance(values, TensorBox): + values.realize() + if isinstance(indices, TensorBox): + indices.realize() + + value_layout, _ = _sort_layouts(self, norm_dim, descending) + mlir_template = MLIRSortTemplate( + [self], + value_layout, + dim=norm_dim, + descending=descending, + stable=True if stable is None else stable, + indices_node=indices, + ) + sorted_values = mlir_template.generate(template_buffer_node=values, epilogue_nodes=[indices]).output_node() + return sorted_values, indices + + +def _sort_layouts(x: TensorBox, dim: int, descending: bool): + with V.graph.fake_mode: + v, i = torch.ops.aten.sort( + ir.ir_node_to_tensor(x, guard_shape=True), + dim, + descending, + ) + v_sizes = ir.convert_shape_to_inductor(v.size()) + v_stride = ir.convert_shape_to_inductor(v.stride()) + i_sizes = ir.convert_shape_to_inductor(i.size()) + i_stride = ir.convert_shape_to_inductor(i.stride()) + + value_layout = ir.FixedLayout(x.get_device(), x.get_dtype(), v_sizes, v_stride) + index_layout = ir.FixedLayout(x.get_device(), torch.int64, i_sizes, i_stride) + return value_layout, index_layout + + +def custom_sort_stable( + self: TensorBox, + *, + stable: Optional[bool] = None, + dim: int = -1, + descending: bool = False, +): + empty_strided_lowering = lowerings.get(aten.empty_strided.default) + if empty_strided_lowering is None: + if _orig_sort_values_stable_lowering is None: + raise RuntimeError("sort.stable lowering requires aten.empty_strided.default") + return _orig_sort_values_stable_lowering(self, dim=dim, descending=descending, stable=True) + + rank = len(self.get_size()) if hasattr(self, "get_size") else 0 + norm_dim = dim + rank if dim < 0 else dim + if rank > 0 and (norm_dim < 0 or norm_dim >= rank): + raise RuntimeError(f"sort.stable dim out of range: dim={dim}, rank={rank}") + + # Template specialization supports rank-2 and dim in {0,1}. + if rank == 2 and norm_dim not in (0, 1): + if _orig_sort_values_stable_lowering is None: + raise RuntimeError("Original aten.sort.values_stable lowering is missing") + return _orig_sort_values_stable_lowering(self, dim=dim, descending=descending, stable=True) + + try: + value_layout, index_layout = _sort_layouts(self, norm_dim, descending) + values = empty_strided_lowering( + list(value_layout.size), + list(value_layout.stride), + dtype=value_layout.dtype, + device=self.get_device(), + ) + indices = empty_strided_lowering( + list(index_layout.size), + list(index_layout.stride), + dtype=index_layout.dtype, + device=self.get_device(), + ) + return _custom_sort_values_impl( + self=self, + dim=dim, + descending=descending, + values=values, + indices=indices, + stable=True if stable is None else stable, + ) + except Exception: + if _orig_sort_values_stable_lowering is None: + raise + return _orig_sort_values_stable_lowering(self, dim=dim, descending=descending, stable=stable) + + +def custom_sort_values_stable( + self: TensorBox, + *, + stable: Optional[bool] = None, + dim: int = -1, + descending: bool = False, + values: Optional[TensorBox] = None, + indices: Optional[TensorBox] = None, +): + return _custom_sort_values_impl( + self=self, + dim=dim, + descending=descending, + values=values, + indices=indices, + stable=stable, + ) + + lowerings.update({getattr(aten.mm, overload): tuned_mm for overload in aten.mm.overloads()}) lowerings.update({getattr(aten.addmm, overload): tuned_addmm for overload in aten.addmm.overloads()}) lowerings.update({getattr(aten.convolution, overload): convolution for overload in aten.convolution.overloads()}) lowerings.update({getattr(aten.bmm, overload): tuned_bmm for overload in aten.bmm.overloads()}) lowerings.update({getattr(aten._sparse_addmm, overload): sparse_addmm for overload in aten._sparse_addmm.overloads()}) lowerings.update({getattr(aten._unsafe_index, overload): custom_unsafe_index for overload in aten._unsafe_index.overloads()}) + +lowerings.update({aten.cat.default: custom_cat_default}) +lowerings.update({aten.cat.out: custom_cat_out}) + +lowerings.update({aten.sort.stable: custom_sort_stable}) +lowerings.update({aten.sort.values_stable: custom_sort_values_stable}) + if extension_config.CONFIG_USE_TIMING_POOLING: - lowerings.update({getattr(aten.max_pool2d_with_indices, overload): custom_maxpool for overload in aten.max_pool2d_with_indices.overloads()}) # FIXME: maxpool should be implemented as a template \ No newline at end of file + lowerings.update({getattr(aten.max_pool2d_with_indices, overload): custom_maxpool for overload in aten.max_pool2d_with_indices.overloads()}) # FIXME: maxpool should be implemented as a template diff --git a/PyTorchSimFrontend/mlir/mlir_sort_template.py b/PyTorchSimFrontend/mlir/mlir_sort_template.py new file mode 100644 index 00000000..d12c7570 --- /dev/null +++ b/PyTorchSimFrontend/mlir/mlir_sort_template.py @@ -0,0 +1,253 @@ +from typing import List, Optional + +import sympy +from torch._inductor.ir import IRNode +from torch._inductor.virtualized import V + +from PyTorchSimFrontend.mlir import mlir_common +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate, MLIRTemplateKernel + + +TEMPLATE = r""" +{{kernel.def_global_vars()}} + +func.func @{{ KERNEL_NAME }} {{kernel.def_kernel(inputs=[X, YI], outputs=[YV], names_str=NAMES_STR, input_reorder=input_reorder)}} { + {{ kernel.def_sram_buffer("YI", YI_TILE_DESC, id=1, indent_size=2) }} + {{ kernel.def_sram_buffer(OUT_DVAR, YV_TILE_DESC, id=2, indent_size=2) }} + {{ kernel.def_local_vars(indent_size=2) }} + + %c0 = arith.constant 0 : index + %c_cols = arith.constant {{ COLS }} : index + + affine.for %sort_block = 0 to 1 step 1 { + // Initialize output value/index buffers. + affine.for %row = 0 to {{ ROWS }} step 1 { + affine.for %col = 0 to {{ COLS }} step 1 { + {{ kernel.def_dma_op("MVIN", "X", INIT_X_IDX, X_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=8) }} + {{ kernel.def_dma_op("MVOUT", OUT_DVAR, INIT_YV_IDX, X_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=8) }} +{% if DIM == 1 %} + %idx_i64 = arith.index_cast %col : index to {{ YI_ELEM_TYPE }} +{% else %} + %idx_i64 = arith.index_cast %row : index to {{ YI_ELEM_TYPE }} +{% endif %} + memref.store %idx_i64, %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} + {{ kernel.def_dma_op("MVOUT", "YI", INIT_YI_IDX, YI_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=8) }} + } + } + +{% if DIM == 1 %} + // Stable bubble sort on each row (dim=1). + affine.for %row = 0 to {{ ROWS }} step 1 { + affine.for %pass = 0 to {{ COLS }} step 1 { + affine.for %j = 0 to {{ COLS_MINUS1 }} step 1 { + {{ kernel.def_dma_op("MVIN", OUT_DVAR, D1_S0_IDX, YV_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=10) }} + %lhs = memref.load %yv_sort_tile[%c0, %c0] : {{ YV_TILE_MEMREF_TYPE }} + + {{ kernel.def_dma_op("MVIN", OUT_DVAR, D1_S1_IDX, YV_S1_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=10) }} + %rhs = memref.load %yv_sort_tile[%c0, %c0] : {{ YV_TILE_MEMREF_TYPE }} + +{% if DESCENDING %} + %need_swap = arith.cmpf olt, %lhs, %rhs : {{ YV_ELEM_TYPE }} +{% else %} + %need_swap = arith.cmpf ogt, %lhs, %rhs : {{ YV_ELEM_TYPE }} +{% endif %} + scf.if %need_swap { + memref.store %rhs, %yv_sort_tile[%c0, %c0] : {{ YV_TILE_MEMREF_TYPE }} + {{ kernel.def_dma_op("MVOUT", OUT_DVAR, D1_S0_IDX, YV_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + + memref.store %lhs, %yv_sort_tile[%c0, %c0] : {{ YV_TILE_MEMREF_TYPE }} + {{ kernel.def_dma_op("MVOUT", OUT_DVAR, D1_S1_IDX, YV_S1_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + + {{ kernel.def_dma_op("MVIN", "YI", D1_S0_IDX, YI_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + %li = memref.load %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} + + {{ kernel.def_dma_op("MVIN", "YI", D1_S1_IDX, YI_S1_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + %ri = memref.load %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} + + memref.store %ri, %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} + {{ kernel.def_dma_op("MVOUT", "YI", D1_S0_IDX, YI_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + + memref.store %li, %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} + {{ kernel.def_dma_op("MVOUT", "YI", D1_S1_IDX, YI_S1_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + } + } + } + } +{% else %} + // Stable bubble sort on each column (dim=0). + affine.for %col = 0 to {{ COLS }} step 1 { + affine.for %pass = 0 to {{ ROWS }} step 1 { + affine.for %i = 0 to {{ ROWS_MINUS1 }} step 1 { + {{ kernel.def_dma_op("MVIN", OUT_DVAR, D0_S0_IDX, YV_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=10) }} + %lhs = memref.load %yv_sort_tile[%c0, %c0] : {{ YV_TILE_MEMREF_TYPE }} + + {{ kernel.def_dma_op("MVIN", OUT_DVAR, D0_S1_IDX, YV_S1_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=10) }} + %rhs = memref.load %yv_sort_tile[%c0, %c0] : {{ YV_TILE_MEMREF_TYPE }} + +{% if DESCENDING %} + %need_swap = arith.cmpf olt, %lhs, %rhs : {{ YV_ELEM_TYPE }} +{% else %} + %need_swap = arith.cmpf ogt, %lhs, %rhs : {{ YV_ELEM_TYPE }} +{% endif %} + scf.if %need_swap { + memref.store %rhs, %yv_sort_tile[%c0, %c0] : {{ YV_TILE_MEMREF_TYPE }} + {{ kernel.def_dma_op("MVOUT", OUT_DVAR, D0_S0_IDX, YV_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + + memref.store %lhs, %yv_sort_tile[%c0, %c0] : {{ YV_TILE_MEMREF_TYPE }} + {{ kernel.def_dma_op("MVOUT", OUT_DVAR, D0_S1_IDX, YV_S1_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + + {{ kernel.def_dma_op("MVIN", "YI", D0_S0_IDX, YI_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + %li = memref.load %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} + + {{ kernel.def_dma_op("MVIN", "YI", D0_S1_IDX, YI_S1_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + %ri = memref.load %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} + + memref.store %ri, %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} + {{ kernel.def_dma_op("MVOUT", "YI", D0_S0_IDX, YI_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + + memref.store %li, %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} + {{ kernel.def_dma_op("MVOUT", "YI", D0_S1_IDX, YI_S1_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} + } + } + } + } +{% endif %} + } { outer_loop=true } + return +} +""" + + +class MLIRSortTemplate(MLIRTemplate): + def __init__(self, input_nodes, layout, dim, descending=False, stable=False, indices_node=None, input_reorder=None): + super().__init__("kernel", input_nodes, layout, input_reorder) + self.dim = dim + self.descending = descending + self.stable = stable + self.indices_node = indices_node + + def render( + self, + kernel: MLIRTemplateKernel, + template_buffer_node=None, + epilogue_nodes: Optional[List[IRNode]] = None, + tile_info=None, + **kwargs, + ): + if template_buffer_node is not None: + self.output_node = template_buffer_node + if self.indices_node is None: + raise RuntimeError("MLIRSortTemplate requires indices output node") + + x = self.input_nodes[0] + yv = self.output_node + yi = self.indices_node + + def _as_int(v): + try: + return int(v) + except Exception: + return int(V.graph.sizevars.size_hint(v)) + + x_size = x.get_size() + if len(x_size) != 2: + raise RuntimeError("MLIRSortTemplate currently supports rank-2 input only") + if self.dim not in (0, 1): + raise RuntimeError(f"MLIRSortTemplate currently supports dim in {{0,1}} only, got dim={self.dim}") + + rows = _as_int(x_size[0]) + cols = _as_int(x_size[1]) + cols_minus1 = max(0, cols - 1) + rows_minus1 = max(0, rows - 1) + + x_dtype = x.get_dtype() + yv_dtype = yv.get_dtype() + yi_dtype = yi.get_dtype() + if x_dtype != yv_dtype: + raise RuntimeError("sort template requires input/value dtype match") + + yi_tile_desc = mlir_common.MLIRMultiDimTile([1, 1], kernel.vector_lane, vlane_split_axis=1, vlane_stride=1) + yi_tile_desc.set_tile_size_stride([1, 1], [1, 1]) + yi_tile_desc.set_name("yi_sort_tile") + yv_tile_desc = mlir_common.MLIRMultiDimTile([1, 1], kernel.vector_lane, vlane_split_axis=1, vlane_stride=1) + yv_tile_desc.set_tile_size_stride([1, 1], [1, 1]) + yv_tile_desc.set_name("yv_sort_tile") + # Neighbor element descriptors use DRAM offset to preserve affine stride metadata. + yv_s1_tile_desc = mlir_common.MLIRMultiDimTile([1, 1], kernel.vector_lane, vlane_split_axis=1, vlane_stride=1) + yv_s1_tile_desc.set_tile_size_stride([1, 1], [1, 1]) + yv_s1_tile_desc.set_name("yv_sort_tile") + yi_s1_tile_desc = mlir_common.MLIRMultiDimTile([1, 1], kernel.vector_lane, vlane_split_axis=1, vlane_stride=1) + yi_s1_tile_desc.set_tile_size_stride([1, 1], [1, 1]) + yi_s1_tile_desc.set_name("yi_sort_tile") + if int(self.dim) == 1: + yv_s1_tile_desc.offset = sympy.Integer(1) + yi_s1_tile_desc.offset = sympy.Integer(1) + else: + yv_s1_tile_desc.offset = sympy.Integer(cols) + yi_s1_tile_desc.offset = sympy.Integer(cols) + + row = sympy.Symbol("row") + col = sympy.Symbol("col") + i = sympy.Symbol("i") + j = sympy.Symbol("j") + + init_x_idx = [row * cols, col] + init_yv_idx = [row * cols, col] + init_yi_idx = [row * cols, col] + + d1_s0_idx = [row * cols, j] + d1_s1_idx = [row * cols, j] + + d0_s0_idx = [i * cols, col] + d0_s1_idx = [i * cols, col] + + kernel.loop_size = None + numel = rows * cols + kernel.render_options = dict( + KERNEL_NAME=self.name, + kernel=kernel, + X=x, + YV=yv, + YI=yi, + OUT_DVAR="YV", + NAMES_STR="X, YI, YV", + ROWS=rows, + COLS=cols, + COLS_MINUS1=cols_minus1, + ROWS_MINUS1=rows_minus1, + DIM=int(self.dim), + DESCENDING=bool(self.descending), + YI_TILE_DESC=yi_tile_desc, + YV_TILE_DESC=yv_tile_desc, + YI_S1_TILE_DESC=yi_s1_tile_desc, + YV_S1_TILE_DESC=yv_s1_tile_desc, + INIT_X_IDX=init_x_idx, + INIT_YV_IDX=init_yv_idx, + INIT_YI_IDX=init_yi_idx, + D1_S0_IDX=d1_s0_idx, + D1_S1_IDX=d1_s1_idx, + D0_S0_IDX=d0_s0_idx, + D0_S1_IDX=d0_s1_idx, + YV_ELEM_TYPE=mlir_common.DTYPE_TO_MLIR[yv_dtype], + YI_ELEM_TYPE=mlir_common.DTYPE_TO_MLIR[yi_dtype], + X_MEMREF_TYPE=f"memref<{numel}x{mlir_common.DTYPE_TO_MLIR[x_dtype]}>", + YV_MEMREF_TYPE=f"memref<{numel}x{mlir_common.DTYPE_TO_MLIR[yv_dtype]}>", + YI_MEMREF_TYPE=f"memref<{numel}x{mlir_common.DTYPE_TO_MLIR[yi_dtype]}>", + YV_TILE_MEMREF_TYPE=yv_tile_desc.get_mlir_shape(mlir_common.DTYPE_TO_MLIR[yv_dtype]), + YI_TILE_MEMREF_TYPE=yi_tile_desc.get_mlir_shape(mlir_common.DTYPE_TO_MLIR[yi_dtype]), + X_TILE_DESC=yv_tile_desc, + input_reorder=self.input_reorder, + ) + + output_node_name = yv.get_name() if hasattr(yv, "get_name") else yv.name + kernel.epilogue_info = dict( + output_node=output_node_name, + sram_var="yv_sort_tile", + dram_var=kernel.render_options["OUT_DVAR"], + dram_tile_desc=yv_tile_desc, + ) + kernel.exception_nodes[kernel.render_options["OUT_DVAR"]] = {"numel": yv.get_numel()} + kernel.exception_nodes["YI"] = {"numel": yi.get_numel()} + + code = self._template_from_string(TEMPLATE).render(**kernel.render_options) + return code diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index b1c756ba..76b0ef71 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -403,7 +403,7 @@ def call_kernel(self, kernel_name): _, call_args, _, _ = self.kernel_group.args.mlir_argdefs() # generate the code to call this wrapper.generate_kernel_call( - kernel_name if self.outer_func_name is None else "wrapper_" + kernel_name, call_args) + kernel_name if self.outer_func_name is None else self.outer_func_name + f"_{len(call_args)}", call_args) def codegen_template_code(self, render, template_node, prologue_nodes, epilogue_nodes, tile_info): with self as kernel: @@ -628,8 +628,26 @@ def def_kernel( self.buffer_names[node.get_name()] = self.epilogue_info['sram_var'] def hook(): - arg_defs, *_ = self.kernel_group.args.mlir_argdefs(extra_node=extra_node) - return f"({', '.join(arg_defs)})" + arg_defs, call_args, *_ = self.kernel_group.args.mlir_argdefs(extra_node=extra_node) + output_names = names[len(inputs) : len(inputs) + len(outputs)] + out_ptr_idx = 0 + renamed_arg_defs = [] + for outer, arg_def in zip(call_args, arg_defs): + raw_symbol = arg_def.split(":", 1)[0].strip().lstrip("%") + if outer in self.kernel_group.args.input_buffers: + symbol = self.kernel_group.args.input_buffers[outer] + elif outer in self.kernel_group.args.output_buffers: + symbol = self.kernel_group.args.output_buffers[outer] + elif raw_symbol.startswith("out_ptr") and out_ptr_idx < len(output_names): + symbol = output_names[out_ptr_idx] + out_ptr_idx += 1 + elif outer in self.kernel_group.args.sizevars: + symbol = self.kernel_group.args.sizevars[outer] + else: + symbol = raw_symbol + _, arg_type = arg_def.split(":", 1) + renamed_arg_defs.append(f"%{symbol}:{arg_type}") + return f"({', '.join(renamed_arg_defs)})" assert "" not in self.render_hooks self.render_hooks[""] = hook @@ -1151,6 +1169,8 @@ def __init__(self, name, input_nodes, layout, input_reorder = None): super().__init__(name) self.input_nodes = [node for node in input_nodes if node is not None] self.output_node: Buffer = Buffer(name="buf_out", layout=layout) + # Multi-output templates can override this with explicit output buffers. + self.output_nodes = [self.output_node] self.input_reorder = input_reorder self.layout = layout @@ -1166,10 +1186,12 @@ def generate(self, **kwargs) -> ChoiceCaller: kernel_hash_name = f"mlir_{self.name}_{next(self.index_counter)}" extra_args = [] # create the BenchmarkRequest + output_nodes = getattr(self, "output_nodes", None) or [self.output_node] + bmreq = MLIRBenchmarkRequest( kernel_name=kernel_name, input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes), - output_tensor_meta=TensorMeta.from_irnodes(self.output_node), + output_tensor_meta=TensorMeta.from_irnodes(output_nodes), extra_args=extra_args, source_code=code, ) diff --git a/tests/DeepSeek/test_deepseek_v3_base.py b/tests/DeepSeek/test_deepseek_v3_base.py index b8402c8b..ade787c5 100644 --- a/tests/DeepSeek/test_deepseek_v3_base.py +++ b/tests/DeepSeek/test_deepseek_v3_base.py @@ -1,8 +1,55 @@ import os import sys import argparse +import copy +from pathlib import Path import torch +# recursive compile for some ops that are caused by graph break +torch.npu.register_eager_to_compile([ + "aten::zero_", + "aten::sum.IntList_out", + "aten::mul.out", + "aten::floor_divide", + "aten::floor_divide.Tensor", + "aten::floor_divide.Scalar", + "aten::cat.out", + "aten::sort.values_stable", +]) + + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + out_cpu = out.cpu() + max_diff = (out_cpu - cpu_out).abs().max().item() + mean_diff = (out_cpu - cpu_out).abs().mean().item() + if torch.allclose(out_cpu, cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print(f"Max absolute difference: {max_diff:.6f}") + print(f"Mean absolute difference: {mean_diff:.6f}") + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("NPU out: ", out_cpu) + print("CPU out: ", cpu_out) + print(f"Max absolute difference: {max_diff:.6f}") + print(f"Mean absolute difference: {mean_diff:.6f}") + exit(1) + + +def _extract_logits(output): + if isinstance(output, torch.Tensor): + return output + if hasattr(output, "logits"): + return output.logits + if isinstance(output, (list, tuple)) and len(output) > 0 and isinstance(output[0], torch.Tensor): + return output[0] + raise TypeError(f"Unsupported output type for comparison: {type(output)}") + def _dtype_from_str(name: str) -> torch.dtype: return { @@ -81,7 +128,7 @@ def _maybe_scale_config(config, scale=1.0, max_layers=None): def _apply_preset(scale, max_layers, batch, seq_len, preset): if preset == "tiny": - return 0.03, 4, 1, min(seq_len, 16) + return 0.03, 1, 1, min(seq_len, 16) if preset == "small": return 0.07, 8, 1, min(seq_len, 32) if preset == "medium": @@ -89,8 +136,58 @@ def _apply_preset(scale, max_layers, batch, seq_len, preset): return scale, max_layers, batch, seq_len +def _togsim_log_count() -> int: + log_dir = Path("togsim_results") + if not log_dir.exists(): + return 0 + return len(list(log_dir.glob("*.log"))) + + +def _assert_simulation_happened(before_count: int, case_name: str): + after_count = _togsim_log_count() + if after_count <= before_count: + raise RuntimeError( + f"{case_name}: TOGSim log count did not increase " + f"(before={before_count}, after={after_count})" + ) + print(f"{case_name}: TOGSim logs increased ({before_count} -> {after_count})") + + +def test_cat_default(device): + def cat_default_fn(a, b): + return torch.cat([a, b], dim=0) + + x = torch.randn(8, 16, device=device) + y = torch.randn(6, 16, device=device) + opt_fn = torch.compile(dynamic=False)(cat_default_fn) + + before = _togsim_log_count() + out = opt_fn(x, y) + _assert_simulation_happened(before, "cat.default") + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=0) + test_result("cat.default", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_out(device): + def cat_out_fn(a, b, out): + return torch.ops.aten.cat.out([a, b], 0, out=out) + + x = torch.randn(8, 16, device=device) + y = torch.randn(6, 16, device=device) + out_buf = torch.empty(14, 16, device=device) + opt_fn = torch.compile(dynamic=False)(cat_out_fn) + + before = _togsim_log_count() + out = opt_fn(x, y, out_buf) + _assert_simulation_happened(before, "cat.out") + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=0) + test_result("cat.out", out, cpu_out, rtol=1e-4, atol=1e-4) + + @torch.no_grad() -def run_deep_seek_v3_base_test( +def run_deepseek_v3_base( model_id, device, init_mode="config-random", @@ -120,7 +217,6 @@ def run_deep_seek_v3_base_test( # (call .to_dict()), so only disable it for pretrained loading path. if init_mode == "pretrained" and getattr(config, "quantization_config", None) is not None: config.quantization_config = None - config = _maybe_scale_config(config, scale=scale, max_layers=max_layers) if init_mode == "config-random": @@ -141,7 +237,6 @@ def run_deep_seek_v3_base_test( else: raise ValueError(f"Unsupported init mode: {init_mode}") - model = model.to(device) model_params = sum(p.numel() for p in model.parameters()) print("init mode:", init_mode) print("scaled hidden_size:", getattr(config, "hidden_size", "n/a")) @@ -157,23 +252,33 @@ def run_deep_seek_v3_base_test( revision=revision, ) encoded = tokenizer(prompt, return_tensors="pt") - input_ids = encoded["input_ids"].to(device) + cpu_input_ids = encoded["input_ids"].cpu() else: vocab_size = getattr(config, "vocab_size", None) if vocab_size is None: raise ValueError("Config has no vocab_size; use --use-tokenizer or pass a model with vocab_size.") - input_ids = _build_random_inputs(batch, seq_len, vocab_size, device) + cpu_input_ids = _build_random_inputs(batch, seq_len, vocab_size, torch.device("cpu")) + input_ids = cpu_input_ids.to(device) - if compile_model: - model = torch.compile(model, dynamic=False) + # CPU version + model_cpu = copy.deepcopy(model).cpu().eval() + cpu_out = _extract_logits(model_cpu(cpu_input_ids)) - out = model(input_ids) - logits = out.logits + # NPU version + model_npu = copy.deepcopy(model_cpu).to(device).eval() + if compile_model: + model_npu = torch.compile(model_npu, dynamic=False) + npu_out = _extract_logits(model_npu(input_ids)) + + # Campare results + test_result( + "DeepSeek V3 Base", + npu_out, + cpu_out, + rtol=3e-1, + atol=2e-1, + ) - print("logits shape:", tuple(logits.shape)) - print("logits dtype:", logits.dtype) - print("logits max:", logits.max().item()) - if __name__ == "__main__": parser = argparse.ArgumentParser(description="DeepSeek V3 download-based test") @@ -181,7 +286,7 @@ def run_deep_seek_v3_base_test( parser.add_argument("--revision", type=str, default=None) parser.add_argument("--trust-remote-code", action="store_true", default=True) parser.add_argument("--init-mode", type=str, default="config-random", choices=["config-random", "pretrained"]) - parser.add_argument("--preset", type=str, default="tiny", choices=["none", "tiny", "small", "medium"]) + parser.add_argument("--preset", type=str, default="small", choices=["none", "tiny", "small", "medium"]) parser.add_argument("--scale", type=float, default=1.0) parser.add_argument("--max-layers", type=int, default=None) parser.add_argument("--dtype", type=str, default="float32", choices=["float32", "float16", "bfloat16"]) @@ -190,6 +295,7 @@ def run_deep_seek_v3_base_test( parser.add_argument("--use-tokenizer", action="store_true") parser.add_argument("--prompt", type=str, default="Hello, DeepSeek V3") parser.add_argument("--compile", action="store_true", default=True) + parser.add_argument("--test", type=str, default="e2e", choices=["all", "e2e", "cat"]) args = parser.parse_args() @@ -203,18 +309,22 @@ def run_deep_seek_v3_base_test( device = torch.device("npu:0") - run_deep_seek_v3_base_test( - model_id=args.model_id, - device=device, - init_mode=args.init_mode, - scale=args.scale, - max_layers=args.max_layers, - dtype=args.dtype, - batch=args.batch, - seq_len=args.seq_len, - use_tokenizer=args.use_tokenizer, - prompt=args.prompt, - trust_remote_code=args.trust_remote_code, - revision=args.revision, - compile_model=args.compile, - ) + if args.test in ("all", "cat"): + test_cat_default(device) + test_cat_out(device) + if args.test in ("all", "e2e"): + run_deepseek_v3_base( + model_id=args.model_id, + device=device, + init_mode=args.init_mode, + scale=args.scale, + max_layers=args.max_layers, + dtype=args.dtype, + batch=args.batch, + seq_len=args.seq_len, + use_tokenizer=args.use_tokenizer, + prompt=args.prompt, + trust_remote_code=args.trust_remote_code, + revision=args.revision, + compile_model=args.compile, + ) diff --git a/tests/test_cat.py b/tests/test_cat.py new file mode 100644 index 00000000..32573a05 --- /dev/null +++ b/tests/test_cat.py @@ -0,0 +1,89 @@ +import argparse +from pathlib import Path + +import torch + + +def _test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + return + + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out: ", out.cpu()) + print("cpu out: ", cpu_out) + raise RuntimeError(f"{name} mismatch") + + +def _togsim_log_count() -> int: + log_dir = Path("togsim_results") + if not log_dir.exists(): + return 0 + return len(list(log_dir.glob("*.log"))) + + +def _assert_simulation_happened(before_count: int, case_name: str): + after_count = _togsim_log_count() + if after_count <= before_count: + raise RuntimeError( + f"{case_name}: TOGSim log count did not increase " + f"(before={before_count}, after={after_count})" + ) + print(f"{case_name}: TOGSim logs increased ({before_count} -> {after_count})") + + +def test_cat_default(device): + def cat_default_fn(a, b): + return torch.cat([a, b], dim=0) + + x = torch.randn(8, 16, device=device) + y = torch.randn(6, 16, device=device) + opt_fn = torch.compile(dynamic=False)(cat_default_fn) + + before = _togsim_log_count() + out = opt_fn(x, y) + _assert_simulation_happened(before, "cat.default") + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=0) + _test_result("cat.default", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_out(device): + def cat_out_fn(a, b, out): + return torch.ops.aten.cat.out([a, b], 0, out=out) + + x = torch.randn(8, 16, device=device) + y = torch.randn(6, 16, device=device) + out_buf = torch.empty(14, 16, device=device) + opt_fn = torch.compile(dynamic=False)(cat_out_fn) + + before = _togsim_log_count() + out = opt_fn(x, y, out_buf) + _assert_simulation_happened(before, "cat.out") + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=0) + _test_result("cat.out", out, cpu_out, rtol=1e-4, atol=1e-4) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run cat simulation tests") + parser.add_argument( + "--case", + choices=["default", "out", "all"], + default="all", + help="Which cat case to run", + ) + args = parser.parse_args() + + device = torch.device("npu:0") + + if args.case in ("default", "all"): + test_cat_default(device) + if args.case in ("out", "all"): + test_cat_out(device) diff --git a/tests/test_sort.py b/tests/test_sort.py new file mode 100644 index 00000000..2b070223 --- /dev/null +++ b/tests/test_sort.py @@ -0,0 +1,112 @@ +import argparse +import torch +import torch._dynamo +import torch.utils.cpp_extension + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out:", out.cpu()) + print("cpu out:", cpu_out) + raise SystemExit(1) + + +def test_equal(name, out, cpu_out): + if torch.equal(out.cpu(), cpu_out): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out:", out.cpu()) + print("cpu out:", cpu_out) + raise SystemExit(1) + + +def _normalize_dim(dim: int, rank: int) -> int: + d = dim if dim >= 0 else rank + dim + if d < 0 or d >= rank: + raise ValueError(f"dim out of range: dim={dim}, rank={rank}") + return d + + +def test_sort_stable(device, size=(128, 128), dim=-1, descending=False): + _normalize_dim(dim, len(size)) + + def sort_stable_fn(x): + return torch.sort(x, stable=True, dim=dim, descending=descending) + + x = torch.randn(size, dtype=torch.float32) + x_npu = x.to(device=device) + + opt_sort = torch.compile(dynamic=False)(sort_stable_fn) + out_values, out_indices = opt_sort(x_npu) + + ref_values, ref_indices = torch.sort(x, stable=True, dim=dim, descending=descending) + + test_result("Sort.stable/values", out_values, ref_values) + test_equal("Sort.stable/indices", out_indices, ref_indices) + + +def test_sort_values_stable(device, size=(128, 128), dim=-1, descending=False): + _normalize_dim(dim, len(size)) + + def sort_out_fn(x): + out_values = torch.empty_like(x, device=x.device) + out_indices = torch.empty_like(x, dtype=torch.int64, device=x.device) + return torch.sort(x, stable=True, dim=dim, descending=descending, out=(out_values, out_indices)) + + x = torch.randn(size, dtype=torch.float32) + x_npu = x.to(device=device) + + opt_sort = sort_out_fn# torch.compile(dynamic=False)(sort_out_fn) + out_values, out_indices = opt_sort(x_npu) + + ref_values, ref_indices = torch.sort(x, stable=True, dim=dim, descending=descending) + + test_result("Sort.values_stable/values", out_values, ref_values) + test_equal("Sort.values_stable/indices", out_indices, ref_indices) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run sort tests") + parser.add_argument("--shape", type=str, default="(128,128)") + parser.add_argument("--dim", type=int, default=0) + parser.add_argument("--descending", action="store_true") + parser.add_argument( + "--mode", + type=str, + default="all", + choices=["all", "default", "values"], + ) + args = parser.parse_args() + + shape = tuple(map(int, args.shape.strip("()").split(","))) + + from Scheduler.scheduler import PyTorchSimRunner + + module = PyTorchSimRunner.setup_device() + device = module.custom_device() + + # Register recursive-compile bridge only when values_stable path is explicitly tested. + if args.mode in ("all", "values"): + torch.npu.register_eager_to_compile([ + "aten::sort.values_stable", + ]) + + if args.mode in ("all", "default"): + test_sort_stable(device, size=shape, dim=args.dim, descending=args.descending) + if args.mode in ("all", "values"): + test_sort_values_stable(device, size=shape, dim=args.dim, descending=args.descending) From f615178ae581236a1b4d1018f9b458b2c552179f Mon Sep 17 00:00:00 2001 From: jung-min Date: Wed, 4 Mar 2026 07:57:47 +0000 Subject: [PATCH 03/31] [Fix] Prevent fallback to eager mode after reaching compilation limit (7) --- tests/test_sdpa.py | 31 ++++++++++--------------------- 1 file changed, 10 insertions(+), 21 deletions(-) diff --git a/tests/test_sdpa.py b/tests/test_sdpa.py index 9c921eb4..6ffd6f2e 100644 --- a/tests/test_sdpa.py +++ b/tests/test_sdpa.py @@ -14,6 +14,7 @@ def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): print("-" * len(message)) print(message) print("-" * len(message)) + pass else: print("custom out: ", out.cpu()) print("cpu out: ", cpu_out) @@ -31,35 +32,25 @@ def test_scaled_dot_product_attention(device, backends="flash"): for n_token in n_token_list: for head_dim in head_dim_list: # Inputs + clear_caches() query = torch.rand(n_batch, n_head, n_token, head_dim, dtype=torch.float32) key = torch.rand(n_batch, n_head, n_token, head_dim, dtype=torch.float32) value = torch.rand(n_batch, n_head, n_token, head_dim, dtype=torch.float32) + # With NPU query = query.to(device=device) key = key.to(device=device) value = value.to(device=device) - # With NPU - if backends == "flash": - backends = [SDPBackend.FLASH_ATTENTION] - elif backends == "math": - backends = [SDPBackend.MATH] - elif backends == "memory_efficient": - backends = [SDPBackend.EFFICIENT_ATTENTION] - else: - backends = [SDPBackend.FLASH_ATTENTION, SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION] - - with sdpa_kernel(backends=backends): - opt_fn = torch.compile(dynamic=False)(F.scaled_dot_product_attention) - out = opt_fn(query, key, value) - + opt_fn = torch.compile(dynamic=False)(F.scaled_dot_product_attention) + out = opt_fn(query, key, value) out = out.to(device) # With CPU - device = torch.device('cpu') - query = query.to(device=device) - key = key.to(device=device) - value = value.to(device=device) + cpu_device = torch.device('cpu') + query = query.to(device=cpu_device) + key = key.to(device=cpu_device) + value = value.to(device=cpu_device) cpu_out = F.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) name = f"SDPA(n_batch: {n_batch}, n_head: {n_head}, n_token: {n_token}, head_dim: {head_dim})" @@ -76,9 +67,7 @@ def clear_caches(): os.environ["TORCHINDUCTOR_CACHE"] = "0" FxGraphCache.clear() -if __name__ == "__main__": - clear_caches() - +if __name__ == "__main__": device = torch.device('npu:0') test_scaled_dot_product_attention(device, backends="flash") \ No newline at end of file From 8ca5d02d599d06725b90963ee44701cb50e8f444 Mon Sep 17 00:00:00 2001 From: jung-min Date: Wed, 4 Mar 2026 08:09:28 +0000 Subject: [PATCH 04/31] [FIX] Add idx_map to the first matmul for logical consistency --- PyTorchSimFrontend/mlir/mlir_sdpa_template.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/PyTorchSimFrontend/mlir/mlir_sdpa_template.py b/PyTorchSimFrontend/mlir/mlir_sdpa_template.py index b3d88cc6..49c6c6bb 100644 --- a/PyTorchSimFrontend/mlir/mlir_sdpa_template.py +++ b/PyTorchSimFrontend/mlir/mlir_sdpa_template.py @@ -339,6 +339,7 @@ def patched_scaled_dot_product_attention( // key @ query.t and scaling. linalg.matmul + { idx_map = array } ins(%k_buffer2D, %qt_buffer2D : memref<{{ tile_s }}x{{ tile_e }}x{{ data_stype }}, 1>, memref<{{ tile_e }}x{{ tile_l }}x{{ data_stype }}, 1>) outs(%mul_buffer : {{ mul_tile_desc.get_mlir_shape(data_stype) }}) @@ -451,7 +452,7 @@ def render(self, prologue_nodes: Optional[List[IRNode]] = None, tile_info = None, **kwargs): - + # Except for kernel, other arguments are usually None. query, key, value, out, q_tensor, k_tensor, v_tensor, out_tensor, b, l, s, e, ev, n_extra_node, n_prologue_node = self.extract_info(template_buffer_node, epilogue_nodes, prologue_nodes) From 41288bc2d300305d91559ae49a67f11984f789c0 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Tue, 3 Mar 2026 16:40:57 +0900 Subject: [PATCH 05/31] [Template] Polish template kernel of cat operation --- .../torch_openreg/openreg/__init__.py | 49 --- PyTorchSimFrontend/mlir/mlir_bmm_template.py | 3 + PyTorchSimFrontend/mlir/mlir_cat_template.py | 369 ++++++++++++------ PyTorchSimFrontend/mlir/mlir_conv_common.py | 3 + PyTorchSimFrontend/mlir/mlir_gemm_template.py | 3 + PyTorchSimFrontend/mlir/mlir_lowering.py | 118 +----- PyTorchSimFrontend/mlir/mlir_scheduling.py | 22 +- PyTorchSimFrontend/mlir/mlir_template.py | 43 +- tests/test_cat.py | 143 +++++-- 9 files changed, 424 insertions(+), 329 deletions(-) diff --git a/PyTorchSimDevice/torch_openreg/openreg/__init__.py b/PyTorchSimDevice/torch_openreg/openreg/__init__.py index 5603a4f7..f5aabc18 100644 --- a/PyTorchSimDevice/torch_openreg/openreg/__init__.py +++ b/PyTorchSimDevice/torch_openreg/openreg/__init__.py @@ -256,52 +256,6 @@ def launch_model(model, *args, stream_index=0, timestamp=0, **kwargs): from .random import * # noqa: F403 from .amp import * -def _precheck_cat_out_args(args, kwargs): - tensors = args[0] if len(args) > 0 else kwargs.get("tensors") - dim = args[1] if len(args) > 1 else kwargs.get("dim", 0) - out = kwargs.get("out", args[2] if len(args) > 2 else None) - - if out is None: - return - if not isinstance(tensors, (list, tuple)) or len(tensors) == 0: - raise RuntimeError("aten::cat.out requires non-empty tensor list") - if not all(isinstance(t, torch.Tensor) for t in tensors): - raise RuntimeError("aten::cat.out tensors must be Tensor values") - if not isinstance(out, torch.Tensor): - raise RuntimeError("aten::cat.out out must be a Tensor") - - rank = tensors[0].dim() - if rank == 0: - raise RuntimeError("aten::cat.out does not support scalar inputs") - if dim < 0: - dim += rank - if dim < 0 or dim >= rank: - raise RuntimeError(f"aten::cat.out dim out of range: dim={dim}, rank={rank}") - if any(t.dim() != rank for t in tensors): - raise RuntimeError("aten::cat.out inputs must have the same rank") - if any(t.dtype != tensors[0].dtype for t in tensors): - raise RuntimeError("aten::cat.out inputs must have the same dtype") - if out.dim() != rank: - raise RuntimeError("aten::cat.out out rank mismatch") - - for d in range(rank): - if d == dim: - continue - base = tensors[0].shape[d] - if any(t.shape[d] != base for t in tensors[1:]): - raise RuntimeError( - f"aten::cat.out non-concatenated dimension mismatch at dim={d}" - ) - if out.shape[d] != base: - raise RuntimeError(f"aten::cat.out out shape mismatch at dim={d}") - - expected = sum(t.shape[dim] for t in tensors) - if out.shape[dim] != expected: - raise RuntimeError( - f"aten::cat.out out concatenated dimension mismatch at dim={dim}: " - f"expected {expected}, got {out.shape[dim]}" - ) - def eager_to_compile(op_name): """ Register an eager mode operation as a graph-based implementation using torch.compile(). @@ -313,9 +267,6 @@ def eager_to_compile(op_name): torch.npu.eager_to_compile("aten::mul.Tensor") """ def wrapper(*args, **kwargs): - if op_name == "aten::cat.out": - _precheck_cat_out_args(args, kwargs) - @torch.compile(dynamic=False) def dummy_graph(*args, **kwargs): # Convert "aten::mul.Tensor" -> torch.ops.aten.mul.Tensor diff --git a/PyTorchSimFrontend/mlir/mlir_bmm_template.py b/PyTorchSimFrontend/mlir/mlir_bmm_template.py index 178ea987..9398f90c 100644 --- a/PyTorchSimFrontend/mlir/mlir_bmm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_bmm_template.py @@ -154,6 +154,9 @@ class MLIRBMMTemplate(MLIRTemplate): def __init__(self, input_nodes, layout, input_reorder=None): super().__init__("kernel", input_nodes, layout, input_reorder) + self.support_epilogue_fusion = True + self.support_prologue_fusion = True + self.support_reduction_fusion = True def render(self, kernel: MLIRTemplateKernel, diff --git a/PyTorchSimFrontend/mlir/mlir_cat_template.py b/PyTorchSimFrontend/mlir/mlir_cat_template.py index 996af1de..d68af7d4 100644 --- a/PyTorchSimFrontend/mlir/mlir_cat_template.py +++ b/PyTorchSimFrontend/mlir/mlir_cat_template.py @@ -1,8 +1,9 @@ -from typing import List, Optional, cast +from typing import List, Optional +import math +import itertools import sympy -from torch._inductor.ir import Buffer, IRNode -from torch._inductor.virtualized import V +from torch._inductor.ir import IRNode from PyTorchSimFrontend.mlir import mlir_common from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate, MLIRTemplateKernel @@ -10,40 +11,28 @@ TEMPLATE = r""" {{kernel.def_global_vars()}} - -func.func @{{ KERNEL_NAME }} {{kernel.def_kernel(inputs=[X0, X1], outputs=[Y], names_str=NAMES_STR, input_reorder=input_reorder)}} { - {{ kernel.def_sram_buffer("X0", X0_TILE_DESC, id=0, indent_size=2) }} - {{ kernel.def_sram_buffer("X1", X1_TILE_DESC, id=1, indent_size=2) }} - {{ kernel.def_sram_buffer(OUT_DVAR, Y_TILE_DESC, id=2, indent_size=2) }} +func.func @{{ KERNEL_NAME }} {{kernel.def_kernel(inputs=INPUT_NAMES, outputs=[Y], names_str=NAMES_STR, input_reorder=input_reorder)}} { +{%- for buffer_name, tile_desc in UNIQUE_BUFFER_TILE_DESCS.items() %} + {{ kernel.def_sram_buffer(buffer_name, tile_desc, indent_size=2) }} +{%- endfor %} {{ kernel.def_local_vars(indent_size=2) }} affine.for %cat_block = 0 to 1 step 1 { -{% if DIM == 0 %} - affine.for %index0 = 0 to {{ X0_ROWS }} step 1 { - affine.for %index1 = 0 to {{ COLS }} step 1 { - {{ kernel.def_dma_op("MVIN", "X0", X0_IDX, X0_TILE_DESC, indent_size=8) }} - {{ kernel.def_dma_op("MVOUT", OUT_DVAR, Y0_IDX, X0_TILE_DESC, indent_size=8) }} - } - } - - affine.for %index2 = 0 to {{ X1_ROWS }} step 1 { - affine.for %index3 = 0 to {{ COLS }} step 1 { - {{ kernel.def_dma_op("MVIN", "X1", X1_IDX, X1_TILE_DESC, indent_size=8) }} - {{ kernel.def_dma_op("MVOUT", OUT_DVAR, Y1_IDX, X1_TILE_DESC, indent_size=8) }} - } - } -{% else %} - affine.for %index0 = 0 to {{ ROWS }} step 1 { - affine.for %index1 = 0 to {{ X0_COLS }} step 1 { - {{ kernel.def_dma_op("MVIN", "X0", X0_IDX, X0_TILE_DESC, indent_size=8) }} - {{ kernel.def_dma_op("MVOUT", OUT_DVAR, Y0_IDX, X0_TILE_DESC, indent_size=8) }} - } - affine.for %index3 = 0 to {{ X1_COLS }} step 1 { - {{ kernel.def_dma_op("MVIN", "X1", X1_IDX, X1_TILE_DESC, indent_size=8) }} - {{ kernel.def_dma_op("MVOUT", OUT_DVAR, Y1_IDX, X1_TILE_DESC, indent_size=8) }} - } - } -{% endif %} +{%- for d in range(RANK-1) %} + affine.for %index{{ OUTPUT_DIM[d] }} = 0 to {{ OUTPUT_SIZES[d] }} step {{ TILE_SIZES[d] }} { +{%- endfor %} +{%- for i in range(NUM_INPUTS) %} + // Input tensor{{ i }} + affine.for %index_local{{ DIM }}_{{ i }} = 0 to {{ INPUT_SIZES[i][DIM] }} step {{ INPUT_TILE_SIZES_DIM[i] }} { + %index{{ DIM }}_{{i}} = affine.apply affine_map<(d0) -> (d0 + {{ CUMULATIVE_OFFSETS[i] }})> (%index_local{{ DIM }}_{{ i }}) + {{ kernel.def_dma_op("MVIN", INPUT_BUFFER_NAMES[i], INPUT_IDXS[i], INPUT_TILE_DESCS[i], indent_size=INDENT_SIZE) }} + {{ kernel.def_dma_op("MVOUT", OUT_DVAR, OUTPUT_IDXS[i], INPUT_TILE_DESCS[i], indent_size=INDENT_SIZE) }} + } { inner_loop=true } +{%- endfor %} + +{%- for d in range(RANK-1) %} + } { outer_loop=true } +{%- endfor %} } { outer_loop=true } return } @@ -51,8 +40,8 @@ class MLIRCatTemplate(MLIRTemplate): - def __init__(self, input_nodes, layout, dim, input_reorder=None): - super().__init__("kernel", input_nodes, layout, input_reorder) + def __init__(self, input_nodes, layout, dim): + super().__init__("kernel", input_nodes, layout) self.dim = dim def render( @@ -66,87 +55,248 @@ def render( is_out_variant = template_buffer_node is not None if is_out_variant: self.output_node = template_buffer_node - # cat template currently emits a single output buffer and does not - # support epilogue output remapping. - - def _unwrap_node(n): - return n.node if hasattr(n, "node") else n - - x0 = _unwrap_node(self.input_nodes[0]) - x1 = _unwrap_node(self.input_nodes[1]) - y = _unwrap_node(self.output_node) - - def _as_int(v): - try: - return int(v) - except Exception: - return int(V.graph.sizevars.size_hint(v)) - - x0_rows = _as_int(x0.get_size()[0]) - x1_rows = _as_int(x1.get_size()[0]) - x0_cols = _as_int(x0.get_size()[1]) - x1_cols = _as_int(x1.get_size()[1]) - y_cols = _as_int(y.get_size()[1]) - kernel.loop_size = None - - # 2D cat template with contiguous layout. - x0_tile_desc = mlir_common.MLIRMultiDimTile([1, 1], kernel.vector_lane, vlane_split_axis=1, vlane_stride=1) - x0_tile_desc.set_tile_size_stride([1, 1], [1, 1]) - x0_tile_desc.set_name("x0_cat_tile") - x1_tile_desc = mlir_common.MLIRMultiDimTile([1, 1], kernel.vector_lane, vlane_split_axis=1, vlane_stride=1) - x1_tile_desc.set_tile_size_stride([1, 1], [1, 1]) - x1_tile_desc.set_name("x1_cat_tile") - y_tile_desc = mlir_common.MLIRMultiDimTile([1, 1], kernel.vector_lane, vlane_split_axis=1, vlane_stride=1) - y_tile_desc.set_tile_size_stride([1, 1], [1, 1]) - y_tile_desc.set_name("y_cat_tile") - if self.dim == 0: - # Flattened offsets for dim=0 cat. - x0_idx = [sympy.Symbol("index0") * x0_cols, sympy.Symbol("index1")] - x1_idx = [sympy.Symbol("index2") * x1_cols, sympy.Symbol("index3")] - y0_idx = [sympy.Symbol("index0") * y_cols, sympy.Symbol("index1")] - y1_idx = [(sympy.Symbol("index2") + x0_rows) * y_cols, sympy.Symbol("index3")] - else: - # Flattened offsets for dim=1 cat. - x0_idx = [sympy.Symbol("index0") * x0_cols, sympy.Symbol("index1")] - x1_idx = [sympy.Symbol("index0") * x1_cols, sympy.Symbol("index3")] - y0_idx = [sympy.Symbol("index0") * y_cols, sympy.Symbol("index1")] - y1_idx = [sympy.Symbol("index0") * y_cols, sympy.Symbol("index3") + x0_cols] + # Extract info + input_nodes = self.input_nodes + y = self.output_node + num_inputs = len(self.input_nodes) + rank = len(y.get_size()) + + input_sizes = [x.get_size() for x in input_nodes] + output_sizes = [sz for dim, sz in enumerate(y.get_size()) if dim != self.dim] + output_dim = [dim for dim, sz in enumerate(y.get_size()) if dim != self.dim] + tile_sizes = tile_info if tile_info is not None else [1] * len(output_sizes) + output_strides = y.get_layout().stride + + # Calculate input tile sizes + input_tile_sizes_dim = self._calculate_input_tile_sizes( + kernel, input_sizes, tile_sizes, num_inputs, rank + ) + buffer_name_to_template_name, input_buffer_names = self._build_buffer_mapping(input_nodes) + input_tile_descs, unique_tile_descs = self._build_tile_descriptors( + kernel, input_nodes, input_sizes, input_tile_sizes_dim, tile_sizes, rank, input_buffer_names + ) + y_tile_desc = self._build_output_tile_desc( + kernel, input_tile_sizes_dim, tile_sizes, rank + ) + + input_idxs, output_idxs, cumulative_offsets = self._build_index_expressions( + input_nodes, input_sizes, output_strides, rank, num_inputs + ) + + # Map unique buffer names to their tile descriptors for template + unique_buffer_tile_descs = {} + for actual_name, template_name in buffer_name_to_template_name.items(): + if actual_name in unique_tile_descs: + unique_buffer_tile_descs[template_name] = unique_tile_descs[actual_name] + + names_str = ", ".join(input_buffer_names + ["out_ptr1" if is_out_variant else "Y"]) + indent_size = 2 + (rank - 1) * 2 + 4 kernel.render_options = dict( KERNEL_NAME=self.name, kernel=kernel, - X0=x0, - X1=x1, Y=y, OUT_DVAR="out_ptr1" if is_out_variant else "Y", - NAMES_STR="X0, X1, out_ptr1" if is_out_variant else "X0, X1, Y", + NAMES_STR=names_str, + INPUT_NAMES=input_nodes, + INPUT_BUFFER_NAMES=input_buffer_names, + NUM_INPUTS=num_inputs, + RANK=rank, DIM=self.dim, - X0_ROWS=x0_rows, - X1_ROWS=x1_rows, - ROWS=x0_rows, - X0_COLS=x0_cols, - X1_COLS=x1_cols, - COLS=x0_cols, - X0_TILE_DESC=x0_tile_desc, - X1_TILE_DESC=x1_tile_desc, - Y_TILE_DESC=y_tile_desc, - X0_IDX=x0_idx, - X1_IDX=x1_idx, - Y0_IDX=y0_idx, - Y1_IDX=y1_idx, + INPUT_SIZES=input_sizes, + OUTPUT_SIZES=output_sizes, + OUTPUT_DIM=output_dim, + TILE_SIZES=tile_sizes, + INPUT_TILE_SIZES_DIM=input_tile_sizes_dim, + INPUT_TILE_DESCS=input_tile_descs, + UNIQUE_BUFFER_TILE_DESCS=unique_buffer_tile_descs, + INPUT_IDXS=input_idxs, + OUTPUT_IDXS=output_idxs, + CUMULATIVE_OFFSETS=cumulative_offsets, + INDENT_SIZE=indent_size, input_reorder=self.input_reorder, ) - # Needed when epilogue fusion requests set_ranges(). - kernel.dim_aliasing = {"index0": "index0", "index1": "index1"} - if hasattr(self.output_node, "node") and hasattr(self.output_node.node, "get_name"): - output_node_name = self.output_node.node.get_name() - elif hasattr(self.output_node, "get_name"): - output_node_name = self.output_node.get_name() - else: - output_node_name = self.output_node.name + self._setup_epilogue_info(kernel, y) + code = self._template_from_string(TEMPLATE).render(**kernel.render_options) + return code + + def get_tile_candidates( + self, + kernel: MLIRTemplateKernel, + template_buffer_node=None, + epilogue_nodes: Optional[List[IRNode]] = None, + **kwargs, + ): + """Generate tile candidates for cat operation. Concat dimension always has tile size 1.""" + if template_buffer_node is not None: + self.output_node = template_buffer_node + + y = self.output_node + num_inputs = len(self.input_nodes) + output_sizes = [sz for dim, sz in enumerate(y.get_size()) if dim != self.dim] + num_non_dim_dims = len(output_sizes) + + if num_non_dim_dims == 0: + return [[1]] + + tile_candidates = [] + dim_tile_candidates = [] + + for dim_size in output_sizes: + dim_candidates = [] + max_tile = min(dim_size, kernel.spad_info["spad_size"] // (kernel.vector_lane * kernel.precision * 2 * num_inputs)) + + for mult in range(1, max_tile // kernel.vector_lane + 1): + tile = mult * kernel.vector_lane + if tile <= dim_size: + dim_candidates.append(tile) + if max_tile > 0: + for exp in range(int(math.log2(max_tile)) + 1): + tile = 2 ** exp + if tile <= dim_size and tile not in dim_candidates: + dim_candidates.append(tile) + + if dim_size not in dim_candidates: + dim_candidates.append(dim_size) + + dim_tile_candidates.append(sorted(set(dim_candidates))[:5]) + + for tile_combo in itertools.product(*dim_tile_candidates): + total_elements = math.prod(tile_combo) + total_spad_needed = total_elements * (num_inputs + 1) * kernel.precision + + if total_spad_needed <= kernel.spad_info["spad_size"] * kernel.vector_lane: + tile_candidates.append(list(tile_combo)) + + if not tile_candidates: + tile_candidates = [[1] * num_non_dim_dims] + + tile_candidates.sort(key=lambda x: -math.prod(x)) + return tile_candidates[:4] + + def _calculate_input_tile_sizes( + self, kernel, input_sizes, tile_sizes, num_inputs, rank + ): + """Calculate tile sizes for concat dimension for each input.""" + non_dim_tile_elements = math.prod(tile_sizes) if tile_sizes else 1 + non_dim_tile_spad = non_dim_tile_elements * kernel.precision + max_spad_per_input = kernel.spad_info["spad_size"] * kernel.vector_lane // 2 + extra_concat_input = math.ceil(max_spad_per_input / non_dim_tile_spad) - num_inputs + + input_tile_sizes_dim = [] + for i in range(num_inputs): + input_dim_size = input_sizes[i][self.dim] + if extra_concat_input > 0 and non_dim_tile_elements > 0: + max_tile_dim = min(input_dim_size, extra_concat_input) + extra_concat_input -= max_tile_dim + else: + max_tile_dim = 1 + input_tile_sizes_dim.append(max_tile_dim) + return input_tile_sizes_dim + + def _build_buffer_mapping(self, input_nodes): + """Map actual buffer names to template buffer names """ + buffer_name_to_template_name = {} + input_buffer_names = [] + for x in input_nodes: + actual_name = x.get_name() + template_name = buffer_name_to_template_name.setdefault( + actual_name, f"X{len(buffer_name_to_template_name)}" + ) + input_buffer_names.append(template_name) + return buffer_name_to_template_name, input_buffer_names + + def _build_tile_descriptors( + self, kernel, input_nodes, input_sizes, input_tile_sizes_dim, tile_sizes, rank, input_buffer_names + ): + """Build tile descriptors for each input.""" + input_tile_descs = [] + unique_tile_descs = {} + + for i, x in enumerate(input_nodes): + # Build full tile size list for this input + full_tile_sizes = [] + tile_size_idx = 0 + for d in range(rank): + if d != self.dim: + full_tile_sizes.append(tile_sizes[tile_size_idx]) + tile_size_idx += 1 + else: + full_tile_sizes.append(input_tile_sizes_dim[i]) + + tile_desc = mlir_common.MLIRMultiDimTile( + full_tile_sizes, + kernel.vector_lane, + vlane_split_axis=rank - 1, + vlane_stride=1 + ) + tile_desc.set_tile_size(full_tile_sizes) + template_buffer_name = input_buffer_names[i] + tile_desc.set_name(f"{template_buffer_name.lower()}_cat_tile") + input_tile_descs.append(tile_desc) + + # Store unique tile desc by actual buffer name + actual_name = x.get_name() + if actual_name not in unique_tile_descs: + unique_tile_descs[actual_name] = tile_desc + + return input_tile_descs, unique_tile_descs + + def _build_index_expressions( + self, input_nodes, input_sizes, output_strides, rank, num_inputs + ): + """Build index expressions for input and output.""" + input_idxs = [] + output_idxs = [] + cumulative_offsets = [0] + for i in range(num_inputs - 1): + cumulative_offsets.append(cumulative_offsets[-1] + input_sizes[i][self.dim]) + + for i, x in enumerate(input_nodes): + x_stride = x.get_layout().stride + input_idx = [] + output_idx = [] + for d in range(rank): + if d != self.dim: + input_idx_symbol = sympy.Symbol(f"index{d}") + output_idx_symbol = sympy.Symbol(f"index{d}") + else: + input_idx_symbol = sympy.Symbol(f"index_local{self.dim}_{i}") + output_idx_symbol = sympy.Symbol(f"index{self.dim}_{i}") + input_idx.append(input_idx_symbol * x_stride[d]) + output_idx.append(output_idx_symbol * output_strides[d]) + input_idxs.append(input_idx) + output_idxs.append(output_idx) + + return input_idxs, output_idxs, cumulative_offsets + + def _build_output_tile_desc(self, kernel, input_tile_sizes_dim, tile_sizes, rank): + """Build output tile descriptor.""" + max_output_tile_dim = max(input_tile_sizes_dim) if input_tile_sizes_dim else 1 + output_full_tile_sizes = [] + tile_size_idx = 0 + for d in range(rank): + if d != self.dim: + output_full_tile_sizes.append(tile_sizes[tile_size_idx]) + tile_size_idx += 1 + else: + output_full_tile_sizes.append(max_output_tile_dim) + + y_tile_desc = mlir_common.MLIRMultiDimTile( + output_full_tile_sizes, + kernel.vector_lane, + vlane_split_axis=rank - 1, + vlane_stride=1 + ) + y_tile_desc.set_tile_size(output_full_tile_sizes) + y_tile_desc.set_name("y_cat_tile") + return y_tile_desc + + def _setup_epilogue_info(self, kernel, y): + """Setup epilogue information.""" if hasattr(y, "get_numel"): y_numel = y.get_numel() elif hasattr(y, "node") and hasattr(y.node, "get_numel"): @@ -154,14 +304,5 @@ def _as_int(v): else: y_numel = None - kernel.epilogue_info = dict( - output_node=output_node_name, - sram_var="y_cat_tile", - dram_var=kernel.render_options["OUT_DVAR"], - dram_tile_desc=y_tile_desc, - ) if y_numel is not None: kernel.exception_nodes[kernel.render_options["OUT_DVAR"]] = {"numel": y_numel} - - code = self._template_from_string(TEMPLATE).render(**kernel.render_options) - return code diff --git a/PyTorchSimFrontend/mlir/mlir_conv_common.py b/PyTorchSimFrontend/mlir/mlir_conv_common.py index f8566b6d..f72a7663 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_common.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_common.py @@ -12,6 +12,9 @@ class MLIRConvCommonTemplate(MLIRTemplate): WRAPPER_TEMPLATE = None def __init__(self, input_nodes, layout, input_reorder=None, **kwargs): super().__init__("kernel", input_nodes, layout, input_reorder) + self.support_epilogue_fusion = True + self.support_prologue_fusion = False + self.support_reduction_fusion = False self.stride = kwargs["stride"] self.padding = kwargs["padding"] self.dilation = kwargs["dilation"] diff --git a/PyTorchSimFrontend/mlir/mlir_gemm_template.py b/PyTorchSimFrontend/mlir/mlir_gemm_template.py index 0158caa6..5b116807 100644 --- a/PyTorchSimFrontend/mlir/mlir_gemm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_gemm_template.py @@ -105,6 +105,9 @@ class MLIRGemmTemplate(MLIRTemplate): def __init__(self, input_nodes, layout, input_reorder=None): super().__init__("kernel", input_nodes, layout, input_reorder) + self.support_epilogue_fusion = True + self.support_prologue_fusion = True + self.support_reduction_fusion = True def render(self, kernel: MLIRTemplateKernel, diff --git a/PyTorchSimFrontend/mlir/mlir_lowering.py b/PyTorchSimFrontend/mlir/mlir_lowering.py index 0f28f03b..d7aee715 100644 --- a/PyTorchSimFrontend/mlir/mlir_lowering.py +++ b/PyTorchSimFrontend/mlir/mlir_lowering.py @@ -202,48 +202,9 @@ def _cat_layout(tensors: Sequence[TensorBox], dim: int) -> ir.Layout: stride, ) - -def _can_use_cat_template(tensors: Sequence[TensorBox], dim: int) -> bool: - # Current template specialization: 2 inputs, rank-2, dim in {0, 1}. - if len(tensors) != 2: - return False - if not all(hasattr(t, "get_size") and hasattr(t, "get_dtype") and hasattr(t, "realize") for t in tensors): - return False - if tensors[0].get_dtype() != tensors[1].get_dtype(): - return False - rank0 = len(tensors[0].get_size()) - rank1 = len(tensors[1].get_size()) - if rank0 != 2 or rank1 != 2: - return False - if dim < 0: - dim += rank0 - if dim not in (0, 1): - return False - - if dim == 0: - cols0 = tensors[0].get_size()[1] - cols1 = tensors[1].get_size()[1] - return V.graph.sizevars.statically_known_equals(cols0, cols1) - - rows0 = tensors[0].get_size()[0] - rows1 = tensors[1].get_size()[0] - return V.graph.sizevars.statically_known_equals(rows0, rows1) - - -def _cat_fallback(reason: str, tensors: Sequence[TensorBox], dim: int): - # Non-template cases delegate to the original lowering path. - return _orig_cat_default_lowering(tensors, dim) - - -def _custom_cat_impl(tensors: Sequence[TensorBox], dim: int = 0): - if _orig_cat_default_lowering is None: - raise RuntimeError("Original aten.cat.default lowering is missing") - if len(tensors) > 0: - rank = len(tensors[0].get_size()) - if dim < 0: - dim += rank - if not _can_use_cat_template(tensors, dim): - return _cat_fallback("default-path", tensors, dim) +def custom_cat_default(tensors: Sequence[TensorBox], dim: int = 0): + if tensors and dim < 0: + dim += len(tensors[0].get_size()) for t in tensors: t.realize() @@ -251,75 +212,6 @@ def _custom_cat_impl(tensors: Sequence[TensorBox], dim: int = 0): mlir_template = MLIRCatTemplate(list(tensors), layout, dim=dim) return mlir_template.generate().output_node() - -def custom_cat_default(tensors: Sequence[TensorBox], dim: int = 0): - return _custom_cat_impl(tensors, dim) - - -def custom_cat_out(tensors: Sequence[TensorBox], dim: int = 0, out: Optional[TensorBox] = None): - if _orig_cat_out_lowering is None: - raise RuntimeError("Original aten.cat.out lowering is missing") - if out is None: - return _orig_cat_out_lowering(tensors, dim, out) - - copy_default_lowering = lowerings.get(aten.copy_.default) - slice_tensor_lowering = lowerings.get(aten.slice.Tensor) - if copy_default_lowering is None or slice_tensor_lowering is None: - raise RuntimeError("cat.out lowering requires aten.copy_.default and aten.slice.Tensor lowerings") - - # Lower cat.out as a sequence of slice+copy ops so each piece still runs - # through the existing compiled/simulated kernel path. - if len(tensors) == 0: - raise RuntimeError("cat.out requires at least one input tensor") - if not all(hasattr(t, "get_size") and hasattr(t, "get_dtype") and hasattr(t, "realize") for t in tensors): - raise RuntimeError("cat.out inputs must be tensor-like values") - rank = len(tensors[0].get_size()) - if rank == 0: - raise RuntimeError("cat.out does not support scalar inputs") - if dim < 0: - dim = dim + rank - if dim < 0 or dim >= rank: - raise RuntimeError(f"cat.out dim out of range: dim={dim}, rank={rank}") - if any(len(t.get_size()) != rank for t in tensors): - raise RuntimeError("cat.out inputs must have the same rank") - if any(t.get_dtype() != tensors[0].get_dtype() for t in tensors): - raise RuntimeError("cat.out inputs must have the same dtype") - # cat semantics: all non-cat dimensions must be equal. - for i in range(rank): - if i == dim: - continue - base = tensors[0].get_size()[i] - if any(not V.graph.sizevars.statically_known_equals(base, t.get_size()[i]) for t in tensors[1:]): - raise RuntimeError(f"cat.out non-concatenated dimension mismatch at dim={i}") - - # Output shape must match concatenated shape. - if not hasattr(out, "get_size"): - raise RuntimeError("cat.out output must be tensor-like") - out_sizes = list(out.get_size()) - if len(out_sizes) != rank: - raise RuntimeError("cat.out output rank mismatch") - for i in range(rank): - if i == dim: - continue - if not V.graph.sizevars.statically_known_equals(out_sizes[i], tensors[0].get_size()[i]): - raise RuntimeError(f"cat.out output shape mismatch at dim={i}") - expected_cat = sum(t.get_size()[dim] for t in tensors) - if not V.graph.sizevars.statically_known_equals(out_sizes[dim], expected_cat): - raise RuntimeError(f"cat.out output concatenated dimension mismatch at dim={dim}") - - if isinstance(out, TensorBox): - out.realize() - - offset = 0 - for src in tensors: - src.realize() - end = offset + src.get_size()[dim] - dst_view = slice_tensor_lowering(out, dim, offset, end, 1) - copy_default_lowering(dst_view, src) - offset = end - return out - - def _custom_sort_values_impl( self: TensorBox, dim: int = -1, @@ -459,9 +351,7 @@ def custom_sort_values_stable( lowerings.update({getattr(aten.bmm, overload): tuned_bmm for overload in aten.bmm.overloads()}) lowerings.update({getattr(aten._sparse_addmm, overload): sparse_addmm for overload in aten._sparse_addmm.overloads()}) lowerings.update({getattr(aten._unsafe_index, overload): custom_unsafe_index for overload in aten._unsafe_index.overloads()}) - -lowerings.update({aten.cat.default: custom_cat_default}) -lowerings.update({aten.cat.out: custom_cat_out}) +lowerings.update({getattr(aten.cat, overload): custom_cat_default for overload in aten.cat.overloads()}) lowerings.update({aten.sort.stable: custom_sort_stable}) lowerings.update({aten.sort.values_stable: custom_sort_values_stable}) diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index af960533..2f9c9704 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -44,12 +44,10 @@ def can_fuse_with_exceptions(self, node1: BaseSchedulerNode, node2: BaseSchedule # Case 3: Prologue(Pointwise) + Tempalte if len(base_template_node1) == 0 and len(node1.get_nodes())==1 and len(node2.get_nodes())==1 and not node1.is_reduction() and len(base_template_node2) == 1 and extension_config.CONFIG_FUSION_PROLOGUE: - from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate - from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate - target_node = base_template_node2[0].node - # Currently only BMM, MM support prologue fusion - if not isinstance(target_node.template, (MLIRBMMTemplate, MLIRGemmTemplate)): + + # Check if template supports prologue fusion + if not getattr(target_node.template, 'support_prologue_fusion', False): return False if len(node1.read_writes.writes) != 1: @@ -129,12 +127,14 @@ def can_fuse_horizontal(self, node1, node2): if len(base_template_node1) == 1 and len(node1.get_nodes())==1 and len(node2.get_nodes())==1 and len(base_template_node2) == 0 and not node2.is_reduction(): # Don't fuse maxpool template code from PyTorchSimFrontend.mlir.mlir_maxpool_template import MLIRMaxPoolTemplate - from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate - from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate template_node = base_template_node1[0] epilogue_node = node2 + # Check if template supports epilogue fusion + if not getattr(template_node.node.template, 'support_epilogue_fusion', False): + return False + if isinstance(template_node.node.template, MLIRMaxPoolTemplate): return False @@ -161,7 +161,7 @@ def can_fuse_horizontal(self, node1, node2): # Revert act_node.group : simplify_and_reorder() modified _body, _size, group if template_node.group != epilogue_node.group: # We don't fuse this case... - if (isinstance(template_node.node.template, MLIRBMMTemplate) or isinstance(template_node.node.template, MLIRGemmTemplate)) and template_node.group[1][0][0] == 1: + if getattr(template_node.node.template, 'support_prologue_fusion', False) and template_node.group[1][0][0] == 1: return False if list(template_node.group[1][0]) != list(epilogue_node.get_nodes()[0].node.data.get_size()): @@ -171,10 +171,10 @@ def can_fuse_horizontal(self, node1, node2): # Case 2: Tempalte + Reduction fusion if len(base_template_node1) == 1 and len(node1.get_nodes())==1 and len(node2.get_nodes())==1 and len(base_template_node2) == 0 and node2.is_reduction() and extension_config.CONFIG_FUSION_REDUCTION_EPILOGUE: - from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate - from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate target_node = base_template_node1[0].node - if not isinstance(target_node.template, (MLIRBMMTemplate, MLIRGemmTemplate)): + + # Check if template supports reduction fusion + if not getattr(target_node.template, 'support_reduction_fusion', False): return False size_match = node1.get_nodes()[0].node.get_numel() == reduce(operator.mul, node2.get_nodes()[0].node.get_size(), 1) * reduce(operator.mul, node2.get_nodes()[0].node.get_reduction_size(), 1) diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index 76b0ef71..04d327f8 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -14,7 +14,7 @@ from unittest.mock import patch from torch._inductor.codegen.common import KernelTemplate, CSE, DeferredLine -from torch._inductor.ir import Buffer, IRNode, TemplateBuffer, ChoiceCaller +from torch._inductor.ir import Buffer, IRNode, TemplateBuffer, ChoiceCaller, ir_node_to_tensor from torch._inductor.select_algorithm import PartialRender from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller from torch._inductor.autotune_process import TensorMeta @@ -124,6 +124,7 @@ def __init__(self, self.epilogue_buffer_group = IndentedBufferGroup(self, prefix="epilogue_") self.global_vars = IndentedBuffer() self.exception_nodes = {} + self.epilogue_info = {} # Reduction data structure self.reduction_epilogue_suffix = IndentedBuffer() self.reduction_fusion = False @@ -403,7 +404,7 @@ def call_kernel(self, kernel_name): _, call_args, _, _ = self.kernel_group.args.mlir_argdefs() # generate the code to call this wrapper.generate_kernel_call( - kernel_name if self.outer_func_name is None else self.outer_func_name + f"_{len(call_args)}", call_args) + kernel_name if self.outer_func_name is None else "wrapper_" + kernel_name, call_args) def codegen_template_code(self, render, template_node, prologue_nodes, epilogue_nodes, tile_info): with self as kernel: @@ -460,11 +461,11 @@ def codegen_template_code(self, render, template_node, prologue_nodes, epilogue_ } node.codegen((vars, reduction_vars)) - # Codegen epilogue nodes - tile_desc = kernel.set_tile_size(kernel.epilogue_info) - kernel.kernel_group.set_tile_info(tile_desc) - kernel.call_ranges = None if epilogue_nodes: + # Codegen epilogue nodes + tile_desc = kernel.set_tile_size(kernel.epilogue_info) + kernel.kernel_group.set_tile_info(tile_desc) + kernel.call_ranges = None with kernel.epilogue_buffer_group.as_local(): _, (group, reduction_group) = max( epilogue_nodes, key=lambda x: int(x.is_reduction()) @@ -625,7 +626,9 @@ def def_kernel( extra_node[node.get_name()] = node.node else: extra_node[node.get_name()] = node - self.buffer_names[node.get_name()] = self.epilogue_info['sram_var'] + + if 'sram_var' in self.epilogue_info: + self.buffer_names[node.get_name()] = self.epilogue_info['sram_var'] def hook(): arg_defs, call_args, *_ = self.kernel_group.args.mlir_argdefs(extra_node=extra_node) @@ -688,7 +691,8 @@ def def_conv_kernel( self.kernel_group.args.output_buffers[node.get_name()] = name self.store_buffer_names.add(node.get_name()) #TODO: Is this enough not calling store() in mlir_common.py? self.extra_node[node.get_name()] = node - self.buffer_names[node.get_name()] = self.epilogue_info['sram_var'] #TODO: Buffer name fixed + if 'sram_var' in self.epilogue_info: + self.buffer_names[node.get_name()] = self.epilogue_info['sram_var'] #TODO: Buffer name fixed def kernel_hook(): arg_defs, *_ = self.kernel_group.args.mlir_argdefs(extra_node=self.extra_node) @@ -1146,6 +1150,15 @@ def set_tile_size(self, template_fusion_info, prologue=False): return tile_desc class MLIRTemplateCaller(CUDATemplateCaller): + def __init__(self, name, category, input_nodes, layout, make_kernel_render, supports_epilogue_fusion, template, info_kwargs, description): + bmreq = MLIRBenchmarkRequest( + kernel_name=name, + input_tensor_meta=list(), + output_tensor_meta=list(), + extra_args=[], + source_code="", + ) + super().__init__(name, category, input_nodes, layout, make_kernel_render, bmreq, supports_epilogue_fusion, template, info_kwargs, description) def __str__(self): return f"MLIRTemplateCaller(source_file={self.bmreq.source_file})" @@ -1173,6 +1186,10 @@ def __init__(self, name, input_nodes, layout, input_reorder = None): self.output_nodes = [self.output_node] self.input_reorder = input_reorder self.layout = layout + # Fusion support flags (default to False) + self.support_epilogue_fusion = False + self.support_prologue_fusion = False + self.support_reduction_fusion = False def generate(self, **kwargs) -> ChoiceCaller: kernel_name = f"mlir_{self.name}" @@ -1184,18 +1201,9 @@ def generate(self, **kwargs) -> ChoiceCaller: code = self.render(kernel=kernel, **kwargs) kernel_hash_name = f"mlir_{self.name}_{next(self.index_counter)}" - extra_args = [] # create the BenchmarkRequest output_nodes = getattr(self, "output_nodes", None) or [self.output_node] - bmreq = MLIRBenchmarkRequest( - kernel_name=kernel_name, - input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes), - output_tensor_meta=TensorMeta.from_irnodes(output_nodes), - extra_args=extra_args, - source_code=code, - ) - def make_kernel_render( template_node: TemplateBuffer, prologue_nodes: Optional[List[IRNode]] = None, @@ -1236,7 +1244,6 @@ def make_kernel_render( self.input_nodes, self.output_node.get_layout(), make_kernel_render, - bmreq, False, # supports_epilogue_fusion self, kwargs, diff --git a/tests/test_cat.py b/tests/test_cat.py index 32573a05..62de6759 100644 --- a/tests/test_cat.py +++ b/tests/test_cat.py @@ -20,24 +20,6 @@ def _test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): print("cpu out: ", cpu_out) raise RuntimeError(f"{name} mismatch") - -def _togsim_log_count() -> int: - log_dir = Path("togsim_results") - if not log_dir.exists(): - return 0 - return len(list(log_dir.glob("*.log"))) - - -def _assert_simulation_happened(before_count: int, case_name: str): - after_count = _togsim_log_count() - if after_count <= before_count: - raise RuntimeError( - f"{case_name}: TOGSim log count did not increase " - f"(before={before_count}, after={after_count})" - ) - print(f"{case_name}: TOGSim logs increased ({before_count} -> {after_count})") - - def test_cat_default(device): def cat_default_fn(a, b): return torch.cat([a, b], dim=0) @@ -46,9 +28,7 @@ def cat_default_fn(a, b): y = torch.randn(6, 16, device=device) opt_fn = torch.compile(dynamic=False)(cat_default_fn) - before = _togsim_log_count() out = opt_fn(x, y) - _assert_simulation_happened(before, "cat.default") cpu_out = torch.cat([x.cpu(), y.cpu()], dim=0) _test_result("cat.default", out, cpu_out, rtol=1e-4, atol=1e-4) @@ -63,19 +43,122 @@ def cat_out_fn(a, b, out): out_buf = torch.empty(14, 16, device=device) opt_fn = torch.compile(dynamic=False)(cat_out_fn) - before = _togsim_log_count() out = opt_fn(x, y, out_buf) - _assert_simulation_happened(before, "cat.out") cpu_out = torch.cat([x.cpu(), y.cpu()], dim=0) _test_result("cat.out", out, cpu_out, rtol=1e-4, atol=1e-4) +def test_cat_4d_dim0(device): + def cat_4d_dim0_fn(a, b): + return torch.cat([a, b], dim=0) + + x = torch.randn(2, 3, 4, 5, device=device) + y = torch.randn(3, 3, 4, 5, device=device) + opt_fn = torch.compile(dynamic=False)(cat_4d_dim0_fn) + + out = opt_fn(x, y) + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=0) + _test_result("cat.4d.dim0", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_4d_dim1(device): + def cat_4d_dim1_fn(a, b): + return torch.cat([a, b], dim=1) + + x = torch.randn(2, 3, 4, 5, device=device) + y = torch.randn(2, 5, 4, 5, device=device) + opt_fn = torch.compile(dynamic=False)(cat_4d_dim1_fn) + + out = opt_fn(x, y) + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=1) + _test_result("cat.4d.dim1", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_4d_dim2(device): + def cat_4d_dim2_fn(a, b): + return torch.cat([a, b], dim=2) + + x = torch.randn(2, 3, 4, 5, device=device) + y = torch.randn(2, 3, 6, 5, device=device) + opt_fn = torch.compile(dynamic=False)(cat_4d_dim2_fn) + + out = opt_fn(x, y) + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=2) + _test_result("cat.4d.dim2", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_4d_dim3(device): + def cat_4d_dim3_fn(a, b): + return torch.cat([a, b], dim=3) + + x = torch.randn(2, 3, 4, 5, device=device) + y = torch.randn(2, 3, 4, 7, device=device) + opt_fn = torch.compile(dynamic=False)(cat_4d_dim3_fn) + + out = opt_fn(x, y) + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=3) + _test_result("cat.4d.dim3", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_three_inputs(device): + def cat_three_inputs_fn(a, b, c): + return torch.cat([a, b, c], dim=0) + + x = torch.randn(4, 16, device=device) + y = torch.randn(5, 16, device=device) + z = torch.randn(3, 16, device=device) + opt_fn = torch.compile(dynamic=False)(cat_three_inputs_fn) + + out = opt_fn(x, y, z) + + cpu_out = torch.cat([x.cpu(), y.cpu(), z.cpu()], dim=0) + _test_result("cat.three_inputs", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_four_inputs(device): + def cat_four_inputs_fn(a, b, c, d): + return torch.cat([a, b, c, d], dim=0) + + x = torch.randn(3, 16, device=device) + y = torch.randn(4, 16, device=device) + z = torch.randn(5, 16, device=device) + w = torch.randn(2, 16, device=device) + opt_fn = torch.compile(dynamic=False)(cat_four_inputs_fn) + + out = opt_fn(x, y, z, w) + + cpu_out = torch.cat([x.cpu(), y.cpu(), z.cpu(), w.cpu()], dim=0) + _test_result("cat.four_inputs", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_4d_three_inputs(device): + def cat_4d_three_inputs_fn(a, b, c): + return torch.cat([a, b, c], dim=1) + + x = torch.randn(2, 3, 4, 5, device=device) + y = torch.randn(2, 4, 4, 5, device=device) + z = torch.randn(2, 5, 4, 5, device=device) + opt_fn = torch.compile(dynamic=False)(cat_4d_three_inputs_fn) + + out = opt_fn(x, y, z) + + cpu_out = torch.cat([x.cpu(), y.cpu(), z.cpu()], dim=1) + _test_result("cat.4d.three_inputs", out, cpu_out, rtol=1e-4, atol=1e-4) + + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run cat simulation tests") parser.add_argument( "--case", - choices=["default", "out", "all"], + choices=[ + "default", "out", "4d_dim0", "4d_dim1", "4d_dim2", "4d_dim3", + "three_inputs", "four_inputs", "4d_three_inputs", "all" + ], default="all", help="Which cat case to run", ) @@ -87,3 +170,17 @@ def cat_out_fn(a, b, out): test_cat_default(device) if args.case in ("out", "all"): test_cat_out(device) + if args.case in ("4d_dim0", "all"): + test_cat_4d_dim0(device) + if args.case in ("4d_dim1", "all"): + test_cat_4d_dim1(device) + if args.case in ("4d_dim2", "all"): + test_cat_4d_dim2(device) + if args.case in ("4d_dim3", "all"): + test_cat_4d_dim3(device) + if args.case in ("three_inputs", "all"): + test_cat_three_inputs(device) + if args.case in ("four_inputs", "all"): + test_cat_four_inputs(device) + if args.case in ("4d_three_inputs", "all"): + test_cat_4d_three_inputs(device) From 434bbb10793a68172e49e107bc3b639fd3b86264 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Wed, 4 Mar 2026 20:02:14 +0900 Subject: [PATCH 06/31] [WIP] --- PyTorchSimFrontend/mlir/mlir_cat_template.py | 13 ------------- PyTorchSimFrontend/mlir/mlir_template.py | 2 +- 2 files changed, 1 insertion(+), 14 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_cat_template.py b/PyTorchSimFrontend/mlir/mlir_cat_template.py index d68af7d4..5062e629 100644 --- a/PyTorchSimFrontend/mlir/mlir_cat_template.py +++ b/PyTorchSimFrontend/mlir/mlir_cat_template.py @@ -118,7 +118,6 @@ def render( input_reorder=self.input_reorder, ) - self._setup_epilogue_info(kernel, y) code = self._template_from_string(TEMPLATE).render(**kernel.render_options) return code @@ -294,15 +293,3 @@ def _build_output_tile_desc(self, kernel, input_tile_sizes_dim, tile_sizes, rank y_tile_desc.set_tile_size(output_full_tile_sizes) y_tile_desc.set_name("y_cat_tile") return y_tile_desc - - def _setup_epilogue_info(self, kernel, y): - """Setup epilogue information.""" - if hasattr(y, "get_numel"): - y_numel = y.get_numel() - elif hasattr(y, "node") and hasattr(y.node, "get_numel"): - y_numel = y.node.get_numel() - else: - y_numel = None - - if y_numel is not None: - kernel.exception_nodes[kernel.render_options["OUT_DVAR"]] = {"numel": y_numel} diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index 04d327f8..59610228 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -813,7 +813,7 @@ def def_dma_op(self, dma_type, dram_var:str, index_list:list, tile_desc:mlir_com if dram_var in self.exception_nodes: numel = self.exception_nodes[dram_var]["numel"] else: - numel = self.get_arg_info(self.named_nodes[dram_var].get_name()).get_numel() + numel = self.named_nodes[dram_var].get_numel() mlir_dtype = mlir_common.DTYPE_TO_MLIR[node_layout.dtype] dram_shape = f"memref<{numel}x{mlir_dtype}>" dram_stride = [] From 5295dfb5a16e21fda57b12d73906c1bd290c4f94 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Wed, 4 Mar 2026 22:13:26 +0900 Subject: [PATCH 07/31] [Template] Delay def_dma_op codegen def_dma_op find data node using dram_var. But it can't locate the proper node when output buffer has not been created. --- PyTorchSimFrontend/mlir/mlir_template.py | 146 +++++++++++++---------- 1 file changed, 81 insertions(+), 65 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index 59610228..7c52bfe6 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -112,7 +112,8 @@ def __init__(self, self.outer_func_name = outer_func_name self.outer_func_render = outer_func_render self.kernel_arg_attributes = kernel_arg_attributes - self.render_hooks = OrderedDict() + self.render_hooks = OrderedDict() # Stores {key: (priority, hook)} + self.dma_op_counter = itertools.count() # Add counter for unique DMA op keys self.buffer_names = dict() self.render_options = dict() self.tile_size = [] @@ -555,7 +556,7 @@ def template_store(): dram_var = self.epilogue_info["dram_var"] index_list = self.epilogue_info["dram_idx"] tile_desc = self.epilogue_info["dram_tile_desc"] - code = self.def_dma_op("MVOUT", dram_var, index_list, tile_desc) + code = self.def_dma_op("MVOUT", dram_var, index_list, tile_desc, lazy_mode=False) self.cse.generate(self.dma_stores, code, assignment = False) body = IndentedBuffer() @@ -653,7 +654,7 @@ def hook(): return f"({', '.join(renamed_arg_defs)})" assert "" not in self.render_hooks - self.render_hooks[""] = hook + self.render_hooks[""] = (5, hook) # Default priority 5 return "" # This function is a temporal function for convolution because currently convolution kernel is not considering padding. @@ -700,7 +701,7 @@ def kernel_hook(): return f"({', '.join(arg_defs)})" assert "" not in self.render_hooks - self.render_hooks[""] = kernel_hook + self.render_hooks[""] = (5, kernel_hook) # Default priority 5 return "" # This function is for convolution wrapper function finalizing. @@ -711,7 +712,7 @@ def wrapper_hook(): return f"({', '.join(wrapper_arg_defs)})" if "" not in self.render_hooks: - self.render_hooks[""] = wrapper_hook + self.render_hooks[""] = (5, wrapper_hook) # Default priority 5 return "" def get_conv_inputs(self): @@ -720,15 +721,15 @@ def get_conv_inputs(self): def get_conv_outputs(self): return {k: v for k, v in self.kernel_group.args.output_buffers.items() if v != 'REMOVED'} - def load_input(self, indent_size: int = 0): + def load_input(self, indent_size: int = 0, priority: int = 1): def hook(): code = IndentedBuffer() prologue_code = self.codegen_prologue_body() if prologue_code.getvalue(): input_dma_code = self.def_dma_op("MVIN", self.prologue_info["input_dram_var"], self.prologue_info["input_idx"], - self.prologue_info["input_tile_desc"], subtile_size=self.prologue_info["input_subtile_size"], async_type=False) + self.prologue_info["input_tile_desc"], subtile_size=self.prologue_info["input_subtile_size"], async_type=False, lazy_mode=False) weight_dma_code = self.def_dma_op("MVIN", self.prologue_info["weight_dram_var"], self.prologue_info["weight_idx"], - self.prologue_info["weight_tile_desc"], subtile_size=self.prologue_info["weight_subtile_size"], async_type=False) + self.prologue_info["weight_tile_desc"], subtile_size=self.prologue_info["weight_subtile_size"], async_type=False, lazy_mode=False) if (self.prologue_info["is_input_fused"]): code.splice(input_dma_code) code.splice(prologue_code) @@ -739,58 +740,63 @@ def hook(): code.splice(input_dma_code) else: dma_code = self.def_dma_op("MVIN", self.prologue_info["input_dram_var"], self.prologue_info["input_idx"], - self.prologue_info["input_tile_desc"], subtile_size=self.prologue_info["input_subtile_size"], async_type=False) + self.prologue_info["input_tile_desc"], subtile_size=self.prologue_info["input_subtile_size"], async_type=False, lazy_mode=False) code.splice(dma_code) dma_code = self.def_dma_op("MVIN", self.prologue_info["weight_dram_var"], self.prologue_info["weight_idx"], - self.prologue_info["weight_tile_desc"], subtile_size=self.prologue_info["weight_subtile_size"], async_type=False) + self.prologue_info["weight_tile_desc"], subtile_size=self.prologue_info["weight_subtile_size"], async_type=False, lazy_mode=False) code.splice(dma_code) code = textwrap.indent(code.getvalue(), " "*indent_size).strip() return code assert "" not in self.render_hooks - self.render_hooks[""] = hook - self.render_hooks.move_to_end("", last=False) # Force order to be triggered first + self.render_hooks[""] = (priority, hook) return "" - def store_output(self, indent_size: int = 0): + def store_output(self, indent_size: int = 0, priority: int = 1): def hook(): epilogue_code = self.codegen_epilogue_body() return textwrap.indent(epilogue_code.getvalue(), " "*indent_size).strip() assert "" not in self.render_hooks - self.render_hooks[""] = hook - self.render_hooks.move_to_end("", last=False) # Force order to be triggered first + self.render_hooks[""] = (priority, hook) return "" - def reduction_output(self, indent_size: int = 0): + def reduction_output(self, indent_size: int = 0, priority: int = 5): def hook(): return textwrap.indent(self.reductions_suffix.getvalue(), " "*indent_size).strip() assert "" not in self.render_hooks - self.render_hooks[""] = hook + self.render_hooks[""] = (priority, hook) return "" + def _sort_hooks_by_priority(self): + """Sort hooks by priority (lower priority executes first).""" + sorted_hooks = OrderedDict() + for key, (priority, hook) in sorted(self.render_hooks.items(), key=lambda x: x[1][0]): + sorted_hooks[key] = hook + return sorted_hooks + def def_function(self): _, call_args, _, _ = self.kernel_group.args.python_argdefs() if self.outer_func_render is not None: partial_code, function_name = self.outer_func_render(input_args=call_args) + return PartialRender( partial_code, - self.render_hooks, + self._sort_hooks_by_priority(), ), function_name else: return None, None - def def_global_vars(self): + def def_global_vars(self, priority: int = 10): key = "" def hook(): return textwrap.indent(self.global_vars.getvalue(), "").strip() - assert key not in self.render_hooks - self.render_hooks[key] = hook + self.render_hooks[key] = (priority, hook) return key - def def_local_vars(self, indent_size=0): + def def_local_vars(self, indent_size=0, priority: int = 10): key = "" def hook(): code = IndentedBuffer() @@ -799,52 +805,62 @@ def hook(): code.splice(self.alloc_buffer) return textwrap.indent(code.getvalue(), " "*indent_size).strip() - assert key not in self.render_hooks - self.render_hooks[key] = hook + self.render_hooks[key] = (priority, hook) return key def def_dma_op(self, dma_type, dram_var:str, index_list:list, tile_desc:mlir_common.MLIRMultiDimTile, - subtile_size:list=[], async_type=None, indent_size=0): - # Prepare code block - local_code = IndentedBuffer() - with self, self.override_buffer_cse(buffer=local_code, cse=self.apply_cse): - index_var = self.parse_index_list(index_list, offset=tile_desc.offset) - node_layout = self.named_nodes[dram_var].get_layout() - if dram_var in self.exception_nodes: - numel = self.exception_nodes[dram_var]["numel"] - else: - numel = self.named_nodes[dram_var].get_numel() - mlir_dtype = mlir_common.DTYPE_TO_MLIR[node_layout.dtype] - dram_shape = f"memref<{numel}x{mlir_dtype}>" - dram_stride = [] - for idx in index_list: - if idx.is_Mul: - dram_stride.append(int(idx.args[0])) - elif idx == sympy.Symbol("c0"): - dram_stride.append(0) - elif not idx.is_Number: - dram_stride.append(1) + subtile_size:list=[], async_type=None, indent_size=0, priority: int = 5, lazy_mode: bool = True): + def generate_dma_code(): + """Internal method to generate DMA code directly.""" + local_code = IndentedBuffer() + with self, self.override_buffer_cse(buffer=local_code, cse=self.apply_cse): + index_var = self.parse_index_list(index_list, offset=tile_desc.offset) + node_layout = self.named_nodes[dram_var].get_layout() + if dram_var in self.exception_nodes: + numel = self.exception_nodes[dram_var]["numel"] else: - dram_stride.append(0) - - sram_var = tile_desc.get_name() - tile_shape = tile_desc.get_mlir_shape(mlir_dtype) - tile_stride = tile_desc.get_tile_stride() - vlane_split_axis = tile_desc.vmap.vlane_split_axis - vlane_stride = tile_desc.vmap.vlane_stride - - zero_cse = self.get_const_cse(0, "index") - sram_index_var = ", ".join([f"%{str(zero_cse)}"]*tile_desc.get_nr_dim()) - - attribute_parts = [f"dram_stride={dram_stride}", f"sram_stride={tile_stride}", "padding=0"] - if subtile_size: - attribute_parts.append(f"subtile_size={subtile_size}, async={int(async_type) if async_type is not None else 1}") - attribute = " {" + ", ".join(attribute_parts) + "}" - code = self.get_dma_code(dma_type, vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - dram_shape, tile_shape, "") - local_code.writeline(code) - local_code.writeline(attribute) - return textwrap.indent(local_code.getvalue(), " "*indent_size).strip() + numel = self.get_arg_info(self.named_nodes[dram_var].get_name()).get_numel() + mlir_dtype = mlir_common.DTYPE_TO_MLIR[node_layout.dtype] + dram_shape = f"memref<{numel}x{mlir_dtype}>" + dram_stride = [] + for idx in index_list: + if idx.is_Mul: + dram_stride.append(int(idx.args[0])) + elif idx == sympy.Symbol("c0"): + dram_stride.append(0) + elif not idx.is_Number: + dram_stride.append(1) + else: + dram_stride.append(0) + + sram_var = tile_desc.get_name() + tile_shape = tile_desc.get_mlir_shape(mlir_dtype) + tile_stride = tile_desc.get_tile_stride() + vlane_split_axis = tile_desc.vmap.vlane_split_axis + vlane_stride = tile_desc.vmap.vlane_stride + + zero_cse = self.get_const_cse(0, "index") + sram_index_var = ", ".join([f"%{str(zero_cse)}"]*tile_desc.get_nr_dim()) + + attribute_parts = [f"dram_stride={dram_stride}", f"sram_stride={tile_stride}", "padding=0"] + if subtile_size: + attribute_parts.append(f"subtile_size={subtile_size}, async={int(async_type) if async_type is not None else 1}") + attribute = " {" + ", ".join(attribute_parts) + "}" + code = self.get_dma_code(dma_type, vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, + dram_shape, tile_shape, "") + local_code.writeline(code) + local_code.writeline(attribute) + return textwrap.indent(local_code.getvalue(), " "*indent_size).strip() + + if not lazy_mode: + # Immediate mode: generate code directly and return it + return generate_dma_code() + + # Lazy mode: register hook and return key + dma_op_id = next(self.dma_op_counter) + key = f"" + self.render_hooks[key] = (priority, generate_dma_code) + return key def def_sram_buffer(self, dram_name, tile_desc, id=0, indent_size=0): # Prepare code block @@ -862,7 +878,7 @@ def render(self, template, kwargs, define_function=None): return PartialRender( code, - self.render_hooks, + self._sort_hooks_by_priority(), ) def get_spad_size_per_lane(self, tile_m, tile_n): From 61caebd5708ca21a88950d4d5073445891ea32f1 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Thu, 5 Mar 2026 00:12:49 +0900 Subject: [PATCH 08/31] [Template/Cat] Fix apply offset setting --- PyTorchSimFrontend/mlir/mlir_cat_template.py | 80 +++++++++----------- 1 file changed, 37 insertions(+), 43 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_cat_template.py b/PyTorchSimFrontend/mlir/mlir_cat_template.py index 5062e629..5aaf3e71 100644 --- a/PyTorchSimFrontend/mlir/mlir_cat_template.py +++ b/PyTorchSimFrontend/mlir/mlir_cat_template.py @@ -26,7 +26,7 @@ affine.for %index_local{{ DIM }}_{{ i }} = 0 to {{ INPUT_SIZES[i][DIM] }} step {{ INPUT_TILE_SIZES_DIM[i] }} { %index{{ DIM }}_{{i}} = affine.apply affine_map<(d0) -> (d0 + {{ CUMULATIVE_OFFSETS[i] }})> (%index_local{{ DIM }}_{{ i }}) {{ kernel.def_dma_op("MVIN", INPUT_BUFFER_NAMES[i], INPUT_IDXS[i], INPUT_TILE_DESCS[i], indent_size=INDENT_SIZE) }} - {{ kernel.def_dma_op("MVOUT", OUT_DVAR, OUTPUT_IDXS[i], INPUT_TILE_DESCS[i], indent_size=INDENT_SIZE) }} + {{ kernel.def_dma_op("MVOUT", OUT_DVAR, OUTPUT_IDXS[i], OUTPUT_TILE_DESCS[i], indent_size=INDENT_SIZE) }} } { inner_loop=true } {%- endfor %} @@ -52,10 +52,6 @@ def render( tile_info=None, **kwargs, ): - is_out_variant = template_buffer_node is not None - if is_out_variant: - self.output_node = template_buffer_node - # Extract info input_nodes = self.input_nodes y = self.output_node @@ -73,11 +69,8 @@ def render( kernel, input_sizes, tile_sizes, num_inputs, rank ) buffer_name_to_template_name, input_buffer_names = self._build_buffer_mapping(input_nodes) - input_tile_descs, unique_tile_descs = self._build_tile_descriptors( - kernel, input_nodes, input_sizes, input_tile_sizes_dim, tile_sizes, rank, input_buffer_names - ) - y_tile_desc = self._build_output_tile_desc( - kernel, input_tile_sizes_dim, tile_sizes, rank + input_tile_descs, output_tile_descs, unique_tile_descs = self._build_tile_descriptors( + kernel, input_nodes, input_sizes, input_tile_sizes_dim, tile_sizes, rank, input_buffer_names, y ) input_idxs, output_idxs, cumulative_offsets = self._build_index_expressions( @@ -90,14 +83,14 @@ def render( if actual_name in unique_tile_descs: unique_buffer_tile_descs[template_name] = unique_tile_descs[actual_name] - names_str = ", ".join(input_buffer_names + ["out_ptr1" if is_out_variant else "Y"]) + names_str = ", ".join(input_buffer_names + ["Y"]) indent_size = 2 + (rank - 1) * 2 + 4 kernel.render_options = dict( KERNEL_NAME=self.name, kernel=kernel, Y=y, - OUT_DVAR="out_ptr1" if is_out_variant else "Y", + OUT_DVAR="Y", NAMES_STR=names_str, INPUT_NAMES=input_nodes, INPUT_BUFFER_NAMES=input_buffer_names, @@ -110,6 +103,7 @@ def render( TILE_SIZES=tile_sizes, INPUT_TILE_SIZES_DIM=input_tile_sizes_dim, INPUT_TILE_DESCS=input_tile_descs, + OUTPUT_TILE_DESCS=output_tile_descs, UNIQUE_BUFFER_TILE_DESCS=unique_buffer_tile_descs, INPUT_IDXS=input_idxs, OUTPUT_IDXS=output_idxs, @@ -209,14 +203,16 @@ def _build_buffer_mapping(self, input_nodes): return buffer_name_to_template_name, input_buffer_names def _build_tile_descriptors( - self, kernel, input_nodes, input_sizes, input_tile_sizes_dim, tile_sizes, rank, input_buffer_names + self, kernel, input_nodes, input_sizes, input_tile_sizes_dim, tile_sizes, rank, input_buffer_names, output_node ): - """Build tile descriptors for each input.""" + """Build tile descriptors for each input and output.""" input_tile_descs = [] + output_tile_descs = [] unique_tile_descs = {} + output_offset = output_node.get_layout().offset for i, x in enumerate(input_nodes): - # Build full tile size list for this input + x_offset = x.get_layout().offset full_tile_sizes = [] tile_size_idx = 0 for d in range(rank): @@ -226,23 +222,37 @@ def _build_tile_descriptors( else: full_tile_sizes.append(input_tile_sizes_dim[i]) - tile_desc = mlir_common.MLIRMultiDimTile( + # Input tile descriptor + input_tile_desc = mlir_common.MLIRMultiDimTile( full_tile_sizes, kernel.vector_lane, vlane_split_axis=rank - 1, vlane_stride=1 ) - tile_desc.set_tile_size(full_tile_sizes) + input_tile_desc.set_tile_size(full_tile_sizes) template_buffer_name = input_buffer_names[i] - tile_desc.set_name(f"{template_buffer_name.lower()}_cat_tile") - input_tile_descs.append(tile_desc) + input_tile_desc.set_name(f"{template_buffer_name.lower()}_cat_tile") + input_tile_desc.offset = x_offset + input_tile_descs.append(input_tile_desc) + + # Output tile descriptor (same as input but with output offset) + output_tile_desc = mlir_common.MLIRMultiDimTile( + full_tile_sizes, + kernel.vector_lane, + vlane_split_axis=rank - 1, + vlane_stride=1 + ) + output_tile_desc.set_tile_size(full_tile_sizes) + output_tile_desc.set_name(f"{template_buffer_name.lower()}_cat_tile") + output_tile_desc.offset = output_offset + output_tile_descs.append(output_tile_desc) # Store unique tile desc by actual buffer name actual_name = x.get_name() if actual_name not in unique_tile_descs: - unique_tile_descs[actual_name] = tile_desc + unique_tile_descs[actual_name] = input_tile_desc - return input_tile_descs, unique_tile_descs + return input_tile_descs, output_tile_descs, unique_tile_descs def _build_index_expressions( self, input_nodes, input_sizes, output_strides, rank, num_inputs @@ -256,6 +266,12 @@ def _build_index_expressions( for i, x in enumerate(input_nodes): x_stride = x.get_layout().stride + x_offset = x.get_layout().offset + if hasattr(x, 'data') and hasattr(x.data, 'dims'): + # In case of PermuteView, the stride is permuted + perm_dims = x.data.dims + x_stride = [x_stride[perm_dims[d]] for d in range(rank)] + input_idx = [] output_idx = [] for d in range(rank): @@ -271,25 +287,3 @@ def _build_index_expressions( output_idxs.append(output_idx) return input_idxs, output_idxs, cumulative_offsets - - def _build_output_tile_desc(self, kernel, input_tile_sizes_dim, tile_sizes, rank): - """Build output tile descriptor.""" - max_output_tile_dim = max(input_tile_sizes_dim) if input_tile_sizes_dim else 1 - output_full_tile_sizes = [] - tile_size_idx = 0 - for d in range(rank): - if d != self.dim: - output_full_tile_sizes.append(tile_sizes[tile_size_idx]) - tile_size_idx += 1 - else: - output_full_tile_sizes.append(max_output_tile_dim) - - y_tile_desc = mlir_common.MLIRMultiDimTile( - output_full_tile_sizes, - kernel.vector_lane, - vlane_split_axis=rank - 1, - vlane_stride=1 - ) - y_tile_desc.set_tile_size(output_full_tile_sizes) - y_tile_desc.set_name("y_cat_tile") - return y_tile_desc From 47684a75942bf9d35e19a7a79a1862418c5649a6 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Thu, 5 Mar 2026 17:44:32 +0900 Subject: [PATCH 09/31] [TOGSim] Add help print --- TOGSim/src/DMA.cc | 2 +- TOGSim/src/helper/CommandLineParser.cc | 6 +++++- TOGSim/src/helper/CommandLineParser.h | 8 +++++++- TOGSim/src/main.cc | 13 +++++++++---- 4 files changed, 22 insertions(+), 7 deletions(-) diff --git a/TOGSim/src/DMA.cc b/TOGSim/src/DMA.cc index f8f21025..fefee6d2 100644 --- a/TOGSim/src/DMA.cc +++ b/TOGSim/src/DMA.cc @@ -12,7 +12,7 @@ void DMA::issue_tile(std::shared_ptr inst) { _current_inst = std::move(inst); std::vector& tile_size = _current_inst->get_tile_size(); if (tile_size.size() <= 0 || tile_size.size() > get_max_dim()) { - spdlog::error("[DMA {}] issued tile is not supported format..", _id); + spdlog::error("[DMA {}] issued tile is not supported format.. tile.size: {}, tile_size: [{}]", _id, tile_size.size(), fmt::join(tile_size, ", ")); exit(EXIT_FAILURE); } _finished = false; diff --git a/TOGSim/src/helper/CommandLineParser.cc b/TOGSim/src/helper/CommandLineParser.cc index 66aebbe1..9cd177ac 100644 --- a/TOGSim/src/helper/CommandLineParser.cc +++ b/TOGSim/src/helper/CommandLineParser.cc @@ -12,9 +12,13 @@ void CommandLineParser::parse(int argc, char **argv) noexcept(false) { po::notify(variables_map); } +void CommandLineParser::print_help_message() const noexcept { + std::cout << options_description << std::endl; +} + void CommandLineParser::print_help_message_if_required() const noexcept { if (variables_map.count("help") > 0) { - std::cout << options_description << std::endl; + print_help_message(); exit(0); } } diff --git a/TOGSim/src/helper/CommandLineParser.h b/TOGSim/src/helper/CommandLineParser.h index 39174d5d..b41eabf3 100644 --- a/TOGSim/src/helper/CommandLineParser.h +++ b/TOGSim/src/helper/CommandLineParser.h @@ -19,7 +19,7 @@ class CommandLineParser { * Command Line Parser constructor */ CommandLineParser() noexcept { - options_description.add_options()("help", "Prints help message"); + options_description.add_options()("help,h", "Prints help message"); } /** @@ -38,6 +38,12 @@ class CommandLineParser { */ void print_help_message_if_required() const noexcept; + /** + * Prints the help message. + * (Can be called to show help for invalid options) + */ + void print_help_message() const noexcept; + /** * Add a new command line argument option. * (Should be called before `parse` method is called) diff --git a/TOGSim/src/main.cc b/TOGSim/src/main.cc index 7c596af5..cda8f986 100644 --- a/TOGSim/src/main.cc +++ b/TOGSim/src/main.cc @@ -96,19 +96,24 @@ int main(int argc, char** argv) { // parse command line argumnet CommandLineParser cmd_parser = CommandLineParser(); cmd_parser.add_command_line_option( - "config", "Path for hardware configuration file"); + "config", "Path for hardware configuration file (.yml)"); cmd_parser.add_command_line_option( - "models_list", "Path for the models list file (can be FIFO or regular file)"); + "models_list", "Path for the trace file (.trace)"); cmd_parser.add_command_line_option( "log_level", "Set for log level [trace, debug, info], default = info"); try { cmd_parser.parse(argc, argv); } catch (const CommandLineParser::ParsingError& e) { spdlog::error( - "Command line argument parrsing error captured. Error message: {}", + "Command line argument parsing error captured. Error message: {}", e.what()); - throw(e); + std::cerr << std::endl; + cmd_parser.print_help_message(); + exit(1); } + + // Check if help was requested + cmd_parser.print_help_message_if_required(); std::string level = "info"; cmd_parser.set_if_defined("log_level", &level); From a24f1f1081a4ce7e5e09a59f61763850d11d994f Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Thu, 5 Mar 2026 17:45:00 +0900 Subject: [PATCH 10/31] [Template/Cat] Limit maximum rank of tile --- PyTorchSimFrontend/mlir/mlir_cat_template.py | 52 +++++++++++++++----- 1 file changed, 41 insertions(+), 11 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_cat_template.py b/PyTorchSimFrontend/mlir/mlir_cat_template.py index 5aaf3e71..2a00ce95 100644 --- a/PyTorchSimFrontend/mlir/mlir_cat_template.py +++ b/PyTorchSimFrontend/mlir/mlir_cat_template.py @@ -64,17 +64,30 @@ def render( tile_sizes = tile_info if tile_info is not None else [1] * len(output_sizes) output_strides = y.get_layout().stride + excluded_dims = list() + max_tiled_dims = 4 - 1 + if len(tile_sizes) > max_tiled_dims: + # Create index:tile_size dictionary and sort by tile_size + dim_tile_dict = {idx: sz for idx, sz in enumerate(tile_sizes)} + sorted_dims = sorted(dim_tile_dict.items(), key=lambda x: x[1], reverse=True) + # Keep top 4 dimensions, exclude the rest + excluded_dims = [idx for idx, _ in sorted_dims[max_tiled_dims:]] + for idx in excluded_dims: + tile_sizes[idx] = 1 + # Calculate input tile sizes input_tile_sizes_dim = self._calculate_input_tile_sizes( kernel, input_sizes, tile_sizes, num_inputs, rank ) buffer_name_to_template_name, input_buffer_names = self._build_buffer_mapping(input_nodes) input_tile_descs, output_tile_descs, unique_tile_descs = self._build_tile_descriptors( - kernel, input_nodes, input_sizes, input_tile_sizes_dim, tile_sizes, rank, input_buffer_names, y + kernel, input_nodes, input_sizes, input_tile_sizes_dim, tile_sizes, rank, input_buffer_names, y, + excluded_dims=excluded_dims ) input_idxs, output_idxs, cumulative_offsets = self._build_index_expressions( - input_nodes, input_sizes, output_strides, rank, num_inputs + input_nodes, input_sizes, output_strides, rank, num_inputs, + excluded_dims=excluded_dims ) # Map unique buffer names to their tile descriptors for template @@ -203,9 +216,12 @@ def _build_buffer_mapping(self, input_nodes): return buffer_name_to_template_name, input_buffer_names def _build_tile_descriptors( - self, kernel, input_nodes, input_sizes, input_tile_sizes_dim, tile_sizes, rank, input_buffer_names, output_node + self, kernel, input_nodes, input_sizes, input_tile_sizes_dim, tile_sizes, rank, input_buffer_names, output_node, excluded_dims=None ): """Build tile descriptors for each input and output.""" + if excluded_dims is None: + excluded_dims = set() + input_tile_descs = [] output_tile_descs = [] unique_tile_descs = {} @@ -217,16 +233,21 @@ def _build_tile_descriptors( tile_size_idx = 0 for d in range(rank): if d != self.dim: - full_tile_sizes.append(tile_sizes[tile_size_idx]) + # Skip excluded dimensions + if tile_size_idx not in excluded_dims: + full_tile_sizes.append(tile_sizes[tile_size_idx]) tile_size_idx += 1 else: full_tile_sizes.append(input_tile_sizes_dim[i]) + # Calculate vlane_split_axis for reduced dimensions + vlane_split_axis = len(full_tile_sizes) - 1 + # Input tile descriptor input_tile_desc = mlir_common.MLIRMultiDimTile( full_tile_sizes, kernel.vector_lane, - vlane_split_axis=rank - 1, + vlane_split_axis=vlane_split_axis, vlane_stride=1 ) input_tile_desc.set_tile_size(full_tile_sizes) @@ -239,7 +260,7 @@ def _build_tile_descriptors( output_tile_desc = mlir_common.MLIRMultiDimTile( full_tile_sizes, kernel.vector_lane, - vlane_split_axis=rank - 1, + vlane_split_axis=vlane_split_axis, vlane_stride=1 ) output_tile_desc.set_tile_size(full_tile_sizes) @@ -255,9 +276,12 @@ def _build_tile_descriptors( return input_tile_descs, output_tile_descs, unique_tile_descs def _build_index_expressions( - self, input_nodes, input_sizes, output_strides, rank, num_inputs + self, input_nodes, input_sizes, output_strides, rank, num_inputs, excluded_dims=None ): """Build index expressions for input and output.""" + if excluded_dims is None: + excluded_dims = set() + input_idxs = [] output_idxs = [] cumulative_offsets = [0] @@ -274,15 +298,21 @@ def _build_index_expressions( input_idx = [] output_idx = [] + tile_size_idx = 0 for d in range(rank): if d != self.dim: - input_idx_symbol = sympy.Symbol(f"index{d}") - output_idx_symbol = sympy.Symbol(f"index{d}") + # Skip excluded dimensions + if tile_size_idx not in excluded_dims: + input_idx_symbol = sympy.Symbol(f"index{d}") + output_idx_symbol = sympy.Symbol(f"index{d}") + input_idx.append(input_idx_symbol * x_stride[d]) + output_idx.append(output_idx_symbol * output_strides[d]) + tile_size_idx += 1 else: input_idx_symbol = sympy.Symbol(f"index_local{self.dim}_{i}") output_idx_symbol = sympy.Symbol(f"index{self.dim}_{i}") - input_idx.append(input_idx_symbol * x_stride[d]) - output_idx.append(output_idx_symbol * output_strides[d]) + input_idx.append(input_idx_symbol * x_stride[d]) + output_idx.append(output_idx_symbol * output_strides[d]) input_idxs.append(input_idx) output_idxs.append(output_idx) From 4e4300e2cda61dcc5eeec103c91fe5ef13ff3a73 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Thu, 5 Mar 2026 20:22:10 +0900 Subject: [PATCH 11/31] [Template/Cat] Refactor cat + Support explicit dram+stride in def_dma_op --- .github/workflows/pytorchsim_test.yml | 21 + PyTorchSimFrontend/mlir/mlir_cat_template.py | 401 ++++++++++--------- PyTorchSimFrontend/mlir/mlir_template.py | 48 ++- tests/test_cat.py | 16 +- 4 files changed, 288 insertions(+), 198 deletions(-) diff --git a/.github/workflows/pytorchsim_test.yml b/.github/workflows/pytorchsim_test.yml index 9589384b..eaaa7e50 100644 --- a/.github/workflows/pytorchsim_test.yml +++ b/.github/workflows/pytorchsim_test.yml @@ -163,6 +163,27 @@ jobs: -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/test_conv2d.py + test_cat: + name: Run test_cat.py + runs-on: self-hosted + steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Run test_cat.py + run: | + echo "Running test_cat.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + -e vpu_num_lanes="${{ inputs.vector_lane }}" \ + -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ + ${{ inputs.image_name }} python3 PyTorchSim/tests/test_cat.py + test_matmul: name: Run test_matmul.py runs-on: self-hosted diff --git a/PyTorchSimFrontend/mlir/mlir_cat_template.py b/PyTorchSimFrontend/mlir/mlir_cat_template.py index 2a00ce95..6eb60198 100644 --- a/PyTorchSimFrontend/mlir/mlir_cat_template.py +++ b/PyTorchSimFrontend/mlir/mlir_cat_template.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List, Optional, Set import math import itertools @@ -23,10 +23,12 @@ {%- endfor %} {%- for i in range(NUM_INPUTS) %} // Input tensor{{ i }} - affine.for %index_local{{ DIM }}_{{ i }} = 0 to {{ INPUT_SIZES[i][DIM] }} step {{ INPUT_TILE_SIZES_DIM[i] }} { - %index{{ DIM }}_{{i}} = affine.apply affine_map<(d0) -> (d0 + {{ CUMULATIVE_OFFSETS[i] }})> (%index_local{{ DIM }}_{{ i }}) - {{ kernel.def_dma_op("MVIN", INPUT_BUFFER_NAMES[i], INPUT_IDXS[i], INPUT_TILE_DESCS[i], indent_size=INDENT_SIZE) }} - {{ kernel.def_dma_op("MVOUT", OUT_DVAR, OUTPUT_IDXS[i], OUTPUT_TILE_DESCS[i], indent_size=INDENT_SIZE) }} + affine.for %index_local{{ DIM }}_{{ i }} = 0 to {{ INPUTS[i].sizes[DIM] }} step {{ INPUTS[i].tile_size_dim }} { + %index{{ DIM }}_{{ i }} = affine.apply affine_map<(d0) -> (d0 + {{ INPUTS[i].cum_offset }})> (%index_local{{ DIM }}_{{ i }}) + %input_dram_offset_{{ i }} = affine.apply {{ INPUTS[i].offset_map }}({{ INPUTS[i].offset_vars }}) + %output_dram_offset_{{ i }} = affine.apply {{ OUTPUTS[i].offset_map }}({{ OUTPUTS[i].offset_vars }}) + {{ kernel.def_dma_op("MVIN", INPUTS[i].dram_name, [], INPUTS[i].tile_desc, indent_size=INDENT_SIZE, dram_stride=INPUTS[i].dram_strides, dram_offset="input_dram_offset_" ~ i) }} + {{ kernel.def_dma_op("MVOUT", "Y", [], OUTPUTS[i].tile_desc, indent_size=INDENT_SIZE, dram_stride=OUTPUTS[i].dram_strides, dram_offset="output_dram_offset_" ~ i) }} } { inner_loop=true } {%- endfor %} @@ -52,81 +54,84 @@ def render( tile_info=None, **kwargs, ): - # Extract info input_nodes = self.input_nodes y = self.output_node - num_inputs = len(self.input_nodes) + num_inputs = len(input_nodes) rank = len(y.get_size()) input_sizes = [x.get_size() for x in input_nodes] - output_sizes = [sz for dim, sz in enumerate(y.get_size()) if dim != self.dim] - output_dim = [dim for dim, sz in enumerate(y.get_size()) if dim != self.dim] - tile_sizes = tile_info if tile_info is not None else [1] * len(output_sizes) + output_sizes = [sz for d, sz in enumerate(y.get_size()) if d != self.dim] + output_dim = [d for d, _ in enumerate(y.get_size()) if d != self.dim] output_strides = y.get_layout().stride - excluded_dims = list() - max_tiled_dims = 4 - 1 - if len(tile_sizes) > max_tiled_dims: - # Create index:tile_size dictionary and sort by tile_size - dim_tile_dict = {idx: sz for idx, sz in enumerate(tile_sizes)} - sorted_dims = sorted(dim_tile_dict.items(), key=lambda x: x[1], reverse=True) - # Keep top 4 dimensions, exclude the rest - excluded_dims = [idx for idx, _ in sorted_dims[max_tiled_dims:]] - for idx in excluded_dims: - tile_sizes[idx] = 1 - - # Calculate input tile sizes + tile_sizes = list(tile_info) if tile_info is not None else [1] * len(output_sizes) + excluded_dims = self._compute_excluded_dims(tile_sizes) + input_tile_sizes_dim = self._calculate_input_tile_sizes( kernel, input_sizes, tile_sizes, num_inputs, rank ) - buffer_name_to_template_name, input_buffer_names = self._build_buffer_mapping(input_nodes) + buffer_name_to_template_name, input_dram_names = self._build_buffer_mapping(input_nodes) input_tile_descs, output_tile_descs, unique_tile_descs = self._build_tile_descriptors( - kernel, input_nodes, input_sizes, input_tile_sizes_dim, tile_sizes, rank, input_buffer_names, y, - excluded_dims=excluded_dims + kernel, input_nodes, input_sizes, input_tile_sizes_dim, tile_sizes, rank, + input_dram_names, y, excluded_dims=excluded_dims ) - - input_idxs, output_idxs, cumulative_offsets = self._build_index_expressions( - input_nodes, input_sizes, output_strides, rank, num_inputs, - excluded_dims=excluded_dims + (input_offset_maps, input_offset_var_strs, input_dram_strides, + output_offset_maps, output_offset_var_strs, output_dram_strides, + cumulative_offsets) = self._build_dma_info( + input_nodes, input_sizes, output_strides, input_tile_descs, output_tile_descs, + rank, num_inputs, excluded_dims=excluded_dims ) - # Map unique buffer names to their tile descriptors for template - unique_buffer_tile_descs = {} - for actual_name, template_name in buffer_name_to_template_name.items(): - if actual_name in unique_tile_descs: - unique_buffer_tile_descs[template_name] = unique_tile_descs[actual_name] - - names_str = ", ".join(input_buffer_names + ["Y"]) + unique_buffer_tile_descs = { + buffer_name_to_template_name[name]: desc + for name, desc in unique_tile_descs.items() + } + names_str = ", ".join(input_dram_names + ["Y"]) indent_size = 2 + (rank - 1) * 2 + 4 + inputs_info = [ + dict( + dram_name = input_dram_names[i], + sizes = input_sizes[i], + tile_size_dim= input_tile_sizes_dim[i], + tile_desc = input_tile_descs[i], + offset_map = input_offset_maps[i], + offset_vars = input_offset_var_strs[i], + dram_strides = input_dram_strides[i], + cum_offset = cumulative_offsets[i], + ) + for i in range(num_inputs) + ] + outputs_info = [ + dict( + tile_desc = output_tile_descs[i], + offset_map = output_offset_maps[i], + offset_vars = output_offset_var_strs[i], + dram_strides = output_dram_strides[i], + ) + for i in range(num_inputs) + ] + kernel.render_options = dict( - KERNEL_NAME=self.name, - kernel=kernel, - Y=y, - OUT_DVAR="Y", - NAMES_STR=names_str, - INPUT_NAMES=input_nodes, - INPUT_BUFFER_NAMES=input_buffer_names, - NUM_INPUTS=num_inputs, - RANK=rank, - DIM=self.dim, - INPUT_SIZES=input_sizes, - OUTPUT_SIZES=output_sizes, - OUTPUT_DIM=output_dim, - TILE_SIZES=tile_sizes, - INPUT_TILE_SIZES_DIM=input_tile_sizes_dim, - INPUT_TILE_DESCS=input_tile_descs, - OUTPUT_TILE_DESCS=output_tile_descs, - UNIQUE_BUFFER_TILE_DESCS=unique_buffer_tile_descs, - INPUT_IDXS=input_idxs, - OUTPUT_IDXS=output_idxs, - CUMULATIVE_OFFSETS=cumulative_offsets, - INDENT_SIZE=indent_size, - input_reorder=self.input_reorder, + KERNEL_NAME = self.name, + kernel = kernel, + NUM_INPUTS = num_inputs, + NAMES_STR = names_str, + Y = y, + INPUT_NAMES = input_nodes, + RANK = rank, + DIM = self.dim, + OUTPUT_SIZES = output_sizes, + OUTPUT_DIM = output_dim, + TILE_SIZES = tile_sizes, + UNIQUE_BUFFER_TILE_DESCS = unique_buffer_tile_descs, + INPUTS = inputs_info, + OUTPUTS = outputs_info, + INDENT_SIZE = indent_size, + input_reorder = self.input_reorder, ) - code = self._template_from_string(TEMPLATE).render(**kernel.render_options) - return code + return self._template_from_string(TEMPLATE).render(**kernel.render_options) def get_tile_candidates( self, @@ -141,179 +146,217 @@ def get_tile_candidates( y = self.output_node num_inputs = len(self.input_nodes) - output_sizes = [sz for dim, sz in enumerate(y.get_size()) if dim != self.dim] - num_non_dim_dims = len(output_sizes) + output_sizes = [sz for d, sz in enumerate(y.get_size()) if d != self.dim] - if num_non_dim_dims == 0: + if not output_sizes: return [[1]] - tile_candidates = [] - dim_tile_candidates = [] + max_tile_total = kernel.spad_info["spad_size"] // ( + kernel.vector_lane * kernel.precision * 2 * num_inputs + ) + dim_tile_candidates = [] for dim_size in output_sizes: - dim_candidates = [] - max_tile = min(dim_size, kernel.spad_info["spad_size"] // (kernel.vector_lane * kernel.precision * 2 * num_inputs)) - + max_tile = min(dim_size, max_tile_total) + candidates = set() for mult in range(1, max_tile // kernel.vector_lane + 1): - tile = mult * kernel.vector_lane - if tile <= dim_size: - dim_candidates.append(tile) - + t = mult * kernel.vector_lane + if t <= dim_size: + candidates.add(t) if max_tile > 0: for exp in range(int(math.log2(max_tile)) + 1): - tile = 2 ** exp - if tile <= dim_size and tile not in dim_candidates: - dim_candidates.append(tile) - - if dim_size not in dim_candidates: - dim_candidates.append(dim_size) - - dim_tile_candidates.append(sorted(set(dim_candidates))[:5]) - - for tile_combo in itertools.product(*dim_tile_candidates): - total_elements = math.prod(tile_combo) - total_spad_needed = total_elements * (num_inputs + 1) * kernel.precision - - if total_spad_needed <= kernel.spad_info["spad_size"] * kernel.vector_lane: - tile_candidates.append(list(tile_combo)) + t = 2 ** exp + if t <= dim_size: + candidates.add(t) + candidates.add(dim_size) + dim_tile_candidates.append(sorted(candidates)[:5]) + + tile_candidates = [ + list(combo) + for combo in itertools.product(*dim_tile_candidates) + if math.prod(combo) * (num_inputs + 1) * kernel.precision + <= kernel.spad_info["spad_size"] * kernel.vector_lane + ] if not tile_candidates: - tile_candidates = [[1] * num_non_dim_dims] + tile_candidates = [[1] * len(output_sizes)] tile_candidates.sort(key=lambda x: -math.prod(x)) return tile_candidates[:4] - def _calculate_input_tile_sizes( - self, kernel, input_sizes, tile_sizes, num_inputs, rank - ): - """Calculate tile sizes for concat dimension for each input.""" + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + def _compute_excluded_dims(self, tile_sizes: list) -> list: + """Return non-tiled dimension indices when rank exceeds the 4-dim limit.""" + max_tiled = 3 + if len(tile_sizes) <= max_tiled: + return [] + sorted_dims = sorted(enumerate(tile_sizes), key=lambda x: x[1], reverse=True) + excluded = [idx for idx, _ in sorted_dims[max_tiled:]] + for idx in excluded: + tile_sizes[idx] = 1 + return excluded + + def _calculate_input_tile_sizes(self, kernel, input_sizes, tile_sizes, num_inputs, rank): + """Calculate tile sizes along the concat dimension for each input.""" non_dim_tile_elements = math.prod(tile_sizes) if tile_sizes else 1 - non_dim_tile_spad = non_dim_tile_elements * kernel.precision max_spad_per_input = kernel.spad_info["spad_size"] * kernel.vector_lane // 2 - extra_concat_input = math.ceil(max_spad_per_input / non_dim_tile_spad) - num_inputs + extra_concat = math.ceil(max_spad_per_input / (non_dim_tile_elements * kernel.precision)) - num_inputs input_tile_sizes_dim = [] for i in range(num_inputs): - input_dim_size = input_sizes[i][self.dim] - if extra_concat_input > 0 and non_dim_tile_elements > 0: - max_tile_dim = min(input_dim_size, extra_concat_input) - extra_concat_input -= max_tile_dim + if extra_concat > 0 and non_dim_tile_elements > 0: + tile_dim = min(input_sizes[i][self.dim], extra_concat) + extra_concat -= tile_dim else: - max_tile_dim = 1 - input_tile_sizes_dim.append(max_tile_dim) + tile_dim = 1 + input_tile_sizes_dim.append(tile_dim) return input_tile_sizes_dim def _build_buffer_mapping(self, input_nodes): - """Map actual buffer names to template buffer names """ - buffer_name_to_template_name = {} - input_buffer_names = [] + """Map actual buffer names to short template names (X0, X1, ...).""" + name_map = {} + template_names = [] for x in input_nodes: - actual_name = x.get_name() - template_name = buffer_name_to_template_name.setdefault( - actual_name, f"X{len(buffer_name_to_template_name)}" - ) - input_buffer_names.append(template_name) - return buffer_name_to_template_name, input_buffer_names + actual = x.get_name() + template = name_map.setdefault(actual, f"X{len(name_map)}") + template_names.append(template) + return name_map, template_names def _build_tile_descriptors( - self, kernel, input_nodes, input_sizes, input_tile_sizes_dim, tile_sizes, rank, input_buffer_names, output_node, excluded_dims=None + self, kernel, input_nodes, input_sizes, input_tile_sizes_dim, tile_sizes, rank, + input_buffer_names, output_node, excluded_dims=None ): - """Build tile descriptors for each input and output.""" + """Build tile descriptors for every input (and its paired output).""" if excluded_dims is None: excluded_dims = set() - input_tile_descs = [] - output_tile_descs = [] - unique_tile_descs = {} + def make_tile_desc(tile_sz, vector_lane, name, offset): + desc = mlir_common.MLIRMultiDimTile( + tile_sz, vector_lane, + vlane_split_axis=len(tile_sz) - 1, + vlane_stride=1 + ) + desc.set_tile_size(tile_sz) + desc.set_name(name) + desc.offset = offset + return desc + output_offset = output_node.get_layout().offset + input_tile_descs, output_tile_descs, unique_tile_descs = [], [], {} for i, x in enumerate(input_nodes): - x_offset = x.get_layout().offset - full_tile_sizes = [] - tile_size_idx = 0 + # Collect tile sizes for tiled dimensions only (skip excluded non-concat dims) + tile_sz = [] + tile_idx = 0 for d in range(rank): if d != self.dim: - # Skip excluded dimensions - if tile_size_idx not in excluded_dims: - full_tile_sizes.append(tile_sizes[tile_size_idx]) - tile_size_idx += 1 + if tile_idx not in excluded_dims: + tile_sz.append(tile_sizes[tile_idx]) + tile_idx += 1 else: - full_tile_sizes.append(input_tile_sizes_dim[i]) + tile_sz.append(input_tile_sizes_dim[i]) - # Calculate vlane_split_axis for reduced dimensions - vlane_split_axis = len(full_tile_sizes) - 1 + sram_name = f"{input_buffer_names[i].lower()}_cat_tile" + input_tile_descs.append(make_tile_desc(tile_sz, kernel.vector_lane, sram_name, x.get_layout().offset)) + output_tile_descs.append(make_tile_desc(tile_sz, kernel.vector_lane, sram_name, output_offset)) - # Input tile descriptor - input_tile_desc = mlir_common.MLIRMultiDimTile( - full_tile_sizes, - kernel.vector_lane, - vlane_split_axis=vlane_split_axis, - vlane_stride=1 - ) - input_tile_desc.set_tile_size(full_tile_sizes) - template_buffer_name = input_buffer_names[i] - input_tile_desc.set_name(f"{template_buffer_name.lower()}_cat_tile") - input_tile_desc.offset = x_offset - input_tile_descs.append(input_tile_desc) - - # Output tile descriptor (same as input but with output offset) - output_tile_desc = mlir_common.MLIRMultiDimTile( - full_tile_sizes, - kernel.vector_lane, - vlane_split_axis=vlane_split_axis, - vlane_stride=1 - ) - output_tile_desc.set_tile_size(full_tile_sizes) - output_tile_desc.set_name(f"{template_buffer_name.lower()}_cat_tile") - output_tile_desc.offset = output_offset - output_tile_descs.append(output_tile_desc) - - # Store unique tile desc by actual buffer name actual_name = x.get_name() if actual_name not in unique_tile_descs: - unique_tile_descs[actual_name] = input_tile_desc + unique_tile_descs[actual_name] = input_tile_descs[-1] return input_tile_descs, output_tile_descs, unique_tile_descs - def _build_index_expressions( - self, input_nodes, input_sizes, output_strides, rank, num_inputs, excluded_dims=None + def _build_dma_info( + self, input_nodes, input_sizes, output_strides, + input_tile_descs, output_tile_descs, + rank, num_inputs, excluded_dims=None ): - """Build index expressions for input and output.""" + """Build per-input DRAM offset affine maps and tile strides. + + Three stride concepts are maintained: + + * layout_strides (internal) - raw DRAM buffer strides for every rank + dimension, used to compute the flat base-address affine map. + These reflect how the tensor is physically laid out in DRAM. + * dram_strides (returned, ``def_dma_op dram_stride=``) - stride in + DRAM per *tiled* dimension (excluded dims removed). The DMA engine + uses these to walk DRAM when loading/storing a tile. + * sram_strides (inside ``def_dma_op``, from tile_desc) - stride in + SRAM per tiled dimension. The DMA engine uses these to place data + into the SRAM tile buffer. + + Returns: + input_offset_maps, input_offset_var_strs, input_dram_strides, + output_offset_maps, output_offset_var_strs, output_dram_strides, + cumulative_offsets + """ if excluded_dims is None: excluded_dims = set() - input_idxs = [] - output_idxs = [] + def make_affine_map(idx_syms, strides, layout_offset): + terms = [] + for j, s in enumerate(strides): + s = int(s) + if s == 1: + terms.append(f"d{j}") + elif s != 0: + terms.append(f"d{j} * {s}") + try: + off = int(layout_offset) + except (TypeError, ValueError): + off = 0 + if off: + terms.append(str(off)) + dim_str = ", ".join(f"d{j}" for j in range(len(idx_syms))) + return f"affine_map<({dim_str}) -> ({' + '.join(terms) if terms else '0'})>" + cumulative_offsets = [0] for i in range(num_inputs - 1): cumulative_offsets.append(cumulative_offsets[-1] + input_sizes[i][self.dim]) + input_offset_maps, input_offset_var_strs, input_dram_strides = [], [], [] + output_offset_maps, output_offset_var_strs, output_dram_strides = [], [], [] + for i, x in enumerate(input_nodes): x_stride = x.get_layout().stride - x_offset = x.get_layout().offset if hasattr(x, 'data') and hasattr(x.data, 'dims'): - # In case of PermuteView, the stride is permuted - perm_dims = x.data.dims - x_stride = [x_stride[perm_dims[d]] for d in range(rank)] + # PermuteView: re-order strides according to the permutation + perm = x.data.dims + x_stride = [x_stride[perm[d]] for d in range(rank)] + + in_syms, in_layout_strides, in_dram_strides = [], [], [] + out_syms, out_layout_strides, out_dram_strides = [], [], [] + tile_idx = 0 - input_idx = [] - output_idx = [] - tile_size_idx = 0 for d in range(rank): if d != self.dim: - # Skip excluded dimensions - if tile_size_idx not in excluded_dims: - input_idx_symbol = sympy.Symbol(f"index{d}") - output_idx_symbol = sympy.Symbol(f"index{d}") - input_idx.append(input_idx_symbol * x_stride[d]) - output_idx.append(output_idx_symbol * output_strides[d]) - tile_size_idx += 1 + in_syms.append(sympy.Symbol(f"index{d}")) + in_layout_strides.append(int(x_stride[d])) + out_syms.append(sympy.Symbol(f"index{d}")) + out_layout_strides.append(int(output_strides[d])) + if tile_idx not in excluded_dims: + in_dram_strides.append(int(x_stride[d])) + out_dram_strides.append(int(output_strides[d])) + tile_idx += 1 else: - input_idx_symbol = sympy.Symbol(f"index_local{self.dim}_{i}") - output_idx_symbol = sympy.Symbol(f"index{self.dim}_{i}") - input_idx.append(input_idx_symbol * x_stride[d]) - output_idx.append(output_idx_symbol * output_strides[d]) - input_idxs.append(input_idx) - output_idxs.append(output_idx) - - return input_idxs, output_idxs, cumulative_offsets + in_syms.append(sympy.Symbol(f"index_local{self.dim}_{i}")) + in_layout_strides.append(int(x_stride[d])) + out_syms.append(sympy.Symbol(f"index{self.dim}_{i}")) + out_layout_strides.append(int(output_strides[d])) + in_dram_strides.append(int(x_stride[d])) + out_dram_strides.append(int(output_strides[d])) + + input_offset_maps.append(make_affine_map(in_syms, in_layout_strides, input_tile_descs[i].offset)) + input_offset_var_strs.append(", ".join(f"%{s}" for s in in_syms)) + input_dram_strides.append(in_dram_strides) + + output_offset_maps.append(make_affine_map(out_syms, out_layout_strides, output_tile_descs[i].offset)) + output_offset_var_strs.append(", ".join(f"%{s}" for s in out_syms)) + output_dram_strides.append(out_dram_strides) + + return (input_offset_maps, input_offset_var_strs, input_dram_strides, + output_offset_maps, output_offset_var_strs, output_dram_strides, + cumulative_offsets) diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index 7c52bfe6..9cc79e0a 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -809,12 +809,18 @@ def hook(): return key def def_dma_op(self, dma_type, dram_var:str, index_list:list, tile_desc:mlir_common.MLIRMultiDimTile, - subtile_size:list=[], async_type=None, indent_size=0, priority: int = 5, lazy_mode: bool = True): + subtile_size:list=[], async_type=None, indent_size=0, priority: int = 5, lazy_mode: bool = True, + dram_stride:list=None, dram_offset=None): + # Todo. Remove legacy behavior (i.e., index_list parsing) def generate_dma_code(): """Internal method to generate DMA code directly.""" local_code = IndentedBuffer() with self, self.override_buffer_cse(buffer=local_code, cse=self.apply_cse): - index_var = self.parse_index_list(index_list, offset=tile_desc.offset) + if dram_offset is not None: + # Use explicitly provided offset (pre-computed MLIR SSA variable name) + index_var = dram_offset + else: + index_var = self.parse_index_list(index_list, offset=tile_desc.offset) node_layout = self.named_nodes[dram_var].get_layout() if dram_var in self.exception_nodes: numel = self.exception_nodes[dram_var]["numel"] @@ -822,27 +828,33 @@ def generate_dma_code(): numel = self.get_arg_info(self.named_nodes[dram_var].get_name()).get_numel() mlir_dtype = mlir_common.DTYPE_TO_MLIR[node_layout.dtype] dram_shape = f"memref<{numel}x{mlir_dtype}>" - dram_stride = [] - for idx in index_list: - if idx.is_Mul: - dram_stride.append(int(idx.args[0])) - elif idx == sympy.Symbol("c0"): - dram_stride.append(0) - elif not idx.is_Number: - dram_stride.append(1) - else: - dram_stride.append(0) - sram_var = tile_desc.get_name() - tile_shape = tile_desc.get_mlir_shape(mlir_dtype) - tile_stride = tile_desc.get_tile_stride() - vlane_split_axis = tile_desc.vmap.vlane_split_axis - vlane_stride = tile_desc.vmap.vlane_stride + if dram_stride is not None: + # Use explicitly provided dram_stride + _dram_stride = dram_stride + else: + # Extract dram_stride from index_list (legacy behavior) + _dram_stride = [] + for idx in index_list: + if idx.is_Mul: + _dram_stride.append(int(idx.args[0])) + elif idx == sympy.Symbol("c0"): + _dram_stride.append(0) + elif not idx.is_Number: + _dram_stride.append(1) + else: + _dram_stride.append(0) + + sram_var = tile_desc.get_name() + tile_shape = tile_desc.get_mlir_shape(mlir_dtype) + sram_strides = tile_desc.get_tile_stride() + vlane_split_axis = tile_desc.vmap.vlane_split_axis + vlane_stride = tile_desc.vmap.vlane_stride zero_cse = self.get_const_cse(0, "index") sram_index_var = ", ".join([f"%{str(zero_cse)}"]*tile_desc.get_nr_dim()) - attribute_parts = [f"dram_stride={dram_stride}", f"sram_stride={tile_stride}", "padding=0"] + attribute_parts = [f"dram_stride={_dram_stride}", f"sram_stride={sram_strides}", "padding=0"] if subtile_size: attribute_parts.append(f"subtile_size={subtile_size}, async={int(async_type) if async_type is not None else 1}") attribute = " {" + ", ".join(attribute_parts) + "}" diff --git a/tests/test_cat.py b/tests/test_cat.py index 62de6759..97fcc754 100644 --- a/tests/test_cat.py +++ b/tests/test_cat.py @@ -150,13 +150,25 @@ def cat_4d_three_inputs_fn(a, b, c): cpu_out = torch.cat([x.cpu(), y.cpu(), z.cpu()], dim=1) _test_result("cat.4d.three_inputs", out, cpu_out, rtol=1e-4, atol=1e-4) +def test_cat_5d(device, dim=0): + def cat_5d_fn(a, b): + return torch.cat([a, b], dim=dim) + + x = torch.randn(2, 3, 4, 5, 6, device=device) + y = torch.randn(3, 3, 4, 5, 6, device=device) + opt_fn = torch.compile(dynamic=False)(cat_5d_fn) + + out = opt_fn(x, y) + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=dim) + _test_result("cat.5d.dim0", out, cpu_out, rtol=1e-4, atol=1e-4) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run cat simulation tests") parser.add_argument( "--case", choices=[ - "default", "out", "4d_dim0", "4d_dim1", "4d_dim2", "4d_dim3", + "default", "out", "4d_dim0", "4d_dim1", "4d_dim2", "4d_dim3", "5d" "three_inputs", "four_inputs", "4d_three_inputs", "all" ], default="all", @@ -184,3 +196,5 @@ def cat_4d_three_inputs_fn(a, b, c): test_cat_four_inputs(device) if args.case in ("4d_three_inputs", "all"): test_cat_4d_three_inputs(device) + if args.case in ("5d", "all"): + test_cat_5d(device) From 3d9cb387b2ba27853efb983241fa4450c3174d9d Mon Sep 17 00:00:00 2001 From: jung-min Date: Thu, 5 Mar 2026 11:45:36 +0000 Subject: [PATCH 12/31] [Frontend/template] Connect SDPA template to NPU using Torch OpenReg --- PyTorchSimDevice/csrc/aten/OpenRegExtra.cpp | 34 +--- PyTorchSimDevice/csrc/aten/native/Extra.cpp | 51 +---- .../torch_openreg/openreg/__init__.py | 4 +- PyTorchSimFrontend/mlir/mlir_lowering.py | 14 +- PyTorchSimFrontend/mlir/mlir_sdpa_template.py | 186 +----------------- 5 files changed, 14 insertions(+), 275 deletions(-) diff --git a/PyTorchSimDevice/csrc/aten/OpenRegExtra.cpp b/PyTorchSimDevice/csrc/aten/OpenRegExtra.cpp index 04ba6d48..f048f878 100644 --- a/PyTorchSimDevice/csrc/aten/OpenRegExtra.cpp +++ b/PyTorchSimDevice/csrc/aten/OpenRegExtra.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -40,36 +41,6 @@ void wrapper_quantize_tensor_per_tensor_affine_stub( rtensor, qtensor, scale, zero_point); } -std::tuple< - at::Tensor, - at::Tensor, - at::Tensor, - at::Tensor, - c10::SymInt, - c10::SymInt, - at::Tensor, - at::Tensor, - at::Tensor> -wrapper__scaled_dot_product_fused_attention_overrideable( - const at::Tensor& query, - const at::Tensor& key, - const at::Tensor& value, - const std::optional& attn_bias, - double dropout_p, - bool is_causal, - bool return_debug_mask, - std::optional scale) { - return at::native::openreg::_scaled_dot_product_fused_attention_overrideable( - query, - key, - value, - attn_bias, - dropout_p, - is_causal, - return_debug_mask, - scale); -} - std::tuple wrapper_scaled_dot_product_fused_attention_overrideable_backward( const at::Tensor& grad_out, @@ -172,9 +143,6 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { m.impl("abs.out", &wrapper_abs_out); m.impl("quantize_per_tensor", &wrapper_quantize_per_tensor); m.impl("_fused_sdp_choice", &wrapper__fused_sdp_choice); - m.impl( - "_scaled_dot_product_fused_attention_overrideable", - &wrapper__scaled_dot_product_fused_attention_overrideable); m.impl( "_scaled_dot_product_fused_attention_overrideable_backward", &wrapper_scaled_dot_product_fused_attention_overrideable_backward); diff --git a/PyTorchSimDevice/csrc/aten/native/Extra.cpp b/PyTorchSimDevice/csrc/aten/native/Extra.cpp index 711d114c..aaf28e1a 100644 --- a/PyTorchSimDevice/csrc/aten/native/Extra.cpp +++ b/PyTorchSimDevice/csrc/aten/native/Extra.cpp @@ -19,7 +19,8 @@ int64_t _fused_sdp_choice( bool is_causal, std::optional scale, bool enable_gqa) { - auto backend = sdp::SDPBackend::math; + + auto backend = sdp::SDPBackend::overrideable; return static_cast(backend); } @@ -29,54 +30,6 @@ void quantize_tensor_per_tensor_affine_stub( double scale, int64_t zero_point) {} -std::tuple< - at::Tensor, - at::Tensor, - at::Tensor, - at::Tensor, - c10::SymInt, - c10::SymInt, - at::Tensor, - at::Tensor, - at::Tensor> -_scaled_dot_product_fused_attention_overrideable( - const at::Tensor& query, - const at::Tensor& key, - const at::Tensor& value, - const std::optional& attn_bias, - double dropout_p, - bool is_causal, - bool return_debug_mask, - std::optional scale) { - const int64_t batch_size = query.size(0); - const int64_t num_heads = query.size(1); - const int64_t head_dim_v = value.size(3); - const int64_t max_seqlen_q = query.size(2); - const int64_t max_seqlen_kv = key.size(2); - - auto opts = query.options(); - auto output = - at::empty({batch_size, num_heads, max_seqlen_q, head_dim_v}, opts); - auto logsumexp = - at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); - auto debug_attn_mask = at::empty( - {batch_size, num_heads, max_seqlen_q, max_seqlen_kv}, - opts.dtype(at::kFloat)); - auto philox_seed = at::empty({}, at::dtype(at::kLong)); - auto philox_offset = at::empty({}, at::dtype(at::kLong)); - - return std::make_tuple( - output, - logsumexp, - at::Tensor(), - at::Tensor(), - max_seqlen_q, - max_seqlen_kv, - philox_seed, - philox_offset, - debug_attn_mask); -} - std::tuple _scaled_dot_product_fused_attention_overrideable_backward( const at::Tensor& grad_out, diff --git a/PyTorchSimDevice/torch_openreg/openreg/__init__.py b/PyTorchSimDevice/torch_openreg/openreg/__init__.py index 5a0de6c3..9d10f90e 100644 --- a/PyTorchSimDevice/torch_openreg/openreg/__init__.py +++ b/PyTorchSimDevice/torch_openreg/openreg/__init__.py @@ -66,8 +66,8 @@ def _lazy_init(): return # Replace the global C++ binding with our custom dispatcher patch - from PyTorchSimFrontend.mlir.mlir_sdpa_template import patched_scaled_dot_product_attention - torch._C._nn.scaled_dot_product_attention = patched_scaled_dot_product_attention + # from PyTorchSimFrontend.mlir.mlir_sdpa_template import patched_scaled_dot_product_attention + # torch._C._nn.scaled_dot_product_attention = patched_scaled_dot_product_attention torch_openreg._C._init() register_interface_for_device(custom_device(), ExtensionDeviceInterface) diff --git a/PyTorchSimFrontend/mlir/mlir_lowering.py b/PyTorchSimFrontend/mlir/mlir_lowering.py index e09dcf57..a6b2478c 100644 --- a/PyTorchSimFrontend/mlir/mlir_lowering.py +++ b/PyTorchSimFrontend/mlir/mlir_lowering.py @@ -15,7 +15,7 @@ from PyTorchSimFrontend.mlir.mlir_conv_sb_template import MLIRConvSingleBatchTemplate from PyTorchSimFrontend.mlir.mlir_conv_sbs_template import MLIRConvSingleBatchStridedTemplate from PyTorchSimFrontend.mlir.mlir_maxpool_template import MLIRMaxPoolTemplate -from PyTorchSimFrontend.mlir.mlir_sdpa_template import MLIRFlashSDPATemplate, flash_sdpa_args +from PyTorchSimFrontend.mlir.mlir_sdpa_template import MLIRFlashSDPATemplate, flash_sdpa_args, calculate_scale from PyTorchSimFrontend import extension_config aten = torch.ops.aten @@ -44,14 +44,16 @@ def tuned_flash_sdpa( query : TensorBox, key : TensorBox, value : TensorBox, - scale : float, + attn_bias : Optional[TensorBox] = None, dropout_p : float = 0.0, is_causal : bool = False, - return_debug_mask : bool =False) -> tuple: + return_debug_mask : bool = False, + scale : Optional[float] = None) -> tuple: - print("Enter tuned_flash_sdpa") - + + scale = calculate_scale(query, scale) N, Hq, H, L, S, E, Ev, layout, query, key, value = flash_sdpa_args(query, key, value) + mlir_template = MLIRFlashSDPATemplate([query, key, value], layout, scale) # _scaled_dot_product_flash_attention has to return a tuple which has 9 values @@ -211,4 +213,4 @@ def custom_unsafe_index(x, indices): if extension_config.CONFIG_USE_TIMING_POOLING: lowerings.update({getattr(aten.max_pool2d_with_indices, overload): custom_maxpool for overload in aten.max_pool2d_with_indices.overloads()}) # FIXME: maxpool should be implemented as a template -lowerings.update({getattr(aten._scaled_dot_product_flash_attention, overload): tuned_flash_sdpa for overload in aten._scaled_dot_product_flash_attention.overloads()}) \ No newline at end of file +lowerings.update({getattr(aten._scaled_dot_product_fused_attention_overrideable, overload): tuned_flash_sdpa for overload in aten._scaled_dot_product_fused_attention_overrideable.overloads()}) \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_sdpa_template.py b/PyTorchSimFrontend/mlir/mlir_sdpa_template.py index 49c6c6bb..05030f27 100644 --- a/PyTorchSimFrontend/mlir/mlir_sdpa_template.py +++ b/PyTorchSimFrontend/mlir/mlir_sdpa_template.py @@ -73,121 +73,6 @@ def flash_sdpa_args( ) return [n, hq, h, l, s, e, ev, layout, query, key, value] - -def validate_sdpa_input( - query : torch.Tensor, - key : torch.Tensor, - value : torch.Tensor, - attn_mask : torch.Tensor = None, - dropout_p : float = 0.0, - is_casual : bool = False, - scale : float = None, - enable_gqa : bool = False) -> None: - """ - Validates input tensors and parameters for Scaled Dot Product Attention (SDPA). - This function's logic can be found in: - https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/attention.cpp(504 line) - https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html - """ - - # Tensor class, dtype, and device consistency - # Ensure all primary inputs are torch.Tensors - if not all(isinstance(t, torch.Tensor) for t in [query, key, value]): - raise TypeError( - f"Expected query, key and value to be Tensors, but got " - f"{type(query).__name__}, {type(key).__name__}, and {type(value).__name__}." - ) - - # Check for dtype mismatch - if query.dtype != key.dtype or query.dtype != value.dtype: - raise TypeError( - f"Expected query, key, and value to have the same dtype, " - f"but got {query.dtype}, {key.dtype}, and {value.dtype}." - ) - - # Check for device mismatch (e.g., mixing CPU and NPU) - if query.device != key.device or query.device != value.device: - raise ValueError( - f"Expected query, key, and value to be on the same device, " - f"but got {query.device}, {key.device}, and {value.device}." - ) - - # Shape and dimension validation - # SDPA typically expects 4D (B, H, S, D), but we check for at least 2D here - if any(t.dim() < 2 for t in [query, key, value]): - raise ValueError( - f"Expected query, key, and value to be at least 2D, " - f"but got Q:{query.dim()}D, K:{key.dim()}D, V:{value.dim()}D." - ) - - # Attention mask validation - if attn_mask is not None: - if not isinstance(attn_mask, torch.Tensor): - raise TypeError(f"Expected attn_mask to be a Tensor, but got {type(attn_mask).__name__}.") - - # Dtype check: floating point masks must match query dtype; bool masks are also allowed - if attn_mask.dtype.is_floating_point: - if attn_mask.dtype != query.dtype: - raise TypeError(f"Floating point attn_mask must match query dtype ({query.dtype}), but got {attn_mask.dtype}.") - elif attn_mask.dtype != torch.bool: - raise TypeError(f"attn_mask must be floating point or bool, but got {attn_mask.dtype}.") - - # Nested tensor limitation with explicit masking - if query.is_nested or key.is_nested: - raise ValueError("Nested tensors are not supported when an explicit attn_mask is set.") - - # Dropout and causal flag validation (added) - # Dropout probability must be in the range [0, 1) - if not (0.0 <= dropout_p < 1.0): - raise ValueError(f"Expected dropout_p to be in [0, 1), but got {dropout_p}.") - - # Mutual exclusivity: cannot use both explicit mask and causal flag (added) - if is_casual and attn_mask is not None: - raise ValueError("Both attn_mask and is_casual cannot be set at the same time.") - - # Scaling factor validation (added) - if scale is not None and scale <= 0.0: - raise ValueError(f"Expected scale to be a positive number, but got {scale}.") - - # GQA (Grouped Query Attention) constraints (added) - n_head_q = query.size(1) - n_head_k = key.size(1) - n_head_v = value.size(1) - - # The aten._scaled_dot_product_flash_attention kernel does not accept an explicit enable_gqa parameter. - # Instead, the Flash SDPA implementation infers GQA usage by checking if n_head_q != n_head_k. - if not enable_gqa and n_head_q != n_head_k: - raise ValueError(f"Query and Key must have the same number of heads when enable_gqa is false (Q:{n_head_q} vs K:{n_head_k}).") - - if enable_gqa: - if n_head_q == n_head_k: - raise ValueError(f"enable_gqa Query and Key ") - - if n_head_k != n_head_v: - raise ValueError(f"Key and Value must have the same number of heads (K:{n_head_k} vs V:{n_head_v}).") - - # Query heads must be an integer multiple of key heads for grouping - if n_head_q % n_head_k != 0: - raise ValueError( - f"Number of query heads ({n_head_q}) must be divisible by " - f"number of key heads ({n_head_k}) for GQA." - ) - -def convert_boolean_attn_mask(attn_mask: torch.Tensor, target_dtype: torch.dtype) -> float: - """ - Equivalent to the C++ 'convert_boolean_attn_mask' function. - Converts a boolean mask to a floating-point mask for SDPA. - """ - - if attn_mask is not None and attn_mask.dtype == torch.bool: - - new_mask = torch.zeros_like(attn_mask, dtype=target_dtype) - minus_inf = torch.finfo(target_dtype).min - new_mask.masked_fill_(attn_mask.logical_not(), minus_inf) - - return new_mask - - return attn_mask def calculate_scale(query: torch.Tensor, scale: float) -> float: """ @@ -195,79 +80,10 @@ def calculate_scale(query: torch.Tensor, scale: float) -> float: Otherwise, use the provided scale. """ if scale is None: - return 1.0 / math.sqrt(query.size(-1)) + return 1.0 / math.sqrt(query.layout.size[-1]) else: return scale -def patched_scaled_dot_product_attention( - query_ : torch.Tensor, - key : torch.Tensor, - value : torch.Tensor, - dropout_p : float = 0.0, - is_casual : bool = False, - attn_mask_ : torch.Tensor = None, - scale_ : float = None, - enable_gqa : bool = None, - orig_fn = torch._C._nn.scaled_dot_product_attention) -> torch.Tensor : - """ - Custom patch for Scaled Dot Product Attention (SDPA) to intercept high-level calls. - For NPU devices, it redirects execution to specific ATen kernels based on global flags. - For all devices, it maintains parity with the original dispatcher logic found in: - https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/attention.cpp - - This function acts as a custom override that replaces the default PyTorch SDPA implementation, - invoked via 'PyTorchSim/PyTorchSimDevice/torch_openreg/openreg/__init__.py'. - """ - - # Device-specific Dispatching: redirect to specialized kernels if on NPU - if "npu" in str(query_.device): - - validate_sdpa_input(query_, key, value, attn_mask_, dropout_p, is_casual, scale_, enable_gqa) - attn_mask = convert_boolean_attn_mask(attn_mask_, query_.dtype) - - # Kernel selection logic: emulate C++ dispatcher priority - # Selection priority(can be changed): flash attention > memory efficient > math (cuDNN is not supported) - aten = torch.ops.aten - scale = calculate_scale(query_, scale_) - - if flash_sdp_enabled(): - # Skip padding query, key and value for alignment. - dispatch_kwargs = { - "dropout_p" : dropout_p, - "is_causal" : is_casual, - "return_debug_mask" : False, - "scale" : scale - } - - out_lse_softmax = aten._scaled_dot_product_flash_attention( - query_, key, value, **dispatch_kwargs - ) - - return out_lse_softmax[0] - elif mem_efficient_sdp_enabled(): - # out_and_lse = aten._scaled_dot_product_efficient_attention(...) - # return out_and_lse[0] - raise NotImplementedError("Memory efficient SDPA is not implemented yet.") - else: - dispatch_kwargs = { - "attn_mask" : attn_mask, - "dropout_p" : dropout_p, - "is_causal" : is_casual, - "dropout_mask" : None, - "scale": scale, - "enable_gqa" : enable_gqa - } - - out_lse_softmax = aten._scaled_dot_product_attention_math( - query_, - key, - value, - **dispatch_kwargs) - - return out_lse_softmax[0] - else: - # Fallback: Delegate to the original C++ Dispatcher for other devices - return orig_fn(query_, key, value) FLASH_SDPA_TEMPLATE = r""" // SDPA kernel From 591e8a98cdb7a734f58c3e2afff6b252f5b86bee Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Thu, 5 Mar 2026 23:16:40 +0900 Subject: [PATCH 13/31] [Templte/Cat] Apply copy operation when node has view --- PyTorchSimFrontend/mlir/mlir_cat_template.py | 11 +++------- PyTorchSimFrontend/mlir/mlir_lowering.py | 23 +++++++++++++++++--- 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_cat_template.py b/PyTorchSimFrontend/mlir/mlir_cat_template.py index 6eb60198..7bee54ac 100644 --- a/PyTorchSimFrontend/mlir/mlir_cat_template.py +++ b/PyTorchSimFrontend/mlir/mlir_cat_template.py @@ -161,14 +161,14 @@ def get_tile_candidates( candidates = set() for mult in range(1, max_tile // kernel.vector_lane + 1): t = mult * kernel.vector_lane - if t <= dim_size: + if t <= dim_size and dim_size % t == 0: candidates.add(t) if max_tile > 0: for exp in range(int(math.log2(max_tile)) + 1): t = 2 ** exp - if t <= dim_size: + if t <= dim_size and dim_size % t == 0: candidates.add(t) - candidates.add(dim_size) + candidates.add(dim_size) # dim_size always divides itself dim_tile_candidates.append(sorted(candidates)[:5]) tile_candidates = [ @@ -322,11 +322,6 @@ def make_affine_map(idx_syms, strides, layout_offset): for i, x in enumerate(input_nodes): x_stride = x.get_layout().stride - if hasattr(x, 'data') and hasattr(x.data, 'dims'): - # PermuteView: re-order strides according to the permutation - perm = x.data.dims - x_stride = [x_stride[perm[d]] for d in range(rank)] - in_syms, in_layout_strides, in_dram_strides = [], [], [] out_syms, out_layout_strides, out_dram_strides = [], [], [] tile_idx = 0 diff --git a/PyTorchSimFrontend/mlir/mlir_lowering.py b/PyTorchSimFrontend/mlir/mlir_lowering.py index d7aee715..e5df4b78 100644 --- a/PyTorchSimFrontend/mlir/mlir_lowering.py +++ b/PyTorchSimFrontend/mlir/mlir_lowering.py @@ -1,3 +1,4 @@ +import math from typing import List, Optional, Sequence import torch @@ -205,11 +206,27 @@ def _cat_layout(tensors: Sequence[TensorBox], dim: int) -> ir.Layout: def custom_cat_default(tensors: Sequence[TensorBox], dim: int = 0): if tensors and dim < 0: dim += len(tensors[0].get_size()) - + copy_default_lowering = lowerings.get(aten.copy_.default) + empty_strided_lowering = lowerings.get(aten.empty_strided.default) + new_tensors = [] for t in tensors: t.realize() - layout = _cat_layout(tensors, dim) - mlir_template = MLIRCatTemplate(list(tensors), layout, dim=dim) + # If the tensor is backed by a view (ReinterpretView, PermuteView, etc.), + # materialise it into a fresh contiguous FixedLayout buffer so the cat + # kernel always receives plain, dense strides. + if isinstance(t.data, ir.BaseView): + sizes = list(t.get_size()) + strides = [math.prod(sizes[i + 1:]) for i in range(len(sizes))] + new_buf = empty_strided_lowering( + sizes, strides, dtype=t.get_dtype(), device=t.get_device() + ) + tt = copy_default_lowering(new_buf, t) + else: + tt = t + new_tensors.append(tt) + + layout = _cat_layout(new_tensors, dim) + mlir_template = MLIRCatTemplate(list(new_tensors), layout, dim=dim) return mlir_template.generate().output_node() def _custom_sort_values_impl( From dab34954d61d5558658684dcb1415fa75c3c6935 Mon Sep 17 00:00:00 2001 From: jung-min Date: Sat, 7 Mar 2026 10:11:57 +0000 Subject: [PATCH 14/31] [Refactor] Refactored TopK test code for the OpenReg device --- tests/test_topk.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/tests/test_topk.py b/tests/test_topk.py index c8565310..caf56779 100644 --- a/tests/test_topk.py +++ b/tests/test_topk.py @@ -31,21 +31,11 @@ def topk_fn(a): opt_topk = torch.compile(dynamic=False)(topk_fn) res_values, res_indices = opt_topk(x) - ref_values, ref_indices = torch.topk(x.cpu(), k, dim=dim, largest=largest, sorted=sorted) test_result("TopK/values", res_values, ref_values) test_result("TopK/indices", res_indices, ref_indices) if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser(description="Run LayerNorm test with dynamic shape") - parser.add_argument('--shape', type=str, default="(512,768)") - args = parser.parse_args() - shape = tuple(map(int, args.shape.strip('()').split(','))) - - from Scheduler.scheduler import ExecutionEngine - module = ExecutionEngine.setup_device() - device = module.custom_device() + device = torch.device('npu:0') test_topk(device, (128, 128), k=2, dim=-1) \ No newline at end of file From a15f5d2128429c5fa9580e8eb2b1f625a55f054d Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Wed, 11 Mar 2026 11:03:29 +0900 Subject: [PATCH 15/31] [Template/Sort] Add template code for Bitonic sort --- PyTorchSimFrontend/mlir/mlir_lowering.py | 133 +--- PyTorchSimFrontend/mlir/mlir_ops.py | 76 ++- PyTorchSimFrontend/mlir/mlir_sort_template.py | 627 ++++++++++++------ tests/test_sort.py | 128 ++-- 4 files changed, 591 insertions(+), 373 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_lowering.py b/PyTorchSimFrontend/mlir/mlir_lowering.py index e5df4b78..36e9955b 100644 --- a/PyTorchSimFrontend/mlir/mlir_lowering.py +++ b/PyTorchSimFrontend/mlir/mlir_lowering.py @@ -17,13 +17,11 @@ from PyTorchSimFrontend.mlir.mlir_conv_sbs_template import MLIRConvSingleBatchStridedTemplate from PyTorchSimFrontend.mlir.mlir_maxpool_template import MLIRMaxPoolTemplate from PyTorchSimFrontend.mlir.mlir_cat_template import MLIRCatTemplate -from PyTorchSimFrontend.mlir.mlir_sort_template import MLIRSortTemplate +from PyTorchSimFrontend.mlir.mlir_sort_template import MLIRSortTemplate, MLIRStableSortTemplate from PyTorchSimFrontend import extension_config aten = torch.ops.aten aten_spmm = MLIRExternKernelChoice(torch.sparse.mm, "custom_op::sparse_addmm") -_orig_cat_default_lowering = lowerings.get(aten.cat.default) -_orig_cat_out_lowering = lowerings.get(aten.cat.out) _orig_sort_values_stable_lowering = lowerings.get(aten.sort.values_stable) def tuned_mm(mat1, mat2, * ,layout=None): @@ -229,48 +227,35 @@ def custom_cat_default(tensors: Sequence[TensorBox], dim: int = 0): mlir_template = MLIRCatTemplate(list(new_tensors), layout, dim=dim) return mlir_template.generate().output_node() -def _custom_sort_values_impl( - self: TensorBox, +def custom_sort_default( + value: TensorBox, dim: int = -1, descending: bool = False, - values: Optional[TensorBox] = None, - indices: Optional[TensorBox] = None, stable: Optional[bool] = None, ): - if values is None or indices is None: - raise RuntimeError("sort.values* lowering requires both out tensors: values, indices") + if dim < 0: + dim += len(value.get_size()) - def _normalize_dim(rank: int, d: int) -> int: - return d + rank if d < 0 else d + value.realize() - if not hasattr(self, "get_size"): - raise RuntimeError("sort.values* lowering requires TensorBox input") - - rank = len(self.get_size()) - norm_dim = _normalize_dim(rank, dim) - if norm_dim < 0 or norm_dim >= rank: - raise RuntimeError(f"sort.values* dim out of range: dim={dim}, rank={rank}") - if rank != 2: - raise RuntimeError(f"sort.values* lowering currently supports rank-2 only, got rank={rank}") - if norm_dim not in (0, 1): - raise RuntimeError(f"sort.values* lowering currently supports dim in {{0,1}} only, got dim={norm_dim}") - - self.realize() - if isinstance(values, TensorBox): - values.realize() - if isinstance(indices, TensorBox): - indices.realize() - - value_layout, _ = _sort_layouts(self, norm_dim, descending) - mlir_template = MLIRSortTemplate( - [self], + value_layout, index_layout = _sort_layouts(value, dim, descending) + empty_strided_lowering = lowerings.get(aten.empty_strided.default) + indices = empty_strided_lowering( + value.get_size(), + index_layout.stride, + dtype=torch.int64, + device=value.get_device(), + ) + stable_required = True if stable is None else stable + sort_template_cls = MLIRStableSortTemplate if stable_required else MLIRSortTemplate + mlir_template = sort_template_cls( + [value, indices], value_layout, - dim=norm_dim, + dim=dim, descending=descending, - stable=True if stable is None else stable, - indices_node=indices, + stable=stable_required, ) - sorted_values = mlir_template.generate(template_buffer_node=values, epilogue_nodes=[indices]).output_node() + sorted_values = mlir_template.generate(template_buffer_node=value).output_node() return sorted_values, indices @@ -290,78 +275,6 @@ def _sort_layouts(x: TensorBox, dim: int, descending: bool): index_layout = ir.FixedLayout(x.get_device(), torch.int64, i_sizes, i_stride) return value_layout, index_layout - -def custom_sort_stable( - self: TensorBox, - *, - stable: Optional[bool] = None, - dim: int = -1, - descending: bool = False, -): - empty_strided_lowering = lowerings.get(aten.empty_strided.default) - if empty_strided_lowering is None: - if _orig_sort_values_stable_lowering is None: - raise RuntimeError("sort.stable lowering requires aten.empty_strided.default") - return _orig_sort_values_stable_lowering(self, dim=dim, descending=descending, stable=True) - - rank = len(self.get_size()) if hasattr(self, "get_size") else 0 - norm_dim = dim + rank if dim < 0 else dim - if rank > 0 and (norm_dim < 0 or norm_dim >= rank): - raise RuntimeError(f"sort.stable dim out of range: dim={dim}, rank={rank}") - - # Template specialization supports rank-2 and dim in {0,1}. - if rank == 2 and norm_dim not in (0, 1): - if _orig_sort_values_stable_lowering is None: - raise RuntimeError("Original aten.sort.values_stable lowering is missing") - return _orig_sort_values_stable_lowering(self, dim=dim, descending=descending, stable=True) - - try: - value_layout, index_layout = _sort_layouts(self, norm_dim, descending) - values = empty_strided_lowering( - list(value_layout.size), - list(value_layout.stride), - dtype=value_layout.dtype, - device=self.get_device(), - ) - indices = empty_strided_lowering( - list(index_layout.size), - list(index_layout.stride), - dtype=index_layout.dtype, - device=self.get_device(), - ) - return _custom_sort_values_impl( - self=self, - dim=dim, - descending=descending, - values=values, - indices=indices, - stable=True if stable is None else stable, - ) - except Exception: - if _orig_sort_values_stable_lowering is None: - raise - return _orig_sort_values_stable_lowering(self, dim=dim, descending=descending, stable=stable) - - -def custom_sort_values_stable( - self: TensorBox, - *, - stable: Optional[bool] = None, - dim: int = -1, - descending: bool = False, - values: Optional[TensorBox] = None, - indices: Optional[TensorBox] = None, -): - return _custom_sort_values_impl( - self=self, - dim=dim, - descending=descending, - values=values, - indices=indices, - stable=stable, - ) - - lowerings.update({getattr(aten.mm, overload): tuned_mm for overload in aten.mm.overloads()}) lowerings.update({getattr(aten.addmm, overload): tuned_addmm for overload in aten.addmm.overloads()}) lowerings.update({getattr(aten.convolution, overload): convolution for overload in aten.convolution.overloads()}) @@ -369,9 +282,7 @@ def custom_sort_values_stable( lowerings.update({getattr(aten._sparse_addmm, overload): sparse_addmm for overload in aten._sparse_addmm.overloads()}) lowerings.update({getattr(aten._unsafe_index, overload): custom_unsafe_index for overload in aten._unsafe_index.overloads()}) lowerings.update({getattr(aten.cat, overload): custom_cat_default for overload in aten.cat.overloads()}) - -lowerings.update({aten.sort.stable: custom_sort_stable}) -lowerings.update({aten.sort.values_stable: custom_sort_values_stable}) +lowerings.update({getattr(aten.sort, overload): custom_sort_default for overload in aten.sort.overloads()}) if extension_config.CONFIG_USE_TIMING_POOLING: lowerings.update({getattr(aten.max_pool2d_with_indices, overload): custom_maxpool for overload in aten.max_pool2d_with_indices.overloads()}) # FIXME: maxpool should be implemented as a template diff --git a/PyTorchSimFrontend/mlir/mlir_ops.py b/PyTorchSimFrontend/mlir/mlir_ops.py index 9edd2e44..ace4f9ea 100644 --- a/PyTorchSimFrontend/mlir/mlir_ops.py +++ b/PyTorchSimFrontend/mlir/mlir_ops.py @@ -182,7 +182,7 @@ def to_dtype(operand, dst_mlir_dtype, *args, **kwargs): # Case A: Integer -> Float if src_type_char == "i" and dst_type_char == "f": - op_str = f"arith.sitofp %{operand} : {src_shape} to {shape}" + op_str = f"arith.uitofp %{operand} : {src_shape} to {shape}" # Case B: Float -> Integer elif src_type_char == "f" and dst_type_char == "i": op_str = f"arith.fptosi %{operand} : {src_shape} to {shape}" @@ -1142,6 +1142,80 @@ def multi_reduction(acc, init, vec_size, red_size, red_shape, red_type, type_nam line = reduction_combine_vec(red_type, value, init, axis=0, shape=new_vshape, reduced_shape=final_reduced_shape) return line, [red_size, type_name] + @staticmethod + def vector_shuffle(operand, indices, operand2=None, *args, **kwargs): + tile_size1, dtype1 = V.kernel.var_info[operand] + if operand2 is None: + operand2 = operand + tile_size2, dtype2 = V.kernel.var_info[operand2] + if dtype1 != dtype2: + raise ValueError( + f"vector_shuffle expects same element type, got {dtype1} and {dtype2}" + ) + total_size = tile_size1 + tile_size2 + for idx in indices: + if idx < -1 or idx >= total_size: + raise ValueError( + f"vector_shuffle index out of range: {idx}, expected in [-1, {total_size - 1}]" + ) + vt1 = f"vector<{tile_size1}x{dtype1}>" + vt2 = f"vector<{tile_size2}x{dtype1}>" + idx_str = ", ".join(str(i) for i in indices) + op_str = f"vector.shuffle %{operand}, %{operand2} [{idx_str}]" + return format_mlir_op(op_str, f"{vt1}, {vt2}", **kwargs), [len(indices), dtype1] + + @staticmethod + def constant_mask(select_min, N, *args, **kwargs): + vals = ", ".join("true" if x else "false" for x in select_min) + op_str = f"arith.constant dense<[{vals}]>" + return format_mlir_op(op_str, f"vector<{N}xi1>", **kwargs), [N, "i1"] + + @staticmethod + def bitonic_sort(operand, descending=False, *args, **kwargs): + def _compute_bitonic_stages(N: int, descending: bool): + assert N >= 2 and (N & (N - 1)) == 0, "N must be power-of-2 >= 2" + stages = [] + size = 2 + while size <= N: + stride = size // 2 + while stride >= 1: + merged_shuffle = list(range(N)) + merged_mask = [None] * N + + for start in range(0, N, size): + blk_dir = "ASCENDING" if (start // size) % 2 == 0 else "DESCENDING" + for i in range(start, start + size - stride, stride * 2): + for j in range(stride): + a, b = i + j, i + j + stride + merged_shuffle[a] = b + merged_shuffle[b] = a + if blk_dir == "ASCENDING": + merged_mask[a] = True # a = min + merged_mask[b] = False # b = max + else: + merged_mask[a] = False # a = max + merged_mask[b] = True # b = min + select_min = [bool(x) if x is not None else False for x in merged_mask] + if descending: + select_min = [not x for x in select_min] + stages.append({ + "shuffle": merged_shuffle, + "select_min": select_min, + }) + stride //= 2 + size *= 2 + return stages + + tile_size, _ = V.kernel.var_info[operand] + cur = operand + for stage in _compute_bitonic_stages(tile_size, descending): + mask = ops.constant_mask(stage["select_min"], tile_size) + shuffled = ops.vector_shuffle(cur, stage["shuffle"]) + vmin = ops.minimum(cur, shuffled) + vmax = ops.maximum(cur, shuffled) + cur = ops.where(mask, vmin, vmax) + return cur, V.kernel.var_info[cur] + @staticmethod def _load(compute_vec_size, mlir_dtype, buffer, indices, buffer_shape, *args, **kwargs): if compute_vec_size == 1: diff --git a/PyTorchSimFrontend/mlir/mlir_sort_template.py b/PyTorchSimFrontend/mlir/mlir_sort_template.py index d12c7570..24b3a460 100644 --- a/PyTorchSimFrontend/mlir/mlir_sort_template.py +++ b/PyTorchSimFrontend/mlir/mlir_sort_template.py @@ -1,130 +1,189 @@ from typing import List, Optional +import contextlib -import sympy -from torch._inductor.ir import IRNode -from torch._inductor.virtualized import V +from torch._inductor.ir import Buffer, IRNode +from torch._inductor.virtualized import _ops as ops +from torch._inductor.codegen import common from PyTorchSimFrontend.mlir import mlir_common from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate, MLIRTemplateKernel +from PyTorchSimFrontend.mlir.mlir_common import LoopLevel + +VECTOR_SIZE = 16 TEMPLATE = r""" {{kernel.def_global_vars()}} +// chunk index -> element index +#map_chunk_to_elem = affine_map<(d0) -> (d0 * {{ VECTOR_SIZE }})> -func.func @{{ KERNEL_NAME }} {{kernel.def_kernel(inputs=[X, YI], outputs=[YV], names_str=NAMES_STR, input_reorder=input_reorder)}} { - {{ kernel.def_sram_buffer("YI", YI_TILE_DESC, id=1, indent_size=2) }} - {{ kernel.def_sram_buffer(OUT_DVAR, YV_TILE_DESC, id=2, indent_size=2) }} +func.func @{{ KERNEL_NAME }} {{kernel.def_kernel(inputs=[X, XI], outputs=[YV], names_str=NAMES_STR, input_reorder=input_reorder)}} { + {{ kernel.def_sram_buffer("X", X_TILE_DESC, id=0, indent_size=2) }} + {{ kernel.def_sram_buffer("XI", XI_TILE_DESC, id=1, indent_size=2) }} + {{ kernel.def_sram_buffer("YV", YV_TILE_DESC, id=2, indent_size=2) }} {{ kernel.def_local_vars(indent_size=2) }} - %c0 = arith.constant 0 : index - %c_cols = arith.constant {{ COLS }} : index affine.for %sort_block = 0 to 1 step 1 { - // Initialize output value/index buffers. - affine.for %row = 0 to {{ ROWS }} step 1 { - affine.for %col = 0 to {{ COLS }} step 1 { - {{ kernel.def_dma_op("MVIN", "X", INIT_X_IDX, X_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=8) }} - {{ kernel.def_dma_op("MVOUT", OUT_DVAR, INIT_YV_IDX, X_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=8) }} -{% if DIM == 1 %} - %idx_i64 = arith.index_cast %col : index to {{ YI_ELEM_TYPE }} -{% else %} - %idx_i64 = arith.index_cast %row : index to {{ YI_ELEM_TYPE }} -{% endif %} - memref.store %idx_i64, %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} - {{ kernel.def_dma_op("MVOUT", "YI", INIT_YI_IDX, YI_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=8) }} - } - } - -{% if DIM == 1 %} - // Stable bubble sort on each row (dim=1). - affine.for %row = 0 to {{ ROWS }} step 1 { - affine.for %pass = 0 to {{ COLS }} step 1 { - affine.for %j = 0 to {{ COLS_MINUS1 }} step 1 { - {{ kernel.def_dma_op("MVIN", OUT_DVAR, D1_S0_IDX, YV_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=10) }} - %lhs = memref.load %yv_sort_tile[%c0, %c0] : {{ YV_TILE_MEMREF_TYPE }} - - {{ kernel.def_dma_op("MVIN", OUT_DVAR, D1_S1_IDX, YV_S1_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=10) }} - %rhs = memref.load %yv_sort_tile[%c0, %c0] : {{ YV_TILE_MEMREF_TYPE }} - -{% if DESCENDING %} - %need_swap = arith.cmpf olt, %lhs, %rhs : {{ YV_ELEM_TYPE }} -{% else %} - %need_swap = arith.cmpf ogt, %lhs, %rhs : {{ YV_ELEM_TYPE }} -{% endif %} - scf.if %need_swap { - memref.store %rhs, %yv_sort_tile[%c0, %c0] : {{ YV_TILE_MEMREF_TYPE }} - {{ kernel.def_dma_op("MVOUT", OUT_DVAR, D1_S0_IDX, YV_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} - - memref.store %lhs, %yv_sort_tile[%c0, %c0] : {{ YV_TILE_MEMREF_TYPE }} - {{ kernel.def_dma_op("MVOUT", OUT_DVAR, D1_S1_IDX, YV_S1_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} - - {{ kernel.def_dma_op("MVIN", "YI", D1_S0_IDX, YI_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} - %li = memref.load %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} - - {{ kernel.def_dma_op("MVIN", "YI", D1_S1_IDX, YI_S1_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} - %ri = memref.load %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} - - memref.store %ri, %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} - {{ kernel.def_dma_op("MVOUT", "YI", D1_S0_IDX, YI_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} - - memref.store %li, %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} - {{ kernel.def_dma_op("MVOUT", "YI", D1_S1_IDX, YI_S1_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} - } - } - } - } -{% else %} - // Stable bubble sort on each column (dim=0). - affine.for %col = 0 to {{ COLS }} step 1 { - affine.for %pass = 0 to {{ ROWS }} step 1 { - affine.for %i = 0 to {{ ROWS_MINUS1 }} step 1 { - {{ kernel.def_dma_op("MVIN", OUT_DVAR, D0_S0_IDX, YV_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=10) }} - %lhs = memref.load %yv_sort_tile[%c0, %c0] : {{ YV_TILE_MEMREF_TYPE }} - - {{ kernel.def_dma_op("MVIN", OUT_DVAR, D0_S1_IDX, YV_S1_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=10) }} - %rhs = memref.load %yv_sort_tile[%c0, %c0] : {{ YV_TILE_MEMREF_TYPE }} - -{% if DESCENDING %} - %need_swap = arith.cmpf olt, %lhs, %rhs : {{ YV_ELEM_TYPE }} -{% else %} - %need_swap = arith.cmpf ogt, %lhs, %rhs : {{ YV_ELEM_TYPE }} -{% endif %} - scf.if %need_swap { - memref.store %rhs, %yv_sort_tile[%c0, %c0] : {{ YV_TILE_MEMREF_TYPE }} - {{ kernel.def_dma_op("MVOUT", OUT_DVAR, D0_S0_IDX, YV_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} - - memref.store %lhs, %yv_sort_tile[%c0, %c0] : {{ YV_TILE_MEMREF_TYPE }} - {{ kernel.def_dma_op("MVOUT", OUT_DVAR, D0_S1_IDX, YV_S1_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} - - {{ kernel.def_dma_op("MVIN", "YI", D0_S0_IDX, YI_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} - %li = memref.load %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} - - {{ kernel.def_dma_op("MVIN", "YI", D0_S1_IDX, YI_S1_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} - %ri = memref.load %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} - - memref.store %ri, %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} - {{ kernel.def_dma_op("MVOUT", "YI", D0_S0_IDX, YI_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} - - memref.store %li, %yi_sort_tile[%c0, %c0] : {{ YI_TILE_MEMREF_TYPE }} - {{ kernel.def_dma_op("MVOUT", "YI", D0_S1_IDX, YI_S1_TILE_DESC, subtile_size=[1, 1], async_type=0, indent_size=12) }} - } - } - } - } -{% endif %} + {%- for d in range(RANK-1) %} + affine.for %index{{ OUTPUT_DIM[d] }} = 0 to {{ OUTPUT_SIZES[d] }} step {{ STEP_SIZES[d] }} { + {%- endfor %} + + %x_dram_offset = affine.apply {{ X_OFFSET_MAP }}({{ OUTER_VARS }}) + %xi_dram_offset = affine.apply {{ XI_OFFSET_MAP }}({{ OUTER_VARS }}) + %yv_dram_offset = affine.apply {{ YV_OFFSET_MAP }}({{ OUTER_VARS }}) + {{ kernel.def_dma_op("MVIN", "X", [], X_TILE_DESC, indent_size=INDENT_SIZE, dram_stride=X_DRAM_STRIDE, dram_offset="x_dram_offset") }} + + // SIMD local sort + loop-based chunk merge. +{{ BITONIC_BODY }} + + {{ kernel.def_dma_op("MVOUT", "XI", [], XI_TILE_DESC, indent_size=INDENT_SIZE, dram_stride=XI_DRAM_STRIDE, dram_offset="xi_dram_offset") }} + {{ kernel.def_dma_op("MVOUT", "YV", [], YV_TILE_DESC, indent_size=INDENT_SIZE, dram_stride=YV_DRAM_STRIDE, dram_offset="yv_dram_offset") }} + {%- for d in range(RANK-1) %} + } { outer_loop=true } + {%- endfor %} } { outer_loop=true } return } """ +def _make_offset_map(outer_dims, all_strides, layout_offset): + """Build an affine_map over outer-dim loop variables that computes the flat DRAM offset.""" + terms = [] + for j, d in enumerate(outer_dims): + s = int(all_strides[d]) + if s == 1: + terms.append(f"d{j}") + elif s != 0: + terms.append(f"d{j} * {s}") + try: + off = int(layout_offset) + except (TypeError, ValueError): + off = 0 + if off: + terms.append(str(off)) + nd = len(outer_dims) + dim_str = ", ".join(f"d{j}" for j in range(nd)) + expr = " + ".join(terms) if terms else "0" + return f"affine_map<({dim_str}) -> ({expr})>" + + +def _compute_bitonic_stages(n: int, descending: bool): + stages = [] + size = 2 + while size <= n: + stride = size // 2 + while stride >= 1: + merged_shuffle = list(range(n)) + merged_mask = [None] * n + for start in range(0, n, size): + blk_dir = "ASCENDING" if (start // size) % 2 == 0 else "DESCENDING" + for i in range(start, start + size - stride, stride * 2): + for j2 in range(stride): + a, b = i + j2, i + j2 + stride + merged_shuffle[a] = b + merged_shuffle[b] = a + if blk_dir == "ASCENDING": + merged_mask[a] = True + merged_mask[b] = False + else: + merged_mask[a] = False + merged_mask[b] = True + select_min = [bool(x) if x is not None else False for x in merged_mask] + if descending: + select_min = [not x for x in select_min] + stages.append({"shuffle": merged_shuffle, "select_min": select_min}) + stride //= 2 + size *= 2 + return stages + + +def _pair_less_equal(left_v, right_v, left_i, right_i): + cmp_val = ops.lt(left_v, right_v) + cmp_eq = ops.eq(left_v, right_v) + cmp_idx = ops.le(left_i, right_i) + return ops.or_(cmp_val, ops.and_(cmp_eq, cmp_idx)) + + +def _pair_greater_equal(left_v, right_v, left_i, right_i): + cmp_val = ops.gt(left_v, right_v) + cmp_eq = ops.eq(left_v, right_v) + cmp_idx = ops.le(left_i, right_i) + return ops.or_(cmp_val, ops.and_(cmp_eq, cmp_idx)) + + +def _bitonic_sort_pair(values, indices, vector_size: int, descending: bool, stable_sort: bool): + cur_v = values + cur_i = indices + for stage_desc in _compute_bitonic_stages(vector_size, descending): + mask = ops.constant_mask(stage_desc["select_min"], vector_size) + shuf_v = ops.vector_shuffle(cur_v, stage_desc["shuffle"]) + shuf_i = ops.vector_shuffle(cur_i, stage_desc["shuffle"]) + if stable_sort: + # `cmp` drives the "min side" selection in the bitonic network. + # For descending stable sort, tie elements with smaller original index + # must stay earlier, so the min side should treat larger index as smaller. + if descending: + cmp_val = ops.lt(cur_v, shuf_v) + cmp_eq = ops.eq(cur_v, shuf_v) + cmp_idx = ops.ge(cur_i, shuf_i) + cmp = ops.or_(cmp_val, ops.and_(cmp_eq, cmp_idx)) + else: + cmp = _pair_less_equal(cur_v, shuf_v, cur_i, shuf_i) + else: + cmp = ops.le(cur_v, shuf_v) + min_v = ops.where(cmp, cur_v, shuf_v) + min_i = ops.where(cmp, cur_i, shuf_i) + max_v = ops.where(cmp, shuf_v, cur_v) + max_i = ops.where(cmp, shuf_i, cur_i) + cur_v = ops.where(mask, min_v, max_v) + cur_i = ops.where(mask, min_i, max_i) + return cur_v, cur_i + + +def _merge_sorted_pair_vectors( + left_norm, + left_idx_norm, + right_norm, + right_idx_norm, + ascending: bool, + stable_sort: bool, + vector_size: int, + rev_indices, +): + right_pair = ops.vector_shuffle(right_norm, rev_indices, right_norm) + right_idx_pair = ops.vector_shuffle(right_idx_norm, rev_indices, right_idx_norm) + if ascending: + cmp = ( + _pair_less_equal(left_norm, right_pair, left_idx_norm, right_idx_pair) + if stable_sort + else ops.le(left_norm, right_pair) + ) + else: + cmp = ( + _pair_greater_equal(left_norm, right_pair, left_idx_norm, right_idx_pair) + if stable_sort + else ops.ge(left_norm, right_pair) + ) + left_merge = ops.where(cmp, left_norm, right_pair) + left_idx_merge = ops.where(cmp, left_idx_norm, right_idx_pair) + right_merge = ops.where(cmp, right_pair, left_norm) + right_idx_merge = ops.where(cmp, right_idx_pair, left_idx_norm) + return left_merge, left_idx_merge, right_merge, right_idx_merge + + class MLIRSortTemplate(MLIRTemplate): - def __init__(self, input_nodes, layout, dim, descending=False, stable=False, indices_node=None, input_reorder=None): + def __init__(self, input_nodes, layout, dim, descending=False, stable=False, input_reorder=None): super().__init__("kernel", input_nodes, layout, input_reorder) self.dim = dim self.descending = descending self.stable = stable - self.indices_node = indices_node + self.use_stable_sort = False + self.output_nodes = [ + Buffer(name="buf_out_values", layout=layout), + ] + self.output_node = self.output_nodes[0] def render( self, @@ -135,119 +194,281 @@ def render( **kwargs, ): if template_buffer_node is not None: + self.output_nodes[0] = template_buffer_node self.output_node = template_buffer_node - if self.indices_node is None: - raise RuntimeError("MLIRSortTemplate requires indices output node") x = self.input_nodes[0] - yv = self.output_node - yi = self.indices_node - - def _as_int(v): - try: - return int(v) - except Exception: - return int(V.graph.sizevars.size_hint(v)) - - x_size = x.get_size() - if len(x_size) != 2: - raise RuntimeError("MLIRSortTemplate currently supports rank-2 input only") - if self.dim not in (0, 1): - raise RuntimeError(f"MLIRSortTemplate currently supports dim in {{0,1}} only, got dim={self.dim}") - - rows = _as_int(x_size[0]) - cols = _as_int(x_size[1]) - cols_minus1 = max(0, cols - 1) - rows_minus1 = max(0, rows - 1) - - x_dtype = x.get_dtype() - yv_dtype = yv.get_dtype() - yi_dtype = yi.get_dtype() - if x_dtype != yv_dtype: - raise RuntimeError("sort template requires input/value dtype match") - - yi_tile_desc = mlir_common.MLIRMultiDimTile([1, 1], kernel.vector_lane, vlane_split_axis=1, vlane_stride=1) - yi_tile_desc.set_tile_size_stride([1, 1], [1, 1]) - yi_tile_desc.set_name("yi_sort_tile") - yv_tile_desc = mlir_common.MLIRMultiDimTile([1, 1], kernel.vector_lane, vlane_split_axis=1, vlane_stride=1) - yv_tile_desc.set_tile_size_stride([1, 1], [1, 1]) - yv_tile_desc.set_name("yv_sort_tile") - # Neighbor element descriptors use DRAM offset to preserve affine stride metadata. - yv_s1_tile_desc = mlir_common.MLIRMultiDimTile([1, 1], kernel.vector_lane, vlane_split_axis=1, vlane_stride=1) - yv_s1_tile_desc.set_tile_size_stride([1, 1], [1, 1]) - yv_s1_tile_desc.set_name("yv_sort_tile") - yi_s1_tile_desc = mlir_common.MLIRMultiDimTile([1, 1], kernel.vector_lane, vlane_split_axis=1, vlane_stride=1) - yi_s1_tile_desc.set_tile_size_stride([1, 1], [1, 1]) - yi_s1_tile_desc.set_name("yi_sort_tile") - if int(self.dim) == 1: - yv_s1_tile_desc.offset = sympy.Integer(1) - yi_s1_tile_desc.offset = sympy.Integer(1) + xi = self.input_nodes[1] + yv = self.output_nodes[0] + # XI is updated in-place by the sort kernel, so mark it as an inout arg. + kernel.kernel_group.args.make_inplace(xi.get_name(), xi.get_name()) + sort_size = int(x.get_size()[self.dim]) + vector_size = VECTOR_SIZE + if sort_size <= 0: + raise NotImplementedError("Sort size must be > 0") + if sort_size < vector_size or sort_size % vector_size != 0: + raise NotImplementedError( + f"Sort size must be a multiple of vector size (sort_size={sort_size}, vector_size={vector_size})" + ) + num_chunks = sort_size // vector_size + if num_chunks & (num_chunks - 1): + raise NotImplementedError( + f"Loop-based bitonic chunk merge requires power-of-two chunk count (num_chunks={num_chunks})" + ) + + # --- N-D generalization: outer loops over all non-sort dims --- + rank = len(x.get_size()) + sort_dim = self.dim if self.dim >= 0 else self.dim + rank + if sort_dim < 0 or sort_dim >= rank: + raise NotImplementedError(f"Invalid sort dim for rank-{rank} tensor (dim={self.dim})") + x_layout = x.get_layout() + xi_layout = xi.get_layout() + yv_layout = yv.get_layout() + + if rank == 1: + # Edge case for 1D tensor + output_sizes = [1] + output_dim = [0] + step_sizes = [1] + tile_sizes = [1, sort_size] + x_dram_stride = [int(x_layout.stride[sort_dim]), int(x_layout.stride[sort_dim])] + xi_dram_stride = [int(xi_layout.stride[sort_dim]), int(xi_layout.stride[sort_dim])] + yv_dram_stride = [int(yv_layout.stride[sort_dim]), int(yv_layout.stride[sort_dim])] + template_rank = 2 else: - yv_s1_tile_desc.offset = sympy.Integer(cols) - yi_s1_tile_desc.offset = sympy.Integer(cols) - - row = sympy.Symbol("row") - col = sympy.Symbol("col") - i = sympy.Symbol("i") - j = sympy.Symbol("j") - - init_x_idx = [row * cols, col] - init_yv_idx = [row * cols, col] - init_yi_idx = [row * cols, col] + output_sizes = [sz for d, sz in enumerate(yv.get_size()) if d != sort_dim] + output_dim = [d for d, _ in enumerate(yv.get_size()) if d != sort_dim] + step_sizes = [1] * len(output_sizes) + + tile_dim = max(output_dim, key=lambda d: int(yv.get_size()[d])) + tile_sizes = [min(kernel.vector_lane, int(yv.get_size()[tile_dim])), sort_size] + step_sizes[output_dim.index(tile_dim)] = tile_sizes[0] + + x_dram_stride = [int(x_layout.stride[tile_dim]), int(x_layout.stride[sort_dim])] + xi_dram_stride = [int(xi_layout.stride[tile_dim]), int(xi_layout.stride[sort_dim])] + yv_dram_stride = [int(yv_layout.stride[tile_dim]), int(yv_layout.stride[sort_dim])] + template_rank = rank + + x_offset_map = _make_offset_map(output_dim, x_layout.stride, x_layout.offset) + xi_offset_map = _make_offset_map(output_dim, xi_layout.stride, xi_layout.offset) + yv_offset_map = _make_offset_map(output_dim, yv_layout.stride, yv_layout.offset) + outer_vars = ", ".join(f"%index{d}" for d in output_dim) + + # indent for DMA ops = 2 (inside func) + 2 per outer loop + indent_size = 2 + len(output_dim) * 2 + 4 + + vlane_stride = 1 + vlane_split_axis = 0 + x_tile_desc = mlir_common.MLIRMultiDimTile(tile_sizes, kernel.vector_lane, vlane_split_axis, vlane_stride) + x_tile_desc.set_tile_size_stride(tile_sizes, [sort_size, 1]) + x_tile_desc.set_name("X_buffer") + x_tile_desc.offset = x_layout.offset + + xi_tile_desc = mlir_common.MLIRMultiDimTile(tile_sizes, kernel.vector_lane, vlane_split_axis, vlane_stride) + xi_tile_desc.set_tile_size_stride(tile_sizes, [sort_size, 1]) + xi_tile_desc.set_name("XI_buffer") + xi_tile_desc.offset = xi_layout.offset + + yv_tile_desc = mlir_common.MLIRMultiDimTile(tile_sizes, kernel.vector_lane, vlane_split_axis, vlane_stride) + yv_tile_desc.set_tile_size_stride(tile_sizes, [sort_size, 1]) + yv_tile_desc.set_name("YV_buffer") + yv_tile_desc.offset = yv_layout.offset + + data_stype = mlir_common.DTYPE_TO_MLIR[x.get_dtype()] + idx_stype = mlir_common.DTYPE_TO_MLIR[xi.get_dtype()] + + elem_memref_t = f"memref<1x{sort_size}x{data_stype}, 1>" + rev_indices = list(range(vector_size - 1, -1, -1)) + + bitonic_body = mlir_common.ParallelLoopBuffer(initial_indent=2) + bitonic_body.tabwidth = 2 + # 1) Local SIMD sort per chunk. + init_cse = common.CSE(kernel.newvar_prefix, kernel.suffix, name_prefix="sort_init") + with kernel, kernel.override_buffer_cse(buffer=bitonic_body, cse=init_cse): + bitonic_body.writelines(LoopLevel("chunk", num_chunks).lines()) + with bitonic_body.indent(attribute="{inner_loop=true}"): + bitonic_body.writeline("%elem = affine.apply #map_chunk_to_elem(%chunk)") + x_chunk = ops._load( + vector_size, + data_stype, + "X_buffer", + "%t_const0, %elem", + x_tile_desc.get_mlir_shape(data_stype), + ) + idx_step_index = kernel.register_var_cse("idx_step_index", vector_size, "index") + bitonic_body.writeline(f"%{idx_step_index} = vector.step : vector<{vector_size}xindex>") + idx_step = ops.index_cast(idx_step_index, idx_stype) + idx_base = kernel.register_var_cse("idx_base", 1, idx_stype) + bitonic_body.writeline(f"%{idx_base} = arith.index_cast %elem : index to {idx_stype}") + idx_base_vec = ops.broadcast(idx_base, vector_size) + idx_chunk = ops.add(idx_base_vec, idx_step) + yv_chunk, yi_chunk = _bitonic_sort_pair( + x_chunk, idx_chunk, vector_size, descending=self.descending, stable_sort=self.use_stable_sort + ) + ops._store( + yv_chunk, + "YV_buffer", + "%t_const0, %elem", + yv_tile_desc.get_mlir_shape(data_stype), + ) + ops._store( + yi_chunk, + "XI_buffer", + "%t_const0, %elem", + xi_tile_desc.get_mlir_shape(idx_stype), + ) + + # 2) Chunk-level bitonic merge (loop form). + stage = 0 + k = 2 + while k <= num_chunks: + j = k // 2 + while j >= 1: + for block_start, is_even_block in ((0, True), (k, False)): + if block_start >= num_chunks: + continue + asc_dir = is_even_block if not self.descending else (not is_even_block) + stage_cse = common.CSE(kernel.newvar_prefix, kernel.suffix, name_prefix=f"sort_stage_{stage}") + with kernel, kernel.override_buffer_cse(buffer=bitonic_body, cse=stage_cse): + stage_loops = [ + LoopLevel("base", num_chunks, start=block_start, step=2 * k), + LoopLevel("p", k, step=2 * j), + LoopLevel("q", j), + ] + with contextlib.ExitStack() as stack: + for loop in stage_loops: + bitonic_body.writelines(loop.lines()) + stack.enter_context(bitonic_body.indent(attribute="{inner_loop=true}")) + + bitonic_body.writeline( + f"%left_elem = affine.apply affine_map<(d0, d1, d2) -> ((d0 + d1 + d2) * {vector_size})>(%base, %p, %q)" + ) + bitonic_body.writeline( + f"%right_elem = affine.apply affine_map<(d0, d1, d2) -> ((d0 + d1 + d2 + {j}) * {vector_size})>(%base, %p, %q)" + ) + + left_vec = ops._load( + vector_size, + data_stype, + "YV_buffer", + "%t_const0, %left_elem", + yv_tile_desc.get_mlir_shape(data_stype), + ) + right_vec = ops._load( + vector_size, + data_stype, + "YV_buffer", + "%t_const0, %right_elem", + yv_tile_desc.get_mlir_shape(data_stype), + ) + left_idx = ops._load( + vector_size, + idx_stype, + "XI_buffer", + "%t_const0, %left_elem", + xi_tile_desc.get_mlir_shape(idx_stype), + ) + right_idx = ops._load( + vector_size, + idx_stype, + "XI_buffer", + "%t_const0, %right_elem", + xi_tile_desc.get_mlir_shape(idx_stype), + ) + norm_desc = not asc_dir + left_norm, left_idx_norm = _bitonic_sort_pair( + left_vec, left_idx, vector_size, descending=norm_desc, stable_sort=self.use_stable_sort + ) + right_norm, right_idx_norm = _bitonic_sort_pair( + right_vec, right_idx, vector_size, descending=norm_desc, stable_sort=self.use_stable_sort + ) + left_merge, left_idx_merge, right_merge, right_idx_merge = _merge_sorted_pair_vectors( + left_norm, + left_idx_norm, + right_norm, + right_idx_norm, + ascending=asc_dir, + stable_sort=self.use_stable_sort, + vector_size=vector_size, + rev_indices=rev_indices, + ) + left_new, left_idx_new = _bitonic_sort_pair( + left_merge, left_idx_merge, vector_size, descending=norm_desc, stable_sort=self.use_stable_sort + ) + right_new, right_idx_new = _bitonic_sort_pair( + right_merge, right_idx_merge, vector_size, descending=norm_desc, stable_sort=self.use_stable_sort + ) + ops._store( + left_new, + "YV_buffer", + "%t_const0, %left_elem", + yv_tile_desc.get_mlir_shape(data_stype), + ) + ops._store( + right_new, + "YV_buffer", + "%t_const0, %right_elem", + yv_tile_desc.get_mlir_shape(data_stype), + ) + ops._store( + left_idx_new, + "XI_buffer", + "%t_const0, %left_elem", + xi_tile_desc.get_mlir_shape(idx_stype), + ) + ops._store( + right_idx_new, + "XI_buffer", + "%t_const0, %right_elem", + xi_tile_desc.get_mlir_shape(idx_stype), + ) + stage += 1 + j //= 2 + k *= 2 - d1_s0_idx = [row * cols, j] - d1_s1_idx = [row * cols, j] - - d0_s0_idx = [i * cols, col] - d0_s1_idx = [i * cols, col] - - kernel.loop_size = None - numel = rows * cols kernel.render_options = dict( KERNEL_NAME=self.name, + NAMES_STR="X, XI, YV", kernel=kernel, X=x, + XI=xi, YV=yv, - YI=yi, - OUT_DVAR="YV", - NAMES_STR="X, YI, YV", - ROWS=rows, - COLS=cols, - COLS_MINUS1=cols_minus1, - ROWS_MINUS1=rows_minus1, - DIM=int(self.dim), - DESCENDING=bool(self.descending), - YI_TILE_DESC=yi_tile_desc, + X_TILE_DESC=x_tile_desc, + XI_TILE_DESC=xi_tile_desc, YV_TILE_DESC=yv_tile_desc, - YI_S1_TILE_DESC=yi_s1_tile_desc, - YV_S1_TILE_DESC=yv_s1_tile_desc, - INIT_X_IDX=init_x_idx, - INIT_YV_IDX=init_yv_idx, - INIT_YI_IDX=init_yi_idx, - D1_S0_IDX=d1_s0_idx, - D1_S1_IDX=d1_s1_idx, - D0_S0_IDX=d0_s0_idx, - D0_S1_IDX=d0_s1_idx, - YV_ELEM_TYPE=mlir_common.DTYPE_TO_MLIR[yv_dtype], - YI_ELEM_TYPE=mlir_common.DTYPE_TO_MLIR[yi_dtype], - X_MEMREF_TYPE=f"memref<{numel}x{mlir_common.DTYPE_TO_MLIR[x_dtype]}>", - YV_MEMREF_TYPE=f"memref<{numel}x{mlir_common.DTYPE_TO_MLIR[yv_dtype]}>", - YI_MEMREF_TYPE=f"memref<{numel}x{mlir_common.DTYPE_TO_MLIR[yi_dtype]}>", - YV_TILE_MEMREF_TYPE=yv_tile_desc.get_mlir_shape(mlir_common.DTYPE_TO_MLIR[yv_dtype]), - YI_TILE_MEMREF_TYPE=yi_tile_desc.get_mlir_shape(mlir_common.DTYPE_TO_MLIR[yi_dtype]), - X_TILE_DESC=yv_tile_desc, + SORT_SIZE=sort_size, + VECTOR_SIZE=vector_size, + DATA_STYPE=data_stype, + IDX_STYPE=idx_stype, + ELEM_MEMREF_T=elem_memref_t, + BITONIC_BODY=bitonic_body.getvalue().rstrip(), input_reorder=self.input_reorder, + # N-D generalization + RANK = template_rank, + OUTPUT_SIZES = output_sizes, + OUTPUT_DIM = output_dim, + STEP_SIZES = step_sizes, + OUTER_VARS = outer_vars, + X_OFFSET_MAP = x_offset_map, + XI_OFFSET_MAP = xi_offset_map, + YV_OFFSET_MAP = yv_offset_map, + X_DRAM_STRIDE = x_dram_stride, + XI_DRAM_STRIDE = xi_dram_stride, + YV_DRAM_STRIDE = yv_dram_stride, + INDENT_SIZE = indent_size, ) - - output_node_name = yv.get_name() if hasattr(yv, "get_name") else yv.name - kernel.epilogue_info = dict( - output_node=output_node_name, - sram_var="yv_sort_tile", - dram_var=kernel.render_options["OUT_DVAR"], - dram_tile_desc=yv_tile_desc, - ) - kernel.exception_nodes[kernel.render_options["OUT_DVAR"]] = {"numel": yv.get_numel()} - kernel.exception_nodes["YI"] = {"numel": yi.get_numel()} - code = self._template_from_string(TEMPLATE).render(**kernel.render_options) return code + + +class MLIRStableSortTemplate(MLIRSortTemplate): + def __init__(self, input_nodes, layout, dim, descending=False, stable=True, input_reorder=None): + super().__init__( + input_nodes=input_nodes, + layout=layout, + dim=dim, + descending=descending, + stable=stable, + input_reorder=input_reorder, + ) + self.use_stable_sort = True diff --git a/tests/test_sort.py b/tests/test_sort.py index 2b070223..05afe92b 100644 --- a/tests/test_sort.py +++ b/tests/test_sort.py @@ -1,7 +1,5 @@ import argparse import torch -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -34,63 +32,85 @@ def test_equal(name, out, cpu_out): print("cpu out:", cpu_out) raise SystemExit(1) - -def _normalize_dim(dim: int, rank: int) -> int: - d = dim if dim >= 0 else rank + dim - if d < 0 or d >= rank: - raise ValueError(f"dim out of range: dim={dim}, rank={rank}") - return d - - -def test_sort_stable(device, size=(128, 128), dim=-1, descending=False): - _normalize_dim(dim, len(size)) - - def sort_stable_fn(x): - return torch.sort(x, stable=True, dim=dim, descending=descending) - - x = torch.randn(size, dtype=torch.float32) - x_npu = x.to(device=device) - - opt_sort = torch.compile(dynamic=False)(sort_stable_fn) - out_values, out_indices = opt_sort(x_npu) - - ref_values, ref_indices = torch.sort(x, stable=True, dim=dim, descending=descending) - - test_result("Sort.stable/values", out_values, ref_values) - test_equal("Sort.stable/indices", out_indices, ref_indices) - - -def test_sort_values_stable(device, size=(128, 128), dim=-1, descending=False): - _normalize_dim(dim, len(size)) - - def sort_out_fn(x): - out_values = torch.empty_like(x, device=x.device) - out_indices = torch.empty_like(x, dtype=torch.int64, device=x.device) - return torch.sort(x, stable=True, dim=dim, descending=descending, out=(out_values, out_indices)) +def test_sort(device, size=(128, 128), dim=-1, descending=False, stable=True): + def sort_test(x): + return torch.sort(x, dim=dim, descending=descending, stable=stable) x = torch.randn(size, dtype=torch.float32) x_npu = x.to(device=device) - opt_sort = sort_out_fn# torch.compile(dynamic=False)(sort_out_fn) + opt_sort = torch.compile(dynamic=False)(sort_test) out_values, out_indices = opt_sort(x_npu) + ref_values, ref_indices = torch.sort(x, stable=stable, dim=dim, descending=descending) - ref_values, ref_indices = torch.sort(x, stable=True, dim=dim, descending=descending) - - test_result("Sort.values_stable/values", out_values, ref_values) - test_equal("Sort.values_stable/indices", out_indices, ref_indices) - + prefix = "Sort.stable" if stable else "Sort.unstable" + test_result(f"{prefix}/values size={size}, dim={dim}, desc={descending}", out_values, ref_values) + if stable: + test_result(f"{prefix}/indices size={size}, dim={dim}, desc={descending}", out_indices, ref_indices) + else: + # Unstable sort does not guarantee tie ordering; validate index-value consistency instead. + gathered = torch.gather(x, dim, out_indices.cpu()) + test_result(f"{prefix}/indices_gather size={size}, dim={dim}, desc={descending}", gathered, out_values.cpu()) + + +def test_sort_stable_suite(device): + # Keep sort-axis sizes compatible with backend constraints (vector-size multiple). + cases = [ + {"size": (64,), "dim": 0, "descending": False}, # 1D + {"size": (4, 64), "dim": 1, "descending": True}, # 2D, last dim + {"size": (2, 8, 32), "dim": 2, "descending": False}, # 3D, last dim + {"size": (2, 16, 4), "dim": 1, "descending": True}, # 3D, middle dim + {"size": (2, 4, 8, 32), "dim": 3, "descending": False}, # 4D, last dim + {"size": (4, 2, 32, 8), "dim": 2, "descending": True}, # 4D, inner dim + ] + for case in cases: + test_sort( + device=device, + size=case["size"], + dim=case["dim"], + descending=case["descending"], + stable=True, + ) + + +def test_sort_duplicate_cases(device): + duplicate_cases = [ + {"size": (64,), "dim": 0, "descending": False}, + {"size": (4, 64), "dim": 1, "descending": True}, + {"size": (2, 8, 32), "dim": 2, "descending": False}, + ] + for case in duplicate_cases: + base = torch.arange(case["size"][case["dim"]], dtype=torch.int64) % 7 + view_shape = [1] * len(case["size"]) + view_shape[case["dim"]] = case["size"][case["dim"]] + x = base.view(view_shape).expand(case["size"]).to(torch.float32) + noise = torch.randn(case["size"], dtype=torch.float32) * 0.0 + x = x + noise + + def sort_test(inp): + return torch.sort(inp, dim=case["dim"], descending=case["descending"], stable=True) + + out_values, out_indices = torch.compile(dynamic=False)(sort_test)(x.to(device=device)) + ref_values, ref_indices = torch.sort( + x, dim=case["dim"], descending=case["descending"], stable=True + ) + test_result(f"Sort.dup/stable_values {case}", out_values, ref_values) + test_equal(f"Sort.dup/stable_indices {case}", out_indices, ref_indices) + + def sort_test_unstable(inp): + return torch.sort(inp, dim=case["dim"], descending=case["descending"], stable=False) + + out_values_u, out_indices_u = torch.compile(dynamic=False)(sort_test_unstable)(x.to(device=device)) + ref_values_u, _ = torch.sort(x, dim=case["dim"], descending=case["descending"], stable=False) + test_result(f"Sort.dup/unstable_values {case}", out_values_u, ref_values_u) + gathered_u = torch.gather(x, case["dim"], out_indices_u.cpu()) + test_result(f"Sort.dup/unstable_gather {case}", gathered_u, out_values_u.cpu()) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run sort tests") - parser.add_argument("--shape", type=str, default="(128,128)") + parser.add_argument("--shape", type=str, default="(64, 32, 16)") parser.add_argument("--dim", type=int, default=0) parser.add_argument("--descending", action="store_true") - parser.add_argument( - "--mode", - type=str, - default="all", - choices=["all", "default", "values"], - ) args = parser.parse_args() shape = tuple(map(int, args.shape.strip("()").split(","))) @@ -100,13 +120,5 @@ def sort_out_fn(x): module = PyTorchSimRunner.setup_device() device = module.custom_device() - # Register recursive-compile bridge only when values_stable path is explicitly tested. - if args.mode in ("all", "values"): - torch.npu.register_eager_to_compile([ - "aten::sort.values_stable", - ]) - - if args.mode in ("all", "default"): - test_sort_stable(device, size=shape, dim=args.dim, descending=args.descending) - if args.mode in ("all", "values"): - test_sort_values_stable(device, size=shape, dim=args.dim, descending=args.descending) + test_sort_stable_suite(device) + test_sort_duplicate_cases(device) \ No newline at end of file From 752cbb834df7705fe12ec18da281d5b76032034e Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Wed, 11 Mar 2026 11:55:53 +0900 Subject: [PATCH 16/31] [Template] Use buffer type instead of hard-coded type --- PyTorchSimFrontend/extension_codecache.py | 21 +++------- PyTorchSimFrontend/mlir/mlir_bmm_template.py | 39 +++++++++++-------- .../mlir/mlir_caller_codegen.py | 4 -- PyTorchSimFrontend/mlir/mlir_common.py | 2 +- PyTorchSimFrontend/mlir/mlir_conv_common.py | 6 +++ .../mlir/mlir_conv_mt_template.py | 18 +++++---- .../mlir/mlir_conv_sb_template.py | 18 +++++---- .../mlir/mlir_conv_sbs_template.py | 18 +++++---- PyTorchSimFrontend/mlir/mlir_conv_template.py | 18 +++++---- PyTorchSimFrontend/mlir/mlir_gemm_template.py | 20 +++++++--- Simulator/simulator.py | 3 +- 11 files changed, 92 insertions(+), 75 deletions(-) diff --git a/PyTorchSimFrontend/extension_codecache.py b/PyTorchSimFrontend/extension_codecache.py index d6b47123..8454dee6 100644 --- a/PyTorchSimFrontend/extension_codecache.py +++ b/PyTorchSimFrontend/extension_codecache.py @@ -67,9 +67,10 @@ def mlir_compile_command(filename, vectorlane_size, vlen=256): f""" {extension_config.CONFIG_TORCHSIM_LLVM_PATH}/llc \ -relocation-model=pic -march=riscv64 -O3 --stack-size-section \ - -mattr=+m,+f,+d,+a,+c,+v,+xsfvcp,zvl{vlen}b \ + -mattr=+m,+f,+d,+a,+c,+v,+zvfh,+xsfvcp,zvl{vlen}b \ + -filetype=obj \ {'--print-after-all' if extension_config.CONFIG_TORCHSIM_DUMP_LLVM_IR else ''} \ - -O2 {filename}.ll -o {filename}.s + -O2 {filename}.ll -o {filename}.o """, ).strip()] @@ -109,9 +110,10 @@ def mlir_gem5_compile_command(filename, sample_filename, tog_file, vectorlane_si f""" {extension_config.CONFIG_TORCHSIM_LLVM_PATH}/llc \ -relocation-model=pic -march=riscv64 -O3 --stack-size-section \ - -mattr=+m,+f,+d,+a,+c,+v,+xsfvcp,zvl{vlen}b \ + -mattr=+m,+f,+d,+a,+c,+v,+zvfh,+xsfvcp,zvl{vlen}b \ + -filetype=obj \ {'--print-after-all' if extension_config.CONFIG_TORCHSIM_DUMP_LLVM_IR else ''} \ - -O2 {sample_filename}.ll -o {sample_filename}.s + -O2 {sample_filename}.ll -o {sample_filename}.o """, ).strip()] @@ -180,17 +182,6 @@ def load(cls, source_code, val_llvm_caller.generate_wrapper_file(write_path, validation_wrapper_name) val_llvm_caller.compile_wih_kernel(write_path, key, validation_wrapper_name, validation_binary_name, new_link_option) - - stack_size = val_llvm_caller.parse_stack_sizes(f"{write_path}/{key}.s", vlenb=vlenb) - spad_size = val_llvm_caller.get_spad_size(validation_binary_path) - spad_usage = stack_size + spad_size # Spad usage per lane - if extension_config.CONFIG_SPAD_INFO["spad_size"] < spad_usage: - logger.debug( - f"Scratchpad size exceeded: required {spad_usage} bytes, " - f"but only {extension_config.CONFIG_SPAD_INFO['spad_size']} bytes available." - ) - raise SpadOverflowError() - # Skip if TOG file already exists if os.path.isfile(tog_path): return key diff --git a/PyTorchSimFrontend/mlir/mlir_bmm_template.py b/PyTorchSimFrontend/mlir/mlir_bmm_template.py index 9398f90c..417d97cd 100644 --- a/PyTorchSimFrontend/mlir/mlir_bmm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_bmm_template.py @@ -26,20 +26,20 @@ {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} {% if not Bias %} - %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}x{{DATA_STYPE}}> {% endif %} %c0 = arith.constant 0 : index {{ kernel.def_local_vars(indent_size=2) }} affine.for %index0 = 0 to {{ B }} { affine.for %index1 = 0 to {{ M }} step {{ TILE_M }} { affine.for %index2 = 0 to {{ N }} step {{ TILE_N }} { - %X_buffer2D = memref.reinterpret_cast %X_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1> - %W_buffer2D = memref.reinterpret_cast %W_buffer to offset: [0], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ W_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1> - %Y_buffer2D = memref.reinterpret_cast %Y_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> + %X_buffer2D = memref.reinterpret_cast %X_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}x{{ DATA_STYPE }}, 1> + %W_buffer2D = memref.reinterpret_cast %W_buffer to offset: [0], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ W_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_K }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1> + %Y_buffer2D = memref.reinterpret_cast %Y_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1> {% if Bias -%} {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Y_tile_desc, subtile_size=[1, SUB_TILE_M, SUB_TILE_N], indent_size=8) }} {%- else -%} - affine.vector_store %v0, %Y_buffer[0, 0, 0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + affine.vector_store %v0, %Y_buffer[0, 0, 0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}x{{DATA_STYPE}}> {% endif %} affine.for %index3 = 0 to {{ K }} step {{ TILE_K }} { @@ -74,20 +74,20 @@ {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} {% if not Bias %} - %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}x{{DATA_STYPE}}> {% endif %} %c0 = arith.constant 0 : index {{ kernel.def_local_vars(indent_size=2) }} affine.for %index0 = 0 to {{ B }} { affine.for %index1 = 0 to {{ M }} step {{ TILE_M }} { affine.for %index2 = 0 to {{ N }} step {{ TILE_N }} { - %X_buffer2D = memref.reinterpret_cast %X_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : memref<1x{{ TILE_M }}x{{ TILE_K }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1> - %W_buffer2D = memref.reinterpret_cast %W_buffer to offset: [0], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<1x{{ TILE_K }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1> - %Y_buffer2D = memref.reinterpret_cast %Y_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<1x{{ TILE_M }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> + %X_buffer2D = memref.reinterpret_cast %X_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : memref<1x{{ TILE_M }}x{{ TILE_K }}x{{DATA_STYPE}}, 1> to memref<{{ TILE_M }}x{{ TILE_K }}x{{DATA_STYPE}}, 1> + %W_buffer2D = memref.reinterpret_cast %W_buffer to offset: [0], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<1x{{ TILE_K }}x{{ TILE_N }}x{{DATA_STYPE}}, 1> to memref<{{ TILE_K }}x{{ TILE_N }}x{{DATA_STYPE}}, 1> + %Y_buffer2D = memref.reinterpret_cast %Y_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<1x{{ TILE_M }}x{{ TILE_N }}x{{DATA_STYPE}}, 1> to memref<{{ TILE_M }}x{{ TILE_N }}x{{DATA_STYPE}}, 1> {% if Bias -%} {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Y_tile_desc, subtile_size=[1, SUB_TILE_M, SUB_TILE_N], indent_size=8) }} {%- else -%} - affine.vector_store %v0, %Y_buffer[0, 0, 0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + affine.vector_store %v0, %Y_buffer[0, 0, 0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}x{{DATA_STYPE}}> {% endif %} affine.for %index3 = 0 to {{ K }} step {{ TILE_K }} { {{kernel.load_input(indent_size=10)}} @@ -120,21 +120,21 @@ {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} {% if not Bias %} - %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}x{{DATA_STYPE}}> {% endif %} %c0 = arith.constant 0 : index {{ kernel.def_local_vars(indent_size=2) }} affine.for %index0=0 to {{ B }} { affine.for %index2 = 0 to {{ N }} step {{ TILE_N }} { affine.for %index1 = 0 to {{ M }} step {{ TILE_M }} { - %X_buffer2D = memref.reinterpret_cast %X_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : memref<1x{{ TILE_M }}x{{ TILE_K }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1> - %W_buffer2D = memref.reinterpret_cast %W_buffer to offset: [0], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<1x{{ TILE_K }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1> - %Y_buffer2D = memref.reinterpret_cast %Y_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<1x{{ TILE_N }}x{{ TILE_M }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> + %X_buffer2D = memref.reinterpret_cast %X_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : memref<1x{{ TILE_M }}x{{ TILE_K }}x{{DATA_STYPE}}, 1> to memref<{{ TILE_M }}x{{ TILE_K }}x{{DATA_STYPE}}, 1> + %W_buffer2D = memref.reinterpret_cast %W_buffer to offset: [0], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<1x{{ TILE_K }}x{{ TILE_N }}x{{DATA_STYPE}}, 1> to memref<{{ TILE_K }}x{{ TILE_N }}x{{DATA_STYPE}}, 1> + %Y_buffer2D = memref.reinterpret_cast %Y_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<1x{{ TILE_N }}x{{ TILE_M }}x{{DATA_STYPE}}, 1> to memref<{{ TILE_M }}x{{ TILE_N }}x{{DATA_STYPE}}, 1> {% if Bias -%} {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Y_tile_desc, subtile_size=[1, SUB_TILE_M, SUB_TILE_N], indent_size=8) }} // Why not N,M? Currently, dma-fine-grained pass assume M->N order... {%- else -%} - affine.vector_store %v0, %Y_buffer[0, 0, 0] : memref<1x{{ TILE_N }}x{{ TILE_M }}xf32, 1>, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + affine.vector_store %v0, %Y_buffer[0, 0, 0] : memref<1x{{ TILE_N }}x{{ TILE_M }}x{{DATA_STYPE}}, 1>, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}x{{DATA_STYPE}}> {% endif %} affine.for %index3 = 0 to {{ K }} step {{ TILE_K }} { {{ kernel.def_dma_op("MVIN", "X", X_idx, X_tile_desc, subtile_size=[1, SUB_TILE_M, SUB_TILE_K], indent_size=10) }} @@ -237,6 +237,7 @@ def render(self, else: Bias_idx = None + data_stype = mlir_common.DTYPE_TO_MLIR[X.get_dtype()] kernel.render_options = dict( KERNEL_NAME=self.name, kernel=kernel, @@ -245,7 +246,7 @@ def render(self, SUB_TILE_M=SUB_TILE_M, SUB_TILE_N=SUB_TILE_N, SUB_TILE_K=SUB_TILE_K, - DATA_STYPE="f32", + DATA_STYPE=data_stype, X = X, W = W,Y = Y, Bias = Bias, X_idx = X_idx, W_idx = W_idx, @@ -319,6 +320,12 @@ def extract_info(self, template_buffer_node, epilogue_nodes, prologue_nodes): X, W = self.input_nodes[0], self.input_nodes[1] Y = self.output_node Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + dtype_infos = [("X", X.get_dtype()), ("W", W.get_dtype()), ("Y", Y.get_dtype())] + if Bias is not None: + dtype_infos.append(("Bias", Bias.get_dtype())) + if len({dtype for _, dtype in dtype_infos}) != 1: + dtype_desc = ", ".join(f"{name}={dtype}" for name, dtype in dtype_infos) + raise NotImplementedError(f"Mixed dtype BMM is not implemented yet ({dtype_desc})") W_tensor = empty_strided(W.layout.size, W.layout.stride) X_tensor = empty_strided(X.layout.size, X.layout.stride) diff --git a/PyTorchSimFrontend/mlir/mlir_caller_codegen.py b/PyTorchSimFrontend/mlir/mlir_caller_codegen.py index 06d41ea2..7c842272 100644 --- a/PyTorchSimFrontend/mlir/mlir_caller_codegen.py +++ b/PyTorchSimFrontend/mlir/mlir_caller_codegen.py @@ -182,22 +182,18 @@ def add_extention(self, name, extension): def compile_wih_kernel(self, write_path, llvm_name, wrapper_name, binary_name, link_option=""): main_path = os.path.join(write_path, self.add_extention(wrapper_name, 'c')) main_obj_path = os.path.join(write_path, self.add_extention(wrapper_name, 'o')) - kernel_path = os.path.join(write_path, self.add_extention(llvm_name, 's')) kernel_obj_path = os.path.join(write_path, self.add_extention(llvm_name, 'o')) main_compile = f'riscv64-unknown-elf-gcc -march=rv64gcv -c {main_path} -o {main_obj_path}' - kernel_compile = f'clang -c --target="riscv64" -march=rv64gcv -O2 -nostdlib {kernel_path} -o {kernel_obj_path}' target = os.path.join(write_path, binary_name) link = f'riscv64-unknown-elf-gcc -march=rv64gcv {main_obj_path} {kernel_obj_path} -o {target} -lm {link_option}' main_compile_cmd = shlex.split(main_compile) - kernel_compile_cmd = shlex.split(kernel_compile) link_cmd = shlex.split(link) try: subprocess.check_call(main_compile_cmd) - subprocess.check_call(kernel_compile_cmd) subprocess.check_call(link_cmd) except subprocess.CalledProcessError as e: print("Command failed with exit code", e.returncode) diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index 256d7101..3c408681 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -67,7 +67,7 @@ DTYPE_TO_C = { torch.float32: "float", torch.float64: "double", - torch.float16: "half", + torch.float16: "uint16_t", torch.int64: "int64_t", torch.int32: "int32_t", torch.int16: "int16_t", diff --git a/PyTorchSimFrontend/mlir/mlir_conv_common.py b/PyTorchSimFrontend/mlir/mlir_conv_common.py index f72a7663..91e200a8 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_common.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_common.py @@ -52,6 +52,12 @@ def extract_info(self, kernel, template_buffer_node, epilogue_nodes): X, W = self.input_nodes[0], self.input_nodes[1] Y = self.output_node Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + dtype_infos = [("X", X.get_dtype()), ("W", W.get_dtype()), ("Y", Y.get_dtype())] + if Bias is not None: + dtype_infos.append(("Bias", Bias.get_dtype())) + if len({dtype for _, dtype in dtype_infos}) != 1: + dtype_desc = ", ".join(f"{name}={dtype}" for name, dtype in dtype_infos) + raise NotImplementedError(f"Mixed dtype Conv is not implemented yet ({dtype_desc})") if epilogue_nodes is not None: extra_node_rw = { diff --git a/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py b/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py index da2bc829..e91014fa 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py @@ -47,7 +47,7 @@ {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} - %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}x{{DATA_STYPE}}> %c0 = arith.constant 0 : index {{- kernel.def_local_vars(indent_size=2) }} @@ -59,7 +59,7 @@ {%- if BIAS %} {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Bias_tile_desc, subtile_size=[SUB_TILE_M, SUB_TILE_N, TILE_O_H, TILE_O_W], indent_size=10) }} {%- else %} - affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> + affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}x{{DATA_STYPE}}> {%- endif %} affine.for %k_h = 0 to {{ K_H }} step {{ TILE_K_H }} { affine.for %tile_k = 0 to {{ I_C * K_W }} step {{ TILE_K }} { @@ -71,16 +71,16 @@ affine.for %tile_k_h = 0 to {{ TILE_K_H }} { // loop order should be fixed for timing simulation. Do not change this order. affine.for %tile_k_w = 0 to 1 { %offset_w = affine.apply #offset_w_map(%tile_k_h, %tile_k_w) - %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ W_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ W_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_K }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1> affine.for %tile_o_h = 0 to {{ TILE_O_H }} { affine.for %tile_o_w = 0 to {{ TILE_O_W }} { %tile_i_h = affine.apply #map_I_H(%tile_o_h, %tile_k_h) %offset_x = affine.apply #offset_x_map(%tile_i_h, %tile_o_w) %offset_y = affine.apply #offset_y_map(%tile_o_h, %tile_o_w) - %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1> - %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> - linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) - outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}x{{DATA_STYPE}}, strided<[{{ TILE_K }}, 1], offset: ?>, 1> + %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}x{{DATA_STYPE}}, strided<[{{ TILE_K }}, 1], offset: ?>, 1>, memref<{{ TILE_K }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) } { inner_loop=true } } { inner_loop=true } } { inner_loop=true } @@ -179,6 +179,8 @@ def render(self, if Bias is not None: Bias_tile_desc.offset = Bias.get_layout().offset + data_stype = mlir_common.DTYPE_TO_MLIR[X.get_dtype()] + kernel.render_options = dict( KERNEL_NAME=self.name, kernel=kernel, @@ -220,7 +222,7 @@ def render(self, X_idx = X_idx, W_idx = W_idx, Bias_idx = Bias_idx, - DATA_STYPE="f32", + DATA_STYPE=data_stype, input_reorder=self.input_reorder ) diff --git a/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py b/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py index cc284522..db2c64db 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py @@ -48,7 +48,7 @@ {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} - %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}x{{DATA_STYPE}}> %c0 = arith.constant 0 : index {{- kernel.def_local_vars(indent_size=2) }} affine.for %tile_n = 0 to {{ O_C }} step {{ TILE_N }} { @@ -58,7 +58,7 @@ {%- if BIAS %} {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Bias_tile_desc, subtile_size=[1, SUB_TILE_N, TILE_O_H, SUB_TILE_M], indent_size=8) }} {%- else %} - affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> + affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}x{{DATA_STYPE}}> {%- endif %} affine.for %k_h = 0 to {{ K_H }} step {{ TILE_K_H }} { affine.for %k_w = 0 to {{ K_W }} step {{ TILE_K_W }} { @@ -72,16 +72,16 @@ affine.for %tile_k_h = 0 to {{ TILE_K_H }} { // loop order should be fixed for timing simulation. Do not change this order. affine.for %tile_k_w = 0 to {{ TILE_K_W }} { %offset_w = affine.apply #offset_w_map(%tile_k_h, %tile_k_w) - %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ W_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ W_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_K }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1> affine.for %tile_o_h = 0 to {{ TILE_O_H }} { affine.for %tile_o_w = 0 to {{ 1 }} { // TILE_O_W %tile_i_h = affine.apply #map_I_H(%tile_o_h, %tile_k_h) %offset_x = affine.apply #offset_x_map(%tile_i_h, %tile_k_w) %offset_y = affine.apply #offset_y_map(%tile_o_h, %tile_o_w) - %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1> - %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> - linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) - outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}x{{DATA_STYPE}}, strided<[{{ TILE_K }}, 1], offset: ?>, 1> + %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}x{{DATA_STYPE}}, strided<[{{ TILE_K }}, 1], offset: ?>, 1>, memref<{{ TILE_K }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) } { inner_loop=true } } { inner_loop=true } } { inner_loop=true } @@ -178,6 +178,8 @@ def render(self, if Bias is not None: Bias_tile_desc.offset = Bias.get_layout().offset + data_stype = mlir_common.DTYPE_TO_MLIR[X.get_dtype()] + kernel.render_options = dict( KERNEL_NAME=self.name, kernel=kernel, @@ -219,7 +221,7 @@ def render(self, X_idx = X_idx, W_idx = W_idx, Bias_idx = Bias_idx, - DATA_STYPE="f32", + DATA_STYPE=data_stype, input_reorder=self.input_reorder ) diff --git a/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py b/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py index 6d768bf2..95db53c3 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py @@ -48,7 +48,7 @@ {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} - %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}x{{DATA_STYPE}}> %c0 = arith.constant 0 : index {{- kernel.def_local_vars(indent_size=2) }} @@ -59,7 +59,7 @@ {%- if BIAS %} {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Bias_tile_desc, subtile_size=[1, SUB_TILE_N, TILE_O_H, SUB_TILE_M], indent_size=8) }} {%- else %} - affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> + affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}x{{DATA_STYPE}}> {%- endif %} affine.for %k_h = 0 to {{ K_H }} step {{ TILE_K_H }} { affine.for %k_w = 0 to {{ K_W }} step {{ TILE_K_W }} { @@ -72,16 +72,16 @@ affine.for %tile_k_h = 0 to {{ TILE_K_H }} { // loop order should be fixed for timing simulation. Do not change this order. affine.for %tile_k_w = 0 to {{ TILE_K_W }} { %offset_w = affine.apply #offset_w_map(%tile_k_h, %tile_k_w) - %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<{{ TILE_K_H }}x{{ TILE_K_W }}x{{ TILE_K }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<{{ TILE_K_H }}x{{ TILE_K_W }}x{{ TILE_K }}x{{ TILE_N }}x{{DATA_STYPE}}, 1> to memref<{{ TILE_K }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1> affine.for %tile_o_h = 0 to {{ TILE_O_H }} { affine.for %tile_o_w = 0 to {{ 1 }} { // TILE_O_W %tile_i_h = affine.apply #map_I_H(%tile_o_h, %tile_k_h) %offset_x = affine.apply #offset_x_map(%tile_i_h, %tile_k_w) %offset_y = affine.apply #offset_y_map(%tile_o_h, %tile_o_w) - %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1> - %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> - linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) - outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}x{{DATA_STYPE}}, strided<[{{ TILE_K }}, 1], offset: ?>, 1> + %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}x{{DATA_STYPE}}, strided<[{{ TILE_K }}, 1], offset: ?>, 1>, memref<{{ TILE_K }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) } { inner_loop=true } } { inner_loop=true } } { inner_loop=true } @@ -179,6 +179,8 @@ def render(self, if Bias is not None: Bias_tile_desc.offset = Bias.get_layout().offset + data_stype = mlir_common.DTYPE_TO_MLIR[X.get_dtype()] + kernel.render_options = dict( KERNEL_NAME=self.name, kernel=kernel, @@ -220,7 +222,7 @@ def render(self, X_idx = X_idx, W_idx = W_idx, Bias_idx = Bias_idx, - DATA_STYPE="f32", + DATA_STYPE=data_stype, input_reorder=self.input_reorder ) diff --git a/PyTorchSimFrontend/mlir/mlir_conv_template.py b/PyTorchSimFrontend/mlir/mlir_conv_template.py index e2cd61fd..3666b3c9 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_template.py @@ -48,7 +48,7 @@ {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} - %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}x{{DATA_STYPE}}> %c0 = arith.constant 0 : index {{ kernel.def_local_vars(indent_size=2) }} @@ -60,7 +60,7 @@ {%- if BIAS %} {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Bias_tile_desc, subtile_size=[SUB_TILE_M, SUB_TILE_N, TILE_O_H, TILE_O_W], indent_size=10) }} {%- else %} - affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> + affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}x{{DATA_STYPE}}> {%- endif %} affine.for %k_h = 0 to {{ K_H }} step {{ TILE_K_H }} { affine.for %k_w = 0 to {{ K_W }} step {{ TILE_K_W }} { @@ -74,17 +74,17 @@ affine.for %tile_k_h = 0 to {{ TILE_K_H }} { // loop order should be fixed for timing simulation. Do not change this order. affine.for %tile_k_w = 0 to {{ TILE_K_W }} { %offset_w = affine.apply #offset_w_map(%tile_k_h, %tile_k_w) - %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ W_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ W_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_K }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1> affine.for %tile_o_h = 0 to {{ TILE_O_H }} { affine.for %tile_o_w = 0 to {{ TILE_O_W }} { %tile_i_h = affine.apply #map_I_H(%tile_o_h, %tile_k_h) %tile_i_w = affine.apply #map_I_W(%tile_o_w, %tile_k_w) %offset_x = affine.apply #offset_x_map(%tile_i_h, %tile_i_w) %offset_y = affine.apply #offset_y_map(%tile_o_h, %tile_o_w) - %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1> - %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> - linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) - outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}x{{DATA_STYPE}}, strided<[{{ TILE_K }}, 1], offset: ?>, 1> + %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}x{{DATA_STYPE}}, strided<[{{ TILE_K }}, 1], offset: ?>, 1>, memref<{{ TILE_K }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) } { inner_loop=true } } { inner_loop=true } } { inner_loop=true } @@ -183,6 +183,8 @@ def render(self, if Bias is not None: Bias_tile_desc.offset = Bias.get_layout().offset + data_stype = mlir_common.DTYPE_TO_MLIR[X.get_dtype()] + kernel.render_options = dict( KERNEL_NAME=self.name, kernel=kernel, @@ -224,7 +226,7 @@ def render(self, X_idx = X_idx, W_idx = W_idx, Bias_idx = Bias_idx, - DATA_STYPE="f32", + DATA_STYPE=data_stype, input_reorder=self.input_reorder ) diff --git a/PyTorchSimFrontend/mlir/mlir_gemm_template.py b/PyTorchSimFrontend/mlir/mlir_gemm_template.py index 5b116807..eb391dba 100644 --- a/PyTorchSimFrontend/mlir/mlir_gemm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_gemm_template.py @@ -27,14 +27,14 @@ {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} {% if not Bias %} - %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32>{% endif %} + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}x{{DATA_STYPE}}>{% endif %} {{ kernel.def_local_vars(indent_size=2) }} affine.for %index0 = 0 to {{ M }} step {{ TILE_M }} { affine.for %index1 = 0 to {{ N }} step {{ TILE_N }} { {%- if Bias %} {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Bias_tile_desc, subtile_size=[SUB_TILE_M, SUB_TILE_N], indent_size=6) }} {%- else %} - affine.vector_store %v0, %Y_buffer[0, 0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + affine.vector_store %v0, %Y_buffer[0, 0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}x{{DATA_STYPE}}> {%- endif %} affine.for %index2 = 0 to {{ K }} step {{ TILE_K }} { {% if prologue_nodes -%} @@ -77,16 +77,16 @@ {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} {% if not Bias %} - %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}x{{DATA_STYPE}}> {% endif %} {{ kernel.def_local_vars(indent_size=2) }} affine.for %index1 = 0 to {{ N }} step {{ TILE_N }} { affine.for %index0 = 0 to {{ M }} step {{ TILE_M }} { - %Y_bufferT = memref.reinterpret_cast %Y_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> + %Y_bufferT = memref.reinterpret_cast %Y_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}x{{DATA_STYPE}}, 1> {%- if Bias %} {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Bias_tile_desc, subtile_size=[SUB_TILE_M, SUB_TILE_N], indent_size=6) }} {%- else %} - affine.vector_store %v0, %Y_buffer[0, 0] : memref<{{ TILE_N }}x{{ TILE_M }}xf32, 1>, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + affine.vector_store %v0, %Y_buffer[0, 0] : memref<{{ TILE_N }}x{{ TILE_M }}x{{DATA_STYPE}}, 1>, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}x{{DATA_STYPE}}> {%- endif %} affine.for %index2 = 0 to {{ K }} step {{ TILE_K }} { {{ kernel.def_dma_op("MVIN", "X", X_idx, X_tile_desc, subtile_size=[SUB_TILE_M, SUB_TILE_K], indent_size=8) }} @@ -187,6 +187,8 @@ def render(self, else: Bias_idx = None + data_stype = mlir_common.DTYPE_TO_MLIR[X.get_dtype()] + kernel.render_options = dict( KERNEL_NAME=self.name, kernel=kernel, @@ -197,7 +199,7 @@ def render(self, SUB_TILE_M=SUB_TILE_M, SUB_TILE_N=SUB_TILE_N, SUB_TILE_K=SUB_TILE_K, - DATA_STYPE="f32", + DATA_STYPE=data_stype, X = X, W = W, Y = Y, Bias = Bias, X_idx = X_idx, @@ -280,6 +282,12 @@ def extract_info(self, template_buffer_node, epilogue_nodes, prologue_nodes): # Extract input arguments info X, W, Y = self.input_nodes[0], self.input_nodes[1], self.output_node + dtype_infos = [("X", X.get_dtype()), ("W", W.get_dtype()), ("Y", Y.get_dtype())] + if len(self.input_nodes) > 2: + dtype_infos.append(("Bias", self.input_nodes[2].get_dtype())) + if len({dtype for _, dtype in dtype_infos}) != 1: + dtype_desc = ", ".join(f"{name}={dtype}" for name, dtype in dtype_infos) + raise NotImplementedError(f"Mixed dtype GEMM is not implemented yet ({dtype_desc})") X_tensor = empty_strided(X.layout.size, X.layout.stride) W_tensor = empty_strided(W.layout.size, W.layout.stride) if len(W_tensor.size()) > 2 or len(X_tensor.size()) > 2: diff --git a/Simulator/simulator.py b/Simulator/simulator.py index 13f2b4f0..f24835ba 100644 --- a/Simulator/simulator.py +++ b/Simulator/simulator.py @@ -68,6 +68,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): torch.uint8: np.uint8, torch.bool: np.uint8, torch.bfloat16: np.float16, + torch.float16: np.float16, } class FunctionalSimulator(): @@ -143,7 +144,7 @@ def run_spike(self, args, arg_attributes, runtime_path, binary, vectorlane_size= base_path= f"--base-path={runtime_path}" os.makedirs(os.path.join(runtime_path, "indirect_access"), exist_ok=True) os.makedirs(os.path.join(runtime_path, "dma_access"), exist_ok=True) - run = f'spike --isa rv64gcv --varch=vlen:256,elen:64 {vectorlane_option} {spad_option} {kernel_address} {base_path} /workspace/riscv-pk/build/pk {target_binary} {file_path_str}' + run = f'spike --isa rv64gcv_zfh --varch=vlen:256,elen:64 {vectorlane_option} {spad_option} {kernel_address} {base_path} /workspace/riscv-pk/build/pk {target_binary} {file_path_str}' if not silent_mode: logger.debug(f"[Spike] cmd> {run}") logger.info("[Spike] Running Spike simulator") From 7af91dedeca74703c35ec9446ec167fcb8e4ec88 Mon Sep 17 00:00:00 2001 From: HamHyungkyu Date: Thu, 12 Mar 2026 10:09:40 +0900 Subject: [PATCH 17/31] [Frontend] Fix incorrect constant key usage and boolean scientific-notation edge case --- PyTorchSimFrontend/mlir/mlir_codegen_backend.py | 10 +++++----- PyTorchSimFrontend/mlir/mlir_ops.py | 4 ++++ 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index d6ddb025..43cb65a4 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -1423,11 +1423,11 @@ def get_const_cse(self, value, dtype="index") -> common.CSEVariable: value = float(value) else: value = int(value) - - if value not in self.consts: - self.consts[str(value)+dtype] = self.const_cse.generate(self.const_buffer, f"arith.constant {value} : {dtype}") - self.register_var_info(self.consts[str(value)+dtype], [1, dtype]) - return self.consts[str(value)+dtype] + key = str(value)+dtype + if key not in self.consts: + self.consts[key] = self.const_cse.generate(self.const_buffer, f"arith.constant {value} : {dtype}") + self.register_var_info(self.consts[key], [1, dtype]) + return self.consts[key] def get_tag_cse(self, value=None, shape="memref<1xi32>"): if value is None: diff --git a/PyTorchSimFrontend/mlir/mlir_ops.py b/PyTorchSimFrontend/mlir/mlir_ops.py index ace4f9ea..76a0e273 100644 --- a/PyTorchSimFrontend/mlir/mlir_ops.py +++ b/PyTorchSimFrontend/mlir/mlir_ops.py @@ -59,6 +59,10 @@ def constant(value, src_type, *args, **kwargs): str_val = str(value) if "inf" == str_val or "-inf" == str_val or "nan" == str_val: value = f"0x{mlir_common.MLIR_INF[str_val][src_type]:x}" + elif isinstance(value, bool): + value = 1 if value else 0 + if src_type[0] == "f": + value = format(float(value), ".20f") # scientific notation check elif "e" in str_val: value = format(float(value), ".20f") From 7bad17ae337873511a8b4e584d73767da56145bb Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Wed, 11 Mar 2026 19:51:41 +0900 Subject: [PATCH 18/31] [Fix] Refactor MLIR precision handling to be dtype-driven --- PyTorchSimFrontend/extension_config.py | 4 +- PyTorchSimFrontend/mlir/mlir_bmm_template.py | 10 +++-- PyTorchSimFrontend/mlir/mlir_cat_template.py | 20 +++++++--- PyTorchSimFrontend/mlir/mlir_common.py | 7 +++- PyTorchSimFrontend/mlir/mlir_conv_common.py | 11 +++--- .../mlir/mlir_conv_mt_template.py | 10 ++--- .../mlir/mlir_conv_sb_template.py | 8 ++-- .../mlir/mlir_conv_sbs_template.py | 8 ++-- PyTorchSimFrontend/mlir/mlir_conv_template.py | 8 ++-- PyTorchSimFrontend/mlir/mlir_gemm_template.py | 10 +++-- PyTorchSimFrontend/mlir/mlir_template.py | 38 +++++++++---------- README.md | 1 - 12 files changed, 76 insertions(+), 59 deletions(-) diff --git a/PyTorchSimFrontend/extension_config.py b/PyTorchSimFrontend/extension_config.py index eff6f573..fe8cc380 100644 --- a/PyTorchSimFrontend/extension_config.py +++ b/PyTorchSimFrontend/extension_config.py @@ -31,8 +31,6 @@ def __getattr__(name): "spad_size" : config_yaml["vpu_spad_size_kb_per_lane"] << 10 # Note: spad size per lane } - if name == "CONFIG_PRECISION": - return 4 # 32bit if name == "CONFIG_NUM_CORES": return config_yaml["num_cores"] if name == "vpu_vector_length_bits": @@ -132,7 +130,7 @@ def load_plan_from_module(module_path): CONFIG_USE_TIMING_POOLING = int(os.environ.get('TORCHSIM_USE_TIMING_POOLING', default=0)) -CONFIG_DEBUG_MODE = int(os.environ.get('TORCHSIM_DEBUG_MODE', default=0)) +CONFIG_DEBUG_MODE = int(os.environ.get('TORCHSIM_DEBUG_MODE', default=1)) def setup_logger(name=None, level=None): diff --git a/PyTorchSimFrontend/mlir/mlir_bmm_template.py b/PyTorchSimFrontend/mlir/mlir_bmm_template.py index 417d97cd..c5fd902f 100644 --- a/PyTorchSimFrontend/mlir/mlir_bmm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_bmm_template.py @@ -166,8 +166,9 @@ def render(self, tile_info = None, **kwargs): X, W, Y, Bias, W_tensor, X_tensor, B, M, N, K, n_extra_node, n_prologue_node = self.extract_info(template_buffer_node, epilogue_nodes, prologue_nodes) + precision_bytes = mlir_common.get_dtype_nbytes(X.get_dtype()) if tile_info is None: - TILE_M, TILE_N, TILE_K, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, M, N, K, n_extra_node, 0, n_prologue_node)[0] + TILE_M, TILE_N, TILE_K, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, M, N, K, n_extra_node, 0, n_prologue_node, precision_bytes)[0] else: TILE_M, TILE_N, TILE_K, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = tile_info @@ -350,10 +351,11 @@ def get_tile_candidates(self, prologue_nodes: Optional[List[IRNode]] = None, **kwargs): X, W, Y, Bias, W_tensor, X_tensor, B, M, N, K, n_extra_node, n_prologue_node = self.extract_info(template_buffer_node, epilogue_nodes, prologue_nodes) - return self.select_tile(kernel, M, N, K, n_extra_node, 0, n_prologue_node) + precision_bytes = mlir_common.get_dtype_nbytes(X.get_dtype()) + return self.select_tile(kernel, M, N, K, n_extra_node, 0, n_prologue_node, precision_bytes) - def select_tile(self, kernel, M, N, K, n_extra_node, n_extra_read, n_prologue_node): - tile_candidates = kernel.gemm_combination_mapping(M, N, K, n_extra_node=n_extra_node) + def select_tile(self, kernel, M, N, K, n_extra_node, n_extra_read, n_prologue_node, precision_bytes): + tile_candidates = kernel.gemm_combination_mapping(M, N, K, n_extra_node=n_extra_node, precision_bytes=precision_bytes) for idx, (TILE_M, TILE_N, TILE_K) in enumerate(tile_candidates): SUB_TILE_M = TILE_M if (TILE_M < kernel.vector_lane) or n_prologue_node else kernel.vector_lane SUB_TILE_N = TILE_N # if (TILE_N < kernel.vector_lane) or prologue_nodes else kernel.vector_lane diff --git a/PyTorchSimFrontend/mlir/mlir_cat_template.py b/PyTorchSimFrontend/mlir/mlir_cat_template.py index 7bee54ac..7abdfee6 100644 --- a/PyTorchSimFrontend/mlir/mlir_cat_template.py +++ b/PyTorchSimFrontend/mlir/mlir_cat_template.py @@ -56,6 +56,11 @@ def render( ): input_nodes = self.input_nodes y = self.output_node + dtype_infos = [("Y", y.get_dtype())] + [(f"X{i}", x.get_dtype()) for i, x in enumerate(input_nodes)] + if len({dtype for _, dtype in dtype_infos}) != 1: + dtype_desc = ", ".join(f"{name}={dtype}" for name, dtype in dtype_infos) + raise NotImplementedError(f"Mixed dtype Cat is not implemented yet ({dtype_desc})") + precision_bytes = mlir_common.get_dtype_nbytes(y.get_dtype()) num_inputs = len(input_nodes) rank = len(y.get_size()) @@ -68,7 +73,7 @@ def render( excluded_dims = self._compute_excluded_dims(tile_sizes) input_tile_sizes_dim = self._calculate_input_tile_sizes( - kernel, input_sizes, tile_sizes, num_inputs, rank + kernel, input_sizes, tile_sizes, num_inputs, rank, precision_bytes ) buffer_name_to_template_name, input_dram_names = self._build_buffer_mapping(input_nodes) input_tile_descs, output_tile_descs, unique_tile_descs = self._build_tile_descriptors( @@ -145,6 +150,11 @@ def get_tile_candidates( self.output_node = template_buffer_node y = self.output_node + dtype_infos = [("Y", y.get_dtype())] + [(f"X{i}", x.get_dtype()) for i, x in enumerate(self.input_nodes)] + if len({dtype for _, dtype in dtype_infos}) != 1: + dtype_desc = ", ".join(f"{name}={dtype}" for name, dtype in dtype_infos) + raise NotImplementedError(f"Mixed dtype Cat is not implemented yet ({dtype_desc})") + precision_bytes = mlir_common.get_dtype_nbytes(y.get_dtype()) num_inputs = len(self.input_nodes) output_sizes = [sz for d, sz in enumerate(y.get_size()) if d != self.dim] @@ -152,7 +162,7 @@ def get_tile_candidates( return [[1]] max_tile_total = kernel.spad_info["spad_size"] // ( - kernel.vector_lane * kernel.precision * 2 * num_inputs + kernel.vector_lane * precision_bytes * 2 * num_inputs ) dim_tile_candidates = [] @@ -174,7 +184,7 @@ def get_tile_candidates( tile_candidates = [ list(combo) for combo in itertools.product(*dim_tile_candidates) - if math.prod(combo) * (num_inputs + 1) * kernel.precision + if math.prod(combo) * (num_inputs + 1) * precision_bytes <= kernel.spad_info["spad_size"] * kernel.vector_lane ] @@ -199,11 +209,11 @@ def _compute_excluded_dims(self, tile_sizes: list) -> list: tile_sizes[idx] = 1 return excluded - def _calculate_input_tile_sizes(self, kernel, input_sizes, tile_sizes, num_inputs, rank): + def _calculate_input_tile_sizes(self, kernel, input_sizes, tile_sizes, num_inputs, rank, precision_bytes): """Calculate tile sizes along the concat dimension for each input.""" non_dim_tile_elements = math.prod(tile_sizes) if tile_sizes else 1 max_spad_per_input = kernel.spad_info["spad_size"] * kernel.vector_lane // 2 - extra_concat = math.ceil(max_spad_per_input / (non_dim_tile_elements * kernel.precision)) - num_inputs + extra_concat = math.ceil(max_spad_per_input / (non_dim_tile_elements * precision_bytes)) - num_inputs input_tile_sizes_dim = [] for i in range(num_inputs): diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index 3c408681..9f5dc6ab 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -90,6 +90,12 @@ "index": 64 } +def get_dtype_nbytes(dtype): + mlir_dtype = DTYPE_TO_MLIR.get(dtype) + if mlir_dtype is None or mlir_dtype not in MLIR_TO_BIT: + raise NotImplementedError(f"Unsupported dtype for precision calculation: {dtype}") + return MLIR_TO_BIT[mlir_dtype] // 8 + DTYPE_LOWP_FP = [ torch.bfloat16, torch.float16, @@ -579,7 +585,6 @@ def __init__(self): # Default HW setting self.vector_lane = extension_config.vpu_num_lanes self.spad_info = extension_config.CONFIG_SPAD_INFO - self.precision = extension_config.CONFIG_PRECISION self.num_cores = extension_config.CONFIG_NUM_CORES self.vlen = extension_config.vpu_vector_length_bits diff --git a/PyTorchSimFrontend/mlir/mlir_conv_common.py b/PyTorchSimFrontend/mlir/mlir_conv_common.py index 91e200a8..386e9bd5 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_common.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_common.py @@ -2,7 +2,7 @@ import math from typing import List, Optional -from PyTorchSimFrontend.mlir.mlir_common import MLIRKernelArgs +from PyTorchSimFrontend.mlir.mlir_common import MLIRKernelArgs, get_dtype_nbytes from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel from torch._inductor.ir import IRNode @@ -40,7 +40,7 @@ def render(self, **kwargs): raise NotImplementedError() - def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W): + def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W, precision_bytes): raise NotImplementedError() def extract_info(self, kernel, template_buffer_node, epilogue_nodes): @@ -58,6 +58,7 @@ def extract_info(self, kernel, template_buffer_node, epilogue_nodes): if len({dtype for _, dtype in dtype_infos}) != 1: dtype_desc = ", ".join(f"{name}={dtype}" for name, dtype in dtype_infos) raise NotImplementedError(f"Mixed dtype Conv is not implemented yet ({dtype_desc})") + precision_bytes = get_dtype_nbytes(X.get_dtype()) if epilogue_nodes is not None: extra_node_rw = { @@ -75,7 +76,7 @@ def extract_info(self, kernel, template_buffer_node, epilogue_nodes): PADDING_W=self.padding[1] STRIDE_H=self.stride[0] STRIDE_W=self.stride[1] - return X,W,Y,Bias,n_extra_node,BATCH,I_C,I_H,I_W,O_C,K_H,K_W,O_H,O_W,PADDING_H,PADDING_W,STRIDE_H,STRIDE_W + return X,W,Y,Bias,n_extra_node,BATCH,I_C,I_H,I_W,O_C,K_H,K_W,O_H,O_W,PADDING_H,PADDING_W,STRIDE_H,STRIDE_W,precision_bytes def get_tile_candidates(self, kernel: MLIRTemplateKernel, @@ -83,8 +84,8 @@ def get_tile_candidates(self, epilogue_nodes: Optional[List[IRNode]] = None, **kwargs): # Extract input arguments info - X, W, Y, Bias, n_extra_node, BATCH, I_C, I_H, I_W, O_C, K_H, K_W, O_H, O_W, PADDING_H, PADDING_W, STRIDE_H, STRIDE_W = self.extract_info(kernel, template_buffer_node, epilogue_nodes) - return self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) + X, W, Y, Bias, n_extra_node, BATCH, I_C, I_H, I_W, O_C, K_H, K_W, O_H, O_W, PADDING_H, PADDING_W, STRIDE_H, STRIDE_W, precision_bytes = self.extract_info(kernel, template_buffer_node, epilogue_nodes) + return self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W, precision_bytes) def outer_func_render(self, kernel_name, input_args): X, W = self.input_nodes[0], self.input_nodes[1] diff --git a/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py b/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py index e91014fa..8b8288a8 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py @@ -131,12 +131,12 @@ def render(self, tile_info = None, **kwargs): # Extract input arguments info - X, W, Y, Bias, n_extra_node, BATCH, I_C, I_H, I_W, O_C, K_H, K_W, O_H, O_W, PADDING_H, PADDING_W, STRIDE_H, STRIDE_W = self.extract_info(kernel, template_buffer_node, epilogue_nodes) + X, W, Y, Bias, n_extra_node, BATCH, I_C, I_H, I_W, O_C, K_H, K_W, O_H, O_W, PADDING_H, PADDING_W, STRIDE_H, STRIDE_W, precision_bytes = self.extract_info(kernel, template_buffer_node, epilogue_nodes) # Select tile size adn template conv_template = CONV_TEMPLATE if tile_info is None: - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W)[0] + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W, precision_bytes)[0] else: TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = tile_info SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N @@ -170,7 +170,7 @@ def render(self, Y_tile_desc.set_name("output_buffer") Y_dim = [Symbol("tile_m"), Symbol("tile_n"), Symbol("o_h"), Symbol("o_w")] Y_idx = [Y_dim[0]*O_C*O_H*O_W, Y_dim[1]*O_H*O_W, Y_dim[2]*O_W, Y_dim[3]] - + # Extract Bias info Bias_idx = [Number(0), Symbol("tile_n"), Number(0), Number(0)] Bias_tile_desc = mlir_common.MLIRMultiDimTile(Y_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) @@ -239,8 +239,8 @@ def render(self, kernel.add_loop_info([kernel.render_options["K_H"], kernel.render_options["K_W"], kernel.render_options["O_H"], kernel.render_options["O_W"], kernel.render_options["BATCH"], kernel.render_options["O_C"], kernel.render_options["I_C"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"], kernel.render_options["TILE_K"]]) return code - def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W): - tile_candidates = kernel.conv_multi_tile_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node) + def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W, precision_bytes): + tile_candidates = kernel.conv_multi_tile_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node, precision_bytes=precision_bytes) for idx, (TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K) in enumerate(tile_candidates): TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] diff --git a/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py b/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py index db2c64db..92efff66 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py @@ -132,12 +132,12 @@ def render(self, tile_info = None, **kwargs): # Extract input arguments info - X, W, Y, Bias, n_extra_node, BATCH, I_C, I_H, I_W, O_C, K_H, K_W, O_H, O_W, PADDING_H, PADDING_W, STRIDE_H, STRIDE_W = self.extract_info(kernel, template_buffer_node, epilogue_nodes) + X, W, Y, Bias, n_extra_node, BATCH, I_C, I_H, I_W, O_C, K_H, K_W, O_H, O_W, PADDING_H, PADDING_W, STRIDE_H, STRIDE_W, precision_bytes = self.extract_info(kernel, template_buffer_node, epilogue_nodes) # Select tile size adn template conv_template = CONV_TEMPLATE if tile_info is None: - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W)[0] + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W, precision_bytes)[0] else: TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = tile_info SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N @@ -238,8 +238,8 @@ def render(self, kernel.add_loop_info([kernel.render_options["K_H"], kernel.render_options["K_W"], kernel.render_options["O_H"], kernel.render_options["O_W"], kernel.render_options["BATCH"], kernel.render_options["O_C"], kernel.render_options["I_C"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"], kernel.render_options["TILE_K"]]) return code - def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W): - tile_candidates = kernel.conv_single_batch_mapping(BATCH, O_C, I_C, K_H, 1, O_H, O_W, self.stride, self.dilation, n_extra_node) # TODO: implement K_W + def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W, precision_bytes): + tile_candidates = kernel.conv_single_batch_mapping(BATCH, O_C, I_C, K_H, 1, O_H, O_W, self.stride, self.dilation, n_extra_node, precision_bytes=precision_bytes) # TODO: implement K_W for idx, (TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K) in enumerate(tile_candidates): TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] + (TILE_K_W - 1) * self.dilation[1] diff --git a/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py b/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py index 95db53c3..dfd418d9 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py @@ -132,12 +132,12 @@ def render(self, tile_info = None, **kwargs): # Extract input arguments info - X, W, Y, Bias, n_extra_node, BATCH, I_C, I_H, I_W, O_C, K_H, K_W, O_H, O_W, PADDING_H, PADDING_W, STRIDE_H, STRIDE_W = self.extract_info(kernel, template_buffer_node, epilogue_nodes) + X, W, Y, Bias, n_extra_node, BATCH, I_C, I_H, I_W, O_C, K_H, K_W, O_H, O_W, PADDING_H, PADDING_W, STRIDE_H, STRIDE_W, precision_bytes = self.extract_info(kernel, template_buffer_node, epilogue_nodes) # Select tile size adn template conv_template = CONV_TEMPLATE if tile_info is None: - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W)[0] + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W, precision_bytes)[0] else: TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = tile_info SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N @@ -239,8 +239,8 @@ def render(self, kernel.add_loop_info([kernel.render_options["K_H"], kernel.render_options["K_W"], kernel.render_options["O_H"], kernel.render_options["O_W"], kernel.render_options["BATCH"], kernel.render_options["O_C"], kernel.render_options["I_C"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"], kernel.render_options["TILE_K"]]) return code - def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W): - tile_candidates = kernel.conv_single_batch_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node) # TODO: implement K_W + def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W, precision_bytes): + tile_candidates = kernel.conv_single_batch_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node, precision_bytes=precision_bytes) # TODO: implement K_W for idx, (TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K) in enumerate(tile_candidates): TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] + (TILE_K_W - 1) * self.dilation[1] diff --git a/PyTorchSimFrontend/mlir/mlir_conv_template.py b/PyTorchSimFrontend/mlir/mlir_conv_template.py index 3666b3c9..178ba7c6 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_template.py @@ -136,12 +136,12 @@ def render(self, tile_info = None, **kwargs): # Extract input arguments info - X, W, Y, Bias, n_extra_node, BATCH, I_C, I_H, I_W, O_C, K_H, K_W, O_H, O_W, PADDING_H, PADDING_W, STRIDE_H, STRIDE_W = self.extract_info(kernel, template_buffer_node, epilogue_nodes) + X, W, Y, Bias, n_extra_node, BATCH, I_C, I_H, I_W, O_C, K_H, K_W, O_H, O_W, PADDING_H, PADDING_W, STRIDE_H, STRIDE_W, precision_bytes = self.extract_info(kernel, template_buffer_node, epilogue_nodes) # Select tile size adn template conv_template = CONV_TEMPLATE if tile_info is None: - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W)[0] + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W, precision_bytes)[0] else: TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = tile_info TOG_latency = BATCH if TILE_M > BATCH else TILE_M @@ -243,8 +243,8 @@ def render(self, kernel.add_loop_info([kernel.render_options["K_H"], kernel.render_options["K_W"], kernel.render_options["O_H"], kernel.render_options["O_W"], kernel.render_options["BATCH"], kernel.render_options["O_C"], kernel.render_options["I_C"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"], kernel.render_options["TILE_K"]]) return code - def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W): - tile_candidates = kernel.conv_combination_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node) + def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W, precision_bytes): + tile_candidates = kernel.conv_combination_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node, precision_bytes=precision_bytes) for idx, (TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K) in enumerate(tile_candidates): TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] + (TILE_K_W - 1) * self.dilation[1] diff --git a/PyTorchSimFrontend/mlir/mlir_gemm_template.py b/PyTorchSimFrontend/mlir/mlir_gemm_template.py index eb391dba..9c61c3d9 100644 --- a/PyTorchSimFrontend/mlir/mlir_gemm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_gemm_template.py @@ -117,8 +117,9 @@ def render(self, tile_info = None, **kwargs): X, W, Y, M, N, K, n_epilogue_node, n_prologue_node, n_extra_read = self.extract_info(template_buffer_node, epilogue_nodes, prologue_nodes) + precision_bytes = mlir_common.get_dtype_nbytes(X.get_dtype()) if tile_info is None: - TILE_M, TILE_N, TILE_K, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, M, N, K, n_epilogue_node, n_extra_read, n_prologue_node)[0] + TILE_M, TILE_N, TILE_K, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, M, N, K, n_epilogue_node, n_extra_read, n_prologue_node, precision_bytes)[0] else: TILE_M, TILE_N, TILE_K, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = tile_info @@ -274,7 +275,8 @@ def get_tile_candidates(self, prologue_nodes: Optional[List[IRNode]] = None, **kwargs): X, W, Y, M, N, K, n_epilogue_node, n_prologue_node, n_extra_read = self.extract_info(template_buffer_node, epilogue_nodes, prologue_nodes) - return self.select_tile(kernel, M, N, K, n_epilogue_node, n_extra_read, n_prologue_node) + precision_bytes = mlir_common.get_dtype_nbytes(X.get_dtype()) + return self.select_tile(kernel, M, N, K, n_epilogue_node, n_extra_read, n_prologue_node, precision_bytes) def extract_info(self, template_buffer_node, epilogue_nodes, prologue_nodes): if template_buffer_node is not None: @@ -307,7 +309,7 @@ def extract_info(self, template_buffer_node, epilogue_nodes, prologue_nodes): M, N, K = X_tensor.size()[0], W_tensor.size()[1], X_tensor.size()[1] return X,W,Y,M,N,K,n_epilogue_node,n_prologue_node,len(n_extra_read) - def select_tile(self, kernel, M, N, K, n_extra_node, n_extra_read, n_prologue_node): + def select_tile(self, kernel, M, N, K, n_extra_node, n_extra_read, n_prologue_node, precision_bytes): data = {} gemm_shape = f"{M}_{N}_{K}" if "external" in extension_config.codegen_mapping_strategy: @@ -327,7 +329,7 @@ def select_tile(self, kernel, M, N, K, n_extra_node, n_extra_read, n_prologue_no else: # case 2: use heuristic mapping min_tile = (n_extra_node + n_prologue_node) == 0 - tile_candidates = kernel.gemm_combination_mapping(M, N, K, max(n_extra_read-2, 0), n_prologue_node, min_tile=True) + tile_candidates = kernel.gemm_combination_mapping(M, N, K, max(n_extra_read-2, 0), n_prologue_node, min_tile=True, precision_bytes=precision_bytes) # Edge case if (M == 0) or (N == 0) or (K == 0): diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index 9cc79e0a..81b3d606 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -150,10 +150,10 @@ def add_loop_info(self, mat_size, tile_size): for idx, (loop_size, stride) in enumerate(zip(mat_size, tile_size)): self.loop_info[f"index{idx}"] = [0, loop_size, stride] - def gemmini_gemm_mapping(self, M, N, K): + def gemmini_gemm_mapping(self, M, N, K, precision_bytes=4): spad_size = self.spad_info["spad_size"] * self.vector_lane num_cores = self.num_cores - precision = self.precision + precision = precision_bytes dim_I, dim_J, dim_K = M, N, K dim = self.vector_lane @@ -205,7 +205,7 @@ def gemmini_gemm_mapping(self, M, N, K): return inner_I, inner_J, inner_K - def gemm_combination_mapping(self, M, N, K, n_extra_node=0, n_prologue_node=0, pad_k=True, min_tile=False, is_conv=False): + def gemm_combination_mapping(self, M, N, K, n_extra_node=0, n_prologue_node=0, pad_k=True, min_tile=False, is_conv=False, precision_bytes=4): tile_candidates = [] spad_size_per_lane = self.spad_info["spad_size"] spad_size = spad_size_per_lane * self.vector_lane @@ -233,11 +233,11 @@ def gemm_combination_mapping(self, M, N, K, n_extra_node=0, n_prologue_node=0, p tile_M = i * self.vector_lane if M > self.vector_lane else M_padded for j in tile_N_range: tile_N = j * self.vector_lane if N > self.vector_lane else N_padded - used_spad_size = (tile_M * tile_K * (1 + n_prologue_node) + tile_K * tile_N + tile_M * tile_N * (1 + n_extra_node)) * self.precision + used_spad_size = (tile_M * tile_K * (1 + n_prologue_node) + tile_K * tile_N + tile_M * tile_N * (1 + n_extra_node)) * precision_bytes weight_size_per_lane = self.get_spad_size_per_lane(tile_K, tile_N) input_size_per_lane = self.get_spad_size_per_lane(tile_M * (1 + n_prologue_node), tile_K) output_size_per_lane = self.get_spad_size_per_lane(tile_M * (1 + n_extra_node), tile_N) - used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * self.precision + used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * precision_bytes check_spad_size = (used_spad_size < max_spad_size and used_spad_size_per_lane < max_spad_per_lane) if check_spad_size: dir_path = f"{extension_config.CONFIG_TORCHSIM_DIR}/validation/gemm_candidates" @@ -259,11 +259,11 @@ def gemm_combination_mapping(self, M, N, K, n_extra_node=0, n_prologue_node=0, p tile_M = i * self.vector_lane if M > self.vector_lane else M_padded for j in tile_N_range: tile_N = j * self.vector_lane if N > self.vector_lane else N_padded - used_spad_size = (tile_M * tile_K * (1 + n_prologue_node) + tile_K * tile_N + tile_M * tile_N * (1 + n_extra_node)) * self.precision + used_spad_size = (tile_M * tile_K * (1 + n_prologue_node) + tile_K * tile_N + tile_M * tile_N * (1 + n_extra_node)) * precision_bytes weight_size_per_lane = self.get_spad_size_per_lane(tile_K, tile_N) input_size_per_lane = self.get_spad_size_per_lane(tile_M * (1 + n_prologue_node), tile_K) output_size_per_lane = self.get_spad_size_per_lane(tile_M * (1 + n_extra_node), tile_N) - used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * self.precision + used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * precision_bytes n_tile = math.ceil(M / max(tile_M, 128)) * math.ceil(N / max(tile_N, 128)) check_spad_size = (used_spad_size < max_spad_size and used_spad_size_per_lane < max_spad_per_lane) if check_spad_size and max_used_spad_size < used_spad_size and maximize_i_j <= tile_M * tile_N and n_tile >= minimum_n_tile and max(tile_N, 128) // max(tile_M, 128) < 10: @@ -277,7 +277,7 @@ def gemm_combination_mapping(self, M, N, K, n_extra_node=0, n_prologue_node=0, p tile_candidates = [v for _, v in tile_candidates] return tile_candidates - def conv_combination_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, n_extra_node=0): + def conv_combination_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, n_extra_node=0, precision_bytes=4): tile_candidates = [] spad_size_per_lane = self.spad_info["spad_size"] spad_size = spad_size_per_lane * self.vector_lane @@ -285,7 +285,7 @@ def conv_combination_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation max_spad_per_lane = spad_size_per_lane // 2 # double buffer max_used_spad_size = 0 - M, N, K = self.gemm_combination_mapping(M, N, K, n_extra_node=n_extra_node, pad_k=False, is_conv=True)[0] + M, N, K = self.gemm_combination_mapping(M, N, K, n_extra_node=n_extra_node, pad_k=False, is_conv=True, precision_bytes=precision_bytes)[0] max_k_h_w = 1 # maximize kernel size max_o_h_w = 1 # maximize output size K = min(K, self.vector_lane) @@ -298,11 +298,11 @@ def conv_combination_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation weight_size = k_w * k_h * K * N input_size = i_w * i_h * M * K output_size = o_w * o_h * M * N - used_spad_size = (weight_size + input_size + output_size * (1 + n_extra_node)) * self.precision + used_spad_size = (weight_size + input_size + output_size * (1 + n_extra_node)) * precision_bytes weight_size_per_lane = self.get_spad_size_per_lane(k_w * k_h * K, N) input_size_per_lane = self.get_spad_size_per_lane(i_w * i_h * M, K) output_size_per_lane = self.get_spad_size_per_lane(o_w * o_h * M * (1 + n_extra_node), N) - used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * self.precision + used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * precision_bytes check_spad_size = (used_spad_size < max_spad_size and used_spad_size_per_lane < max_spad_per_lane) if check_spad_size: tile_candidates.append((used_spad_size, (k_h, k_w, o_h, o_w, M, N, K))) @@ -318,7 +318,7 @@ def conv_combination_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation tile_candidates = [v for _, v in tile_candidates] return tile_candidates - def conv_multi_tile_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, n_extra_node=0): + def conv_multi_tile_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, n_extra_node=0, precision_bytes=4): tile_candidates = [] spad_size_per_lane = self.spad_info["spad_size"] spad_size = spad_size_per_lane * self.vector_lane @@ -326,7 +326,7 @@ def conv_multi_tile_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, max_spad_per_lane = spad_size_per_lane // 2 max_used_spad_size = 0 - M, N, K = self.gemm_combination_mapping(M, N, K * K_W, n_extra_node=n_extra_node, pad_k=False, is_conv=True)[0] + M, N, K = self.gemm_combination_mapping(M, N, K * K_W, n_extra_node=n_extra_node, pad_k=False, is_conv=True, precision_bytes=precision_bytes)[0] max_k_h_w = K_W for o_h in sympy.divisors(O_H): for o_w in sympy.divisors(O_W): @@ -336,11 +336,11 @@ def conv_multi_tile_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, weight_size = 1 * k_h * K * N input_size = i_w * i_h * M * K output_size = o_w * o_h * M * N - used_spad_size = (weight_size + input_size + output_size * (1 + n_extra_node)) * self.precision + used_spad_size = (weight_size + input_size + output_size * (1 + n_extra_node)) * precision_bytes weight_size_per_lane = self.get_spad_size_per_lane(1 * k_h * K, N) input_size_per_lane = self.get_spad_size_per_lane(i_w * i_h * M, K) output_size_per_lane = self.get_spad_size_per_lane(o_w * o_h * M * (1 + n_extra_node), N) - used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * self.precision + used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * precision_bytes check_spad_size = (used_spad_size < max_spad_size and used_spad_size_per_lane < max_spad_per_lane) if check_spad_size: tile_candidates.append((used_spad_size, (k_h, K_W, o_h, o_w, M, N, K))) @@ -354,7 +354,7 @@ def conv_multi_tile_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, tile_candidates = [v for _, v in tile_candidates] return tile_candidates - def conv_single_batch_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, n_extra_node=0): + def conv_single_batch_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, n_extra_node=0, precision_bytes=4): tile_candidates = [] spad_size_per_lane = self.spad_info["spad_size"] spad_size = spad_size_per_lane * self.vector_lane @@ -362,7 +362,7 @@ def conv_single_batch_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilatio max_spad_per_lane = spad_size_per_lane // 2 max_used_spad_size = 0 - M, N, K = self.gemm_combination_mapping(O_W, N, K, n_extra_node=n_extra_node, pad_k=False, is_conv=True)[0] + M, N, K = self.gemm_combination_mapping(O_W, N, K, n_extra_node=n_extra_node, pad_k=False, is_conv=True, precision_bytes=precision_bytes)[0] max_k_h_w = 1 for o_h in sympy.divisors(O_H): for k_h in sympy.divisors(K_H): @@ -372,11 +372,11 @@ def conv_single_batch_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilatio weight_size = k_w * k_h * K * N input_size = i_w * i_h * k_w * K output_size = M * o_h * N - used_spad_size = (weight_size + input_size + output_size * (1 + n_extra_node)) * self.precision + used_spad_size = (weight_size + input_size + output_size * (1 + n_extra_node)) * precision_bytes weight_size_per_lane = self.get_spad_size_per_lane(k_w * k_h * K, N) input_size_per_lane = self.get_spad_size_per_lane(i_w * i_h * k_w, K) output_size_per_lane = self.get_spad_size_per_lane(M * o_h * (1 + n_extra_node), N) - used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * self.precision + used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * precision_bytes check_spad_size = (used_spad_size < max_spad_size and used_spad_size_per_lane < max_spad_per_lane) if check_spad_size: tile_candidates.append((used_spad_size, (k_h, k_w, o_h, M, M, N, K))) diff --git a/README.md b/README.md index 4a3ef145..f55995c9 100644 --- a/README.md +++ b/README.md @@ -396,7 +396,6 @@ export TORCHSIM_USE_TIMING_POOLING=0 # use lightweight pooling for timing "icnt_injection_ports_per_core" : 16 // Interconnect injection ports per core "icnt_config_path" : "../configs/booksim2_configs/fly_c4_m32.icnt", // Booksim2 config file path - "precision" : 4, // Element's precision in tensor (Byte) "scheduler" : "simple", // Scheduler type (Now, only support simple scheduler) "num_partition" : 2, // Multi-core Partitioning "partition": { // allocate request queue index From fadba78ef71f69992b321c9318a23a1377506121 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Thu, 12 Mar 2026 14:42:37 +0900 Subject: [PATCH 19/31] [Fix] malloc size align + fix origin info --- AsmParser/tog_generator.py | 4 ++-- PyTorchSimFrontend/extension_codecache.py | 21 +++++++++++++++++++ PyTorchSimFrontend/mlir/mlir_autotune.py | 2 +- .../mlir/mlir_codegen_backend.py | 3 ++- PyTorchSimFrontend/mlir/mlir_scheduling.py | 6 ++++-- 5 files changed, 30 insertions(+), 6 deletions(-) diff --git a/AsmParser/tog_generator.py b/AsmParser/tog_generator.py index 5f586d99..a12460e3 100644 --- a/AsmParser/tog_generator.py +++ b/AsmParser/tog_generator.py @@ -37,7 +37,7 @@ class tog_generator: StonneTraceCompute= 6 StonneTraceLoad = 7 StonneTraceStore = 8 - def __init__(self, origins="Unknown") -> None: + def __init__(self, origins={"Unknown"}) -> None: self.module_name = "tile_operation_graph" self.module = None self.raw_graph = {} @@ -226,7 +226,7 @@ def generate_tile_graph(self, name="tile_graph", cycle_list=list, x_offset=int, offset = w_offset if is_preload else x_offset iter_node.torchsim_overlapping_cycle = max(iter_node.torchsim_cycle - offset, 0) - origin_info = "_".join(map(str, self.origins)) + origin_info = self.origins if isinstance(self.origins, str) else "_".join(map(str, self.origins)) onnx_node_list = [node.to_onnx() for node in node_list] # Exclude root node dump_onnx_graph(name, onnx_node_list, vector_lane, origin_info, stonneGraph=stonneGraph) diff --git a/PyTorchSimFrontend/extension_codecache.py b/PyTorchSimFrontend/extension_codecache.py index 8454dee6..b1c457d3 100644 --- a/PyTorchSimFrontend/extension_codecache.py +++ b/PyTorchSimFrontend/extension_codecache.py @@ -72,6 +72,14 @@ def mlir_compile_command(filename, vectorlane_size, vlen=256): {'--print-after-all' if extension_config.CONFIG_TORCHSIM_DUMP_LLVM_IR else ''} \ -O2 {filename}.ll -o {filename}.o """, + ).strip(), + re.sub(r"[ \n]+", " ", + f""" + {extension_config.CONFIG_TORCHSIM_LLVM_PATH}/llc \ + -relocation-model=pic -march=riscv64 -O3 --stack-size-section \ + -mattr=+m,+f,+d,+a,+c,+v,+zvfh,+xsfvcp,zvl{vlen}b \ + -O2 {filename}.ll -o {filename}.s + """, ).strip()] def mlir_gem5_compile_command(filename, sample_filename, tog_file, vectorlane_size, vlen=256): @@ -168,11 +176,13 @@ def load(cls, source_code, opt_cmd = shlex.split(cmds[0]) translate_cmd = shlex.split(cmds[1]) llc_cmd = shlex.split(cmds[2]) + llc_asm_cmd = shlex.split(cmds[3]) with lock: try: subprocess.check_call(opt_cmd) subprocess.check_call(translate_cmd) subprocess.check_call(llc_cmd) + subprocess.check_call(llc_asm_cmd) except subprocess.CalledProcessError as e: logger.error(f"Command failed with exit code {e.returncode}") logger.error(f"Error output: {e.output.decode() if isinstance(e.output, bytes) else e.output}") @@ -182,6 +192,17 @@ def load(cls, source_code, val_llvm_caller.generate_wrapper_file(write_path, validation_wrapper_name) val_llvm_caller.compile_wih_kernel(write_path, key, validation_wrapper_name, validation_binary_name, new_link_option) + + stack_size = val_llvm_caller.parse_stack_sizes(f"{write_path}/{key}.s", vlenb=vlenb) + spad_size = val_llvm_caller.get_spad_size(validation_binary_path) + spad_usage = stack_size + spad_size # Spad usage per lane + if extension_config.CONFIG_SPAD_INFO["spad_size"] < spad_usage: + logger.debug( + f"Scratchpad size exceeded: required {spad_usage} bytes, " + f"but only {extension_config.CONFIG_SPAD_INFO['spad_size']} bytes available." + ) + raise SpadOverflowError() + # Skip if TOG file already exists if os.path.isfile(tog_path): return key diff --git a/PyTorchSimFrontend/mlir/mlir_autotune.py b/PyTorchSimFrontend/mlir/mlir_autotune.py index 4503584c..caf4d6da 100644 --- a/PyTorchSimFrontend/mlir/mlir_autotune.py +++ b/PyTorchSimFrontend/mlir/mlir_autotune.py @@ -85,7 +85,7 @@ def cached_run_fn(*args, **kwargs): self.source_code, vectorlane_size=self.extra_args["vector_lane"], loop_size=None, spad_info=self.extra_args["spad_info"], vlen=self.extra_args["vlen"], arg_attributes=self.extra_args["arg_attributes"], - origins="Unknown", silent_mode=True, + origins=self.extra_args["origins"], silent_mode=True, autotune=self.extra_args['autotune']) args = [ diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index 43cb65a4..24d6636a 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -285,7 +285,7 @@ def __init__(self, kernel_group, reason=None): self.gem5_header = IndentedBuffer() self.header.writeline("#include ") self.header.writeline("#include ") - self.header.writeline("void* __wrap_malloc(size_t size) { return sbrk(size); }") + self.header.writeline("void* __wrap_malloc(size_t size) { size = (size + 511UL) & ~511UL; return sbrk(size); }") # Align to 512 bytes self.header.writeline("void __wrap_free(void *ptr) { return; }") self.reduction_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="tmp_acc") self.spad_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="spad") @@ -1060,6 +1060,7 @@ def run_bench(self, nodes, kernel_name, src_code): "vlen" : self.vlen, "arg_attributes" : arg_attributes, "autotune" : True, + "origins" : {str(i) for node in nodes for i in node.node.origins}, }, source_code=src_code, ) diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index 2f9c9704..22d1011b 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -276,7 +276,7 @@ def codegen_node(self, _node): MLIRScheduling.count += 1 src_code, meta_code = ex_kernel.codegen_nodes(nodes, kernel_name_candidate) kernel_name = self.define_kernel(src_code, meta_code, kernel_name_candidate, ex_kernel.vector_lane, - ex_kernel.spad_info, origins= {str(i) for i in nodes[0].node.origins}) + ex_kernel.spad_info, origins={str(i) for node in nodes for i in node.node.origins}) ex_kernel.call_kernel(kernel_name) _, args, _, _ = ex_kernel.args.mlir_argdefs() args = ", ".join(args) @@ -332,8 +332,10 @@ def codegen_template(self, template_node, epilogue_nodes, prologue_nodes): src_code, meta_code = kernel.codegen_nodes(tile_candidates, render, template_node, prologue_nodes, epilogue_nodes) with kernel: + all_nodes = [template_node] + (epilogue_nodes or []) + (prologue_nodes or []) + origins = {str(i) for n in all_nodes for i in n.node.origins} kernel_name = self.define_kernel(src_code, meta_code, kernel.kernel_name, kernel.vector_lane, kernel.spad_info, - kernel.loop_size, origins={str(i) for i in template_node.node.origins}) + kernel.loop_size, origins=origins) self.define_function(kernel) kernel.call_kernel(kernel_name) From 0189ab978fbe3ce02e72bb77f66c2bd10342babe Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Thu, 12 Mar 2026 15:06:28 +0900 Subject: [PATCH 20/31] [TOGSim] Fix local/remote memory stat --- TOGSim/src/Simulator.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/TOGSim/src/Simulator.cc b/TOGSim/src/Simulator.cc index b5b9c778..d7fe9f1b 100644 --- a/TOGSim/src/Simulator.cc +++ b/TOGSim/src/Simulator.cc @@ -121,7 +121,7 @@ void Simulator::icnt_cycle() { front->set_core_id(core_id); if (!_icnt->is_full(port_id, front)) { int node_id = _dram->get_channel_id(front) / _config.dram_channels_per_partitions; - if (core_id == node_id) + if (get_partition_id(core_id) == node_id) _cores[core_id]->inc_numa_local_access(); else _cores[core_id]->inc_numa_remote_access(); From 5268be2df8352f3470bee4e60739b9467fa07ca8 Mon Sep 17 00:00:00 2001 From: HamHyungkyu Date: Thu, 12 Mar 2026 19:30:04 +0900 Subject: [PATCH 21/31] [Frontend/template] add SPDA decode GQA template imlementation --- .../mlir/mlir_codegen_backend.py | 7 +- PyTorchSimFrontend/mlir/mlir_lowering.py | 37 +- PyTorchSimFrontend/mlir/mlir_sdpa_template.py | 888 +++++++++++++++++- PyTorchSimFrontend/mlir/mlir_template.py | 4 +- tests/test_sdpa.py | 57 +- 5 files changed, 973 insertions(+), 20 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index 24d6636a..38125e31 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -470,7 +470,12 @@ def parse_index_list(self, expr_list:list, offset=sympy.Number(0)) -> common.CSE new_expr_list[idx] = arg.subs(arg.args[1], dim_list[idx]) indices.append(str(new_arg)) elif not arg.is_number: - new_arg = sympy.Symbol(str(self.convert_index(arg))) + try: + new_arg = sympy.Symbol(str(self.convert_index(arg))) + #not implemented case + except NotImplementedError: + print(f"Not implemented case: {arg}") + raise NotImplementedError(f"Not implemented case: {arg}") new_expr_list[idx] = new_arg.subs(new_arg, dim_list[idx]) indices.append(str(new_arg)) else: diff --git a/PyTorchSimFrontend/mlir/mlir_lowering.py b/PyTorchSimFrontend/mlir/mlir_lowering.py index 9d49f212..ac7eb853 100644 --- a/PyTorchSimFrontend/mlir/mlir_lowering.py +++ b/PyTorchSimFrontend/mlir/mlir_lowering.py @@ -16,9 +16,15 @@ from PyTorchSimFrontend.mlir.mlir_conv_sb_template import MLIRConvSingleBatchTemplate from PyTorchSimFrontend.mlir.mlir_conv_sbs_template import MLIRConvSingleBatchStridedTemplate from PyTorchSimFrontend.mlir.mlir_maxpool_template import MLIRMaxPoolTemplate -from PyTorchSimFrontend.mlir.mlir_sdpa_template import MLIRFlashSDPATemplate, flash_sdpa_args, calculate_scale from PyTorchSimFrontend.mlir.mlir_cat_template import MLIRCatTemplate from PyTorchSimFrontend.mlir.mlir_sort_template import MLIRSortTemplate, MLIRStableSortTemplate +from PyTorchSimFrontend.mlir.mlir_sdpa_template import ( + MLIRFlashSDPATemplate, + MLIRDecodeGQASDPAPartialTemplate, + MLIRDecodeGQASDPAReduceTemplate, + flash_sdpa_args, + calculate_scale, +) from PyTorchSimFrontend import extension_config aten = torch.ops.aten @@ -58,6 +64,35 @@ def tuned_flash_sdpa( scale = calculate_scale(query, scale) N, Hq, H, L, S, E, Ev, layout, query, key, value = flash_sdpa_args(query, key, value) + # Decode-only GQA fast path: q is (B,Hq,1,Dh), B==1, Hq!=H, Hq%H==0. + # Always use the 2-kernel decode path: + # 1) block partials over (kv head, sequence block) + # 2) reduce/merge across blocks + # This keeps KV shared across qsub, avoids dh0-outer duplication, and + # stores compact partials instead of full score/prob tensors in DRAM. + if L == 1 and Hq != H and N == 1 and (Hq % H) == 0: + g = Hq // H + vector_lane = extension_config.vpu_num_lanes + tile_e = vector_lane + dh_tiles = E // tile_e + decode_gqa_block_size = 512 + BlkS = decode_gqa_block_size if S >= decode_gqa_block_size else int(S) + # Padding-based tail handling: allow S not divisible by BlkS. + nblk = (S + BlkS - 1) // BlkS + HgDhTiles = H * g * dh_tiles + tile_pack = tile_e * 2 + + partial_layout = ir.FixedLayout( + query.get_device(), + torch.float32, + [HgDhTiles, nblk, tile_pack], + ) + partial_tmpl = MLIRDecodeGQASDPAPartialTemplate([query, key, value], partial_layout, scale, BlkS=BlkS) + partial = partial_tmpl.generate().output_node() + reduce_tmpl = MLIRDecodeGQASDPAReduceTemplate([partial], layout, BlkS=BlkS) + out_node = reduce_tmpl.generate().output_node() + return (out_node, None, None, None, None, None, None, None, None) + mlir_template = MLIRFlashSDPATemplate([query, key, value], layout, scale) # _scaled_dot_product_flash_attention has to return a tuple which has 9 values diff --git a/PyTorchSimFrontend/mlir/mlir_sdpa_template.py b/PyTorchSimFrontend/mlir/mlir_sdpa_template.py index 05030f27..1cd810e8 100644 --- a/PyTorchSimFrontend/mlir/mlir_sdpa_template.py +++ b/PyTorchSimFrontend/mlir/mlir_sdpa_template.py @@ -48,23 +48,28 @@ def flash_sdpa_args( s = V.graph.sizevars.guard_equals(sk, sv) e = V.graph.sizevars.guard_equals(eq, ek) - # While there are no theoretical requirements for e == ev, - # this implementation enforces e == ev for simplicity. - # Distinct notations are still maintained to ensure future compatibility and clarity. + # While there are no theoretical requirements for e == ev, + # this implementation currently enforces e == ev for simplicity. if e != ev: - raise NotImplementedError("Flash SDPA does not support mismatched head dimensions between query and value.") - - # Flash attention does not split tiles along the head dimension (e or ev). - # Therefore, the head dimension size must be less than or equal to the number of vlanes. - vector_lane = extension_config.vpu_num_lanes - if e > vector_lane or ev > vector_lane: - raise ValueError(f"The head dimension size must be less than or equal to the number of vlanes (e: {e}, ev: {ev}, vlanes: {vector_lane}).") + raise NotImplementedError( + "Flash SDPA currently requires matching head dimensions between query and value (e == ev)." + ) + + # Support head dimensions larger than vector lanes by tiling e/ev. + # For now, require multiples of vector lanes (covers 64/128 with vlanes=16). + vector_lane = extension_config.vpu_num_lanes + if (e % vector_lane) != 0: + raise NotImplementedError( + f"Flash SDPA currently requires e to be a multiple of vlanes (e: {e}, vlanes: {vector_lane})." + ) - # The aten._scaled_dot_product_flash_attention kernel does not accept an explicit enable_gqa parameter. - # Instead, the Flash SDPA implementation infers GQA usage by checking if hq != hk. - # The Flash SDPA for GQA will be implemented after implementing its native version. - if hq != h : - raise NotImplementedError("Flash SDPA for GQA is not supported yet.") + # Minimal GQA support (single-batch only for now). + # We map each query head to a KV head by grouping: hq = g * h. + if hq != h: + if n != 1: + raise NotImplementedError("Flash SDPA GQA is currently supported only for n == 1.") + if (hq % h) != 0: + raise NotImplementedError(f"Flash SDPA GQA requires hq % h == 0 (hq: {hq}, h: {h}).") layout = FixedLayout( query.get_device(), @@ -479,3 +484,856 @@ def select_tile(self, kernel, l, s, e, n_extra_node, n_extra_read, n_prologue_no tile_candidates[idx] = tile_l,tile_s,tile_e,subtile_l,subtile_s,subtile_e return tile_candidates + + +# --------------------------- +# Decode-only GQA SDPA (Lq == 1) +# --------------------------- + +DECODE_GQA_SDPA_TEMPLATE = r""" +// Decode GQA SDPA kernel (Lq == 1) +// B = {{ B }} +// Hq = {{ Hq }} +// H = {{ H }} +// g = {{ g }} +// S = {{ S }} +// Dh = {{ Dh }} +// BlkS = {{ BlkS }} +// tile_s = {{ tile_s }} +// tile_e = {{ tile_e }} +// dh_tiles = {{ dh_tiles }} +{{kernel.def_global_vars()}} + +func.func @{{ KERNEL_NAME }}{{kernel.def_kernel(inputs=[query, key, value], outputs=[out], names_str="query, key, value, out", input_reorder=input_reorder)}} { + // IO buffers follow input dtype (fp16/bf16/f32) + {{ kernel.def_sram_buffer("query", q_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("key", k_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("value", v_tile_desc, indent_size=2) }} + // Softmax output used for SV matmul (io dtype) + {{ kernel.def_sram_buffer("mul", mul_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("score", score_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("prob", prob_desc, indent_size=2) }} + // Accumulator in fp32 (stable) + {{ kernel.def_sram_buffer("out_acc", out_acc_tile_desc, indent_size=2) }} + // Temp output in io dtype for SV matmul result + {{ kernel.def_sram_buffer("out_io", out_io_tile_desc, indent_size=2) }} + // Softmax running stats in fp32 + {{ kernel.def_sram_buffer("max", max_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("sum", sum_desc, indent_size=2) }} + + %c0 = arith.constant 0.0 : {{ acc_stype }} + %c1 = arith.constant 1.0 : {{ acc_stype }} + %c_scale = arith.constant {{ scale }} : {{ acc_stype }} + %c_neg_inf = arith.constant -1.0e+30 : {{ acc_stype }} + + %v0_e_acc = arith.constant dense<0.0> : vector<{{ tile_e }}x{{ acc_stype }}> + %v0_e_io = arith.constant dense<0.0> : vector<{{ tile_e }}x{{ io_stype }}> + %v0_2x = arith.constant dense<0.0> : vector<2x{{ acc_stype }}> + %v_neg_inf_2x = arith.constant dense<-1.0e+30> : vector<2x{{ acc_stype }}> + %v0_s_acc = arith.constant dense<0.0> : vector<{{ tile_s }}x{{ acc_stype }}> + + %v_scale = vector.broadcast %c_scale : {{ acc_stype }} to vector<{{ tile_s }}x{{ acc_stype }}> + + {{ kernel.def_local_vars(indent_size=2) }} + + // kv_head parallelism is the natural unit for GQA reuse + affine.for %kv = 0 to {{ H }} { + // Process S in blocks (BlkS). Sequential inside a core. + affine.for %blk = 0 to {{ S }} step {{ BlkS }} { + // Initialize per-qsub accumulators for this (kv, blk) + affine.for %qsub = 0 to {{ g }} { + affine.vector_store %v_neg_inf_2x, %max_buffer[%qsub, 0] : {{ max_desc.get_mlir_shape(acc_stype) }}, vector<2x{{ acc_stype }}> + affine.vector_store %v0_2x, %sum_buffer[%qsub, 0] : {{ sum_desc.get_mlir_shape(acc_stype) }}, vector<2x{{ acc_stype }}> + affine.for %dht = 0 to {{ dh_tiles }} { + affine.vector_store %v0_e_acc, %out_acc_buffer[%qsub, %dht, 0] : {{ out_acc_tile_desc.get_mlir_shape(acc_stype) }}, vector<{{ tile_e }}x{{ acc_stype }}> + } + } + + affine.for %s0 = %blk to (%blk + {{ BlkS }}) step {{ tile_s }} { + // Accumulate score per qsub so K tiles can be shared across qsub. + affine.for %qsub = 0 to {{ g }} { + affine.vector_store %v0_s_acc, %score_buffer[%qsub, 0] : {{ score_desc.get_mlir_shape(acc_stype) }}, vector<{{ tile_s }}x{{ acc_stype }}> + } + + affine.for %k0 = 0 to {{ Dh }} step {{ tile_e }} { + // Load K slice once for all qsub. + {{ kernel.def_dma_op("MVIN", "key", kk_idx, k_tile_desc, subtile_size=[1, tile_s, tile_e], indent_size=10, padding=1) }} + %k2D = memref.reinterpret_cast %k_buffer to offset: [0], sizes: [{{ tile_s }}, {{ tile_e }}], strides: [{{ tile_e }}, 1] : {{ k_tile_desc.get_mlir_shape(io_stype) }} to memref<{{ tile_s }}x{{ tile_e }}x{{ io_stype }}, 1> + + affine.for %qsub = 0 to {{ g }} { + {{ kernel.def_dma_op("MVIN", "query", qk_idx, q_tile_desc, subtile_size=[1, 1, tile_e], indent_size=12) }} + %q2D = memref.reinterpret_cast %q_buffer to offset: [0], sizes: [{{ tile_e }}, 1], strides: [1, 1] : {{ q_tile_desc.get_mlir_shape(io_stype) }} to memref<{{ tile_e }}x1x{{ io_stype }}, 1> + + // mul = k @ q -> (tile_s x 1) in io dtype, then upcast and accumulate. + linalg.matmul + { idx_map = array } + ins(%k2D, %q2D : memref<{{ tile_s }}x{{ tile_e }}x{{ io_stype }}, 1>, memref<{{ tile_e }}x1x{{ io_stype }}, 1>) + outs(%mul_buffer : {{ mul_tile_desc.get_mlir_shape(io_stype) }}) + + %raw_mul_io = affine.vector_load %mul_buffer[0, 0] : {{ mul_tile_desc.get_mlir_shape(io_stype) }}, vector<{{ tile_s }}x{{ io_stype }}> + %raw_mul = arith.extf %raw_mul_io : vector<{{ tile_s }}x{{ io_stype }}> to vector<{{ tile_s }}x{{ acc_stype }}> + %old_score = affine.vector_load %score_buffer[%qsub, 0] : {{ score_desc.get_mlir_shape(acc_stype) }}, vector<{{ tile_s }}x{{ acc_stype }}> + %new_score = arith.addf %old_score, %raw_mul : vector<{{ tile_s }}x{{ acc_stype }}> + affine.vector_store %new_score, %score_buffer[%qsub, 0] : {{ score_desc.get_mlir_shape(acc_stype) }}, vector<{{ tile_s }}x{{ acc_stype }}> + } { accumulation_loop=true } + } { accumulation_loop=true } + + affine.for %qsub = 0 to {{ g }} { + %score_acc = affine.vector_load %score_buffer[%qsub, 0] : {{ score_desc.get_mlir_shape(acc_stype) }}, vector<{{ tile_s }}x{{ acc_stype }}> + // scale after full Dh reduction + %scaled_mul_vec = arith.mulf %score_acc, %v_scale : vector<{{ tile_s }}x{{ acc_stype }}> + + // Online softmax update (max/sum/out) identical to FLASH_SDPA_TEMPLATE but specialized to Lq==1. + %old_max = affine.vector_load %max_buffer[%qsub, 0] : {{ max_desc.get_mlir_shape(acc_stype) }}, vector<2x{{ acc_stype }}> + // Reduce max over tile_s + %max_init = vector.broadcast %c_neg_inf : {{ acc_stype }} to vector<{{ tile_s }}x{{ acc_stype }}> + %local_max_vec = arith.maximumf %scaled_mul_vec, %max_init : vector<{{ tile_s }}x{{ acc_stype }}> + %max_cast = vector.shape_cast %local_max_vec : vector<{{ tile_s }}x{{ acc_stype }}> to vector<{{ tile_s // 2 }}x2x{{ acc_stype }}> + %max_red1 = vector.multi_reduction , %max_cast, %v_neg_inf_2x [0] : vector<{{ tile_s // 2 }}x2x{{ acc_stype }}> to vector<2x{{ acc_stype }}> + %max_shuf = vector.shuffle %max_red1, %max_red1 [1, 0] : vector<2x{{ acc_stype }}>, vector<2x{{ acc_stype }}> + %max_red2 = arith.maximumf %max_red1, %max_shuf : vector<2x{{ acc_stype }}> + %new_max = arith.maximumf %max_red2, %old_max : vector<2x{{ acc_stype }}> + affine.vector_store %new_max, %max_buffer[%qsub, 0] : {{ max_desc.get_mlir_shape(acc_stype) }}, vector<2x{{ acc_stype }}> + + // rescale = exp(old_max - new_max) + %max_diff = arith.subf %old_max, %new_max : vector<2x{{ acc_stype }}> + %max_diff_scalar = vector.extract %max_diff[0] : {{ acc_stype }} from vector<2x{{ acc_stype }}> + %rescale_e = vector.broadcast %max_diff_scalar : {{ acc_stype }} to vector<{{ tile_e }}x{{ acc_stype }}> + %exp_rescale_e = math.exp %rescale_e : vector<{{ tile_e }}x{{ acc_stype }}> + %rescale_2 = vector.broadcast %max_diff_scalar : {{ acc_stype }} to vector<2x{{ acc_stype }}> + %exp_rescale_2 = math.exp %rescale_2 : vector<2x{{ acc_stype }}> + + // out *= rescale + %old_out = affine.vector_load %out_acc_buffer[%qsub, 0, 0] : {{ out_acc_tile_desc.get_mlir_shape(acc_stype) }}, vector<{{ tile_e }}x{{ acc_stype }}> + %rescaled_out = arith.mulf %exp_rescale_e, %old_out : vector<{{ tile_e }}x{{ acc_stype }}> + affine.vector_store %rescaled_out, %out_acc_buffer[%qsub, 0, 0] : {{ out_acc_tile_desc.get_mlir_shape(acc_stype) }}, vector<{{ tile_e }}x{{ acc_stype }}> + + // sum *= rescale + %old_sum = affine.vector_load %sum_buffer[%qsub, 0] : {{ sum_desc.get_mlir_shape(acc_stype) }}, vector<2x{{ acc_stype }}> + %rescaled_sum = arith.mulf %old_sum, %exp_rescale_2 : vector<2x{{ acc_stype }}> + + // exp(score - new_max) + %new_max_scalar = vector.extract %new_max[0] : {{ acc_stype }} from vector<2x{{ acc_stype }}> + %new_max_bcast = vector.broadcast %new_max_scalar : {{ acc_stype }} to vector<{{ tile_s }}x{{ acc_stype }}> + %shifted = arith.subf %scaled_mul_vec, %new_max_bcast : vector<{{ tile_s }}x{{ acc_stype }}> + %exp_scores = math.exp %shifted : vector<{{ tile_s }}x{{ acc_stype }}> + // For SV matmul: downcast softmax output to io dtype (common in practice) + %exp_scores_io = arith.truncf %exp_scores : vector<{{ tile_s }}x{{ acc_stype }}> to vector<{{ tile_s }}x{{ io_stype }}> + affine.vector_store %exp_scores_io, %prob_buffer[%qsub, 0] : {{ prob_desc.get_mlir_shape(io_stype) }}, vector<{{ tile_s }}x{{ io_stype }}> + + // sum += reduce(exp_scores) + %sum_cast = vector.shape_cast %exp_scores : vector<{{ tile_s }}x{{ acc_stype }}> to vector<{{ tile_s // 2 }}x2x{{ acc_stype }}> + %zero_2x = vector.broadcast %c0 : {{ acc_stype }} to vector<2x{{ acc_stype }}> + %sum_red1 = vector.multi_reduction , %sum_cast, %zero_2x [0] : vector<{{ tile_s // 2 }}x2x{{ acc_stype }}> to vector<2x{{ acc_stype }}> + %sum_shuf = vector.shuffle %sum_red1, %sum_red1 [1, 0] : vector<2x{{ acc_stype }}>, vector<2x{{ acc_stype }}> + %sum_red2 = arith.addf %sum_red1, %sum_shuf : vector<2x{{ acc_stype }}> + %new_sum = arith.addf %sum_red2, %rescaled_sum : vector<2x{{ acc_stype }}> + affine.vector_store %new_sum, %sum_buffer[%qsub, 0] : {{ sum_desc.get_mlir_shape(acc_stype) }}, vector<2x{{ acc_stype }}> + + } { accumulation_loop=true } + + // 2) SV accumulation: for each output dh tile, load V once and share across qsub. + affine.for %dht = 0 to {{ dh_tiles }} { + %dh0 = affine.apply affine_map<(d0) -> (d0 * {{ tile_e }})>(%dht) + {{ kernel.def_dma_op("MVIN", "value", v_idx, v_tile_desc, subtile_size=[1, tile_s, tile_e], indent_size=10, padding=0) }} + %v2D = memref.reinterpret_cast %v_buffer to offset: [0], sizes: [{{ tile_e }}, {{ tile_s }}], strides: [{{ tile_s }}, 1] : {{ v_tile_desc.get_mlir_shape(io_stype) }} to memref<{{ tile_e }}x{{ tile_s }}x{{ io_stype }}, 1> + + affine.for %qsub = 0 to {{ g }} { + %prob_vec = affine.vector_load %prob_buffer[%qsub, 0] : {{ prob_desc.get_mlir_shape(io_stype) }}, vector<{{ tile_s }}x{{ io_stype }}> + affine.vector_store %prob_vec, %mul_buffer[0, 0] : {{ mul_tile_desc.get_mlir_shape(io_stype) }}, vector<{{ tile_s }}x{{ io_stype }}> + affine.vector_store %v0_e_io, %out_io_buffer[0, 0, 0] : {{ out_io_tile_desc.get_mlir_shape(io_stype) }}, vector<{{ tile_e }}x{{ io_stype }}> + %out_io_2D = memref.reinterpret_cast %out_io_buffer to offset: [0], sizes: [{{ tile_e }}, 1], strides: [1, 1] : {{ out_io_tile_desc.get_mlir_shape(io_stype) }} to memref<{{ tile_e }}x1x{{ io_stype }}, 1> + linalg.matmul + { idx_map = array } + ins(%v2D, %mul_buffer : memref<{{ tile_e }}x{{ tile_s }}x{{ io_stype }}, 1>, {{ mul_tile_desc.get_mlir_shape(io_stype) }}) + outs(%out_io_2D : memref<{{ tile_e }}x1x{{ io_stype }}, 1>) + + %out_io_vec = affine.vector_load %out_io_buffer[0, 0, 0] : {{ out_io_tile_desc.get_mlir_shape(io_stype) }}, vector<{{ tile_e }}x{{ io_stype }}> + %out_io_f32 = arith.extf %out_io_vec : vector<{{ tile_e }}x{{ io_stype }}> to vector<{{ tile_e }}x{{ acc_stype }}> + %out_acc_vec = affine.vector_load %out_acc_buffer[%qsub, %dht, 0] : {{ out_acc_tile_desc.get_mlir_shape(acc_stype) }}, vector<{{ tile_e }}x{{ acc_stype }}> + %out_acc_new = arith.addf %out_acc_vec, %out_io_f32 : vector<{{ tile_e }}x{{ acc_stype }}> + affine.vector_store %out_acc_new, %out_acc_buffer[%qsub, %dht, 0] : {{ out_acc_tile_desc.get_mlir_shape(acc_stype) }}, vector<{{ tile_e }}x{{ acc_stype }}> + } { accumulation_loop=true } + } { accumulation_loop=true } + } { accumulation_loop=true } + + // finalize per-qsub for this (kv, blk) and store out for all dh tiles + affine.for %qsub = 0 to {{ g }} { + %final_sum = affine.vector_load %sum_buffer[%qsub, 0] : {{ sum_desc.get_mlir_shape(acc_stype) }}, vector<2x{{ acc_stype }}> + %one_2x = vector.broadcast %c1 : {{ acc_stype }} to vector<2x{{ acc_stype }}> + %inv_sum_2x = arith.divf %one_2x, %final_sum : vector<2x{{ acc_stype }}> + %inv_sum = vector.extract %inv_sum_2x[0] : {{ acc_stype }} from vector<2x{{ acc_stype }}> + %inv_bcast = vector.broadcast %inv_sum : {{ acc_stype }} to vector<{{ tile_e }}x{{ acc_stype }}> + + affine.for %dht = 0 to {{ dh_tiles }} { + %dh0 = affine.apply affine_map<(d0) -> (d0 * {{ tile_e }})>(%dht) + %acc_out = affine.vector_load %out_acc_buffer[%qsub, %dht, 0] : {{ out_acc_tile_desc.get_mlir_shape(acc_stype) }}, vector<{{ tile_e }}x{{ acc_stype }}> + %final_out_acc = arith.mulf %acc_out, %inv_bcast : vector<{{ tile_e }}x{{ acc_stype }}> + %final_out_io = arith.truncf %final_out_acc : vector<{{ tile_e }}x{{ acc_stype }}> to vector<{{ tile_e }}x{{ io_stype }}> + affine.vector_store %final_out_io, %out_io_buffer[0, 0, 0] : {{ out_io_tile_desc.get_mlir_shape(io_stype) }}, vector<{{ tile_e }}x{{ io_stype }}> + {{ kernel.store_output(indent_size=10) }} + } + } { outer_loop=true } + } { outer_loop=true } + } { outer_loop=true } + + return +} +""" + + +class MLIRDecodeGQASDPATemplate(MLIRTemplate): + def __init__(self, input_nodes, layout, scale, BlkS: int = 1024, input_reorder=None): + super().__init__("kernel", input_nodes, layout, input_reorder) + self.scale = scale + self.BlkS = BlkS + + def render(self, kernel: MLIRTemplateKernel, template_buffer_node=None, epilogue_nodes=None, prologue_nodes=None, tile_info=None, **kwargs): + # Decode-only: q is (B,Hq,1,Dh) + query, key, value, out = self.input_nodes[0], self.input_nodes[1], self.input_nodes[2], self.output_node + + # Materialize tensors for stride metadata + q_tensor4 = empty_strided(query.layout.size, query.layout.stride) + k_tensor4 = empty_strided(key.layout.size, key.layout.stride) + v_tensor4 = empty_strided(value.layout.size, value.layout.stride) + + B, Hq, Lq, Dh = q_tensor4.shape + Bk, H, S, Dhk = k_tensor4.shape + assert B == 1, "Decode GQA template currently supports B==1" + assert Lq == 1, "Decode GQA template requires Lq==1" + assert Dh == Dhk + g = Hq // H + BlkS = min(int(self.BlkS), int(S)) + + # Use 3D views to match the existing SDPA indexing scheme + # q: (Hq, 1, Dh), k/v: (H, S, Dh), out: (Hq, 1, Dh) + q_tensor = q_tensor4.view(Hq, 1, Dh) + k_tensor = k_tensor4.view(H, S, Dh) + v_tensor = v_tensor4.view(H, S, Dh) + + tile_s = kernel.vector_lane + tile_e = kernel.vector_lane + dh_tiles = int(Dh) // int(tile_e) + + io_stype = mlir_common.DTYPE_TO_MLIR[query.get_dtype()] + acc_stype = "f32" + + # SRAM tiles: q(1x1xtile_e), k/v(1xtile_sxtile_e), mul(tile_sx1) in io dtype. + # out_acc in f32; out_io temp in io dtype. + vlane_stride = 1 + q_tile_desc = mlir_common.MLIRMultiDimTile([1, 1, tile_e], kernel.vector_lane, 1, vlane_stride) + q_tile_desc.set_tile_size_stride([1, 1, tile_e], [0, tile_e, 1]) + q_tile_desc.set_name("q_buffer") + q_tile_desc.offset = query.get_layout().offset + + k_tile_desc = mlir_common.MLIRMultiDimTile([1, tile_s, tile_e], kernel.vector_lane, 2, vlane_stride) + k_tile_desc.set_tile_size_stride([1, tile_s, tile_e], [0, 1, tile_s]) + k_tile_desc.set_name("k_buffer") + k_tile_desc.offset = key.get_layout().offset + + v_tile_desc = mlir_common.MLIRMultiDimTile([1, tile_s, tile_e], kernel.vector_lane, 1, vlane_stride) + v_tile_desc.set_tile_size_stride([1, tile_s, tile_e], [0, tile_e, 1]) + v_tile_desc.set_name("v_buffer") + v_tile_desc.offset = value.get_layout().offset + + mul_tile_desc = mlir_common.MLIRMultiDimTile([tile_s, 1], kernel.vector_lane, 1, vlane_stride) + mul_tile_desc.set_tile_size_stride([tile_s, 1], [1, 1]) + mul_tile_desc.set_name("mul_buffer") + + score_desc = mlir_common.MLIRMultiDimTile([g, tile_s], kernel.vector_lane, 1, vlane_stride) + score_desc.set_tile_size_stride([g, tile_s], [tile_s, 1]) + score_desc.set_name("score_buffer") + + prob_desc = mlir_common.MLIRMultiDimTile([g, tile_s], kernel.vector_lane, 1, vlane_stride) + prob_desc.set_tile_size_stride([g, tile_s], [tile_s, 1]) + prob_desc.set_name("prob_buffer") + + # Per-qsub accumulators so KV tiles can be shared across qsub + out_acc_tile_desc = mlir_common.MLIRMultiDimTile([g, dh_tiles, tile_e], kernel.vector_lane, 2, vlane_stride) + out_acc_tile_desc.set_tile_size_stride([g, dh_tiles, tile_e], [dh_tiles * tile_e, tile_e, 1]) + out_acc_tile_desc.set_name("out_acc_buffer") + + out_io_tile_desc = mlir_common.MLIRMultiDimTile([1, 1, tile_e], kernel.vector_lane, 1, vlane_stride) + out_io_tile_desc.set_tile_size_stride([1, 1, tile_e], [0, tile_e, 1]) + out_io_tile_desc.set_name("out_io_buffer") + + max_desc = mlir_common.MLIRMultiDimTile([g, 2], kernel.vector_lane, 0, vlane_stride) + max_desc.set_tile_size_stride([g, 2], [2, 1]) + max_desc.set_name("max_buffer") + + sum_desc = mlir_common.MLIRMultiDimTile([g, 2], kernel.vector_lane, 0, vlane_stride) + sum_desc.set_tile_size_stride([g, 2], [2, 1]) + sum_desc.set_name("sum_buffer") + + # Indices + kv = sympy.Symbol("kv") + qsub = sympy.Symbol("qsub") + dh0 = sympy.Symbol("dh0") + k0 = sympy.Symbol("k0") + s0 = sympy.Symbol("s0") + q_head = kv * g + qsub + + q_stride = q_tensor.stride() + k_stride = k_tensor.stride() + v_stride = v_tensor.stride() + # out is (B,Hq,1,Dh) but we address it as (Hq,1,Dh) + out_tensor = empty_strided(out.get_layout().size, out.get_layout().stride).view(Hq, 1, Dh) + out_stride = out_tensor.stride() + + # QK indices use k0 reduction over Dh + qk_idx = [q_head * q_stride[0], sympy.Integer(0), k0 * q_stride[2]] + kk_idx = [kv * k_stride[0], s0 * k_stride[1], k0 * k_stride[2]] + # V and output use dh0 tile offset + v_idx = [kv * v_stride[0], s0 * v_stride[1], dh0 * v_stride[2]] + out_idx = [q_head * out_stride[0], sympy.Integer(0), dh0 * out_stride[2]] + + kernel.loop_size = [tile_s, tile_e, 1] + + kernel.render_options = dict( + KERNEL_NAME=self.name, + kernel=kernel, + B=B, + Hq=Hq, + H=H, + g=g, + S=S, + Dh=Dh, + dh_tiles=dh_tiles, + BlkS=BlkS, + tile_s=tile_s, + tile_e=tile_e, + io_stype=io_stype, + acc_stype=acc_stype, + scale=self.scale, + query=query, + key=key, + value=value, + out=out, + q_tile_desc=q_tile_desc, + k_tile_desc=k_tile_desc, + v_tile_desc=v_tile_desc, + out_acc_tile_desc=out_acc_tile_desc, + out_io_tile_desc=out_io_tile_desc, + mul_tile_desc=mul_tile_desc, + score_desc=score_desc, + prob_desc=prob_desc, + max_desc=max_desc, + sum_desc=sum_desc, + qk_idx=qk_idx, + kk_idx=kk_idx, + v_idx=v_idx, + out_idx=out_idx, + input_reorder=self.input_reorder, + ) + + kernel.epilogue_info = dict( + output_node=self.output_node.name, + sram_var="out_io_buffer", + dram_var="out", + dram_idx=out_idx, + dram_tile_desc=out_io_tile_desc, + nr_rdim=0, + r_dim_size=0, + dim_aliasing={"kv": "kv", "qsub": "qsub", "dh0": "dh0", "s0": "s0"}, + ) + + return self._template_from_string(DECODE_GQA_SDPA_TEMPLATE).render(**kernel.render_options) + + +# --------------------------- +# Decode-only GQA SDPA: 2-kernel pipeline (partial blocks + reduce) +# --------------------------- + +DECODE_GQA_SDPA_PARTIAL_TEMPLATE = r""" +// Decode GQA SDPA partial kernel (per sequence block) +// Produces partials per (kv,qsub,dh_tile,blk): +// - first half lanes: o_j (tile_e) +// - second half lanes: [m_j, l_j, 0, 0, ...] (tile_e) +// QK/softmax is computed once per (kv,qsub,s0) over full Dh using k0 reduction. +// SV then reuses those probabilities across all dh tiles. +// H = {{ H }}, g = {{ g }}, Dh = {{ Dh }}, dh_tiles = {{ dh_tiles }}, S = {{ S }}, BlkS = {{ BlkS }}, nblk = {{ nblk }} +{{kernel.def_global_vars()}} + +func.func @{{ KERNEL_NAME }}{{kernel.def_kernel(inputs=[query, key, value], outputs=[partial], names_str="query, key, value, partial", input_reorder=input_reorder)}} { + {{ kernel.def_sram_buffer("query", q_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("key", k_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("value", v_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("mul", mul_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("score", score_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("prob", prob_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("out_io", out_io_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("max", max_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("sum", sum_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("out_acc", out_acc_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("partial", partial_tile_desc, indent_size=2) }} + + %c0 = arith.constant 0.0 : f32 + %c_scale = arith.constant {{ scale }} : f32 + %c_neg_inf = arith.constant -1.0e+30 : f32 + + %v0_e = arith.constant dense<0.0> : vector<{{ tile_e }}xf32> + %v0_e_io = arith.constant dense<0.0> : vector<{{ tile_e }}x{{ io_stype }}> + %v0_s = arith.constant dense<0.0> : vector<{{ tile_s }}xf32> + %v0_2x = arith.constant dense<0.0> : vector<2xf32> + %v_neg_inf_2x = arith.constant dense<-1.0e+30> : vector<2xf32> + %v_scale = vector.broadcast %c_scale : f32 to vector<{{ tile_s }}xf32> + + {{ kernel.def_local_vars(indent_size=2) }} + + affine.for %kv = 0 to {{ H }} { + affine.for %blk = 0 to {{ nblk }} step 1 { + // Reset per-block accumulators for all qsub/dh tiles. + affine.for %qsub = 0 to {{ g }} { + affine.vector_store %v_neg_inf_2x, %max_buffer[%qsub, 0] : {{ max_desc.get_mlir_shape("f32") }}, vector<2xf32> + affine.vector_store %v0_2x, %sum_buffer[%qsub, 0] : {{ sum_desc.get_mlir_shape("f32") }}, vector<2xf32> + affine.for %dht = 0 to {{ dh_tiles }} { + affine.vector_store %v0_e, %out_acc_buffer[%qsub, %dht, 0] : {{ out_acc_tile_desc.get_mlir_shape("f32") }}, vector<{{ tile_e }}xf32> + } + } + + affine.for %s0 = ({{ BlkS }} * %blk) to ({{ BlkS }} * (%blk + 1)) step {{ tile_s }} { + // Accumulate score per qsub so K tiles can be shared across qsub. + affine.for %qsub = 0 to {{ g }} { + affine.vector_store %v0_s, %score_buffer[%qsub, 0] : {{ score_desc.get_mlir_shape("f32") }}, vector<{{ tile_s }}xf32> + } + + affine.for %k0 = 0 to {{ Dh }} step {{ tile_e }} { + {{ kernel.def_dma_op("MVIN", "key", kk_idx, k_tile_desc, subtile_size=[1, tile_s, tile_e], indent_size=10, padding=1) }} + %k2D = memref.reinterpret_cast %k_buffer to offset: [0], sizes: [{{ tile_s }}, {{ tile_e }}], strides: [{{ tile_e }}, 1] : {{ k_tile_desc.get_mlir_shape(io_stype) }} to memref<{{ tile_s }}x{{ tile_e }}x{{ io_stype }}, 1> + + affine.for %qsub = 0 to {{ g }} { + {{ kernel.def_dma_op("MVIN", "query", qk_idx, q_tile_desc, subtile_size=[1, 1, tile_e], indent_size=12) }} + %q2D = memref.reinterpret_cast %q_buffer to offset: [0], sizes: [{{ tile_e }}, 1], strides: [1, 1] : {{ q_tile_desc.get_mlir_shape(io_stype) }} to memref<{{ tile_e }}x1x{{ io_stype }}, 1> + linalg.matmul + { idx_map = array } + ins(%k2D, %q2D : memref<{{ tile_s }}x{{ tile_e }}x{{ io_stype }}, 1>, memref<{{ tile_e }}x1x{{ io_stype }}, 1>) + outs(%mul_buffer : {{ mul_tile_desc.get_mlir_shape(io_stype) }}) + %raw_mul_io = affine.vector_load %mul_buffer[0, 0] : {{ mul_tile_desc.get_mlir_shape(io_stype) }}, vector<{{ tile_s }}x{{ io_stype }}> + %raw_mul = arith.extf %raw_mul_io : vector<{{ tile_s }}x{{ io_stype }}> to vector<{{ tile_s }}xf32> + %old_score = affine.vector_load %score_buffer[%qsub, 0] : {{ score_desc.get_mlir_shape("f32") }}, vector<{{ tile_s }}xf32> + %new_score = arith.addf %old_score, %raw_mul : vector<{{ tile_s }}xf32> + affine.vector_store %new_score, %score_buffer[%qsub, 0] : {{ score_desc.get_mlir_shape("f32") }}, vector<{{ tile_s }}xf32> + } { accumulation_loop=true } + } { accumulation_loop=true } + + // Softmax once per qsub; persist probabilities in SRAM for all SV dh tiles. + affine.for %qsub = 0 to {{ g }} { + %score = affine.vector_load %score_buffer[%qsub, 0] : {{ score_desc.get_mlir_shape("f32") }}, vector<{{ tile_s }}xf32> + %scaled = arith.mulf %score, %v_scale : vector<{{ tile_s }}xf32> + + %old_max = affine.vector_load %max_buffer[%qsub, 0] : {{ max_desc.get_mlir_shape("f32") }}, vector<2xf32> + %max_init = vector.broadcast %c_neg_inf : f32 to vector<{{ tile_s }}xf32> + %local_max_vec = arith.maximumf %scaled, %max_init : vector<{{ tile_s }}xf32> + %max_cast = vector.shape_cast %local_max_vec : vector<{{ tile_s }}xf32> to vector<{{ tile_s // 2 }}x2xf32> + %max_red1 = vector.multi_reduction , %max_cast, %v_neg_inf_2x [0] : vector<{{ tile_s // 2 }}x2xf32> to vector<2xf32> + %max_shuf = vector.shuffle %max_red1, %max_red1 [1, 0] : vector<2xf32>, vector<2xf32> + %max_red2 = arith.maximumf %max_red1, %max_shuf : vector<2xf32> + %new_max = arith.maximumf %max_red2, %old_max : vector<2xf32> + affine.vector_store %new_max, %max_buffer[%qsub, 0] : {{ max_desc.get_mlir_shape("f32") }}, vector<2xf32> + + %max_diff = arith.subf %old_max, %new_max : vector<2xf32> + %max_diff_scalar = vector.extract %max_diff[0] : f32 from vector<2xf32> + %rescale_e = vector.broadcast %max_diff_scalar : f32 to vector<{{ tile_e }}xf32> + %exp_rescale_e = math.exp %rescale_e : vector<{{ tile_e }}xf32> + %rescale_2 = vector.broadcast %max_diff_scalar : f32 to vector<2xf32> + %exp_rescale_2 = math.exp %rescale_2 : vector<2xf32> + + %old_sum = affine.vector_load %sum_buffer[%qsub, 0] : {{ sum_desc.get_mlir_shape("f32") }}, vector<2xf32> + %rescaled_sum = arith.mulf %old_sum, %exp_rescale_2 : vector<2xf32> + + affine.for %dht = 0 to {{ dh_tiles }} { + %old_out = affine.vector_load %out_acc_buffer[%qsub, %dht, 0] : {{ out_acc_tile_desc.get_mlir_shape("f32") }}, vector<{{ tile_e }}xf32> + %rescaled_out = arith.mulf %exp_rescale_e, %old_out : vector<{{ tile_e }}xf32> + affine.vector_store %rescaled_out, %out_acc_buffer[%qsub, %dht, 0] : {{ out_acc_tile_desc.get_mlir_shape("f32") }}, vector<{{ tile_e }}xf32> + } + + %new_max_scalar = vector.extract %new_max[0] : f32 from vector<2xf32> + %new_max_bcast = vector.broadcast %new_max_scalar : f32 to vector<{{ tile_s }}xf32> + %shifted = arith.subf %scaled, %new_max_bcast : vector<{{ tile_s }}xf32> + %exp_scores = math.exp %shifted : vector<{{ tile_s }}xf32> + %exp_scores_io = arith.truncf %exp_scores : vector<{{ tile_s }}xf32> to vector<{{ tile_s }}x{{ io_stype }}> + affine.vector_store %exp_scores_io, %prob_buffer[%qsub, 0] : {{ prob_desc.get_mlir_shape(io_stype) }}, vector<{{ tile_s }}x{{ io_stype }}> + + %sum_cast = vector.shape_cast %exp_scores : vector<{{ tile_s }}xf32> to vector<{{ tile_s // 2 }}x2xf32> + %zero_2x = vector.broadcast %c0 : f32 to vector<2xf32> + %sum_red1 = vector.multi_reduction , %sum_cast, %zero_2x [0] : vector<{{ tile_s // 2 }}x2xf32> to vector<2xf32> + %sum_shuf = vector.shuffle %sum_red1, %sum_red1 [1, 0] : vector<2xf32>, vector<2xf32> + %sum_red2 = arith.addf %sum_red1, %sum_shuf : vector<2xf32> + %new_sum = arith.addf %sum_red2, %rescaled_sum : vector<2xf32> + affine.vector_store %new_sum, %sum_buffer[%qsub, 0] : {{ sum_desc.get_mlir_shape("f32") }}, vector<2xf32> + } { accumulation_loop=true } + + // For each output dh tile, load V once and share it across qsub. + affine.for %dht = 0 to {{ dh_tiles }} { + %dh0 = affine.apply affine_map<(d0) -> (d0 * {{ tile_e }})>(%dht) + {{ kernel.def_dma_op("MVIN", "value", v_idx, v_tile_desc, subtile_size=[1, tile_s, tile_e], indent_size=10, padding=0) }} + %v2D = memref.reinterpret_cast %v_buffer to offset: [0], sizes: [{{ tile_e }}, {{ tile_s }}], strides: [{{ tile_s }}, 1] : {{ v_tile_desc.get_mlir_shape(io_stype) }} to memref<{{ tile_e }}x{{ tile_s }}x{{ io_stype }}, 1> + + affine.for %qsub = 0 to {{ g }} { + %prob_vec = affine.vector_load %prob_buffer[%qsub, 0] : {{ prob_desc.get_mlir_shape(io_stype) }}, vector<{{ tile_s }}x{{ io_stype }}> + affine.vector_store %prob_vec, %mul_buffer[0, 0] : {{ mul_tile_desc.get_mlir_shape(io_stype) }}, vector<{{ tile_s }}x{{ io_stype }}> + affine.vector_store %v0_e_io, %out_io_buffer[0, 0, 0] : {{ out_io_tile_desc.get_mlir_shape(io_stype) }}, vector<{{ tile_e }}x{{ io_stype }}> + %out_io_2D = memref.reinterpret_cast %out_io_buffer to offset: [0], sizes: [{{ tile_e }}, 1], strides: [1, 1] : {{ out_io_tile_desc.get_mlir_shape(io_stype) }} to memref<{{ tile_e }}x1x{{ io_stype }}, 1> + linalg.matmul + { idx_map = array } + ins(%v2D, %mul_buffer : memref<{{ tile_e }}x{{ tile_s }}x{{ io_stype }}, 1>, {{ mul_tile_desc.get_mlir_shape(io_stype) }}) + outs(%out_io_2D : memref<{{ tile_e }}x1x{{ io_stype }}, 1>) + + %out_io_vec = affine.vector_load %out_io_buffer[0, 0, 0] : {{ out_io_tile_desc.get_mlir_shape(io_stype) }}, vector<{{ tile_e }}x{{ io_stype }}> + %out_io_f32 = arith.extf %out_io_vec : vector<{{ tile_e }}x{{ io_stype }}> to vector<{{ tile_e }}xf32> + %out_acc_vec = affine.vector_load %out_acc_buffer[%qsub, %dht, 0] : {{ out_acc_tile_desc.get_mlir_shape("f32") }}, vector<{{ tile_e }}xf32> + %out_acc_new = arith.addf %out_acc_vec, %out_io_f32 : vector<{{ tile_e }}xf32> + affine.vector_store %out_acc_new, %out_acc_buffer[%qsub, %dht, 0] : {{ out_acc_tile_desc.get_mlir_shape("f32") }}, vector<{{ tile_e }}xf32> + } { accumulation_loop=true } + } { accumulation_loop=true } + } { accumulation_loop=true } + + // Store packed partials for all qsub/dh tiles. + affine.for %qsub = 0 to {{ g }} { + %final_max = affine.vector_load %max_buffer[%qsub, 0] : {{ max_desc.get_mlir_shape("f32") }}, vector<2xf32> + %m_scalar = vector.extract %final_max[0] : f32 from vector<2xf32> + %final_sum = affine.vector_load %sum_buffer[%qsub, 0] : {{ sum_desc.get_mlir_shape("f32") }}, vector<2xf32> + %l_scalar = vector.extract %final_sum[0] : f32 from vector<2xf32> + %ml_vec = vector.broadcast %c0 : f32 to vector<{{ tile_e }}xf32> + %ml0 = vector.insert %m_scalar, %ml_vec[0] : f32 into vector<{{ tile_e }}xf32> + %ml1 = vector.insert %l_scalar, %ml0[1] : f32 into vector<{{ tile_e }}xf32> + + affine.for %dht = 0 to {{ dh_tiles }} { + %out_vec = affine.vector_load %out_acc_buffer[%qsub, %dht, 0] : {{ out_acc_tile_desc.get_mlir_shape("f32") }}, vector<{{ tile_e }}xf32> + %packed = vector.concat %out_vec, %ml1 : vector<{{ tile_pack }}xf32> + affine.vector_store %packed, %partial_buffer[0, 0, 0] : {{ partial_tile_desc.get_mlir_shape("f32") }}, vector<{{ tile_pack }}xf32> + {{ kernel.store_output(indent_size=10) }} + } + } { outer_loop=true } + } { outer_loop=true } + } { outer_loop=true } + return +} +""" + + +DECODE_GQA_SDPA_REDUCE_TEMPLATE = r""" +// Decode GQA SDPA reduce kernel: merge partials across blocks +// Input partial shape: (HgDhTiles, nblk, tile_pack) +{{kernel.def_global_vars()}} + +func.func @{{ KERNEL_NAME }}{{kernel.def_kernel(inputs=[partial], outputs=[out], names_str="partial, out", input_reorder=input_reorder)}} { + {{ kernel.def_sram_buffer("partial", partial_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("out_acc", out_acc_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("max", max_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("sum", sum_desc, indent_size=2) }} + + %c0 = arith.constant 0.0 : f32 + %c1 = arith.constant 1.0 : f32 + %c_neg_inf = arith.constant -1.0e+30 : f32 + %v0_e = arith.constant dense<0.0> : vector<{{ tile_e }}xf32> + %v0_2x = arith.constant dense<0.0> : vector<2xf32> + %v_neg_inf_2x = arith.constant dense<-1.0e+30> : vector<2xf32> + + {{ kernel.def_local_vars(indent_size=2) }} + + affine.for %gh = 0 to {{ HgDhTiles }} { + // reset merged accumulators + affine.vector_store %v0_e, %out_acc_buffer[0, 0, 0] : {{ out_acc_tile_desc.get_mlir_shape("f32") }}, vector<{{ tile_e }}xf32> + affine.vector_store %v_neg_inf_2x, %max_buffer[0, 0] : {{ max_desc.get_mlir_shape("f32") }}, vector<2xf32> + affine.vector_store %v0_2x, %sum_buffer[0, 0] : {{ sum_desc.get_mlir_shape("f32") }}, vector<2xf32> + + affine.for %blk = 0 to {{ nblk }} { + {{ kernel.def_dma_op("MVIN", "partial", partial_idx, partial_tile_desc, subtile_size=[1, 1, tile_pack], indent_size=8) }} + %p = affine.vector_load %partial_buffer[0, 0, 0] : {{ partial_tile_desc.get_mlir_shape("f32") }}, vector<{{ tile_pack }}xf32> + %p2 = vector.shape_cast %p : vector<{{ tile_pack }}xf32> to vector<2x{{ tile_e }}xf32> + %o_j = vector.extract %p2[0] : vector<{{ tile_e }}xf32> from vector<2x{{ tile_e }}xf32> + %ml_j = vector.extract %p2[1] : vector<{{ tile_e }}xf32> from vector<2x{{ tile_e }}xf32> + %m_j = vector.extract %ml_j[0] : f32 from vector<{{ tile_e }}xf32> + %l_j = vector.extract %ml_j[1] : f32 from vector<{{ tile_e }}xf32> + + %old_max = affine.vector_load %max_buffer[0, 0] : {{ max_desc.get_mlir_shape("f32") }}, vector<2xf32> + %m_old = vector.extract %old_max[0] : f32 from vector<2xf32> + %m_new = arith.maximumf %m_old, %m_j : f32 + %m_new2 = vector.broadcast %m_new : f32 to vector<2xf32> + affine.vector_store %m_new2, %max_buffer[0, 0] : {{ max_desc.get_mlir_shape("f32") }}, vector<2xf32> + + %diff_old = arith.subf %m_old, %m_new : f32 + %diff_j = arith.subf %m_j, %m_new : f32 + %scale_old = math.exp %diff_old : f32 + %scale_j = math.exp %diff_j : f32 + %scale_old_e = vector.broadcast %scale_old : f32 to vector<{{ tile_e }}xf32> + %scale_j_e = vector.broadcast %scale_j : f32 to vector<{{ tile_e }}xf32> + + %o_old = affine.vector_load %out_acc_buffer[0, 0, 0] : {{ out_acc_tile_desc.get_mlir_shape("f32") }}, vector<{{ tile_e }}xf32> + %o_old_rs = arith.mulf %o_old, %scale_old_e : vector<{{ tile_e }}xf32> + %o_j_rs = arith.mulf %o_j, %scale_j_e : vector<{{ tile_e }}xf32> + %o_new = arith.addf %o_old_rs, %o_j_rs : vector<{{ tile_e }}xf32> + affine.vector_store %o_new, %out_acc_buffer[0, 0, 0] : {{ out_acc_tile_desc.get_mlir_shape("f32") }}, vector<{{ tile_e }}xf32> + + %old_sum = affine.vector_load %sum_buffer[0, 0] : {{ sum_desc.get_mlir_shape("f32") }}, vector<2xf32> + %l_old = vector.extract %old_sum[0] : f32 from vector<2xf32> + %l_new = arith.addf (arith.mulf %l_old, %scale_old : f32), (arith.mulf %l_j, %scale_j : f32) : f32 + %l_new2 = vector.broadcast %l_new : f32 to vector<2xf32> + affine.vector_store %l_new2, %sum_buffer[0, 0] : {{ sum_desc.get_mlir_shape("f32") }}, vector<2xf32> + } { accumulation_loop=true } + + // finalize: out = o / l + %sum2 = affine.vector_load %sum_buffer[0, 0] : {{ sum_desc.get_mlir_shape("f32") }}, vector<2xf32> + %l = vector.extract %sum2[0] : f32 from vector<2xf32> + %inv = arith.divf %c1, %l : f32 + %inv_e = vector.broadcast %inv : f32 to vector<{{ tile_e }}xf32> + %o = affine.vector_load %out_acc_buffer[0, 0, 0] : {{ out_acc_tile_desc.get_mlir_shape("f32") }}, vector<{{ tile_e }}xf32> + %out_f32 = arith.mulf %o, %inv_e : vector<{{ tile_e }}xf32> + %out_io = arith.truncf %out_f32 : vector<{{ tile_e }}xf32> to vector<{{ tile_e }}x{{ io_stype }}> + affine.vector_store %out_io, %out_buffer[0, 0, 0] : {{ out_tile_desc.get_mlir_shape(io_stype) }}, vector<{{ tile_e }}x{{ io_stype }}> + {{ kernel.store_output(indent_size=4) }} + } { outer_loop=true } + return +} +""" + + +class MLIRDecodeGQASDPAPartialTemplate(MLIRTemplate): + def __init__(self, input_nodes, layout, scale, BlkS: int = 1024, input_reorder=None): + super().__init__("kernel", input_nodes, layout, input_reorder) + self.scale = scale + self.BlkS = BlkS + + def render(self, kernel: MLIRTemplateKernel, template_buffer_node=None, epilogue_nodes=None, prologue_nodes=None, tile_info=None, **kwargs): + query, key, value = self.input_nodes[0], self.input_nodes[1], self.input_nodes[2] + partial = self.output_node + + q_tensor4 = empty_strided(query.layout.size, query.layout.stride) + k_tensor4 = empty_strided(key.layout.size, key.layout.stride) + v_tensor4 = empty_strided(value.layout.size, value.layout.stride) + B, Hq, Lq, Dh = q_tensor4.shape + _, H, S, _ = k_tensor4.shape + assert B == 1 and Lq == 1 + g = Hq // H + BlkS = min(int(self.BlkS), int(S)) + nblk = (int(S) + int(BlkS) - 1) // int(BlkS) + + io_stype = mlir_common.DTYPE_TO_MLIR[query.get_dtype()] + tile_s = kernel.vector_lane + tile_e = kernel.vector_lane + tile_pack = tile_e * 2 + + # Use 3D views for indices + q_tensor = q_tensor4.view(Hq, 1, Dh) + k_tensor = k_tensor4.view(H, S, Dh) + v_tensor = v_tensor4.view(H, S, Dh) + + # Flatten (kv,qsub,dh_tile) into GH = H*g*(Dh/tile_e) + dh_tiles = int(Dh) // int(tile_e) + HgDhTiles = int(H) * int(g) * int(dh_tiles) + + # tile descs + vlane_stride = 1 + q_tile_desc = mlir_common.MLIRMultiDimTile([1, 1, tile_e], kernel.vector_lane, 1, vlane_stride) + q_tile_desc.set_tile_size_stride([1, 1, tile_e], [0, tile_e, 1]) + q_tile_desc.set_name("q_buffer") + q_tile_desc.offset = query.get_layout().offset + + k_tile_desc = mlir_common.MLIRMultiDimTile([1, tile_s, tile_e], kernel.vector_lane, 2, vlane_stride) + k_tile_desc.set_tile_size_stride([1, tile_s, tile_e], [0, 1, tile_s]) + k_tile_desc.set_name("k_buffer") + k_tile_desc.offset = key.get_layout().offset + + v_tile_desc = mlir_common.MLIRMultiDimTile([1, tile_s, tile_e], kernel.vector_lane, 1, vlane_stride) + v_tile_desc.set_tile_size_stride([1, tile_s, tile_e], [0, tile_e, 1]) + v_tile_desc.set_name("v_buffer") + v_tile_desc.offset = value.get_layout().offset + + mul_tile_desc = mlir_common.MLIRMultiDimTile([tile_s, 1], kernel.vector_lane, 1, vlane_stride) + mul_tile_desc.set_tile_size_stride([tile_s, 1], [1, 1]) + mul_tile_desc.set_name("mul_buffer") + + score_desc = mlir_common.MLIRMultiDimTile([g, tile_s], kernel.vector_lane, 1, vlane_stride) + score_desc.set_tile_size_stride([g, tile_s], [tile_s, 1]) + score_desc.set_name("score_buffer") + + prob_desc = mlir_common.MLIRMultiDimTile([g, tile_s], kernel.vector_lane, 1, vlane_stride) + prob_desc.set_tile_size_stride([g, tile_s], [tile_s, 1]) + prob_desc.set_name("prob_buffer") + + # Per-qsub, per-dh-tile accumulators so QK is computed once and SV expands across dh tiles. + out_acc_tile_desc = mlir_common.MLIRMultiDimTile([g, dh_tiles, tile_e], kernel.vector_lane, 2, vlane_stride) + out_acc_tile_desc.set_tile_size_stride([g, dh_tiles, tile_e], [dh_tiles * tile_e, tile_e, 1]) + out_acc_tile_desc.set_name("out_acc_buffer") + + max_desc = mlir_common.MLIRMultiDimTile([g, 2], kernel.vector_lane, 0, vlane_stride) + max_desc.set_tile_size_stride([g, 2], [2, 1]) + max_desc.set_name("max_buffer") + + sum_desc = mlir_common.MLIRMultiDimTile([g, 2], kernel.vector_lane, 0, vlane_stride) + sum_desc.set_tile_size_stride([g, 2], [2, 1]) + sum_desc.set_name("sum_buffer") + + out_io_tile_desc = mlir_common.MLIRMultiDimTile([1, 1, tile_e], kernel.vector_lane, 1, vlane_stride) + out_io_tile_desc.set_tile_size_stride([1, 1, tile_e], [0, tile_e, 1]) + out_io_tile_desc.set_name("out_io_buffer") + + partial_tile_desc = mlir_common.MLIRMultiDimTile([1, 1, tile_pack], kernel.vector_lane, 1, vlane_stride) + partial_tile_desc.set_tile_size_stride([1, 1, tile_pack], [0, tile_pack, 1]) + partial_tile_desc.set_name("partial_buffer") + + # Indices + kv = sympy.Symbol("kv") + qsub = sympy.Symbol("qsub") + dht = sympy.Symbol("dht") + dh0 = sympy.Symbol("dh0") + k0 = sympy.Symbol("k0") + blk = sympy.Symbol("blk") + s0 = sympy.Symbol("s0") + q_head = kv * g + qsub + + q_stride = q_tensor.stride() + k_stride = k_tensor.stride() + v_stride = v_tensor.stride() + + qk_idx = [q_head * q_stride[0], sympy.Integer(0), k0 * q_stride[2]] + kk_idx = [kv * k_stride[0], s0 * k_stride[1], k0 * k_stride[2]] + v_idx = [kv * v_stride[0], s0 * v_stride[1], dh0 * v_stride[2]] + + # partial tensor is view(HgDhTiles, nblk, tile_pack) contiguous + p_tensor = empty_strided(partial.get_layout().size, partial.get_layout().stride).view(HgDhTiles, nblk, tile_pack) + p_stride = p_tensor.stride() + # group head index: ((kv*g + qsub)*dh_tiles + dht) + gh = (kv * g + qsub) * dh_tiles + dht + partial_idx = [gh * p_stride[0], blk * p_stride[1], sympy.Integer(0)] + + kernel.loop_size = [tile_s, tile_e, tile_pack] + + kernel.render_options = dict( + KERNEL_NAME=self.name, + kernel=kernel, + H=H, + g=g, + Dh=Dh, + S=S, + BlkS=BlkS, + nblk=nblk, + tile_s=tile_s, + tile_e=tile_e, + dh_tiles=dh_tiles, + tile_pack=tile_pack, + io_stype=io_stype, + scale=self.scale, + query=query, + key=key, + value=value, + partial=partial, + q_tile_desc=q_tile_desc, + k_tile_desc=k_tile_desc, + v_tile_desc=v_tile_desc, + mul_tile_desc=mul_tile_desc, + score_desc=score_desc, + prob_desc=prob_desc, + out_io_tile_desc=out_io_tile_desc, + out_acc_tile_desc=out_acc_tile_desc, + max_desc=max_desc, + sum_desc=sum_desc, + partial_tile_desc=partial_tile_desc, + qk_idx=qk_idx, + kk_idx=kk_idx, + v_idx=v_idx, + partial_idx=partial_idx, + input_reorder=self.input_reorder, + ) + + kernel.epilogue_info = dict( + output_node=self.output_node.name, + sram_var="partial_buffer", + dram_var="partial", + dram_idx=partial_idx, + dram_tile_desc=partial_tile_desc, + nr_rdim=0, + r_dim_size=0, + dim_aliasing={"kv": "kv", "qsub": "qsub", "dht": "dht", "dh0": "dh0", "k0": "k0", "blk": "blk", "s0": "s0"}, + ) + return self._template_from_string(DECODE_GQA_SDPA_PARTIAL_TEMPLATE).render(**kernel.render_options) + + +class MLIRDecodeGQASDPAReduceTemplate(MLIRTemplate): + def __init__(self, input_nodes, layout, BlkS: int = 1024, input_reorder=None): + super().__init__("kernel", input_nodes, layout, input_reorder) + self.BlkS = BlkS + + def render(self, kernel: MLIRTemplateKernel, template_buffer_node=None, epilogue_nodes=None, prologue_nodes=None, tile_info=None, **kwargs): + partial = self.input_nodes[0] + out = self.output_node + + tile_e = kernel.vector_lane + tile_pack = tile_e * 2 + + # Infer sizes from partial layout: (HgDhTiles, nblk, tile_pack) + HgDhTiles, nblk, _ = partial.get_size() + io_stype = mlir_common.DTYPE_TO_MLIR[out.get_dtype()] + + vlane_stride = 1 + partial_tile_desc = mlir_common.MLIRMultiDimTile([1, 1, tile_pack], kernel.vector_lane, 1, vlane_stride) + partial_tile_desc.set_tile_size_stride([1, 1, tile_pack], [0, tile_pack, 1]) + partial_tile_desc.set_name("partial_buffer") + partial_tile_desc.offset = partial.get_layout().offset + + out_acc_tile_desc = mlir_common.MLIRMultiDimTile([1, 1, tile_e], kernel.vector_lane, 1, vlane_stride) + out_acc_tile_desc.set_tile_size_stride([1, 1, tile_e], [0, tile_e, 1]) + out_acc_tile_desc.set_name("out_acc_buffer") + + max_desc = mlir_common.MLIRMultiDimTile([1, 2], kernel.vector_lane, 0, vlane_stride) + max_desc.set_tile_size_stride([1, 2], [2, 1]) + max_desc.set_name("max_buffer") + + sum_desc = mlir_common.MLIRMultiDimTile([1, 2], kernel.vector_lane, 0, vlane_stride) + sum_desc.set_tile_size_stride([1, 2], [2, 1]) + sum_desc.set_name("sum_buffer") + + out_tile_desc = mlir_common.MLIRMultiDimTile([1, 1, tile_e], kernel.vector_lane, 1, vlane_stride) + out_tile_desc.set_tile_size_stride([1, 1, tile_e], [0, tile_e, 1]) + out_tile_desc.set_name("out_buffer") + + # Indexing: partial is already 3D; out is (Hq,1,Dh) but view as (Hq*Dh/tile_e, 1, tile_e) + p_tensor = empty_strided(partial.get_layout().size, partial.get_layout().stride) + p_stride = p_tensor.stride() + gh = sympy.Symbol("gh") + blk = sympy.Symbol("blk") + partial_idx = [gh * p_stride[0], blk * p_stride[1], sympy.Integer(0)] + + # out view + out_tensor4 = empty_strided(out.get_layout().size, out.get_layout().stride) + B, Hq, Lq, Dh = out_tensor4.shape + assert B == 1 and Lq == 1 + dh_tiles = int(Dh) // int(tile_e) + out_tensor = out_tensor4.view(Hq * dh_tiles, 1, tile_e) + o_stride = out_tensor.stride() + out_idx = [gh * o_stride[0], sympy.Integer(0), sympy.Integer(0)] + + kernel.loop_size = [tile_pack, tile_e, 1] + + kernel.render_options = dict( + KERNEL_NAME=self.name, + kernel=kernel, + HgDhTiles=HgDhTiles, + nblk=nblk, + tile_e=tile_e, + tile_pack=tile_pack, + io_stype=io_stype, + partial=partial, + out=out, + partial_tile_desc=partial_tile_desc, + out_acc_tile_desc=out_acc_tile_desc, + max_desc=max_desc, + sum_desc=sum_desc, + out_tile_desc=out_tile_desc, + partial_idx=partial_idx, + out_idx=out_idx, + input_reorder=self.input_reorder, + ) + + kernel.epilogue_info = dict( + output_node=self.output_node.name, + sram_var="out_buffer", + dram_var="out", + dram_idx=out_idx, + dram_tile_desc=out_tile_desc, + nr_rdim=0, + r_dim_size=0, + dim_aliasing={"gh": "gh", "blk": "blk"}, + ) + return self._template_from_string(DECODE_GQA_SDPA_REDUCE_TEMPLATE).render(**kernel.render_options) diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index b2df1d06..53db988b 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -904,7 +904,7 @@ def hook(): def def_dma_op(self, dma_type, dram_var:str, index_list:list, tile_desc:mlir_common.MLIRMultiDimTile, subtile_size:list=[], async_type=None, indent_size=0, priority: int = 5, lazy_mode: bool = True, - dram_stride:list=None, dram_offset=None): + dram_stride:list=None, dram_offset=None, padding: int = 0): # Todo. Remove legacy behavior (i.e., index_list parsing) def generate_dma_code(): """Internal method to generate DMA code directly.""" @@ -948,7 +948,7 @@ def generate_dma_code(): zero_cse = self.get_const_cse(0, "index") sram_index_var = ", ".join([f"%{str(zero_cse)}"]*tile_desc.get_nr_dim()) - attribute_parts = [f"dram_stride={_dram_stride}", f"sram_stride={sram_strides}", "padding=0"] + attribute_parts = [f"dram_stride={_dram_stride}", f"sram_stride={sram_strides}", f"padding={int(padding)}"] if subtile_size: attribute_parts.append(f"subtile_size={subtile_size}, async={int(async_type) if async_type is not None else 1}") attribute = " {" + ", ".join(attribute_parts) + "}" diff --git a/tests/test_sdpa.py b/tests/test_sdpa.py index 6ffd6f2e..ed7ae8f8 100644 --- a/tests/test_sdpa.py +++ b/tests/test_sdpa.py @@ -58,6 +58,60 @@ def test_scaled_dot_product_attention(device, backends="flash"): print("All tests passed!") +def test_scaled_dot_product_attention_gqa_single_batch(device): + """ + Focused GQA testcases for single-batch (n==1). + Shapes: + q: (B, Hq, Lq, Dh) + k: (B, H, S, Dh) + v: (B, H, S, Dh) + """ + torch.manual_seed(0) + + B = 1 + # Decode-focused: include a larger S to hit BlkS logic + seq_len_list = [128, 256, 1024] + head_dim_list = [64, 128] + # GQA ratios requested: Hq / H in {4, 5, 8, 16}. + # Keep H=1 to directly realize those ratios. + gqa_ratios = [4, 5, 8, 16] + H = 1 + + for seq_len in seq_len_list: + for head_dim in head_dim_list: + for ratio in gqa_ratios: + Hq = ratio * H + + clear_caches() + # Decode shape: Lq == 1 + q = torch.rand(B, Hq, 1, head_dim, dtype=torch.float32) + k = torch.rand(B, H, seq_len, head_dim, dtype=torch.float32) + v = torch.rand(B, H, seq_len, head_dim, dtype=torch.float32) + + # NPU + q_npu = q.to(device=device) + k_npu = k.to(device=device) + v_npu = v.to(device=device) + opt_fn = torch.compile(dynamic=False)(F.scaled_dot_product_attention) + out = opt_fn(q_npu, k_npu, v_npu, attn_mask=None, dropout_p=0.0, is_causal=True, enable_gqa=True) + + # CPU reference + cpu_device = torch.device("cpu") + cpu_out = F.scaled_dot_product_attention( + q.to(device=cpu_device), + k.to(device=cpu_device), + v.to(device=cpu_device), + attn_mask=None, + dropout_p=0.0, + is_causal=True, + enable_gqa=True, + ) + + name = f"SDPA-GQA(B: {B}, Hq: {Hq}, H: {H}, S: {seq_len}, head_dim: {head_dim})" + test_result(name, out, cpu_out) + + print("All GQA single-batch tests passed!") + def clear_caches(): import os from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache @@ -69,5 +123,6 @@ def clear_caches(): if __name__ == "__main__": device = torch.device('npu:0') - test_scaled_dot_product_attention(device, backends="flash") + # test_scaled_dot_product_attention(device, backends="flash") + test_scaled_dot_product_attention_gqa_single_batch(device) \ No newline at end of file From 59bd8f8ddc9ff86f35a45a347d7f5c7d5fe8bf7a Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Thu, 12 Mar 2026 21:29:16 +0900 Subject: [PATCH 22/31] WIP --- PyTorchSimFrontend/mlir/mlir_lowering.py | 1 + PyTorchSimFrontend/mlir/mlir_sdpa_template.py | 660 +++++++++++------- 2 files changed, 398 insertions(+), 263 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_lowering.py b/PyTorchSimFrontend/mlir/mlir_lowering.py index ac7eb853..7b2c07bf 100644 --- a/PyTorchSimFrontend/mlir/mlir_lowering.py +++ b/PyTorchSimFrontend/mlir/mlir_lowering.py @@ -89,6 +89,7 @@ def tuned_flash_sdpa( ) partial_tmpl = MLIRDecodeGQASDPAPartialTemplate([query, key, value], partial_layout, scale, BlkS=BlkS) partial = partial_tmpl.generate().output_node() + partial.realize() reduce_tmpl = MLIRDecodeGQASDPAReduceTemplate([partial], layout, BlkS=BlkS) out_node = reduce_tmpl.generate().output_node() return (out_node, None, None, None, None, None, None, None, None) diff --git a/PyTorchSimFrontend/mlir/mlir_sdpa_template.py b/PyTorchSimFrontend/mlir/mlir_sdpa_template.py index 1cd810e8..077a8cd2 100644 --- a/PyTorchSimFrontend/mlir/mlir_sdpa_template.py +++ b/PyTorchSimFrontend/mlir/mlir_sdpa_template.py @@ -16,17 +16,87 @@ from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel +def _make_offset_map_with_sym(strides, sym_dim, sym_stride, offset=0): + """Like _make_offset_map but injects a block symbol ``s`` into dimension ``sym_dim``. + + The effective index for that dimension becomes ``d{sym_dim} + sym_stride * s``. + Use this to keep ``affine.for`` bounds static and encode the block contribution + directly inside the ``affine.apply`` call that computes the DRAM offset. + + Args: + strides: per-dimension DRAM strides. + sym_dim: which dimension carries the block symbol. + sym_stride: multiplier for the symbol (1 for abs-position loops like FLASH + ``%blk``; ``BlkS`` for block-index loops like PARTIAL ``%blk``). + offset: constant layout offset. + + Returns: + MLIR affine_map string with one symbol, e.g. + ``affine_map<(d0, d1, d2)[s] -> (d0 * 8192 + (d1 + 128 * s) * 64 + d2)>`` + """ + n = len(strides) + terms = [] + for j, sv in enumerate(strides): + sv = int(sv) + if sv == 0: + continue + if j == sym_dim: + inner = f"d{j} + s" if sym_stride == 1 else f"d{j} + {sym_stride} * s" + terms.append(f"({inner})" if sv == 1 else f"({inner}) * {sv}") + else: + terms.append(f"d{j}" if sv == 1 else f"d{j} * {sv}") + try: + off = int(offset) + except (TypeError, ValueError): + off = 0 + if off: + terms.append(str(off)) + dim_str = ", ".join(f"d{j}" for j in range(n)) + expr = " + ".join(terms) if terms else "0" + return f"affine_map<({dim_str})[s] -> ({expr})>" + + +def _make_offset_map(strides, offset=0): + """Generate an MLIR affine_map string for a flat DRAM base-address. + + Args: + strides: list of integer per-dimension strides. + A stride of 0 means the dimension does not contribute. + offset: constant layout offset (e.g. from IRNode.get_layout().offset). + + Returns: + MLIR affine_map string, e.g. ``affine_map<(d0, d1) -> (d0 * 128 + d1)>`` + """ + n = len(strides) + terms = [] + for j, s in enumerate(strides): + s = int(s) + if s == 1: + terms.append(f"d{j}") + elif s != 0: + terms.append(f"d{j} * {s}") + try: + off = int(offset) + except (TypeError, ValueError): + off = 0 + if off: + terms.append(str(off)) + dim_str = ", ".join(f"d{j}" for j in range(n)) + expr = " + ".join(terms) if terms else "0" + return f"affine_map<({dim_str}) -> ({expr})>" + + def flash_sdpa_args( - query : TensorBox, - key : TensorBox, + query : TensorBox, + key : TensorBox, value : TensorBox) -> list: """ Arg processing for flash SDPA. - Its logic is based on: + Its logic is based on: mm_args() which is in torch._inductor.kernel.mm_common.py (142 line). """ - # Materialize input buffers for the codegen backend. + # Materialize input buffers for the codegen backend. query, key, value = realize_inputs(query, key, value) # query : (n, hq, l, e) @@ -43,7 +113,7 @@ def flash_sdpa_args( n = V.graph.sizevars.guard_equals(nq, nk) n = V.graph.sizevars.guard_equals(nq, nk) - + h = V.graph.sizevars.guard_equals(hk, hv) s = V.graph.sizevars.guard_equals(sk, sv) e = V.graph.sizevars.guard_equals(eq, ek) @@ -62,7 +132,7 @@ def flash_sdpa_args( raise NotImplementedError( f"Flash SDPA currently requires e to be a multiple of vlanes (e: {e}, vlanes: {vector_lane})." ) - + # Minimal GQA support (single-batch only for now). # We map each query head to a KV head by grouping: hq = g * h. if hq != h: @@ -70,14 +140,14 @@ def flash_sdpa_args( raise NotImplementedError("Flash SDPA GQA is currently supported only for n == 1.") if (hq % h) != 0: raise NotImplementedError(f"Flash SDPA GQA requires hq % h == 0 (hq: {hq}, h: {h}).") - + layout = FixedLayout( query.get_device(), query.get_dtype(), [n, hq, l, ev] ) - return [n, hq, h, l, s, e, ev, layout, query, key, value] + return [n, hq, h, l, s, e, ev, layout, query, key, value] def calculate_scale(query: torch.Tensor, scale: float) -> float: """ @@ -109,7 +179,7 @@ def calculate_scale(query: torch.Tensor, scale: float) -> float: {{ kernel.def_sram_buffer("query", q_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("key", k_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("value", v_tile_desc, indent_size=2) }} - + // Output {{ kernel.def_sram_buffer("out", out_tile_desc, indent_size=2) }} @@ -117,7 +187,7 @@ def calculate_scale(query: torch.Tensor, scale: float) -> float: {{ kernel.def_sram_buffer("mul", mul_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("max", max_desc, indent_size=2) }} {{ kernel.def_sram_buffer("sum", sum_desc, indent_size=2) }} - + // Constants %c0 = arith.constant 0.0 : {{ data_stype }} %c1 = arith.constant 1.0 : {{ data_stype }} @@ -133,33 +203,36 @@ def calculate_scale(query: torch.Tensor, scale: float) -> float: %v_neg_inf_2x = arith.constant dense<-1.0e+30> : vector<2x{{ data_stype }}> %v_scale = vector.broadcast %c_scale : {{ data_stype }} to vector<{{ tile_s }}x{{ data_stype }}> - - {{ kernel.def_local_vars(indent_size=2) }} - + + {{ kernel.def_local_vars(indent_size=2) }} + affine.for %index0 = 0 to {{ b }} { affine.for %index3 = 0 to 1 step 1 { affine.for %index1 = 0 to {{ l }} step {{ tile_l }} { - {{ kernel.def_dma_op("MVIN", "query", q_idx, q_tile_desc, subtile_size=[1, subtile_l, subtile_e], indent_size=8) }} - + %q_dram_offset = affine.apply {{ q_offset_map }}(%index0, %index1, %index3) + {{ kernel.def_dma_op("MVIN", "query", [], q_tile_desc, subtile_size=[1, subtile_l, subtile_e], indent_size=8, dram_stride=q_dram_stride, dram_offset="q_dram_offset") }} + affine.vector_store %v0_l, %out_buffer[0, 0, 0] : {{ out_tile_desc.get_mlir_shape(data_stype) }}, vector<{{ kernel.get_spad_size_per_lane(tile_l, tile_e) }}x{{ data_stype }}> - affine.vector_store %v_neg_inf_2x, %max_buffer[0, 0] : {{ max_desc.get_mlir_shape(data_stype) }}, vector<2x{{ data_stype }}> + affine.vector_store %v_neg_inf_2x, %max_buffer[0, 0] : {{ max_desc.get_mlir_shape(data_stype) }}, vector<2x{{ data_stype }}> affine.vector_store %v0_2x, %sum_buffer[0, 0] : {{ sum_desc.get_mlir_shape(data_stype) }}, vector<2x{{ data_stype }}> - + %qt_buffer2D = memref.reinterpret_cast %q_buffer to offset: [0], sizes: [{{ tile_e }}, {{ tile_l }}], strides: [{{ tile_l }}, 1] : {{ q_tile_desc.get_mlir_shape(data_stype) }} to memref<{{ tile_e }}x{{ tile_l }}x{{ data_stype }}, 1> %ot_buffer2D = memref.reinterpret_cast %out_buffer to offset: [0], sizes: [{{ tile_e }}, {{ tile_l }}], strides: [{{ tile_l }}, 1] : {{ out_tile_desc.get_mlir_shape(data_stype) }} to memref<{{ tile_e }}x{{ tile_l }}x{{ data_stype }}, 1> affine.for %index2 = 0 to {{ s }} step {{ tile_s }} { - {{ kernel.def_dma_op("MVIN", "key", k_idx, k_tile_desc, subtile_size=[1, subtile_s, subtile_e], indent_size=10) }} - {{ kernel.def_dma_op("MVIN", "value", v_idx, v_tile_desc, subtile_size=[1, subtile_s, subtile_e], indent_size=10) }} + %k_dram_offset = affine.apply {{ k_offset_map }}(%index0, %index2, %index3) + {{ kernel.def_dma_op("MVIN", "key", [], k_tile_desc, subtile_size=[1, subtile_s, subtile_e], indent_size=10, dram_stride=k_dram_stride, dram_offset="k_dram_offset") }} + %v_dram_offset = affine.apply {{ v_offset_map }}(%index0, %index2, %index3) + {{ kernel.def_dma_op("MVIN", "value", [], v_tile_desc, subtile_size=[1, subtile_s, subtile_e], indent_size=10, dram_stride=v_dram_stride, dram_offset="v_dram_offset") }} - affine.vector_store %v0_s, %mul_buffer[0, 0] : {{ mul_tile_desc.get_mlir_shape(data_stype) }}, vector<{{ kernel.get_spad_size_per_lane(tile_s, tile_l) }}x{{ data_stype }}> + affine.vector_store %v0_s, %mul_buffer[0, 0] : {{ mul_tile_desc.get_mlir_shape(data_stype) }}, vector<{{ kernel.get_spad_size_per_lane(tile_s, tile_l) }}x{{ data_stype }}> %k_buffer2D = memref.reinterpret_cast %k_buffer to offset: [0], sizes: [{{ tile_s }}, {{ tile_e }}], strides: [{{ tile_e }}, 1] : {{ k_tile_desc.get_mlir_shape(data_stype) }} to memref<{{ tile_s }}x{{ tile_e }}x{{ data_stype }}, 1> %vt_buffer2D = memref.reinterpret_cast %v_buffer to offset: [0], sizes: [{{ tile_e }}, {{ tile_s }}], strides: [{{ tile_s }}, 1] : {{ v_tile_desc.get_mlir_shape(data_stype) }} to memref<{{ tile_e }}x{{ tile_s }}x{{ data_stype }}, 1> - + // key @ query.t and scaling. - linalg.matmul + linalg.matmul { idx_map = array } ins(%k_buffer2D, %qt_buffer2D : memref<{{ tile_s }}x{{ tile_e }}x{{ data_stype }}, 1>, memref<{{ tile_e }}x{{ tile_l }}x{{ data_stype }}, 1>) outs(%mul_buffer : {{ mul_tile_desc.get_mlir_shape(data_stype) }}) @@ -168,7 +241,7 @@ def calculate_scale(query: torch.Tensor, scale: float) -> float: %scaled_mul_vec = arith.mulf %raw_mul_vec, %v_scale : vector<{{ tile_s }}x{{ data_stype }}> affine.vector_store %scaled_mul_vec, %mul_buffer[0, 0] : {{ mul_tile_desc.get_mlir_shape(data_stype) }}, vector<{{ tile_s }}x{{ data_stype }}> - + // Find new max. %old_max = affine.vector_load %max_buffer[0,0] : {{ max_desc.get_mlir_shape(data_stype) }}, vector<2x{{ data_stype }}> @@ -182,22 +255,22 @@ def calculate_scale(query: torch.Tensor, scale: float) -> float: %max_reduced_1 = vector.multi_reduction , %max_cast, %v_neg_inf_2x [0] : vector<8x2x{{ data_stype }}> to vector<2x{{ data_stype }}> %max_shuffled = vector.shuffle %max_reduced_1, %max_reduced_1 [1, 0] : vector<2x{{ data_stype }}>, vector<2x{{ data_stype }}> %max_reduced_2 = arith.maximumf %max_reduced_1, %max_shuffled : vector<2x{{ data_stype }}> - - %new_max = arith.maximumf %max_reduced_2, %old_max : vector<2x{{ data_stype }}> + + %new_max = arith.maximumf %max_reduced_2, %old_max : vector<2x{{ data_stype }}> affine.vector_store %new_max, %max_buffer[0, 0] : {{ max_desc.get_mlir_shape(data_stype) }}, vector<2x{{ data_stype }}> - + // Compute rescale factors: exp(old_max - new_max) %max_diff = arith.subf %old_max, %new_max : vector<2x{{ data_stype }}> %max_diff_scalar = vector.extract %max_diff[0] : {{ data_stype }} from vector<2x{{ data_stype }}> - - %rescale_bcast_e = vector.broadcast %max_diff_scalar : {{ data_stype }} to vector<{{ tile_e }}x{{ data_stype }}> - %exp_rescale_e = math.exp %rescale_bcast_e : vector<{{ tile_e }}x{{ data_stype }}> + + %rescale_bcast_e = vector.broadcast %max_diff_scalar : {{ data_stype }} to vector<{{ tile_e }}x{{ data_stype }}> + %exp_rescale_e = math.exp %rescale_bcast_e : vector<{{ tile_e }}x{{ data_stype }}> %rescale_bcast_2 = vector.broadcast %max_diff_scalar : {{ data_stype }} to vector<2x{{ data_stype }}> %exp_rescale_2 = math.exp %rescale_bcast_2 : vector<2x{{ data_stype }}> - + // Rescale previous out and sum accumulators %old_out = affine.vector_load %ot_buffer2D[0, 0] : memref<{{ tile_e }}x{{ tile_l }}x{{ data_stype }}, 1>, vector<{{ tile_e }}x{{ data_stype }}> %rescaled_out = arith.mulf %exp_rescale_e, %old_out : vector<{{ tile_e }}x{{ data_stype }}> @@ -206,16 +279,16 @@ def calculate_scale(query: torch.Tensor, scale: float) -> float: %old_sum = affine.vector_load %sum_buffer[0, 0] : {{ sum_desc.get_mlir_shape(data_stype) }}, vector<2x{{ data_stype }}> %rescaled_sum = arith.mulf %old_sum, %exp_rescale_2 : vector<2x{{ data_stype }}> - + // Shift scores and apply exp: exp(x - new_max) %scaled_scores_reload = affine.vector_load %mul_buffer[0, 0] : {{ mul_tile_desc.get_mlir_shape(data_stype) }}, vector<{{ tile_s }}x{{ data_stype }}> %new_max_scalar = vector.extract %new_max[0] : {{ data_stype }} from vector<2x{{ data_stype }}> %new_max_bcast = vector.broadcast %new_max_scalar : {{ data_stype }} to vector<{{ tile_s }}x{{ data_stype }}> - + %shifted_scores = arith.subf %scaled_scores_reload, %new_max_bcast : vector<{{ tile_s }}x{{ data_stype }}> %exp_scores = math.exp %shifted_scores : vector<{{ tile_s }}x{{ data_stype }}> affine.vector_store %exp_scores, %mul_buffer[0, 0] : {{ mul_tile_desc.get_mlir_shape(data_stype) }}, vector<{{ tile_s }}x{{ data_stype }}> - + // accumulate current sum %chunk_sum_res = affine.for %index5 = 0 to {{ tile_s }} step {{ chunk_size }} iter_args(%iter_sum=%v0_c) -> (vector<{{ chunk_size }}x{{ data_stype }}>) { @@ -223,19 +296,19 @@ def calculate_scale(query: torch.Tensor, scale: float) -> float: %local_sum = arith.addf %chunk_exp, %iter_sum : vector<{{ chunk_size }}x{{ data_stype }}> affine.yield %local_sum : vector<{{ chunk_size }}x{{ data_stype }}> } - + %zero_2x = vector.broadcast %c0 : {{ data_stype }} to vector<2x{{ data_stype }}> %sum_cast = vector.shape_cast %chunk_sum_res : vector<{{ chunk_size }}x{{ data_stype }}> to vector<{{ chunk_size // 2 }}x2x{{ data_stype }}> %sum_reduced_1 = vector.multi_reduction , %sum_cast, %zero_2x [0] : vector<8x2x{{ data_stype }}> to vector<2x{{ data_stype }}> %sum_shuffled = vector.shuffle %sum_reduced_1, %sum_reduced_1 [1, 0] : vector<2x{{ data_stype }}>, vector<2x{{ data_stype }}> %sum_reduced_2 = arith.addf %sum_reduced_1, %sum_shuffled : vector<2x{{ data_stype }}> - + %new_sum = arith.addf %sum_reduced_2, %rescaled_sum : vector<2x{{ data_stype }}> affine.vector_store %new_sum, %sum_buffer[0, 0] : {{ sum_desc.get_mlir_shape(data_stype) }}, vector<2x{{ data_stype }}> - + // value.t @ mul - linalg.matmul + linalg.matmul { idx_map = array } ins(%vt_buffer2D, %mul_buffer : memref<{{ tile_e }}x{{ tile_s }}x{{ data_stype }}, 1>, {{ mul_tile_desc.get_mlir_shape(data_stype) }}) outs(%ot_buffer2D : memref<{{ tile_e }}x{{ tile_l }}x{{ data_stype }}, 1>) @@ -244,20 +317,21 @@ def calculate_scale(query: torch.Tensor, scale: float) -> float: // out @ row_sum^(-1) %final_row_sum = affine.vector_load %sum_buffer[0, 0] : {{ sum_desc.get_mlir_shape(data_stype) }}, vector<2x{{ data_stype }}> %one_2x = vector.broadcast %c1 : {{ data_stype }} to vector<2x{{ data_stype }}> - + %reciprocal_row_sum_2x = arith.divf %one_2x, %final_row_sum : vector<2x{{ data_stype }}> %reciprocal_scalar = vector.extract %reciprocal_row_sum_2x[0] : {{ data_stype }} from vector<2x{{ data_stype }}> %reciprocal_bcast_e = vector.broadcast %reciprocal_scalar : {{ data_stype }} to vector<{{ tile_e }}x{{ data_stype }}> - + %accumulated_out = affine.vector_load %ot_buffer2D[0, 0] : memref<{{ tile_e }}x{{ tile_l }}x{{ data_stype }}, 1>, vector<{{ tile_e }}x{{ data_stype }}> %stable_final_out = arith.mulf %accumulated_out, %reciprocal_bcast_e : vector<{{ tile_e }}x{{ data_stype }}> affine.vector_store %stable_final_out, %ot_buffer2D[0, 0] : memref<{{ tile_e }}x{{ tile_l }}x{{ data_stype }}, 1>, vector<{{ tile_e }}x{{ data_stype }}> - {{ kernel.store_output(indent_size=8) }} - } { accumulation_loop=true } + %out_dram_offset = affine.apply {{ out_offset_map }}(%index0, %index1, %index3) + {{ kernel.def_dma_op("MVOUT", "out", [], out_tile_desc, indent_size=8, dram_stride=out_dram_stride, dram_offset="out_dram_offset") }} + } { accumulation_loop=true } } { outer_loop=true } } { outer_loop=true } - return + return } """ @@ -273,10 +347,10 @@ def render(self, prologue_nodes: Optional[List[IRNode]] = None, tile_info = None, **kwargs): - + # Except for kernel, other arguments are usually None. query, key, value, out, q_tensor, k_tensor, v_tensor, out_tensor, b, l, s, e, ev, n_extra_node, n_prologue_node = self.extract_info(template_buffer_node, epilogue_nodes, prologue_nodes) - + if tile_info is None: tile_l, tile_s, tile_e, subtile_l, subtile_s, subtile_e = self.select_tile(kernel, l, s, e, n_extra_node, 0, n_prologue_node)[0] else: @@ -299,10 +373,10 @@ def render(self, # Prepare tile descriptors for input and output tensors. # Intermediate buffers (transient data) do not require DRAM settings(dram stride and dram indices) - # as they are not synchronized with external DRAM. + # as they are not synchronized with external DRAM. # DRAM and SRAM tile shapes must match. vlane_stride = 1 - + # (n, l, s, e, ev) loop_dim = [sympy.Symbol("index0"), sympy.Symbol("index1"), sympy.Symbol("index2"), sympy.Symbol("index3")] @@ -317,11 +391,10 @@ def render(self, q_tile_desc.set_tile_size_stride(q_tile_size, q_tile_stride) q_tile_desc.set_name("q_buffer") q_tile_desc.offset = query.get_layout().offset - # DRAM settings + # DRAM settings q_stride = q_tensor.stride() - q_idx = [loop_dim[0]*q_stride[0], loop_dim[1]*q_stride[1], loop_dim[3]*q_stride[2]] # To keep index arguemnt order, we used index_list - # Since we use a weight-stationary approach in the Systolic Array (SA), + # Since we use a weight-stationary approach in the Systolic Array (SA), # the split axis of the first operand differs from a standard linear algebra matmul. # The first operand (key) must be split along the column axis. # This logic aligns with the relationship between the dot product's summation direction and the hardware's accumulation direction in the SA. @@ -335,7 +408,6 @@ def render(self, k_tile_desc.offset = key.get_layout().offset # DRAM settings k_stride = k_tensor.stride() - k_idx = [loop_dim[0]*k_stride[0], loop_dim[2]*k_stride[1], loop_dim[3]*k_stride[2]] # Since we compute mul = key @ query.t, we perform out.t = (value.t @ Softmax(mul).t).t, # which simplifies to (value.t @ Softmax(mul)) @@ -349,19 +421,17 @@ def render(self, v_tile_desc.offset = value.get_layout().offset # DRAM settings v_stride = v_tensor.stride() - v_idx = [loop_dim[0]*v_stride[0], loop_dim[2]*v_stride[1], loop_dim[3]*v_stride[2]] # To keep index arguemnt order, we used index_list # Output is also stored in transposed format to match the value.t @ Softmax(mul) operation. # SRAM settings vlane_split_axis = 1 - out_tile_size = [1, tile_l, tile_e] - out_tile_stride=[0, tile_e, 1] + out_tile_size = [1, tile_l, tile_e] + out_tile_stride=[0, tile_e, 1] out_tile_desc = mlir_common.MLIRMultiDimTile(out_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) out_tile_desc.set_tile_size_stride(out_tile_size, out_tile_stride) out_tile_desc.set_name("out_buffer") # DRAM settings out_stride = out.get_layout().stride[1:] - out_idx = [loop_dim[0]*out_stride[0], loop_dim[1]*out_stride[1], loop_dim[3]*out_stride[2]] # Intermediate buffers @@ -393,28 +463,46 @@ def render(self, # For reduction chunk_size = 16 + # DMA strides and offset affine maps (dram_stride + dram_offset style) + q_dram_stride = [int(q_stride[0]), int(q_stride[1]), int(q_stride[2])] + k_dram_stride = [int(k_stride[0]), int(k_stride[1]), int(k_stride[2])] + v_dram_stride = [int(v_stride[0]), int(v_stride[1]), int(v_stride[2])] + out_dram_stride = [int(out_stride[0]), int(out_stride[1]), int(out_stride[2])] + + q_offset_map = _make_offset_map(q_dram_stride, q_tile_desc.offset) + k_offset_map = _make_offset_map(k_dram_stride, k_tile_desc.offset) + v_offset_map = _make_offset_map(v_dram_stride, v_tile_desc.offset) + out_offset_map = _make_offset_map(out_dram_stride, 0) + + # Keep out_idx only for epilogue_info (not in render_options) + out_idx = [loop_dim[0]*out_stride[0], loop_dim[1]*out_stride[1], loop_dim[3]*out_stride[2]] + kernel.render_options = dict( KERNEL_NAME = self.name, kernel = kernel, - b = b, - l = l, - s = s, + b = b, + l = l, + s = s, e = e, # Input sizes (dram) - tile_l = tile_l, - tile_s = tile_s, + tile_l = tile_l, + tile_s = tile_s, tile_e = tile_e, # Tile sizes (sram) - subtile_l = subtile_l, - subtile_s = subtile_s, - subtile_e = subtile_e, # Subtile sizes (sram) + subtile_l = subtile_l, + subtile_s = subtile_s, + subtile_e = subtile_e, # Subtile sizes (sram) data_stype="f32", - query = query, + query = query, key = key, - value = value, + value = value, out = out, # Inputs and output (dram) - q_idx = q_idx, - k_idx = k_idx, - v_idx = v_idx, - out_idx = out_idx, # Strides (dram) + q_dram_stride = q_dram_stride, + k_dram_stride = k_dram_stride, + v_dram_stride = v_dram_stride, + out_dram_stride = out_dram_stride, # Per-dim DRAM strides + q_offset_map = q_offset_map, + k_offset_map = k_offset_map, + v_offset_map = v_offset_map, + out_offset_map = out_offset_map, # Affine maps for base address q_tile_desc = q_tile_desc, k_tile_desc = k_tile_desc, v_tile_desc = v_tile_desc, @@ -423,19 +511,8 @@ def render(self, max_desc = max_desc, sum_desc = sum_desc, # Intermediate buffer descriptions (sram) scale = self.scale, - chunk_size = chunk_size, - input_reorder = self.input_reorder # ETC - ) - - kernel.epilogue_info = dict( - output_node = self.output_node.name, - sram_var = "out_buffer", - dram_var = "out", - dram_idx = out_idx, - dram_tile_desc = out_tile_desc, - nr_rdim = nr_rdim, - r_dim_size = 0, - dim_aliasing = epilogue_dim_aliasing + chunk_size = chunk_size, + input_reorder = self.input_reorder # ETC ) code = self._template_from_string(template).render(**kernel.render_options) @@ -445,7 +522,7 @@ def render(self, def extract_info(self, template_buffer_node, epilogue_nodes, prologue_nodes): if template_buffer_node is not None: self.output_node = template_buffer_node - + query = self.input_nodes[0] key = self.input_nodes[1] value = self.input_nodes[2] @@ -462,7 +539,7 @@ def extract_info(self, template_buffer_node, epilogue_nodes, prologue_nodes): v_tensor = v_tensor.view([-1, v_tensor.shape[-2], v_tensor.shape[-1]]) out_tensor = out_tensor.view([-1, out_tensor.shape[-2], out_tensor.shape[-1]]) - b, l, s, e, ev = q_tensor.size(0), q_tensor.size(1), k_tensor.size(1), k_tensor.size(2), v_tensor.size(2) + b, l, s, e, ev = q_tensor.size(0), q_tensor.size(1), k_tensor.size(1), k_tensor.size(2), v_tensor.size(2) n_extra_node = len(epilogue_nodes) if epilogue_nodes is not None else 0 n_prologue_node = len(prologue_nodes) if prologue_nodes is not None else 0 @@ -549,7 +626,7 @@ def select_tile(self, kernel, l, s, e, n_extra_node, n_extra_read, n_prologue_no } } - affine.for %s0 = %blk to (%blk + {{ BlkS }}) step {{ tile_s }} { + affine.for %s0 = 0 to {{ BlkS }} step {{ tile_s }} { // Accumulate score per qsub so K tiles can be shared across qsub. affine.for %qsub = 0 to {{ g }} { affine.vector_store %v0_s_acc, %score_buffer[%qsub, 0] : {{ score_desc.get_mlir_shape(acc_stype) }}, vector<{{ tile_s }}x{{ acc_stype }}> @@ -557,11 +634,14 @@ def select_tile(self, kernel, l, s, e, n_extra_node, n_extra_read, n_prologue_no affine.for %k0 = 0 to {{ Dh }} step {{ tile_e }} { // Load K slice once for all qsub. - {{ kernel.def_dma_op("MVIN", "key", kk_idx, k_tile_desc, subtile_size=[1, tile_s, tile_e], indent_size=10, padding=1) }} + %kk_offset = affine.apply {{ kk_offset_map_blk }}(%kv, %s0, %k0)[%blk] + {{ kernel.def_dma_op("MVIN", "key", [], k_tile_desc, subtile_size=[1, tile_s, tile_e], indent_size=10, padding=1, dram_stride=k_dram_stride, dram_offset="kk_offset") }} %k2D = memref.reinterpret_cast %k_buffer to offset: [0], sizes: [{{ tile_s }}, {{ tile_e }}], strides: [{{ tile_e }}, 1] : {{ k_tile_desc.get_mlir_shape(io_stype) }} to memref<{{ tile_s }}x{{ tile_e }}x{{ io_stype }}, 1> affine.for %qsub = 0 to {{ g }} { - {{ kernel.def_dma_op("MVIN", "query", qk_idx, q_tile_desc, subtile_size=[1, 1, tile_e], indent_size=12) }} + %q_head = affine.apply affine_map<(d0, d1) -> (d0 * {{ g }} + d1)>(%kv, %qsub) + %qk_offset = affine.apply {{ qk_offset_map }}(%q_head, %k0) + {{ kernel.def_dma_op("MVIN", "query", [], q_tile_desc, subtile_size=[1, 1, tile_e], indent_size=12, dram_stride=q_dram_stride, dram_offset="qk_offset") }} %q2D = memref.reinterpret_cast %q_buffer to offset: [0], sizes: [{{ tile_e }}, 1], strides: [1, 1] : {{ q_tile_desc.get_mlir_shape(io_stype) }} to memref<{{ tile_e }}x1x{{ io_stype }}, 1> // mul = k @ q -> (tile_s x 1) in io dtype, then upcast and accumulate. @@ -571,9 +651,9 @@ def select_tile(self, kernel, l, s, e, n_extra_node, n_extra_read, n_prologue_no outs(%mul_buffer : {{ mul_tile_desc.get_mlir_shape(io_stype) }}) %raw_mul_io = affine.vector_load %mul_buffer[0, 0] : {{ mul_tile_desc.get_mlir_shape(io_stype) }}, vector<{{ tile_s }}x{{ io_stype }}> - %raw_mul = arith.extf %raw_mul_io : vector<{{ tile_s }}x{{ io_stype }}> to vector<{{ tile_s }}x{{ acc_stype }}> + {% if io_stype != acc_stype %}%raw_mul = arith.extf %raw_mul_io : vector<{{ tile_s }}x{{ io_stype }}> to vector<{{ tile_s }}x{{ acc_stype }}>{% endif %} %old_score = affine.vector_load %score_buffer[%qsub, 0] : {{ score_desc.get_mlir_shape(acc_stype) }}, vector<{{ tile_s }}x{{ acc_stype }}> - %new_score = arith.addf %old_score, %raw_mul : vector<{{ tile_s }}x{{ acc_stype }}> + %new_score = arith.addf %old_score, {{ "%raw_mul" if io_stype != acc_stype else "%raw_mul_io" }} : vector<{{ tile_s }}x{{ acc_stype }}> affine.vector_store %new_score, %score_buffer[%qsub, 0] : {{ score_desc.get_mlir_shape(acc_stype) }}, vector<{{ tile_s }}x{{ acc_stype }}> } { accumulation_loop=true } } { accumulation_loop=true } @@ -618,8 +698,8 @@ def select_tile(self, kernel, l, s, e, n_extra_node, n_extra_read, n_prologue_no %shifted = arith.subf %scaled_mul_vec, %new_max_bcast : vector<{{ tile_s }}x{{ acc_stype }}> %exp_scores = math.exp %shifted : vector<{{ tile_s }}x{{ acc_stype }}> // For SV matmul: downcast softmax output to io dtype (common in practice) - %exp_scores_io = arith.truncf %exp_scores : vector<{{ tile_s }}x{{ acc_stype }}> to vector<{{ tile_s }}x{{ io_stype }}> - affine.vector_store %exp_scores_io, %prob_buffer[%qsub, 0] : {{ prob_desc.get_mlir_shape(io_stype) }}, vector<{{ tile_s }}x{{ io_stype }}> + {% if io_stype != acc_stype %}%exp_scores_io = arith.truncf %exp_scores : vector<{{ tile_s }}x{{ acc_stype }}> to vector<{{ tile_s }}x{{ io_stype }}>{% endif %} + affine.vector_store {{ "%exp_scores_io" if io_stype != acc_stype else "%exp_scores" }}, %prob_buffer[%qsub, 0] : {{ prob_desc.get_mlir_shape(io_stype) }}, vector<{{ tile_s }}x{{ io_stype }}> // sum += reduce(exp_scores) %sum_cast = vector.shape_cast %exp_scores : vector<{{ tile_s }}x{{ acc_stype }}> to vector<{{ tile_s // 2 }}x2x{{ acc_stype }}> @@ -635,7 +715,8 @@ def select_tile(self, kernel, l, s, e, n_extra_node, n_extra_read, n_prologue_no // 2) SV accumulation: for each output dh tile, load V once and share across qsub. affine.for %dht = 0 to {{ dh_tiles }} { %dh0 = affine.apply affine_map<(d0) -> (d0 * {{ tile_e }})>(%dht) - {{ kernel.def_dma_op("MVIN", "value", v_idx, v_tile_desc, subtile_size=[1, tile_s, tile_e], indent_size=10, padding=0) }} + %v_offset = affine.apply {{ v_offset_map_blk }}(%kv, %s0, %dh0)[%blk] + {{ kernel.def_dma_op("MVIN", "value", [], v_tile_desc, subtile_size=[1, tile_s, tile_e], indent_size=10, padding=0, dram_stride=v_dram_stride, dram_offset="v_offset") }} %v2D = memref.reinterpret_cast %v_buffer to offset: [0], sizes: [{{ tile_e }}, {{ tile_s }}], strides: [{{ tile_s }}, 1] : {{ v_tile_desc.get_mlir_shape(io_stype) }} to memref<{{ tile_e }}x{{ tile_s }}x{{ io_stype }}, 1> affine.for %qsub = 0 to {{ g }} { @@ -649,9 +730,9 @@ def select_tile(self, kernel, l, s, e, n_extra_node, n_extra_read, n_prologue_no outs(%out_io_2D : memref<{{ tile_e }}x1x{{ io_stype }}, 1>) %out_io_vec = affine.vector_load %out_io_buffer[0, 0, 0] : {{ out_io_tile_desc.get_mlir_shape(io_stype) }}, vector<{{ tile_e }}x{{ io_stype }}> - %out_io_f32 = arith.extf %out_io_vec : vector<{{ tile_e }}x{{ io_stype }}> to vector<{{ tile_e }}x{{ acc_stype }}> + {% if io_stype != acc_stype %}%out_io_f32 = arith.extf %out_io_vec : vector<{{ tile_e }}x{{ io_stype }}> to vector<{{ tile_e }}x{{ acc_stype }}>{% endif %} %out_acc_vec = affine.vector_load %out_acc_buffer[%qsub, %dht, 0] : {{ out_acc_tile_desc.get_mlir_shape(acc_stype) }}, vector<{{ tile_e }}x{{ acc_stype }}> - %out_acc_new = arith.addf %out_acc_vec, %out_io_f32 : vector<{{ tile_e }}x{{ acc_stype }}> + %out_acc_new = arith.addf %out_acc_vec, {{ "%out_io_f32" if io_stype != acc_stype else "%out_io_vec" }} : vector<{{ tile_e }}x{{ acc_stype }}> affine.vector_store %out_acc_new, %out_acc_buffer[%qsub, %dht, 0] : {{ out_acc_tile_desc.get_mlir_shape(acc_stype) }}, vector<{{ tile_e }}x{{ acc_stype }}> } { accumulation_loop=true } } { accumulation_loop=true } @@ -669,9 +750,11 @@ def select_tile(self, kernel, l, s, e, n_extra_node, n_extra_read, n_prologue_no %dh0 = affine.apply affine_map<(d0) -> (d0 * {{ tile_e }})>(%dht) %acc_out = affine.vector_load %out_acc_buffer[%qsub, %dht, 0] : {{ out_acc_tile_desc.get_mlir_shape(acc_stype) }}, vector<{{ tile_e }}x{{ acc_stype }}> %final_out_acc = arith.mulf %acc_out, %inv_bcast : vector<{{ tile_e }}x{{ acc_stype }}> - %final_out_io = arith.truncf %final_out_acc : vector<{{ tile_e }}x{{ acc_stype }}> to vector<{{ tile_e }}x{{ io_stype }}> - affine.vector_store %final_out_io, %out_io_buffer[0, 0, 0] : {{ out_io_tile_desc.get_mlir_shape(io_stype) }}, vector<{{ tile_e }}x{{ io_stype }}> - {{ kernel.store_output(indent_size=10) }} + {% if io_stype != acc_stype %}%final_out_io = arith.truncf %final_out_acc : vector<{{ tile_e }}x{{ acc_stype }}> to vector<{{ tile_e }}x{{ io_stype }}>{% endif %} + affine.vector_store {{ "%final_out_io" if io_stype != acc_stype else "%final_out_acc" }}, %out_io_buffer[0, 0, 0] : {{ out_io_tile_desc.get_mlir_shape(io_stype) }}, vector<{{ tile_e }}x{{ io_stype }}> + %q_head = affine.apply affine_map<(d0, d1) -> (d0 * {{ g }} + d1)>(%kv, %qsub) + %out_offset = affine.apply {{ out_offset_map }}(%q_head, %dh0) + {{ kernel.def_dma_op("MVOUT", "out", [], out_io_tile_desc, indent_size=10, dram_stride=out_dram_stride, dram_offset="out_offset") }} } } { outer_loop=true } } { outer_loop=true } @@ -690,7 +773,12 @@ def __init__(self, input_nodes, layout, scale, BlkS: int = 1024, input_reorder=N def render(self, kernel: MLIRTemplateKernel, template_buffer_node=None, epilogue_nodes=None, prologue_nodes=None, tile_info=None, **kwargs): # Decode-only: q is (B,Hq,1,Dh) - query, key, value, out = self.input_nodes[0], self.input_nodes[1], self.input_nodes[2], self.output_node + # Use template_buffer_node (the actual V.graph-registered CUDATemplateBuffer with its + # real name e.g. "buf0") when available, instead of the placeholder self.output_node + # (always named "buf_out"). This ensures output_buffers["buf0"] maps correctly + # in mlir_argdefs, which looks up buffer_types by the actual DRAM buffer name. + query, key, value, out = self.input_nodes[0], self.input_nodes[1], self.input_nodes[2], \ + template_buffer_node if template_buffer_node is not None else self.output_node # Materialize tensors for stride metadata q_tensor4 = empty_strided(query.layout.size, query.layout.stride) @@ -765,14 +853,7 @@ def render(self, kernel: MLIRTemplateKernel, template_buffer_node=None, epilogue sum_desc.set_tile_size_stride([g, 2], [2, 1]) sum_desc.set_name("sum_buffer") - # Indices - kv = sympy.Symbol("kv") - qsub = sympy.Symbol("qsub") - dh0 = sympy.Symbol("dh0") - k0 = sympy.Symbol("k0") - s0 = sympy.Symbol("s0") - q_head = kv * g + qsub - + # Strides from 3D tensor views q_stride = q_tensor.stride() k_stride = k_tensor.stride() v_stride = v_tensor.stride() @@ -780,11 +861,34 @@ def render(self, kernel: MLIRTemplateKernel, template_buffer_node=None, epilogue out_tensor = empty_strided(out.get_layout().size, out.get_layout().stride).view(Hq, 1, Dh) out_stride = out_tensor.stride() - # QK indices use k0 reduction over Dh - qk_idx = [q_head * q_stride[0], sympy.Integer(0), k0 * q_stride[2]] - kk_idx = [kv * k_stride[0], s0 * k_stride[1], k0 * k_stride[2]] - # V and output use dh0 tile offset - v_idx = [kv * v_stride[0], s0 * v_stride[1], dh0 * v_stride[2]] + # DMA strides (per-dimension DRAM strides for each tile) + k_dram_stride = [int(k_stride[0]), int(k_stride[1]), int(k_stride[2])] + # Q: q_head is pre-computed in template; stride[1]=0 since Lq=1 + q_dram_stride = [int(q_stride[0]), 0, int(q_stride[2])] + v_dram_stride = [int(v_stride[0]), int(v_stride[1]), int(v_stride[2])] + # out: q_head is pre-computed; stride[1]=0 since Lq=1 + out_dram_stride = [int(out_stride[0]), 0, int(out_stride[2])] + + # Affine maps for flat DRAM base address (used with pre-computed loop var expressions) + # K: offset(kv, s0, k0) + kk_offset_map = _make_offset_map(k_dram_stride, k_tile_desc.offset) + # Q: offset(q_head, k0) -- q_head = kv*g+qsub pre-computed in template + qk_offset_map = _make_offset_map([int(q_stride[0]), int(q_stride[2])], q_tile_desc.offset) + # V: offset(kv, s0, dh0) + v_offset_map = _make_offset_map(v_dram_stride, v_tile_desc.offset) + # Out: offset(q_head, dh0) -- q_head pre-computed in template + out_offset_map = _make_offset_map([int(out_stride[0]), int(out_stride[2])], 0) + # Blk-symbol variants: %s0 is relative (0..BlkS-1), %blk is the absolute + # block start (steps by BlkS), so actual_s = s0_rel + 1*blk → sym_stride=1. + kk_offset_map_blk = _make_offset_map_with_sym(k_dram_stride, sym_dim=1, sym_stride=1, offset=k_tile_desc.offset) + v_offset_map_blk = _make_offset_map_with_sym(v_dram_stride, sym_dim=1, sym_stride=1, offset=v_tile_desc.offset) + + # Keep sympy-based out_idx only for epilogue_info (not in render_options) + kv = sympy.Symbol("kv") + qsub = sympy.Symbol("qsub") + dh0 = sympy.Symbol("dh0") + s0 = sympy.Symbol("s0") + q_head = kv * g + qsub out_idx = [q_head * out_stride[0], sympy.Integer(0), dh0 * out_stride[2]] kernel.loop_size = [tile_s, tile_e, 1] @@ -819,24 +923,21 @@ def render(self, kernel: MLIRTemplateKernel, template_buffer_node=None, epilogue prob_desc=prob_desc, max_desc=max_desc, sum_desc=sum_desc, - qk_idx=qk_idx, - kk_idx=kk_idx, - v_idx=v_idx, - out_idx=out_idx, + # DMA strides + k_dram_stride=k_dram_stride, + q_dram_stride=q_dram_stride, + v_dram_stride=v_dram_stride, + out_dram_stride=out_dram_stride, + # Affine offset maps + kk_offset_map=kk_offset_map, + qk_offset_map=qk_offset_map, + v_offset_map=v_offset_map, + out_offset_map=out_offset_map, + kk_offset_map_blk=kk_offset_map_blk, + v_offset_map_blk=v_offset_map_blk, input_reorder=self.input_reorder, ) - kernel.epilogue_info = dict( - output_node=self.output_node.name, - sram_var="out_io_buffer", - dram_var="out", - dram_idx=out_idx, - dram_tile_desc=out_io_tile_desc, - nr_rdim=0, - r_dim_size=0, - dim_aliasing={"kv": "kv", "qsub": "qsub", "dh0": "dh0", "s0": "s0"}, - ) - return self._template_from_string(DECODE_GQA_SDPA_TEMPLATE).render(**kernel.render_options) @@ -891,27 +992,30 @@ def render(self, kernel: MLIRTemplateKernel, template_buffer_node=None, epilogue } } - affine.for %s0 = ({{ BlkS }} * %blk) to ({{ BlkS }} * (%blk + 1)) step {{ tile_s }} { + affine.for %s0 = 0 to {{ BlkS }} step {{ tile_s }} { // Accumulate score per qsub so K tiles can be shared across qsub. affine.for %qsub = 0 to {{ g }} { affine.vector_store %v0_s, %score_buffer[%qsub, 0] : {{ score_desc.get_mlir_shape("f32") }}, vector<{{ tile_s }}xf32> } affine.for %k0 = 0 to {{ Dh }} step {{ tile_e }} { - {{ kernel.def_dma_op("MVIN", "key", kk_idx, k_tile_desc, subtile_size=[1, tile_s, tile_e], indent_size=10, padding=1) }} + %kk_offset = affine.apply {{ kk_offset_map_blk }}(%kv, %s0, %k0)[%blk] + {{ kernel.def_dma_op("MVIN", "key", [], k_tile_desc, subtile_size=[1, tile_s, tile_e], indent_size=10, padding=1, dram_stride=k_dram_stride, dram_offset="kk_offset") }} %k2D = memref.reinterpret_cast %k_buffer to offset: [0], sizes: [{{ tile_s }}, {{ tile_e }}], strides: [{{ tile_e }}, 1] : {{ k_tile_desc.get_mlir_shape(io_stype) }} to memref<{{ tile_s }}x{{ tile_e }}x{{ io_stype }}, 1> affine.for %qsub = 0 to {{ g }} { - {{ kernel.def_dma_op("MVIN", "query", qk_idx, q_tile_desc, subtile_size=[1, 1, tile_e], indent_size=12) }} + %q_head = affine.apply affine_map<(d0, d1) -> (d0 * {{ g }} + d1)>(%kv, %qsub) + %qk_offset = affine.apply {{ qk_offset_map }}(%q_head, %k0) + {{ kernel.def_dma_op("MVIN", "query", [], q_tile_desc, subtile_size=[1, 1, tile_e], indent_size=12, dram_stride=q_dram_stride, dram_offset="qk_offset") }} %q2D = memref.reinterpret_cast %q_buffer to offset: [0], sizes: [{{ tile_e }}, 1], strides: [1, 1] : {{ q_tile_desc.get_mlir_shape(io_stype) }} to memref<{{ tile_e }}x1x{{ io_stype }}, 1> linalg.matmul { idx_map = array } ins(%k2D, %q2D : memref<{{ tile_s }}x{{ tile_e }}x{{ io_stype }}, 1>, memref<{{ tile_e }}x1x{{ io_stype }}, 1>) outs(%mul_buffer : {{ mul_tile_desc.get_mlir_shape(io_stype) }}) %raw_mul_io = affine.vector_load %mul_buffer[0, 0] : {{ mul_tile_desc.get_mlir_shape(io_stype) }}, vector<{{ tile_s }}x{{ io_stype }}> - %raw_mul = arith.extf %raw_mul_io : vector<{{ tile_s }}x{{ io_stype }}> to vector<{{ tile_s }}xf32> + {% if io_stype != "f32" %}%raw_mul = arith.extf %raw_mul_io : vector<{{ tile_s }}x{{ io_stype }}> to vector<{{ tile_s }}xf32>{% endif %} %old_score = affine.vector_load %score_buffer[%qsub, 0] : {{ score_desc.get_mlir_shape("f32") }}, vector<{{ tile_s }}xf32> - %new_score = arith.addf %old_score, %raw_mul : vector<{{ tile_s }}xf32> + %new_score = arith.addf %old_score, {{ "%raw_mul" if io_stype != "f32" else "%raw_mul_io" }} : vector<{{ tile_s }}xf32> affine.vector_store %new_score, %score_buffer[%qsub, 0] : {{ score_desc.get_mlir_shape("f32") }}, vector<{{ tile_s }}xf32> } { accumulation_loop=true } } { accumulation_loop=true } @@ -951,8 +1055,8 @@ def render(self, kernel: MLIRTemplateKernel, template_buffer_node=None, epilogue %new_max_bcast = vector.broadcast %new_max_scalar : f32 to vector<{{ tile_s }}xf32> %shifted = arith.subf %scaled, %new_max_bcast : vector<{{ tile_s }}xf32> %exp_scores = math.exp %shifted : vector<{{ tile_s }}xf32> - %exp_scores_io = arith.truncf %exp_scores : vector<{{ tile_s }}xf32> to vector<{{ tile_s }}x{{ io_stype }}> - affine.vector_store %exp_scores_io, %prob_buffer[%qsub, 0] : {{ prob_desc.get_mlir_shape(io_stype) }}, vector<{{ tile_s }}x{{ io_stype }}> + {% if io_stype != "f32" %}%exp_scores_io = arith.truncf %exp_scores : vector<{{ tile_s }}xf32> to vector<{{ tile_s }}x{{ io_stype }}>{% endif %} + affine.vector_store {{ "%exp_scores_io" if io_stype != "f32" else "%exp_scores" }}, %prob_buffer[%qsub, 0] : {{ prob_desc.get_mlir_shape(io_stype) }}, vector<{{ tile_s }}x{{ io_stype }}> %sum_cast = vector.shape_cast %exp_scores : vector<{{ tile_s }}xf32> to vector<{{ tile_s // 2 }}x2xf32> %zero_2x = vector.broadcast %c0 : f32 to vector<2xf32> @@ -966,7 +1070,8 @@ def render(self, kernel: MLIRTemplateKernel, template_buffer_node=None, epilogue // For each output dh tile, load V once and share it across qsub. affine.for %dht = 0 to {{ dh_tiles }} { %dh0 = affine.apply affine_map<(d0) -> (d0 * {{ tile_e }})>(%dht) - {{ kernel.def_dma_op("MVIN", "value", v_idx, v_tile_desc, subtile_size=[1, tile_s, tile_e], indent_size=10, padding=0) }} + %v_offset = affine.apply {{ v_offset_map_blk }}(%kv, %s0, %dh0)[%blk] + {{ kernel.def_dma_op("MVIN", "value", [], v_tile_desc, subtile_size=[1, tile_s, tile_e], indent_size=10, padding=0, dram_stride=v_dram_stride, dram_offset="v_offset") }} %v2D = memref.reinterpret_cast %v_buffer to offset: [0], sizes: [{{ tile_e }}, {{ tile_s }}], strides: [{{ tile_s }}, 1] : {{ v_tile_desc.get_mlir_shape(io_stype) }} to memref<{{ tile_e }}x{{ tile_s }}x{{ io_stype }}, 1> affine.for %qsub = 0 to {{ g }} { @@ -980,9 +1085,9 @@ def render(self, kernel: MLIRTemplateKernel, template_buffer_node=None, epilogue outs(%out_io_2D : memref<{{ tile_e }}x1x{{ io_stype }}, 1>) %out_io_vec = affine.vector_load %out_io_buffer[0, 0, 0] : {{ out_io_tile_desc.get_mlir_shape(io_stype) }}, vector<{{ tile_e }}x{{ io_stype }}> - %out_io_f32 = arith.extf %out_io_vec : vector<{{ tile_e }}x{{ io_stype }}> to vector<{{ tile_e }}xf32> + {% if io_stype != "f32" %}%out_io_f32 = arith.extf %out_io_vec : vector<{{ tile_e }}x{{ io_stype }}> to vector<{{ tile_e }}xf32>{% endif %} %out_acc_vec = affine.vector_load %out_acc_buffer[%qsub, %dht, 0] : {{ out_acc_tile_desc.get_mlir_shape("f32") }}, vector<{{ tile_e }}xf32> - %out_acc_new = arith.addf %out_acc_vec, %out_io_f32 : vector<{{ tile_e }}xf32> + %out_acc_new = arith.addf %out_acc_vec, {{ "%out_io_f32" if io_stype != "f32" else "%out_io_vec" }} : vector<{{ tile_e }}xf32> affine.vector_store %out_acc_new, %out_acc_buffer[%qsub, %dht, 0] : {{ out_acc_tile_desc.get_mlir_shape("f32") }}, vector<{{ tile_e }}xf32> } { accumulation_loop=true } } { accumulation_loop=true } @@ -1000,9 +1105,12 @@ def render(self, kernel: MLIRTemplateKernel, template_buffer_node=None, epilogue affine.for %dht = 0 to {{ dh_tiles }} { %out_vec = affine.vector_load %out_acc_buffer[%qsub, %dht, 0] : {{ out_acc_tile_desc.get_mlir_shape("f32") }}, vector<{{ tile_e }}xf32> - %packed = vector.concat %out_vec, %ml1 : vector<{{ tile_pack }}xf32> + %packed = vector.shuffle %out_vec, %ml1 [{{ range(tile_pack) | join(', ') }}] : vector<{{ tile_e }}xf32>, vector<{{ tile_e }}xf32> affine.vector_store %packed, %partial_buffer[0, 0, 0] : {{ partial_tile_desc.get_mlir_shape("f32") }}, vector<{{ tile_pack }}xf32> - {{ kernel.store_output(indent_size=10) }} + %q_head = affine.apply affine_map<(d0, d1) -> (d0 * {{ g }} + d1)>(%kv, %qsub) + %gh = affine.apply affine_map<(d0, d1) -> (d0 * {{ dh_tiles }} + d1)>(%q_head, %dht) + %partial_offset = affine.apply {{ partial_offset_map }}(%gh, %blk) + {{ kernel.def_dma_op("MVOUT", "partial", [], partial_tile_desc, indent_size=10, dram_stride=partial_dram_stride, dram_offset="partial_offset") }} } } { outer_loop=true } } { outer_loop=true } @@ -1012,83 +1120,6 @@ def render(self, kernel: MLIRTemplateKernel, template_buffer_node=None, epilogue """ -DECODE_GQA_SDPA_REDUCE_TEMPLATE = r""" -// Decode GQA SDPA reduce kernel: merge partials across blocks -// Input partial shape: (HgDhTiles, nblk, tile_pack) -{{kernel.def_global_vars()}} - -func.func @{{ KERNEL_NAME }}{{kernel.def_kernel(inputs=[partial], outputs=[out], names_str="partial, out", input_reorder=input_reorder)}} { - {{ kernel.def_sram_buffer("partial", partial_tile_desc, indent_size=2) }} - {{ kernel.def_sram_buffer("out_acc", out_acc_tile_desc, indent_size=2) }} - {{ kernel.def_sram_buffer("max", max_desc, indent_size=2) }} - {{ kernel.def_sram_buffer("sum", sum_desc, indent_size=2) }} - - %c0 = arith.constant 0.0 : f32 - %c1 = arith.constant 1.0 : f32 - %c_neg_inf = arith.constant -1.0e+30 : f32 - %v0_e = arith.constant dense<0.0> : vector<{{ tile_e }}xf32> - %v0_2x = arith.constant dense<0.0> : vector<2xf32> - %v_neg_inf_2x = arith.constant dense<-1.0e+30> : vector<2xf32> - - {{ kernel.def_local_vars(indent_size=2) }} - - affine.for %gh = 0 to {{ HgDhTiles }} { - // reset merged accumulators - affine.vector_store %v0_e, %out_acc_buffer[0, 0, 0] : {{ out_acc_tile_desc.get_mlir_shape("f32") }}, vector<{{ tile_e }}xf32> - affine.vector_store %v_neg_inf_2x, %max_buffer[0, 0] : {{ max_desc.get_mlir_shape("f32") }}, vector<2xf32> - affine.vector_store %v0_2x, %sum_buffer[0, 0] : {{ sum_desc.get_mlir_shape("f32") }}, vector<2xf32> - - affine.for %blk = 0 to {{ nblk }} { - {{ kernel.def_dma_op("MVIN", "partial", partial_idx, partial_tile_desc, subtile_size=[1, 1, tile_pack], indent_size=8) }} - %p = affine.vector_load %partial_buffer[0, 0, 0] : {{ partial_tile_desc.get_mlir_shape("f32") }}, vector<{{ tile_pack }}xf32> - %p2 = vector.shape_cast %p : vector<{{ tile_pack }}xf32> to vector<2x{{ tile_e }}xf32> - %o_j = vector.extract %p2[0] : vector<{{ tile_e }}xf32> from vector<2x{{ tile_e }}xf32> - %ml_j = vector.extract %p2[1] : vector<{{ tile_e }}xf32> from vector<2x{{ tile_e }}xf32> - %m_j = vector.extract %ml_j[0] : f32 from vector<{{ tile_e }}xf32> - %l_j = vector.extract %ml_j[1] : f32 from vector<{{ tile_e }}xf32> - - %old_max = affine.vector_load %max_buffer[0, 0] : {{ max_desc.get_mlir_shape("f32") }}, vector<2xf32> - %m_old = vector.extract %old_max[0] : f32 from vector<2xf32> - %m_new = arith.maximumf %m_old, %m_j : f32 - %m_new2 = vector.broadcast %m_new : f32 to vector<2xf32> - affine.vector_store %m_new2, %max_buffer[0, 0] : {{ max_desc.get_mlir_shape("f32") }}, vector<2xf32> - - %diff_old = arith.subf %m_old, %m_new : f32 - %diff_j = arith.subf %m_j, %m_new : f32 - %scale_old = math.exp %diff_old : f32 - %scale_j = math.exp %diff_j : f32 - %scale_old_e = vector.broadcast %scale_old : f32 to vector<{{ tile_e }}xf32> - %scale_j_e = vector.broadcast %scale_j : f32 to vector<{{ tile_e }}xf32> - - %o_old = affine.vector_load %out_acc_buffer[0, 0, 0] : {{ out_acc_tile_desc.get_mlir_shape("f32") }}, vector<{{ tile_e }}xf32> - %o_old_rs = arith.mulf %o_old, %scale_old_e : vector<{{ tile_e }}xf32> - %o_j_rs = arith.mulf %o_j, %scale_j_e : vector<{{ tile_e }}xf32> - %o_new = arith.addf %o_old_rs, %o_j_rs : vector<{{ tile_e }}xf32> - affine.vector_store %o_new, %out_acc_buffer[0, 0, 0] : {{ out_acc_tile_desc.get_mlir_shape("f32") }}, vector<{{ tile_e }}xf32> - - %old_sum = affine.vector_load %sum_buffer[0, 0] : {{ sum_desc.get_mlir_shape("f32") }}, vector<2xf32> - %l_old = vector.extract %old_sum[0] : f32 from vector<2xf32> - %l_new = arith.addf (arith.mulf %l_old, %scale_old : f32), (arith.mulf %l_j, %scale_j : f32) : f32 - %l_new2 = vector.broadcast %l_new : f32 to vector<2xf32> - affine.vector_store %l_new2, %sum_buffer[0, 0] : {{ sum_desc.get_mlir_shape("f32") }}, vector<2xf32> - } { accumulation_loop=true } - - // finalize: out = o / l - %sum2 = affine.vector_load %sum_buffer[0, 0] : {{ sum_desc.get_mlir_shape("f32") }}, vector<2xf32> - %l = vector.extract %sum2[0] : f32 from vector<2xf32> - %inv = arith.divf %c1, %l : f32 - %inv_e = vector.broadcast %inv : f32 to vector<{{ tile_e }}xf32> - %o = affine.vector_load %out_acc_buffer[0, 0, 0] : {{ out_acc_tile_desc.get_mlir_shape("f32") }}, vector<{{ tile_e }}xf32> - %out_f32 = arith.mulf %o, %inv_e : vector<{{ tile_e }}xf32> - %out_io = arith.truncf %out_f32 : vector<{{ tile_e }}xf32> to vector<{{ tile_e }}x{{ io_stype }}> - affine.vector_store %out_io, %out_buffer[0, 0, 0] : {{ out_tile_desc.get_mlir_shape(io_stype) }}, vector<{{ tile_e }}x{{ io_stype }}> - {{ kernel.store_output(indent_size=4) }} - } { outer_loop=true } - return -} -""" - - class MLIRDecodeGQASDPAPartialTemplate(MLIRTemplate): def __init__(self, input_nodes, layout, scale, BlkS: int = 1024, input_reorder=None): super().__init__("kernel", input_nodes, layout, input_reorder) @@ -1097,7 +1128,8 @@ def __init__(self, input_nodes, layout, scale, BlkS: int = 1024, input_reorder=N def render(self, kernel: MLIRTemplateKernel, template_buffer_node=None, epilogue_nodes=None, prologue_nodes=None, tile_info=None, **kwargs): query, key, value = self.input_nodes[0], self.input_nodes[1], self.input_nodes[2] - partial = self.output_node + # Use the actual registered buffer node (e.g. "buf0") instead of the placeholder "buf_out". + partial = template_buffer_node if template_buffer_node is not None else self.output_node q_tensor4 = empty_strided(query.layout.size, query.layout.stride) k_tensor4 = empty_strided(key.layout.size, key.layout.stride) @@ -1173,28 +1205,39 @@ def render(self, kernel: MLIRTemplateKernel, template_buffer_node=None, epilogue partial_tile_desc.set_tile_size_stride([1, 1, tile_pack], [0, tile_pack, 1]) partial_tile_desc.set_name("partial_buffer") - # Indices - kv = sympy.Symbol("kv") - qsub = sympy.Symbol("qsub") - dht = sympy.Symbol("dht") - dh0 = sympy.Symbol("dh0") - k0 = sympy.Symbol("k0") - blk = sympy.Symbol("blk") - s0 = sympy.Symbol("s0") - q_head = kv * g + qsub - + # Strides from 3D tensor views q_stride = q_tensor.stride() k_stride = k_tensor.stride() v_stride = v_tensor.stride() - qk_idx = [q_head * q_stride[0], sympy.Integer(0), k0 * q_stride[2]] - kk_idx = [kv * k_stride[0], s0 * k_stride[1], k0 * k_stride[2]] - v_idx = [kv * v_stride[0], s0 * v_stride[1], dh0 * v_stride[2]] - # partial tensor is view(HgDhTiles, nblk, tile_pack) contiguous p_tensor = empty_strided(partial.get_layout().size, partial.get_layout().stride).view(HgDhTiles, nblk, tile_pack) p_stride = p_tensor.stride() - # group head index: ((kv*g + qsub)*dh_tiles + dht) + + # DMA strides + k_dram_stride = [int(k_stride[0]), int(k_stride[1]), int(k_stride[2])] + q_dram_stride = [int(q_stride[0]), 0, int(q_stride[2])] + v_dram_stride = [int(v_stride[0]), int(v_stride[1]), int(v_stride[2])] + partial_dram_stride = [int(p_stride[0]), int(p_stride[1]), 1] + + # Affine offset maps + kk_offset_map = _make_offset_map(k_dram_stride, k_tile_desc.offset) + qk_offset_map = _make_offset_map([int(q_stride[0]), int(q_stride[2])], q_tile_desc.offset) + v_offset_map = _make_offset_map(v_dram_stride, v_tile_desc.offset) + # partial: offset(gh, blk) -- gh = (kv*g+qsub)*dh_tiles+dht, pre-computed in template + partial_offset_map = _make_offset_map([int(p_stride[0]), int(p_stride[1])], 0) + # Blk-symbol variants: %s0 is relative (0..BlkS-1), %blk is a block index (0..nblk-1), + # so actual_s = s0_rel + BlkS * blk → sym_stride=BlkS. + kk_offset_map_blk = _make_offset_map_with_sym(k_dram_stride, sym_dim=1, sym_stride=int(BlkS), offset=k_tile_desc.offset) + v_offset_map_blk = _make_offset_map_with_sym(v_dram_stride, sym_dim=1, sym_stride=int(BlkS), offset=v_tile_desc.offset) + + # Keep sympy-based indices only for epilogue_info + kv = sympy.Symbol("kv") + qsub = sympy.Symbol("qsub") + dht = sympy.Symbol("dht") + dh0 = sympy.Symbol("dh0") + blk = sympy.Symbol("blk") + q_head = kv * g + qsub gh = (kv * g + qsub) * dh_tiles + dht partial_idx = [gh * p_stride[0], blk * p_stride[1], sympy.Integer(0)] @@ -1230,26 +1273,110 @@ def render(self, kernel: MLIRTemplateKernel, template_buffer_node=None, epilogue max_desc=max_desc, sum_desc=sum_desc, partial_tile_desc=partial_tile_desc, - qk_idx=qk_idx, - kk_idx=kk_idx, - v_idx=v_idx, - partial_idx=partial_idx, + # DMA strides + k_dram_stride=k_dram_stride, + q_dram_stride=q_dram_stride, + v_dram_stride=v_dram_stride, + partial_dram_stride=partial_dram_stride, + # Affine offset maps + kk_offset_map=kk_offset_map, + qk_offset_map=qk_offset_map, + v_offset_map=v_offset_map, + partial_offset_map=partial_offset_map, + kk_offset_map_blk=kk_offset_map_blk, + v_offset_map_blk=v_offset_map_blk, input_reorder=self.input_reorder, ) - kernel.epilogue_info = dict( - output_node=self.output_node.name, - sram_var="partial_buffer", - dram_var="partial", - dram_idx=partial_idx, - dram_tile_desc=partial_tile_desc, - nr_rdim=0, - r_dim_size=0, - dim_aliasing={"kv": "kv", "qsub": "qsub", "dht": "dht", "dh0": "dh0", "k0": "k0", "blk": "blk", "s0": "s0"}, - ) return self._template_from_string(DECODE_GQA_SDPA_PARTIAL_TEMPLATE).render(**kernel.render_options) +DECODE_GQA_SDPA_REDUCE_TEMPLATE = r""" +// Decode GQA SDPA reduce kernel: merge partials across blocks +// Input partial shape: (HgDhTiles, nblk, tile_pack) +{{kernel.def_global_vars()}} + +func.func @{{ KERNEL_NAME }}{{kernel.def_kernel(inputs=[partial], outputs=[out], names_str="partial, out", input_reorder=input_reorder)}} { + {{ kernel.def_sram_buffer("partial", partial_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("out_acc", out_acc_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("out", out_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("max", max_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("sum", sum_desc, indent_size=2) }} + + %c0 = arith.constant 0.0 : f32 + %c1 = arith.constant 1.0 : f32 + %c_neg_inf = arith.constant -1.0e+30 : f32 + %v0_e = arith.constant dense<0.0> : vector<{{ tile_e }}xf32> + %v0_2x = arith.constant dense<0.0> : vector<2xf32> + %v_neg_inf_2x = arith.constant dense<-1.0e+30> : vector<2xf32> + + {{ kernel.def_local_vars(indent_size=2) }} + + affine.for %gh = 0 to {{ HgDhTiles }} { + // reset merged accumulators + affine.vector_store %v0_e, %out_acc_buffer[0, 0, 0] : {{ out_acc_tile_desc.get_mlir_shape("f32") }}, vector<{{ tile_e }}xf32> + affine.vector_store %v_neg_inf_2x, %max_buffer[0, 0] : {{ max_desc.get_mlir_shape("f32") }}, vector<2xf32> + affine.vector_store %v0_2x, %sum_buffer[0, 0] : {{ sum_desc.get_mlir_shape("f32") }}, vector<2xf32> + + affine.for %blk = 0 to {{ nblk }} { + %partial_offset = affine.apply {{ partial_offset_map }}(%gh, %blk) + {{ kernel.def_dma_op("MVIN", "partial", [], partial_tile_desc, subtile_size=[1, 1, tile_pack], indent_size=8, dram_stride=partial_dram_stride, dram_offset="partial_offset") }} + %p = affine.vector_load %partial_buffer[0, 0, 0] : {{ partial_tile_desc.get_mlir_shape("f32") }}, vector<{{ tile_pack }}xf32> + %p2 = vector.shape_cast %p : vector<{{ tile_pack }}xf32> to vector<2x{{ tile_e }}xf32> + %o_j = vector.extract %p2[0] : vector<{{ tile_e }}xf32> from vector<2x{{ tile_e }}xf32> + %ml_j = vector.extract %p2[1] : vector<{{ tile_e }}xf32> from vector<2x{{ tile_e }}xf32> + %m_j = vector.extract %ml_j[0] : f32 from vector<{{ tile_e }}xf32> + %l_j = vector.extract %ml_j[1] : f32 from vector<{{ tile_e }}xf32> + + %old_max = affine.vector_load %max_buffer[0, 0] : {{ max_desc.get_mlir_shape("f32") }}, vector<2xf32> + %m_old = vector.extract %old_max[0] : f32 from vector<2xf32> + %m_new = arith.maximumf %m_old, %m_j : f32 + %m_new2 = vector.broadcast %m_new : f32 to vector<2xf32> + affine.vector_store %m_new2, %max_buffer[0, 0] : {{ max_desc.get_mlir_shape("f32") }}, vector<2xf32> + + %diff_old = arith.subf %m_old, %m_new : f32 + %diff_j = arith.subf %m_j, %m_new : f32 + %diff_old_v = vector.broadcast %diff_old : f32 to vector<1xf32> + %diff_j_v = vector.broadcast %diff_j : f32 to vector<1xf32> + %scale_old_v = math.exp %diff_old_v : vector<1xf32> + %scale_j_v = math.exp %diff_j_v : vector<1xf32> + %scale_old = vector.extract %scale_old_v[0] : f32 from vector<1xf32> + %scale_j = vector.extract %scale_j_v[0] : f32 from vector<1xf32> + %scale_old_e = vector.broadcast %scale_old : f32 to vector<{{ tile_e }}xf32> + %scale_j_e = vector.broadcast %scale_j : f32 to vector<{{ tile_e }}xf32> + + %o_old = affine.vector_load %out_acc_buffer[0, 0, 0] : {{ out_acc_tile_desc.get_mlir_shape("f32") }}, vector<{{ tile_e }}xf32> + %o_old_rs = arith.mulf %o_old, %scale_old_e : vector<{{ tile_e }}xf32> + %o_j_rs = arith.mulf %o_j, %scale_j_e : vector<{{ tile_e }}xf32> + %o_new = arith.addf %o_old_rs, %o_j_rs : vector<{{ tile_e }}xf32> + affine.vector_store %o_new, %out_acc_buffer[0, 0, 0] : {{ out_acc_tile_desc.get_mlir_shape("f32") }}, vector<{{ tile_e }}xf32> + + %old_sum = affine.vector_load %sum_buffer[0, 0] : {{ sum_desc.get_mlir_shape("f32") }}, vector<2xf32> + %l_old = vector.extract %old_sum[0] : f32 from vector<2xf32> + %l_old_rs = arith.mulf %l_old, %scale_old : f32 + %l_j_rs = arith.mulf %l_j, %scale_j : f32 + %l_new = arith.addf %l_old_rs, %l_j_rs : f32 + %l_new2 = vector.broadcast %l_new : f32 to vector<2xf32> + affine.vector_store %l_new2, %sum_buffer[0, 0] : {{ sum_desc.get_mlir_shape("f32") }}, vector<2xf32> + } { accumulation_loop=true } + + // finalize: out = o / l + %sum2 = affine.vector_load %sum_buffer[0, 0] : {{ sum_desc.get_mlir_shape("f32") }}, vector<2xf32> + %l = vector.extract %sum2[0] : f32 from vector<2xf32> + %inv = arith.divf %c1, %l : f32 + %inv_e = vector.broadcast %inv : f32 to vector<{{ tile_e }}xf32> + %o = affine.vector_load %out_acc_buffer[0, 0, 0] : {{ out_acc_tile_desc.get_mlir_shape("f32") }}, vector<{{ tile_e }}xf32> + %out_f32 = arith.mulf %o, %inv_e : vector<{{ tile_e }}xf32> + {% if io_stype != "f32" %}%out_io = arith.truncf %out_f32 : vector<{{ tile_e }}xf32> to vector<{{ tile_e }}x{{ io_stype }}>{% endif %} + affine.vector_store {{ "%out_io" if io_stype != "f32" else "%out_f32" }}, %out_buffer[0, 0, 0] : {{ out_tile_desc.get_mlir_shape(io_stype) }}, vector<{{ tile_e }}x{{ io_stype }}> + %out_offset = affine.apply {{ out_offset_map }}(%gh) + {{ kernel.def_dma_op("MVOUT", "out", [], out_tile_desc, indent_size=4, dram_stride=out_dram_stride, dram_offset="out_offset") }} + } { outer_loop=true } + return +} +""" + + class MLIRDecodeGQASDPAReduceTemplate(MLIRTemplate): def __init__(self, input_nodes, layout, BlkS: int = 1024, input_reorder=None): super().__init__("kernel", input_nodes, layout, input_reorder) @@ -1257,7 +1384,8 @@ def __init__(self, input_nodes, layout, BlkS: int = 1024, input_reorder=None): def render(self, kernel: MLIRTemplateKernel, template_buffer_node=None, epilogue_nodes=None, prologue_nodes=None, tile_info=None, **kwargs): partial = self.input_nodes[0] - out = self.output_node + # Use the actual registered buffer node (e.g. "buf0") instead of the placeholder "buf_out". + out = template_buffer_node if template_buffer_node is not None else self.output_node tile_e = kernel.vector_lane tile_pack = tile_e * 2 @@ -1288,21 +1416,33 @@ def render(self, kernel: MLIRTemplateKernel, template_buffer_node=None, epilogue out_tile_desc.set_tile_size_stride([1, 1, tile_e], [0, tile_e, 1]) out_tile_desc.set_name("out_buffer") - # Indexing: partial is already 3D; out is (Hq,1,Dh) but view as (Hq*Dh/tile_e, 1, tile_e) + # Partial tensor strides p_tensor = empty_strided(partial.get_layout().size, partial.get_layout().stride) p_stride = p_tensor.stride() - gh = sympy.Symbol("gh") - blk = sympy.Symbol("blk") - partial_idx = [gh * p_stride[0], blk * p_stride[1], sympy.Integer(0)] - # out view + # Out view: (Hq*dh_tiles, 1, tile_e) out_tensor4 = empty_strided(out.get_layout().size, out.get_layout().stride) B, Hq, Lq, Dh = out_tensor4.shape assert B == 1 and Lq == 1 dh_tiles = int(Dh) // int(tile_e) out_tensor = out_tensor4.view(Hq * dh_tiles, 1, tile_e) o_stride = out_tensor.stride() - out_idx = [gh * o_stride[0], sympy.Integer(0), sympy.Integer(0)] + + # DMA strides + partial_dram_stride = [int(p_stride[0]), int(p_stride[1]), 1] + out_dram_stride = [int(o_stride[0]), 0, 0] + + # Affine offset maps + # partial: offset(gh, blk) + partial_offset_map = _make_offset_map([int(p_stride[0]), int(p_stride[1])], partial_tile_desc.offset) + # out: offset(gh) -- single dimension + out_offset_map = _make_offset_map([int(o_stride[0])], 0) + + # Keep sympy-based indices for epilogue_info + gh = sympy.Symbol("gh") + blk = sympy.Symbol("blk") + partial_idx = [gh * p_stride[0], blk * p_stride[1], sympy.Integer(0)] + out_idx = [gh * o_stride[0], sympy.Integer(0), sympy.Integer(0)] kernel.loop_size = [tile_pack, tile_e, 1] @@ -1321,19 +1461,13 @@ def render(self, kernel: MLIRTemplateKernel, template_buffer_node=None, epilogue max_desc=max_desc, sum_desc=sum_desc, out_tile_desc=out_tile_desc, - partial_idx=partial_idx, - out_idx=out_idx, + # DMA strides + partial_dram_stride=partial_dram_stride, + out_dram_stride=out_dram_stride, + # Affine offset maps + partial_offset_map=partial_offset_map, + out_offset_map=out_offset_map, input_reorder=self.input_reorder, ) - kernel.epilogue_info = dict( - output_node=self.output_node.name, - sram_var="out_buffer", - dram_var="out", - dram_idx=out_idx, - dram_tile_desc=out_tile_desc, - nr_rdim=0, - r_dim_size=0, - dim_aliasing={"gh": "gh", "blk": "blk"}, - ) return self._template_from_string(DECODE_GQA_SDPA_REDUCE_TEMPLATE).render(**kernel.render_options) From bfc2b22b334599fe8ddd959adb2e17ac1f576474 Mon Sep 17 00:00:00 2001 From: HamHyungkyu Date: Fri, 13 Mar 2026 19:37:08 +0900 Subject: [PATCH 23/31] [Frontend/template] SPDA implementation debug --- PyTorchSimFrontend/extension_codecache.py | 2 - PyTorchSimFrontend/mlir/mlir_sdpa_template.py | 592 ++---------------- 2 files changed, 48 insertions(+), 546 deletions(-) diff --git a/PyTorchSimFrontend/extension_codecache.py b/PyTorchSimFrontend/extension_codecache.py index b1c457d3..d3ac7259 100644 --- a/PyTorchSimFrontend/extension_codecache.py +++ b/PyTorchSimFrontend/extension_codecache.py @@ -37,7 +37,6 @@ def mlir_compile_command(filename, vectorlane_size, vlen=256): f""" {extension_config.CONFIG_TORCHSIM_LLVM_PATH}/mlir-opt \ -test-loop-padding \ - -dma-fine-grained='systolic-array-size={vectorlane_size}' \ -global-idx='vlen={vlen}' \ -test-pytorchsim-to-vcix='systolic-array-size={vectorlane_size} vlen={vlen}' \ -test-memref-to-gemmini="vectorlane={vectorlane_size}" \ @@ -87,7 +86,6 @@ def mlir_gem5_compile_command(filename, sample_filename, tog_file, vectorlane_si f""" {extension_config.CONFIG_TORCHSIM_LLVM_PATH}/mlir-opt \ -test-loop-padding='timing_mode=1' \ - -dma-fine-grained='systolic-array-size={vectorlane_size}' \ -global-idx='vlen={vlen}' \ -test-pytorchsim-to-vcix='systolic-array-size={vectorlane_size} vlen={vlen}' \ -test-tile-operation-graph='vectorlane={vectorlane_size} tls_mode={extension_config.CONFIG_TLS_MODE}' \ diff --git a/PyTorchSimFrontend/mlir/mlir_sdpa_template.py b/PyTorchSimFrontend/mlir/mlir_sdpa_template.py index 077a8cd2..adcc7801 100644 --- a/PyTorchSimFrontend/mlir/mlir_sdpa_template.py +++ b/PyTorchSimFrontend/mlir/mlir_sdpa_template.py @@ -563,384 +563,6 @@ def select_tile(self, kernel, l, s, e, n_extra_node, n_extra_read, n_prologue_no return tile_candidates -# --------------------------- -# Decode-only GQA SDPA (Lq == 1) -# --------------------------- - -DECODE_GQA_SDPA_TEMPLATE = r""" -// Decode GQA SDPA kernel (Lq == 1) -// B = {{ B }} -// Hq = {{ Hq }} -// H = {{ H }} -// g = {{ g }} -// S = {{ S }} -// Dh = {{ Dh }} -// BlkS = {{ BlkS }} -// tile_s = {{ tile_s }} -// tile_e = {{ tile_e }} -// dh_tiles = {{ dh_tiles }} -{{kernel.def_global_vars()}} - -func.func @{{ KERNEL_NAME }}{{kernel.def_kernel(inputs=[query, key, value], outputs=[out], names_str="query, key, value, out", input_reorder=input_reorder)}} { - // IO buffers follow input dtype (fp16/bf16/f32) - {{ kernel.def_sram_buffer("query", q_tile_desc, indent_size=2) }} - {{ kernel.def_sram_buffer("key", k_tile_desc, indent_size=2) }} - {{ kernel.def_sram_buffer("value", v_tile_desc, indent_size=2) }} - // Softmax output used for SV matmul (io dtype) - {{ kernel.def_sram_buffer("mul", mul_tile_desc, indent_size=2) }} - {{ kernel.def_sram_buffer("score", score_desc, indent_size=2) }} - {{ kernel.def_sram_buffer("prob", prob_desc, indent_size=2) }} - // Accumulator in fp32 (stable) - {{ kernel.def_sram_buffer("out_acc", out_acc_tile_desc, indent_size=2) }} - // Temp output in io dtype for SV matmul result - {{ kernel.def_sram_buffer("out_io", out_io_tile_desc, indent_size=2) }} - // Softmax running stats in fp32 - {{ kernel.def_sram_buffer("max", max_desc, indent_size=2) }} - {{ kernel.def_sram_buffer("sum", sum_desc, indent_size=2) }} - - %c0 = arith.constant 0.0 : {{ acc_stype }} - %c1 = arith.constant 1.0 : {{ acc_stype }} - %c_scale = arith.constant {{ scale }} : {{ acc_stype }} - %c_neg_inf = arith.constant -1.0e+30 : {{ acc_stype }} - - %v0_e_acc = arith.constant dense<0.0> : vector<{{ tile_e }}x{{ acc_stype }}> - %v0_e_io = arith.constant dense<0.0> : vector<{{ tile_e }}x{{ io_stype }}> - %v0_2x = arith.constant dense<0.0> : vector<2x{{ acc_stype }}> - %v_neg_inf_2x = arith.constant dense<-1.0e+30> : vector<2x{{ acc_stype }}> - %v0_s_acc = arith.constant dense<0.0> : vector<{{ tile_s }}x{{ acc_stype }}> - - %v_scale = vector.broadcast %c_scale : {{ acc_stype }} to vector<{{ tile_s }}x{{ acc_stype }}> - - {{ kernel.def_local_vars(indent_size=2) }} - - // kv_head parallelism is the natural unit for GQA reuse - affine.for %kv = 0 to {{ H }} { - // Process S in blocks (BlkS). Sequential inside a core. - affine.for %blk = 0 to {{ S }} step {{ BlkS }} { - // Initialize per-qsub accumulators for this (kv, blk) - affine.for %qsub = 0 to {{ g }} { - affine.vector_store %v_neg_inf_2x, %max_buffer[%qsub, 0] : {{ max_desc.get_mlir_shape(acc_stype) }}, vector<2x{{ acc_stype }}> - affine.vector_store %v0_2x, %sum_buffer[%qsub, 0] : {{ sum_desc.get_mlir_shape(acc_stype) }}, vector<2x{{ acc_stype }}> - affine.for %dht = 0 to {{ dh_tiles }} { - affine.vector_store %v0_e_acc, %out_acc_buffer[%qsub, %dht, 0] : {{ out_acc_tile_desc.get_mlir_shape(acc_stype) }}, vector<{{ tile_e }}x{{ acc_stype }}> - } - } - - affine.for %s0 = 0 to {{ BlkS }} step {{ tile_s }} { - // Accumulate score per qsub so K tiles can be shared across qsub. - affine.for %qsub = 0 to {{ g }} { - affine.vector_store %v0_s_acc, %score_buffer[%qsub, 0] : {{ score_desc.get_mlir_shape(acc_stype) }}, vector<{{ tile_s }}x{{ acc_stype }}> - } - - affine.for %k0 = 0 to {{ Dh }} step {{ tile_e }} { - // Load K slice once for all qsub. - %kk_offset = affine.apply {{ kk_offset_map_blk }}(%kv, %s0, %k0)[%blk] - {{ kernel.def_dma_op("MVIN", "key", [], k_tile_desc, subtile_size=[1, tile_s, tile_e], indent_size=10, padding=1, dram_stride=k_dram_stride, dram_offset="kk_offset") }} - %k2D = memref.reinterpret_cast %k_buffer to offset: [0], sizes: [{{ tile_s }}, {{ tile_e }}], strides: [{{ tile_e }}, 1] : {{ k_tile_desc.get_mlir_shape(io_stype) }} to memref<{{ tile_s }}x{{ tile_e }}x{{ io_stype }}, 1> - - affine.for %qsub = 0 to {{ g }} { - %q_head = affine.apply affine_map<(d0, d1) -> (d0 * {{ g }} + d1)>(%kv, %qsub) - %qk_offset = affine.apply {{ qk_offset_map }}(%q_head, %k0) - {{ kernel.def_dma_op("MVIN", "query", [], q_tile_desc, subtile_size=[1, 1, tile_e], indent_size=12, dram_stride=q_dram_stride, dram_offset="qk_offset") }} - %q2D = memref.reinterpret_cast %q_buffer to offset: [0], sizes: [{{ tile_e }}, 1], strides: [1, 1] : {{ q_tile_desc.get_mlir_shape(io_stype) }} to memref<{{ tile_e }}x1x{{ io_stype }}, 1> - - // mul = k @ q -> (tile_s x 1) in io dtype, then upcast and accumulate. - linalg.matmul - { idx_map = array } - ins(%k2D, %q2D : memref<{{ tile_s }}x{{ tile_e }}x{{ io_stype }}, 1>, memref<{{ tile_e }}x1x{{ io_stype }}, 1>) - outs(%mul_buffer : {{ mul_tile_desc.get_mlir_shape(io_stype) }}) - - %raw_mul_io = affine.vector_load %mul_buffer[0, 0] : {{ mul_tile_desc.get_mlir_shape(io_stype) }}, vector<{{ tile_s }}x{{ io_stype }}> - {% if io_stype != acc_stype %}%raw_mul = arith.extf %raw_mul_io : vector<{{ tile_s }}x{{ io_stype }}> to vector<{{ tile_s }}x{{ acc_stype }}>{% endif %} - %old_score = affine.vector_load %score_buffer[%qsub, 0] : {{ score_desc.get_mlir_shape(acc_stype) }}, vector<{{ tile_s }}x{{ acc_stype }}> - %new_score = arith.addf %old_score, {{ "%raw_mul" if io_stype != acc_stype else "%raw_mul_io" }} : vector<{{ tile_s }}x{{ acc_stype }}> - affine.vector_store %new_score, %score_buffer[%qsub, 0] : {{ score_desc.get_mlir_shape(acc_stype) }}, vector<{{ tile_s }}x{{ acc_stype }}> - } { accumulation_loop=true } - } { accumulation_loop=true } - - affine.for %qsub = 0 to {{ g }} { - %score_acc = affine.vector_load %score_buffer[%qsub, 0] : {{ score_desc.get_mlir_shape(acc_stype) }}, vector<{{ tile_s }}x{{ acc_stype }}> - // scale after full Dh reduction - %scaled_mul_vec = arith.mulf %score_acc, %v_scale : vector<{{ tile_s }}x{{ acc_stype }}> - - // Online softmax update (max/sum/out) identical to FLASH_SDPA_TEMPLATE but specialized to Lq==1. - %old_max = affine.vector_load %max_buffer[%qsub, 0] : {{ max_desc.get_mlir_shape(acc_stype) }}, vector<2x{{ acc_stype }}> - // Reduce max over tile_s - %max_init = vector.broadcast %c_neg_inf : {{ acc_stype }} to vector<{{ tile_s }}x{{ acc_stype }}> - %local_max_vec = arith.maximumf %scaled_mul_vec, %max_init : vector<{{ tile_s }}x{{ acc_stype }}> - %max_cast = vector.shape_cast %local_max_vec : vector<{{ tile_s }}x{{ acc_stype }}> to vector<{{ tile_s // 2 }}x2x{{ acc_stype }}> - %max_red1 = vector.multi_reduction , %max_cast, %v_neg_inf_2x [0] : vector<{{ tile_s // 2 }}x2x{{ acc_stype }}> to vector<2x{{ acc_stype }}> - %max_shuf = vector.shuffle %max_red1, %max_red1 [1, 0] : vector<2x{{ acc_stype }}>, vector<2x{{ acc_stype }}> - %max_red2 = arith.maximumf %max_red1, %max_shuf : vector<2x{{ acc_stype }}> - %new_max = arith.maximumf %max_red2, %old_max : vector<2x{{ acc_stype }}> - affine.vector_store %new_max, %max_buffer[%qsub, 0] : {{ max_desc.get_mlir_shape(acc_stype) }}, vector<2x{{ acc_stype }}> - - // rescale = exp(old_max - new_max) - %max_diff = arith.subf %old_max, %new_max : vector<2x{{ acc_stype }}> - %max_diff_scalar = vector.extract %max_diff[0] : {{ acc_stype }} from vector<2x{{ acc_stype }}> - %rescale_e = vector.broadcast %max_diff_scalar : {{ acc_stype }} to vector<{{ tile_e }}x{{ acc_stype }}> - %exp_rescale_e = math.exp %rescale_e : vector<{{ tile_e }}x{{ acc_stype }}> - %rescale_2 = vector.broadcast %max_diff_scalar : {{ acc_stype }} to vector<2x{{ acc_stype }}> - %exp_rescale_2 = math.exp %rescale_2 : vector<2x{{ acc_stype }}> - - // out *= rescale - %old_out = affine.vector_load %out_acc_buffer[%qsub, 0, 0] : {{ out_acc_tile_desc.get_mlir_shape(acc_stype) }}, vector<{{ tile_e }}x{{ acc_stype }}> - %rescaled_out = arith.mulf %exp_rescale_e, %old_out : vector<{{ tile_e }}x{{ acc_stype }}> - affine.vector_store %rescaled_out, %out_acc_buffer[%qsub, 0, 0] : {{ out_acc_tile_desc.get_mlir_shape(acc_stype) }}, vector<{{ tile_e }}x{{ acc_stype }}> - - // sum *= rescale - %old_sum = affine.vector_load %sum_buffer[%qsub, 0] : {{ sum_desc.get_mlir_shape(acc_stype) }}, vector<2x{{ acc_stype }}> - %rescaled_sum = arith.mulf %old_sum, %exp_rescale_2 : vector<2x{{ acc_stype }}> - - // exp(score - new_max) - %new_max_scalar = vector.extract %new_max[0] : {{ acc_stype }} from vector<2x{{ acc_stype }}> - %new_max_bcast = vector.broadcast %new_max_scalar : {{ acc_stype }} to vector<{{ tile_s }}x{{ acc_stype }}> - %shifted = arith.subf %scaled_mul_vec, %new_max_bcast : vector<{{ tile_s }}x{{ acc_stype }}> - %exp_scores = math.exp %shifted : vector<{{ tile_s }}x{{ acc_stype }}> - // For SV matmul: downcast softmax output to io dtype (common in practice) - {% if io_stype != acc_stype %}%exp_scores_io = arith.truncf %exp_scores : vector<{{ tile_s }}x{{ acc_stype }}> to vector<{{ tile_s }}x{{ io_stype }}>{% endif %} - affine.vector_store {{ "%exp_scores_io" if io_stype != acc_stype else "%exp_scores" }}, %prob_buffer[%qsub, 0] : {{ prob_desc.get_mlir_shape(io_stype) }}, vector<{{ tile_s }}x{{ io_stype }}> - - // sum += reduce(exp_scores) - %sum_cast = vector.shape_cast %exp_scores : vector<{{ tile_s }}x{{ acc_stype }}> to vector<{{ tile_s // 2 }}x2x{{ acc_stype }}> - %zero_2x = vector.broadcast %c0 : {{ acc_stype }} to vector<2x{{ acc_stype }}> - %sum_red1 = vector.multi_reduction , %sum_cast, %zero_2x [0] : vector<{{ tile_s // 2 }}x2x{{ acc_stype }}> to vector<2x{{ acc_stype }}> - %sum_shuf = vector.shuffle %sum_red1, %sum_red1 [1, 0] : vector<2x{{ acc_stype }}>, vector<2x{{ acc_stype }}> - %sum_red2 = arith.addf %sum_red1, %sum_shuf : vector<2x{{ acc_stype }}> - %new_sum = arith.addf %sum_red2, %rescaled_sum : vector<2x{{ acc_stype }}> - affine.vector_store %new_sum, %sum_buffer[%qsub, 0] : {{ sum_desc.get_mlir_shape(acc_stype) }}, vector<2x{{ acc_stype }}> - - } { accumulation_loop=true } - - // 2) SV accumulation: for each output dh tile, load V once and share across qsub. - affine.for %dht = 0 to {{ dh_tiles }} { - %dh0 = affine.apply affine_map<(d0) -> (d0 * {{ tile_e }})>(%dht) - %v_offset = affine.apply {{ v_offset_map_blk }}(%kv, %s0, %dh0)[%blk] - {{ kernel.def_dma_op("MVIN", "value", [], v_tile_desc, subtile_size=[1, tile_s, tile_e], indent_size=10, padding=0, dram_stride=v_dram_stride, dram_offset="v_offset") }} - %v2D = memref.reinterpret_cast %v_buffer to offset: [0], sizes: [{{ tile_e }}, {{ tile_s }}], strides: [{{ tile_s }}, 1] : {{ v_tile_desc.get_mlir_shape(io_stype) }} to memref<{{ tile_e }}x{{ tile_s }}x{{ io_stype }}, 1> - - affine.for %qsub = 0 to {{ g }} { - %prob_vec = affine.vector_load %prob_buffer[%qsub, 0] : {{ prob_desc.get_mlir_shape(io_stype) }}, vector<{{ tile_s }}x{{ io_stype }}> - affine.vector_store %prob_vec, %mul_buffer[0, 0] : {{ mul_tile_desc.get_mlir_shape(io_stype) }}, vector<{{ tile_s }}x{{ io_stype }}> - affine.vector_store %v0_e_io, %out_io_buffer[0, 0, 0] : {{ out_io_tile_desc.get_mlir_shape(io_stype) }}, vector<{{ tile_e }}x{{ io_stype }}> - %out_io_2D = memref.reinterpret_cast %out_io_buffer to offset: [0], sizes: [{{ tile_e }}, 1], strides: [1, 1] : {{ out_io_tile_desc.get_mlir_shape(io_stype) }} to memref<{{ tile_e }}x1x{{ io_stype }}, 1> - linalg.matmul - { idx_map = array } - ins(%v2D, %mul_buffer : memref<{{ tile_e }}x{{ tile_s }}x{{ io_stype }}, 1>, {{ mul_tile_desc.get_mlir_shape(io_stype) }}) - outs(%out_io_2D : memref<{{ tile_e }}x1x{{ io_stype }}, 1>) - - %out_io_vec = affine.vector_load %out_io_buffer[0, 0, 0] : {{ out_io_tile_desc.get_mlir_shape(io_stype) }}, vector<{{ tile_e }}x{{ io_stype }}> - {% if io_stype != acc_stype %}%out_io_f32 = arith.extf %out_io_vec : vector<{{ tile_e }}x{{ io_stype }}> to vector<{{ tile_e }}x{{ acc_stype }}>{% endif %} - %out_acc_vec = affine.vector_load %out_acc_buffer[%qsub, %dht, 0] : {{ out_acc_tile_desc.get_mlir_shape(acc_stype) }}, vector<{{ tile_e }}x{{ acc_stype }}> - %out_acc_new = arith.addf %out_acc_vec, {{ "%out_io_f32" if io_stype != acc_stype else "%out_io_vec" }} : vector<{{ tile_e }}x{{ acc_stype }}> - affine.vector_store %out_acc_new, %out_acc_buffer[%qsub, %dht, 0] : {{ out_acc_tile_desc.get_mlir_shape(acc_stype) }}, vector<{{ tile_e }}x{{ acc_stype }}> - } { accumulation_loop=true } - } { accumulation_loop=true } - } { accumulation_loop=true } - - // finalize per-qsub for this (kv, blk) and store out for all dh tiles - affine.for %qsub = 0 to {{ g }} { - %final_sum = affine.vector_load %sum_buffer[%qsub, 0] : {{ sum_desc.get_mlir_shape(acc_stype) }}, vector<2x{{ acc_stype }}> - %one_2x = vector.broadcast %c1 : {{ acc_stype }} to vector<2x{{ acc_stype }}> - %inv_sum_2x = arith.divf %one_2x, %final_sum : vector<2x{{ acc_stype }}> - %inv_sum = vector.extract %inv_sum_2x[0] : {{ acc_stype }} from vector<2x{{ acc_stype }}> - %inv_bcast = vector.broadcast %inv_sum : {{ acc_stype }} to vector<{{ tile_e }}x{{ acc_stype }}> - - affine.for %dht = 0 to {{ dh_tiles }} { - %dh0 = affine.apply affine_map<(d0) -> (d0 * {{ tile_e }})>(%dht) - %acc_out = affine.vector_load %out_acc_buffer[%qsub, %dht, 0] : {{ out_acc_tile_desc.get_mlir_shape(acc_stype) }}, vector<{{ tile_e }}x{{ acc_stype }}> - %final_out_acc = arith.mulf %acc_out, %inv_bcast : vector<{{ tile_e }}x{{ acc_stype }}> - {% if io_stype != acc_stype %}%final_out_io = arith.truncf %final_out_acc : vector<{{ tile_e }}x{{ acc_stype }}> to vector<{{ tile_e }}x{{ io_stype }}>{% endif %} - affine.vector_store {{ "%final_out_io" if io_stype != acc_stype else "%final_out_acc" }}, %out_io_buffer[0, 0, 0] : {{ out_io_tile_desc.get_mlir_shape(io_stype) }}, vector<{{ tile_e }}x{{ io_stype }}> - %q_head = affine.apply affine_map<(d0, d1) -> (d0 * {{ g }} + d1)>(%kv, %qsub) - %out_offset = affine.apply {{ out_offset_map }}(%q_head, %dh0) - {{ kernel.def_dma_op("MVOUT", "out", [], out_io_tile_desc, indent_size=10, dram_stride=out_dram_stride, dram_offset="out_offset") }} - } - } { outer_loop=true } - } { outer_loop=true } - } { outer_loop=true } - - return -} -""" - - -class MLIRDecodeGQASDPATemplate(MLIRTemplate): - def __init__(self, input_nodes, layout, scale, BlkS: int = 1024, input_reorder=None): - super().__init__("kernel", input_nodes, layout, input_reorder) - self.scale = scale - self.BlkS = BlkS - - def render(self, kernel: MLIRTemplateKernel, template_buffer_node=None, epilogue_nodes=None, prologue_nodes=None, tile_info=None, **kwargs): - # Decode-only: q is (B,Hq,1,Dh) - # Use template_buffer_node (the actual V.graph-registered CUDATemplateBuffer with its - # real name e.g. "buf0") when available, instead of the placeholder self.output_node - # (always named "buf_out"). This ensures output_buffers["buf0"] maps correctly - # in mlir_argdefs, which looks up buffer_types by the actual DRAM buffer name. - query, key, value, out = self.input_nodes[0], self.input_nodes[1], self.input_nodes[2], \ - template_buffer_node if template_buffer_node is not None else self.output_node - - # Materialize tensors for stride metadata - q_tensor4 = empty_strided(query.layout.size, query.layout.stride) - k_tensor4 = empty_strided(key.layout.size, key.layout.stride) - v_tensor4 = empty_strided(value.layout.size, value.layout.stride) - - B, Hq, Lq, Dh = q_tensor4.shape - Bk, H, S, Dhk = k_tensor4.shape - assert B == 1, "Decode GQA template currently supports B==1" - assert Lq == 1, "Decode GQA template requires Lq==1" - assert Dh == Dhk - g = Hq // H - BlkS = min(int(self.BlkS), int(S)) - - # Use 3D views to match the existing SDPA indexing scheme - # q: (Hq, 1, Dh), k/v: (H, S, Dh), out: (Hq, 1, Dh) - q_tensor = q_tensor4.view(Hq, 1, Dh) - k_tensor = k_tensor4.view(H, S, Dh) - v_tensor = v_tensor4.view(H, S, Dh) - - tile_s = kernel.vector_lane - tile_e = kernel.vector_lane - dh_tiles = int(Dh) // int(tile_e) - - io_stype = mlir_common.DTYPE_TO_MLIR[query.get_dtype()] - acc_stype = "f32" - - # SRAM tiles: q(1x1xtile_e), k/v(1xtile_sxtile_e), mul(tile_sx1) in io dtype. - # out_acc in f32; out_io temp in io dtype. - vlane_stride = 1 - q_tile_desc = mlir_common.MLIRMultiDimTile([1, 1, tile_e], kernel.vector_lane, 1, vlane_stride) - q_tile_desc.set_tile_size_stride([1, 1, tile_e], [0, tile_e, 1]) - q_tile_desc.set_name("q_buffer") - q_tile_desc.offset = query.get_layout().offset - - k_tile_desc = mlir_common.MLIRMultiDimTile([1, tile_s, tile_e], kernel.vector_lane, 2, vlane_stride) - k_tile_desc.set_tile_size_stride([1, tile_s, tile_e], [0, 1, tile_s]) - k_tile_desc.set_name("k_buffer") - k_tile_desc.offset = key.get_layout().offset - - v_tile_desc = mlir_common.MLIRMultiDimTile([1, tile_s, tile_e], kernel.vector_lane, 1, vlane_stride) - v_tile_desc.set_tile_size_stride([1, tile_s, tile_e], [0, tile_e, 1]) - v_tile_desc.set_name("v_buffer") - v_tile_desc.offset = value.get_layout().offset - - mul_tile_desc = mlir_common.MLIRMultiDimTile([tile_s, 1], kernel.vector_lane, 1, vlane_stride) - mul_tile_desc.set_tile_size_stride([tile_s, 1], [1, 1]) - mul_tile_desc.set_name("mul_buffer") - - score_desc = mlir_common.MLIRMultiDimTile([g, tile_s], kernel.vector_lane, 1, vlane_stride) - score_desc.set_tile_size_stride([g, tile_s], [tile_s, 1]) - score_desc.set_name("score_buffer") - - prob_desc = mlir_common.MLIRMultiDimTile([g, tile_s], kernel.vector_lane, 1, vlane_stride) - prob_desc.set_tile_size_stride([g, tile_s], [tile_s, 1]) - prob_desc.set_name("prob_buffer") - - # Per-qsub accumulators so KV tiles can be shared across qsub - out_acc_tile_desc = mlir_common.MLIRMultiDimTile([g, dh_tiles, tile_e], kernel.vector_lane, 2, vlane_stride) - out_acc_tile_desc.set_tile_size_stride([g, dh_tiles, tile_e], [dh_tiles * tile_e, tile_e, 1]) - out_acc_tile_desc.set_name("out_acc_buffer") - - out_io_tile_desc = mlir_common.MLIRMultiDimTile([1, 1, tile_e], kernel.vector_lane, 1, vlane_stride) - out_io_tile_desc.set_tile_size_stride([1, 1, tile_e], [0, tile_e, 1]) - out_io_tile_desc.set_name("out_io_buffer") - - max_desc = mlir_common.MLIRMultiDimTile([g, 2], kernel.vector_lane, 0, vlane_stride) - max_desc.set_tile_size_stride([g, 2], [2, 1]) - max_desc.set_name("max_buffer") - - sum_desc = mlir_common.MLIRMultiDimTile([g, 2], kernel.vector_lane, 0, vlane_stride) - sum_desc.set_tile_size_stride([g, 2], [2, 1]) - sum_desc.set_name("sum_buffer") - - # Strides from 3D tensor views - q_stride = q_tensor.stride() - k_stride = k_tensor.stride() - v_stride = v_tensor.stride() - # out is (B,Hq,1,Dh) but we address it as (Hq,1,Dh) - out_tensor = empty_strided(out.get_layout().size, out.get_layout().stride).view(Hq, 1, Dh) - out_stride = out_tensor.stride() - - # DMA strides (per-dimension DRAM strides for each tile) - k_dram_stride = [int(k_stride[0]), int(k_stride[1]), int(k_stride[2])] - # Q: q_head is pre-computed in template; stride[1]=0 since Lq=1 - q_dram_stride = [int(q_stride[0]), 0, int(q_stride[2])] - v_dram_stride = [int(v_stride[0]), int(v_stride[1]), int(v_stride[2])] - # out: q_head is pre-computed; stride[1]=0 since Lq=1 - out_dram_stride = [int(out_stride[0]), 0, int(out_stride[2])] - - # Affine maps for flat DRAM base address (used with pre-computed loop var expressions) - # K: offset(kv, s0, k0) - kk_offset_map = _make_offset_map(k_dram_stride, k_tile_desc.offset) - # Q: offset(q_head, k0) -- q_head = kv*g+qsub pre-computed in template - qk_offset_map = _make_offset_map([int(q_stride[0]), int(q_stride[2])], q_tile_desc.offset) - # V: offset(kv, s0, dh0) - v_offset_map = _make_offset_map(v_dram_stride, v_tile_desc.offset) - # Out: offset(q_head, dh0) -- q_head pre-computed in template - out_offset_map = _make_offset_map([int(out_stride[0]), int(out_stride[2])], 0) - # Blk-symbol variants: %s0 is relative (0..BlkS-1), %blk is the absolute - # block start (steps by BlkS), so actual_s = s0_rel + 1*blk → sym_stride=1. - kk_offset_map_blk = _make_offset_map_with_sym(k_dram_stride, sym_dim=1, sym_stride=1, offset=k_tile_desc.offset) - v_offset_map_blk = _make_offset_map_with_sym(v_dram_stride, sym_dim=1, sym_stride=1, offset=v_tile_desc.offset) - - # Keep sympy-based out_idx only for epilogue_info (not in render_options) - kv = sympy.Symbol("kv") - qsub = sympy.Symbol("qsub") - dh0 = sympy.Symbol("dh0") - s0 = sympy.Symbol("s0") - q_head = kv * g + qsub - out_idx = [q_head * out_stride[0], sympy.Integer(0), dh0 * out_stride[2]] - - kernel.loop_size = [tile_s, tile_e, 1] - - kernel.render_options = dict( - KERNEL_NAME=self.name, - kernel=kernel, - B=B, - Hq=Hq, - H=H, - g=g, - S=S, - Dh=Dh, - dh_tiles=dh_tiles, - BlkS=BlkS, - tile_s=tile_s, - tile_e=tile_e, - io_stype=io_stype, - acc_stype=acc_stype, - scale=self.scale, - query=query, - key=key, - value=value, - out=out, - q_tile_desc=q_tile_desc, - k_tile_desc=k_tile_desc, - v_tile_desc=v_tile_desc, - out_acc_tile_desc=out_acc_tile_desc, - out_io_tile_desc=out_io_tile_desc, - mul_tile_desc=mul_tile_desc, - score_desc=score_desc, - prob_desc=prob_desc, - max_desc=max_desc, - sum_desc=sum_desc, - # DMA strides - k_dram_stride=k_dram_stride, - q_dram_stride=q_dram_stride, - v_dram_stride=v_dram_stride, - out_dram_stride=out_dram_stride, - # Affine offset maps - kk_offset_map=kk_offset_map, - qk_offset_map=qk_offset_map, - v_offset_map=v_offset_map, - out_offset_map=out_offset_map, - kk_offset_map_blk=kk_offset_map_blk, - v_offset_map_blk=v_offset_map_blk, - input_reorder=self.input_reorder, - ) - - return self._template_from_string(DECODE_GQA_SDPA_TEMPLATE).render(**kernel.render_options) - - # --------------------------- # Decode-only GQA SDPA: 2-kernel pipeline (partial blocks + reduce) # --------------------------- @@ -960,13 +582,7 @@ def render(self, kernel: MLIRTemplateKernel, template_buffer_node=None, epilogue {{ kernel.def_sram_buffer("key", k_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("value", v_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("mul", mul_tile_desc, indent_size=2) }} - {{ kernel.def_sram_buffer("score", score_desc, indent_size=2) }} - {{ kernel.def_sram_buffer("prob", prob_desc, indent_size=2) }} - {{ kernel.def_sram_buffer("out_io", out_io_tile_desc, indent_size=2) }} - {{ kernel.def_sram_buffer("max", max_desc, indent_size=2) }} - {{ kernel.def_sram_buffer("sum", sum_desc, indent_size=2) }} - {{ kernel.def_sram_buffer("out_acc", out_acc_tile_desc, indent_size=2) }} - {{ kernel.def_sram_buffer("partial", partial_tile_desc, indent_size=2) }} + %c0 = arith.constant 0.0 : f32 %c_scale = arith.constant {{ scale }} : f32 @@ -984,135 +600,21 @@ def render(self, kernel: MLIRTemplateKernel, template_buffer_node=None, epilogue affine.for %kv = 0 to {{ H }} { affine.for %blk = 0 to {{ nblk }} step 1 { // Reset per-block accumulators for all qsub/dh tiles. - affine.for %qsub = 0 to {{ g }} { - affine.vector_store %v_neg_inf_2x, %max_buffer[%qsub, 0] : {{ max_desc.get_mlir_shape("f32") }}, vector<2xf32> - affine.vector_store %v0_2x, %sum_buffer[%qsub, 0] : {{ sum_desc.get_mlir_shape("f32") }}, vector<2xf32> - affine.for %dht = 0 to {{ dh_tiles }} { - affine.vector_store %v0_e, %out_acc_buffer[%qsub, %dht, 0] : {{ out_acc_tile_desc.get_mlir_shape("f32") }}, vector<{{ tile_e }}xf32> - } - } - + %qk_offset = affine.apply {{ qk_offset_map }}(%kv) + {{ kernel.def_dma_op("MVIN", "query", [], q_tile_desc, subtile_size=[Dh, 1, g_size], indent_size=8, dram_stride=q_dram_stride, dram_offset="qk_offset") }} + %q2D_buffer = memref.reinterpret_cast %q_buffer to offset: [0], sizes: [{{ Dh }}, {{ g_size }}], strides: [{{g_size}}, 1] : {{ q_tile_desc.get_mlir_shape(io_stype) }} to memref<{{ Dh }}x{{ g_size }}x{{ io_stype }}, 1> affine.for %s0 = 0 to {{ BlkS }} step {{ tile_s }} { - // Accumulate score per qsub so K tiles can be shared across qsub. - affine.for %qsub = 0 to {{ g }} { - affine.vector_store %v0_s, %score_buffer[%qsub, 0] : {{ score_desc.get_mlir_shape("f32") }}, vector<{{ tile_s }}xf32> - } - affine.for %k0 = 0 to {{ Dh }} step {{ tile_e }} { %kk_offset = affine.apply {{ kk_offset_map_blk }}(%kv, %s0, %k0)[%blk] {{ kernel.def_dma_op("MVIN", "key", [], k_tile_desc, subtile_size=[1, tile_s, tile_e], indent_size=10, padding=1, dram_stride=k_dram_stride, dram_offset="kk_offset") }} - %k2D = memref.reinterpret_cast %k_buffer to offset: [0], sizes: [{{ tile_s }}, {{ tile_e }}], strides: [{{ tile_e }}, 1] : {{ k_tile_desc.get_mlir_shape(io_stype) }} to memref<{{ tile_s }}x{{ tile_e }}x{{ io_stype }}, 1> - - affine.for %qsub = 0 to {{ g }} { - %q_head = affine.apply affine_map<(d0, d1) -> (d0 * {{ g }} + d1)>(%kv, %qsub) - %qk_offset = affine.apply {{ qk_offset_map }}(%q_head, %k0) - {{ kernel.def_dma_op("MVIN", "query", [], q_tile_desc, subtile_size=[1, 1, tile_e], indent_size=12, dram_stride=q_dram_stride, dram_offset="qk_offset") }} - %q2D = memref.reinterpret_cast %q_buffer to offset: [0], sizes: [{{ tile_e }}, 1], strides: [1, 1] : {{ q_tile_desc.get_mlir_shape(io_stype) }} to memref<{{ tile_e }}x1x{{ io_stype }}, 1> - linalg.matmul - { idx_map = array } - ins(%k2D, %q2D : memref<{{ tile_s }}x{{ tile_e }}x{{ io_stype }}, 1>, memref<{{ tile_e }}x1x{{ io_stype }}, 1>) - outs(%mul_buffer : {{ mul_tile_desc.get_mlir_shape(io_stype) }}) - %raw_mul_io = affine.vector_load %mul_buffer[0, 0] : {{ mul_tile_desc.get_mlir_shape(io_stype) }}, vector<{{ tile_s }}x{{ io_stype }}> - {% if io_stype != "f32" %}%raw_mul = arith.extf %raw_mul_io : vector<{{ tile_s }}x{{ io_stype }}> to vector<{{ tile_s }}xf32>{% endif %} - %old_score = affine.vector_load %score_buffer[%qsub, 0] : {{ score_desc.get_mlir_shape("f32") }}, vector<{{ tile_s }}xf32> - %new_score = arith.addf %old_score, {{ "%raw_mul" if io_stype != "f32" else "%raw_mul_io" }} : vector<{{ tile_s }}xf32> - affine.vector_store %new_score, %score_buffer[%qsub, 0] : {{ score_desc.get_mlir_shape("f32") }}, vector<{{ tile_s }}xf32> - } { accumulation_loop=true } - } { accumulation_loop=true } - - // Softmax once per qsub; persist probabilities in SRAM for all SV dh tiles. - affine.for %qsub = 0 to {{ g }} { - %score = affine.vector_load %score_buffer[%qsub, 0] : {{ score_desc.get_mlir_shape("f32") }}, vector<{{ tile_s }}xf32> - %scaled = arith.mulf %score, %v_scale : vector<{{ tile_s }}xf32> - - %old_max = affine.vector_load %max_buffer[%qsub, 0] : {{ max_desc.get_mlir_shape("f32") }}, vector<2xf32> - %max_init = vector.broadcast %c_neg_inf : f32 to vector<{{ tile_s }}xf32> - %local_max_vec = arith.maximumf %scaled, %max_init : vector<{{ tile_s }}xf32> - %max_cast = vector.shape_cast %local_max_vec : vector<{{ tile_s }}xf32> to vector<{{ tile_s // 2 }}x2xf32> - %max_red1 = vector.multi_reduction , %max_cast, %v_neg_inf_2x [0] : vector<{{ tile_s // 2 }}x2xf32> to vector<2xf32> - %max_shuf = vector.shuffle %max_red1, %max_red1 [1, 0] : vector<2xf32>, vector<2xf32> - %max_red2 = arith.maximumf %max_red1, %max_shuf : vector<2xf32> - %new_max = arith.maximumf %max_red2, %old_max : vector<2xf32> - affine.vector_store %new_max, %max_buffer[%qsub, 0] : {{ max_desc.get_mlir_shape("f32") }}, vector<2xf32> - - %max_diff = arith.subf %old_max, %new_max : vector<2xf32> - %max_diff_scalar = vector.extract %max_diff[0] : f32 from vector<2xf32> - %rescale_e = vector.broadcast %max_diff_scalar : f32 to vector<{{ tile_e }}xf32> - %exp_rescale_e = math.exp %rescale_e : vector<{{ tile_e }}xf32> - %rescale_2 = vector.broadcast %max_diff_scalar : f32 to vector<2xf32> - %exp_rescale_2 = math.exp %rescale_2 : vector<2xf32> - - %old_sum = affine.vector_load %sum_buffer[%qsub, 0] : {{ sum_desc.get_mlir_shape("f32") }}, vector<2xf32> - %rescaled_sum = arith.mulf %old_sum, %exp_rescale_2 : vector<2xf32> - - affine.for %dht = 0 to {{ dh_tiles }} { - %old_out = affine.vector_load %out_acc_buffer[%qsub, %dht, 0] : {{ out_acc_tile_desc.get_mlir_shape("f32") }}, vector<{{ tile_e }}xf32> - %rescaled_out = arith.mulf %exp_rescale_e, %old_out : vector<{{ tile_e }}xf32> - affine.vector_store %rescaled_out, %out_acc_buffer[%qsub, %dht, 0] : {{ out_acc_tile_desc.get_mlir_shape("f32") }}, vector<{{ tile_e }}xf32> - } - - %new_max_scalar = vector.extract %new_max[0] : f32 from vector<2xf32> - %new_max_bcast = vector.broadcast %new_max_scalar : f32 to vector<{{ tile_s }}xf32> - %shifted = arith.subf %scaled, %new_max_bcast : vector<{{ tile_s }}xf32> - %exp_scores = math.exp %shifted : vector<{{ tile_s }}xf32> - {% if io_stype != "f32" %}%exp_scores_io = arith.truncf %exp_scores : vector<{{ tile_s }}xf32> to vector<{{ tile_s }}x{{ io_stype }}>{% endif %} - affine.vector_store {{ "%exp_scores_io" if io_stype != "f32" else "%exp_scores" }}, %prob_buffer[%qsub, 0] : {{ prob_desc.get_mlir_shape(io_stype) }}, vector<{{ tile_s }}x{{ io_stype }}> - - %sum_cast = vector.shape_cast %exp_scores : vector<{{ tile_s }}xf32> to vector<{{ tile_s // 2 }}x2xf32> - %zero_2x = vector.broadcast %c0 : f32 to vector<2xf32> - %sum_red1 = vector.multi_reduction , %sum_cast, %zero_2x [0] : vector<{{ tile_s // 2 }}x2xf32> to vector<2xf32> - %sum_shuf = vector.shuffle %sum_red1, %sum_red1 [1, 0] : vector<2xf32>, vector<2xf32> - %sum_red2 = arith.addf %sum_red1, %sum_shuf : vector<2xf32> - %new_sum = arith.addf %sum_red2, %rescaled_sum : vector<2xf32> - affine.vector_store %new_sum, %sum_buffer[%qsub, 0] : {{ sum_desc.get_mlir_shape("f32") }}, vector<2xf32> - } { accumulation_loop=true } + %k2D = memref.reinterpret_cast %k_buffer to offset: [0], sizes: [{{ tile_s }}, {{ tile_e }}], strides: [{{ tile_e }},1] : {{ k_tile_desc.get_mlir_shape(io_stype) }} to memref<{{ tile_s }}x{{ tile_e }}x{{ io_stype }}, 1> + %q2D = memref.reinterpret_cast %q2D_buffer to offset: [%k0], sizes: [{{ tile_e }}, {{ g_size }}], strides: [{{ g_size }}, 1] : memref<{{ Dh }}x{{ g_size }}x{{ io_stype }}, 1> to memref<{{ tile_e }}x{{ g_size }}x{{ io_stype }}, 1> + linalg.matmul + ins(%k2D, %q2D : memref<{{ tile_s }}x{{ tile_e }}x{{ io_stype }}, 1>, memref<{{ tile_e }}x{{ g_size }}x{{ io_stype }}, 1>) + outs(%mul_buffer : {{ mul_tile_desc.get_mlir_shape(io_stype) }}) - // For each output dh tile, load V once and share it across qsub. - affine.for %dht = 0 to {{ dh_tiles }} { - %dh0 = affine.apply affine_map<(d0) -> (d0 * {{ tile_e }})>(%dht) - %v_offset = affine.apply {{ v_offset_map_blk }}(%kv, %s0, %dh0)[%blk] - {{ kernel.def_dma_op("MVIN", "value", [], v_tile_desc, subtile_size=[1, tile_s, tile_e], indent_size=10, padding=0, dram_stride=v_dram_stride, dram_offset="v_offset") }} - %v2D = memref.reinterpret_cast %v_buffer to offset: [0], sizes: [{{ tile_e }}, {{ tile_s }}], strides: [{{ tile_s }}, 1] : {{ v_tile_desc.get_mlir_shape(io_stype) }} to memref<{{ tile_e }}x{{ tile_s }}x{{ io_stype }}, 1> - - affine.for %qsub = 0 to {{ g }} { - %prob_vec = affine.vector_load %prob_buffer[%qsub, 0] : {{ prob_desc.get_mlir_shape(io_stype) }}, vector<{{ tile_s }}x{{ io_stype }}> - affine.vector_store %prob_vec, %mul_buffer[0, 0] : {{ mul_tile_desc.get_mlir_shape(io_stype) }}, vector<{{ tile_s }}x{{ io_stype }}> - affine.vector_store %v0_e_io, %out_io_buffer[0, 0, 0] : {{ out_io_tile_desc.get_mlir_shape(io_stype) }}, vector<{{ tile_e }}x{{ io_stype }}> - %out_io_2D = memref.reinterpret_cast %out_io_buffer to offset: [0], sizes: [{{ tile_e }}, 1], strides: [1, 1] : {{ out_io_tile_desc.get_mlir_shape(io_stype) }} to memref<{{ tile_e }}x1x{{ io_stype }}, 1> - linalg.matmul - { idx_map = array } - ins(%v2D, %mul_buffer : memref<{{ tile_e }}x{{ tile_s }}x{{ io_stype }}, 1>, {{ mul_tile_desc.get_mlir_shape(io_stype) }}) - outs(%out_io_2D : memref<{{ tile_e }}x1x{{ io_stype }}, 1>) - - %out_io_vec = affine.vector_load %out_io_buffer[0, 0, 0] : {{ out_io_tile_desc.get_mlir_shape(io_stype) }}, vector<{{ tile_e }}x{{ io_stype }}> - {% if io_stype != "f32" %}%out_io_f32 = arith.extf %out_io_vec : vector<{{ tile_e }}x{{ io_stype }}> to vector<{{ tile_e }}xf32>{% endif %} - %out_acc_vec = affine.vector_load %out_acc_buffer[%qsub, %dht, 0] : {{ out_acc_tile_desc.get_mlir_shape("f32") }}, vector<{{ tile_e }}xf32> - %out_acc_new = arith.addf %out_acc_vec, {{ "%out_io_f32" if io_stype != "f32" else "%out_io_vec" }} : vector<{{ tile_e }}xf32> - affine.vector_store %out_acc_new, %out_acc_buffer[%qsub, %dht, 0] : {{ out_acc_tile_desc.get_mlir_shape("f32") }}, vector<{{ tile_e }}xf32> - } { accumulation_loop=true } } { accumulation_loop=true } } { accumulation_loop=true } - - // Store packed partials for all qsub/dh tiles. - affine.for %qsub = 0 to {{ g }} { - %final_max = affine.vector_load %max_buffer[%qsub, 0] : {{ max_desc.get_mlir_shape("f32") }}, vector<2xf32> - %m_scalar = vector.extract %final_max[0] : f32 from vector<2xf32> - %final_sum = affine.vector_load %sum_buffer[%qsub, 0] : {{ sum_desc.get_mlir_shape("f32") }}, vector<2xf32> - %l_scalar = vector.extract %final_sum[0] : f32 from vector<2xf32> - %ml_vec = vector.broadcast %c0 : f32 to vector<{{ tile_e }}xf32> - %ml0 = vector.insert %m_scalar, %ml_vec[0] : f32 into vector<{{ tile_e }}xf32> - %ml1 = vector.insert %l_scalar, %ml0[1] : f32 into vector<{{ tile_e }}xf32> - - affine.for %dht = 0 to {{ dh_tiles }} { - %out_vec = affine.vector_load %out_acc_buffer[%qsub, %dht, 0] : {{ out_acc_tile_desc.get_mlir_shape("f32") }}, vector<{{ tile_e }}xf32> - %packed = vector.shuffle %out_vec, %ml1 [{{ range(tile_pack) | join(', ') }}] : vector<{{ tile_e }}xf32>, vector<{{ tile_e }}xf32> - affine.vector_store %packed, %partial_buffer[0, 0, 0] : {{ partial_tile_desc.get_mlir_shape("f32") }}, vector<{{ tile_pack }}xf32> - %q_head = affine.apply affine_map<(d0, d1) -> (d0 * {{ g }} + d1)>(%kv, %qsub) - %gh = affine.apply affine_map<(d0, d1) -> (d0 * {{ dh_tiles }} + d1)>(%q_head, %dht) - %partial_offset = affine.apply {{ partial_offset_map }}(%gh, %blk) - {{ kernel.def_dma_op("MVOUT", "partial", [], partial_tile_desc, indent_size=10, dram_stride=partial_dram_stride, dram_offset="partial_offset") }} - } - } { outer_loop=true } } { outer_loop=true } } { outer_loop=true } return @@ -1138,6 +640,7 @@ def render(self, kernel: MLIRTemplateKernel, template_buffer_node=None, epilogue _, H, S, _ = k_tensor4.shape assert B == 1 and Lq == 1 g = Hq // H + g_size = g BlkS = min(int(self.BlkS), int(S)) nblk = (int(S) + int(BlkS) - 1) // int(BlkS) @@ -1157,53 +660,53 @@ def render(self, kernel: MLIRTemplateKernel, template_buffer_node=None, epilogue # tile descs vlane_stride = 1 - q_tile_desc = mlir_common.MLIRMultiDimTile([1, 1, tile_e], kernel.vector_lane, 1, vlane_stride) - q_tile_desc.set_tile_size_stride([1, 1, tile_e], [0, tile_e, 1]) + q_tile_desc = mlir_common.MLIRMultiDimTile([Dh, 1, g_size], kernel.vector_lane, 2, vlane_stride) + q_tile_desc.set_tile_size_stride([Dh, 1, g_size], [g_size, 1, 1]) q_tile_desc.set_name("q_buffer") q_tile_desc.offset = query.get_layout().offset k_tile_desc = mlir_common.MLIRMultiDimTile([1, tile_s, tile_e], kernel.vector_lane, 2, vlane_stride) - k_tile_desc.set_tile_size_stride([1, tile_s, tile_e], [0, 1, tile_s]) + k_tile_desc.set_tile_size_stride([1, tile_s, tile_e], [1, 1, tile_s]) k_tile_desc.set_name("k_buffer") k_tile_desc.offset = key.get_layout().offset v_tile_desc = mlir_common.MLIRMultiDimTile([1, tile_s, tile_e], kernel.vector_lane, 1, vlane_stride) - v_tile_desc.set_tile_size_stride([1, tile_s, tile_e], [0, tile_e, 1]) + v_tile_desc.set_tile_size_stride([1, tile_s, tile_e], [1, tile_e, 1]) v_tile_desc.set_name("v_buffer") v_tile_desc.offset = value.get_layout().offset - mul_tile_desc = mlir_common.MLIRMultiDimTile([tile_s, 1], kernel.vector_lane, 1, vlane_stride) - mul_tile_desc.set_tile_size_stride([tile_s, 1], [1, 1]) + mul_tile_desc = mlir_common.MLIRMultiDimTile([tile_s, g_size], kernel.vector_lane, 1, vlane_stride) + mul_tile_desc.set_tile_size_stride([tile_s, g_size], [1, tile_s]) mul_tile_desc.set_name("mul_buffer") - score_desc = mlir_common.MLIRMultiDimTile([g, tile_s], kernel.vector_lane, 1, vlane_stride) - score_desc.set_tile_size_stride([g, tile_s], [tile_s, 1]) - score_desc.set_name("score_buffer") + # score_desc = mlir_common.MLIRMultiDimTile([g, tile_s], kernel.vector_lane, 1, vlane_stride) + # score_desc.set_tile_size_stride([g, tile_s], [tile_s, 1]) + # score_desc.set_name("score_buffer") - prob_desc = mlir_common.MLIRMultiDimTile([g, tile_s], kernel.vector_lane, 1, vlane_stride) - prob_desc.set_tile_size_stride([g, tile_s], [tile_s, 1]) - prob_desc.set_name("prob_buffer") + # prob_desc = mlir_common.MLIRMultiDimTile([g, tile_s], kernel.vector_lane, 1, vlane_stride) + # prob_desc.set_tile_size_stride([g, tile_s], [tile_s, 1]) + # prob_desc.set_name("prob_buffer") - # Per-qsub, per-dh-tile accumulators so QK is computed once and SV expands across dh tiles. - out_acc_tile_desc = mlir_common.MLIRMultiDimTile([g, dh_tiles, tile_e], kernel.vector_lane, 2, vlane_stride) - out_acc_tile_desc.set_tile_size_stride([g, dh_tiles, tile_e], [dh_tiles * tile_e, tile_e, 1]) - out_acc_tile_desc.set_name("out_acc_buffer") + # # Per-qsub, per-dh-tile accumulators so QK is computed once and SV expands across dh tiles. + # out_acc_tile_desc = mlir_common.MLIRMultiDimTile([g, dh_tiles, tile_e], kernel.vector_lane, 2, vlane_stride) + # out_acc_tile_desc.set_tile_size_stride([g, dh_tiles, tile_e], [dh_tiles * tile_e, tile_e, 1]) + # out_acc_tile_desc.set_name("out_acc_buffer") - max_desc = mlir_common.MLIRMultiDimTile([g, 2], kernel.vector_lane, 0, vlane_stride) - max_desc.set_tile_size_stride([g, 2], [2, 1]) - max_desc.set_name("max_buffer") + # max_desc = mlir_common.MLIRMultiDimTile([g, 2], kernel.vector_lane, 0, vlane_stride) + # max_desc.set_tile_size_stride([g, 2], [2, 1]) + # max_desc.set_name("max_buffer") - sum_desc = mlir_common.MLIRMultiDimTile([g, 2], kernel.vector_lane, 0, vlane_stride) - sum_desc.set_tile_size_stride([g, 2], [2, 1]) - sum_desc.set_name("sum_buffer") + # sum_desc = mlir_common.MLIRMultiDimTile([g, 2], kernel.vector_lane, 0, vlane_stride) + # sum_desc.set_tile_size_stride([g, 2], [2, 1]) + # sum_desc.set_name("sum_buffer") - out_io_tile_desc = mlir_common.MLIRMultiDimTile([1, 1, tile_e], kernel.vector_lane, 1, vlane_stride) - out_io_tile_desc.set_tile_size_stride([1, 1, tile_e], [0, tile_e, 1]) - out_io_tile_desc.set_name("out_io_buffer") + # out_io_tile_desc = mlir_common.MLIRMultiDimTile([1, 1, tile_e], kernel.vector_lane, 1, vlane_stride) + # out_io_tile_desc.set_tile_size_stride([1, 1, tile_e], [0, tile_e, 1]) + # out_io_tile_desc.set_name("out_io_buffer") - partial_tile_desc = mlir_common.MLIRMultiDimTile([1, 1, tile_pack], kernel.vector_lane, 1, vlane_stride) - partial_tile_desc.set_tile_size_stride([1, 1, tile_pack], [0, tile_pack, 1]) - partial_tile_desc.set_name("partial_buffer") + # partial_tile_desc = mlir_common.MLIRMultiDimTile([1, 1, tile_pack], kernel.vector_lane, 1, vlane_stride) + # partial_tile_desc.set_tile_size_stride([1, 1, tile_pack], [0, tile_pack, 1]) + # partial_tile_desc.set_name("partial_buffer") # Strides from 3D tensor views q_stride = q_tensor.stride() @@ -1216,13 +719,13 @@ def render(self, kernel: MLIRTemplateKernel, template_buffer_node=None, epilogue # DMA strides k_dram_stride = [int(k_stride[0]), int(k_stride[1]), int(k_stride[2])] - q_dram_stride = [int(q_stride[0]), 0, int(q_stride[2])] + q_dram_stride = [int(q_stride[2]), 0, int(q_stride[1])] v_dram_stride = [int(v_stride[0]), int(v_stride[1]), int(v_stride[2])] partial_dram_stride = [int(p_stride[0]), int(p_stride[1]), 1] # Affine offset maps kk_offset_map = _make_offset_map(k_dram_stride, k_tile_desc.offset) - qk_offset_map = _make_offset_map([int(q_stride[0]), int(q_stride[2])], q_tile_desc.offset) + qk_offset_map = _make_offset_map([int(g) * int(q_stride[2])], q_tile_desc.offset) v_offset_map = _make_offset_map(v_dram_stride, v_tile_desc.offset) # partial: offset(gh, blk) -- gh = (kv*g+qsub)*dh_tiles+dht, pre-computed in template partial_offset_map = _make_offset_map([int(p_stride[0]), int(p_stride[1])], 0) @@ -1254,6 +757,7 @@ def render(self, kernel: MLIRTemplateKernel, template_buffer_node=None, epilogue nblk=nblk, tile_s=tile_s, tile_e=tile_e, + g_size=g_size, dh_tiles=dh_tiles, tile_pack=tile_pack, io_stype=io_stype, @@ -1266,13 +770,13 @@ def render(self, kernel: MLIRTemplateKernel, template_buffer_node=None, epilogue k_tile_desc=k_tile_desc, v_tile_desc=v_tile_desc, mul_tile_desc=mul_tile_desc, - score_desc=score_desc, - prob_desc=prob_desc, - out_io_tile_desc=out_io_tile_desc, - out_acc_tile_desc=out_acc_tile_desc, - max_desc=max_desc, - sum_desc=sum_desc, - partial_tile_desc=partial_tile_desc, + # score_desc=score_desc, + # prob_desc=prob_desc, + # out_io_tile_desc=out_io_tile_desc, + # out_acc_tile_desc=out_acc_tile_desc, + # max_desc=max_desc, + # sum_desc=sum_desc, + # partial_tile_desc=partial_tile_desc, # DMA strides k_dram_stride=k_dram_stride, q_dram_stride=q_dram_stride, From ce9330670c60bb4debf795c4771b8d80057e92e5 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Fri, 13 Mar 2026 21:30:07 +0900 Subject: [PATCH 24/31] [Template/SPDA] Remove subtile size temporarily --- PyTorchSimFrontend/extension_codecache.py | 2 ++ PyTorchSimFrontend/mlir/mlir_sdpa_template.py | 18 ++++++------------ 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/PyTorchSimFrontend/extension_codecache.py b/PyTorchSimFrontend/extension_codecache.py index d3ac7259..b1c457d3 100644 --- a/PyTorchSimFrontend/extension_codecache.py +++ b/PyTorchSimFrontend/extension_codecache.py @@ -37,6 +37,7 @@ def mlir_compile_command(filename, vectorlane_size, vlen=256): f""" {extension_config.CONFIG_TORCHSIM_LLVM_PATH}/mlir-opt \ -test-loop-padding \ + -dma-fine-grained='systolic-array-size={vectorlane_size}' \ -global-idx='vlen={vlen}' \ -test-pytorchsim-to-vcix='systolic-array-size={vectorlane_size} vlen={vlen}' \ -test-memref-to-gemmini="vectorlane={vectorlane_size}" \ @@ -86,6 +87,7 @@ def mlir_gem5_compile_command(filename, sample_filename, tog_file, vectorlane_si f""" {extension_config.CONFIG_TORCHSIM_LLVM_PATH}/mlir-opt \ -test-loop-padding='timing_mode=1' \ + -dma-fine-grained='systolic-array-size={vectorlane_size}' \ -global-idx='vlen={vlen}' \ -test-pytorchsim-to-vcix='systolic-array-size={vectorlane_size} vlen={vlen}' \ -test-tile-operation-graph='vectorlane={vectorlane_size} tls_mode={extension_config.CONFIG_TLS_MODE}' \ diff --git a/PyTorchSimFrontend/mlir/mlir_sdpa_template.py b/PyTorchSimFrontend/mlir/mlir_sdpa_template.py index adcc7801..b1569be6 100644 --- a/PyTorchSimFrontend/mlir/mlir_sdpa_template.py +++ b/PyTorchSimFrontend/mlir/mlir_sdpa_template.py @@ -169,9 +169,6 @@ def calculate_scale(query: torch.Tensor, scale: float) -> float: // tile_l = {{ tile_l }} // tile_s = {{ tile_s }} // tile_e = {{ tile_e }} -// subtile_l = {{ subtile_l }} -// subtile_s = {{ subtile_s }} -// subtile_e = {{ subtile_e }} {{kernel.def_global_vars()}} func.func @{{ KERNEL_NAME }}{{kernel.def_kernel(inputs=[query, key, value], outputs=[out], names_str="query, key, value, out", input_reorder=input_reorder)}} { @@ -210,7 +207,7 @@ def calculate_scale(query: torch.Tensor, scale: float) -> float: affine.for %index3 = 0 to 1 step 1 { affine.for %index1 = 0 to {{ l }} step {{ tile_l }} { %q_dram_offset = affine.apply {{ q_offset_map }}(%index0, %index1, %index3) - {{ kernel.def_dma_op("MVIN", "query", [], q_tile_desc, subtile_size=[1, subtile_l, subtile_e], indent_size=8, dram_stride=q_dram_stride, dram_offset="q_dram_offset") }} + {{ kernel.def_dma_op("MVIN", "query", [], q_tile_desc, indent_size=8, dram_stride=q_dram_stride, dram_offset="q_dram_offset") }} affine.vector_store %v0_l, %out_buffer[0, 0, 0] : {{ out_tile_desc.get_mlir_shape(data_stype) }}, vector<{{ kernel.get_spad_size_per_lane(tile_l, tile_e) }}x{{ data_stype }}> affine.vector_store %v_neg_inf_2x, %max_buffer[0, 0] : {{ max_desc.get_mlir_shape(data_stype) }}, vector<2x{{ data_stype }}> @@ -221,9 +218,9 @@ def calculate_scale(query: torch.Tensor, scale: float) -> float: affine.for %index2 = 0 to {{ s }} step {{ tile_s }} { %k_dram_offset = affine.apply {{ k_offset_map }}(%index0, %index2, %index3) - {{ kernel.def_dma_op("MVIN", "key", [], k_tile_desc, subtile_size=[1, subtile_s, subtile_e], indent_size=10, dram_stride=k_dram_stride, dram_offset="k_dram_offset") }} + {{ kernel.def_dma_op("MVIN", "key", [], k_tile_desc, indent_size=10, dram_stride=k_dram_stride, dram_offset="k_dram_offset") }} %v_dram_offset = affine.apply {{ v_offset_map }}(%index0, %index2, %index3) - {{ kernel.def_dma_op("MVIN", "value", [], v_tile_desc, subtile_size=[1, subtile_s, subtile_e], indent_size=10, dram_stride=v_dram_stride, dram_offset="v_dram_offset") }} + {{ kernel.def_dma_op("MVIN", "value", [], v_tile_desc, indent_size=10, dram_stride=v_dram_stride, dram_offset="v_dram_offset") }} affine.vector_store %v0_s, %mul_buffer[0, 0] : {{ mul_tile_desc.get_mlir_shape(data_stype) }}, vector<{{ kernel.get_spad_size_per_lane(tile_s, tile_l) }}x{{ data_stype }}> @@ -487,9 +484,6 @@ def render(self, tile_l = tile_l, tile_s = tile_s, tile_e = tile_e, # Tile sizes (sram) - subtile_l = subtile_l, - subtile_s = subtile_s, - subtile_e = subtile_e, # Subtile sizes (sram) data_stype="f32", query = query, key = key, @@ -601,12 +595,12 @@ def select_tile(self, kernel, l, s, e, n_extra_node, n_extra_read, n_prologue_no affine.for %blk = 0 to {{ nblk }} step 1 { // Reset per-block accumulators for all qsub/dh tiles. %qk_offset = affine.apply {{ qk_offset_map }}(%kv) - {{ kernel.def_dma_op("MVIN", "query", [], q_tile_desc, subtile_size=[Dh, 1, g_size], indent_size=8, dram_stride=q_dram_stride, dram_offset="qk_offset") }} + {{ kernel.def_dma_op("MVIN", "query", [], q_tile_desc, indent_size=8, dram_stride=q_dram_stride, dram_offset="qk_offset") }} %q2D_buffer = memref.reinterpret_cast %q_buffer to offset: [0], sizes: [{{ Dh }}, {{ g_size }}], strides: [{{g_size}}, 1] : {{ q_tile_desc.get_mlir_shape(io_stype) }} to memref<{{ Dh }}x{{ g_size }}x{{ io_stype }}, 1> affine.for %s0 = 0 to {{ BlkS }} step {{ tile_s }} { affine.for %k0 = 0 to {{ Dh }} step {{ tile_e }} { %kk_offset = affine.apply {{ kk_offset_map_blk }}(%kv, %s0, %k0)[%blk] - {{ kernel.def_dma_op("MVIN", "key", [], k_tile_desc, subtile_size=[1, tile_s, tile_e], indent_size=10, padding=1, dram_stride=k_dram_stride, dram_offset="kk_offset") }} + {{ kernel.def_dma_op("MVIN", "key", [], k_tile_desc, indent_size=10, padding=1, dram_stride=k_dram_stride, dram_offset="kk_offset") }} %k2D = memref.reinterpret_cast %k_buffer to offset: [0], sizes: [{{ tile_s }}, {{ tile_e }}], strides: [{{ tile_e }},1] : {{ k_tile_desc.get_mlir_shape(io_stype) }} to memref<{{ tile_s }}x{{ tile_e }}x{{ io_stype }}, 1> %q2D = memref.reinterpret_cast %q2D_buffer to offset: [%k0], sizes: [{{ tile_e }}, {{ g_size }}], strides: [{{ g_size }}, 1] : memref<{{ Dh }}x{{ g_size }}x{{ io_stype }}, 1> to memref<{{ tile_e }}x{{ g_size }}x{{ io_stype }}, 1> linalg.matmul @@ -824,7 +818,7 @@ def render(self, kernel: MLIRTemplateKernel, template_buffer_node=None, epilogue affine.for %blk = 0 to {{ nblk }} { %partial_offset = affine.apply {{ partial_offset_map }}(%gh, %blk) - {{ kernel.def_dma_op("MVIN", "partial", [], partial_tile_desc, subtile_size=[1, 1, tile_pack], indent_size=8, dram_stride=partial_dram_stride, dram_offset="partial_offset") }} + {{ kernel.def_dma_op("MVIN", "partial", [], partial_tile_desc, indent_size=8, dram_stride=partial_dram_stride, dram_offset="partial_offset") }} %p = affine.vector_load %partial_buffer[0, 0, 0] : {{ partial_tile_desc.get_mlir_shape("f32") }}, vector<{{ tile_pack }}xf32> %p2 = vector.shape_cast %p : vector<{{ tile_pack }}xf32> to vector<2x{{ tile_e }}xf32> %o_j = vector.extract %p2[0] : vector<{{ tile_e }}xf32> from vector<2x{{ tile_e }}xf32> From f2717e1cd117b5229f769ebf3a7040185c984891 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Fri, 13 Mar 2026 22:45:00 +0900 Subject: [PATCH 25/31] [Template/SPDA] minor fix --- PyTorchSimFrontend/mlir/mlir_sdpa_template.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/PyTorchSimFrontend/mlir/mlir_sdpa_template.py b/PyTorchSimFrontend/mlir/mlir_sdpa_template.py index b1569be6..be6e7124 100644 --- a/PyTorchSimFrontend/mlir/mlir_sdpa_template.py +++ b/PyTorchSimFrontend/mlir/mlir_sdpa_template.py @@ -713,7 +713,7 @@ def render(self, kernel: MLIRTemplateKernel, template_buffer_node=None, epilogue # DMA strides k_dram_stride = [int(k_stride[0]), int(k_stride[1]), int(k_stride[2])] - q_dram_stride = [int(q_stride[2]), 0, int(q_stride[1])] + q_dram_stride = [int(q_stride[2]), 0, int(q_stride[0])] v_dram_stride = [int(v_stride[0]), int(v_stride[1]), int(v_stride[2])] partial_dram_stride = [int(p_stride[0]), int(p_stride[1]), 1] From be23638400926454d8be17742eff4b6fc358b750 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Mon, 16 Mar 2026 20:43:21 +0900 Subject: [PATCH 26/31] [Cleanup] Unflag debug option --- PyTorchSimFrontend/extension_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/PyTorchSimFrontend/extension_config.py b/PyTorchSimFrontend/extension_config.py index fe8cc380..1b7ccf8d 100644 --- a/PyTorchSimFrontend/extension_config.py +++ b/PyTorchSimFrontend/extension_config.py @@ -130,7 +130,7 @@ def load_plan_from_module(module_path): CONFIG_USE_TIMING_POOLING = int(os.environ.get('TORCHSIM_USE_TIMING_POOLING', default=0)) -CONFIG_DEBUG_MODE = int(os.environ.get('TORCHSIM_DEBUG_MODE', default=1)) +CONFIG_DEBUG_MODE = int(os.environ.get('TORCHSIM_DEBUG_MODE', default=0)) def setup_logger(name=None, level=None): From e925ae45cad8cebca98e42de5c1cfb8c01cd35bf Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Mon, 16 Mar 2026 22:02:07 +0900 Subject: [PATCH 27/31] [CI] Add deepseek test case --- .github/workflows/pytorchsim_test.yml | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/.github/workflows/pytorchsim_test.yml b/.github/workflows/pytorchsim_test.yml index eaaa7e50..36a62b68 100644 --- a/.github/workflows/pytorchsim_test.yml +++ b/.github/workflows/pytorchsim_test.yml @@ -726,6 +726,27 @@ jobs: -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/Yolov5/test_yolov5.py + test_deepseek: + name: Run test_deepseek + runs-on: self-hosted + steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Run test_deepseek_v3_base.py + run: | + echo "Running test_deepseek_v3_base.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + -e vpu_num_lanes="${{ inputs.vector_lane }}" \ + -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ + ${{ inputs.image_name }} python3 PyTorchSim/tests/DeepSeek/test_deepseek_v3_base.py + test_accuracy: name: Run test_accuracy runs-on: self-hosted From db859911ed73b21db65031f84dc47dc4555dcc3f Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Tue, 17 Mar 2026 16:24:45 +0900 Subject: [PATCH 28/31] [Template/SPDA] Cleanup test case + Add an activate option --- PyTorchSimDevice/csrc/aten/native/Extra.cpp | 34 +- .../torch_openreg/openreg/__init__.py | 5 + PyTorchSimFrontend/mlir/mlir_lowering.py | 60 +-- PyTorchSimFrontend/mlir/mlir_sdpa_template.py | 423 +----------------- tests/test_sdpa.py | 241 +++++----- 5 files changed, 181 insertions(+), 582 deletions(-) diff --git a/PyTorchSimDevice/csrc/aten/native/Extra.cpp b/PyTorchSimDevice/csrc/aten/native/Extra.cpp index aaf28e1a..eb76f5d7 100644 --- a/PyTorchSimDevice/csrc/aten/native/Extra.cpp +++ b/PyTorchSimDevice/csrc/aten/native/Extra.cpp @@ -20,8 +20,38 @@ int64_t _fused_sdp_choice( std::optional scale, bool enable_gqa) { - auto backend = sdp::SDPBackend::overrideable; - return static_cast(backend); + sdp::sdp_params params{query, key, value, attn_mask, dropout_p, is_causal, enable_gqa}; + + // Reject inputs that are fundamentally unsupported (e.g. wrong rank) + if (!sdp::check_tensor_shapes(params, /*debug=*/false)) { + return static_cast(sdp::SDPBackend::error); + } + + // q: (B, Hq, L, E) k/v: (B, H, S, E) + const int64_t Hq = query.size(-3); + const int64_t H = key.size(-3); + const int64_t L = query.size(-2); // query sequence length + const int64_t S = key.size(-2); // key/value sequence length + + // Conditions required by the MLIR FlashSDPA kernel: + // Prefill only : L == S (decode has L == 1, not supported) + // Non-GQA : Hq == H (equal query and KV heads) + // No dropout : template has no dropout implementation + // Dense tensors : no nested tensor support + const bool can_use_mlir_flash = + (L == S) && + (Hq == H) && !enable_gqa && + sdp::check_for_dropout(params, /*debug=*/false) && + sdp::check_nested_tensor(params, /*debug=*/false); + + const bool ctx_flash = at::globalContext().userEnabledFlashSDP(); + const bool ctx_math = at::globalContext().userEnabledMathSDP(); + + if (ctx_flash && can_use_mlir_flash) { + return static_cast(sdp::SDPBackend::overrideable); + } + + return static_cast(sdp::SDPBackend::math); } void quantize_tensor_per_tensor_affine_stub( diff --git a/PyTorchSimDevice/torch_openreg/openreg/__init__.py b/PyTorchSimDevice/torch_openreg/openreg/__init__.py index f674ec06..592011aa 100644 --- a/PyTorchSimDevice/torch_openreg/openreg/__init__.py +++ b/PyTorchSimDevice/torch_openreg/openreg/__init__.py @@ -73,6 +73,11 @@ def _lazy_init(): register_interface_for_device(custom_device(), ExtensionDeviceInterface) _initialized = True + # Set default SDPA backend to math-only for this device. + torch._C._set_sdp_use_flash(False) + torch._C._set_sdp_use_overrideable(False) + torch._C._set_sdp_use_math(True) + # Create default streams for all devices num_devices = device_count() for device_idx in range(num_devices): diff --git a/PyTorchSimFrontend/mlir/mlir_lowering.py b/PyTorchSimFrontend/mlir/mlir_lowering.py index 7b2c07bf..b717089f 100644 --- a/PyTorchSimFrontend/mlir/mlir_lowering.py +++ b/PyTorchSimFrontend/mlir/mlir_lowering.py @@ -20,8 +20,6 @@ from PyTorchSimFrontend.mlir.mlir_sort_template import MLIRSortTemplate, MLIRStableSortTemplate from PyTorchSimFrontend.mlir.mlir_sdpa_template import ( MLIRFlashSDPATemplate, - MLIRDecodeGQASDPAPartialTemplate, - MLIRDecodeGQASDPAReduceTemplate, flash_sdpa_args, calculate_scale, ) @@ -51,56 +49,27 @@ def tuned_bmm(mat1, mat2, *, layout=None): def tuned_flash_sdpa( - query : TensorBox, - key : TensorBox, - value : TensorBox, + query : TensorBox, + key : TensorBox, + value : TensorBox, attn_bias : Optional[TensorBox] = None, - dropout_p : float = 0.0, - is_causal : bool = False, + dropout_p : float = 0.0, + is_causal : bool = False, return_debug_mask : bool = False, - scale : Optional[float] = None) -> tuple: - - + scale : Optional[float] = None, + enable_gqa : bool = False) -> tuple: + # _fused_sdp_choice in C++ already guarantees: + # L == S (prefill), Hq == H (non-GQA), dropout_p == 0.0 + # before routing here via SDPBackend::overrideable. + # Non-matching shapes fall back to SDPBackend::math in C++ and decompose + # into primitive ops (matmul/softmax) before reaching this lowering. scale = calculate_scale(query, scale) N, Hq, H, L, S, E, Ev, layout, query, key, value = flash_sdpa_args(query, key, value) - - # Decode-only GQA fast path: q is (B,Hq,1,Dh), B==1, Hq!=H, Hq%H==0. - # Always use the 2-kernel decode path: - # 1) block partials over (kv head, sequence block) - # 2) reduce/merge across blocks - # This keeps KV shared across qsub, avoids dh0-outer duplication, and - # stores compact partials instead of full score/prob tensors in DRAM. - if L == 1 and Hq != H and N == 1 and (Hq % H) == 0: - g = Hq // H - vector_lane = extension_config.vpu_num_lanes - tile_e = vector_lane - dh_tiles = E // tile_e - decode_gqa_block_size = 512 - BlkS = decode_gqa_block_size if S >= decode_gqa_block_size else int(S) - # Padding-based tail handling: allow S not divisible by BlkS. - nblk = (S + BlkS - 1) // BlkS - HgDhTiles = H * g * dh_tiles - tile_pack = tile_e * 2 - - partial_layout = ir.FixedLayout( - query.get_device(), - torch.float32, - [HgDhTiles, nblk, tile_pack], - ) - partial_tmpl = MLIRDecodeGQASDPAPartialTemplate([query, key, value], partial_layout, scale, BlkS=BlkS) - partial = partial_tmpl.generate().output_node() - partial.realize() - reduce_tmpl = MLIRDecodeGQASDPAReduceTemplate([partial], layout, BlkS=BlkS) - out_node = reduce_tmpl.generate().output_node() - return (out_node, None, None, None, None, None, None, None, None) - mlir_template = MLIRFlashSDPATemplate([query, key, value], layout, scale) - - # _scaled_dot_product_flash_attention has to return a tuple which has 9 values - # since its backward(_scaled_dot_product_flash_attention_backward) needs that values. - # (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor rng_state, Tensor unused, Tensor debug_attn_mask) return (mlir_template.generate().output_node(), None, None, None, None, None, None, None, None) + + def conv_layout( x: TensorBox, weight: TensorBox, @@ -345,5 +314,4 @@ def _sort_layouts(x: TensorBox, dim: int, descending: bool): if extension_config.CONFIG_USE_TIMING_POOLING: lowerings.update({getattr(aten.max_pool2d_with_indices, overload): custom_maxpool for overload in aten.max_pool2d_with_indices.overloads()}) # FIXME: maxpool should be implemented as a template - lowerings.update({getattr(aten._scaled_dot_product_fused_attention_overrideable, overload): tuned_flash_sdpa for overload in aten._scaled_dot_product_fused_attention_overrideable.overloads()}) diff --git a/PyTorchSimFrontend/mlir/mlir_sdpa_template.py b/PyTorchSimFrontend/mlir/mlir_sdpa_template.py index be6e7124..37db4956 100644 --- a/PyTorchSimFrontend/mlir/mlir_sdpa_template.py +++ b/PyTorchSimFrontend/mlir/mlir_sdpa_template.py @@ -125,14 +125,6 @@ def flash_sdpa_args( "Flash SDPA currently requires matching head dimensions between query and value (e == ev)." ) - # Support head dimensions larger than vector lanes by tiling e/ev. - # For now, require multiples of vector lanes (covers 64/128 with vlanes=16). - vector_lane = extension_config.vpu_num_lanes - if (e % vector_lane) != 0: - raise NotImplementedError( - f"Flash SDPA currently requires e to be a multiple of vlanes (e: {e}, vlanes: {vector_lane})." - ) - # Minimal GQA support (single-batch only for now). # We map each query head to a KV head by grouping: hq = g * h. if hq != h: @@ -309,7 +301,7 @@ def calculate_scale(query: torch.Tensor, scale: float) -> float: { idx_map = array } ins(%vt_buffer2D, %mul_buffer : memref<{{ tile_e }}x{{ tile_s }}x{{ data_stype }}, 1>, {{ mul_tile_desc.get_mlir_shape(data_stype) }}) outs(%ot_buffer2D : memref<{{ tile_e }}x{{ tile_l }}x{{ data_stype }}, 1>) - } + } {inner_loop=true} // out @ row_sum^(-1) %final_row_sum = affine.vector_load %sum_buffer[0, 0] : {{ sum_desc.get_mlir_shape(data_stype) }}, vector<2x{{ data_stype }}> @@ -556,416 +548,3 @@ def select_tile(self, kernel, l, s, e, n_extra_node, n_extra_read, n_prologue_no return tile_candidates - -# --------------------------- -# Decode-only GQA SDPA: 2-kernel pipeline (partial blocks + reduce) -# --------------------------- - -DECODE_GQA_SDPA_PARTIAL_TEMPLATE = r""" -// Decode GQA SDPA partial kernel (per sequence block) -// Produces partials per (kv,qsub,dh_tile,blk): -// - first half lanes: o_j (tile_e) -// - second half lanes: [m_j, l_j, 0, 0, ...] (tile_e) -// QK/softmax is computed once per (kv,qsub,s0) over full Dh using k0 reduction. -// SV then reuses those probabilities across all dh tiles. -// H = {{ H }}, g = {{ g }}, Dh = {{ Dh }}, dh_tiles = {{ dh_tiles }}, S = {{ S }}, BlkS = {{ BlkS }}, nblk = {{ nblk }} -{{kernel.def_global_vars()}} - -func.func @{{ KERNEL_NAME }}{{kernel.def_kernel(inputs=[query, key, value], outputs=[partial], names_str="query, key, value, partial", input_reorder=input_reorder)}} { - {{ kernel.def_sram_buffer("query", q_tile_desc, indent_size=2) }} - {{ kernel.def_sram_buffer("key", k_tile_desc, indent_size=2) }} - {{ kernel.def_sram_buffer("value", v_tile_desc, indent_size=2) }} - {{ kernel.def_sram_buffer("mul", mul_tile_desc, indent_size=2) }} - - - %c0 = arith.constant 0.0 : f32 - %c_scale = arith.constant {{ scale }} : f32 - %c_neg_inf = arith.constant -1.0e+30 : f32 - - %v0_e = arith.constant dense<0.0> : vector<{{ tile_e }}xf32> - %v0_e_io = arith.constant dense<0.0> : vector<{{ tile_e }}x{{ io_stype }}> - %v0_s = arith.constant dense<0.0> : vector<{{ tile_s }}xf32> - %v0_2x = arith.constant dense<0.0> : vector<2xf32> - %v_neg_inf_2x = arith.constant dense<-1.0e+30> : vector<2xf32> - %v_scale = vector.broadcast %c_scale : f32 to vector<{{ tile_s }}xf32> - - {{ kernel.def_local_vars(indent_size=2) }} - - affine.for %kv = 0 to {{ H }} { - affine.for %blk = 0 to {{ nblk }} step 1 { - // Reset per-block accumulators for all qsub/dh tiles. - %qk_offset = affine.apply {{ qk_offset_map }}(%kv) - {{ kernel.def_dma_op("MVIN", "query", [], q_tile_desc, indent_size=8, dram_stride=q_dram_stride, dram_offset="qk_offset") }} - %q2D_buffer = memref.reinterpret_cast %q_buffer to offset: [0], sizes: [{{ Dh }}, {{ g_size }}], strides: [{{g_size}}, 1] : {{ q_tile_desc.get_mlir_shape(io_stype) }} to memref<{{ Dh }}x{{ g_size }}x{{ io_stype }}, 1> - affine.for %s0 = 0 to {{ BlkS }} step {{ tile_s }} { - affine.for %k0 = 0 to {{ Dh }} step {{ tile_e }} { - %kk_offset = affine.apply {{ kk_offset_map_blk }}(%kv, %s0, %k0)[%blk] - {{ kernel.def_dma_op("MVIN", "key", [], k_tile_desc, indent_size=10, padding=1, dram_stride=k_dram_stride, dram_offset="kk_offset") }} - %k2D = memref.reinterpret_cast %k_buffer to offset: [0], sizes: [{{ tile_s }}, {{ tile_e }}], strides: [{{ tile_e }},1] : {{ k_tile_desc.get_mlir_shape(io_stype) }} to memref<{{ tile_s }}x{{ tile_e }}x{{ io_stype }}, 1> - %q2D = memref.reinterpret_cast %q2D_buffer to offset: [%k0], sizes: [{{ tile_e }}, {{ g_size }}], strides: [{{ g_size }}, 1] : memref<{{ Dh }}x{{ g_size }}x{{ io_stype }}, 1> to memref<{{ tile_e }}x{{ g_size }}x{{ io_stype }}, 1> - linalg.matmul - ins(%k2D, %q2D : memref<{{ tile_s }}x{{ tile_e }}x{{ io_stype }}, 1>, memref<{{ tile_e }}x{{ g_size }}x{{ io_stype }}, 1>) - outs(%mul_buffer : {{ mul_tile_desc.get_mlir_shape(io_stype) }}) - - } { accumulation_loop=true } - } { accumulation_loop=true } - } { outer_loop=true } - } { outer_loop=true } - return -} -""" - - -class MLIRDecodeGQASDPAPartialTemplate(MLIRTemplate): - def __init__(self, input_nodes, layout, scale, BlkS: int = 1024, input_reorder=None): - super().__init__("kernel", input_nodes, layout, input_reorder) - self.scale = scale - self.BlkS = BlkS - - def render(self, kernel: MLIRTemplateKernel, template_buffer_node=None, epilogue_nodes=None, prologue_nodes=None, tile_info=None, **kwargs): - query, key, value = self.input_nodes[0], self.input_nodes[1], self.input_nodes[2] - # Use the actual registered buffer node (e.g. "buf0") instead of the placeholder "buf_out". - partial = template_buffer_node if template_buffer_node is not None else self.output_node - - q_tensor4 = empty_strided(query.layout.size, query.layout.stride) - k_tensor4 = empty_strided(key.layout.size, key.layout.stride) - v_tensor4 = empty_strided(value.layout.size, value.layout.stride) - B, Hq, Lq, Dh = q_tensor4.shape - _, H, S, _ = k_tensor4.shape - assert B == 1 and Lq == 1 - g = Hq // H - g_size = g - BlkS = min(int(self.BlkS), int(S)) - nblk = (int(S) + int(BlkS) - 1) // int(BlkS) - - io_stype = mlir_common.DTYPE_TO_MLIR[query.get_dtype()] - tile_s = kernel.vector_lane - tile_e = kernel.vector_lane - tile_pack = tile_e * 2 - - # Use 3D views for indices - q_tensor = q_tensor4.view(Hq, 1, Dh) - k_tensor = k_tensor4.view(H, S, Dh) - v_tensor = v_tensor4.view(H, S, Dh) - - # Flatten (kv,qsub,dh_tile) into GH = H*g*(Dh/tile_e) - dh_tiles = int(Dh) // int(tile_e) - HgDhTiles = int(H) * int(g) * int(dh_tiles) - - # tile descs - vlane_stride = 1 - q_tile_desc = mlir_common.MLIRMultiDimTile([Dh, 1, g_size], kernel.vector_lane, 2, vlane_stride) - q_tile_desc.set_tile_size_stride([Dh, 1, g_size], [g_size, 1, 1]) - q_tile_desc.set_name("q_buffer") - q_tile_desc.offset = query.get_layout().offset - - k_tile_desc = mlir_common.MLIRMultiDimTile([1, tile_s, tile_e], kernel.vector_lane, 2, vlane_stride) - k_tile_desc.set_tile_size_stride([1, tile_s, tile_e], [1, 1, tile_s]) - k_tile_desc.set_name("k_buffer") - k_tile_desc.offset = key.get_layout().offset - - v_tile_desc = mlir_common.MLIRMultiDimTile([1, tile_s, tile_e], kernel.vector_lane, 1, vlane_stride) - v_tile_desc.set_tile_size_stride([1, tile_s, tile_e], [1, tile_e, 1]) - v_tile_desc.set_name("v_buffer") - v_tile_desc.offset = value.get_layout().offset - - mul_tile_desc = mlir_common.MLIRMultiDimTile([tile_s, g_size], kernel.vector_lane, 1, vlane_stride) - mul_tile_desc.set_tile_size_stride([tile_s, g_size], [1, tile_s]) - mul_tile_desc.set_name("mul_buffer") - - # score_desc = mlir_common.MLIRMultiDimTile([g, tile_s], kernel.vector_lane, 1, vlane_stride) - # score_desc.set_tile_size_stride([g, tile_s], [tile_s, 1]) - # score_desc.set_name("score_buffer") - - # prob_desc = mlir_common.MLIRMultiDimTile([g, tile_s], kernel.vector_lane, 1, vlane_stride) - # prob_desc.set_tile_size_stride([g, tile_s], [tile_s, 1]) - # prob_desc.set_name("prob_buffer") - - # # Per-qsub, per-dh-tile accumulators so QK is computed once and SV expands across dh tiles. - # out_acc_tile_desc = mlir_common.MLIRMultiDimTile([g, dh_tiles, tile_e], kernel.vector_lane, 2, vlane_stride) - # out_acc_tile_desc.set_tile_size_stride([g, dh_tiles, tile_e], [dh_tiles * tile_e, tile_e, 1]) - # out_acc_tile_desc.set_name("out_acc_buffer") - - # max_desc = mlir_common.MLIRMultiDimTile([g, 2], kernel.vector_lane, 0, vlane_stride) - # max_desc.set_tile_size_stride([g, 2], [2, 1]) - # max_desc.set_name("max_buffer") - - # sum_desc = mlir_common.MLIRMultiDimTile([g, 2], kernel.vector_lane, 0, vlane_stride) - # sum_desc.set_tile_size_stride([g, 2], [2, 1]) - # sum_desc.set_name("sum_buffer") - - # out_io_tile_desc = mlir_common.MLIRMultiDimTile([1, 1, tile_e], kernel.vector_lane, 1, vlane_stride) - # out_io_tile_desc.set_tile_size_stride([1, 1, tile_e], [0, tile_e, 1]) - # out_io_tile_desc.set_name("out_io_buffer") - - # partial_tile_desc = mlir_common.MLIRMultiDimTile([1, 1, tile_pack], kernel.vector_lane, 1, vlane_stride) - # partial_tile_desc.set_tile_size_stride([1, 1, tile_pack], [0, tile_pack, 1]) - # partial_tile_desc.set_name("partial_buffer") - - # Strides from 3D tensor views - q_stride = q_tensor.stride() - k_stride = k_tensor.stride() - v_stride = v_tensor.stride() - - # partial tensor is view(HgDhTiles, nblk, tile_pack) contiguous - p_tensor = empty_strided(partial.get_layout().size, partial.get_layout().stride).view(HgDhTiles, nblk, tile_pack) - p_stride = p_tensor.stride() - - # DMA strides - k_dram_stride = [int(k_stride[0]), int(k_stride[1]), int(k_stride[2])] - q_dram_stride = [int(q_stride[2]), 0, int(q_stride[0])] - v_dram_stride = [int(v_stride[0]), int(v_stride[1]), int(v_stride[2])] - partial_dram_stride = [int(p_stride[0]), int(p_stride[1]), 1] - - # Affine offset maps - kk_offset_map = _make_offset_map(k_dram_stride, k_tile_desc.offset) - qk_offset_map = _make_offset_map([int(g) * int(q_stride[2])], q_tile_desc.offset) - v_offset_map = _make_offset_map(v_dram_stride, v_tile_desc.offset) - # partial: offset(gh, blk) -- gh = (kv*g+qsub)*dh_tiles+dht, pre-computed in template - partial_offset_map = _make_offset_map([int(p_stride[0]), int(p_stride[1])], 0) - # Blk-symbol variants: %s0 is relative (0..BlkS-1), %blk is a block index (0..nblk-1), - # so actual_s = s0_rel + BlkS * blk → sym_stride=BlkS. - kk_offset_map_blk = _make_offset_map_with_sym(k_dram_stride, sym_dim=1, sym_stride=int(BlkS), offset=k_tile_desc.offset) - v_offset_map_blk = _make_offset_map_with_sym(v_dram_stride, sym_dim=1, sym_stride=int(BlkS), offset=v_tile_desc.offset) - - # Keep sympy-based indices only for epilogue_info - kv = sympy.Symbol("kv") - qsub = sympy.Symbol("qsub") - dht = sympy.Symbol("dht") - dh0 = sympy.Symbol("dh0") - blk = sympy.Symbol("blk") - q_head = kv * g + qsub - gh = (kv * g + qsub) * dh_tiles + dht - partial_idx = [gh * p_stride[0], blk * p_stride[1], sympy.Integer(0)] - - kernel.loop_size = [tile_s, tile_e, tile_pack] - - kernel.render_options = dict( - KERNEL_NAME=self.name, - kernel=kernel, - H=H, - g=g, - Dh=Dh, - S=S, - BlkS=BlkS, - nblk=nblk, - tile_s=tile_s, - tile_e=tile_e, - g_size=g_size, - dh_tiles=dh_tiles, - tile_pack=tile_pack, - io_stype=io_stype, - scale=self.scale, - query=query, - key=key, - value=value, - partial=partial, - q_tile_desc=q_tile_desc, - k_tile_desc=k_tile_desc, - v_tile_desc=v_tile_desc, - mul_tile_desc=mul_tile_desc, - # score_desc=score_desc, - # prob_desc=prob_desc, - # out_io_tile_desc=out_io_tile_desc, - # out_acc_tile_desc=out_acc_tile_desc, - # max_desc=max_desc, - # sum_desc=sum_desc, - # partial_tile_desc=partial_tile_desc, - # DMA strides - k_dram_stride=k_dram_stride, - q_dram_stride=q_dram_stride, - v_dram_stride=v_dram_stride, - partial_dram_stride=partial_dram_stride, - # Affine offset maps - kk_offset_map=kk_offset_map, - qk_offset_map=qk_offset_map, - v_offset_map=v_offset_map, - partial_offset_map=partial_offset_map, - kk_offset_map_blk=kk_offset_map_blk, - v_offset_map_blk=v_offset_map_blk, - input_reorder=self.input_reorder, - ) - - return self._template_from_string(DECODE_GQA_SDPA_PARTIAL_TEMPLATE).render(**kernel.render_options) - - -DECODE_GQA_SDPA_REDUCE_TEMPLATE = r""" -// Decode GQA SDPA reduce kernel: merge partials across blocks -// Input partial shape: (HgDhTiles, nblk, tile_pack) -{{kernel.def_global_vars()}} - -func.func @{{ KERNEL_NAME }}{{kernel.def_kernel(inputs=[partial], outputs=[out], names_str="partial, out", input_reorder=input_reorder)}} { - {{ kernel.def_sram_buffer("partial", partial_tile_desc, indent_size=2) }} - {{ kernel.def_sram_buffer("out_acc", out_acc_tile_desc, indent_size=2) }} - {{ kernel.def_sram_buffer("out", out_tile_desc, indent_size=2) }} - {{ kernel.def_sram_buffer("max", max_desc, indent_size=2) }} - {{ kernel.def_sram_buffer("sum", sum_desc, indent_size=2) }} - - %c0 = arith.constant 0.0 : f32 - %c1 = arith.constant 1.0 : f32 - %c_neg_inf = arith.constant -1.0e+30 : f32 - %v0_e = arith.constant dense<0.0> : vector<{{ tile_e }}xf32> - %v0_2x = arith.constant dense<0.0> : vector<2xf32> - %v_neg_inf_2x = arith.constant dense<-1.0e+30> : vector<2xf32> - - {{ kernel.def_local_vars(indent_size=2) }} - - affine.for %gh = 0 to {{ HgDhTiles }} { - // reset merged accumulators - affine.vector_store %v0_e, %out_acc_buffer[0, 0, 0] : {{ out_acc_tile_desc.get_mlir_shape("f32") }}, vector<{{ tile_e }}xf32> - affine.vector_store %v_neg_inf_2x, %max_buffer[0, 0] : {{ max_desc.get_mlir_shape("f32") }}, vector<2xf32> - affine.vector_store %v0_2x, %sum_buffer[0, 0] : {{ sum_desc.get_mlir_shape("f32") }}, vector<2xf32> - - affine.for %blk = 0 to {{ nblk }} { - %partial_offset = affine.apply {{ partial_offset_map }}(%gh, %blk) - {{ kernel.def_dma_op("MVIN", "partial", [], partial_tile_desc, indent_size=8, dram_stride=partial_dram_stride, dram_offset="partial_offset") }} - %p = affine.vector_load %partial_buffer[0, 0, 0] : {{ partial_tile_desc.get_mlir_shape("f32") }}, vector<{{ tile_pack }}xf32> - %p2 = vector.shape_cast %p : vector<{{ tile_pack }}xf32> to vector<2x{{ tile_e }}xf32> - %o_j = vector.extract %p2[0] : vector<{{ tile_e }}xf32> from vector<2x{{ tile_e }}xf32> - %ml_j = vector.extract %p2[1] : vector<{{ tile_e }}xf32> from vector<2x{{ tile_e }}xf32> - %m_j = vector.extract %ml_j[0] : f32 from vector<{{ tile_e }}xf32> - %l_j = vector.extract %ml_j[1] : f32 from vector<{{ tile_e }}xf32> - - %old_max = affine.vector_load %max_buffer[0, 0] : {{ max_desc.get_mlir_shape("f32") }}, vector<2xf32> - %m_old = vector.extract %old_max[0] : f32 from vector<2xf32> - %m_new = arith.maximumf %m_old, %m_j : f32 - %m_new2 = vector.broadcast %m_new : f32 to vector<2xf32> - affine.vector_store %m_new2, %max_buffer[0, 0] : {{ max_desc.get_mlir_shape("f32") }}, vector<2xf32> - - %diff_old = arith.subf %m_old, %m_new : f32 - %diff_j = arith.subf %m_j, %m_new : f32 - %diff_old_v = vector.broadcast %diff_old : f32 to vector<1xf32> - %diff_j_v = vector.broadcast %diff_j : f32 to vector<1xf32> - %scale_old_v = math.exp %diff_old_v : vector<1xf32> - %scale_j_v = math.exp %diff_j_v : vector<1xf32> - %scale_old = vector.extract %scale_old_v[0] : f32 from vector<1xf32> - %scale_j = vector.extract %scale_j_v[0] : f32 from vector<1xf32> - %scale_old_e = vector.broadcast %scale_old : f32 to vector<{{ tile_e }}xf32> - %scale_j_e = vector.broadcast %scale_j : f32 to vector<{{ tile_e }}xf32> - - %o_old = affine.vector_load %out_acc_buffer[0, 0, 0] : {{ out_acc_tile_desc.get_mlir_shape("f32") }}, vector<{{ tile_e }}xf32> - %o_old_rs = arith.mulf %o_old, %scale_old_e : vector<{{ tile_e }}xf32> - %o_j_rs = arith.mulf %o_j, %scale_j_e : vector<{{ tile_e }}xf32> - %o_new = arith.addf %o_old_rs, %o_j_rs : vector<{{ tile_e }}xf32> - affine.vector_store %o_new, %out_acc_buffer[0, 0, 0] : {{ out_acc_tile_desc.get_mlir_shape("f32") }}, vector<{{ tile_e }}xf32> - - %old_sum = affine.vector_load %sum_buffer[0, 0] : {{ sum_desc.get_mlir_shape("f32") }}, vector<2xf32> - %l_old = vector.extract %old_sum[0] : f32 from vector<2xf32> - %l_old_rs = arith.mulf %l_old, %scale_old : f32 - %l_j_rs = arith.mulf %l_j, %scale_j : f32 - %l_new = arith.addf %l_old_rs, %l_j_rs : f32 - %l_new2 = vector.broadcast %l_new : f32 to vector<2xf32> - affine.vector_store %l_new2, %sum_buffer[0, 0] : {{ sum_desc.get_mlir_shape("f32") }}, vector<2xf32> - } { accumulation_loop=true } - - // finalize: out = o / l - %sum2 = affine.vector_load %sum_buffer[0, 0] : {{ sum_desc.get_mlir_shape("f32") }}, vector<2xf32> - %l = vector.extract %sum2[0] : f32 from vector<2xf32> - %inv = arith.divf %c1, %l : f32 - %inv_e = vector.broadcast %inv : f32 to vector<{{ tile_e }}xf32> - %o = affine.vector_load %out_acc_buffer[0, 0, 0] : {{ out_acc_tile_desc.get_mlir_shape("f32") }}, vector<{{ tile_e }}xf32> - %out_f32 = arith.mulf %o, %inv_e : vector<{{ tile_e }}xf32> - {% if io_stype != "f32" %}%out_io = arith.truncf %out_f32 : vector<{{ tile_e }}xf32> to vector<{{ tile_e }}x{{ io_stype }}>{% endif %} - affine.vector_store {{ "%out_io" if io_stype != "f32" else "%out_f32" }}, %out_buffer[0, 0, 0] : {{ out_tile_desc.get_mlir_shape(io_stype) }}, vector<{{ tile_e }}x{{ io_stype }}> - %out_offset = affine.apply {{ out_offset_map }}(%gh) - {{ kernel.def_dma_op("MVOUT", "out", [], out_tile_desc, indent_size=4, dram_stride=out_dram_stride, dram_offset="out_offset") }} - } { outer_loop=true } - return -} -""" - - -class MLIRDecodeGQASDPAReduceTemplate(MLIRTemplate): - def __init__(self, input_nodes, layout, BlkS: int = 1024, input_reorder=None): - super().__init__("kernel", input_nodes, layout, input_reorder) - self.BlkS = BlkS - - def render(self, kernel: MLIRTemplateKernel, template_buffer_node=None, epilogue_nodes=None, prologue_nodes=None, tile_info=None, **kwargs): - partial = self.input_nodes[0] - # Use the actual registered buffer node (e.g. "buf0") instead of the placeholder "buf_out". - out = template_buffer_node if template_buffer_node is not None else self.output_node - - tile_e = kernel.vector_lane - tile_pack = tile_e * 2 - - # Infer sizes from partial layout: (HgDhTiles, nblk, tile_pack) - HgDhTiles, nblk, _ = partial.get_size() - io_stype = mlir_common.DTYPE_TO_MLIR[out.get_dtype()] - - vlane_stride = 1 - partial_tile_desc = mlir_common.MLIRMultiDimTile([1, 1, tile_pack], kernel.vector_lane, 1, vlane_stride) - partial_tile_desc.set_tile_size_stride([1, 1, tile_pack], [0, tile_pack, 1]) - partial_tile_desc.set_name("partial_buffer") - partial_tile_desc.offset = partial.get_layout().offset - - out_acc_tile_desc = mlir_common.MLIRMultiDimTile([1, 1, tile_e], kernel.vector_lane, 1, vlane_stride) - out_acc_tile_desc.set_tile_size_stride([1, 1, tile_e], [0, tile_e, 1]) - out_acc_tile_desc.set_name("out_acc_buffer") - - max_desc = mlir_common.MLIRMultiDimTile([1, 2], kernel.vector_lane, 0, vlane_stride) - max_desc.set_tile_size_stride([1, 2], [2, 1]) - max_desc.set_name("max_buffer") - - sum_desc = mlir_common.MLIRMultiDimTile([1, 2], kernel.vector_lane, 0, vlane_stride) - sum_desc.set_tile_size_stride([1, 2], [2, 1]) - sum_desc.set_name("sum_buffer") - - out_tile_desc = mlir_common.MLIRMultiDimTile([1, 1, tile_e], kernel.vector_lane, 1, vlane_stride) - out_tile_desc.set_tile_size_stride([1, 1, tile_e], [0, tile_e, 1]) - out_tile_desc.set_name("out_buffer") - - # Partial tensor strides - p_tensor = empty_strided(partial.get_layout().size, partial.get_layout().stride) - p_stride = p_tensor.stride() - - # Out view: (Hq*dh_tiles, 1, tile_e) - out_tensor4 = empty_strided(out.get_layout().size, out.get_layout().stride) - B, Hq, Lq, Dh = out_tensor4.shape - assert B == 1 and Lq == 1 - dh_tiles = int(Dh) // int(tile_e) - out_tensor = out_tensor4.view(Hq * dh_tiles, 1, tile_e) - o_stride = out_tensor.stride() - - # DMA strides - partial_dram_stride = [int(p_stride[0]), int(p_stride[1]), 1] - out_dram_stride = [int(o_stride[0]), 0, 0] - - # Affine offset maps - # partial: offset(gh, blk) - partial_offset_map = _make_offset_map([int(p_stride[0]), int(p_stride[1])], partial_tile_desc.offset) - # out: offset(gh) -- single dimension - out_offset_map = _make_offset_map([int(o_stride[0])], 0) - - # Keep sympy-based indices for epilogue_info - gh = sympy.Symbol("gh") - blk = sympy.Symbol("blk") - partial_idx = [gh * p_stride[0], blk * p_stride[1], sympy.Integer(0)] - out_idx = [gh * o_stride[0], sympy.Integer(0), sympy.Integer(0)] - - kernel.loop_size = [tile_pack, tile_e, 1] - - kernel.render_options = dict( - KERNEL_NAME=self.name, - kernel=kernel, - HgDhTiles=HgDhTiles, - nblk=nblk, - tile_e=tile_e, - tile_pack=tile_pack, - io_stype=io_stype, - partial=partial, - out=out, - partial_tile_desc=partial_tile_desc, - out_acc_tile_desc=out_acc_tile_desc, - max_desc=max_desc, - sum_desc=sum_desc, - out_tile_desc=out_tile_desc, - # DMA strides - partial_dram_stride=partial_dram_stride, - out_dram_stride=out_dram_stride, - # Affine offset maps - partial_offset_map=partial_offset_map, - out_offset_map=out_offset_map, - input_reorder=self.input_reorder, - ) - - return self._template_from_string(DECODE_GQA_SDPA_REDUCE_TEMPLATE).render(**kernel.render_options) diff --git a/tests/test_sdpa.py b/tests/test_sdpa.py index ed7ae8f8..c4825731 100644 --- a/tests/test_sdpa.py +++ b/tests/test_sdpa.py @@ -1,128 +1,145 @@ import sys -import math +import os import torch -import inspect -from typing import List +import torch._dynamo import torch.nn.functional as F -from torch.nn.attention import SDPBackend, sdpa_kernel -from torch.fx.passes.graph_drawer import FxGraphDrawer -from torch._inductor.decomposition import decompositions -def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): - message = f"|{name} Test Passed|" +base_dir = os.environ.get("TORCHSIM_DIR", default="/workspace/PyTorchSim") +sys.path.append(base_dir) + +device = torch.device("npu:0") + +# --------------------------------------------------------------------------- +# Default sweep configs - edit here to change what gets tested +# --------------------------------------------------------------------------- +SDPA_DEFAULTS = dict( + n_batch_list = [1, 4, 8, 16], + n_head_list = [4, 6, 8, 12], + n_token_list = [128, 256, 512, 1024], + head_dim_list = [32, 64, 128], + is_causal = False, +) + +GQA_DEFAULTS = dict( + batch_list = [1], + num_kv_heads = 1, + gqa_ratios = [4, 5, 8, 16], # Hq = ratio * num_kv_heads + seq_len_list = [128, 256, 1024], + head_dim_list = [64, 128], + query_len = 1, # decode shape: Lq == 1 + is_causal = True, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- +def clear_caches(): + from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache + from torch._inductor.codecache import FxGraphCache + AOTAutogradCache.clear() + torch._dynamo.reset() + os.environ["TORCHINDUCTOR_CACHE"] = "0" + FxGraphCache.clear() + + +def assert_close(name, out, cpu_out, rtol=1e-4, atol=1e-4): + msg = f"|{name} Test Passed|" if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): - print("-" * len(message)) - print(message) - print("-" * len(message)) - pass + print("-" * len(msg)) + print(msg) + print("-" * len(msg)) else: - print("custom out: ", out.cpu()) - print("cpu out: ", cpu_out) + print(f"[FAIL] {name}") + print(" device out:", out.cpu()) + print(" cpu out:", cpu_out) exit(1) -def test_scaled_dot_product_attention(device, backends="flash"): + +def _run_sdpa(device, q, k, v, **kwargs): + """Compile and run SDPA on device; return result on device.""" + opt_fn = torch.compile(dynamic=False)(F.scaled_dot_product_attention) + return opt_fn(q.to(device), k.to(device), v.to(device), **kwargs) + + +def _cpu_sdpa(q, k, v, **kwargs): + """Run reference SDPA on CPU.""" + return F.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(), **kwargs) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- +def test_sdpa( + device, + n_batch_list = SDPA_DEFAULTS["n_batch_list"], + n_head_list = SDPA_DEFAULTS["n_head_list"], + n_token_list = SDPA_DEFAULTS["n_token_list"], + head_dim_list = SDPA_DEFAULTS["head_dim_list"], + is_causal = SDPA_DEFAULTS["is_causal"], +): torch.manual_seed(0) - n_batch_list = [1, 4, 8, 16] - n_head_list = [1, 4, 8, 12] - n_token_list = [128, 256, 512, 1024] - head_dim_list = [32, 64, 128] - - for n_batch in n_batch_list: - for n_head in n_head_list: - for n_token in n_token_list: - for head_dim in head_dim_list: - # Inputs + sdpa_kwargs = dict(attn_mask=None, dropout_p=0.0, is_causal=is_causal) + + for B in n_batch_list: + for H in n_head_list: + for S in n_token_list: + for D in head_dim_list: clear_caches() - query = torch.rand(n_batch, n_head, n_token, head_dim, dtype=torch.float32) - key = torch.rand(n_batch, n_head, n_token, head_dim, dtype=torch.float32) - value = torch.rand(n_batch, n_head, n_token, head_dim, dtype=torch.float32) - - # With NPU - query = query.to(device=device) - key = key.to(device=device) - value = value.to(device=device) - - opt_fn = torch.compile(dynamic=False)(F.scaled_dot_product_attention) - out = opt_fn(query, key, value) - out = out.to(device) - - # With CPU - cpu_device = torch.device('cpu') - query = query.to(device=cpu_device) - key = key.to(device=cpu_device) - value = value.to(device=cpu_device) - cpu_out = F.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) - - name = f"SDPA(n_batch: {n_batch}, n_head: {n_head}, n_token: {n_token}, head_dim: {head_dim})" - test_result(name, out, cpu_out) - - print("All tests passed!") - -def test_scaled_dot_product_attention_gqa_single_batch(device): + q = torch.rand(B, H, S, D, dtype=torch.float32) + k = torch.rand(B, H, S, D, dtype=torch.float32) + v = torch.rand(B, H, S, D, dtype=torch.float32) + + out = _run_sdpa(device, q, k, v, **sdpa_kwargs) + cpu_out = _cpu_sdpa(q, k, v, **sdpa_kwargs) + + assert_close(f"SDPA(B:{B}, H:{H}, S:{S}, D:{D})", out, cpu_out) + + print("All SDPA tests passed!") + + +def test_gqa( + device, + batch_list = GQA_DEFAULTS["batch_list"], + num_kv_heads = GQA_DEFAULTS["num_kv_heads"], + gqa_ratios = GQA_DEFAULTS["gqa_ratios"], + seq_len_list = GQA_DEFAULTS["seq_len_list"], + head_dim_list= GQA_DEFAULTS["head_dim_list"], + query_len = GQA_DEFAULTS["query_len"], + is_causal = GQA_DEFAULTS["is_causal"], +): """ - Focused GQA testcases for single-batch (n==1). - Shapes: - q: (B, Hq, Lq, Dh) - k: (B, H, S, Dh) - v: (B, H, S, Dh) + GQA sweep: q shape (B, Hq, Lq, D), kv shape (B, H, S, D). + Hq = ratio * num_kv_heads for each ratio in gqa_ratios. """ torch.manual_seed(0) + sdpa_kwargs = dict(attn_mask=None, dropout_p=0.0, is_causal=is_causal, enable_gqa=True) - B = 1 - # Decode-focused: include a larger S to hit BlkS logic - seq_len_list = [128, 256, 1024] - head_dim_list = [64, 128] - # GQA ratios requested: Hq / H in {4, 5, 8, 16}. - # Keep H=1 to directly realize those ratios. - gqa_ratios = [4, 5, 8, 16] - H = 1 - - for seq_len in seq_len_list: - for head_dim in head_dim_list: - for ratio in gqa_ratios: - Hq = ratio * H - - clear_caches() - # Decode shape: Lq == 1 - q = torch.rand(B, Hq, 1, head_dim, dtype=torch.float32) - k = torch.rand(B, H, seq_len, head_dim, dtype=torch.float32) - v = torch.rand(B, H, seq_len, head_dim, dtype=torch.float32) - - # NPU - q_npu = q.to(device=device) - k_npu = k.to(device=device) - v_npu = v.to(device=device) - opt_fn = torch.compile(dynamic=False)(F.scaled_dot_product_attention) - out = opt_fn(q_npu, k_npu, v_npu, attn_mask=None, dropout_p=0.0, is_causal=True, enable_gqa=True) - - # CPU reference - cpu_device = torch.device("cpu") - cpu_out = F.scaled_dot_product_attention( - q.to(device=cpu_device), - k.to(device=cpu_device), - v.to(device=cpu_device), - attn_mask=None, - dropout_p=0.0, - is_causal=True, - enable_gqa=True, - ) - - name = f"SDPA-GQA(B: {B}, Hq: {Hq}, H: {H}, S: {seq_len}, head_dim: {head_dim})" - test_result(name, out, cpu_out) - - print("All GQA single-batch tests passed!") + for B in batch_list: + for S in seq_len_list: + for D in head_dim_list: + for ratio in gqa_ratios: + Hq = ratio * num_kv_heads + clear_caches() + q = torch.rand(B, Hq, query_len, D, dtype=torch.float32) + k = torch.rand(B, num_kv_heads, S, D, dtype=torch.float32) + v = torch.rand(B, num_kv_heads, S, D, dtype=torch.float32) -def clear_caches(): - import os - from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache - from torch._inductor.codecache import FxGraphCache - AOTAutogradCache.clear() - torch._dynamo.reset() - os.environ["TORCHINDUCTOR_CACHE"] = "0" - FxGraphCache.clear() + out = _run_sdpa(device, q, k, v, **sdpa_kwargs) + cpu_out = _cpu_sdpa(q, k, v, **sdpa_kwargs) + + assert_close( + f"GQA(B:{B}, Hq:{Hq}, H:{num_kv_heads}, S:{S}, D:{D})", + out, cpu_out, + ) + + print("All GQA tests passed!") + + +if __name__ == "__main__": + with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.FLASH_ATTENTION]): + test_sdpa(device) + #test_gqa(device) -if __name__ == "__main__": - device = torch.device('npu:0') - # test_scaled_dot_product_attention(device, backends="flash") - test_scaled_dot_product_attention_gqa_single_batch(device) - \ No newline at end of file + # Example: quick single-config run + # test_gqa(device, batch_list=[1], gqa_ratios=[5], seq_len_list=[32], head_dim_list=[128]) From dd71c70766a06149a975615585f51536d1ea2904 Mon Sep 17 00:00:00 2001 From: HamHyungkyu Date: Tue, 17 Mar 2026 02:31:49 +0000 Subject: [PATCH 29/31] [Frontend] Handle RecompileSignal in MLIRKernel code generation --- PyTorchSimFrontend/mlir/mlir_codegen_backend.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index 38125e31..672c35f7 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -964,7 +964,10 @@ def make_choices(self, nodes, kernel_name): # Try initial tile size self.reset(None) - src_code, meta_code = super().codegen_nodes(nodes, kernel_name) + try: + src_code, meta_code = super().codegen_nodes(nodes, kernel_name) + except mlir_common.RecompileSignal: + continue current_tile_sz = tuple(self.kernel_group.tile_desc.get_tile_size()) search_space.add(current_tile_sz) @@ -986,14 +989,12 @@ def make_choices(self, nodes, kernel_name): # Try increase tile size for this axis try: self.kernel_group.tile_desc.scale_tile_dim(axis, prev_ranges[axis], 2) - except extension_codecache.TileSizeError as e: - # Failed to find proper tile size + self.reset(None) + src_code, meta_code = super().codegen_nodes(nodes, kernel_name) + except (extension_codecache.TileSizeError, mlir_common.RecompileSignal): candidate_axes.remove(axis) self.reset(None) continue - - self.reset(None) - src_code, meta_code = super().codegen_nodes(nodes, kernel_name) current_tile_sz = tuple(self.kernel_group.tile_desc.get_tile_size()) # FIXME. How to intergrate this constraint to tile system? From c5f085ece4e9523ca1e97ee165c6cb976df5427c Mon Sep 17 00:00:00 2001 From: HamHyungkyu Date: Tue, 17 Mar 2026 02:39:03 +0000 Subject: [PATCH 30/31] [Frontend] Enhance vector size handling for low-precision paths in MLIR kernels --- PyTorchSimFrontend/mlir/mlir_common.py | 73 +++++++++++++++++++++--- PyTorchSimFrontend/mlir/mlir_template.py | 4 +- 2 files changed, 68 insertions(+), 9 deletions(-) diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index 9f5dc6ab..32805261 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -103,14 +103,17 @@ def get_dtype_nbytes(dtype): MLIR_INF = { "inf" : { + "f16" : 0x7C00, "f32" : 0x7F800000, "f64" : 0x7FF0000000000000 }, "-inf" : { + "f16" : 0xFC00, "f32" : 0xFF800000, "f64" : 0xFFF0000000000000 }, "nan" : { + "f16" : 0x7C00, "f32" : 0x7FC00000, "f64" : 0x7FF8000000000000 } @@ -260,17 +263,23 @@ def get_tile_stride_per_lane(self, tile_size: list[int], tile_stride: list[int]) return tile_stride def get_compute_vec_size(self, tile_size: list[int], reduction_numel: int, nr_rdim: int) -> int: - if self.forced_vec_size is not None: - return self.forced_vec_size - per_lane = self.get_numel_per_lane(tile_size) stride = self.vlane_stride if nr_rdim: val = per_lane // max(reduction_numel, 1) + result = val for mult in [8, 4, 2]: if per_lane >= val * mult: - return val * mult - return val + result = val * mult + break + if self.forced_vec_size is not None: + # Cap while keeping result divisible by val (= reduction_size). + # This preserves the assert(vec_len % reduction_size == 0) invariant. + capped = (min(result, self.forced_vec_size) // max(val, 1)) * max(val, 1) + result = max(capped, val) + return result + if self.forced_vec_size is not None: + return self.forced_vec_size for mult in [8, 4, 2]: if (per_lane // stride) >= mult: return stride * mult @@ -787,10 +796,24 @@ def codegen_nodes(self, nodes, kernel_name): # Set node range info vars, reduction_vars = self.set_ranges(group, reduction_group) tile_desc = self.compute_tile_size(nodes, vars, reduction_vars) + _, _, _, self.buffer_types = self.kernel_group.args.mlir_argdefs() + safe_vec_size = self.get_safe_vec_size(tile_desc.get_compute_vec_size()) + # For pointwise (non-reduction) kernels, cap the MLIR vector size so that + # f16->f32 widening stays within LMUL<=4 (step and forced_vec_size must match). + # Reduction kernels are left unchanged: their accumulator/multi_reduction + # structure assumes compute_vec_size == step, so we must not split them here. + tile_desc.vmap.forced_vec_size = safe_vec_size + compute_vec = tile_desc.get_compute_vec_size() + # RVV requires vector lengths that produce integer power-of-2 LMUL values. + # Non-power-of-2 element counts (e.g. 24) cause LLVM WidenVectorResult crashes. + # Raise BEFORE the try/except so this propagates to make_choices (not retried). + if compute_vec > 1 and (compute_vec & (compute_vec - 1)) != 0: + raise RecompileSignal( + f"Non-power-of-2 compute_vec_size {compute_vec}: tile rejected (RVV requires power-of-2 LMUL)" + ) self.compute_body_loop.size = tile_desc.get_numel_per_lane() - self.compute_body_loop.step = tile_desc.get_compute_vec_size() + self.compute_body_loop.step = compute_vec try: - _, _, _, self.buffer_types = self.kernel_group.args.mlir_argdefs() with self as kernel: for node in nodes: node.run(vars, reduction_vars) @@ -1035,6 +1058,42 @@ def __exit__(self, exc_type, exc_val, exc_tb): self._nested_context_depth -= 1 if self._nested_context_depth == 0: super().__exit__(exc_type, exc_val, exc_tb) + + def get_safe_vec_size(self, default_vec_size: int = 64) -> int: + """ + Cap forced vector size for low-precision paths so widening ops + (e.g., f16/bf16 -> f32) do not exceed RVV LMUL limits. + + Widening is legal up to source LMUL<=4 (destination LMUL<=8). + Using RVV relation LMUL = (SEW * VL) / VLEN, the safe source VL is: + VL <= 4 * VLEN / SEW + """ + + if not hasattr(self, "buffer_types") or not self.buffer_types: + return default_vec_size + + lowp_bits = [] + for info in self.buffer_types.values(): + dtype = info[0] if info else None + if dtype in DTYPE_LOWP_FP: + mlir_dtype = DTYPE_TO_MLIR[dtype] + lowp_bits.append(MLIR_TO_BIT[mlir_dtype]) + + if not lowp_bits: + return default_vec_size + + min_lowp_bits = min(lowp_bits) + # Constraint: Vector element count must be compatible across all types. + # VLEN=256: f16 (LMUL=2) and f32 (LMUL=4) both yield 32 elements. + # Note: Gem5 version restricts widening ops to LMUL < 8 for destination registers. + # Max LMUL set to 2 to ensure compatibility/safety. + + widen_safe_cap = self.vlen * 2 // min_lowp_bits + if widen_safe_cap <= 0: + return default_vec_size + + vec_size = min(default_vec_size, widen_safe_cap) + return vec_size @dataclasses.dataclass class LoopLevel: diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index 53db988b..851f070f 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -1255,7 +1255,7 @@ def set_tile_size(self, template_fusion_info, prologue=False): numel_per_lane = tile_desc.get_numel_per_lane() r_tile_size = tile_desc.get_tile_size()[-1] nr_outer_loop = (numel_per_lane + r_tile_size-1) // r_tile_size - tile_desc.vmap.forced_vec_size = nr_outer_loop * 32 # Why? Emprically selected, other option failed to functionality... + tile_desc.vmap.forced_vec_size = self.get_safe_vec_size(nr_outer_loop * 32) # Why? Emprically selected, other option failed to functionality... self.reduction_fusion = True self.r_tile_size = tile_desc.get_tile_size()[-1] @@ -1266,7 +1266,7 @@ def set_tile_size(self, template_fusion_info, prologue=False): self.compute_body_loop.step = tile_desc.get_compute_vec_size() // nr_outer_loop self.reduction_body_loop = mlir_common.LoopLevel(self.reduction_loop_idx, nr_outer_loop) else: - tile_desc.vmap.forced_vec_size = 64 + tile_desc.vmap.forced_vec_size = self.get_safe_vec_size(64) if prologue: self.prologue_compute_body_loop.size = tile_desc.get_numel_per_lane() From fdd5b5459c41892b4d1a738b5baa3e21cd945b31 Mon Sep 17 00:00:00 2001 From: Wonhyuk Yang Date: Thu, 19 Mar 2026 02:01:08 +0900 Subject: [PATCH 31/31] [Refactor] move to TOGSimulator-based scheduler API --- experiments/BERT.py | 77 ++++++++++++++++------------------------ experiments/attention.py | 70 +++++++++++++----------------------- experiments/conv.py | 76 +++++++++++++++------------------------ experiments/gemm.py | 61 +++++++++++-------------------- experiments/layernorm.py | 59 +++++++++++------------------- experiments/resnet18.py | 57 ++++++++++------------------- experiments/resnet50.py | 57 ++++++++++------------------- experiments/softmax.py | 58 +++++++++++------------------- tests/Fusion/__init__.py | 0 tests/__init__.py | 0 10 files changed, 182 insertions(+), 333 deletions(-) create mode 100644 tests/Fusion/__init__.py create mode 100644 tests/__init__.py diff --git a/experiments/BERT.py b/experiments/BERT.py index fd671833..b938f4e6 100644 --- a/experiments/BERT.py +++ b/experiments/BERT.py @@ -1,57 +1,42 @@ -import torch -import torch._dynamo -import torch.utils.cpp_extension - +import os +import sys import argparse -import datetime -def run_BERT(size, input_seq, config): - from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request - # from tests.test_transformer import EncoderBlock - from tests.Fusion.test_transformer_fusion import EncoderBlock - scheduler = Scheduler(num_request_queue=1, engine_select=Scheduler.FIFO_ENGINE, togsim_config=config) - device = scheduler.execution_engine.module.custom_device() +base_path = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') +sys.path.insert(0, base_path) - hidden_dim = {'base': 768, 'large': 1024, 'xlarge': 2048} - embedding_size = {'base': 768, 'large': 1024, 'xlarge': 2048} - heads = {'base': 12, 'large': 16, 'xlarge': 32} # hidden/64 https://arxiv.org/pdf/1909.11942 - cpu_query = torch.randn(input_seq, hidden_dim[size]) - encoder_block = EncoderBlock(embedding_size[size], heads[size]).eval() - - query = cpu_query.clone().to(device=device) - opt_fn = torch.compile(dynamic=False)(encoder_block.to(device=device)) +import torch +from Simulator.simulator import TOGSimulator - SchedulerDNNModel.register_model(f"BERT-{size}", opt_fn) - request = Request(f"BERT-{size}", [query], [], request_queue_idx=0) - scheduler.add_request(request, request_time=0) +config = os.environ.get('TOGSIM_CONFIG', f'{base_path}/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.yml') +os.environ['TOGSIM_CONFIG'] = config - # Run scheduler - while not scheduler.is_finished(): - with torch.no_grad(): - scheduler.schedule() +# Try Fusion EncoderBlock first, fall back to standard test_transformer +try: + from tests.Fusion.test_transformer_fusion import EncoderBlock +except ImportError: + from tests.test_transformer import EncoderBlock - print(f"BERT-{size} Simulation Done") +HIDDEN_DIM = {'base': 768, 'large': 1024, 'xlarge': 2048} +EMBEDDING_SIZE = {'base': 768, 'large': 1024, 'xlarge': 2048} +HEADS = {'base': 12, 'large': 16, 'xlarge': 32} if __name__ == "__main__": - import os - import sys - base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') - config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.yml') - config_prefix = config.split('/')[-1].split('.')[0][9:] # extract config name from config path FIXME: gem5 result is different as directoy name - sys.path.append(base_dir) args = argparse.ArgumentParser() - args.add_argument('--size', type=str, default='base') - args.add_argument('--dump_path', type=str, default='results') + args.add_argument('--size', type=str, default='base', choices=['base', 'large', 'xlarge']) args.add_argument('--input_size', type=int, default=512) args = args.parse_args() - size = args.size - input_seq = args.input_size - result_path = os.path.join(base_dir, args.dump_path, config_prefix, f"BERT_{size}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") - # setting environment variables - os.environ['TORCHSIM_LOG_PATH'] = result_path - # only timing simulation - os.environ['TORCHSIM_VALIDATION_MODE'] = "0" - if 'pytorchsim_functional_mode' in os.environ: - del os.environ['pytorchsim_functional_mode'] - - run_BERT(size, input_seq, config) + + hidden_dim = HIDDEN_DIM[args.size] + embedding_size = EMBEDDING_SIZE[args.size] + heads = HEADS[args.size] + + device = torch.device("npu:0") + model = EncoderBlock(embedding_size, heads).eval().to(device=device) + model_input = torch.randn(args.input_size, hidden_dim).to(device=device) + opt_fn = torch.compile(dynamic=False)(model) + + with TOGSimulator(config_path=config): + torch.npu.launch_model(opt_fn, model_input, stream_index=0, timestamp=0) + torch.npu.synchronize() + print(f"BERT-{args.size} Simulation Done") diff --git a/experiments/attention.py b/experiments/attention.py index 211433f1..b56ed537 100644 --- a/experiments/attention.py +++ b/experiments/attention.py @@ -1,56 +1,36 @@ -import torch -import torch._dynamo -import torch.utils.cpp_extension - +import os +import sys +import math import argparse -import datetime +base_path = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') +sys.path.insert(0, base_path) -def run_attention(size, config): - def attention(query, key, value): - import math - d_k = query.size(-1) - scores = torch.matmul(key, query.transpose(-2, -1)) / math.sqrt(d_k) - p_attn = scores.softmax(dim=-2) - return torch.matmul(value.transpose(-1, -2), p_attn) - from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request - scheduler = Scheduler(num_request_queue=1, engine_select=Scheduler.FIFO_ENGINE, togsim_config=config) - device = scheduler.execution_engine.module.custom_device() - query = torch.randn(size).to(device=device) - key = torch.randn(size).to(device=device) - value = torch.randn(size).to(device=device) - opt_fn = torch.compile(dynamic=False)(attention) - - SchedulerDNNModel.register_model("attention", opt_fn) - request = Request("attention", [query, key, value], [], request_queue_idx=0) - scheduler.add_request(request, request_time=0) +import torch +from Simulator.simulator import TOGSimulator - # Run scheduler - while not scheduler.is_finished(): - with torch.no_grad(): - scheduler.schedule() +config = os.environ.get('TOGSIM_CONFIG', f'{base_path}/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.yml') +os.environ['TOGSIM_CONFIG'] = config - print(f"Attention {str(size)} Simulation Done") +def attention(query, key, value): + d_k = query.size(-1) + scores = torch.matmul(key, query.transpose(-2, -1)) / math.sqrt(d_k) + p_attn = scores.softmax(dim=-2) + return torch.matmul(value.transpose(-1, -2), p_attn) if __name__ == "__main__": - import os - import sys - base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') - config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.yml') - config_prefix = config.split('/')[-1].split('.')[0][9:] # extract config name from config path - sys.path.append(base_dir) args = argparse.ArgumentParser() args.add_argument('--size', nargs='+', type=int, default=[12, 512, 64], help='Tensor Shape') - args.add_argument('--dump_path', type=str, default='results') args = args.parse_args() - size = args.size - size_str = "x".join([str(i) for i in size]) - result_path = os.path.join(base_dir, args.dump_path, config_prefix, f"attention_{size_str}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") - # setting environment variables - os.environ['TORCHSIM_LOG_PATH'] = result_path - # only timing simulation - os.environ['TORCHSIM_VALIDATION_MODE'] = "0" - if 'pytorchsim_functional_mode' in os.environ: - del os.environ['pytorchsim_functional_mode'] + size = tuple(args.size) + + device = torch.device("npu:0") + query = torch.randn(*size).to(device=device) + key = torch.randn(*size).to(device=device) + value = torch.randn(*size).to(device=device) + opt_fn = torch.compile(dynamic=False)(attention) - run_attention(size, config) + with TOGSimulator(config_path=config): + torch.npu.launch_model(opt_fn, query, key, value, stream_index=0, timestamp=0) + torch.npu.synchronize() + print(f"Attention {size} Simulation Done") diff --git a/experiments/conv.py b/experiments/conv.py index 61f7ad80..98391fae 100644 --- a/experiments/conv.py +++ b/experiments/conv.py @@ -1,57 +1,39 @@ -import torch -import torch._dynamo -import torch.utils.cpp_extension - +import os +import sys import argparse -import datetime +base_path = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') +sys.path.insert(0, base_path) + +import torch +from Simulator.simulator import TOGSimulator + +config = os.environ.get('TOGSIM_CONFIG', f'{base_path}/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.yml') +os.environ['TOGSIM_CONFIG'] = config -def run_conv2d(batch_size, i_h, i_w, i_c, o_c, kernel_size, stride, padding, config): - from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request - def custom_conv2d(a, b, bias): - i_c = a.shape[1] - o_c = b.shape[0] - conv2d = torch.nn.Conv2d(i_c, o_c, b.shape[-1], stride=stride, padding=padding, dilation=1, bias=False) +def conv2d_fn(batch_size, i_h, i_w, i_c, o_c, kernel_size, stride, padding): + def _conv(a, b, bias): + conv2d = torch.nn.Conv2d(i_c, o_c, kernel_size, stride=stride, padding=padding, dilation=1, bias=False) conv2d.weight = torch.nn.Parameter(b) - # conv2d.bias = torch.nn.Parameter(bias) return conv2d(a) - scheduler = Scheduler(num_request_queue=1, engine_select=Scheduler.FIFO_ENGINE, togsim_config=config) - device = scheduler.execution_engine.module.custom_device() + return _conv + +if __name__ == "__main__": + args = argparse.ArgumentParser() + args.add_argument('--size', nargs='+', type=int, default=[8, 28, 28, 128, 128, 3, 1, 1], + help='B H W I_C O_C K S P') + args = args.parse_args() + batch_size, i_h, i_w, i_c, o_c, kernel_size, stride, padding = args.size + + device = torch.device("npu:0") conv_input = torch.randn(batch_size, i_c, i_h, i_w).to(memory_format=torch.channels_last, device=device) conv_kernel = torch.randn(o_c, i_c, kernel_size, kernel_size).to(memory_format=torch.channels_last, device=device) conv_bias = torch.randn(o_c).to(device=device) - opt_fn = torch.compile(dynamic=False)(custom_conv2d) - - SchedulerDNNModel.register_model("CONV", opt_fn) - request = Request("CONV", [conv_input, conv_kernel, conv_bias], [], request_queue_idx=0) - scheduler.add_request(request, request_time=0) - - # Run scheduler - while not scheduler.is_finished(): - with torch.no_grad(): - scheduler.schedule() - print(f"CONV {batch_size}_{i_h}_{i_w}_{i_c}_{o_c}_{kernel_size}_{stride}_{padding} (B_H_W_I_C_O_C_K_S_P) Simulation Done") + custom_conv = conv2d_fn(batch_size, i_h, i_w, i_c, o_c, kernel_size, stride, padding) + opt_fn = torch.compile(dynamic=False)(custom_conv) -if __name__ == "__main__": - import os - import sys - base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') - config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.yml') - config_prefix = config.split('/')[-1].split('.')[0][9:] # extract config name from config path - sys.path.append(base_dir) - args = argparse.ArgumentParser() - args.add_argument('--size', nargs='+', type=int, default=[8, 28, 28, 128, 128, 3, 1, 1], help='B H W I_C O_C K S P') - args.add_argument('--dump_path', type=str, default='results') - args = args.parse_args() - size = args.size - size_str = "_".join([str(i) for i in size]) - result_path = os.path.join(base_dir, args.dump_path, config_prefix, f"CONV_{size_str}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") - # setting environment variables - os.environ['TORCHSIM_LOG_PATH'] = result_path - # only timing simulation - os.environ['TORCHSIM_VALIDATION_MODE'] = "0" - if 'pytorchsim_functional_mode' in os.environ: - del os.environ['pytorchsim_functional_mode'] - - run_conv2d(size[0], size[1], size[2], size[3], size[4], size[5], size[6], size[7], config) \ No newline at end of file + with TOGSimulator(config_path=config): + torch.npu.launch_model(opt_fn, conv_input, conv_kernel, conv_bias, stream_index=0, timestamp=0) + torch.npu.synchronize() + print(f"CONV {batch_size}_{i_h}_{i_w}_{i_c}_{o_c}_{kernel_size}_{stride}_{padding} Simulation Done") diff --git a/experiments/gemm.py b/experiments/gemm.py index 0e1a15e4..d256e931 100644 --- a/experiments/gemm.py +++ b/experiments/gemm.py @@ -1,51 +1,32 @@ -import torch -import torch._dynamo -import torch.utils.cpp_extension - +import os +import sys import argparse -import datetime - -def run_matmul(input_size, hidden_size, output_size, config): - from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request - def custom_matmul(a, b): - return torch.matmul(a, b) - scheduler = Scheduler(num_request_queue=1, engine_select=Scheduler.FIFO_ENGINE, togsim_config=config) - device = scheduler.execution_engine.module.custom_device() - torch.manual_seed(0) - input = torch.randn(input_size, hidden_size).to(device=device) - weight = torch.randn(hidden_size, output_size).to(device=device) - opt_fn = torch.compile(dynamic=False)(custom_matmul) +base_path = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') +sys.path.insert(0, base_path) - SchedulerDNNModel.register_model("GEMM", opt_fn) - request = Request("GEMM", [input, weight], [], request_queue_idx=0) - scheduler.add_request(request, request_time=0) +import torch +from Simulator.simulator import TOGSimulator - # Run scheduler - while not scheduler.is_finished(): - scheduler.schedule() +config = os.environ.get('TOGSIM_CONFIG', f'{base_path}/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.yml') +os.environ['TOGSIM_CONFIG'] = config - print(f"GEMM {input_size}x{hidden_size}x{output_size} (MxKxN) Simulation Done") +def matmul_fn(a, b): + return torch.matmul(a, b) if __name__ == "__main__": - import os - import sys - base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') - config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.yml') - config_prefix = config.split('/')[-1].split('.')[0][9:] # extract config name from config path - sys.path.append(base_dir) args = argparse.ArgumentParser() args.add_argument('--size', nargs='+', type=int, default=[128, 128, 128], help='M K N') - args.add_argument('--dump_path', type=str, default='results') args = args.parse_args() - size = args.size - size_str = "x".join([str(i) for i in size]) - result_path = os.path.join(base_dir, args.dump_path, config_prefix, f"GEMM_{size_str}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") - # setting environment variables - os.environ['TORCHSIM_LOG_PATH'] = result_path - # only timing simulation - os.environ['TORCHSIM_VALIDATION_MODE'] = "0" - if 'pytorchsim_functional_mode' in os.environ: - del os.environ['pytorchsim_functional_mode'] + M, K, N = args.size[0], args.size[1], args.size[2] - run_matmul(size[0], size[1], size[2], config) + device = torch.device("npu:0") + torch.manual_seed(0) + input_a = torch.randn(M, K).to(device=device) + input_b = torch.randn(K, N).to(device=device) + opt_fn = torch.compile(dynamic=False)(matmul_fn) + + with TOGSimulator(config_path=config): + torch.npu.launch_model(opt_fn, input_a, input_b, stream_index=0, timestamp=0) + torch.npu.synchronize() + print(f"GEMM {M}x{K}x{N} (MxKxN) Simulation Done") diff --git a/experiments/layernorm.py b/experiments/layernorm.py index a6b16986..a9170c6b 100644 --- a/experiments/layernorm.py +++ b/experiments/layernorm.py @@ -1,48 +1,29 @@ -import torch -import torch._dynamo -import torch.utils.cpp_extension - +import os +import sys import argparse -import datetime - -def run_layernorm(size, config): - from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request - scheduler = Scheduler(num_request_queue=1, engine_select=Scheduler.FIFO_ENGINE, togsim_config=config) - device = scheduler.execution_engine.module.custom_device() - input = torch.randn(size).to(device=device) - opt_fn = torch.compile(dynamic=False)(torch.nn.LayerNorm(size[-1]).to(device=device)) +base_path = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') +sys.path.insert(0, base_path) - SchedulerDNNModel.register_model("LayerNorm", opt_fn) - request = Request("LayerNorm", [input], [], request_queue_idx=0) - scheduler.add_request(request, request_time=0) - - # Run scheduler - while not scheduler.is_finished(): - scheduler.schedule() +import torch +from Simulator.simulator import TOGSimulator - print(f"LayerNorm {str(size)} Simulation Done") +config = os.environ.get('TOGSIM_CONFIG', f'{base_path}/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.yml') +os.environ['TOGSIM_CONFIG'] = config if __name__ == "__main__": - import os - import sys - base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') - config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.yml') - config_prefix = config.split('/')[-1].split('.')[0][9:] # extract config name from config path - sys.path.append(base_dir) args = argparse.ArgumentParser() args.add_argument('--size', nargs='+', type=int, default=[512, 768], help='Tensor Shape') - args.add_argument('--dump_path', type=str, default='results') args = args.parse_args() - size = args.size - size_str = "x".join([str(i) for i in size]) - result_path = os.path.join(base_dir, args.dump_path, config_prefix, f"LayerNorm_{size_str}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") - # setting environment variables - os.environ['TORCHSIM_LOG_PATH'] = result_path - os.environ['TORCHSIM_FUSION_REDUCTION_REDUCTION'] = "0" - # only timing simulation - os.environ['TORCHSIM_VALIDATION_MODE'] = "0" - if 'pytorchsim_functional_mode' in os.environ: - del os.environ['pytorchsim_functional_mode'] - - run_layernorm(size, config) + size = tuple(args.size) + normalized_shape = size[-1] + + device = torch.device("npu:0") + model = torch.nn.LayerNorm(normalized_shape).to(device=device) + opt_fn = torch.compile(dynamic=False)(model) + model_input = torch.randn(*size).to(device=device) + + with TOGSimulator(config_path=config): + torch.npu.launch_model(opt_fn, model_input, stream_index=0, timestamp=0) + torch.npu.synchronize() + print(f"LayerNorm {size} Simulation Done") diff --git a/experiments/resnet18.py b/experiments/resnet18.py index c7763d86..38fb80fe 100644 --- a/experiments/resnet18.py +++ b/experiments/resnet18.py @@ -1,49 +1,28 @@ -import torch -import torch._dynamo -import torch.utils.cpp_extension - +import os +import sys import argparse -import datetime -def run_resnet(batch, config): - from torchvision.models import resnet18 - from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request - scheduler = Scheduler(num_request_queue=1, engine_select=Scheduler.FIFO_ENGINE, togsim_config=config) - device = scheduler.execution_engine.module.custom_device() - model = resnet18().eval() - input = torch.randn(batch, 3, 224, 224).to(device=device) - opt_fn = torch.compile(dynamic=False)(model.to(device, memory_format=torch.channels_last)) +base_path = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') +sys.path.insert(0, base_path) - SchedulerDNNModel.register_model("resnet18", opt_fn) - request = Request("resnet18", [input], [], request_queue_idx=0) - scheduler.add_request(request, request_time=0) +import torch +from torchvision.models import resnet18 +from Simulator.simulator import TOGSimulator - # Run scheduler - while not scheduler.is_finished(): - with torch.no_grad(): - scheduler.schedule() - - print("ResNet18 Simulation Done") +config = os.environ.get('TOGSIM_CONFIG', f'{base_path}/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.yml') +os.environ['TOGSIM_CONFIG'] = config if __name__ == "__main__": - import os - import sys - base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') - config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.yml') - config_prefix = config.split('/')[-1].split('.')[0][9:] # extract config name from config path - sys.path.append(base_dir) args = argparse.ArgumentParser() args.add_argument('--batch', type=int, default=1) - args.add_argument('--dump_path', type=str, default='results') args = args.parse_args() - batch = args.batch - result_path = os.path.join(base_dir, args.dump_path, config_prefix, f"resnet18_{batch}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") - # setting environment variables - os.environ['TORCHSIM_LOG_PATH'] = result_path - os.environ['TORCHSIM_USE_TIMING_POOLING'] = "1" - # only timing simulation - os.environ['TORCHSIM_VALIDATION_MODE'] = "0" - if 'pytorchsim_functional_mode' in os.environ: - del os.environ['pytorchsim_functional_mode'] - run_resnet(batch, config) + device = torch.device("npu:0") + model = resnet18().eval().to(device=device, memory_format=torch.channels_last) + opt_fn = torch.compile(dynamic=False)(model) + model_input = torch.randn(args.batch, 3, 224, 224).to(device=device) + + with TOGSimulator(config_path=config): + torch.npu.launch_model(opt_fn, model_input, stream_index=0, timestamp=0) + torch.npu.synchronize() + print("ResNet18 Simulation Done") diff --git a/experiments/resnet50.py b/experiments/resnet50.py index 4e611541..5b134c13 100644 --- a/experiments/resnet50.py +++ b/experiments/resnet50.py @@ -1,49 +1,28 @@ -import torch -import torch._dynamo -import torch.utils.cpp_extension - +import os +import sys import argparse -import datetime -def run_resnet(batch, config): - from torchvision.models import resnet50 - from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request - scheduler = Scheduler(num_request_queue=1, engine_select=Scheduler.FIFO_ENGINE, togsim_config=config) - device = scheduler.execution_engine.module.custom_device() - model = resnet50().eval() - input = torch.randn(batch, 3, 224, 224).to(device=device) - opt_fn = torch.compile(dynamic=False)(model.to(device, memory_format=torch.channels_last)) +base_path = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') +sys.path.insert(0, base_path) - SchedulerDNNModel.register_model("resnet50", opt_fn) - request = Request("resnet50", [input], [], request_queue_idx=0) - scheduler.add_request(request, request_time=0) +import torch +from torchvision.models import resnet50 +from Simulator.simulator import TOGSimulator - # Run scheduler - while not scheduler.is_finished(): - with torch.no_grad(): - scheduler.schedule() - - print("ResNet50 Simulation Done") +config = os.environ.get('TOGSIM_CONFIG', f'{base_path}/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.yml') +os.environ['TOGSIM_CONFIG'] = config if __name__ == "__main__": - import os - import sys - base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') - config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.yml') - config_prefix = config.split('/')[-1].split('.')[0][9:] # extract config name from config path - sys.path.append(base_dir) args = argparse.ArgumentParser() args.add_argument('--batch', type=int, default=1) - args.add_argument('--dump_path', type=str, default='results') args = args.parse_args() - batch = args.batch - result_path = os.path.join(base_dir, args.dump_path, config_prefix, f"resnet50_{batch}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") - # setting environment variables - os.environ['TORCHSIM_LOG_PATH'] = result_path - os.environ['TORCHSIM_USE_TIMING_POOLING'] = "1" - # only timing simulation - os.environ['TORCHSIM_VALIDATION_MODE'] = "0" - if 'pytorchsim_functional_mode' in os.environ: - del os.environ['pytorchsim_functional_mode'] - run_resnet(batch, config) + device = torch.device("npu:0") + model = resnet50().eval().to(device=device, memory_format=torch.channels_last) + opt_fn = torch.compile(dynamic=False)(model) + model_input = torch.randn(args.batch, 3, 224, 224).to(device=device) + + with TOGSimulator(config_path=config): + torch.npu.launch_model(opt_fn, model_input, stream_index=0, timestamp=0) + torch.npu.synchronize() + print("ResNet50 Simulation Done") diff --git a/experiments/softmax.py b/experiments/softmax.py index d30559f7..b86febe0 100644 --- a/experiments/softmax.py +++ b/experiments/softmax.py @@ -1,47 +1,29 @@ -import torch -import torch._dynamo -import torch.utils.cpp_extension - +import os +import sys import argparse -import datetime - -def run_softmax(size, config, dim=1): - from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request - scheduler = Scheduler(num_request_queue=1, engine_select=Scheduler.FIFO_ENGINE, togsim_config=config) - device = scheduler.execution_engine.module.custom_device() - input = torch.randn(size).to(device=device) - opt_fn = torch.compile(dynamic=False)(torch.nn.Softmax(dim=dim).to(device=device)) +base_path = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') +sys.path.insert(0, base_path) - SchedulerDNNModel.register_model("Softmax", opt_fn) - request = Request("Softmax", [input], [], request_queue_idx=0) - scheduler.add_request(request, request_time=0) - - # Run scheduler - while not scheduler.is_finished(): - scheduler.schedule() +import torch +from Simulator.simulator import TOGSimulator - print(f"Softmax {str(size)} Simulation Done") +config = os.environ.get('TOGSIM_CONFIG', f'{base_path}/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.yml') +os.environ['TOGSIM_CONFIG'] = config if __name__ == "__main__": - import os - import sys - base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') - config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.yml') - config_prefix = config.split('/')[-1].split('.')[0][9:] # extract config name from config path - sys.path.append(base_dir) args = argparse.ArgumentParser() args.add_argument('--size', nargs='+', type=int, default=[512, 512], help='Tensor Shape') - args.add_argument('--dump_path', type=str, default='results') args = args.parse_args() - size = args.size - size_str = "x".join([str(i) for i in size]) - result_path = os.path.join(base_dir, args.dump_path, config_prefix, f"Softmax_{size_str}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") - # setting environment variables - os.environ['TORCHSIM_LOG_PATH'] = result_path - # only timing simulation - os.environ['TORCHSIM_VALIDATION_MODE'] = "0" - if 'pytorchsim_functional_mode' in os.environ: - del os.environ['pytorchsim_functional_mode'] - - run_softmax(size, config) + size = tuple(args.size) + dim = 1 + + device = torch.device("npu:0") + model = torch.nn.Softmax(dim=dim).to(device=device) + opt_fn = torch.compile(dynamic=False)(model) + model_input = torch.randn(*size).to(device=device) + + with TOGSimulator(config_path=config): + torch.npu.launch_model(opt_fn, model_input, stream_index=0, timestamp=0) + torch.npu.synchronize() + print(f"Softmax {size} Simulation Done") diff --git a/tests/Fusion/__init__.py b/tests/Fusion/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b