diff --git a/.github/workflows/pytorchsim_test.yml b/.github/workflows/pytorchsim_test.yml index 9589384b..36a62b68 100644 --- a/.github/workflows/pytorchsim_test.yml +++ b/.github/workflows/pytorchsim_test.yml @@ -163,6 +163,27 @@ jobs: -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/test_conv2d.py + test_cat: + name: Run test_cat.py + runs-on: self-hosted + steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Run test_cat.py + run: | + echo "Running test_cat.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + -e vpu_num_lanes="${{ inputs.vector_lane }}" \ + -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ + ${{ inputs.image_name }} python3 PyTorchSim/tests/test_cat.py + test_matmul: name: Run test_matmul.py runs-on: self-hosted @@ -705,6 +726,27 @@ jobs: -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ ${{ inputs.image_name }} python3 PyTorchSim/tests/Yolov5/test_yolov5.py + test_deepseek: + name: Run test_deepseek + runs-on: self-hosted + steps: + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Run test_deepseek_v3_base.py + run: | + echo "Running test_deepseek_v3_base.py" + docker run --rm \ + -v /tmp/torchsim-ci/${GITHUB_SHA}:/dump \ + -e TORCHSIM_DUMP_PATH=/dump \ + -e vpu_num_lanes="${{ inputs.vector_lane }}" \ + -e vpu_spad_size_kb_per_lane="${{ inputs.spad_size }}" \ + ${{ inputs.image_name }} python3 PyTorchSim/tests/DeepSeek/test_deepseek_v3_base.py + test_accuracy: name: Run test_accuracy runs-on: self-hosted diff --git a/AsmParser/tog_generator.py b/AsmParser/tog_generator.py index 5f586d99..a12460e3 100644 --- a/AsmParser/tog_generator.py +++ b/AsmParser/tog_generator.py @@ -37,7 +37,7 @@ class tog_generator: StonneTraceCompute= 6 StonneTraceLoad = 7 StonneTraceStore = 8 - def __init__(self, origins="Unknown") -> None: + def __init__(self, origins={"Unknown"}) -> None: self.module_name = "tile_operation_graph" self.module = None self.raw_graph = {} @@ -226,7 +226,7 @@ def generate_tile_graph(self, name="tile_graph", cycle_list=list, x_offset=int, offset = w_offset if is_preload else x_offset iter_node.torchsim_overlapping_cycle = max(iter_node.torchsim_cycle - offset, 0) - origin_info = "_".join(map(str, self.origins)) + origin_info = self.origins if isinstance(self.origins, str) else "_".join(map(str, self.origins)) onnx_node_list = [node.to_onnx() for node in node_list] # Exclude root node dump_onnx_graph(name, onnx_node_list, vector_lane, origin_info, stonneGraph=stonneGraph) diff --git a/PyTorchSimDevice/csrc/aten/OpenRegExtra.cpp b/PyTorchSimDevice/csrc/aten/OpenRegExtra.cpp index 04ba6d48..f048f878 100644 --- a/PyTorchSimDevice/csrc/aten/OpenRegExtra.cpp +++ b/PyTorchSimDevice/csrc/aten/OpenRegExtra.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -40,36 +41,6 @@ void wrapper_quantize_tensor_per_tensor_affine_stub( rtensor, qtensor, scale, zero_point); } -std::tuple< - at::Tensor, - at::Tensor, - at::Tensor, - at::Tensor, - c10::SymInt, - c10::SymInt, - at::Tensor, - at::Tensor, - at::Tensor> -wrapper__scaled_dot_product_fused_attention_overrideable( - const at::Tensor& query, - const at::Tensor& key, - const at::Tensor& value, - const std::optional& attn_bias, - double dropout_p, - bool is_causal, - bool return_debug_mask, - std::optional scale) { - return at::native::openreg::_scaled_dot_product_fused_attention_overrideable( - query, - key, - value, - attn_bias, - dropout_p, - is_causal, - return_debug_mask, - scale); -} - std::tuple wrapper_scaled_dot_product_fused_attention_overrideable_backward( const at::Tensor& grad_out, @@ -172,9 +143,6 @@ TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { m.impl("abs.out", &wrapper_abs_out); m.impl("quantize_per_tensor", &wrapper_quantize_per_tensor); m.impl("_fused_sdp_choice", &wrapper__fused_sdp_choice); - m.impl( - "_scaled_dot_product_fused_attention_overrideable", - &wrapper__scaled_dot_product_fused_attention_overrideable); m.impl( "_scaled_dot_product_fused_attention_overrideable_backward", &wrapper_scaled_dot_product_fused_attention_overrideable_backward); diff --git a/PyTorchSimDevice/csrc/aten/native/Extra.cpp b/PyTorchSimDevice/csrc/aten/native/Extra.cpp index 711d114c..eb76f5d7 100644 --- a/PyTorchSimDevice/csrc/aten/native/Extra.cpp +++ b/PyTorchSimDevice/csrc/aten/native/Extra.cpp @@ -19,8 +19,39 @@ int64_t _fused_sdp_choice( bool is_causal, std::optional scale, bool enable_gqa) { - auto backend = sdp::SDPBackend::math; - return static_cast(backend); + + sdp::sdp_params params{query, key, value, attn_mask, dropout_p, is_causal, enable_gqa}; + + // Reject inputs that are fundamentally unsupported (e.g. wrong rank) + if (!sdp::check_tensor_shapes(params, /*debug=*/false)) { + return static_cast(sdp::SDPBackend::error); + } + + // q: (B, Hq, L, E) k/v: (B, H, S, E) + const int64_t Hq = query.size(-3); + const int64_t H = key.size(-3); + const int64_t L = query.size(-2); // query sequence length + const int64_t S = key.size(-2); // key/value sequence length + + // Conditions required by the MLIR FlashSDPA kernel: + // Prefill only : L == S (decode has L == 1, not supported) + // Non-GQA : Hq == H (equal query and KV heads) + // No dropout : template has no dropout implementation + // Dense tensors : no nested tensor support + const bool can_use_mlir_flash = + (L == S) && + (Hq == H) && !enable_gqa && + sdp::check_for_dropout(params, /*debug=*/false) && + sdp::check_nested_tensor(params, /*debug=*/false); + + const bool ctx_flash = at::globalContext().userEnabledFlashSDP(); + const bool ctx_math = at::globalContext().userEnabledMathSDP(); + + if (ctx_flash && can_use_mlir_flash) { + return static_cast(sdp::SDPBackend::overrideable); + } + + return static_cast(sdp::SDPBackend::math); } void quantize_tensor_per_tensor_affine_stub( @@ -29,54 +60,6 @@ void quantize_tensor_per_tensor_affine_stub( double scale, int64_t zero_point) {} -std::tuple< - at::Tensor, - at::Tensor, - at::Tensor, - at::Tensor, - c10::SymInt, - c10::SymInt, - at::Tensor, - at::Tensor, - at::Tensor> -_scaled_dot_product_fused_attention_overrideable( - const at::Tensor& query, - const at::Tensor& key, - const at::Tensor& value, - const std::optional& attn_bias, - double dropout_p, - bool is_causal, - bool return_debug_mask, - std::optional scale) { - const int64_t batch_size = query.size(0); - const int64_t num_heads = query.size(1); - const int64_t head_dim_v = value.size(3); - const int64_t max_seqlen_q = query.size(2); - const int64_t max_seqlen_kv = key.size(2); - - auto opts = query.options(); - auto output = - at::empty({batch_size, num_heads, max_seqlen_q, head_dim_v}, opts); - auto logsumexp = - at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); - auto debug_attn_mask = at::empty( - {batch_size, num_heads, max_seqlen_q, max_seqlen_kv}, - opts.dtype(at::kFloat)); - auto philox_seed = at::empty({}, at::dtype(at::kLong)); - auto philox_offset = at::empty({}, at::dtype(at::kLong)); - - return std::make_tuple( - output, - logsumexp, - at::Tensor(), - at::Tensor(), - max_seqlen_q, - max_seqlen_kv, - philox_seed, - philox_offset, - debug_attn_mask); -} - std::tuple _scaled_dot_product_fused_attention_overrideable_backward( const at::Tensor& grad_out, diff --git a/PyTorchSimDevice/torch_openreg/openreg/__init__.py b/PyTorchSimDevice/torch_openreg/openreg/__init__.py index f5aabc18..592011aa 100644 --- a/PyTorchSimDevice/torch_openreg/openreg/__init__.py +++ b/PyTorchSimDevice/torch_openreg/openreg/__init__.py @@ -24,7 +24,7 @@ class device: def __init__(self, device): self.idx = torch.accelerator._get_device_index(device, optional=True) - self.prev_idx = -1 + self.prev_idx = -1 def __enter__(self): self.prev_idx = torch_openreg._C._exchangeDevice(self.idx) @@ -64,10 +64,20 @@ def _lazy_init(): global _initialized, _tog_simulator if is_initialized(): return + + # Replace the global C++ binding with our custom dispatcher patch + # from PyTorchSimFrontend.mlir.mlir_sdpa_template import patched_scaled_dot_product_attention + # torch._C._nn.scaled_dot_product_attention = patched_scaled_dot_product_attention + torch_openreg._C._init() register_interface_for_device(custom_device(), ExtensionDeviceInterface) _initialized = True + # Set default SDPA backend to math-only for this device. + torch._C._set_sdp_use_flash(False) + torch._C._set_sdp_use_overrideable(False) + torch._C._set_sdp_use_math(True) + # Create default streams for all devices num_devices = device_count() for device_idx in range(num_devices): diff --git a/PyTorchSimFrontend/extension_codecache.py b/PyTorchSimFrontend/extension_codecache.py index d6b47123..b1c457d3 100644 --- a/PyTorchSimFrontend/extension_codecache.py +++ b/PyTorchSimFrontend/extension_codecache.py @@ -67,8 +67,17 @@ def mlir_compile_command(filename, vectorlane_size, vlen=256): f""" {extension_config.CONFIG_TORCHSIM_LLVM_PATH}/llc \ -relocation-model=pic -march=riscv64 -O3 --stack-size-section \ - -mattr=+m,+f,+d,+a,+c,+v,+xsfvcp,zvl{vlen}b \ + -mattr=+m,+f,+d,+a,+c,+v,+zvfh,+xsfvcp,zvl{vlen}b \ + -filetype=obj \ {'--print-after-all' if extension_config.CONFIG_TORCHSIM_DUMP_LLVM_IR else ''} \ + -O2 {filename}.ll -o {filename}.o + """, + ).strip(), + re.sub(r"[ \n]+", " ", + f""" + {extension_config.CONFIG_TORCHSIM_LLVM_PATH}/llc \ + -relocation-model=pic -march=riscv64 -O3 --stack-size-section \ + -mattr=+m,+f,+d,+a,+c,+v,+zvfh,+xsfvcp,zvl{vlen}b \ -O2 {filename}.ll -o {filename}.s """, ).strip()] @@ -109,9 +118,10 @@ def mlir_gem5_compile_command(filename, sample_filename, tog_file, vectorlane_si f""" {extension_config.CONFIG_TORCHSIM_LLVM_PATH}/llc \ -relocation-model=pic -march=riscv64 -O3 --stack-size-section \ - -mattr=+m,+f,+d,+a,+c,+v,+xsfvcp,zvl{vlen}b \ + -mattr=+m,+f,+d,+a,+c,+v,+zvfh,+xsfvcp,zvl{vlen}b \ + -filetype=obj \ {'--print-after-all' if extension_config.CONFIG_TORCHSIM_DUMP_LLVM_IR else ''} \ - -O2 {sample_filename}.ll -o {sample_filename}.s + -O2 {sample_filename}.ll -o {sample_filename}.o """, ).strip()] @@ -166,11 +176,13 @@ def load(cls, source_code, opt_cmd = shlex.split(cmds[0]) translate_cmd = shlex.split(cmds[1]) llc_cmd = shlex.split(cmds[2]) + llc_asm_cmd = shlex.split(cmds[3]) with lock: try: subprocess.check_call(opt_cmd) subprocess.check_call(translate_cmd) subprocess.check_call(llc_cmd) + subprocess.check_call(llc_asm_cmd) except subprocess.CalledProcessError as e: logger.error(f"Command failed with exit code {e.returncode}") logger.error(f"Error output: {e.output.decode() if isinstance(e.output, bytes) else e.output}") diff --git a/PyTorchSimFrontend/extension_config.py b/PyTorchSimFrontend/extension_config.py index eff6f573..1b7ccf8d 100644 --- a/PyTorchSimFrontend/extension_config.py +++ b/PyTorchSimFrontend/extension_config.py @@ -31,8 +31,6 @@ def __getattr__(name): "spad_size" : config_yaml["vpu_spad_size_kb_per_lane"] << 10 # Note: spad size per lane } - if name == "CONFIG_PRECISION": - return 4 # 32bit if name == "CONFIG_NUM_CORES": return config_yaml["num_cores"] if name == "vpu_vector_length_bits": diff --git a/PyTorchSimFrontend/mlir/mlir_autotune.py b/PyTorchSimFrontend/mlir/mlir_autotune.py index 4503584c..caf4d6da 100644 --- a/PyTorchSimFrontend/mlir/mlir_autotune.py +++ b/PyTorchSimFrontend/mlir/mlir_autotune.py @@ -85,7 +85,7 @@ def cached_run_fn(*args, **kwargs): self.source_code, vectorlane_size=self.extra_args["vector_lane"], loop_size=None, spad_info=self.extra_args["spad_info"], vlen=self.extra_args["vlen"], arg_attributes=self.extra_args["arg_attributes"], - origins="Unknown", silent_mode=True, + origins=self.extra_args["origins"], silent_mode=True, autotune=self.extra_args['autotune']) args = [ diff --git a/PyTorchSimFrontend/mlir/mlir_bmm_template.py b/PyTorchSimFrontend/mlir/mlir_bmm_template.py index 178ea987..c5fd902f 100644 --- a/PyTorchSimFrontend/mlir/mlir_bmm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_bmm_template.py @@ -26,20 +26,20 @@ {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} {% if not Bias %} - %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}x{{DATA_STYPE}}> {% endif %} %c0 = arith.constant 0 : index {{ kernel.def_local_vars(indent_size=2) }} affine.for %index0 = 0 to {{ B }} { affine.for %index1 = 0 to {{ M }} step {{ TILE_M }} { affine.for %index2 = 0 to {{ N }} step {{ TILE_N }} { - %X_buffer2D = memref.reinterpret_cast %X_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1> - %W_buffer2D = memref.reinterpret_cast %W_buffer to offset: [0], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ W_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1> - %Y_buffer2D = memref.reinterpret_cast %Y_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> + %X_buffer2D = memref.reinterpret_cast %X_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}x{{ DATA_STYPE }}, 1> + %W_buffer2D = memref.reinterpret_cast %W_buffer to offset: [0], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ W_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_K }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1> + %Y_buffer2D = memref.reinterpret_cast %Y_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}x{{ DATA_STYPE }}, 1> {% if Bias -%} {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Y_tile_desc, subtile_size=[1, SUB_TILE_M, SUB_TILE_N], indent_size=8) }} {%- else -%} - affine.vector_store %v0, %Y_buffer[0, 0, 0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + affine.vector_store %v0, %Y_buffer[0, 0, 0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}x{{DATA_STYPE}}> {% endif %} affine.for %index3 = 0 to {{ K }} step {{ TILE_K }} { @@ -74,20 +74,20 @@ {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} {% if not Bias %} - %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}x{{DATA_STYPE}}> {% endif %} %c0 = arith.constant 0 : index {{ kernel.def_local_vars(indent_size=2) }} affine.for %index0 = 0 to {{ B }} { affine.for %index1 = 0 to {{ M }} step {{ TILE_M }} { affine.for %index2 = 0 to {{ N }} step {{ TILE_N }} { - %X_buffer2D = memref.reinterpret_cast %X_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : memref<1x{{ TILE_M }}x{{ TILE_K }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1> - %W_buffer2D = memref.reinterpret_cast %W_buffer to offset: [0], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<1x{{ TILE_K }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1> - %Y_buffer2D = memref.reinterpret_cast %Y_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<1x{{ TILE_M }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> + %X_buffer2D = memref.reinterpret_cast %X_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : memref<1x{{ TILE_M }}x{{ TILE_K }}x{{DATA_STYPE}}, 1> to memref<{{ TILE_M }}x{{ TILE_K }}x{{DATA_STYPE}}, 1> + %W_buffer2D = memref.reinterpret_cast %W_buffer to offset: [0], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<1x{{ TILE_K }}x{{ TILE_N }}x{{DATA_STYPE}}, 1> to memref<{{ TILE_K }}x{{ TILE_N }}x{{DATA_STYPE}}, 1> + %Y_buffer2D = memref.reinterpret_cast %Y_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<1x{{ TILE_M }}x{{ TILE_N }}x{{DATA_STYPE}}, 1> to memref<{{ TILE_M }}x{{ TILE_N }}x{{DATA_STYPE}}, 1> {% if Bias -%} {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Y_tile_desc, subtile_size=[1, SUB_TILE_M, SUB_TILE_N], indent_size=8) }} {%- else -%} - affine.vector_store %v0, %Y_buffer[0, 0, 0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + affine.vector_store %v0, %Y_buffer[0, 0, 0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}x{{DATA_STYPE}}> {% endif %} affine.for %index3 = 0 to {{ K }} step {{ TILE_K }} { {{kernel.load_input(indent_size=10)}} @@ -120,21 +120,21 @@ {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} {% if not Bias %} - %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}x{{DATA_STYPE}}> {% endif %} %c0 = arith.constant 0 : index {{ kernel.def_local_vars(indent_size=2) }} affine.for %index0=0 to {{ B }} { affine.for %index2 = 0 to {{ N }} step {{ TILE_N }} { affine.for %index1 = 0 to {{ M }} step {{ TILE_M }} { - %X_buffer2D = memref.reinterpret_cast %X_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : memref<1x{{ TILE_M }}x{{ TILE_K }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_K }}xf32, 1> - %W_buffer2D = memref.reinterpret_cast %W_buffer to offset: [0], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<1x{{ TILE_K }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_K }}x{{ TILE_N }}xf32, 1> - %Y_buffer2D = memref.reinterpret_cast %Y_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<1x{{ TILE_N }}x{{ TILE_M }}xf32, 1> to memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> + %X_buffer2D = memref.reinterpret_cast %X_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : memref<1x{{ TILE_M }}x{{ TILE_K }}x{{DATA_STYPE}}, 1> to memref<{{ TILE_M }}x{{ TILE_K }}x{{DATA_STYPE}}, 1> + %W_buffer2D = memref.reinterpret_cast %W_buffer to offset: [0], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<1x{{ TILE_K }}x{{ TILE_N }}x{{DATA_STYPE}}, 1> to memref<{{ TILE_K }}x{{ TILE_N }}x{{DATA_STYPE}}, 1> + %Y_buffer2D = memref.reinterpret_cast %Y_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<1x{{ TILE_N }}x{{ TILE_M }}x{{DATA_STYPE}}, 1> to memref<{{ TILE_M }}x{{ TILE_N }}x{{DATA_STYPE}}, 1> {% if Bias -%} {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Y_tile_desc, subtile_size=[1, SUB_TILE_M, SUB_TILE_N], indent_size=8) }} // Why not N,M? Currently, dma-fine-grained pass assume M->N order... {%- else -%} - affine.vector_store %v0, %Y_buffer[0, 0, 0] : memref<1x{{ TILE_N }}x{{ TILE_M }}xf32, 1>, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + affine.vector_store %v0, %Y_buffer[0, 0, 0] : memref<1x{{ TILE_N }}x{{ TILE_M }}x{{DATA_STYPE}}, 1>, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}x{{DATA_STYPE}}> {% endif %} affine.for %index3 = 0 to {{ K }} step {{ TILE_K }} { {{ kernel.def_dma_op("MVIN", "X", X_idx, X_tile_desc, subtile_size=[1, SUB_TILE_M, SUB_TILE_K], indent_size=10) }} @@ -154,6 +154,9 @@ class MLIRBMMTemplate(MLIRTemplate): def __init__(self, input_nodes, layout, input_reorder=None): super().__init__("kernel", input_nodes, layout, input_reorder) + self.support_epilogue_fusion = True + self.support_prologue_fusion = True + self.support_reduction_fusion = True def render(self, kernel: MLIRTemplateKernel, @@ -163,8 +166,9 @@ def render(self, tile_info = None, **kwargs): X, W, Y, Bias, W_tensor, X_tensor, B, M, N, K, n_extra_node, n_prologue_node = self.extract_info(template_buffer_node, epilogue_nodes, prologue_nodes) + precision_bytes = mlir_common.get_dtype_nbytes(X.get_dtype()) if tile_info is None: - TILE_M, TILE_N, TILE_K, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, M, N, K, n_extra_node, 0, n_prologue_node)[0] + TILE_M, TILE_N, TILE_K, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, M, N, K, n_extra_node, 0, n_prologue_node, precision_bytes)[0] else: TILE_M, TILE_N, TILE_K, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = tile_info @@ -234,6 +238,7 @@ def render(self, else: Bias_idx = None + data_stype = mlir_common.DTYPE_TO_MLIR[X.get_dtype()] kernel.render_options = dict( KERNEL_NAME=self.name, kernel=kernel, @@ -242,7 +247,7 @@ def render(self, SUB_TILE_M=SUB_TILE_M, SUB_TILE_N=SUB_TILE_N, SUB_TILE_K=SUB_TILE_K, - DATA_STYPE="f32", + DATA_STYPE=data_stype, X = X, W = W,Y = Y, Bias = Bias, X_idx = X_idx, W_idx = W_idx, @@ -316,6 +321,12 @@ def extract_info(self, template_buffer_node, epilogue_nodes, prologue_nodes): X, W = self.input_nodes[0], self.input_nodes[1] Y = self.output_node Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + dtype_infos = [("X", X.get_dtype()), ("W", W.get_dtype()), ("Y", Y.get_dtype())] + if Bias is not None: + dtype_infos.append(("Bias", Bias.get_dtype())) + if len({dtype for _, dtype in dtype_infos}) != 1: + dtype_desc = ", ".join(f"{name}={dtype}" for name, dtype in dtype_infos) + raise NotImplementedError(f"Mixed dtype BMM is not implemented yet ({dtype_desc})") W_tensor = empty_strided(W.layout.size, W.layout.stride) X_tensor = empty_strided(X.layout.size, X.layout.stride) @@ -340,10 +351,11 @@ def get_tile_candidates(self, prologue_nodes: Optional[List[IRNode]] = None, **kwargs): X, W, Y, Bias, W_tensor, X_tensor, B, M, N, K, n_extra_node, n_prologue_node = self.extract_info(template_buffer_node, epilogue_nodes, prologue_nodes) - return self.select_tile(kernel, M, N, K, n_extra_node, 0, n_prologue_node) + precision_bytes = mlir_common.get_dtype_nbytes(X.get_dtype()) + return self.select_tile(kernel, M, N, K, n_extra_node, 0, n_prologue_node, precision_bytes) - def select_tile(self, kernel, M, N, K, n_extra_node, n_extra_read, n_prologue_node): - tile_candidates = kernel.gemm_combination_mapping(M, N, K, n_extra_node=n_extra_node) + def select_tile(self, kernel, M, N, K, n_extra_node, n_extra_read, n_prologue_node, precision_bytes): + tile_candidates = kernel.gemm_combination_mapping(M, N, K, n_extra_node=n_extra_node, precision_bytes=precision_bytes) for idx, (TILE_M, TILE_N, TILE_K) in enumerate(tile_candidates): SUB_TILE_M = TILE_M if (TILE_M < kernel.vector_lane) or n_prologue_node else kernel.vector_lane SUB_TILE_N = TILE_N # if (TILE_N < kernel.vector_lane) or prologue_nodes else kernel.vector_lane diff --git a/PyTorchSimFrontend/mlir/mlir_caller_codegen.py b/PyTorchSimFrontend/mlir/mlir_caller_codegen.py index 06d41ea2..7c842272 100644 --- a/PyTorchSimFrontend/mlir/mlir_caller_codegen.py +++ b/PyTorchSimFrontend/mlir/mlir_caller_codegen.py @@ -182,22 +182,18 @@ def add_extention(self, name, extension): def compile_wih_kernel(self, write_path, llvm_name, wrapper_name, binary_name, link_option=""): main_path = os.path.join(write_path, self.add_extention(wrapper_name, 'c')) main_obj_path = os.path.join(write_path, self.add_extention(wrapper_name, 'o')) - kernel_path = os.path.join(write_path, self.add_extention(llvm_name, 's')) kernel_obj_path = os.path.join(write_path, self.add_extention(llvm_name, 'o')) main_compile = f'riscv64-unknown-elf-gcc -march=rv64gcv -c {main_path} -o {main_obj_path}' - kernel_compile = f'clang -c --target="riscv64" -march=rv64gcv -O2 -nostdlib {kernel_path} -o {kernel_obj_path}' target = os.path.join(write_path, binary_name) link = f'riscv64-unknown-elf-gcc -march=rv64gcv {main_obj_path} {kernel_obj_path} -o {target} -lm {link_option}' main_compile_cmd = shlex.split(main_compile) - kernel_compile_cmd = shlex.split(kernel_compile) link_cmd = shlex.split(link) try: subprocess.check_call(main_compile_cmd) - subprocess.check_call(kernel_compile_cmd) subprocess.check_call(link_cmd) except subprocess.CalledProcessError as e: print("Command failed with exit code", e.returncode) diff --git a/PyTorchSimFrontend/mlir/mlir_cat_template.py b/PyTorchSimFrontend/mlir/mlir_cat_template.py new file mode 100644 index 00000000..7abdfee6 --- /dev/null +++ b/PyTorchSimFrontend/mlir/mlir_cat_template.py @@ -0,0 +1,367 @@ +from typing import List, Optional, Set +import math +import itertools + +import sympy +from torch._inductor.ir import IRNode + +from PyTorchSimFrontend.mlir import mlir_common +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate, MLIRTemplateKernel + + +TEMPLATE = r""" +{{kernel.def_global_vars()}} +func.func @{{ KERNEL_NAME }} {{kernel.def_kernel(inputs=INPUT_NAMES, outputs=[Y], names_str=NAMES_STR, input_reorder=input_reorder)}} { +{%- for buffer_name, tile_desc in UNIQUE_BUFFER_TILE_DESCS.items() %} + {{ kernel.def_sram_buffer(buffer_name, tile_desc, indent_size=2) }} +{%- endfor %} + {{ kernel.def_local_vars(indent_size=2) }} + + affine.for %cat_block = 0 to 1 step 1 { +{%- for d in range(RANK-1) %} + affine.for %index{{ OUTPUT_DIM[d] }} = 0 to {{ OUTPUT_SIZES[d] }} step {{ TILE_SIZES[d] }} { +{%- endfor %} +{%- for i in range(NUM_INPUTS) %} + // Input tensor{{ i }} + affine.for %index_local{{ DIM }}_{{ i }} = 0 to {{ INPUTS[i].sizes[DIM] }} step {{ INPUTS[i].tile_size_dim }} { + %index{{ DIM }}_{{ i }} = affine.apply affine_map<(d0) -> (d0 + {{ INPUTS[i].cum_offset }})> (%index_local{{ DIM }}_{{ i }}) + %input_dram_offset_{{ i }} = affine.apply {{ INPUTS[i].offset_map }}({{ INPUTS[i].offset_vars }}) + %output_dram_offset_{{ i }} = affine.apply {{ OUTPUTS[i].offset_map }}({{ OUTPUTS[i].offset_vars }}) + {{ kernel.def_dma_op("MVIN", INPUTS[i].dram_name, [], INPUTS[i].tile_desc, indent_size=INDENT_SIZE, dram_stride=INPUTS[i].dram_strides, dram_offset="input_dram_offset_" ~ i) }} + {{ kernel.def_dma_op("MVOUT", "Y", [], OUTPUTS[i].tile_desc, indent_size=INDENT_SIZE, dram_stride=OUTPUTS[i].dram_strides, dram_offset="output_dram_offset_" ~ i) }} + } { inner_loop=true } +{%- endfor %} + +{%- for d in range(RANK-1) %} + } { outer_loop=true } +{%- endfor %} + } { outer_loop=true } + return +} +""" + + +class MLIRCatTemplate(MLIRTemplate): + def __init__(self, input_nodes, layout, dim): + super().__init__("kernel", input_nodes, layout) + self.dim = dim + + def render( + self, + kernel: MLIRTemplateKernel, + template_buffer_node=None, + epilogue_nodes: Optional[List[IRNode]] = None, + tile_info=None, + **kwargs, + ): + input_nodes = self.input_nodes + y = self.output_node + dtype_infos = [("Y", y.get_dtype())] + [(f"X{i}", x.get_dtype()) for i, x in enumerate(input_nodes)] + if len({dtype for _, dtype in dtype_infos}) != 1: + dtype_desc = ", ".join(f"{name}={dtype}" for name, dtype in dtype_infos) + raise NotImplementedError(f"Mixed dtype Cat is not implemented yet ({dtype_desc})") + precision_bytes = mlir_common.get_dtype_nbytes(y.get_dtype()) + num_inputs = len(input_nodes) + rank = len(y.get_size()) + + input_sizes = [x.get_size() for x in input_nodes] + output_sizes = [sz for d, sz in enumerate(y.get_size()) if d != self.dim] + output_dim = [d for d, _ in enumerate(y.get_size()) if d != self.dim] + output_strides = y.get_layout().stride + + tile_sizes = list(tile_info) if tile_info is not None else [1] * len(output_sizes) + excluded_dims = self._compute_excluded_dims(tile_sizes) + + input_tile_sizes_dim = self._calculate_input_tile_sizes( + kernel, input_sizes, tile_sizes, num_inputs, rank, precision_bytes + ) + buffer_name_to_template_name, input_dram_names = self._build_buffer_mapping(input_nodes) + input_tile_descs, output_tile_descs, unique_tile_descs = self._build_tile_descriptors( + kernel, input_nodes, input_sizes, input_tile_sizes_dim, tile_sizes, rank, + input_dram_names, y, excluded_dims=excluded_dims + ) + (input_offset_maps, input_offset_var_strs, input_dram_strides, + output_offset_maps, output_offset_var_strs, output_dram_strides, + cumulative_offsets) = self._build_dma_info( + input_nodes, input_sizes, output_strides, input_tile_descs, output_tile_descs, + rank, num_inputs, excluded_dims=excluded_dims + ) + + unique_buffer_tile_descs = { + buffer_name_to_template_name[name]: desc + for name, desc in unique_tile_descs.items() + } + names_str = ", ".join(input_dram_names + ["Y"]) + indent_size = 2 + (rank - 1) * 2 + 4 + + inputs_info = [ + dict( + dram_name = input_dram_names[i], + sizes = input_sizes[i], + tile_size_dim= input_tile_sizes_dim[i], + tile_desc = input_tile_descs[i], + offset_map = input_offset_maps[i], + offset_vars = input_offset_var_strs[i], + dram_strides = input_dram_strides[i], + cum_offset = cumulative_offsets[i], + ) + for i in range(num_inputs) + ] + outputs_info = [ + dict( + tile_desc = output_tile_descs[i], + offset_map = output_offset_maps[i], + offset_vars = output_offset_var_strs[i], + dram_strides = output_dram_strides[i], + ) + for i in range(num_inputs) + ] + + kernel.render_options = dict( + KERNEL_NAME = self.name, + kernel = kernel, + NUM_INPUTS = num_inputs, + NAMES_STR = names_str, + Y = y, + INPUT_NAMES = input_nodes, + RANK = rank, + DIM = self.dim, + OUTPUT_SIZES = output_sizes, + OUTPUT_DIM = output_dim, + TILE_SIZES = tile_sizes, + UNIQUE_BUFFER_TILE_DESCS = unique_buffer_tile_descs, + INPUTS = inputs_info, + OUTPUTS = outputs_info, + INDENT_SIZE = indent_size, + input_reorder = self.input_reorder, + ) + + return self._template_from_string(TEMPLATE).render(**kernel.render_options) + + def get_tile_candidates( + self, + kernel: MLIRTemplateKernel, + template_buffer_node=None, + epilogue_nodes: Optional[List[IRNode]] = None, + **kwargs, + ): + """Generate tile candidates for cat operation. Concat dimension always has tile size 1.""" + if template_buffer_node is not None: + self.output_node = template_buffer_node + + y = self.output_node + dtype_infos = [("Y", y.get_dtype())] + [(f"X{i}", x.get_dtype()) for i, x in enumerate(self.input_nodes)] + if len({dtype for _, dtype in dtype_infos}) != 1: + dtype_desc = ", ".join(f"{name}={dtype}" for name, dtype in dtype_infos) + raise NotImplementedError(f"Mixed dtype Cat is not implemented yet ({dtype_desc})") + precision_bytes = mlir_common.get_dtype_nbytes(y.get_dtype()) + num_inputs = len(self.input_nodes) + output_sizes = [sz for d, sz in enumerate(y.get_size()) if d != self.dim] + + if not output_sizes: + return [[1]] + + max_tile_total = kernel.spad_info["spad_size"] // ( + kernel.vector_lane * precision_bytes * 2 * num_inputs + ) + + dim_tile_candidates = [] + for dim_size in output_sizes: + max_tile = min(dim_size, max_tile_total) + candidates = set() + for mult in range(1, max_tile // kernel.vector_lane + 1): + t = mult * kernel.vector_lane + if t <= dim_size and dim_size % t == 0: + candidates.add(t) + if max_tile > 0: + for exp in range(int(math.log2(max_tile)) + 1): + t = 2 ** exp + if t <= dim_size and dim_size % t == 0: + candidates.add(t) + candidates.add(dim_size) # dim_size always divides itself + dim_tile_candidates.append(sorted(candidates)[:5]) + + tile_candidates = [ + list(combo) + for combo in itertools.product(*dim_tile_candidates) + if math.prod(combo) * (num_inputs + 1) * precision_bytes + <= kernel.spad_info["spad_size"] * kernel.vector_lane + ] + + if not tile_candidates: + tile_candidates = [[1] * len(output_sizes)] + + tile_candidates.sort(key=lambda x: -math.prod(x)) + return tile_candidates[:4] + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + def _compute_excluded_dims(self, tile_sizes: list) -> list: + """Return non-tiled dimension indices when rank exceeds the 4-dim limit.""" + max_tiled = 3 + if len(tile_sizes) <= max_tiled: + return [] + sorted_dims = sorted(enumerate(tile_sizes), key=lambda x: x[1], reverse=True) + excluded = [idx for idx, _ in sorted_dims[max_tiled:]] + for idx in excluded: + tile_sizes[idx] = 1 + return excluded + + def _calculate_input_tile_sizes(self, kernel, input_sizes, tile_sizes, num_inputs, rank, precision_bytes): + """Calculate tile sizes along the concat dimension for each input.""" + non_dim_tile_elements = math.prod(tile_sizes) if tile_sizes else 1 + max_spad_per_input = kernel.spad_info["spad_size"] * kernel.vector_lane // 2 + extra_concat = math.ceil(max_spad_per_input / (non_dim_tile_elements * precision_bytes)) - num_inputs + + input_tile_sizes_dim = [] + for i in range(num_inputs): + if extra_concat > 0 and non_dim_tile_elements > 0: + tile_dim = min(input_sizes[i][self.dim], extra_concat) + extra_concat -= tile_dim + else: + tile_dim = 1 + input_tile_sizes_dim.append(tile_dim) + return input_tile_sizes_dim + + def _build_buffer_mapping(self, input_nodes): + """Map actual buffer names to short template names (X0, X1, ...).""" + name_map = {} + template_names = [] + for x in input_nodes: + actual = x.get_name() + template = name_map.setdefault(actual, f"X{len(name_map)}") + template_names.append(template) + return name_map, template_names + + def _build_tile_descriptors( + self, kernel, input_nodes, input_sizes, input_tile_sizes_dim, tile_sizes, rank, + input_buffer_names, output_node, excluded_dims=None + ): + """Build tile descriptors for every input (and its paired output).""" + if excluded_dims is None: + excluded_dims = set() + + def make_tile_desc(tile_sz, vector_lane, name, offset): + desc = mlir_common.MLIRMultiDimTile( + tile_sz, vector_lane, + vlane_split_axis=len(tile_sz) - 1, + vlane_stride=1 + ) + desc.set_tile_size(tile_sz) + desc.set_name(name) + desc.offset = offset + return desc + + output_offset = output_node.get_layout().offset + input_tile_descs, output_tile_descs, unique_tile_descs = [], [], {} + + for i, x in enumerate(input_nodes): + # Collect tile sizes for tiled dimensions only (skip excluded non-concat dims) + tile_sz = [] + tile_idx = 0 + for d in range(rank): + if d != self.dim: + if tile_idx not in excluded_dims: + tile_sz.append(tile_sizes[tile_idx]) + tile_idx += 1 + else: + tile_sz.append(input_tile_sizes_dim[i]) + + sram_name = f"{input_buffer_names[i].lower()}_cat_tile" + input_tile_descs.append(make_tile_desc(tile_sz, kernel.vector_lane, sram_name, x.get_layout().offset)) + output_tile_descs.append(make_tile_desc(tile_sz, kernel.vector_lane, sram_name, output_offset)) + + actual_name = x.get_name() + if actual_name not in unique_tile_descs: + unique_tile_descs[actual_name] = input_tile_descs[-1] + + return input_tile_descs, output_tile_descs, unique_tile_descs + + def _build_dma_info( + self, input_nodes, input_sizes, output_strides, + input_tile_descs, output_tile_descs, + rank, num_inputs, excluded_dims=None + ): + """Build per-input DRAM offset affine maps and tile strides. + + Three stride concepts are maintained: + + * layout_strides (internal) - raw DRAM buffer strides for every rank + dimension, used to compute the flat base-address affine map. + These reflect how the tensor is physically laid out in DRAM. + * dram_strides (returned, ``def_dma_op dram_stride=``) - stride in + DRAM per *tiled* dimension (excluded dims removed). The DMA engine + uses these to walk DRAM when loading/storing a tile. + * sram_strides (inside ``def_dma_op``, from tile_desc) - stride in + SRAM per tiled dimension. The DMA engine uses these to place data + into the SRAM tile buffer. + + Returns: + input_offset_maps, input_offset_var_strs, input_dram_strides, + output_offset_maps, output_offset_var_strs, output_dram_strides, + cumulative_offsets + """ + if excluded_dims is None: + excluded_dims = set() + + def make_affine_map(idx_syms, strides, layout_offset): + terms = [] + for j, s in enumerate(strides): + s = int(s) + if s == 1: + terms.append(f"d{j}") + elif s != 0: + terms.append(f"d{j} * {s}") + try: + off = int(layout_offset) + except (TypeError, ValueError): + off = 0 + if off: + terms.append(str(off)) + dim_str = ", ".join(f"d{j}" for j in range(len(idx_syms))) + return f"affine_map<({dim_str}) -> ({' + '.join(terms) if terms else '0'})>" + + cumulative_offsets = [0] + for i in range(num_inputs - 1): + cumulative_offsets.append(cumulative_offsets[-1] + input_sizes[i][self.dim]) + + input_offset_maps, input_offset_var_strs, input_dram_strides = [], [], [] + output_offset_maps, output_offset_var_strs, output_dram_strides = [], [], [] + + for i, x in enumerate(input_nodes): + x_stride = x.get_layout().stride + in_syms, in_layout_strides, in_dram_strides = [], [], [] + out_syms, out_layout_strides, out_dram_strides = [], [], [] + tile_idx = 0 + + for d in range(rank): + if d != self.dim: + in_syms.append(sympy.Symbol(f"index{d}")) + in_layout_strides.append(int(x_stride[d])) + out_syms.append(sympy.Symbol(f"index{d}")) + out_layout_strides.append(int(output_strides[d])) + if tile_idx not in excluded_dims: + in_dram_strides.append(int(x_stride[d])) + out_dram_strides.append(int(output_strides[d])) + tile_idx += 1 + else: + in_syms.append(sympy.Symbol(f"index_local{self.dim}_{i}")) + in_layout_strides.append(int(x_stride[d])) + out_syms.append(sympy.Symbol(f"index{self.dim}_{i}")) + out_layout_strides.append(int(output_strides[d])) + in_dram_strides.append(int(x_stride[d])) + out_dram_strides.append(int(output_strides[d])) + + input_offset_maps.append(make_affine_map(in_syms, in_layout_strides, input_tile_descs[i].offset)) + input_offset_var_strs.append(", ".join(f"%{s}" for s in in_syms)) + input_dram_strides.append(in_dram_strides) + + output_offset_maps.append(make_affine_map(out_syms, out_layout_strides, output_tile_descs[i].offset)) + output_offset_var_strs.append(", ".join(f"%{s}" for s in out_syms)) + output_dram_strides.append(out_dram_strides) + + return (input_offset_maps, input_offset_var_strs, input_dram_strides, + output_offset_maps, output_offset_var_strs, output_dram_strides, + cumulative_offsets) diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index d6ddb025..672c35f7 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -285,7 +285,7 @@ def __init__(self, kernel_group, reason=None): self.gem5_header = IndentedBuffer() self.header.writeline("#include ") self.header.writeline("#include ") - self.header.writeline("void* __wrap_malloc(size_t size) { return sbrk(size); }") + self.header.writeline("void* __wrap_malloc(size_t size) { size = (size + 511UL) & ~511UL; return sbrk(size); }") # Align to 512 bytes self.header.writeline("void __wrap_free(void *ptr) { return; }") self.reduction_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="tmp_acc") self.spad_cse = common.CSE(self.newvar_prefix, self.suffix, name_prefix="spad") @@ -470,7 +470,12 @@ def parse_index_list(self, expr_list:list, offset=sympy.Number(0)) -> common.CSE new_expr_list[idx] = arg.subs(arg.args[1], dim_list[idx]) indices.append(str(new_arg)) elif not arg.is_number: - new_arg = sympy.Symbol(str(self.convert_index(arg))) + try: + new_arg = sympy.Symbol(str(self.convert_index(arg))) + #not implemented case + except NotImplementedError: + print(f"Not implemented case: {arg}") + raise NotImplementedError(f"Not implemented case: {arg}") new_expr_list[idx] = new_arg.subs(new_arg, dim_list[idx]) indices.append(str(new_arg)) else: @@ -959,7 +964,10 @@ def make_choices(self, nodes, kernel_name): # Try initial tile size self.reset(None) - src_code, meta_code = super().codegen_nodes(nodes, kernel_name) + try: + src_code, meta_code = super().codegen_nodes(nodes, kernel_name) + except mlir_common.RecompileSignal: + continue current_tile_sz = tuple(self.kernel_group.tile_desc.get_tile_size()) search_space.add(current_tile_sz) @@ -981,14 +989,12 @@ def make_choices(self, nodes, kernel_name): # Try increase tile size for this axis try: self.kernel_group.tile_desc.scale_tile_dim(axis, prev_ranges[axis], 2) - except extension_codecache.TileSizeError as e: - # Failed to find proper tile size + self.reset(None) + src_code, meta_code = super().codegen_nodes(nodes, kernel_name) + except (extension_codecache.TileSizeError, mlir_common.RecompileSignal): candidate_axes.remove(axis) self.reset(None) continue - - self.reset(None) - src_code, meta_code = super().codegen_nodes(nodes, kernel_name) current_tile_sz = tuple(self.kernel_group.tile_desc.get_tile_size()) # FIXME. How to intergrate this constraint to tile system? @@ -1060,6 +1066,7 @@ def run_bench(self, nodes, kernel_name, src_code): "vlen" : self.vlen, "arg_attributes" : arg_attributes, "autotune" : True, + "origins" : {str(i) for node in nodes for i in node.node.origins}, }, source_code=src_code, ) @@ -1423,11 +1430,11 @@ def get_const_cse(self, value, dtype="index") -> common.CSEVariable: value = float(value) else: value = int(value) - - if value not in self.consts: - self.consts[str(value)+dtype] = self.const_cse.generate(self.const_buffer, f"arith.constant {value} : {dtype}") - self.register_var_info(self.consts[str(value)+dtype], [1, dtype]) - return self.consts[str(value)+dtype] + key = str(value)+dtype + if key not in self.consts: + self.consts[key] = self.const_cse.generate(self.const_buffer, f"arith.constant {value} : {dtype}") + self.register_var_info(self.consts[key], [1, dtype]) + return self.consts[key] def get_tag_cse(self, value=None, shape="memref<1xi32>"): if value is None: diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index 34b185b8..32805261 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -67,7 +67,7 @@ DTYPE_TO_C = { torch.float32: "float", torch.float64: "double", - torch.float16: "half", + torch.float16: "uint16_t", torch.int64: "int64_t", torch.int32: "int32_t", torch.int16: "int16_t", @@ -90,6 +90,12 @@ "index": 64 } +def get_dtype_nbytes(dtype): + mlir_dtype = DTYPE_TO_MLIR.get(dtype) + if mlir_dtype is None or mlir_dtype not in MLIR_TO_BIT: + raise NotImplementedError(f"Unsupported dtype for precision calculation: {dtype}") + return MLIR_TO_BIT[mlir_dtype] // 8 + DTYPE_LOWP_FP = [ torch.bfloat16, torch.float16, @@ -97,14 +103,17 @@ MLIR_INF = { "inf" : { + "f16" : 0x7C00, "f32" : 0x7F800000, "f64" : 0x7FF0000000000000 }, "-inf" : { + "f16" : 0xFC00, "f32" : 0xFF800000, "f64" : 0xFFF0000000000000 }, "nan" : { + "f16" : 0x7C00, "f32" : 0x7FC00000, "f64" : 0x7FF8000000000000 } @@ -173,7 +182,11 @@ def get_mlir_shape(info): def mlir_argdefs(self, extra_node=dict()): buffer_types = {} for x in V.graph.buffers: - if not isinstance(x.layout, MultiOutputLayout): # FIXME: MultiOutputLayout should be handled + if isinstance(x.layout, MultiOutputLayout): + # MultiOutput kernel containers own concrete output nodes in `outputs`. + for out in getattr(x, "outputs", []): + buffer_types[out.get_name()] = [out.get_dtype(), out.get_numel(), out.get_size(), out.get_stride()] + else: buffer_types[x.get_name()] = [x.get_dtype(), x.get_numel(), x.get_size(), x.get_stride()] for name, val in V.graph.graph_inputs.items(): if isinstance(val, sympy.Expr): @@ -250,17 +263,23 @@ def get_tile_stride_per_lane(self, tile_size: list[int], tile_stride: list[int]) return tile_stride def get_compute_vec_size(self, tile_size: list[int], reduction_numel: int, nr_rdim: int) -> int: - if self.forced_vec_size is not None: - return self.forced_vec_size - per_lane = self.get_numel_per_lane(tile_size) stride = self.vlane_stride if nr_rdim: val = per_lane // max(reduction_numel, 1) + result = val for mult in [8, 4, 2]: if per_lane >= val * mult: - return val * mult - return val + result = val * mult + break + if self.forced_vec_size is not None: + # Cap while keeping result divisible by val (= reduction_size). + # This preserves the assert(vec_len % reduction_size == 0) invariant. + capped = (min(result, self.forced_vec_size) // max(val, 1)) * max(val, 1) + result = max(capped, val) + return result + if self.forced_vec_size is not None: + return self.forced_vec_size for mult in [8, 4, 2]: if (per_lane // stride) >= mult: return stride * mult @@ -575,7 +594,6 @@ def __init__(self): # Default HW setting self.vector_lane = extension_config.vpu_num_lanes self.spad_info = extension_config.CONFIG_SPAD_INFO - self.precision = extension_config.CONFIG_PRECISION self.num_cores = extension_config.CONFIG_NUM_CORES self.vlen = extension_config.vpu_vector_length_bits @@ -778,10 +796,24 @@ def codegen_nodes(self, nodes, kernel_name): # Set node range info vars, reduction_vars = self.set_ranges(group, reduction_group) tile_desc = self.compute_tile_size(nodes, vars, reduction_vars) + _, _, _, self.buffer_types = self.kernel_group.args.mlir_argdefs() + safe_vec_size = self.get_safe_vec_size(tile_desc.get_compute_vec_size()) + # For pointwise (non-reduction) kernels, cap the MLIR vector size so that + # f16->f32 widening stays within LMUL<=4 (step and forced_vec_size must match). + # Reduction kernels are left unchanged: their accumulator/multi_reduction + # structure assumes compute_vec_size == step, so we must not split them here. + tile_desc.vmap.forced_vec_size = safe_vec_size + compute_vec = tile_desc.get_compute_vec_size() + # RVV requires vector lengths that produce integer power-of-2 LMUL values. + # Non-power-of-2 element counts (e.g. 24) cause LLVM WidenVectorResult crashes. + # Raise BEFORE the try/except so this propagates to make_choices (not retried). + if compute_vec > 1 and (compute_vec & (compute_vec - 1)) != 0: + raise RecompileSignal( + f"Non-power-of-2 compute_vec_size {compute_vec}: tile rejected (RVV requires power-of-2 LMUL)" + ) self.compute_body_loop.size = tile_desc.get_numel_per_lane() - self.compute_body_loop.step = tile_desc.get_compute_vec_size() + self.compute_body_loop.step = compute_vec try: - _, _, _, self.buffer_types = self.kernel_group.args.mlir_argdefs() with self as kernel: for node in nodes: node.run(vars, reduction_vars) @@ -1026,6 +1058,42 @@ def __exit__(self, exc_type, exc_val, exc_tb): self._nested_context_depth -= 1 if self._nested_context_depth == 0: super().__exit__(exc_type, exc_val, exc_tb) + + def get_safe_vec_size(self, default_vec_size: int = 64) -> int: + """ + Cap forced vector size for low-precision paths so widening ops + (e.g., f16/bf16 -> f32) do not exceed RVV LMUL limits. + + Widening is legal up to source LMUL<=4 (destination LMUL<=8). + Using RVV relation LMUL = (SEW * VL) / VLEN, the safe source VL is: + VL <= 4 * VLEN / SEW + """ + + if not hasattr(self, "buffer_types") or not self.buffer_types: + return default_vec_size + + lowp_bits = [] + for info in self.buffer_types.values(): + dtype = info[0] if info else None + if dtype in DTYPE_LOWP_FP: + mlir_dtype = DTYPE_TO_MLIR[dtype] + lowp_bits.append(MLIR_TO_BIT[mlir_dtype]) + + if not lowp_bits: + return default_vec_size + + min_lowp_bits = min(lowp_bits) + # Constraint: Vector element count must be compatible across all types. + # VLEN=256: f16 (LMUL=2) and f32 (LMUL=4) both yield 32 elements. + # Note: Gem5 version restricts widening ops to LMUL < 8 for destination registers. + # Max LMUL set to 2 to ensure compatibility/safety. + + widen_safe_cap = self.vlen * 2 // min_lowp_bits + if widen_safe_cap <= 0: + return default_vec_size + + vec_size = min(default_vec_size, widen_safe_cap) + return vec_size @dataclasses.dataclass class LoopLevel: diff --git a/PyTorchSimFrontend/mlir/mlir_conv_common.py b/PyTorchSimFrontend/mlir/mlir_conv_common.py index f8566b6d..386e9bd5 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_common.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_common.py @@ -2,7 +2,7 @@ import math from typing import List, Optional -from PyTorchSimFrontend.mlir.mlir_common import MLIRKernelArgs +from PyTorchSimFrontend.mlir.mlir_common import MLIRKernelArgs, get_dtype_nbytes from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel from torch._inductor.ir import IRNode @@ -12,6 +12,9 @@ class MLIRConvCommonTemplate(MLIRTemplate): WRAPPER_TEMPLATE = None def __init__(self, input_nodes, layout, input_reorder=None, **kwargs): super().__init__("kernel", input_nodes, layout, input_reorder) + self.support_epilogue_fusion = True + self.support_prologue_fusion = False + self.support_reduction_fusion = False self.stride = kwargs["stride"] self.padding = kwargs["padding"] self.dilation = kwargs["dilation"] @@ -37,7 +40,7 @@ def render(self, **kwargs): raise NotImplementedError() - def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W): + def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W, precision_bytes): raise NotImplementedError() def extract_info(self, kernel, template_buffer_node, epilogue_nodes): @@ -49,6 +52,13 @@ def extract_info(self, kernel, template_buffer_node, epilogue_nodes): X, W = self.input_nodes[0], self.input_nodes[1] Y = self.output_node Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + dtype_infos = [("X", X.get_dtype()), ("W", W.get_dtype()), ("Y", Y.get_dtype())] + if Bias is not None: + dtype_infos.append(("Bias", Bias.get_dtype())) + if len({dtype for _, dtype in dtype_infos}) != 1: + dtype_desc = ", ".join(f"{name}={dtype}" for name, dtype in dtype_infos) + raise NotImplementedError(f"Mixed dtype Conv is not implemented yet ({dtype_desc})") + precision_bytes = get_dtype_nbytes(X.get_dtype()) if epilogue_nodes is not None: extra_node_rw = { @@ -66,7 +76,7 @@ def extract_info(self, kernel, template_buffer_node, epilogue_nodes): PADDING_W=self.padding[1] STRIDE_H=self.stride[0] STRIDE_W=self.stride[1] - return X,W,Y,Bias,n_extra_node,BATCH,I_C,I_H,I_W,O_C,K_H,K_W,O_H,O_W,PADDING_H,PADDING_W,STRIDE_H,STRIDE_W + return X,W,Y,Bias,n_extra_node,BATCH,I_C,I_H,I_W,O_C,K_H,K_W,O_H,O_W,PADDING_H,PADDING_W,STRIDE_H,STRIDE_W,precision_bytes def get_tile_candidates(self, kernel: MLIRTemplateKernel, @@ -74,8 +84,8 @@ def get_tile_candidates(self, epilogue_nodes: Optional[List[IRNode]] = None, **kwargs): # Extract input arguments info - X, W, Y, Bias, n_extra_node, BATCH, I_C, I_H, I_W, O_C, K_H, K_W, O_H, O_W, PADDING_H, PADDING_W, STRIDE_H, STRIDE_W = self.extract_info(kernel, template_buffer_node, epilogue_nodes) - return self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W) + X, W, Y, Bias, n_extra_node, BATCH, I_C, I_H, I_W, O_C, K_H, K_W, O_H, O_W, PADDING_H, PADDING_W, STRIDE_H, STRIDE_W, precision_bytes = self.extract_info(kernel, template_buffer_node, epilogue_nodes) + return self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W, precision_bytes) def outer_func_render(self, kernel_name, input_args): X, W = self.input_nodes[0], self.input_nodes[1] diff --git a/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py b/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py index da2bc829..8b8288a8 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_mt_template.py @@ -47,7 +47,7 @@ {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} - %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}x{{DATA_STYPE}}> %c0 = arith.constant 0 : index {{- kernel.def_local_vars(indent_size=2) }} @@ -59,7 +59,7 @@ {%- if BIAS %} {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Bias_tile_desc, subtile_size=[SUB_TILE_M, SUB_TILE_N, TILE_O_H, TILE_O_W], indent_size=10) }} {%- else %} - affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> + affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}x{{DATA_STYPE}}> {%- endif %} affine.for %k_h = 0 to {{ K_H }} step {{ TILE_K_H }} { affine.for %tile_k = 0 to {{ I_C * K_W }} step {{ TILE_K }} { @@ -71,16 +71,16 @@ affine.for %tile_k_h = 0 to {{ TILE_K_H }} { // loop order should be fixed for timing simulation. Do not change this order. affine.for %tile_k_w = 0 to 1 { %offset_w = affine.apply #offset_w_map(%tile_k_h, %tile_k_w) - %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ W_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ W_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_K }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1> affine.for %tile_o_h = 0 to {{ TILE_O_H }} { affine.for %tile_o_w = 0 to {{ TILE_O_W }} { %tile_i_h = affine.apply #map_I_H(%tile_o_h, %tile_k_h) %offset_x = affine.apply #offset_x_map(%tile_i_h, %tile_o_w) %offset_y = affine.apply #offset_y_map(%tile_o_h, %tile_o_w) - %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1> - %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> - linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) - outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}x{{DATA_STYPE}}, strided<[{{ TILE_K }}, 1], offset: ?>, 1> + %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}x{{DATA_STYPE}}, strided<[{{ TILE_K }}, 1], offset: ?>, 1>, memref<{{ TILE_K }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) } { inner_loop=true } } { inner_loop=true } } { inner_loop=true } @@ -131,12 +131,12 @@ def render(self, tile_info = None, **kwargs): # Extract input arguments info - X, W, Y, Bias, n_extra_node, BATCH, I_C, I_H, I_W, O_C, K_H, K_W, O_H, O_W, PADDING_H, PADDING_W, STRIDE_H, STRIDE_W = self.extract_info(kernel, template_buffer_node, epilogue_nodes) + X, W, Y, Bias, n_extra_node, BATCH, I_C, I_H, I_W, O_C, K_H, K_W, O_H, O_W, PADDING_H, PADDING_W, STRIDE_H, STRIDE_W, precision_bytes = self.extract_info(kernel, template_buffer_node, epilogue_nodes) # Select tile size adn template conv_template = CONV_TEMPLATE if tile_info is None: - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W)[0] + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W, precision_bytes)[0] else: TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = tile_info SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N @@ -170,7 +170,7 @@ def render(self, Y_tile_desc.set_name("output_buffer") Y_dim = [Symbol("tile_m"), Symbol("tile_n"), Symbol("o_h"), Symbol("o_w")] Y_idx = [Y_dim[0]*O_C*O_H*O_W, Y_dim[1]*O_H*O_W, Y_dim[2]*O_W, Y_dim[3]] - + # Extract Bias info Bias_idx = [Number(0), Symbol("tile_n"), Number(0), Number(0)] Bias_tile_desc = mlir_common.MLIRMultiDimTile(Y_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) @@ -179,6 +179,8 @@ def render(self, if Bias is not None: Bias_tile_desc.offset = Bias.get_layout().offset + data_stype = mlir_common.DTYPE_TO_MLIR[X.get_dtype()] + kernel.render_options = dict( KERNEL_NAME=self.name, kernel=kernel, @@ -220,7 +222,7 @@ def render(self, X_idx = X_idx, W_idx = W_idx, Bias_idx = Bias_idx, - DATA_STYPE="f32", + DATA_STYPE=data_stype, input_reorder=self.input_reorder ) @@ -237,8 +239,8 @@ def render(self, kernel.add_loop_info([kernel.render_options["K_H"], kernel.render_options["K_W"], kernel.render_options["O_H"], kernel.render_options["O_W"], kernel.render_options["BATCH"], kernel.render_options["O_C"], kernel.render_options["I_C"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"], kernel.render_options["TILE_K"]]) return code - def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W): - tile_candidates = kernel.conv_multi_tile_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node) + def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W, precision_bytes): + tile_candidates = kernel.conv_multi_tile_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node, precision_bytes=precision_bytes) for idx, (TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K) in enumerate(tile_candidates): TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] diff --git a/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py b/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py index cc284522..92efff66 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_sb_template.py @@ -48,7 +48,7 @@ {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} - %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}x{{DATA_STYPE}}> %c0 = arith.constant 0 : index {{- kernel.def_local_vars(indent_size=2) }} affine.for %tile_n = 0 to {{ O_C }} step {{ TILE_N }} { @@ -58,7 +58,7 @@ {%- if BIAS %} {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Bias_tile_desc, subtile_size=[1, SUB_TILE_N, TILE_O_H, SUB_TILE_M], indent_size=8) }} {%- else %} - affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> + affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}x{{DATA_STYPE}}> {%- endif %} affine.for %k_h = 0 to {{ K_H }} step {{ TILE_K_H }} { affine.for %k_w = 0 to {{ K_W }} step {{ TILE_K_W }} { @@ -72,16 +72,16 @@ affine.for %tile_k_h = 0 to {{ TILE_K_H }} { // loop order should be fixed for timing simulation. Do not change this order. affine.for %tile_k_w = 0 to {{ TILE_K_W }} { %offset_w = affine.apply #offset_w_map(%tile_k_h, %tile_k_w) - %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ W_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ W_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_K }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1> affine.for %tile_o_h = 0 to {{ TILE_O_H }} { affine.for %tile_o_w = 0 to {{ 1 }} { // TILE_O_W %tile_i_h = affine.apply #map_I_H(%tile_o_h, %tile_k_h) %offset_x = affine.apply #offset_x_map(%tile_i_h, %tile_k_w) %offset_y = affine.apply #offset_y_map(%tile_o_h, %tile_o_w) - %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1> - %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> - linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) - outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}x{{DATA_STYPE}}, strided<[{{ TILE_K }}, 1], offset: ?>, 1> + %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}x{{DATA_STYPE}}, strided<[{{ TILE_K }}, 1], offset: ?>, 1>, memref<{{ TILE_K }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) } { inner_loop=true } } { inner_loop=true } } { inner_loop=true } @@ -132,12 +132,12 @@ def render(self, tile_info = None, **kwargs): # Extract input arguments info - X, W, Y, Bias, n_extra_node, BATCH, I_C, I_H, I_W, O_C, K_H, K_W, O_H, O_W, PADDING_H, PADDING_W, STRIDE_H, STRIDE_W = self.extract_info(kernel, template_buffer_node, epilogue_nodes) + X, W, Y, Bias, n_extra_node, BATCH, I_C, I_H, I_W, O_C, K_H, K_W, O_H, O_W, PADDING_H, PADDING_W, STRIDE_H, STRIDE_W, precision_bytes = self.extract_info(kernel, template_buffer_node, epilogue_nodes) # Select tile size adn template conv_template = CONV_TEMPLATE if tile_info is None: - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W)[0] + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W, precision_bytes)[0] else: TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = tile_info SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N @@ -178,6 +178,8 @@ def render(self, if Bias is not None: Bias_tile_desc.offset = Bias.get_layout().offset + data_stype = mlir_common.DTYPE_TO_MLIR[X.get_dtype()] + kernel.render_options = dict( KERNEL_NAME=self.name, kernel=kernel, @@ -219,7 +221,7 @@ def render(self, X_idx = X_idx, W_idx = W_idx, Bias_idx = Bias_idx, - DATA_STYPE="f32", + DATA_STYPE=data_stype, input_reorder=self.input_reorder ) @@ -236,8 +238,8 @@ def render(self, kernel.add_loop_info([kernel.render_options["K_H"], kernel.render_options["K_W"], kernel.render_options["O_H"], kernel.render_options["O_W"], kernel.render_options["BATCH"], kernel.render_options["O_C"], kernel.render_options["I_C"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"], kernel.render_options["TILE_K"]]) return code - def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W): - tile_candidates = kernel.conv_single_batch_mapping(BATCH, O_C, I_C, K_H, 1, O_H, O_W, self.stride, self.dilation, n_extra_node) # TODO: implement K_W + def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W, precision_bytes): + tile_candidates = kernel.conv_single_batch_mapping(BATCH, O_C, I_C, K_H, 1, O_H, O_W, self.stride, self.dilation, n_extra_node, precision_bytes=precision_bytes) # TODO: implement K_W for idx, (TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K) in enumerate(tile_candidates): TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] + (TILE_K_W - 1) * self.dilation[1] diff --git a/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py b/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py index 6d768bf2..dfd418d9 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_sbs_template.py @@ -48,7 +48,7 @@ {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} - %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}x{{DATA_STYPE}}> %c0 = arith.constant 0 : index {{- kernel.def_local_vars(indent_size=2) }} @@ -59,7 +59,7 @@ {%- if BIAS %} {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Bias_tile_desc, subtile_size=[1, SUB_TILE_N, TILE_O_H, SUB_TILE_M], indent_size=8) }} {%- else %} - affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> + affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}x{{DATA_STYPE}}> {%- endif %} affine.for %k_h = 0 to {{ K_H }} step {{ TILE_K_H }} { affine.for %k_w = 0 to {{ K_W }} step {{ TILE_K_W }} { @@ -72,16 +72,16 @@ affine.for %tile_k_h = 0 to {{ TILE_K_H }} { // loop order should be fixed for timing simulation. Do not change this order. affine.for %tile_k_w = 0 to {{ TILE_K_W }} { %offset_w = affine.apply #offset_w_map(%tile_k_h, %tile_k_w) - %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<{{ TILE_K_H }}x{{ TILE_K_W }}x{{ TILE_K }}x{{ TILE_N }}xf32, 1> to memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : memref<{{ TILE_K_H }}x{{ TILE_K_W }}x{{ TILE_K }}x{{ TILE_N }}x{{DATA_STYPE}}, 1> to memref<{{ TILE_K }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1> affine.for %tile_o_h = 0 to {{ TILE_O_H }} { affine.for %tile_o_w = 0 to {{ 1 }} { // TILE_O_W %tile_i_h = affine.apply #map_I_H(%tile_o_h, %tile_k_h) %offset_x = affine.apply #offset_x_map(%tile_i_h, %tile_k_w) %offset_y = affine.apply #offset_y_map(%tile_o_h, %tile_o_w) - %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1> - %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> - linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) - outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}x{{DATA_STYPE}}, strided<[{{ TILE_K }}, 1], offset: ?>, 1> + %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}x{{DATA_STYPE}}, strided<[{{ TILE_K }}, 1], offset: ?>, 1>, memref<{{ TILE_K }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) } { inner_loop=true } } { inner_loop=true } } { inner_loop=true } @@ -132,12 +132,12 @@ def render(self, tile_info = None, **kwargs): # Extract input arguments info - X, W, Y, Bias, n_extra_node, BATCH, I_C, I_H, I_W, O_C, K_H, K_W, O_H, O_W, PADDING_H, PADDING_W, STRIDE_H, STRIDE_W = self.extract_info(kernel, template_buffer_node, epilogue_nodes) + X, W, Y, Bias, n_extra_node, BATCH, I_C, I_H, I_W, O_C, K_H, K_W, O_H, O_W, PADDING_H, PADDING_W, STRIDE_H, STRIDE_W, precision_bytes = self.extract_info(kernel, template_buffer_node, epilogue_nodes) # Select tile size adn template conv_template = CONV_TEMPLATE if tile_info is None: - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W)[0] + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W, precision_bytes)[0] else: TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = tile_info SUB_TILE_N = TILE_N if TILE_N > 512 else SUB_TILE_N @@ -179,6 +179,8 @@ def render(self, if Bias is not None: Bias_tile_desc.offset = Bias.get_layout().offset + data_stype = mlir_common.DTYPE_TO_MLIR[X.get_dtype()] + kernel.render_options = dict( KERNEL_NAME=self.name, kernel=kernel, @@ -220,7 +222,7 @@ def render(self, X_idx = X_idx, W_idx = W_idx, Bias_idx = Bias_idx, - DATA_STYPE="f32", + DATA_STYPE=data_stype, input_reorder=self.input_reorder ) @@ -237,8 +239,8 @@ def render(self, kernel.add_loop_info([kernel.render_options["K_H"], kernel.render_options["K_W"], kernel.render_options["O_H"], kernel.render_options["O_W"], kernel.render_options["BATCH"], kernel.render_options["O_C"], kernel.render_options["I_C"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"], kernel.render_options["TILE_K"]]) return code - def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W): - tile_candidates = kernel.conv_single_batch_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node) # TODO: implement K_W + def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W, precision_bytes): + tile_candidates = kernel.conv_single_batch_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node, precision_bytes=precision_bytes) # TODO: implement K_W for idx, (TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K) in enumerate(tile_candidates): TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] + (TILE_K_W - 1) * self.dilation[1] diff --git a/PyTorchSimFrontend/mlir/mlir_conv_template.py b/PyTorchSimFrontend/mlir/mlir_conv_template.py index e2cd61fd..178ba7c6 100644 --- a/PyTorchSimFrontend/mlir/mlir_conv_template.py +++ b/PyTorchSimFrontend/mlir/mlir_conv_template.py @@ -48,7 +48,7 @@ {{ kernel.def_sram_buffer("X", X_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} - %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}x{{DATA_STYPE}}> %c0 = arith.constant 0 : index {{ kernel.def_local_vars(indent_size=2) }} @@ -60,7 +60,7 @@ {%- if BIAS %} {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Bias_tile_desc, subtile_size=[SUB_TILE_M, SUB_TILE_N, TILE_O_H, TILE_O_W], indent_size=10) }} {%- else %} - affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}xf32> + affine.vector_store %v0, %output_buffer[%c0, %c0, %c0, %c0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_O_H * TILE_M, TILE_N) }}x{{DATA_STYPE}}> {%- endif %} affine.for %k_h = 0 to {{ K_H }} step {{ TILE_K_H }} { affine.for %k_w = 0 to {{ K_W }} step {{ TILE_K_W }} { @@ -74,17 +74,17 @@ affine.for %tile_k_h = 0 to {{ TILE_K_H }} { // loop order should be fixed for timing simulation. Do not change this order. affine.for %tile_k_w = 0 to {{ TILE_K_W }} { %offset_w = affine.apply #offset_w_map(%tile_k_h, %tile_k_w) - %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ W_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + %W_buffer = memref.reinterpret_cast %weight_buffer to offset: [%offset_w], sizes: [{{ TILE_K }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ W_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_K }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1> affine.for %tile_o_h = 0 to {{ TILE_O_H }} { affine.for %tile_o_w = 0 to {{ TILE_O_W }} { %tile_i_h = affine.apply #map_I_H(%tile_o_h, %tile_k_h) %tile_i_w = affine.apply #map_I_W(%tile_o_w, %tile_k_w) %offset_x = affine.apply #offset_x_map(%tile_i_h, %tile_i_w) %offset_y = affine.apply #offset_y_map(%tile_o_h, %tile_o_w) - %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1> - %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1> - linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}xf32, strided<[{{ TILE_K }}, 1], offset: ?>, 1>, memref<{{ TILE_K }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) - outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}xf32, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + %X_buffer = memref.reinterpret_cast %input_buffer to offset: [%offset_x], sizes: [{{ TILE_M }}, {{ TILE_K }}], strides: [{{ TILE_K }}, 1] : {{ X_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_K }}x{{DATA_STYPE}}, strided<[{{ TILE_K }}, 1], offset: ?>, 1> + %Y_buffer = memref.reinterpret_cast %output_buffer to offset: [%offset_y], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1> + linalg.matmul ins(%X_buffer, %W_buffer : memref<{{ TILE_M }}x{{ TILE_K }}x{{DATA_STYPE}}, strided<[{{ TILE_K }}, 1], offset: ?>, 1>, memref<{{ TILE_K }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) + outs(%Y_buffer : memref<{{ TILE_M }}x{{ TILE_N }}x{{DATA_STYPE}}, strided<[{{ TILE_N }}, 1], offset: ?>, 1>) } { inner_loop=true } } { inner_loop=true } } { inner_loop=true } @@ -136,12 +136,12 @@ def render(self, tile_info = None, **kwargs): # Extract input arguments info - X, W, Y, Bias, n_extra_node, BATCH, I_C, I_H, I_W, O_C, K_H, K_W, O_H, O_W, PADDING_H, PADDING_W, STRIDE_H, STRIDE_W = self.extract_info(kernel, template_buffer_node, epilogue_nodes) + X, W, Y, Bias, n_extra_node, BATCH, I_C, I_H, I_W, O_C, K_H, K_W, O_H, O_W, PADDING_H, PADDING_W, STRIDE_H, STRIDE_W, precision_bytes = self.extract_info(kernel, template_buffer_node, epilogue_nodes) # Select tile size adn template conv_template = CONV_TEMPLATE if tile_info is None: - TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W)[0] + TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W, precision_bytes)[0] else: TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K, TILE_I_H, TILE_I_W, SUB_TILE_I_H, SUB_TILE_I_W, SUB_TILE_K_H, SUB_TILE_K_W, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = tile_info TOG_latency = BATCH if TILE_M > BATCH else TILE_M @@ -183,6 +183,8 @@ def render(self, if Bias is not None: Bias_tile_desc.offset = Bias.get_layout().offset + data_stype = mlir_common.DTYPE_TO_MLIR[X.get_dtype()] + kernel.render_options = dict( KERNEL_NAME=self.name, kernel=kernel, @@ -224,7 +226,7 @@ def render(self, X_idx = X_idx, W_idx = W_idx, Bias_idx = Bias_idx, - DATA_STYPE="f32", + DATA_STYPE=data_stype, input_reorder=self.input_reorder ) @@ -241,8 +243,8 @@ def render(self, kernel.add_loop_info([kernel.render_options["K_H"], kernel.render_options["K_W"], kernel.render_options["O_H"], kernel.render_options["O_W"], kernel.render_options["BATCH"], kernel.render_options["O_C"], kernel.render_options["I_C"]], [kernel.render_options["TILE_M"], kernel.render_options["TILE_N"], kernel.render_options["TILE_K"]]) return code - def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W): - tile_candidates = kernel.conv_combination_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node) + def select_tile(self, kernel, n_extra_node, BATCH, I_C, O_C, K_H, K_W, O_H, O_W, precision_bytes): + tile_candidates = kernel.conv_combination_mapping(BATCH, O_C, I_C, K_H, K_W, O_H, O_W, self.stride, self.dilation, n_extra_node, precision_bytes=precision_bytes) for idx, (TILE_K_H, TILE_K_W, TILE_O_H, TILE_O_W, TILE_M, TILE_N, TILE_K) in enumerate(tile_candidates): TILE_I_H = 1 + (TILE_O_H - 1) * self.stride[0] + (TILE_K_H - 1) * self.dilation[0] TILE_I_W = 1 + (TILE_O_W - 1) * self.stride[1] + (TILE_K_W - 1) * self.dilation[1] diff --git a/PyTorchSimFrontend/mlir/mlir_gemm_template.py b/PyTorchSimFrontend/mlir/mlir_gemm_template.py index 0158caa6..9c61c3d9 100644 --- a/PyTorchSimFrontend/mlir/mlir_gemm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_gemm_template.py @@ -27,14 +27,14 @@ {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} {% if not Bias %} - %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32>{% endif %} + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}x{{DATA_STYPE}}>{% endif %} {{ kernel.def_local_vars(indent_size=2) }} affine.for %index0 = 0 to {{ M }} step {{ TILE_M }} { affine.for %index1 = 0 to {{ N }} step {{ TILE_N }} { {%- if Bias %} {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Bias_tile_desc, subtile_size=[SUB_TILE_M, SUB_TILE_N], indent_size=6) }} {%- else %} - affine.vector_store %v0, %Y_buffer[0, 0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + affine.vector_store %v0, %Y_buffer[0, 0] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }}, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}x{{DATA_STYPE}}> {%- endif %} affine.for %index2 = 0 to {{ K }} step {{ TILE_K }} { {% if prologue_nodes -%} @@ -77,16 +77,16 @@ {{ kernel.def_sram_buffer("W", W_tile_desc, indent_size=2) }} {{ kernel.def_sram_buffer("Y", Y_tile_desc, indent_size=2) }} {% if not Bias %} - %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + %v0 = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}x{{DATA_STYPE}}> {% endif %} {{ kernel.def_local_vars(indent_size=2) }} affine.for %index1 = 0 to {{ N }} step {{ TILE_N }} { affine.for %index0 = 0 to {{ M }} step {{ TILE_M }} { - %Y_bufferT = memref.reinterpret_cast %Y_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}xf32, 1> + %Y_bufferT = memref.reinterpret_cast %Y_buffer to offset: [0], sizes: [{{ TILE_M }}, {{ TILE_N }}], strides: [{{ TILE_N }}, 1] : {{ Y_tile_desc.get_mlir_shape(DATA_STYPE) }} to memref<{{ TILE_M }}x{{ TILE_N }}x{{DATA_STYPE}}, 1> {%- if Bias %} {{ kernel.def_dma_op("MVIN", "Bias", Bias_idx, Bias_tile_desc, subtile_size=[SUB_TILE_M, SUB_TILE_N], indent_size=6) }} {%- else %} - affine.vector_store %v0, %Y_buffer[0, 0] : memref<{{ TILE_N }}x{{ TILE_M }}xf32, 1>, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}xf32> + affine.vector_store %v0, %Y_buffer[0, 0] : memref<{{ TILE_N }}x{{ TILE_M }}x{{DATA_STYPE}}, 1>, vector<{{ kernel.get_spad_size_per_lane(TILE_M, TILE_N) }}x{{DATA_STYPE}}> {%- endif %} affine.for %index2 = 0 to {{ K }} step {{ TILE_K }} { {{ kernel.def_dma_op("MVIN", "X", X_idx, X_tile_desc, subtile_size=[SUB_TILE_M, SUB_TILE_K], indent_size=8) }} @@ -105,6 +105,9 @@ class MLIRGemmTemplate(MLIRTemplate): def __init__(self, input_nodes, layout, input_reorder=None): super().__init__("kernel", input_nodes, layout, input_reorder) + self.support_epilogue_fusion = True + self.support_prologue_fusion = True + self.support_reduction_fusion = True def render(self, kernel: MLIRTemplateKernel, @@ -114,8 +117,9 @@ def render(self, tile_info = None, **kwargs): X, W, Y, M, N, K, n_epilogue_node, n_prologue_node, n_extra_read = self.extract_info(template_buffer_node, epilogue_nodes, prologue_nodes) + precision_bytes = mlir_common.get_dtype_nbytes(X.get_dtype()) if tile_info is None: - TILE_M, TILE_N, TILE_K, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, M, N, K, n_epilogue_node, n_extra_read, n_prologue_node)[0] + TILE_M, TILE_N, TILE_K, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = self.select_tile(kernel, M, N, K, n_epilogue_node, n_extra_read, n_prologue_node, precision_bytes)[0] else: TILE_M, TILE_N, TILE_K, SUB_TILE_M, SUB_TILE_N, SUB_TILE_K = tile_info @@ -184,6 +188,8 @@ def render(self, else: Bias_idx = None + data_stype = mlir_common.DTYPE_TO_MLIR[X.get_dtype()] + kernel.render_options = dict( KERNEL_NAME=self.name, kernel=kernel, @@ -194,7 +200,7 @@ def render(self, SUB_TILE_M=SUB_TILE_M, SUB_TILE_N=SUB_TILE_N, SUB_TILE_K=SUB_TILE_K, - DATA_STYPE="f32", + DATA_STYPE=data_stype, X = X, W = W, Y = Y, Bias = Bias, X_idx = X_idx, @@ -269,7 +275,8 @@ def get_tile_candidates(self, prologue_nodes: Optional[List[IRNode]] = None, **kwargs): X, W, Y, M, N, K, n_epilogue_node, n_prologue_node, n_extra_read = self.extract_info(template_buffer_node, epilogue_nodes, prologue_nodes) - return self.select_tile(kernel, M, N, K, n_epilogue_node, n_extra_read, n_prologue_node) + precision_bytes = mlir_common.get_dtype_nbytes(X.get_dtype()) + return self.select_tile(kernel, M, N, K, n_epilogue_node, n_extra_read, n_prologue_node, precision_bytes) def extract_info(self, template_buffer_node, epilogue_nodes, prologue_nodes): if template_buffer_node is not None: @@ -277,6 +284,12 @@ def extract_info(self, template_buffer_node, epilogue_nodes, prologue_nodes): # Extract input arguments info X, W, Y = self.input_nodes[0], self.input_nodes[1], self.output_node + dtype_infos = [("X", X.get_dtype()), ("W", W.get_dtype()), ("Y", Y.get_dtype())] + if len(self.input_nodes) > 2: + dtype_infos.append(("Bias", self.input_nodes[2].get_dtype())) + if len({dtype for _, dtype in dtype_infos}) != 1: + dtype_desc = ", ".join(f"{name}={dtype}" for name, dtype in dtype_infos) + raise NotImplementedError(f"Mixed dtype GEMM is not implemented yet ({dtype_desc})") X_tensor = empty_strided(X.layout.size, X.layout.stride) W_tensor = empty_strided(W.layout.size, W.layout.stride) if len(W_tensor.size()) > 2 or len(X_tensor.size()) > 2: @@ -296,7 +309,7 @@ def extract_info(self, template_buffer_node, epilogue_nodes, prologue_nodes): M, N, K = X_tensor.size()[0], W_tensor.size()[1], X_tensor.size()[1] return X,W,Y,M,N,K,n_epilogue_node,n_prologue_node,len(n_extra_read) - def select_tile(self, kernel, M, N, K, n_extra_node, n_extra_read, n_prologue_node): + def select_tile(self, kernel, M, N, K, n_extra_node, n_extra_read, n_prologue_node, precision_bytes): data = {} gemm_shape = f"{M}_{N}_{K}" if "external" in extension_config.codegen_mapping_strategy: @@ -316,7 +329,7 @@ def select_tile(self, kernel, M, N, K, n_extra_node, n_extra_read, n_prologue_no else: # case 2: use heuristic mapping min_tile = (n_extra_node + n_prologue_node) == 0 - tile_candidates = kernel.gemm_combination_mapping(M, N, K, max(n_extra_read-2, 0), n_prologue_node, min_tile=True) + tile_candidates = kernel.gemm_combination_mapping(M, N, K, max(n_extra_read-2, 0), n_prologue_node, min_tile=True, precision_bytes=precision_bytes) # Edge case if (M == 0) or (N == 0) or (K == 0): diff --git a/PyTorchSimFrontend/mlir/mlir_lowering.py b/PyTorchSimFrontend/mlir/mlir_lowering.py index ebf0c80e..b717089f 100644 --- a/PyTorchSimFrontend/mlir/mlir_lowering.py +++ b/PyTorchSimFrontend/mlir/mlir_lowering.py @@ -1,3 +1,4 @@ +import math from typing import List, Optional, Sequence import torch @@ -15,10 +16,18 @@ from PyTorchSimFrontend.mlir.mlir_conv_sb_template import MLIRConvSingleBatchTemplate from PyTorchSimFrontend.mlir.mlir_conv_sbs_template import MLIRConvSingleBatchStridedTemplate from PyTorchSimFrontend.mlir.mlir_maxpool_template import MLIRMaxPoolTemplate +from PyTorchSimFrontend.mlir.mlir_cat_template import MLIRCatTemplate +from PyTorchSimFrontend.mlir.mlir_sort_template import MLIRSortTemplate, MLIRStableSortTemplate +from PyTorchSimFrontend.mlir.mlir_sdpa_template import ( + MLIRFlashSDPATemplate, + flash_sdpa_args, + calculate_scale, +) from PyTorchSimFrontend import extension_config aten = torch.ops.aten aten_spmm = MLIRExternKernelChoice(torch.sparse.mm, "custom_op::sparse_addmm") +_orig_sort_values_stable_lowering = lowerings.get(aten.sort.values_stable) def tuned_mm(mat1, mat2, * ,layout=None): m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout) @@ -38,6 +47,29 @@ def tuned_bmm(mat1, mat2, *, layout=None): return mlir_template.generate().output_node() + +def tuned_flash_sdpa( + query : TensorBox, + key : TensorBox, + value : TensorBox, + attn_bias : Optional[TensorBox] = None, + dropout_p : float = 0.0, + is_causal : bool = False, + return_debug_mask : bool = False, + scale : Optional[float] = None, + enable_gqa : bool = False) -> tuple: + # _fused_sdp_choice in C++ already guarantees: + # L == S (prefill), Hq == H (non-GQA), dropout_p == 0.0 + # before routing here via SDPBackend::overrideable. + # Non-matching shapes fall back to SDPBackend::math in C++ and decompose + # into primitive ops (matmul/softmax) before reaching this lowering. + scale = calculate_scale(query, scale) + N, Hq, H, L, S, E, Ev, layout, query, key, value = flash_sdpa_args(query, key, value) + mlir_template = MLIRFlashSDPATemplate([query, key, value], layout, scale) + return (mlir_template.generate().output_node(), None, None, None, None, None, None, None, None) + + + def conv_layout( x: TensorBox, weight: TensorBox, @@ -181,11 +213,105 @@ def custom_unsafe_index(x, indices): x.realize() return index_impl(x, indices, check=False) + +def _cat_layout(tensors: Sequence[TensorBox], dim: int) -> ir.Layout: + with V.graph.fake_mode: + output = torch.ops.aten.cat( + [ir.ir_node_to_tensor(t, guard_shape=True) for t in tensors], + dim, + ) + sizes = ir.convert_shape_to_inductor(output.size()) + stride = ir.convert_shape_to_inductor(output.stride()) + return ir.FixedLayout( + tensors[0].get_device(), + tensors[0].get_dtype(), + sizes, + stride, + ) + +def custom_cat_default(tensors: Sequence[TensorBox], dim: int = 0): + if tensors and dim < 0: + dim += len(tensors[0].get_size()) + copy_default_lowering = lowerings.get(aten.copy_.default) + empty_strided_lowering = lowerings.get(aten.empty_strided.default) + new_tensors = [] + for t in tensors: + t.realize() + # If the tensor is backed by a view (ReinterpretView, PermuteView, etc.), + # materialise it into a fresh contiguous FixedLayout buffer so the cat + # kernel always receives plain, dense strides. + if isinstance(t.data, ir.BaseView): + sizes = list(t.get_size()) + strides = [math.prod(sizes[i + 1:]) for i in range(len(sizes))] + new_buf = empty_strided_lowering( + sizes, strides, dtype=t.get_dtype(), device=t.get_device() + ) + tt = copy_default_lowering(new_buf, t) + else: + tt = t + new_tensors.append(tt) + + layout = _cat_layout(new_tensors, dim) + mlir_template = MLIRCatTemplate(list(new_tensors), layout, dim=dim) + return mlir_template.generate().output_node() + +def custom_sort_default( + value: TensorBox, + dim: int = -1, + descending: bool = False, + stable: Optional[bool] = None, +): + if dim < 0: + dim += len(value.get_size()) + + value.realize() + + value_layout, index_layout = _sort_layouts(value, dim, descending) + empty_strided_lowering = lowerings.get(aten.empty_strided.default) + indices = empty_strided_lowering( + value.get_size(), + index_layout.stride, + dtype=torch.int64, + device=value.get_device(), + ) + stable_required = True if stable is None else stable + sort_template_cls = MLIRStableSortTemplate if stable_required else MLIRSortTemplate + mlir_template = sort_template_cls( + [value, indices], + value_layout, + dim=dim, + descending=descending, + stable=stable_required, + ) + sorted_values = mlir_template.generate(template_buffer_node=value).output_node() + return sorted_values, indices + + +def _sort_layouts(x: TensorBox, dim: int, descending: bool): + with V.graph.fake_mode: + v, i = torch.ops.aten.sort( + ir.ir_node_to_tensor(x, guard_shape=True), + dim, + descending, + ) + v_sizes = ir.convert_shape_to_inductor(v.size()) + v_stride = ir.convert_shape_to_inductor(v.stride()) + i_sizes = ir.convert_shape_to_inductor(i.size()) + i_stride = ir.convert_shape_to_inductor(i.stride()) + + value_layout = ir.FixedLayout(x.get_device(), x.get_dtype(), v_sizes, v_stride) + index_layout = ir.FixedLayout(x.get_device(), torch.int64, i_sizes, i_stride) + return value_layout, index_layout + lowerings.update({getattr(aten.mm, overload): tuned_mm for overload in aten.mm.overloads()}) lowerings.update({getattr(aten.addmm, overload): tuned_addmm for overload in aten.addmm.overloads()}) lowerings.update({getattr(aten.convolution, overload): convolution for overload in aten.convolution.overloads()}) lowerings.update({getattr(aten.bmm, overload): tuned_bmm for overload in aten.bmm.overloads()}) lowerings.update({getattr(aten._sparse_addmm, overload): sparse_addmm for overload in aten._sparse_addmm.overloads()}) lowerings.update({getattr(aten._unsafe_index, overload): custom_unsafe_index for overload in aten._unsafe_index.overloads()}) +lowerings.update({getattr(aten.cat, overload): custom_cat_default for overload in aten.cat.overloads()}) +lowerings.update({getattr(aten.sort, overload): custom_sort_default for overload in aten.sort.overloads()}) + if extension_config.CONFIG_USE_TIMING_POOLING: - lowerings.update({getattr(aten.max_pool2d_with_indices, overload): custom_maxpool for overload in aten.max_pool2d_with_indices.overloads()}) # FIXME: maxpool should be implemented as a template \ No newline at end of file + lowerings.update({getattr(aten.max_pool2d_with_indices, overload): custom_maxpool for overload in aten.max_pool2d_with_indices.overloads()}) # FIXME: maxpool should be implemented as a template +lowerings.update({getattr(aten._scaled_dot_product_fused_attention_overrideable, overload): tuned_flash_sdpa for overload in aten._scaled_dot_product_fused_attention_overrideable.overloads()}) diff --git a/PyTorchSimFrontend/mlir/mlir_ops.py b/PyTorchSimFrontend/mlir/mlir_ops.py index 9edd2e44..76a0e273 100644 --- a/PyTorchSimFrontend/mlir/mlir_ops.py +++ b/PyTorchSimFrontend/mlir/mlir_ops.py @@ -59,6 +59,10 @@ def constant(value, src_type, *args, **kwargs): str_val = str(value) if "inf" == str_val or "-inf" == str_val or "nan" == str_val: value = f"0x{mlir_common.MLIR_INF[str_val][src_type]:x}" + elif isinstance(value, bool): + value = 1 if value else 0 + if src_type[0] == "f": + value = format(float(value), ".20f") # scientific notation check elif "e" in str_val: value = format(float(value), ".20f") @@ -182,7 +186,7 @@ def to_dtype(operand, dst_mlir_dtype, *args, **kwargs): # Case A: Integer -> Float if src_type_char == "i" and dst_type_char == "f": - op_str = f"arith.sitofp %{operand} : {src_shape} to {shape}" + op_str = f"arith.uitofp %{operand} : {src_shape} to {shape}" # Case B: Float -> Integer elif src_type_char == "f" and dst_type_char == "i": op_str = f"arith.fptosi %{operand} : {src_shape} to {shape}" @@ -1142,6 +1146,80 @@ def multi_reduction(acc, init, vec_size, red_size, red_shape, red_type, type_nam line = reduction_combine_vec(red_type, value, init, axis=0, shape=new_vshape, reduced_shape=final_reduced_shape) return line, [red_size, type_name] + @staticmethod + def vector_shuffle(operand, indices, operand2=None, *args, **kwargs): + tile_size1, dtype1 = V.kernel.var_info[operand] + if operand2 is None: + operand2 = operand + tile_size2, dtype2 = V.kernel.var_info[operand2] + if dtype1 != dtype2: + raise ValueError( + f"vector_shuffle expects same element type, got {dtype1} and {dtype2}" + ) + total_size = tile_size1 + tile_size2 + for idx in indices: + if idx < -1 or idx >= total_size: + raise ValueError( + f"vector_shuffle index out of range: {idx}, expected in [-1, {total_size - 1}]" + ) + vt1 = f"vector<{tile_size1}x{dtype1}>" + vt2 = f"vector<{tile_size2}x{dtype1}>" + idx_str = ", ".join(str(i) for i in indices) + op_str = f"vector.shuffle %{operand}, %{operand2} [{idx_str}]" + return format_mlir_op(op_str, f"{vt1}, {vt2}", **kwargs), [len(indices), dtype1] + + @staticmethod + def constant_mask(select_min, N, *args, **kwargs): + vals = ", ".join("true" if x else "false" for x in select_min) + op_str = f"arith.constant dense<[{vals}]>" + return format_mlir_op(op_str, f"vector<{N}xi1>", **kwargs), [N, "i1"] + + @staticmethod + def bitonic_sort(operand, descending=False, *args, **kwargs): + def _compute_bitonic_stages(N: int, descending: bool): + assert N >= 2 and (N & (N - 1)) == 0, "N must be power-of-2 >= 2" + stages = [] + size = 2 + while size <= N: + stride = size // 2 + while stride >= 1: + merged_shuffle = list(range(N)) + merged_mask = [None] * N + + for start in range(0, N, size): + blk_dir = "ASCENDING" if (start // size) % 2 == 0 else "DESCENDING" + for i in range(start, start + size - stride, stride * 2): + for j in range(stride): + a, b = i + j, i + j + stride + merged_shuffle[a] = b + merged_shuffle[b] = a + if blk_dir == "ASCENDING": + merged_mask[a] = True # a = min + merged_mask[b] = False # b = max + else: + merged_mask[a] = False # a = max + merged_mask[b] = True # b = min + select_min = [bool(x) if x is not None else False for x in merged_mask] + if descending: + select_min = [not x for x in select_min] + stages.append({ + "shuffle": merged_shuffle, + "select_min": select_min, + }) + stride //= 2 + size *= 2 + return stages + + tile_size, _ = V.kernel.var_info[operand] + cur = operand + for stage in _compute_bitonic_stages(tile_size, descending): + mask = ops.constant_mask(stage["select_min"], tile_size) + shuffled = ops.vector_shuffle(cur, stage["shuffle"]) + vmin = ops.minimum(cur, shuffled) + vmax = ops.maximum(cur, shuffled) + cur = ops.where(mask, vmin, vmax) + return cur, V.kernel.var_info[cur] + @staticmethod def _load(compute_vec_size, mlir_dtype, buffer, indices, buffer_shape, *args, **kwargs): if compute_vec_size == 1: diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index af960533..22d1011b 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -44,12 +44,10 @@ def can_fuse_with_exceptions(self, node1: BaseSchedulerNode, node2: BaseSchedule # Case 3: Prologue(Pointwise) + Tempalte if len(base_template_node1) == 0 and len(node1.get_nodes())==1 and len(node2.get_nodes())==1 and not node1.is_reduction() and len(base_template_node2) == 1 and extension_config.CONFIG_FUSION_PROLOGUE: - from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate - from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate - target_node = base_template_node2[0].node - # Currently only BMM, MM support prologue fusion - if not isinstance(target_node.template, (MLIRBMMTemplate, MLIRGemmTemplate)): + + # Check if template supports prologue fusion + if not getattr(target_node.template, 'support_prologue_fusion', False): return False if len(node1.read_writes.writes) != 1: @@ -129,12 +127,14 @@ def can_fuse_horizontal(self, node1, node2): if len(base_template_node1) == 1 and len(node1.get_nodes())==1 and len(node2.get_nodes())==1 and len(base_template_node2) == 0 and not node2.is_reduction(): # Don't fuse maxpool template code from PyTorchSimFrontend.mlir.mlir_maxpool_template import MLIRMaxPoolTemplate - from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate - from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate template_node = base_template_node1[0] epilogue_node = node2 + # Check if template supports epilogue fusion + if not getattr(template_node.node.template, 'support_epilogue_fusion', False): + return False + if isinstance(template_node.node.template, MLIRMaxPoolTemplate): return False @@ -161,7 +161,7 @@ def can_fuse_horizontal(self, node1, node2): # Revert act_node.group : simplify_and_reorder() modified _body, _size, group if template_node.group != epilogue_node.group: # We don't fuse this case... - if (isinstance(template_node.node.template, MLIRBMMTemplate) or isinstance(template_node.node.template, MLIRGemmTemplate)) and template_node.group[1][0][0] == 1: + if getattr(template_node.node.template, 'support_prologue_fusion', False) and template_node.group[1][0][0] == 1: return False if list(template_node.group[1][0]) != list(epilogue_node.get_nodes()[0].node.data.get_size()): @@ -171,10 +171,10 @@ def can_fuse_horizontal(self, node1, node2): # Case 2: Tempalte + Reduction fusion if len(base_template_node1) == 1 and len(node1.get_nodes())==1 and len(node2.get_nodes())==1 and len(base_template_node2) == 0 and node2.is_reduction() and extension_config.CONFIG_FUSION_REDUCTION_EPILOGUE: - from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate - from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate target_node = base_template_node1[0].node - if not isinstance(target_node.template, (MLIRBMMTemplate, MLIRGemmTemplate)): + + # Check if template supports reduction fusion + if not getattr(target_node.template, 'support_reduction_fusion', False): return False size_match = node1.get_nodes()[0].node.get_numel() == reduce(operator.mul, node2.get_nodes()[0].node.get_size(), 1) * reduce(operator.mul, node2.get_nodes()[0].node.get_reduction_size(), 1) @@ -276,7 +276,7 @@ def codegen_node(self, _node): MLIRScheduling.count += 1 src_code, meta_code = ex_kernel.codegen_nodes(nodes, kernel_name_candidate) kernel_name = self.define_kernel(src_code, meta_code, kernel_name_candidate, ex_kernel.vector_lane, - ex_kernel.spad_info, origins= {str(i) for i in nodes[0].node.origins}) + ex_kernel.spad_info, origins={str(i) for node in nodes for i in node.node.origins}) ex_kernel.call_kernel(kernel_name) _, args, _, _ = ex_kernel.args.mlir_argdefs() args = ", ".join(args) @@ -332,8 +332,10 @@ def codegen_template(self, template_node, epilogue_nodes, prologue_nodes): src_code, meta_code = kernel.codegen_nodes(tile_candidates, render, template_node, prologue_nodes, epilogue_nodes) with kernel: + all_nodes = [template_node] + (epilogue_nodes or []) + (prologue_nodes or []) + origins = {str(i) for n in all_nodes for i in n.node.origins} kernel_name = self.define_kernel(src_code, meta_code, kernel.kernel_name, kernel.vector_lane, kernel.spad_info, - kernel.loop_size, origins={str(i) for i in template_node.node.origins}) + kernel.loop_size, origins=origins) self.define_function(kernel) kernel.call_kernel(kernel_name) diff --git a/PyTorchSimFrontend/mlir/mlir_sdpa_template.py b/PyTorchSimFrontend/mlir/mlir_sdpa_template.py new file mode 100644 index 00000000..37db4956 --- /dev/null +++ b/PyTorchSimFrontend/mlir/mlir_sdpa_template.py @@ -0,0 +1,550 @@ +import math # sqrt +import sympy + +from typing import List, Optional + +import torch +from torch import empty_strided +from torch._inductor.ir import IRNode, TensorBox, FixedLayout +from torch._inductor.virtualized import V +from torch._inductor.select_algorithm import realize_inputs +from torch.backends.cuda import flash_sdp_enabled, mem_efficient_sdp_enabled + +from PyTorchSimFrontend import extension_config +from PyTorchSimFrontend.mlir import mlir_common +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplateKernel + + +def _make_offset_map_with_sym(strides, sym_dim, sym_stride, offset=0): + """Like _make_offset_map but injects a block symbol ``s`` into dimension ``sym_dim``. + + The effective index for that dimension becomes ``d{sym_dim} + sym_stride * s``. + Use this to keep ``affine.for`` bounds static and encode the block contribution + directly inside the ``affine.apply`` call that computes the DRAM offset. + + Args: + strides: per-dimension DRAM strides. + sym_dim: which dimension carries the block symbol. + sym_stride: multiplier for the symbol (1 for abs-position loops like FLASH + ``%blk``; ``BlkS`` for block-index loops like PARTIAL ``%blk``). + offset: constant layout offset. + + Returns: + MLIR affine_map string with one symbol, e.g. + ``affine_map<(d0, d1, d2)[s] -> (d0 * 8192 + (d1 + 128 * s) * 64 + d2)>`` + """ + n = len(strides) + terms = [] + for j, sv in enumerate(strides): + sv = int(sv) + if sv == 0: + continue + if j == sym_dim: + inner = f"d{j} + s" if sym_stride == 1 else f"d{j} + {sym_stride} * s" + terms.append(f"({inner})" if sv == 1 else f"({inner}) * {sv}") + else: + terms.append(f"d{j}" if sv == 1 else f"d{j} * {sv}") + try: + off = int(offset) + except (TypeError, ValueError): + off = 0 + if off: + terms.append(str(off)) + dim_str = ", ".join(f"d{j}" for j in range(n)) + expr = " + ".join(terms) if terms else "0" + return f"affine_map<({dim_str})[s] -> ({expr})>" + + +def _make_offset_map(strides, offset=0): + """Generate an MLIR affine_map string for a flat DRAM base-address. + + Args: + strides: list of integer per-dimension strides. + A stride of 0 means the dimension does not contribute. + offset: constant layout offset (e.g. from IRNode.get_layout().offset). + + Returns: + MLIR affine_map string, e.g. ``affine_map<(d0, d1) -> (d0 * 128 + d1)>`` + """ + n = len(strides) + terms = [] + for j, s in enumerate(strides): + s = int(s) + if s == 1: + terms.append(f"d{j}") + elif s != 0: + terms.append(f"d{j} * {s}") + try: + off = int(offset) + except (TypeError, ValueError): + off = 0 + if off: + terms.append(str(off)) + dim_str = ", ".join(f"d{j}" for j in range(n)) + expr = " + ".join(terms) if terms else "0" + return f"affine_map<({dim_str}) -> ({expr})>" + + +def flash_sdpa_args( + query : TensorBox, + key : TensorBox, + value : TensorBox) -> list: + """ + Arg processing for flash SDPA. + Its logic is based on: + mm_args() which is in torch._inductor.kernel.mm_common.py (142 line). + """ + + # Materialize input buffers for the codegen backend. + query, key, value = realize_inputs(query, key, value) + + # query : (n, hq, l, e) + # key : (n, h, s, e) + # value : (n, h, s, ev) + # out : (n, hq, l, ev) + # n: Batch size + # hq: query's head counts, h: key and value's head counts. + # l: target sequence lenght and s: source sequence length. + # e: embeding dimension of the query and key and ev: embeding dimension of the value. + nq, hq, l, eq = query.get_size() + nk, hk, sk, ek = key.get_size() + nk, hv, sv, ev = value.get_size() + + n = V.graph.sizevars.guard_equals(nq, nk) + n = V.graph.sizevars.guard_equals(nq, nk) + + h = V.graph.sizevars.guard_equals(hk, hv) + s = V.graph.sizevars.guard_equals(sk, sv) + e = V.graph.sizevars.guard_equals(eq, ek) + + # While there are no theoretical requirements for e == ev, + # this implementation currently enforces e == ev for simplicity. + if e != ev: + raise NotImplementedError( + "Flash SDPA currently requires matching head dimensions between query and value (e == ev)." + ) + + # Minimal GQA support (single-batch only for now). + # We map each query head to a KV head by grouping: hq = g * h. + if hq != h: + if n != 1: + raise NotImplementedError("Flash SDPA GQA is currently supported only for n == 1.") + if (hq % h) != 0: + raise NotImplementedError(f"Flash SDPA GQA requires hq % h == 0 (hq: {hq}, h: {h}).") + + layout = FixedLayout( + query.get_device(), + query.get_dtype(), + [n, hq, l, ev] + ) + + return [n, hq, h, l, s, e, ev, layout, query, key, value] + +def calculate_scale(query: torch.Tensor, scale: float) -> float: + """ + Calculate the scaling factor based on the head dimension if scale is None + Otherwise, use the provided scale. + """ + if scale is None: + return 1.0 / math.sqrt(query.layout.size[-1]) + else: + return scale + + +FLASH_SDPA_TEMPLATE = r""" +// SDPA kernel +// b = {{ b }} +// l = {{ l }} +// s = {{ s }} +// e = {{ e }} +// tile_l = {{ tile_l }} +// tile_s = {{ tile_s }} +// tile_e = {{ tile_e }} +{{kernel.def_global_vars()}} + +func.func @{{ KERNEL_NAME }}{{kernel.def_kernel(inputs=[query, key, value], outputs=[out], names_str="query, key, value, out", input_reorder=input_reorder)}} { + // Inputs + {{ kernel.def_sram_buffer("query", q_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("key", k_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("value", v_tile_desc, indent_size=2) }} + + // Output + {{ kernel.def_sram_buffer("out", out_tile_desc, indent_size=2) }} + + // Intermediate buffers + {{ kernel.def_sram_buffer("mul", mul_tile_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("max", max_desc, indent_size=2) }} + {{ kernel.def_sram_buffer("sum", sum_desc, indent_size=2) }} + + // Constants + %c0 = arith.constant 0.0 : {{ data_stype }} + %c1 = arith.constant 1.0 : {{ data_stype }} + %c_scale = arith.constant {{ scale }} : {{ data_stype }} + %c_neg_inf = arith.constant -1.0e+30 : {{ data_stype }} + + %v0_c = arith.constant dense<0.0> : vector<{{ chunk_size }}x{{ data_stype }}> + %v0_l = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(tile_l, tile_e) }}x{{ data_stype }}> + %v0_s = arith.constant dense<0.0> : vector<{{ kernel.get_spad_size_per_lane(tile_s, tile_l) }}x{{ data_stype }}> + %v0_2x = arith.constant dense<0.0> : vector<2x{{ data_stype }}> + + %v_neg_inf_c = arith.constant dense<-1.0e+30> : vector<{{ chunk_size }}x{{ data_stype }}> + %v_neg_inf_2x = arith.constant dense<-1.0e+30> : vector<2x{{ data_stype }}> + + %v_scale = vector.broadcast %c_scale : {{ data_stype }} to vector<{{ tile_s }}x{{ data_stype }}> + + {{ kernel.def_local_vars(indent_size=2) }} + + affine.for %index0 = 0 to {{ b }} { + affine.for %index3 = 0 to 1 step 1 { + affine.for %index1 = 0 to {{ l }} step {{ tile_l }} { + %q_dram_offset = affine.apply {{ q_offset_map }}(%index0, %index1, %index3) + {{ kernel.def_dma_op("MVIN", "query", [], q_tile_desc, indent_size=8, dram_stride=q_dram_stride, dram_offset="q_dram_offset") }} + + affine.vector_store %v0_l, %out_buffer[0, 0, 0] : {{ out_tile_desc.get_mlir_shape(data_stype) }}, vector<{{ kernel.get_spad_size_per_lane(tile_l, tile_e) }}x{{ data_stype }}> + affine.vector_store %v_neg_inf_2x, %max_buffer[0, 0] : {{ max_desc.get_mlir_shape(data_stype) }}, vector<2x{{ data_stype }}> + affine.vector_store %v0_2x, %sum_buffer[0, 0] : {{ sum_desc.get_mlir_shape(data_stype) }}, vector<2x{{ data_stype }}> + + %qt_buffer2D = memref.reinterpret_cast %q_buffer to offset: [0], sizes: [{{ tile_e }}, {{ tile_l }}], strides: [{{ tile_l }}, 1] : {{ q_tile_desc.get_mlir_shape(data_stype) }} to memref<{{ tile_e }}x{{ tile_l }}x{{ data_stype }}, 1> + %ot_buffer2D = memref.reinterpret_cast %out_buffer to offset: [0], sizes: [{{ tile_e }}, {{ tile_l }}], strides: [{{ tile_l }}, 1] : {{ out_tile_desc.get_mlir_shape(data_stype) }} to memref<{{ tile_e }}x{{ tile_l }}x{{ data_stype }}, 1> + + affine.for %index2 = 0 to {{ s }} step {{ tile_s }} { + %k_dram_offset = affine.apply {{ k_offset_map }}(%index0, %index2, %index3) + {{ kernel.def_dma_op("MVIN", "key", [], k_tile_desc, indent_size=10, dram_stride=k_dram_stride, dram_offset="k_dram_offset") }} + %v_dram_offset = affine.apply {{ v_offset_map }}(%index0, %index2, %index3) + {{ kernel.def_dma_op("MVIN", "value", [], v_tile_desc, indent_size=10, dram_stride=v_dram_stride, dram_offset="v_dram_offset") }} + + affine.vector_store %v0_s, %mul_buffer[0, 0] : {{ mul_tile_desc.get_mlir_shape(data_stype) }}, vector<{{ kernel.get_spad_size_per_lane(tile_s, tile_l) }}x{{ data_stype }}> + + %k_buffer2D = memref.reinterpret_cast %k_buffer to offset: [0], sizes: [{{ tile_s }}, {{ tile_e }}], strides: [{{ tile_e }}, 1] : {{ k_tile_desc.get_mlir_shape(data_stype) }} to memref<{{ tile_s }}x{{ tile_e }}x{{ data_stype }}, 1> + %vt_buffer2D = memref.reinterpret_cast %v_buffer to offset: [0], sizes: [{{ tile_e }}, {{ tile_s }}], strides: [{{ tile_s }}, 1] : {{ v_tile_desc.get_mlir_shape(data_stype) }} to memref<{{ tile_e }}x{{ tile_s }}x{{ data_stype }}, 1> + + + // key @ query.t and scaling. + linalg.matmul + { idx_map = array } + ins(%k_buffer2D, %qt_buffer2D : memref<{{ tile_s }}x{{ tile_e }}x{{ data_stype }}, 1>, memref<{{ tile_e }}x{{ tile_l }}x{{ data_stype }}, 1>) + outs(%mul_buffer : {{ mul_tile_desc.get_mlir_shape(data_stype) }}) + + %raw_mul_vec = affine.vector_load %mul_buffer[0, 0] : {{ mul_tile_desc.get_mlir_shape(data_stype) }}, vector<{{ tile_s }}x{{ data_stype }}> + %scaled_mul_vec = arith.mulf %raw_mul_vec, %v_scale : vector<{{ tile_s }}x{{ data_stype }}> + affine.vector_store %scaled_mul_vec, %mul_buffer[0, 0] : {{ mul_tile_desc.get_mlir_shape(data_stype) }}, vector<{{ tile_s }}x{{ data_stype }}> + + + // Find new max. + %old_max = affine.vector_load %max_buffer[0,0] : {{ max_desc.get_mlir_shape(data_stype) }}, vector<2x{{ data_stype }}> + + %chunk_max_res = affine.for %index5 = 0 to {{ tile_s }} step {{ chunk_size }} iter_args(%iter_max=%v_neg_inf_c) -> (vector<{{ chunk_size }}x{{ data_stype }}>) { + %chunk_val = affine.vector_load %mul_buffer[0, %index5] : {{ mul_tile_desc.get_mlir_shape(data_stype) }}, vector<{{ chunk_size }}x{{ data_stype }}> + %local_max = arith.maximumf %chunk_val, %iter_max : vector<{{ chunk_size }}x{{ data_stype }}> + affine.yield %local_max : vector<{{ chunk_size }}x{{ data_stype }}> + } + + %max_cast = vector.shape_cast %chunk_max_res : vector<{{ chunk_size }}x{{ data_stype }}> to vector<{{ chunk_size // 2 }}x2x{{ data_stype }}> + %max_reduced_1 = vector.multi_reduction , %max_cast, %v_neg_inf_2x [0] : vector<8x2x{{ data_stype }}> to vector<2x{{ data_stype }}> + %max_shuffled = vector.shuffle %max_reduced_1, %max_reduced_1 [1, 0] : vector<2x{{ data_stype }}>, vector<2x{{ data_stype }}> + %max_reduced_2 = arith.maximumf %max_reduced_1, %max_shuffled : vector<2x{{ data_stype }}> + + %new_max = arith.maximumf %max_reduced_2, %old_max : vector<2x{{ data_stype }}> + affine.vector_store %new_max, %max_buffer[0, 0] : {{ max_desc.get_mlir_shape(data_stype) }}, vector<2x{{ data_stype }}> + + + // Compute rescale factors: exp(old_max - new_max) + %max_diff = arith.subf %old_max, %new_max : vector<2x{{ data_stype }}> + %max_diff_scalar = vector.extract %max_diff[0] : {{ data_stype }} from vector<2x{{ data_stype }}> + + %rescale_bcast_e = vector.broadcast %max_diff_scalar : {{ data_stype }} to vector<{{ tile_e }}x{{ data_stype }}> + %exp_rescale_e = math.exp %rescale_bcast_e : vector<{{ tile_e }}x{{ data_stype }}> + + %rescale_bcast_2 = vector.broadcast %max_diff_scalar : {{ data_stype }} to vector<2x{{ data_stype }}> + %exp_rescale_2 = math.exp %rescale_bcast_2 : vector<2x{{ data_stype }}> + + + // Rescale previous out and sum accumulators + %old_out = affine.vector_load %ot_buffer2D[0, 0] : memref<{{ tile_e }}x{{ tile_l }}x{{ data_stype }}, 1>, vector<{{ tile_e }}x{{ data_stype }}> + %rescaled_out = arith.mulf %exp_rescale_e, %old_out : vector<{{ tile_e }}x{{ data_stype }}> + affine.vector_store %rescaled_out, %ot_buffer2D[0, 0] : memref<{{ tile_e }}x{{ tile_l }}x{{ data_stype }}, 1>, vector<{{ tile_e }}x{{ data_stype }}> + + %old_sum = affine.vector_load %sum_buffer[0, 0] : {{ sum_desc.get_mlir_shape(data_stype) }}, vector<2x{{ data_stype }}> + %rescaled_sum = arith.mulf %old_sum, %exp_rescale_2 : vector<2x{{ data_stype }}> + + + // Shift scores and apply exp: exp(x - new_max) + %scaled_scores_reload = affine.vector_load %mul_buffer[0, 0] : {{ mul_tile_desc.get_mlir_shape(data_stype) }}, vector<{{ tile_s }}x{{ data_stype }}> + %new_max_scalar = vector.extract %new_max[0] : {{ data_stype }} from vector<2x{{ data_stype }}> + %new_max_bcast = vector.broadcast %new_max_scalar : {{ data_stype }} to vector<{{ tile_s }}x{{ data_stype }}> + + %shifted_scores = arith.subf %scaled_scores_reload, %new_max_bcast : vector<{{ tile_s }}x{{ data_stype }}> + %exp_scores = math.exp %shifted_scores : vector<{{ tile_s }}x{{ data_stype }}> + affine.vector_store %exp_scores, %mul_buffer[0, 0] : {{ mul_tile_desc.get_mlir_shape(data_stype) }}, vector<{{ tile_s }}x{{ data_stype }}> + + + // accumulate current sum + %chunk_sum_res = affine.for %index5 = 0 to {{ tile_s }} step {{ chunk_size }} iter_args(%iter_sum=%v0_c) -> (vector<{{ chunk_size }}x{{ data_stype }}>) { + %chunk_exp = affine.vector_load %mul_buffer[0, %index5] : {{ mul_tile_desc.get_mlir_shape(data_stype) }}, vector<{{ chunk_size }}x{{ data_stype }}> + %local_sum = arith.addf %chunk_exp, %iter_sum : vector<{{ chunk_size }}x{{ data_stype }}> + affine.yield %local_sum : vector<{{ chunk_size }}x{{ data_stype }}> + } + + %zero_2x = vector.broadcast %c0 : {{ data_stype }} to vector<2x{{ data_stype }}> + %sum_cast = vector.shape_cast %chunk_sum_res : vector<{{ chunk_size }}x{{ data_stype }}> to vector<{{ chunk_size // 2 }}x2x{{ data_stype }}> + %sum_reduced_1 = vector.multi_reduction , %sum_cast, %zero_2x [0] : vector<8x2x{{ data_stype }}> to vector<2x{{ data_stype }}> + %sum_shuffled = vector.shuffle %sum_reduced_1, %sum_reduced_1 [1, 0] : vector<2x{{ data_stype }}>, vector<2x{{ data_stype }}> + %sum_reduced_2 = arith.addf %sum_reduced_1, %sum_shuffled : vector<2x{{ data_stype }}> + + %new_sum = arith.addf %sum_reduced_2, %rescaled_sum : vector<2x{{ data_stype }}> + affine.vector_store %new_sum, %sum_buffer[0, 0] : {{ sum_desc.get_mlir_shape(data_stype) }}, vector<2x{{ data_stype }}> + + + // value.t @ mul + linalg.matmul + { idx_map = array } + ins(%vt_buffer2D, %mul_buffer : memref<{{ tile_e }}x{{ tile_s }}x{{ data_stype }}, 1>, {{ mul_tile_desc.get_mlir_shape(data_stype) }}) + outs(%ot_buffer2D : memref<{{ tile_e }}x{{ tile_l }}x{{ data_stype }}, 1>) + } {inner_loop=true} + + // out @ row_sum^(-1) + %final_row_sum = affine.vector_load %sum_buffer[0, 0] : {{ sum_desc.get_mlir_shape(data_stype) }}, vector<2x{{ data_stype }}> + %one_2x = vector.broadcast %c1 : {{ data_stype }} to vector<2x{{ data_stype }}> + + %reciprocal_row_sum_2x = arith.divf %one_2x, %final_row_sum : vector<2x{{ data_stype }}> + %reciprocal_scalar = vector.extract %reciprocal_row_sum_2x[0] : {{ data_stype }} from vector<2x{{ data_stype }}> + %reciprocal_bcast_e = vector.broadcast %reciprocal_scalar : {{ data_stype }} to vector<{{ tile_e }}x{{ data_stype }}> + + %accumulated_out = affine.vector_load %ot_buffer2D[0, 0] : memref<{{ tile_e }}x{{ tile_l }}x{{ data_stype }}, 1>, vector<{{ tile_e }}x{{ data_stype }}> + %stable_final_out = arith.mulf %accumulated_out, %reciprocal_bcast_e : vector<{{ tile_e }}x{{ data_stype }}> + affine.vector_store %stable_final_out, %ot_buffer2D[0, 0] : memref<{{ tile_e }}x{{ tile_l }}x{{ data_stype }}, 1>, vector<{{ tile_e }}x{{ data_stype }}> + + %out_dram_offset = affine.apply {{ out_offset_map }}(%index0, %index1, %index3) + {{ kernel.def_dma_op("MVOUT", "out", [], out_tile_desc, indent_size=8, dram_stride=out_dram_stride, dram_offset="out_dram_offset") }} + } { accumulation_loop=true } + } { outer_loop=true } + } { outer_loop=true } + return +} +""" + +class MLIRFlashSDPATemplate(MLIRTemplate): + def __init__(self, input_nodes, layout, scale, input_reorder=None): + super().__init__("kernel", input_nodes, layout, input_reorder) + self.scale = scale + + def render(self, + kernel: MLIRTemplateKernel, + template_buffer_node = None, + epilogue_nodes: Optional[List[IRNode]] = None, + prologue_nodes: Optional[List[IRNode]] = None, + tile_info = None, + **kwargs): + + # Except for kernel, other arguments are usually None. + query, key, value, out, q_tensor, k_tensor, v_tensor, out_tensor, b, l, s, e, ev, n_extra_node, n_prologue_node = self.extract_info(template_buffer_node, epilogue_nodes, prologue_nodes) + + if tile_info is None: + tile_l, tile_s, tile_e, subtile_l, subtile_s, subtile_e = self.select_tile(kernel, l, s, e, n_extra_node, 0, n_prologue_node)[0] + else: + tile_l, tile_s, tile_e, subtile_l, subtile_s, subtile_e = tile_info + + TOG_latency = l if tile_l > l else tile_l + kernel.loop_size = [TOG_latency, tile_s, tile_e] + + # Select template code + # Other templates will be added according to situations. + nr_reduction_nodes = [node for node in epilogue_nodes if node.is_reduction()] if epilogue_nodes is not None else [] + if nr_reduction_nodes: + raise NotImplementedError("FLASH_SDPA_REDUCTION_TEMPLATE is not implemented yet.") + elif prologue_nodes: + raise NotImplementedError("FLASH_SDPA_PROLOGUE_TEMPLATE is not implemented yet.") + else: + template = FLASH_SDPA_TEMPLATE + epilogue_dim_aliasing = {"index0":"index0", "index1":"index1", "index2": "index2", "index3": "index3"} + nr_rdim = 0 + + # Prepare tile descriptors for input and output tensors. + # Intermediate buffers (transient data) do not require DRAM settings(dram stride and dram indices) + # as they are not synchronized with external DRAM. + # DRAM and SRAM tile shapes must match. + vlane_stride = 1 + + # (n, l, s, e, ev) + loop_dim = [sympy.Symbol("index0"), sympy.Symbol("index1"), sympy.Symbol("index2"), sympy.Symbol("index3")] + + + # Hardware constraint: The tile split axis is restricted. + # To accommodate this, we compute (key @ query.t) instead of (query @ key.t). + # SRAM settings + vlane_split_axis = 1 + q_tile_size = [1, tile_l, tile_e] + q_tile_stride = [0, tile_e, 1] + q_tile_desc = mlir_common.MLIRMultiDimTile(q_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + q_tile_desc.set_tile_size_stride(q_tile_size, q_tile_stride) + q_tile_desc.set_name("q_buffer") + q_tile_desc.offset = query.get_layout().offset + # DRAM settings + q_stride = q_tensor.stride() + + # Since we use a weight-stationary approach in the Systolic Array (SA), + # the split axis of the first operand differs from a standard linear algebra matmul. + # The first operand (key) must be split along the column axis. + # This logic aligns with the relationship between the dot product's summation direction and the hardware's accumulation direction in the SA. + # SRAM settings + vlane_split_axis = 2 + k_tile_size = [1, tile_s, tile_e] + k_tile_stride = [0, 1, tile_s] + k_tile_desc = mlir_common.MLIRMultiDimTile(k_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + k_tile_desc.set_tile_size_stride(k_tile_size, k_tile_stride) + k_tile_desc.set_name("k_buffer") + k_tile_desc.offset = key.get_layout().offset + # DRAM settings + k_stride = k_tensor.stride() + + # Since we compute mul = key @ query.t, we perform out.t = (value.t @ Softmax(mul).t).t, + # which simplifies to (value.t @ Softmax(mul)) + # SRAM settings + vlane_split_axis = 1 + v_tile_size = [1, tile_s, tile_e] + v_tile_stride = [0, tile_e, 1] + v_tile_desc = mlir_common.MLIRMultiDimTile(v_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + v_tile_desc.set_tile_size_stride(v_tile_size, v_tile_stride) + v_tile_desc.set_name("v_buffer") + v_tile_desc.offset = value.get_layout().offset + # DRAM settings + v_stride = v_tensor.stride() + + # Output is also stored in transposed format to match the value.t @ Softmax(mul) operation. + # SRAM settings + vlane_split_axis = 1 + out_tile_size = [1, tile_l, tile_e] + out_tile_stride=[0, tile_e, 1] + out_tile_desc = mlir_common.MLIRMultiDimTile(out_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + out_tile_desc.set_tile_size_stride(out_tile_size, out_tile_stride) + out_tile_desc.set_name("out_buffer") + # DRAM settings + out_stride = out.get_layout().stride[1:] + + # Intermediate buffers + + # For mul = key @ query.t + vlane_split_axis = 1 + mul_tile_size = [tile_s, tile_l] + mul_tile_stride = [tile_l, 1] + mul_tile_desc = mlir_common.MLIRMultiDimTile(mul_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + mul_tile_desc.set_tile_size_stride(mul_tile_size, mul_tile_stride) + mul_tile_desc.set_name("mul_buffer") + #FIXME. What is the offset? -> It doesn't matter at this time. + + # For storing maximum values per row + vlane_split_axis = 0 + max_size = [tile_l, 2] + max_stride = [2, 1] + max_desc = mlir_common.MLIRMultiDimTile(max_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + max_desc.set_tile_size_stride(max_size, max_stride) + max_desc.set_name("max_buffer") + + # For storing summation per row + vlane_split_axis = 0 + sum_size = [tile_l, 2] + sum_stride = [2, 1] + sum_desc = mlir_common.MLIRMultiDimTile(sum_size, kernel.vector_lane, vlane_split_axis, vlane_stride) + sum_desc.set_tile_size_stride(sum_size, sum_stride) + sum_desc.set_name("sum_buffer") + + # For reduction + chunk_size = 16 + + # DMA strides and offset affine maps (dram_stride + dram_offset style) + q_dram_stride = [int(q_stride[0]), int(q_stride[1]), int(q_stride[2])] + k_dram_stride = [int(k_stride[0]), int(k_stride[1]), int(k_stride[2])] + v_dram_stride = [int(v_stride[0]), int(v_stride[1]), int(v_stride[2])] + out_dram_stride = [int(out_stride[0]), int(out_stride[1]), int(out_stride[2])] + + q_offset_map = _make_offset_map(q_dram_stride, q_tile_desc.offset) + k_offset_map = _make_offset_map(k_dram_stride, k_tile_desc.offset) + v_offset_map = _make_offset_map(v_dram_stride, v_tile_desc.offset) + out_offset_map = _make_offset_map(out_dram_stride, 0) + + # Keep out_idx only for epilogue_info (not in render_options) + out_idx = [loop_dim[0]*out_stride[0], loop_dim[1]*out_stride[1], loop_dim[3]*out_stride[2]] + + kernel.render_options = dict( + KERNEL_NAME = self.name, + kernel = kernel, + b = b, + l = l, + s = s, + e = e, # Input sizes (dram) + tile_l = tile_l, + tile_s = tile_s, + tile_e = tile_e, # Tile sizes (sram) + data_stype="f32", + query = query, + key = key, + value = value, + out = out, # Inputs and output (dram) + q_dram_stride = q_dram_stride, + k_dram_stride = k_dram_stride, + v_dram_stride = v_dram_stride, + out_dram_stride = out_dram_stride, # Per-dim DRAM strides + q_offset_map = q_offset_map, + k_offset_map = k_offset_map, + v_offset_map = v_offset_map, + out_offset_map = out_offset_map, # Affine maps for base address + q_tile_desc = q_tile_desc, + k_tile_desc = k_tile_desc, + v_tile_desc = v_tile_desc, + mul_tile_desc = mul_tile_desc, + out_tile_desc = out_tile_desc, # Tile descriptions (sram) + max_desc = max_desc, + sum_desc = sum_desc, # Intermediate buffer descriptions (sram) + scale = self.scale, + chunk_size = chunk_size, + input_reorder = self.input_reorder # ETC + ) + + code = self._template_from_string(template).render(**kernel.render_options) + kernel.add_loop_info([kernel.render_options["l"], kernel.render_options["s"], kernel.render_options["e"]], [kernel.render_options["tile_l"], kernel.render_options["tile_s"], kernel.render_options["tile_e"]]) + return code + + def extract_info(self, template_buffer_node, epilogue_nodes, prologue_nodes): + if template_buffer_node is not None: + self.output_node = template_buffer_node + + query = self.input_nodes[0] + key = self.input_nodes[1] + value = self.input_nodes[2] + out = self.output_node + + q_tensor = empty_strided(query.layout.size, query.layout.stride) + k_tensor = empty_strided(key.layout.size, key.layout.stride) + v_tensor = empty_strided(value.layout.size, value.layout.stride) + out_tensor = empty_strided(out.layout.size, out.layout.stride) + + # Flatten batch and head dimensions (n, h) into a single dimension (b = n*h) + q_tensor = q_tensor.view([-1, q_tensor.shape[-2], q_tensor.shape[-1]]) + k_tensor = k_tensor.view([-1, k_tensor.shape[-2], k_tensor.shape[-1]]) + v_tensor = v_tensor.view([-1, v_tensor.shape[-2], v_tensor.shape[-1]]) + out_tensor = out_tensor.view([-1, out_tensor.shape[-2], out_tensor.shape[-1]]) + + b, l, s, e, ev = q_tensor.size(0), q_tensor.size(1), k_tensor.size(1), k_tensor.size(2), v_tensor.size(2) + + n_extra_node = len(epilogue_nodes) if epilogue_nodes is not None else 0 + n_prologue_node = len(prologue_nodes) if prologue_nodes is not None else 0 + + return query, key, value, out, q_tensor, k_tensor, v_tensor, out_tensor, b, l, s, e, ev, n_extra_node, n_prologue_node + + # Reuse the existing function in MLIRBMMTemplate. + def select_tile(self, kernel, l, s, e, n_extra_node, n_extra_read, n_prologue_node): + + # FIXME: Update the method for getting tile candidates once TestDmaFineGrained oass works correctly with Flash Attention. + # tile_candidates = kernel.flash_sdpa_mapping(l, s, e, n_extra_node=n_extra_node) + tile_candidates = [[kernel.vector_lane, kernel.vector_lane, e]] + + for idx, (tile_l, tile_s, tile_e) in enumerate(tile_candidates): + subtile_l = tile_l if (tile_l < kernel.vector_lane) or n_prologue_node else kernel.vector_lane + subtile_s = tile_s # if (tile_s < kernel.vector_lane) or prologue_nodes else kernel.vector_lane + subtile_e = tile_e # if (tile_e < kernel.vector_lane) or prologue_nodes else kernel.vector_lane + + tile_candidates[idx] = tile_l,tile_s,tile_e,subtile_l,subtile_s,subtile_e + + return tile_candidates + diff --git a/PyTorchSimFrontend/mlir/mlir_sort_template.py b/PyTorchSimFrontend/mlir/mlir_sort_template.py new file mode 100644 index 00000000..24b3a460 --- /dev/null +++ b/PyTorchSimFrontend/mlir/mlir_sort_template.py @@ -0,0 +1,474 @@ +from typing import List, Optional +import contextlib + +from torch._inductor.ir import Buffer, IRNode +from torch._inductor.virtualized import _ops as ops +from torch._inductor.codegen import common + +from PyTorchSimFrontend.mlir import mlir_common +from PyTorchSimFrontend.mlir.mlir_template import MLIRTemplate, MLIRTemplateKernel +from PyTorchSimFrontend.mlir.mlir_common import LoopLevel + +VECTOR_SIZE = 16 + + +TEMPLATE = r""" +{{kernel.def_global_vars()}} +// chunk index -> element index +#map_chunk_to_elem = affine_map<(d0) -> (d0 * {{ VECTOR_SIZE }})> + +func.func @{{ KERNEL_NAME }} {{kernel.def_kernel(inputs=[X, XI], outputs=[YV], names_str=NAMES_STR, input_reorder=input_reorder)}} { + {{ kernel.def_sram_buffer("X", X_TILE_DESC, id=0, indent_size=2) }} + {{ kernel.def_sram_buffer("XI", XI_TILE_DESC, id=1, indent_size=2) }} + {{ kernel.def_sram_buffer("YV", YV_TILE_DESC, id=2, indent_size=2) }} + {{ kernel.def_local_vars(indent_size=2) }} + + + affine.for %sort_block = 0 to 1 step 1 { + {%- for d in range(RANK-1) %} + affine.for %index{{ OUTPUT_DIM[d] }} = 0 to {{ OUTPUT_SIZES[d] }} step {{ STEP_SIZES[d] }} { + {%- endfor %} + + %x_dram_offset = affine.apply {{ X_OFFSET_MAP }}({{ OUTER_VARS }}) + %xi_dram_offset = affine.apply {{ XI_OFFSET_MAP }}({{ OUTER_VARS }}) + %yv_dram_offset = affine.apply {{ YV_OFFSET_MAP }}({{ OUTER_VARS }}) + {{ kernel.def_dma_op("MVIN", "X", [], X_TILE_DESC, indent_size=INDENT_SIZE, dram_stride=X_DRAM_STRIDE, dram_offset="x_dram_offset") }} + + // SIMD local sort + loop-based chunk merge. +{{ BITONIC_BODY }} + + {{ kernel.def_dma_op("MVOUT", "XI", [], XI_TILE_DESC, indent_size=INDENT_SIZE, dram_stride=XI_DRAM_STRIDE, dram_offset="xi_dram_offset") }} + {{ kernel.def_dma_op("MVOUT", "YV", [], YV_TILE_DESC, indent_size=INDENT_SIZE, dram_stride=YV_DRAM_STRIDE, dram_offset="yv_dram_offset") }} + {%- for d in range(RANK-1) %} + } { outer_loop=true } + {%- endfor %} + } { outer_loop=true } + return +} +""" + + +def _make_offset_map(outer_dims, all_strides, layout_offset): + """Build an affine_map over outer-dim loop variables that computes the flat DRAM offset.""" + terms = [] + for j, d in enumerate(outer_dims): + s = int(all_strides[d]) + if s == 1: + terms.append(f"d{j}") + elif s != 0: + terms.append(f"d{j} * {s}") + try: + off = int(layout_offset) + except (TypeError, ValueError): + off = 0 + if off: + terms.append(str(off)) + nd = len(outer_dims) + dim_str = ", ".join(f"d{j}" for j in range(nd)) + expr = " + ".join(terms) if terms else "0" + return f"affine_map<({dim_str}) -> ({expr})>" + + +def _compute_bitonic_stages(n: int, descending: bool): + stages = [] + size = 2 + while size <= n: + stride = size // 2 + while stride >= 1: + merged_shuffle = list(range(n)) + merged_mask = [None] * n + for start in range(0, n, size): + blk_dir = "ASCENDING" if (start // size) % 2 == 0 else "DESCENDING" + for i in range(start, start + size - stride, stride * 2): + for j2 in range(stride): + a, b = i + j2, i + j2 + stride + merged_shuffle[a] = b + merged_shuffle[b] = a + if blk_dir == "ASCENDING": + merged_mask[a] = True + merged_mask[b] = False + else: + merged_mask[a] = False + merged_mask[b] = True + select_min = [bool(x) if x is not None else False for x in merged_mask] + if descending: + select_min = [not x for x in select_min] + stages.append({"shuffle": merged_shuffle, "select_min": select_min}) + stride //= 2 + size *= 2 + return stages + + +def _pair_less_equal(left_v, right_v, left_i, right_i): + cmp_val = ops.lt(left_v, right_v) + cmp_eq = ops.eq(left_v, right_v) + cmp_idx = ops.le(left_i, right_i) + return ops.or_(cmp_val, ops.and_(cmp_eq, cmp_idx)) + + +def _pair_greater_equal(left_v, right_v, left_i, right_i): + cmp_val = ops.gt(left_v, right_v) + cmp_eq = ops.eq(left_v, right_v) + cmp_idx = ops.le(left_i, right_i) + return ops.or_(cmp_val, ops.and_(cmp_eq, cmp_idx)) + + +def _bitonic_sort_pair(values, indices, vector_size: int, descending: bool, stable_sort: bool): + cur_v = values + cur_i = indices + for stage_desc in _compute_bitonic_stages(vector_size, descending): + mask = ops.constant_mask(stage_desc["select_min"], vector_size) + shuf_v = ops.vector_shuffle(cur_v, stage_desc["shuffle"]) + shuf_i = ops.vector_shuffle(cur_i, stage_desc["shuffle"]) + if stable_sort: + # `cmp` drives the "min side" selection in the bitonic network. + # For descending stable sort, tie elements with smaller original index + # must stay earlier, so the min side should treat larger index as smaller. + if descending: + cmp_val = ops.lt(cur_v, shuf_v) + cmp_eq = ops.eq(cur_v, shuf_v) + cmp_idx = ops.ge(cur_i, shuf_i) + cmp = ops.or_(cmp_val, ops.and_(cmp_eq, cmp_idx)) + else: + cmp = _pair_less_equal(cur_v, shuf_v, cur_i, shuf_i) + else: + cmp = ops.le(cur_v, shuf_v) + min_v = ops.where(cmp, cur_v, shuf_v) + min_i = ops.where(cmp, cur_i, shuf_i) + max_v = ops.where(cmp, shuf_v, cur_v) + max_i = ops.where(cmp, shuf_i, cur_i) + cur_v = ops.where(mask, min_v, max_v) + cur_i = ops.where(mask, min_i, max_i) + return cur_v, cur_i + + +def _merge_sorted_pair_vectors( + left_norm, + left_idx_norm, + right_norm, + right_idx_norm, + ascending: bool, + stable_sort: bool, + vector_size: int, + rev_indices, +): + right_pair = ops.vector_shuffle(right_norm, rev_indices, right_norm) + right_idx_pair = ops.vector_shuffle(right_idx_norm, rev_indices, right_idx_norm) + if ascending: + cmp = ( + _pair_less_equal(left_norm, right_pair, left_idx_norm, right_idx_pair) + if stable_sort + else ops.le(left_norm, right_pair) + ) + else: + cmp = ( + _pair_greater_equal(left_norm, right_pair, left_idx_norm, right_idx_pair) + if stable_sort + else ops.ge(left_norm, right_pair) + ) + left_merge = ops.where(cmp, left_norm, right_pair) + left_idx_merge = ops.where(cmp, left_idx_norm, right_idx_pair) + right_merge = ops.where(cmp, right_pair, left_norm) + right_idx_merge = ops.where(cmp, right_idx_pair, left_idx_norm) + return left_merge, left_idx_merge, right_merge, right_idx_merge + + +class MLIRSortTemplate(MLIRTemplate): + def __init__(self, input_nodes, layout, dim, descending=False, stable=False, input_reorder=None): + super().__init__("kernel", input_nodes, layout, input_reorder) + self.dim = dim + self.descending = descending + self.stable = stable + self.use_stable_sort = False + self.output_nodes = [ + Buffer(name="buf_out_values", layout=layout), + ] + self.output_node = self.output_nodes[0] + + def render( + self, + kernel: MLIRTemplateKernel, + template_buffer_node=None, + epilogue_nodes: Optional[List[IRNode]] = None, + tile_info=None, + **kwargs, + ): + if template_buffer_node is not None: + self.output_nodes[0] = template_buffer_node + self.output_node = template_buffer_node + + x = self.input_nodes[0] + xi = self.input_nodes[1] + yv = self.output_nodes[0] + # XI is updated in-place by the sort kernel, so mark it as an inout arg. + kernel.kernel_group.args.make_inplace(xi.get_name(), xi.get_name()) + sort_size = int(x.get_size()[self.dim]) + vector_size = VECTOR_SIZE + if sort_size <= 0: + raise NotImplementedError("Sort size must be > 0") + if sort_size < vector_size or sort_size % vector_size != 0: + raise NotImplementedError( + f"Sort size must be a multiple of vector size (sort_size={sort_size}, vector_size={vector_size})" + ) + num_chunks = sort_size // vector_size + if num_chunks & (num_chunks - 1): + raise NotImplementedError( + f"Loop-based bitonic chunk merge requires power-of-two chunk count (num_chunks={num_chunks})" + ) + + # --- N-D generalization: outer loops over all non-sort dims --- + rank = len(x.get_size()) + sort_dim = self.dim if self.dim >= 0 else self.dim + rank + if sort_dim < 0 or sort_dim >= rank: + raise NotImplementedError(f"Invalid sort dim for rank-{rank} tensor (dim={self.dim})") + x_layout = x.get_layout() + xi_layout = xi.get_layout() + yv_layout = yv.get_layout() + + if rank == 1: + # Edge case for 1D tensor + output_sizes = [1] + output_dim = [0] + step_sizes = [1] + tile_sizes = [1, sort_size] + x_dram_stride = [int(x_layout.stride[sort_dim]), int(x_layout.stride[sort_dim])] + xi_dram_stride = [int(xi_layout.stride[sort_dim]), int(xi_layout.stride[sort_dim])] + yv_dram_stride = [int(yv_layout.stride[sort_dim]), int(yv_layout.stride[sort_dim])] + template_rank = 2 + else: + output_sizes = [sz for d, sz in enumerate(yv.get_size()) if d != sort_dim] + output_dim = [d for d, _ in enumerate(yv.get_size()) if d != sort_dim] + step_sizes = [1] * len(output_sizes) + + tile_dim = max(output_dim, key=lambda d: int(yv.get_size()[d])) + tile_sizes = [min(kernel.vector_lane, int(yv.get_size()[tile_dim])), sort_size] + step_sizes[output_dim.index(tile_dim)] = tile_sizes[0] + + x_dram_stride = [int(x_layout.stride[tile_dim]), int(x_layout.stride[sort_dim])] + xi_dram_stride = [int(xi_layout.stride[tile_dim]), int(xi_layout.stride[sort_dim])] + yv_dram_stride = [int(yv_layout.stride[tile_dim]), int(yv_layout.stride[sort_dim])] + template_rank = rank + + x_offset_map = _make_offset_map(output_dim, x_layout.stride, x_layout.offset) + xi_offset_map = _make_offset_map(output_dim, xi_layout.stride, xi_layout.offset) + yv_offset_map = _make_offset_map(output_dim, yv_layout.stride, yv_layout.offset) + outer_vars = ", ".join(f"%index{d}" for d in output_dim) + + # indent for DMA ops = 2 (inside func) + 2 per outer loop + indent_size = 2 + len(output_dim) * 2 + 4 + + vlane_stride = 1 + vlane_split_axis = 0 + x_tile_desc = mlir_common.MLIRMultiDimTile(tile_sizes, kernel.vector_lane, vlane_split_axis, vlane_stride) + x_tile_desc.set_tile_size_stride(tile_sizes, [sort_size, 1]) + x_tile_desc.set_name("X_buffer") + x_tile_desc.offset = x_layout.offset + + xi_tile_desc = mlir_common.MLIRMultiDimTile(tile_sizes, kernel.vector_lane, vlane_split_axis, vlane_stride) + xi_tile_desc.set_tile_size_stride(tile_sizes, [sort_size, 1]) + xi_tile_desc.set_name("XI_buffer") + xi_tile_desc.offset = xi_layout.offset + + yv_tile_desc = mlir_common.MLIRMultiDimTile(tile_sizes, kernel.vector_lane, vlane_split_axis, vlane_stride) + yv_tile_desc.set_tile_size_stride(tile_sizes, [sort_size, 1]) + yv_tile_desc.set_name("YV_buffer") + yv_tile_desc.offset = yv_layout.offset + + data_stype = mlir_common.DTYPE_TO_MLIR[x.get_dtype()] + idx_stype = mlir_common.DTYPE_TO_MLIR[xi.get_dtype()] + + elem_memref_t = f"memref<1x{sort_size}x{data_stype}, 1>" + rev_indices = list(range(vector_size - 1, -1, -1)) + + bitonic_body = mlir_common.ParallelLoopBuffer(initial_indent=2) + bitonic_body.tabwidth = 2 + # 1) Local SIMD sort per chunk. + init_cse = common.CSE(kernel.newvar_prefix, kernel.suffix, name_prefix="sort_init") + with kernel, kernel.override_buffer_cse(buffer=bitonic_body, cse=init_cse): + bitonic_body.writelines(LoopLevel("chunk", num_chunks).lines()) + with bitonic_body.indent(attribute="{inner_loop=true}"): + bitonic_body.writeline("%elem = affine.apply #map_chunk_to_elem(%chunk)") + x_chunk = ops._load( + vector_size, + data_stype, + "X_buffer", + "%t_const0, %elem", + x_tile_desc.get_mlir_shape(data_stype), + ) + idx_step_index = kernel.register_var_cse("idx_step_index", vector_size, "index") + bitonic_body.writeline(f"%{idx_step_index} = vector.step : vector<{vector_size}xindex>") + idx_step = ops.index_cast(idx_step_index, idx_stype) + idx_base = kernel.register_var_cse("idx_base", 1, idx_stype) + bitonic_body.writeline(f"%{idx_base} = arith.index_cast %elem : index to {idx_stype}") + idx_base_vec = ops.broadcast(idx_base, vector_size) + idx_chunk = ops.add(idx_base_vec, idx_step) + yv_chunk, yi_chunk = _bitonic_sort_pair( + x_chunk, idx_chunk, vector_size, descending=self.descending, stable_sort=self.use_stable_sort + ) + ops._store( + yv_chunk, + "YV_buffer", + "%t_const0, %elem", + yv_tile_desc.get_mlir_shape(data_stype), + ) + ops._store( + yi_chunk, + "XI_buffer", + "%t_const0, %elem", + xi_tile_desc.get_mlir_shape(idx_stype), + ) + + # 2) Chunk-level bitonic merge (loop form). + stage = 0 + k = 2 + while k <= num_chunks: + j = k // 2 + while j >= 1: + for block_start, is_even_block in ((0, True), (k, False)): + if block_start >= num_chunks: + continue + asc_dir = is_even_block if not self.descending else (not is_even_block) + stage_cse = common.CSE(kernel.newvar_prefix, kernel.suffix, name_prefix=f"sort_stage_{stage}") + with kernel, kernel.override_buffer_cse(buffer=bitonic_body, cse=stage_cse): + stage_loops = [ + LoopLevel("base", num_chunks, start=block_start, step=2 * k), + LoopLevel("p", k, step=2 * j), + LoopLevel("q", j), + ] + with contextlib.ExitStack() as stack: + for loop in stage_loops: + bitonic_body.writelines(loop.lines()) + stack.enter_context(bitonic_body.indent(attribute="{inner_loop=true}")) + + bitonic_body.writeline( + f"%left_elem = affine.apply affine_map<(d0, d1, d2) -> ((d0 + d1 + d2) * {vector_size})>(%base, %p, %q)" + ) + bitonic_body.writeline( + f"%right_elem = affine.apply affine_map<(d0, d1, d2) -> ((d0 + d1 + d2 + {j}) * {vector_size})>(%base, %p, %q)" + ) + + left_vec = ops._load( + vector_size, + data_stype, + "YV_buffer", + "%t_const0, %left_elem", + yv_tile_desc.get_mlir_shape(data_stype), + ) + right_vec = ops._load( + vector_size, + data_stype, + "YV_buffer", + "%t_const0, %right_elem", + yv_tile_desc.get_mlir_shape(data_stype), + ) + left_idx = ops._load( + vector_size, + idx_stype, + "XI_buffer", + "%t_const0, %left_elem", + xi_tile_desc.get_mlir_shape(idx_stype), + ) + right_idx = ops._load( + vector_size, + idx_stype, + "XI_buffer", + "%t_const0, %right_elem", + xi_tile_desc.get_mlir_shape(idx_stype), + ) + norm_desc = not asc_dir + left_norm, left_idx_norm = _bitonic_sort_pair( + left_vec, left_idx, vector_size, descending=norm_desc, stable_sort=self.use_stable_sort + ) + right_norm, right_idx_norm = _bitonic_sort_pair( + right_vec, right_idx, vector_size, descending=norm_desc, stable_sort=self.use_stable_sort + ) + left_merge, left_idx_merge, right_merge, right_idx_merge = _merge_sorted_pair_vectors( + left_norm, + left_idx_norm, + right_norm, + right_idx_norm, + ascending=asc_dir, + stable_sort=self.use_stable_sort, + vector_size=vector_size, + rev_indices=rev_indices, + ) + left_new, left_idx_new = _bitonic_sort_pair( + left_merge, left_idx_merge, vector_size, descending=norm_desc, stable_sort=self.use_stable_sort + ) + right_new, right_idx_new = _bitonic_sort_pair( + right_merge, right_idx_merge, vector_size, descending=norm_desc, stable_sort=self.use_stable_sort + ) + ops._store( + left_new, + "YV_buffer", + "%t_const0, %left_elem", + yv_tile_desc.get_mlir_shape(data_stype), + ) + ops._store( + right_new, + "YV_buffer", + "%t_const0, %right_elem", + yv_tile_desc.get_mlir_shape(data_stype), + ) + ops._store( + left_idx_new, + "XI_buffer", + "%t_const0, %left_elem", + xi_tile_desc.get_mlir_shape(idx_stype), + ) + ops._store( + right_idx_new, + "XI_buffer", + "%t_const0, %right_elem", + xi_tile_desc.get_mlir_shape(idx_stype), + ) + stage += 1 + j //= 2 + k *= 2 + + kernel.render_options = dict( + KERNEL_NAME=self.name, + NAMES_STR="X, XI, YV", + kernel=kernel, + X=x, + XI=xi, + YV=yv, + X_TILE_DESC=x_tile_desc, + XI_TILE_DESC=xi_tile_desc, + YV_TILE_DESC=yv_tile_desc, + SORT_SIZE=sort_size, + VECTOR_SIZE=vector_size, + DATA_STYPE=data_stype, + IDX_STYPE=idx_stype, + ELEM_MEMREF_T=elem_memref_t, + BITONIC_BODY=bitonic_body.getvalue().rstrip(), + input_reorder=self.input_reorder, + # N-D generalization + RANK = template_rank, + OUTPUT_SIZES = output_sizes, + OUTPUT_DIM = output_dim, + STEP_SIZES = step_sizes, + OUTER_VARS = outer_vars, + X_OFFSET_MAP = x_offset_map, + XI_OFFSET_MAP = xi_offset_map, + YV_OFFSET_MAP = yv_offset_map, + X_DRAM_STRIDE = x_dram_stride, + XI_DRAM_STRIDE = xi_dram_stride, + YV_DRAM_STRIDE = yv_dram_stride, + INDENT_SIZE = indent_size, + ) + code = self._template_from_string(TEMPLATE).render(**kernel.render_options) + return code + + +class MLIRStableSortTemplate(MLIRSortTemplate): + def __init__(self, input_nodes, layout, dim, descending=False, stable=True, input_reorder=None): + super().__init__( + input_nodes=input_nodes, + layout=layout, + dim=dim, + descending=descending, + stable=stable, + input_reorder=input_reorder, + ) + self.use_stable_sort = True diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index b1c756ba..851f070f 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -14,7 +14,7 @@ from unittest.mock import patch from torch._inductor.codegen.common import KernelTemplate, CSE, DeferredLine -from torch._inductor.ir import Buffer, IRNode, TemplateBuffer, ChoiceCaller +from torch._inductor.ir import Buffer, IRNode, TemplateBuffer, ChoiceCaller, ir_node_to_tensor from torch._inductor.select_algorithm import PartialRender from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller from torch._inductor.autotune_process import TensorMeta @@ -112,7 +112,8 @@ def __init__(self, self.outer_func_name = outer_func_name self.outer_func_render = outer_func_render self.kernel_arg_attributes = kernel_arg_attributes - self.render_hooks = OrderedDict() + self.render_hooks = OrderedDict() # Stores {key: (priority, hook)} + self.dma_op_counter = itertools.count() # Add counter for unique DMA op keys self.buffer_names = dict() self.render_options = dict() self.tile_size = [] @@ -124,6 +125,7 @@ def __init__(self, self.epilogue_buffer_group = IndentedBufferGroup(self, prefix="epilogue_") self.global_vars = IndentedBuffer() self.exception_nodes = {} + self.epilogue_info = {} # Reduction data structure self.reduction_epilogue_suffix = IndentedBuffer() self.reduction_fusion = False @@ -148,10 +150,10 @@ def add_loop_info(self, mat_size, tile_size): for idx, (loop_size, stride) in enumerate(zip(mat_size, tile_size)): self.loop_info[f"index{idx}"] = [0, loop_size, stride] - def gemmini_gemm_mapping(self, M, N, K): + def gemmini_gemm_mapping(self, M, N, K, precision_bytes=4): spad_size = self.spad_info["spad_size"] * self.vector_lane num_cores = self.num_cores - precision = self.precision + precision = precision_bytes dim_I, dim_J, dim_K = M, N, K dim = self.vector_lane @@ -203,7 +205,7 @@ def gemmini_gemm_mapping(self, M, N, K): return inner_I, inner_J, inner_K - def gemm_combination_mapping(self, M, N, K, n_extra_node=0, n_prologue_node=0, pad_k=True, min_tile=False, is_conv=False): + def gemm_combination_mapping(self, M, N, K, n_extra_node=0, n_prologue_node=0, pad_k=True, min_tile=False, is_conv=False, precision_bytes=4): tile_candidates = [] spad_size_per_lane = self.spad_info["spad_size"] spad_size = spad_size_per_lane * self.vector_lane @@ -231,11 +233,11 @@ def gemm_combination_mapping(self, M, N, K, n_extra_node=0, n_prologue_node=0, p tile_M = i * self.vector_lane if M > self.vector_lane else M_padded for j in tile_N_range: tile_N = j * self.vector_lane if N > self.vector_lane else N_padded - used_spad_size = (tile_M * tile_K * (1 + n_prologue_node) + tile_K * tile_N + tile_M * tile_N * (1 + n_extra_node)) * self.precision + used_spad_size = (tile_M * tile_K * (1 + n_prologue_node) + tile_K * tile_N + tile_M * tile_N * (1 + n_extra_node)) * precision_bytes weight_size_per_lane = self.get_spad_size_per_lane(tile_K, tile_N) input_size_per_lane = self.get_spad_size_per_lane(tile_M * (1 + n_prologue_node), tile_K) output_size_per_lane = self.get_spad_size_per_lane(tile_M * (1 + n_extra_node), tile_N) - used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * self.precision + used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * precision_bytes check_spad_size = (used_spad_size < max_spad_size and used_spad_size_per_lane < max_spad_per_lane) if check_spad_size: dir_path = f"{extension_config.CONFIG_TORCHSIM_DIR}/validation/gemm_candidates" @@ -257,11 +259,11 @@ def gemm_combination_mapping(self, M, N, K, n_extra_node=0, n_prologue_node=0, p tile_M = i * self.vector_lane if M > self.vector_lane else M_padded for j in tile_N_range: tile_N = j * self.vector_lane if N > self.vector_lane else N_padded - used_spad_size = (tile_M * tile_K * (1 + n_prologue_node) + tile_K * tile_N + tile_M * tile_N * (1 + n_extra_node)) * self.precision + used_spad_size = (tile_M * tile_K * (1 + n_prologue_node) + tile_K * tile_N + tile_M * tile_N * (1 + n_extra_node)) * precision_bytes weight_size_per_lane = self.get_spad_size_per_lane(tile_K, tile_N) input_size_per_lane = self.get_spad_size_per_lane(tile_M * (1 + n_prologue_node), tile_K) output_size_per_lane = self.get_spad_size_per_lane(tile_M * (1 + n_extra_node), tile_N) - used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * self.precision + used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * precision_bytes n_tile = math.ceil(M / max(tile_M, 128)) * math.ceil(N / max(tile_N, 128)) check_spad_size = (used_spad_size < max_spad_size and used_spad_size_per_lane < max_spad_per_lane) if check_spad_size and max_used_spad_size < used_spad_size and maximize_i_j <= tile_M * tile_N and n_tile >= minimum_n_tile and max(tile_N, 128) // max(tile_M, 128) < 10: @@ -275,7 +277,7 @@ def gemm_combination_mapping(self, M, N, K, n_extra_node=0, n_prologue_node=0, p tile_candidates = [v for _, v in tile_candidates] return tile_candidates - def conv_combination_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, n_extra_node=0): + def conv_combination_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, n_extra_node=0, precision_bytes=4): tile_candidates = [] spad_size_per_lane = self.spad_info["spad_size"] spad_size = spad_size_per_lane * self.vector_lane @@ -283,7 +285,7 @@ def conv_combination_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation max_spad_per_lane = spad_size_per_lane // 2 # double buffer max_used_spad_size = 0 - M, N, K = self.gemm_combination_mapping(M, N, K, n_extra_node=n_extra_node, pad_k=False, is_conv=True)[0] + M, N, K = self.gemm_combination_mapping(M, N, K, n_extra_node=n_extra_node, pad_k=False, is_conv=True, precision_bytes=precision_bytes)[0] max_k_h_w = 1 # maximize kernel size max_o_h_w = 1 # maximize output size K = min(K, self.vector_lane) @@ -296,11 +298,11 @@ def conv_combination_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation weight_size = k_w * k_h * K * N input_size = i_w * i_h * M * K output_size = o_w * o_h * M * N - used_spad_size = (weight_size + input_size + output_size * (1 + n_extra_node)) * self.precision + used_spad_size = (weight_size + input_size + output_size * (1 + n_extra_node)) * precision_bytes weight_size_per_lane = self.get_spad_size_per_lane(k_w * k_h * K, N) input_size_per_lane = self.get_spad_size_per_lane(i_w * i_h * M, K) output_size_per_lane = self.get_spad_size_per_lane(o_w * o_h * M * (1 + n_extra_node), N) - used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * self.precision + used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * precision_bytes check_spad_size = (used_spad_size < max_spad_size and used_spad_size_per_lane < max_spad_per_lane) if check_spad_size: tile_candidates.append((used_spad_size, (k_h, k_w, o_h, o_w, M, N, K))) @@ -316,7 +318,7 @@ def conv_combination_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation tile_candidates = [v for _, v in tile_candidates] return tile_candidates - def conv_multi_tile_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, n_extra_node=0): + def conv_multi_tile_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, n_extra_node=0, precision_bytes=4): tile_candidates = [] spad_size_per_lane = self.spad_info["spad_size"] spad_size = spad_size_per_lane * self.vector_lane @@ -324,7 +326,7 @@ def conv_multi_tile_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, max_spad_per_lane = spad_size_per_lane // 2 max_used_spad_size = 0 - M, N, K = self.gemm_combination_mapping(M, N, K * K_W, n_extra_node=n_extra_node, pad_k=False, is_conv=True)[0] + M, N, K = self.gemm_combination_mapping(M, N, K * K_W, n_extra_node=n_extra_node, pad_k=False, is_conv=True, precision_bytes=precision_bytes)[0] max_k_h_w = K_W for o_h in sympy.divisors(O_H): for o_w in sympy.divisors(O_W): @@ -334,11 +336,11 @@ def conv_multi_tile_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, weight_size = 1 * k_h * K * N input_size = i_w * i_h * M * K output_size = o_w * o_h * M * N - used_spad_size = (weight_size + input_size + output_size * (1 + n_extra_node)) * self.precision + used_spad_size = (weight_size + input_size + output_size * (1 + n_extra_node)) * precision_bytes weight_size_per_lane = self.get_spad_size_per_lane(1 * k_h * K, N) input_size_per_lane = self.get_spad_size_per_lane(i_w * i_h * M, K) output_size_per_lane = self.get_spad_size_per_lane(o_w * o_h * M * (1 + n_extra_node), N) - used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * self.precision + used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * precision_bytes check_spad_size = (used_spad_size < max_spad_size and used_spad_size_per_lane < max_spad_per_lane) if check_spad_size: tile_candidates.append((used_spad_size, (k_h, K_W, o_h, o_w, M, N, K))) @@ -352,7 +354,7 @@ def conv_multi_tile_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, tile_candidates = [v for _, v in tile_candidates] return tile_candidates - def conv_single_batch_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, n_extra_node=0): + def conv_single_batch_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilation, n_extra_node=0, precision_bytes=4): tile_candidates = [] spad_size_per_lane = self.spad_info["spad_size"] spad_size = spad_size_per_lane * self.vector_lane @@ -360,7 +362,7 @@ def conv_single_batch_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilatio max_spad_per_lane = spad_size_per_lane // 2 max_used_spad_size = 0 - M, N, K = self.gemm_combination_mapping(O_W, N, K, n_extra_node=n_extra_node, pad_k=False, is_conv=True)[0] + M, N, K = self.gemm_combination_mapping(O_W, N, K, n_extra_node=n_extra_node, pad_k=False, is_conv=True, precision_bytes=precision_bytes)[0] max_k_h_w = 1 for o_h in sympy.divisors(O_H): for k_h in sympy.divisors(K_H): @@ -370,11 +372,11 @@ def conv_single_batch_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilatio weight_size = k_w * k_h * K * N input_size = i_w * i_h * k_w * K output_size = M * o_h * N - used_spad_size = (weight_size + input_size + output_size * (1 + n_extra_node)) * self.precision + used_spad_size = (weight_size + input_size + output_size * (1 + n_extra_node)) * precision_bytes weight_size_per_lane = self.get_spad_size_per_lane(k_w * k_h * K, N) input_size_per_lane = self.get_spad_size_per_lane(i_w * i_h * k_w, K) output_size_per_lane = self.get_spad_size_per_lane(M * o_h * (1 + n_extra_node), N) - used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * self.precision + used_spad_size_per_lane = (weight_size_per_lane + input_size_per_lane + output_size_per_lane) * precision_bytes check_spad_size = (used_spad_size < max_spad_size and used_spad_size_per_lane < max_spad_per_lane) if check_spad_size: tile_candidates.append((used_spad_size, (k_h, k_w, o_h, M, M, N, K))) @@ -387,6 +389,100 @@ def conv_single_batch_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilatio tile_candidates = sorted(tile_candidates, key=lambda x: x[0], reverse=True) tile_candidates = [v for _, v in tile_candidates] return tile_candidates + + # Flash Attention requires more SRAM compared to standard GEMM. + # Total buffers needed: query, key, value, out, mul, max, sum + # Tensor Shapes: + # query (tile_l, tile_e), key (tile_s, tile_e), value (tile_s, tile_e), mul (tile_s, tile_l), out(tile_l, tile_e) + # max, sum : (tile_l, 2) + def flash_sdpa_mapping(self, l, s, e, n_extra_node=0, n_prologue_node=0, pad_e=True, min_tile=False, is_conv=False): + tile_candidates = [] + + spad_size_per_lane = self.spad_info["spad_size"] + spad_size = spad_size_per_lane * self.vector_lane + + # Double buffering + max_spad_per_lane = spad_size_per_lane // 2 + max_spad_size = spad_size // 2 + + # Padding for utilization + minimum_tile_size = 8 + minimum_n_tile = self.num_cores if min_tile else 1 + l_pad_factor = self.vector_lane if l > self.vector_lane else minimum_tile_size + s_pad_factor = self.vector_lane if s > self.vector_lane else minimum_tile_size + + pad = lambda x, factor: ((x + factor - 1) // factor) * factor + l_padded = pad(l, l_pad_factor) + s_padded = pad(s, s_pad_factor) + + # Calculate the total number of vector-sized blocks + l_idx = l_padded // self.vector_lane + s_idx = s_padded // self.vector_lane + + # Generate candidates for the number of blocks per tile + l_tile_range = sympy.divisors(l_idx) if l > self.vector_lane else [1] + s_tile_range = sympy.divisors(s_idx) if s > self.vector_lane else [1] + + # Convert block count to actual tile size + maximize_i_j = 1 + max_used_spad_size = 0 + + # Flash Attention does not tile along the head dimension (e or ev). + tile_e = e + + for i in l_tile_range: + tile_l = i * self.vector_lane if l > self.vector_lane else l_padded + for j in s_tile_range: + tile_s = j * self.vector_lane if s > self.vector_lane else s_padded + + # Calculate used spad size + used_spad_size = ( + tile_l * tile_e * (1 + n_prologue_node) # query + + tile_s * tile_e # key + + tile_s * tile_e # value + + tile_s * tile_l # mul + + tile_l * tile_e * (1 + n_extra_node) # out + + (tile_l * 2) * 2 # max, sum + ) * self.precision + + # Calculate used spad size per lane. + query_per_lane = tile_e * (1+n_prologue_node) + key_per_lane = tile_s + value_per_lane = tile_e + mul_per_lane = tile_s + out_per_lane = tile_e * (1 + n_extra_node) + vec_per_lane = 2 * 2 + + used_spad_per_lane = ( + query_per_lane + + key_per_lane + + value_per_lane + + mul_per_lane + + out_per_lane + + vec_per_lane + ) * self.precision + + # Add the validated candidate to the list if it passes all hardware constraints. + n_tile = math.ceil(l / max(tile_l, 128)) * math.ceil(s / max(tile_s, 128)) + check_spad_size = (used_spad_size < max_spad_size and used_spad_per_lane < max_spad_per_lane) + + if (check_spad_size + and max_used_spad_size < used_spad_size # SRAM utilization + and maximize_i_j <= tile_l * tile_s # Larger tile + and n_tile >= minimum_n_tile # Pallelism + and max(tile_s, 128) // max(tile_l, 128) < 10): # Balanced Shape + max_used_spad_size = used_spad_size + maximize_i_j = tile_l * tile_s + + if check_spad_size: + tile_candidates.append((used_spad_size, (tile_l, tile_s, tile_e))) + + # Sort by used_spad_size. + # tile_candidates[0] is the best solution we have. + tile_candidates = sorted(tile_candidates, key=lambda x: x[0], reverse=True) + tile_candidates = [v for _, v in tile_candidates] + + return tile_candidates def meta_kernel(self): kernel_arg_attributes = self.kernel_arg_attributes @@ -460,11 +556,11 @@ def codegen_template_code(self, render, template_node, prologue_nodes, epilogue_ } node.codegen((vars, reduction_vars)) - # Codegen epilogue nodes - tile_desc = kernel.set_tile_size(kernel.epilogue_info) - kernel.kernel_group.set_tile_info(tile_desc) - kernel.call_ranges = None if epilogue_nodes: + # Codegen epilogue nodes + tile_desc = kernel.set_tile_size(kernel.epilogue_info) + kernel.kernel_group.set_tile_info(tile_desc) + kernel.call_ranges = None with kernel.epilogue_buffer_group.as_local(): _, (group, reduction_group) = max( epilogue_nodes, key=lambda x: int(x.is_reduction()) @@ -554,7 +650,7 @@ def template_store(): dram_var = self.epilogue_info["dram_var"] index_list = self.epilogue_info["dram_idx"] tile_desc = self.epilogue_info["dram_tile_desc"] - code = self.def_dma_op("MVOUT", dram_var, index_list, tile_desc) + code = self.def_dma_op("MVOUT", dram_var, index_list, tile_desc, lazy_mode=False) self.cse.generate(self.dma_stores, code, assignment = False) body = IndentedBuffer() @@ -625,14 +721,34 @@ def def_kernel( extra_node[node.get_name()] = node.node else: extra_node[node.get_name()] = node - self.buffer_names[node.get_name()] = self.epilogue_info['sram_var'] + + if 'sram_var' in self.epilogue_info: + self.buffer_names[node.get_name()] = self.epilogue_info['sram_var'] def hook(): - arg_defs, *_ = self.kernel_group.args.mlir_argdefs(extra_node=extra_node) - return f"({', '.join(arg_defs)})" + arg_defs, call_args, *_ = self.kernel_group.args.mlir_argdefs(extra_node=extra_node) + output_names = names[len(inputs) : len(inputs) + len(outputs)] + out_ptr_idx = 0 + renamed_arg_defs = [] + for outer, arg_def in zip(call_args, arg_defs): + raw_symbol = arg_def.split(":", 1)[0].strip().lstrip("%") + if outer in self.kernel_group.args.input_buffers: + symbol = self.kernel_group.args.input_buffers[outer] + elif outer in self.kernel_group.args.output_buffers: + symbol = self.kernel_group.args.output_buffers[outer] + elif raw_symbol.startswith("out_ptr") and out_ptr_idx < len(output_names): + symbol = output_names[out_ptr_idx] + out_ptr_idx += 1 + elif outer in self.kernel_group.args.sizevars: + symbol = self.kernel_group.args.sizevars[outer] + else: + symbol = raw_symbol + _, arg_type = arg_def.split(":", 1) + renamed_arg_defs.append(f"%{symbol}:{arg_type}") + return f"({', '.join(renamed_arg_defs)})" assert "" not in self.render_hooks - self.render_hooks[""] = hook + self.render_hooks[""] = (5, hook) # Default priority 5 return "" # This function is a temporal function for convolution because currently convolution kernel is not considering padding. @@ -670,7 +786,8 @@ def def_conv_kernel( self.kernel_group.args.output_buffers[node.get_name()] = name self.store_buffer_names.add(node.get_name()) #TODO: Is this enough not calling store() in mlir_common.py? self.extra_node[node.get_name()] = node - self.buffer_names[node.get_name()] = self.epilogue_info['sram_var'] #TODO: Buffer name fixed + if 'sram_var' in self.epilogue_info: + self.buffer_names[node.get_name()] = self.epilogue_info['sram_var'] #TODO: Buffer name fixed def kernel_hook(): arg_defs, *_ = self.kernel_group.args.mlir_argdefs(extra_node=self.extra_node) @@ -678,7 +795,7 @@ def kernel_hook(): return f"({', '.join(arg_defs)})" assert "" not in self.render_hooks - self.render_hooks[""] = kernel_hook + self.render_hooks[""] = (5, kernel_hook) # Default priority 5 return "" # This function is for convolution wrapper function finalizing. @@ -689,7 +806,7 @@ def wrapper_hook(): return f"({', '.join(wrapper_arg_defs)})" if "" not in self.render_hooks: - self.render_hooks[""] = wrapper_hook + self.render_hooks[""] = (5, wrapper_hook) # Default priority 5 return "" def get_conv_inputs(self): @@ -698,15 +815,15 @@ def get_conv_inputs(self): def get_conv_outputs(self): return {k: v for k, v in self.kernel_group.args.output_buffers.items() if v != 'REMOVED'} - def load_input(self, indent_size: int = 0): + def load_input(self, indent_size: int = 0, priority: int = 1): def hook(): code = IndentedBuffer() prologue_code = self.codegen_prologue_body() if prologue_code.getvalue(): input_dma_code = self.def_dma_op("MVIN", self.prologue_info["input_dram_var"], self.prologue_info["input_idx"], - self.prologue_info["input_tile_desc"], subtile_size=self.prologue_info["input_subtile_size"], async_type=False) + self.prologue_info["input_tile_desc"], subtile_size=self.prologue_info["input_subtile_size"], async_type=False, lazy_mode=False) weight_dma_code = self.def_dma_op("MVIN", self.prologue_info["weight_dram_var"], self.prologue_info["weight_idx"], - self.prologue_info["weight_tile_desc"], subtile_size=self.prologue_info["weight_subtile_size"], async_type=False) + self.prologue_info["weight_tile_desc"], subtile_size=self.prologue_info["weight_subtile_size"], async_type=False, lazy_mode=False) if (self.prologue_info["is_input_fused"]): code.splice(input_dma_code) code.splice(prologue_code) @@ -717,58 +834,63 @@ def hook(): code.splice(input_dma_code) else: dma_code = self.def_dma_op("MVIN", self.prologue_info["input_dram_var"], self.prologue_info["input_idx"], - self.prologue_info["input_tile_desc"], subtile_size=self.prologue_info["input_subtile_size"], async_type=False) + self.prologue_info["input_tile_desc"], subtile_size=self.prologue_info["input_subtile_size"], async_type=False, lazy_mode=False) code.splice(dma_code) dma_code = self.def_dma_op("MVIN", self.prologue_info["weight_dram_var"], self.prologue_info["weight_idx"], - self.prologue_info["weight_tile_desc"], subtile_size=self.prologue_info["weight_subtile_size"], async_type=False) + self.prologue_info["weight_tile_desc"], subtile_size=self.prologue_info["weight_subtile_size"], async_type=False, lazy_mode=False) code.splice(dma_code) code = textwrap.indent(code.getvalue(), " "*indent_size).strip() return code assert "" not in self.render_hooks - self.render_hooks[""] = hook - self.render_hooks.move_to_end("", last=False) # Force order to be triggered first + self.render_hooks[""] = (priority, hook) return "" - def store_output(self, indent_size: int = 0): + def store_output(self, indent_size: int = 0, priority: int = 1): def hook(): epilogue_code = self.codegen_epilogue_body() return textwrap.indent(epilogue_code.getvalue(), " "*indent_size).strip() assert "" not in self.render_hooks - self.render_hooks[""] = hook - self.render_hooks.move_to_end("", last=False) # Force order to be triggered first + self.render_hooks[""] = (priority, hook) return "" - def reduction_output(self, indent_size: int = 0): + def reduction_output(self, indent_size: int = 0, priority: int = 5): def hook(): return textwrap.indent(self.reductions_suffix.getvalue(), " "*indent_size).strip() assert "" not in self.render_hooks - self.render_hooks[""] = hook + self.render_hooks[""] = (priority, hook) return "" + def _sort_hooks_by_priority(self): + """Sort hooks by priority (lower priority executes first).""" + sorted_hooks = OrderedDict() + for key, (priority, hook) in sorted(self.render_hooks.items(), key=lambda x: x[1][0]): + sorted_hooks[key] = hook + return sorted_hooks + def def_function(self): _, call_args, _, _ = self.kernel_group.args.python_argdefs() if self.outer_func_render is not None: partial_code, function_name = self.outer_func_render(input_args=call_args) + return PartialRender( partial_code, - self.render_hooks, + self._sort_hooks_by_priority(), ), function_name else: return None, None - def def_global_vars(self): + def def_global_vars(self, priority: int = 10): key = "" def hook(): return textwrap.indent(self.global_vars.getvalue(), "").strip() - assert key not in self.render_hooks - self.render_hooks[key] = hook + self.render_hooks[key] = (priority, hook) return key - def def_local_vars(self, indent_size=0): + def def_local_vars(self, indent_size=0, priority: int = 10): key = "" def hook(): code = IndentedBuffer() @@ -777,57 +899,84 @@ def hook(): code.splice(self.alloc_buffer) return textwrap.indent(code.getvalue(), " "*indent_size).strip() - assert key not in self.render_hooks - self.render_hooks[key] = hook + self.render_hooks[key] = (priority, hook) return key def def_dma_op(self, dma_type, dram_var:str, index_list:list, tile_desc:mlir_common.MLIRMultiDimTile, - subtile_size:list=[], async_type=None, indent_size=0): - # Prepare code block - local_code = IndentedBuffer() - with self, self.override_buffer_cse(buffer=local_code, cse=self.apply_cse): - index_var = self.parse_index_list(index_list, offset=tile_desc.offset) - node_layout = self.named_nodes[dram_var].get_layout() - if dram_var in self.exception_nodes: - numel = self.exception_nodes[dram_var]["numel"] - else: - numel = self.get_arg_info(self.named_nodes[dram_var].get_name()).get_numel() - mlir_dtype = mlir_common.DTYPE_TO_MLIR[node_layout.dtype] - dram_shape = f"memref<{numel}x{mlir_dtype}>" - dram_stride = [] - for idx in index_list: - if idx.is_Mul: - dram_stride.append(int(idx.args[0])) - elif idx == sympy.Symbol("c0"): - dram_stride.append(0) - elif not idx.is_Number: - dram_stride.append(1) + subtile_size:list=[], async_type=None, indent_size=0, priority: int = 5, lazy_mode: bool = True, + dram_stride:list=None, dram_offset=None, padding: int = 0): + # Todo. Remove legacy behavior (i.e., index_list parsing) + def generate_dma_code(): + """Internal method to generate DMA code directly.""" + local_code = IndentedBuffer() + with self, self.override_buffer_cse(buffer=local_code, cse=self.apply_cse): + if dram_offset is not None: + # Use explicitly provided offset (pre-computed MLIR SSA variable name) + index_var = dram_offset + else: + index_var = self.parse_index_list(index_list, offset=tile_desc.offset) + node_layout = self.named_nodes[dram_var].get_layout() + if dram_var in self.exception_nodes: + numel = self.exception_nodes[dram_var]["numel"] else: - dram_stride.append(0) - - sram_var = tile_desc.get_name() - tile_shape = tile_desc.get_mlir_shape(mlir_dtype) - tile_stride = tile_desc.get_tile_stride() - vlane_split_axis = tile_desc.vmap.vlane_split_axis - vlane_stride = tile_desc.vmap.vlane_stride - - zero_cse = self.get_const_cse(0, "index") - sram_index_var = ", ".join([f"%{str(zero_cse)}"]*tile_desc.get_nr_dim()) - - attribute_parts = [f"dram_stride={dram_stride}", f"sram_stride={tile_stride}", "padding=0"] - if subtile_size: - attribute_parts.append(f"subtile_size={subtile_size}, async={int(async_type) if async_type is not None else 1}") - attribute = " {" + ", ".join(attribute_parts) + "}" - code = self.get_dma_code(dma_type, vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, - dram_shape, tile_shape, "") - local_code.writeline(code) - local_code.writeline(attribute) - return textwrap.indent(local_code.getvalue(), " "*indent_size).strip() + numel = self.get_arg_info(self.named_nodes[dram_var].get_name()).get_numel() + mlir_dtype = mlir_common.DTYPE_TO_MLIR[node_layout.dtype] + dram_shape = f"memref<{numel}x{mlir_dtype}>" + + if dram_stride is not None: + # Use explicitly provided dram_stride + _dram_stride = dram_stride + else: + # Extract dram_stride from index_list (legacy behavior) + _dram_stride = [] + for idx in index_list: + if idx.is_Mul: + _dram_stride.append(int(idx.args[0])) + elif idx == sympy.Symbol("c0"): + _dram_stride.append(0) + elif not idx.is_Number: + _dram_stride.append(1) + else: + _dram_stride.append(0) + + sram_var = tile_desc.get_name() + tile_shape = tile_desc.get_mlir_shape(mlir_dtype) + sram_strides = tile_desc.get_tile_stride() + vlane_split_axis = tile_desc.vmap.vlane_split_axis + vlane_stride = tile_desc.vmap.vlane_stride + + zero_cse = self.get_const_cse(0, "index") + sram_index_var = ", ".join([f"%{str(zero_cse)}"]*tile_desc.get_nr_dim()) + + attribute_parts = [f"dram_stride={_dram_stride}", f"sram_stride={sram_strides}", f"padding={int(padding)}"] + if subtile_size: + attribute_parts.append(f"subtile_size={subtile_size}, async={int(async_type) if async_type is not None else 1}") + attribute = " {" + ", ".join(attribute_parts) + "}" + code = self.get_dma_code(dma_type, vlane_split_axis, vlane_stride, mlir_dtype, dram_var, index_var, sram_var, sram_index_var, + dram_shape, tile_shape, "") + local_code.writeline(code) + local_code.writeline(attribute) + return textwrap.indent(local_code.getvalue(), " "*indent_size).strip() + + if not lazy_mode: + # Immediate mode: generate code directly and return it + return generate_dma_code() + + # Lazy mode: register hook and return key + dma_op_id = next(self.dma_op_counter) + key = f"" + self.render_hooks[key] = (priority, generate_dma_code) + return key def def_sram_buffer(self, dram_name, tile_desc, id=0, indent_size=0): # Prepare code block with self: - dtype = self.named_nodes[dram_name].get_layout().dtype + try: + dtype = self.named_nodes[dram_name].get_layout().dtype + except (KeyError, AttributeError, TypeError): + import torch + dtype = torch.float32 + tile_shape = tile_desc.get_mlir_shape(mlir_common.DTYPE_TO_MLIR[dtype]) buffer_name = self.allocate_sram_buffer(dtype, dram_name, tile_desc, id, forced_name=dram_name) code = f"%{tile_desc.name} = memref.get_global @{buffer_name} : {tile_shape}" @@ -840,7 +989,7 @@ def render(self, template, kwargs, define_function=None): return PartialRender( code, - self.render_hooks, + self._sort_hooks_by_priority(), ) def get_spad_size_per_lane(self, tile_m, tile_n): @@ -1106,7 +1255,7 @@ def set_tile_size(self, template_fusion_info, prologue=False): numel_per_lane = tile_desc.get_numel_per_lane() r_tile_size = tile_desc.get_tile_size()[-1] nr_outer_loop = (numel_per_lane + r_tile_size-1) // r_tile_size - tile_desc.vmap.forced_vec_size = nr_outer_loop * 32 # Why? Emprically selected, other option failed to functionality... + tile_desc.vmap.forced_vec_size = self.get_safe_vec_size(nr_outer_loop * 32) # Why? Emprically selected, other option failed to functionality... self.reduction_fusion = True self.r_tile_size = tile_desc.get_tile_size()[-1] @@ -1117,7 +1266,7 @@ def set_tile_size(self, template_fusion_info, prologue=False): self.compute_body_loop.step = tile_desc.get_compute_vec_size() // nr_outer_loop self.reduction_body_loop = mlir_common.LoopLevel(self.reduction_loop_idx, nr_outer_loop) else: - tile_desc.vmap.forced_vec_size = 64 + tile_desc.vmap.forced_vec_size = self.get_safe_vec_size(64) if prologue: self.prologue_compute_body_loop.size = tile_desc.get_numel_per_lane() @@ -1128,6 +1277,15 @@ def set_tile_size(self, template_fusion_info, prologue=False): return tile_desc class MLIRTemplateCaller(CUDATemplateCaller): + def __init__(self, name, category, input_nodes, layout, make_kernel_render, supports_epilogue_fusion, template, info_kwargs, description): + bmreq = MLIRBenchmarkRequest( + kernel_name=name, + input_tensor_meta=list(), + output_tensor_meta=list(), + extra_args=[], + source_code="", + ) + super().__init__(name, category, input_nodes, layout, make_kernel_render, bmreq, supports_epilogue_fusion, template, info_kwargs, description) def __str__(self): return f"MLIRTemplateCaller(source_file={self.bmreq.source_file})" @@ -1151,8 +1309,14 @@ def __init__(self, name, input_nodes, layout, input_reorder = None): super().__init__(name) self.input_nodes = [node for node in input_nodes if node is not None] self.output_node: Buffer = Buffer(name="buf_out", layout=layout) + # Multi-output templates can override this with explicit output buffers. + self.output_nodes = [self.output_node] self.input_reorder = input_reorder self.layout = layout + # Fusion support flags (default to False) + self.support_epilogue_fusion = False + self.support_prologue_fusion = False + self.support_reduction_fusion = False def generate(self, **kwargs) -> ChoiceCaller: kernel_name = f"mlir_{self.name}" @@ -1164,15 +1328,8 @@ def generate(self, **kwargs) -> ChoiceCaller: code = self.render(kernel=kernel, **kwargs) kernel_hash_name = f"mlir_{self.name}_{next(self.index_counter)}" - extra_args = [] # create the BenchmarkRequest - bmreq = MLIRBenchmarkRequest( - kernel_name=kernel_name, - input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes), - output_tensor_meta=TensorMeta.from_irnodes(self.output_node), - extra_args=extra_args, - source_code=code, - ) + output_nodes = getattr(self, "output_nodes", None) or [self.output_node] def make_kernel_render( template_node: TemplateBuffer, @@ -1214,7 +1371,6 @@ def make_kernel_render( self.input_nodes, self.output_node.get_layout(), make_kernel_render, - bmreq, False, # supports_epilogue_fusion self, kwargs, diff --git a/README.md b/README.md index 4a3ef145..f55995c9 100644 --- a/README.md +++ b/README.md @@ -396,7 +396,6 @@ export TORCHSIM_USE_TIMING_POOLING=0 # use lightweight pooling for timing "icnt_injection_ports_per_core" : 16 // Interconnect injection ports per core "icnt_config_path" : "../configs/booksim2_configs/fly_c4_m32.icnt", // Booksim2 config file path - "precision" : 4, // Element's precision in tensor (Byte) "scheduler" : "simple", // Scheduler type (Now, only support simple scheduler) "num_partition" : 2, // Multi-core Partitioning "partition": { // allocate request queue index diff --git a/Simulator/simulator.py b/Simulator/simulator.py index 13f2b4f0..f24835ba 100644 --- a/Simulator/simulator.py +++ b/Simulator/simulator.py @@ -68,6 +68,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): torch.uint8: np.uint8, torch.bool: np.uint8, torch.bfloat16: np.float16, + torch.float16: np.float16, } class FunctionalSimulator(): @@ -143,7 +144,7 @@ def run_spike(self, args, arg_attributes, runtime_path, binary, vectorlane_size= base_path= f"--base-path={runtime_path}" os.makedirs(os.path.join(runtime_path, "indirect_access"), exist_ok=True) os.makedirs(os.path.join(runtime_path, "dma_access"), exist_ok=True) - run = f'spike --isa rv64gcv --varch=vlen:256,elen:64 {vectorlane_option} {spad_option} {kernel_address} {base_path} /workspace/riscv-pk/build/pk {target_binary} {file_path_str}' + run = f'spike --isa rv64gcv_zfh --varch=vlen:256,elen:64 {vectorlane_option} {spad_option} {kernel_address} {base_path} /workspace/riscv-pk/build/pk {target_binary} {file_path_str}' if not silent_mode: logger.debug(f"[Spike] cmd> {run}") logger.info("[Spike] Running Spike simulator") diff --git a/TOGSim/src/DMA.cc b/TOGSim/src/DMA.cc index f8f21025..fefee6d2 100644 --- a/TOGSim/src/DMA.cc +++ b/TOGSim/src/DMA.cc @@ -12,7 +12,7 @@ void DMA::issue_tile(std::shared_ptr inst) { _current_inst = std::move(inst); std::vector& tile_size = _current_inst->get_tile_size(); if (tile_size.size() <= 0 || tile_size.size() > get_max_dim()) { - spdlog::error("[DMA {}] issued tile is not supported format..", _id); + spdlog::error("[DMA {}] issued tile is not supported format.. tile.size: {}, tile_size: [{}]", _id, tile_size.size(), fmt::join(tile_size, ", ")); exit(EXIT_FAILURE); } _finished = false; diff --git a/TOGSim/src/Simulator.cc b/TOGSim/src/Simulator.cc index b5b9c778..d7fe9f1b 100644 --- a/TOGSim/src/Simulator.cc +++ b/TOGSim/src/Simulator.cc @@ -121,7 +121,7 @@ void Simulator::icnt_cycle() { front->set_core_id(core_id); if (!_icnt->is_full(port_id, front)) { int node_id = _dram->get_channel_id(front) / _config.dram_channels_per_partitions; - if (core_id == node_id) + if (get_partition_id(core_id) == node_id) _cores[core_id]->inc_numa_local_access(); else _cores[core_id]->inc_numa_remote_access(); diff --git a/TOGSim/src/helper/CommandLineParser.cc b/TOGSim/src/helper/CommandLineParser.cc index 66aebbe1..9cd177ac 100644 --- a/TOGSim/src/helper/CommandLineParser.cc +++ b/TOGSim/src/helper/CommandLineParser.cc @@ -12,9 +12,13 @@ void CommandLineParser::parse(int argc, char **argv) noexcept(false) { po::notify(variables_map); } +void CommandLineParser::print_help_message() const noexcept { + std::cout << options_description << std::endl; +} + void CommandLineParser::print_help_message_if_required() const noexcept { if (variables_map.count("help") > 0) { - std::cout << options_description << std::endl; + print_help_message(); exit(0); } } diff --git a/TOGSim/src/helper/CommandLineParser.h b/TOGSim/src/helper/CommandLineParser.h index 39174d5d..b41eabf3 100644 --- a/TOGSim/src/helper/CommandLineParser.h +++ b/TOGSim/src/helper/CommandLineParser.h @@ -19,7 +19,7 @@ class CommandLineParser { * Command Line Parser constructor */ CommandLineParser() noexcept { - options_description.add_options()("help", "Prints help message"); + options_description.add_options()("help,h", "Prints help message"); } /** @@ -38,6 +38,12 @@ class CommandLineParser { */ void print_help_message_if_required() const noexcept; + /** + * Prints the help message. + * (Can be called to show help for invalid options) + */ + void print_help_message() const noexcept; + /** * Add a new command line argument option. * (Should be called before `parse` method is called) diff --git a/TOGSim/src/main.cc b/TOGSim/src/main.cc index 7c596af5..cda8f986 100644 --- a/TOGSim/src/main.cc +++ b/TOGSim/src/main.cc @@ -96,19 +96,24 @@ int main(int argc, char** argv) { // parse command line argumnet CommandLineParser cmd_parser = CommandLineParser(); cmd_parser.add_command_line_option( - "config", "Path for hardware configuration file"); + "config", "Path for hardware configuration file (.yml)"); cmd_parser.add_command_line_option( - "models_list", "Path for the models list file (can be FIFO or regular file)"); + "models_list", "Path for the trace file (.trace)"); cmd_parser.add_command_line_option( "log_level", "Set for log level [trace, debug, info], default = info"); try { cmd_parser.parse(argc, argv); } catch (const CommandLineParser::ParsingError& e) { spdlog::error( - "Command line argument parrsing error captured. Error message: {}", + "Command line argument parsing error captured. Error message: {}", e.what()); - throw(e); + std::cerr << std::endl; + cmd_parser.print_help_message(); + exit(1); } + + // Check if help was requested + cmd_parser.print_help_message_if_required(); std::string level = "info"; cmd_parser.set_if_defined("log_level", &level); diff --git a/experiments/BERT.py b/experiments/BERT.py index fd671833..b938f4e6 100644 --- a/experiments/BERT.py +++ b/experiments/BERT.py @@ -1,57 +1,42 @@ -import torch -import torch._dynamo -import torch.utils.cpp_extension - +import os +import sys import argparse -import datetime -def run_BERT(size, input_seq, config): - from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request - # from tests.test_transformer import EncoderBlock - from tests.Fusion.test_transformer_fusion import EncoderBlock - scheduler = Scheduler(num_request_queue=1, engine_select=Scheduler.FIFO_ENGINE, togsim_config=config) - device = scheduler.execution_engine.module.custom_device() +base_path = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') +sys.path.insert(0, base_path) - hidden_dim = {'base': 768, 'large': 1024, 'xlarge': 2048} - embedding_size = {'base': 768, 'large': 1024, 'xlarge': 2048} - heads = {'base': 12, 'large': 16, 'xlarge': 32} # hidden/64 https://arxiv.org/pdf/1909.11942 - cpu_query = torch.randn(input_seq, hidden_dim[size]) - encoder_block = EncoderBlock(embedding_size[size], heads[size]).eval() - - query = cpu_query.clone().to(device=device) - opt_fn = torch.compile(dynamic=False)(encoder_block.to(device=device)) +import torch +from Simulator.simulator import TOGSimulator - SchedulerDNNModel.register_model(f"BERT-{size}", opt_fn) - request = Request(f"BERT-{size}", [query], [], request_queue_idx=0) - scheduler.add_request(request, request_time=0) +config = os.environ.get('TOGSIM_CONFIG', f'{base_path}/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.yml') +os.environ['TOGSIM_CONFIG'] = config - # Run scheduler - while not scheduler.is_finished(): - with torch.no_grad(): - scheduler.schedule() +# Try Fusion EncoderBlock first, fall back to standard test_transformer +try: + from tests.Fusion.test_transformer_fusion import EncoderBlock +except ImportError: + from tests.test_transformer import EncoderBlock - print(f"BERT-{size} Simulation Done") +HIDDEN_DIM = {'base': 768, 'large': 1024, 'xlarge': 2048} +EMBEDDING_SIZE = {'base': 768, 'large': 1024, 'xlarge': 2048} +HEADS = {'base': 12, 'large': 16, 'xlarge': 32} if __name__ == "__main__": - import os - import sys - base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') - config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.yml') - config_prefix = config.split('/')[-1].split('.')[0][9:] # extract config name from config path FIXME: gem5 result is different as directoy name - sys.path.append(base_dir) args = argparse.ArgumentParser() - args.add_argument('--size', type=str, default='base') - args.add_argument('--dump_path', type=str, default='results') + args.add_argument('--size', type=str, default='base', choices=['base', 'large', 'xlarge']) args.add_argument('--input_size', type=int, default=512) args = args.parse_args() - size = args.size - input_seq = args.input_size - result_path = os.path.join(base_dir, args.dump_path, config_prefix, f"BERT_{size}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") - # setting environment variables - os.environ['TORCHSIM_LOG_PATH'] = result_path - # only timing simulation - os.environ['TORCHSIM_VALIDATION_MODE'] = "0" - if 'pytorchsim_functional_mode' in os.environ: - del os.environ['pytorchsim_functional_mode'] - - run_BERT(size, input_seq, config) + + hidden_dim = HIDDEN_DIM[args.size] + embedding_size = EMBEDDING_SIZE[args.size] + heads = HEADS[args.size] + + device = torch.device("npu:0") + model = EncoderBlock(embedding_size, heads).eval().to(device=device) + model_input = torch.randn(args.input_size, hidden_dim).to(device=device) + opt_fn = torch.compile(dynamic=False)(model) + + with TOGSimulator(config_path=config): + torch.npu.launch_model(opt_fn, model_input, stream_index=0, timestamp=0) + torch.npu.synchronize() + print(f"BERT-{args.size} Simulation Done") diff --git a/experiments/attention.py b/experiments/attention.py index 211433f1..b56ed537 100644 --- a/experiments/attention.py +++ b/experiments/attention.py @@ -1,56 +1,36 @@ -import torch -import torch._dynamo -import torch.utils.cpp_extension - +import os +import sys +import math import argparse -import datetime +base_path = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') +sys.path.insert(0, base_path) -def run_attention(size, config): - def attention(query, key, value): - import math - d_k = query.size(-1) - scores = torch.matmul(key, query.transpose(-2, -1)) / math.sqrt(d_k) - p_attn = scores.softmax(dim=-2) - return torch.matmul(value.transpose(-1, -2), p_attn) - from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request - scheduler = Scheduler(num_request_queue=1, engine_select=Scheduler.FIFO_ENGINE, togsim_config=config) - device = scheduler.execution_engine.module.custom_device() - query = torch.randn(size).to(device=device) - key = torch.randn(size).to(device=device) - value = torch.randn(size).to(device=device) - opt_fn = torch.compile(dynamic=False)(attention) - - SchedulerDNNModel.register_model("attention", opt_fn) - request = Request("attention", [query, key, value], [], request_queue_idx=0) - scheduler.add_request(request, request_time=0) +import torch +from Simulator.simulator import TOGSimulator - # Run scheduler - while not scheduler.is_finished(): - with torch.no_grad(): - scheduler.schedule() +config = os.environ.get('TOGSIM_CONFIG', f'{base_path}/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.yml') +os.environ['TOGSIM_CONFIG'] = config - print(f"Attention {str(size)} Simulation Done") +def attention(query, key, value): + d_k = query.size(-1) + scores = torch.matmul(key, query.transpose(-2, -1)) / math.sqrt(d_k) + p_attn = scores.softmax(dim=-2) + return torch.matmul(value.transpose(-1, -2), p_attn) if __name__ == "__main__": - import os - import sys - base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') - config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.yml') - config_prefix = config.split('/')[-1].split('.')[0][9:] # extract config name from config path - sys.path.append(base_dir) args = argparse.ArgumentParser() args.add_argument('--size', nargs='+', type=int, default=[12, 512, 64], help='Tensor Shape') - args.add_argument('--dump_path', type=str, default='results') args = args.parse_args() - size = args.size - size_str = "x".join([str(i) for i in size]) - result_path = os.path.join(base_dir, args.dump_path, config_prefix, f"attention_{size_str}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") - # setting environment variables - os.environ['TORCHSIM_LOG_PATH'] = result_path - # only timing simulation - os.environ['TORCHSIM_VALIDATION_MODE'] = "0" - if 'pytorchsim_functional_mode' in os.environ: - del os.environ['pytorchsim_functional_mode'] + size = tuple(args.size) + + device = torch.device("npu:0") + query = torch.randn(*size).to(device=device) + key = torch.randn(*size).to(device=device) + value = torch.randn(*size).to(device=device) + opt_fn = torch.compile(dynamic=False)(attention) - run_attention(size, config) + with TOGSimulator(config_path=config): + torch.npu.launch_model(opt_fn, query, key, value, stream_index=0, timestamp=0) + torch.npu.synchronize() + print(f"Attention {size} Simulation Done") diff --git a/experiments/conv.py b/experiments/conv.py index 61f7ad80..98391fae 100644 --- a/experiments/conv.py +++ b/experiments/conv.py @@ -1,57 +1,39 @@ -import torch -import torch._dynamo -import torch.utils.cpp_extension - +import os +import sys import argparse -import datetime +base_path = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') +sys.path.insert(0, base_path) + +import torch +from Simulator.simulator import TOGSimulator + +config = os.environ.get('TOGSIM_CONFIG', f'{base_path}/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.yml') +os.environ['TOGSIM_CONFIG'] = config -def run_conv2d(batch_size, i_h, i_w, i_c, o_c, kernel_size, stride, padding, config): - from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request - def custom_conv2d(a, b, bias): - i_c = a.shape[1] - o_c = b.shape[0] - conv2d = torch.nn.Conv2d(i_c, o_c, b.shape[-1], stride=stride, padding=padding, dilation=1, bias=False) +def conv2d_fn(batch_size, i_h, i_w, i_c, o_c, kernel_size, stride, padding): + def _conv(a, b, bias): + conv2d = torch.nn.Conv2d(i_c, o_c, kernel_size, stride=stride, padding=padding, dilation=1, bias=False) conv2d.weight = torch.nn.Parameter(b) - # conv2d.bias = torch.nn.Parameter(bias) return conv2d(a) - scheduler = Scheduler(num_request_queue=1, engine_select=Scheduler.FIFO_ENGINE, togsim_config=config) - device = scheduler.execution_engine.module.custom_device() + return _conv + +if __name__ == "__main__": + args = argparse.ArgumentParser() + args.add_argument('--size', nargs='+', type=int, default=[8, 28, 28, 128, 128, 3, 1, 1], + help='B H W I_C O_C K S P') + args = args.parse_args() + batch_size, i_h, i_w, i_c, o_c, kernel_size, stride, padding = args.size + + device = torch.device("npu:0") conv_input = torch.randn(batch_size, i_c, i_h, i_w).to(memory_format=torch.channels_last, device=device) conv_kernel = torch.randn(o_c, i_c, kernel_size, kernel_size).to(memory_format=torch.channels_last, device=device) conv_bias = torch.randn(o_c).to(device=device) - opt_fn = torch.compile(dynamic=False)(custom_conv2d) - - SchedulerDNNModel.register_model("CONV", opt_fn) - request = Request("CONV", [conv_input, conv_kernel, conv_bias], [], request_queue_idx=0) - scheduler.add_request(request, request_time=0) - - # Run scheduler - while not scheduler.is_finished(): - with torch.no_grad(): - scheduler.schedule() - print(f"CONV {batch_size}_{i_h}_{i_w}_{i_c}_{o_c}_{kernel_size}_{stride}_{padding} (B_H_W_I_C_O_C_K_S_P) Simulation Done") + custom_conv = conv2d_fn(batch_size, i_h, i_w, i_c, o_c, kernel_size, stride, padding) + opt_fn = torch.compile(dynamic=False)(custom_conv) -if __name__ == "__main__": - import os - import sys - base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') - config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.yml') - config_prefix = config.split('/')[-1].split('.')[0][9:] # extract config name from config path - sys.path.append(base_dir) - args = argparse.ArgumentParser() - args.add_argument('--size', nargs='+', type=int, default=[8, 28, 28, 128, 128, 3, 1, 1], help='B H W I_C O_C K S P') - args.add_argument('--dump_path', type=str, default='results') - args = args.parse_args() - size = args.size - size_str = "_".join([str(i) for i in size]) - result_path = os.path.join(base_dir, args.dump_path, config_prefix, f"CONV_{size_str}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") - # setting environment variables - os.environ['TORCHSIM_LOG_PATH'] = result_path - # only timing simulation - os.environ['TORCHSIM_VALIDATION_MODE'] = "0" - if 'pytorchsim_functional_mode' in os.environ: - del os.environ['pytorchsim_functional_mode'] - - run_conv2d(size[0], size[1], size[2], size[3], size[4], size[5], size[6], size[7], config) \ No newline at end of file + with TOGSimulator(config_path=config): + torch.npu.launch_model(opt_fn, conv_input, conv_kernel, conv_bias, stream_index=0, timestamp=0) + torch.npu.synchronize() + print(f"CONV {batch_size}_{i_h}_{i_w}_{i_c}_{o_c}_{kernel_size}_{stride}_{padding} Simulation Done") diff --git a/experiments/gemm.py b/experiments/gemm.py index 0e1a15e4..d256e931 100644 --- a/experiments/gemm.py +++ b/experiments/gemm.py @@ -1,51 +1,32 @@ -import torch -import torch._dynamo -import torch.utils.cpp_extension - +import os +import sys import argparse -import datetime - -def run_matmul(input_size, hidden_size, output_size, config): - from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request - def custom_matmul(a, b): - return torch.matmul(a, b) - scheduler = Scheduler(num_request_queue=1, engine_select=Scheduler.FIFO_ENGINE, togsim_config=config) - device = scheduler.execution_engine.module.custom_device() - torch.manual_seed(0) - input = torch.randn(input_size, hidden_size).to(device=device) - weight = torch.randn(hidden_size, output_size).to(device=device) - opt_fn = torch.compile(dynamic=False)(custom_matmul) +base_path = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') +sys.path.insert(0, base_path) - SchedulerDNNModel.register_model("GEMM", opt_fn) - request = Request("GEMM", [input, weight], [], request_queue_idx=0) - scheduler.add_request(request, request_time=0) +import torch +from Simulator.simulator import TOGSimulator - # Run scheduler - while not scheduler.is_finished(): - scheduler.schedule() +config = os.environ.get('TOGSIM_CONFIG', f'{base_path}/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.yml') +os.environ['TOGSIM_CONFIG'] = config - print(f"GEMM {input_size}x{hidden_size}x{output_size} (MxKxN) Simulation Done") +def matmul_fn(a, b): + return torch.matmul(a, b) if __name__ == "__main__": - import os - import sys - base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') - config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.yml') - config_prefix = config.split('/')[-1].split('.')[0][9:] # extract config name from config path - sys.path.append(base_dir) args = argparse.ArgumentParser() args.add_argument('--size', nargs='+', type=int, default=[128, 128, 128], help='M K N') - args.add_argument('--dump_path', type=str, default='results') args = args.parse_args() - size = args.size - size_str = "x".join([str(i) for i in size]) - result_path = os.path.join(base_dir, args.dump_path, config_prefix, f"GEMM_{size_str}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") - # setting environment variables - os.environ['TORCHSIM_LOG_PATH'] = result_path - # only timing simulation - os.environ['TORCHSIM_VALIDATION_MODE'] = "0" - if 'pytorchsim_functional_mode' in os.environ: - del os.environ['pytorchsim_functional_mode'] + M, K, N = args.size[0], args.size[1], args.size[2] - run_matmul(size[0], size[1], size[2], config) + device = torch.device("npu:0") + torch.manual_seed(0) + input_a = torch.randn(M, K).to(device=device) + input_b = torch.randn(K, N).to(device=device) + opt_fn = torch.compile(dynamic=False)(matmul_fn) + + with TOGSimulator(config_path=config): + torch.npu.launch_model(opt_fn, input_a, input_b, stream_index=0, timestamp=0) + torch.npu.synchronize() + print(f"GEMM {M}x{K}x{N} (MxKxN) Simulation Done") diff --git a/experiments/layernorm.py b/experiments/layernorm.py index a6b16986..a9170c6b 100644 --- a/experiments/layernorm.py +++ b/experiments/layernorm.py @@ -1,48 +1,29 @@ -import torch -import torch._dynamo -import torch.utils.cpp_extension - +import os +import sys import argparse -import datetime - -def run_layernorm(size, config): - from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request - scheduler = Scheduler(num_request_queue=1, engine_select=Scheduler.FIFO_ENGINE, togsim_config=config) - device = scheduler.execution_engine.module.custom_device() - input = torch.randn(size).to(device=device) - opt_fn = torch.compile(dynamic=False)(torch.nn.LayerNorm(size[-1]).to(device=device)) +base_path = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') +sys.path.insert(0, base_path) - SchedulerDNNModel.register_model("LayerNorm", opt_fn) - request = Request("LayerNorm", [input], [], request_queue_idx=0) - scheduler.add_request(request, request_time=0) - - # Run scheduler - while not scheduler.is_finished(): - scheduler.schedule() +import torch +from Simulator.simulator import TOGSimulator - print(f"LayerNorm {str(size)} Simulation Done") +config = os.environ.get('TOGSIM_CONFIG', f'{base_path}/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.yml') +os.environ['TOGSIM_CONFIG'] = config if __name__ == "__main__": - import os - import sys - base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') - config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.yml') - config_prefix = config.split('/')[-1].split('.')[0][9:] # extract config name from config path - sys.path.append(base_dir) args = argparse.ArgumentParser() args.add_argument('--size', nargs='+', type=int, default=[512, 768], help='Tensor Shape') - args.add_argument('--dump_path', type=str, default='results') args = args.parse_args() - size = args.size - size_str = "x".join([str(i) for i in size]) - result_path = os.path.join(base_dir, args.dump_path, config_prefix, f"LayerNorm_{size_str}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") - # setting environment variables - os.environ['TORCHSIM_LOG_PATH'] = result_path - os.environ['TORCHSIM_FUSION_REDUCTION_REDUCTION'] = "0" - # only timing simulation - os.environ['TORCHSIM_VALIDATION_MODE'] = "0" - if 'pytorchsim_functional_mode' in os.environ: - del os.environ['pytorchsim_functional_mode'] - - run_layernorm(size, config) + size = tuple(args.size) + normalized_shape = size[-1] + + device = torch.device("npu:0") + model = torch.nn.LayerNorm(normalized_shape).to(device=device) + opt_fn = torch.compile(dynamic=False)(model) + model_input = torch.randn(*size).to(device=device) + + with TOGSimulator(config_path=config): + torch.npu.launch_model(opt_fn, model_input, stream_index=0, timestamp=0) + torch.npu.synchronize() + print(f"LayerNorm {size} Simulation Done") diff --git a/experiments/resnet18.py b/experiments/resnet18.py index c7763d86..38fb80fe 100644 --- a/experiments/resnet18.py +++ b/experiments/resnet18.py @@ -1,49 +1,28 @@ -import torch -import torch._dynamo -import torch.utils.cpp_extension - +import os +import sys import argparse -import datetime -def run_resnet(batch, config): - from torchvision.models import resnet18 - from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request - scheduler = Scheduler(num_request_queue=1, engine_select=Scheduler.FIFO_ENGINE, togsim_config=config) - device = scheduler.execution_engine.module.custom_device() - model = resnet18().eval() - input = torch.randn(batch, 3, 224, 224).to(device=device) - opt_fn = torch.compile(dynamic=False)(model.to(device, memory_format=torch.channels_last)) +base_path = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') +sys.path.insert(0, base_path) - SchedulerDNNModel.register_model("resnet18", opt_fn) - request = Request("resnet18", [input], [], request_queue_idx=0) - scheduler.add_request(request, request_time=0) +import torch +from torchvision.models import resnet18 +from Simulator.simulator import TOGSimulator - # Run scheduler - while not scheduler.is_finished(): - with torch.no_grad(): - scheduler.schedule() - - print("ResNet18 Simulation Done") +config = os.environ.get('TOGSIM_CONFIG', f'{base_path}/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.yml') +os.environ['TOGSIM_CONFIG'] = config if __name__ == "__main__": - import os - import sys - base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') - config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/configs/systolic_ws_128x128_c1_simple_noc_tpuv3.yml') - config_prefix = config.split('/')[-1].split('.')[0][9:] # extract config name from config path - sys.path.append(base_dir) args = argparse.ArgumentParser() args.add_argument('--batch', type=int, default=1) - args.add_argument('--dump_path', type=str, default='results') args = args.parse_args() - batch = args.batch - result_path = os.path.join(base_dir, args.dump_path, config_prefix, f"resnet18_{batch}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") - # setting environment variables - os.environ['TORCHSIM_LOG_PATH'] = result_path - os.environ['TORCHSIM_USE_TIMING_POOLING'] = "1" - # only timing simulation - os.environ['TORCHSIM_VALIDATION_MODE'] = "0" - if 'pytorchsim_functional_mode' in os.environ: - del os.environ['pytorchsim_functional_mode'] - run_resnet(batch, config) + device = torch.device("npu:0") + model = resnet18().eval().to(device=device, memory_format=torch.channels_last) + opt_fn = torch.compile(dynamic=False)(model) + model_input = torch.randn(args.batch, 3, 224, 224).to(device=device) + + with TOGSimulator(config_path=config): + torch.npu.launch_model(opt_fn, model_input, stream_index=0, timestamp=0) + torch.npu.synchronize() + print("ResNet18 Simulation Done") diff --git a/experiments/resnet50.py b/experiments/resnet50.py index 4e611541..5b134c13 100644 --- a/experiments/resnet50.py +++ b/experiments/resnet50.py @@ -1,49 +1,28 @@ -import torch -import torch._dynamo -import torch.utils.cpp_extension - +import os +import sys import argparse -import datetime -def run_resnet(batch, config): - from torchvision.models import resnet50 - from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request - scheduler = Scheduler(num_request_queue=1, engine_select=Scheduler.FIFO_ENGINE, togsim_config=config) - device = scheduler.execution_engine.module.custom_device() - model = resnet50().eval() - input = torch.randn(batch, 3, 224, 224).to(device=device) - opt_fn = torch.compile(dynamic=False)(model.to(device, memory_format=torch.channels_last)) +base_path = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') +sys.path.insert(0, base_path) - SchedulerDNNModel.register_model("resnet50", opt_fn) - request = Request("resnet50", [input], [], request_queue_idx=0) - scheduler.add_request(request, request_time=0) +import torch +from torchvision.models import resnet50 +from Simulator.simulator import TOGSimulator - # Run scheduler - while not scheduler.is_finished(): - with torch.no_grad(): - scheduler.schedule() - - print("ResNet50 Simulation Done") +config = os.environ.get('TOGSIM_CONFIG', f'{base_path}/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.yml') +os.environ['TOGSIM_CONFIG'] = config if __name__ == "__main__": - import os - import sys - base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') - config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.yml') - config_prefix = config.split('/')[-1].split('.')[0][9:] # extract config name from config path - sys.path.append(base_dir) args = argparse.ArgumentParser() args.add_argument('--batch', type=int, default=1) - args.add_argument('--dump_path', type=str, default='results') args = args.parse_args() - batch = args.batch - result_path = os.path.join(base_dir, args.dump_path, config_prefix, f"resnet50_{batch}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") - # setting environment variables - os.environ['TORCHSIM_LOG_PATH'] = result_path - os.environ['TORCHSIM_USE_TIMING_POOLING'] = "1" - # only timing simulation - os.environ['TORCHSIM_VALIDATION_MODE'] = "0" - if 'pytorchsim_functional_mode' in os.environ: - del os.environ['pytorchsim_functional_mode'] - run_resnet(batch, config) + device = torch.device("npu:0") + model = resnet50().eval().to(device=device, memory_format=torch.channels_last) + opt_fn = torch.compile(dynamic=False)(model) + model_input = torch.randn(args.batch, 3, 224, 224).to(device=device) + + with TOGSimulator(config_path=config): + torch.npu.launch_model(opt_fn, model_input, stream_index=0, timestamp=0) + torch.npu.synchronize() + print("ResNet50 Simulation Done") diff --git a/experiments/softmax.py b/experiments/softmax.py index d30559f7..b86febe0 100644 --- a/experiments/softmax.py +++ b/experiments/softmax.py @@ -1,47 +1,29 @@ -import torch -import torch._dynamo -import torch.utils.cpp_extension - +import os +import sys import argparse -import datetime - -def run_softmax(size, config, dim=1): - from Scheduler.scheduler import Scheduler, SchedulerDNNModel, Request - scheduler = Scheduler(num_request_queue=1, engine_select=Scheduler.FIFO_ENGINE, togsim_config=config) - device = scheduler.execution_engine.module.custom_device() - input = torch.randn(size).to(device=device) - opt_fn = torch.compile(dynamic=False)(torch.nn.Softmax(dim=dim).to(device=device)) +base_path = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') +sys.path.insert(0, base_path) - SchedulerDNNModel.register_model("Softmax", opt_fn) - request = Request("Softmax", [input], [], request_queue_idx=0) - scheduler.add_request(request, request_time=0) - - # Run scheduler - while not scheduler.is_finished(): - scheduler.schedule() +import torch +from Simulator.simulator import TOGSimulator - print(f"Softmax {str(size)} Simulation Done") +config = os.environ.get('TOGSIM_CONFIG', f'{base_path}/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.yml') +os.environ['TOGSIM_CONFIG'] = config if __name__ == "__main__": - import os - import sys - base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') - config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.yml') - config_prefix = config.split('/')[-1].split('.')[0][9:] # extract config name from config path - sys.path.append(base_dir) args = argparse.ArgumentParser() args.add_argument('--size', nargs='+', type=int, default=[512, 512], help='Tensor Shape') - args.add_argument('--dump_path', type=str, default='results') args = args.parse_args() - size = args.size - size_str = "x".join([str(i) for i in size]) - result_path = os.path.join(base_dir, args.dump_path, config_prefix, f"Softmax_{size_str}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") - # setting environment variables - os.environ['TORCHSIM_LOG_PATH'] = result_path - # only timing simulation - os.environ['TORCHSIM_VALIDATION_MODE'] = "0" - if 'pytorchsim_functional_mode' in os.environ: - del os.environ['pytorchsim_functional_mode'] - - run_softmax(size, config) + size = tuple(args.size) + dim = 1 + + device = torch.device("npu:0") + model = torch.nn.Softmax(dim=dim).to(device=device) + opt_fn = torch.compile(dynamic=False)(model) + model_input = torch.randn(*size).to(device=device) + + with TOGSimulator(config_path=config): + torch.npu.launch_model(opt_fn, model_input, stream_index=0, timestamp=0) + torch.npu.synchronize() + print(f"Softmax {size} Simulation Done") diff --git a/tests/DeepSeek/test_deepseek_v3_base.py b/tests/DeepSeek/test_deepseek_v3_base.py index b8402c8b..ade787c5 100644 --- a/tests/DeepSeek/test_deepseek_v3_base.py +++ b/tests/DeepSeek/test_deepseek_v3_base.py @@ -1,8 +1,55 @@ import os import sys import argparse +import copy +from pathlib import Path import torch +# recursive compile for some ops that are caused by graph break +torch.npu.register_eager_to_compile([ + "aten::zero_", + "aten::sum.IntList_out", + "aten::mul.out", + "aten::floor_divide", + "aten::floor_divide.Tensor", + "aten::floor_divide.Scalar", + "aten::cat.out", + "aten::sort.values_stable", +]) + + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + out_cpu = out.cpu() + max_diff = (out_cpu - cpu_out).abs().max().item() + mean_diff = (out_cpu - cpu_out).abs().mean().item() + if torch.allclose(out_cpu, cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print(f"Max absolute difference: {max_diff:.6f}") + print(f"Mean absolute difference: {mean_diff:.6f}") + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("NPU out: ", out_cpu) + print("CPU out: ", cpu_out) + print(f"Max absolute difference: {max_diff:.6f}") + print(f"Mean absolute difference: {mean_diff:.6f}") + exit(1) + + +def _extract_logits(output): + if isinstance(output, torch.Tensor): + return output + if hasattr(output, "logits"): + return output.logits + if isinstance(output, (list, tuple)) and len(output) > 0 and isinstance(output[0], torch.Tensor): + return output[0] + raise TypeError(f"Unsupported output type for comparison: {type(output)}") + def _dtype_from_str(name: str) -> torch.dtype: return { @@ -81,7 +128,7 @@ def _maybe_scale_config(config, scale=1.0, max_layers=None): def _apply_preset(scale, max_layers, batch, seq_len, preset): if preset == "tiny": - return 0.03, 4, 1, min(seq_len, 16) + return 0.03, 1, 1, min(seq_len, 16) if preset == "small": return 0.07, 8, 1, min(seq_len, 32) if preset == "medium": @@ -89,8 +136,58 @@ def _apply_preset(scale, max_layers, batch, seq_len, preset): return scale, max_layers, batch, seq_len +def _togsim_log_count() -> int: + log_dir = Path("togsim_results") + if not log_dir.exists(): + return 0 + return len(list(log_dir.glob("*.log"))) + + +def _assert_simulation_happened(before_count: int, case_name: str): + after_count = _togsim_log_count() + if after_count <= before_count: + raise RuntimeError( + f"{case_name}: TOGSim log count did not increase " + f"(before={before_count}, after={after_count})" + ) + print(f"{case_name}: TOGSim logs increased ({before_count} -> {after_count})") + + +def test_cat_default(device): + def cat_default_fn(a, b): + return torch.cat([a, b], dim=0) + + x = torch.randn(8, 16, device=device) + y = torch.randn(6, 16, device=device) + opt_fn = torch.compile(dynamic=False)(cat_default_fn) + + before = _togsim_log_count() + out = opt_fn(x, y) + _assert_simulation_happened(before, "cat.default") + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=0) + test_result("cat.default", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_out(device): + def cat_out_fn(a, b, out): + return torch.ops.aten.cat.out([a, b], 0, out=out) + + x = torch.randn(8, 16, device=device) + y = torch.randn(6, 16, device=device) + out_buf = torch.empty(14, 16, device=device) + opt_fn = torch.compile(dynamic=False)(cat_out_fn) + + before = _togsim_log_count() + out = opt_fn(x, y, out_buf) + _assert_simulation_happened(before, "cat.out") + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=0) + test_result("cat.out", out, cpu_out, rtol=1e-4, atol=1e-4) + + @torch.no_grad() -def run_deep_seek_v3_base_test( +def run_deepseek_v3_base( model_id, device, init_mode="config-random", @@ -120,7 +217,6 @@ def run_deep_seek_v3_base_test( # (call .to_dict()), so only disable it for pretrained loading path. if init_mode == "pretrained" and getattr(config, "quantization_config", None) is not None: config.quantization_config = None - config = _maybe_scale_config(config, scale=scale, max_layers=max_layers) if init_mode == "config-random": @@ -141,7 +237,6 @@ def run_deep_seek_v3_base_test( else: raise ValueError(f"Unsupported init mode: {init_mode}") - model = model.to(device) model_params = sum(p.numel() for p in model.parameters()) print("init mode:", init_mode) print("scaled hidden_size:", getattr(config, "hidden_size", "n/a")) @@ -157,23 +252,33 @@ def run_deep_seek_v3_base_test( revision=revision, ) encoded = tokenizer(prompt, return_tensors="pt") - input_ids = encoded["input_ids"].to(device) + cpu_input_ids = encoded["input_ids"].cpu() else: vocab_size = getattr(config, "vocab_size", None) if vocab_size is None: raise ValueError("Config has no vocab_size; use --use-tokenizer or pass a model with vocab_size.") - input_ids = _build_random_inputs(batch, seq_len, vocab_size, device) + cpu_input_ids = _build_random_inputs(batch, seq_len, vocab_size, torch.device("cpu")) + input_ids = cpu_input_ids.to(device) - if compile_model: - model = torch.compile(model, dynamic=False) + # CPU version + model_cpu = copy.deepcopy(model).cpu().eval() + cpu_out = _extract_logits(model_cpu(cpu_input_ids)) - out = model(input_ids) - logits = out.logits + # NPU version + model_npu = copy.deepcopy(model_cpu).to(device).eval() + if compile_model: + model_npu = torch.compile(model_npu, dynamic=False) + npu_out = _extract_logits(model_npu(input_ids)) + + # Campare results + test_result( + "DeepSeek V3 Base", + npu_out, + cpu_out, + rtol=3e-1, + atol=2e-1, + ) - print("logits shape:", tuple(logits.shape)) - print("logits dtype:", logits.dtype) - print("logits max:", logits.max().item()) - if __name__ == "__main__": parser = argparse.ArgumentParser(description="DeepSeek V3 download-based test") @@ -181,7 +286,7 @@ def run_deep_seek_v3_base_test( parser.add_argument("--revision", type=str, default=None) parser.add_argument("--trust-remote-code", action="store_true", default=True) parser.add_argument("--init-mode", type=str, default="config-random", choices=["config-random", "pretrained"]) - parser.add_argument("--preset", type=str, default="tiny", choices=["none", "tiny", "small", "medium"]) + parser.add_argument("--preset", type=str, default="small", choices=["none", "tiny", "small", "medium"]) parser.add_argument("--scale", type=float, default=1.0) parser.add_argument("--max-layers", type=int, default=None) parser.add_argument("--dtype", type=str, default="float32", choices=["float32", "float16", "bfloat16"]) @@ -190,6 +295,7 @@ def run_deep_seek_v3_base_test( parser.add_argument("--use-tokenizer", action="store_true") parser.add_argument("--prompt", type=str, default="Hello, DeepSeek V3") parser.add_argument("--compile", action="store_true", default=True) + parser.add_argument("--test", type=str, default="e2e", choices=["all", "e2e", "cat"]) args = parser.parse_args() @@ -203,18 +309,22 @@ def run_deep_seek_v3_base_test( device = torch.device("npu:0") - run_deep_seek_v3_base_test( - model_id=args.model_id, - device=device, - init_mode=args.init_mode, - scale=args.scale, - max_layers=args.max_layers, - dtype=args.dtype, - batch=args.batch, - seq_len=args.seq_len, - use_tokenizer=args.use_tokenizer, - prompt=args.prompt, - trust_remote_code=args.trust_remote_code, - revision=args.revision, - compile_model=args.compile, - ) + if args.test in ("all", "cat"): + test_cat_default(device) + test_cat_out(device) + if args.test in ("all", "e2e"): + run_deepseek_v3_base( + model_id=args.model_id, + device=device, + init_mode=args.init_mode, + scale=args.scale, + max_layers=args.max_layers, + dtype=args.dtype, + batch=args.batch, + seq_len=args.seq_len, + use_tokenizer=args.use_tokenizer, + prompt=args.prompt, + trust_remote_code=args.trust_remote_code, + revision=args.revision, + compile_model=args.compile, + ) diff --git a/tests/Fusion/__init__.py b/tests/Fusion/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_cat.py b/tests/test_cat.py new file mode 100644 index 00000000..97fcc754 --- /dev/null +++ b/tests/test_cat.py @@ -0,0 +1,200 @@ +import argparse +from pathlib import Path + +import torch + + +def _test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + return + + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out: ", out.cpu()) + print("cpu out: ", cpu_out) + raise RuntimeError(f"{name} mismatch") + +def test_cat_default(device): + def cat_default_fn(a, b): + return torch.cat([a, b], dim=0) + + x = torch.randn(8, 16, device=device) + y = torch.randn(6, 16, device=device) + opt_fn = torch.compile(dynamic=False)(cat_default_fn) + + out = opt_fn(x, y) + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=0) + _test_result("cat.default", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_out(device): + def cat_out_fn(a, b, out): + return torch.ops.aten.cat.out([a, b], 0, out=out) + + x = torch.randn(8, 16, device=device) + y = torch.randn(6, 16, device=device) + out_buf = torch.empty(14, 16, device=device) + opt_fn = torch.compile(dynamic=False)(cat_out_fn) + + out = opt_fn(x, y, out_buf) + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=0) + _test_result("cat.out", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_4d_dim0(device): + def cat_4d_dim0_fn(a, b): + return torch.cat([a, b], dim=0) + + x = torch.randn(2, 3, 4, 5, device=device) + y = torch.randn(3, 3, 4, 5, device=device) + opt_fn = torch.compile(dynamic=False)(cat_4d_dim0_fn) + + out = opt_fn(x, y) + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=0) + _test_result("cat.4d.dim0", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_4d_dim1(device): + def cat_4d_dim1_fn(a, b): + return torch.cat([a, b], dim=1) + + x = torch.randn(2, 3, 4, 5, device=device) + y = torch.randn(2, 5, 4, 5, device=device) + opt_fn = torch.compile(dynamic=False)(cat_4d_dim1_fn) + + out = opt_fn(x, y) + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=1) + _test_result("cat.4d.dim1", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_4d_dim2(device): + def cat_4d_dim2_fn(a, b): + return torch.cat([a, b], dim=2) + + x = torch.randn(2, 3, 4, 5, device=device) + y = torch.randn(2, 3, 6, 5, device=device) + opt_fn = torch.compile(dynamic=False)(cat_4d_dim2_fn) + + out = opt_fn(x, y) + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=2) + _test_result("cat.4d.dim2", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_4d_dim3(device): + def cat_4d_dim3_fn(a, b): + return torch.cat([a, b], dim=3) + + x = torch.randn(2, 3, 4, 5, device=device) + y = torch.randn(2, 3, 4, 7, device=device) + opt_fn = torch.compile(dynamic=False)(cat_4d_dim3_fn) + + out = opt_fn(x, y) + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=3) + _test_result("cat.4d.dim3", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_three_inputs(device): + def cat_three_inputs_fn(a, b, c): + return torch.cat([a, b, c], dim=0) + + x = torch.randn(4, 16, device=device) + y = torch.randn(5, 16, device=device) + z = torch.randn(3, 16, device=device) + opt_fn = torch.compile(dynamic=False)(cat_three_inputs_fn) + + out = opt_fn(x, y, z) + + cpu_out = torch.cat([x.cpu(), y.cpu(), z.cpu()], dim=0) + _test_result("cat.three_inputs", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_four_inputs(device): + def cat_four_inputs_fn(a, b, c, d): + return torch.cat([a, b, c, d], dim=0) + + x = torch.randn(3, 16, device=device) + y = torch.randn(4, 16, device=device) + z = torch.randn(5, 16, device=device) + w = torch.randn(2, 16, device=device) + opt_fn = torch.compile(dynamic=False)(cat_four_inputs_fn) + + out = opt_fn(x, y, z, w) + + cpu_out = torch.cat([x.cpu(), y.cpu(), z.cpu(), w.cpu()], dim=0) + _test_result("cat.four_inputs", out, cpu_out, rtol=1e-4, atol=1e-4) + + +def test_cat_4d_three_inputs(device): + def cat_4d_three_inputs_fn(a, b, c): + return torch.cat([a, b, c], dim=1) + + x = torch.randn(2, 3, 4, 5, device=device) + y = torch.randn(2, 4, 4, 5, device=device) + z = torch.randn(2, 5, 4, 5, device=device) + opt_fn = torch.compile(dynamic=False)(cat_4d_three_inputs_fn) + + out = opt_fn(x, y, z) + + cpu_out = torch.cat([x.cpu(), y.cpu(), z.cpu()], dim=1) + _test_result("cat.4d.three_inputs", out, cpu_out, rtol=1e-4, atol=1e-4) + +def test_cat_5d(device, dim=0): + def cat_5d_fn(a, b): + return torch.cat([a, b], dim=dim) + + x = torch.randn(2, 3, 4, 5, 6, device=device) + y = torch.randn(3, 3, 4, 5, 6, device=device) + opt_fn = torch.compile(dynamic=False)(cat_5d_fn) + + out = opt_fn(x, y) + + cpu_out = torch.cat([x.cpu(), y.cpu()], dim=dim) + _test_result("cat.5d.dim0", out, cpu_out, rtol=1e-4, atol=1e-4) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run cat simulation tests") + parser.add_argument( + "--case", + choices=[ + "default", "out", "4d_dim0", "4d_dim1", "4d_dim2", "4d_dim3", "5d" + "three_inputs", "four_inputs", "4d_three_inputs", "all" + ], + default="all", + help="Which cat case to run", + ) + args = parser.parse_args() + + device = torch.device("npu:0") + + if args.case in ("default", "all"): + test_cat_default(device) + if args.case in ("out", "all"): + test_cat_out(device) + if args.case in ("4d_dim0", "all"): + test_cat_4d_dim0(device) + if args.case in ("4d_dim1", "all"): + test_cat_4d_dim1(device) + if args.case in ("4d_dim2", "all"): + test_cat_4d_dim2(device) + if args.case in ("4d_dim3", "all"): + test_cat_4d_dim3(device) + if args.case in ("three_inputs", "all"): + test_cat_three_inputs(device) + if args.case in ("four_inputs", "all"): + test_cat_four_inputs(device) + if args.case in ("4d_three_inputs", "all"): + test_cat_4d_three_inputs(device) + if args.case in ("5d", "all"): + test_cat_5d(device) diff --git a/tests/test_sdpa.py b/tests/test_sdpa.py new file mode 100644 index 00000000..c4825731 --- /dev/null +++ b/tests/test_sdpa.py @@ -0,0 +1,145 @@ +import sys +import os +import torch +import torch._dynamo +import torch.nn.functional as F + +base_dir = os.environ.get("TORCHSIM_DIR", default="/workspace/PyTorchSim") +sys.path.append(base_dir) + +device = torch.device("npu:0") + +# --------------------------------------------------------------------------- +# Default sweep configs - edit here to change what gets tested +# --------------------------------------------------------------------------- +SDPA_DEFAULTS = dict( + n_batch_list = [1, 4, 8, 16], + n_head_list = [4, 6, 8, 12], + n_token_list = [128, 256, 512, 1024], + head_dim_list = [32, 64, 128], + is_causal = False, +) + +GQA_DEFAULTS = dict( + batch_list = [1], + num_kv_heads = 1, + gqa_ratios = [4, 5, 8, 16], # Hq = ratio * num_kv_heads + seq_len_list = [128, 256, 1024], + head_dim_list = [64, 128], + query_len = 1, # decode shape: Lq == 1 + is_causal = True, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- +def clear_caches(): + from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache + from torch._inductor.codecache import FxGraphCache + AOTAutogradCache.clear() + torch._dynamo.reset() + os.environ["TORCHINDUCTOR_CACHE"] = "0" + FxGraphCache.clear() + + +def assert_close(name, out, cpu_out, rtol=1e-4, atol=1e-4): + msg = f"|{name} Test Passed|" + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + print("-" * len(msg)) + print(msg) + print("-" * len(msg)) + else: + print(f"[FAIL] {name}") + print(" device out:", out.cpu()) + print(" cpu out:", cpu_out) + exit(1) + + +def _run_sdpa(device, q, k, v, **kwargs): + """Compile and run SDPA on device; return result on device.""" + opt_fn = torch.compile(dynamic=False)(F.scaled_dot_product_attention) + return opt_fn(q.to(device), k.to(device), v.to(device), **kwargs) + + +def _cpu_sdpa(q, k, v, **kwargs): + """Run reference SDPA on CPU.""" + return F.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(), **kwargs) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- +def test_sdpa( + device, + n_batch_list = SDPA_DEFAULTS["n_batch_list"], + n_head_list = SDPA_DEFAULTS["n_head_list"], + n_token_list = SDPA_DEFAULTS["n_token_list"], + head_dim_list = SDPA_DEFAULTS["head_dim_list"], + is_causal = SDPA_DEFAULTS["is_causal"], +): + torch.manual_seed(0) + sdpa_kwargs = dict(attn_mask=None, dropout_p=0.0, is_causal=is_causal) + + for B in n_batch_list: + for H in n_head_list: + for S in n_token_list: + for D in head_dim_list: + clear_caches() + q = torch.rand(B, H, S, D, dtype=torch.float32) + k = torch.rand(B, H, S, D, dtype=torch.float32) + v = torch.rand(B, H, S, D, dtype=torch.float32) + + out = _run_sdpa(device, q, k, v, **sdpa_kwargs) + cpu_out = _cpu_sdpa(q, k, v, **sdpa_kwargs) + + assert_close(f"SDPA(B:{B}, H:{H}, S:{S}, D:{D})", out, cpu_out) + + print("All SDPA tests passed!") + + +def test_gqa( + device, + batch_list = GQA_DEFAULTS["batch_list"], + num_kv_heads = GQA_DEFAULTS["num_kv_heads"], + gqa_ratios = GQA_DEFAULTS["gqa_ratios"], + seq_len_list = GQA_DEFAULTS["seq_len_list"], + head_dim_list= GQA_DEFAULTS["head_dim_list"], + query_len = GQA_DEFAULTS["query_len"], + is_causal = GQA_DEFAULTS["is_causal"], +): + """ + GQA sweep: q shape (B, Hq, Lq, D), kv shape (B, H, S, D). + Hq = ratio * num_kv_heads for each ratio in gqa_ratios. + """ + torch.manual_seed(0) + sdpa_kwargs = dict(attn_mask=None, dropout_p=0.0, is_causal=is_causal, enable_gqa=True) + + for B in batch_list: + for S in seq_len_list: + for D in head_dim_list: + for ratio in gqa_ratios: + Hq = ratio * num_kv_heads + clear_caches() + q = torch.rand(B, Hq, query_len, D, dtype=torch.float32) + k = torch.rand(B, num_kv_heads, S, D, dtype=torch.float32) + v = torch.rand(B, num_kv_heads, S, D, dtype=torch.float32) + + out = _run_sdpa(device, q, k, v, **sdpa_kwargs) + cpu_out = _cpu_sdpa(q, k, v, **sdpa_kwargs) + + assert_close( + f"GQA(B:{B}, Hq:{Hq}, H:{num_kv_heads}, S:{S}, D:{D})", + out, cpu_out, + ) + + print("All GQA tests passed!") + + +if __name__ == "__main__": + with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.FLASH_ATTENTION]): + test_sdpa(device) + #test_gqa(device) + + # Example: quick single-config run + # test_gqa(device, batch_list=[1], gqa_ratios=[5], seq_len_list=[32], head_dim_list=[128]) diff --git a/tests/test_sort.py b/tests/test_sort.py new file mode 100644 index 00000000..05afe92b --- /dev/null +++ b/tests/test_sort.py @@ -0,0 +1,124 @@ +import argparse +import torch + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out:", out.cpu()) + print("cpu out:", cpu_out) + raise SystemExit(1) + + +def test_equal(name, out, cpu_out): + if torch.equal(out.cpu(), cpu_out): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out:", out.cpu()) + print("cpu out:", cpu_out) + raise SystemExit(1) + +def test_sort(device, size=(128, 128), dim=-1, descending=False, stable=True): + def sort_test(x): + return torch.sort(x, dim=dim, descending=descending, stable=stable) + + x = torch.randn(size, dtype=torch.float32) + x_npu = x.to(device=device) + + opt_sort = torch.compile(dynamic=False)(sort_test) + out_values, out_indices = opt_sort(x_npu) + ref_values, ref_indices = torch.sort(x, stable=stable, dim=dim, descending=descending) + + prefix = "Sort.stable" if stable else "Sort.unstable" + test_result(f"{prefix}/values size={size}, dim={dim}, desc={descending}", out_values, ref_values) + if stable: + test_result(f"{prefix}/indices size={size}, dim={dim}, desc={descending}", out_indices, ref_indices) + else: + # Unstable sort does not guarantee tie ordering; validate index-value consistency instead. + gathered = torch.gather(x, dim, out_indices.cpu()) + test_result(f"{prefix}/indices_gather size={size}, dim={dim}, desc={descending}", gathered, out_values.cpu()) + + +def test_sort_stable_suite(device): + # Keep sort-axis sizes compatible with backend constraints (vector-size multiple). + cases = [ + {"size": (64,), "dim": 0, "descending": False}, # 1D + {"size": (4, 64), "dim": 1, "descending": True}, # 2D, last dim + {"size": (2, 8, 32), "dim": 2, "descending": False}, # 3D, last dim + {"size": (2, 16, 4), "dim": 1, "descending": True}, # 3D, middle dim + {"size": (2, 4, 8, 32), "dim": 3, "descending": False}, # 4D, last dim + {"size": (4, 2, 32, 8), "dim": 2, "descending": True}, # 4D, inner dim + ] + for case in cases: + test_sort( + device=device, + size=case["size"], + dim=case["dim"], + descending=case["descending"], + stable=True, + ) + + +def test_sort_duplicate_cases(device): + duplicate_cases = [ + {"size": (64,), "dim": 0, "descending": False}, + {"size": (4, 64), "dim": 1, "descending": True}, + {"size": (2, 8, 32), "dim": 2, "descending": False}, + ] + for case in duplicate_cases: + base = torch.arange(case["size"][case["dim"]], dtype=torch.int64) % 7 + view_shape = [1] * len(case["size"]) + view_shape[case["dim"]] = case["size"][case["dim"]] + x = base.view(view_shape).expand(case["size"]).to(torch.float32) + noise = torch.randn(case["size"], dtype=torch.float32) * 0.0 + x = x + noise + + def sort_test(inp): + return torch.sort(inp, dim=case["dim"], descending=case["descending"], stable=True) + + out_values, out_indices = torch.compile(dynamic=False)(sort_test)(x.to(device=device)) + ref_values, ref_indices = torch.sort( + x, dim=case["dim"], descending=case["descending"], stable=True + ) + test_result(f"Sort.dup/stable_values {case}", out_values, ref_values) + test_equal(f"Sort.dup/stable_indices {case}", out_indices, ref_indices) + + def sort_test_unstable(inp): + return torch.sort(inp, dim=case["dim"], descending=case["descending"], stable=False) + + out_values_u, out_indices_u = torch.compile(dynamic=False)(sort_test_unstable)(x.to(device=device)) + ref_values_u, _ = torch.sort(x, dim=case["dim"], descending=case["descending"], stable=False) + test_result(f"Sort.dup/unstable_values {case}", out_values_u, ref_values_u) + gathered_u = torch.gather(x, case["dim"], out_indices_u.cpu()) + test_result(f"Sort.dup/unstable_gather {case}", gathered_u, out_values_u.cpu()) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run sort tests") + parser.add_argument("--shape", type=str, default="(64, 32, 16)") + parser.add_argument("--dim", type=int, default=0) + parser.add_argument("--descending", action="store_true") + args = parser.parse_args() + + shape = tuple(map(int, args.shape.strip("()").split(","))) + + from Scheduler.scheduler import PyTorchSimRunner + + module = PyTorchSimRunner.setup_device() + device = module.custom_device() + + test_sort_stable_suite(device) + test_sort_duplicate_cases(device) \ No newline at end of file diff --git a/tests/test_topk.py b/tests/test_topk.py index c8565310..caf56779 100644 --- a/tests/test_topk.py +++ b/tests/test_topk.py @@ -31,21 +31,11 @@ def topk_fn(a): opt_topk = torch.compile(dynamic=False)(topk_fn) res_values, res_indices = opt_topk(x) - ref_values, ref_indices = torch.topk(x.cpu(), k, dim=dim, largest=largest, sorted=sorted) test_result("TopK/values", res_values, ref_values) test_result("TopK/indices", res_indices, ref_indices) if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser(description="Run LayerNorm test with dynamic shape") - parser.add_argument('--shape', type=str, default="(512,768)") - args = parser.parse_args() - shape = tuple(map(int, args.shape.strip('()').split(','))) - - from Scheduler.scheduler import ExecutionEngine - module = ExecutionEngine.setup_device() - device = module.custom_device() + device = torch.device('npu:0') test_topk(device, (128, 128), k=2, dim=-1) \ No newline at end of file