From 0607d31c77a6bdcc3f82489c3fd9966ef7328cf9 Mon Sep 17 00:00:00 2001 From: Jiangyong Date: Wed, 3 Dec 2025 10:36:07 +0800 Subject: [PATCH 01/22] add sparse attention VSA --- example/ck_tile/50_sparse_attn/CMakeLists.txt | 89 + example/ck_tile/50_sparse_attn/bias.hpp | 100 ++ .../50_sparse_attn/codegen/__init__.py | 0 .../50_sparse_attn/codegen/cpp_symbol_map.py | 141 ++ .../50_sparse_attn/codegen/ops/__init__.py | 0 .../codegen/ops/fmha_fwd_jenga.py | 750 ++++++++ .../codegen/ops/fmha_fwd_vsa.py | 752 ++++++++ .../ck_tile/50_sparse_attn/fmha_fwd_trek.hpp | 725 ++++++++ example/ck_tile/50_sparse_attn/generate.py | 131 ++ .../50_sparse_attn/jenga_sparse_attention.cu | 216 +++ .../50_sparse_attn/jenga_sparse_attention.h | 61 + example/ck_tile/50_sparse_attn/mask.hpp | 176 ++ .../50_sparse_attn/test_vsa_sparse_attn.cpp | 490 ++++++ .../50_sparse_attn/vsa_sparse_attention.cu | 216 +++ example/ck_tile/CMakeLists.txt | 1 + ...ock_fmha_pipeline_qr_ks_vs_async_jenga.hpp | 887 ++++++++++ ...block_fmha_pipeline_qr_ks_vs_async_vsa.hpp | 872 ++++++++++ .../ops/sparse_attn/fmha_fwd_jenga_kernel.hpp | 1507 ++++++++++++++++ .../ops/sparse_attn/fmha_fwd_vsa_kernel.hpp | 1524 +++++++++++++++++ 19 files changed, 8638 insertions(+) create mode 100644 example/ck_tile/50_sparse_attn/CMakeLists.txt create mode 100644 example/ck_tile/50_sparse_attn/bias.hpp create mode 100644 example/ck_tile/50_sparse_attn/codegen/__init__.py create mode 100644 example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py create mode 100644 example/ck_tile/50_sparse_attn/codegen/ops/__init__.py create mode 100644 example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py create mode 100644 example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py create mode 100644 example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp create mode 100644 example/ck_tile/50_sparse_attn/generate.py create mode 100644 example/ck_tile/50_sparse_attn/jenga_sparse_attention.cu create mode 100644 example/ck_tile/50_sparse_attn/jenga_sparse_attention.h create mode 100644 example/ck_tile/50_sparse_attn/mask.hpp create mode 100644 example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp create mode 100644 example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu create mode 100644 include/ck_tile/ops/sparse_attn/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp create mode 100644 include/ck_tile/ops/sparse_attn/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp create mode 100644 include/ck_tile/ops/sparse_attn/fmha_fwd_jenga_kernel.hpp create mode 100644 include/ck_tile/ops/sparse_attn/fmha_fwd_vsa_kernel.hpp diff --git a/example/ck_tile/50_sparse_attn/CMakeLists.txt b/example/ck_tile/50_sparse_attn/CMakeLists.txt new file mode 100644 index 00000000000..533fe6587ac --- /dev/null +++ b/example/ck_tile/50_sparse_attn/CMakeLists.txt @@ -0,0 +1,89 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +# CMakeLists.txt for sparse attention (Jenga and VSA) + +# Use SUPPORTED_GPU_TARGETS directly +set(INST_TARGETS ${SUPPORTED_GPU_TARGETS}) +set(GPU_TARGETS ${SUPPORTED_GPU_TARGETS}) + +message(STATUS "VSA Sparse Attention: SUPPORTED_GPU_TARGETS=${SUPPORTED_GPU_TARGETS}, INST_TARGETS=${INST_TARGETS}") + +list(FILTER INST_TARGETS INCLUDE REGEX "gfx9|gfx12") +if(NOT INST_TARGETS) + message(WARNING "Skipping Tile Engine Sparse Attention: No supported GPU targets found") + return() +endif() + +message(STATUS "Building VSA Sparse Attention for targets: ${INST_TARGETS}") + +# Code generation scripts +file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS + ${CMAKE_CURRENT_LIST_DIR}/generate.py + ${CMAKE_CURRENT_LIST_DIR}/codegen/*.py +) +set_directory_properties(PROPERTIES CMAKE_CONFIGURE_DEPENDS "${CODE_GEN_SCRIPTS}") + +# Code generation for VSA (receipt 600 for aiter integration) +set(SPARSE_ATTN_VSA_CODE_GEN_ARGS + ${CMAKE_CURRENT_LIST_DIR}/generate.py + --api fwd_vsa + --receipt 600 +) + +# Generate list of VSA kernels (at configure time, only list, not generate) +execute_process( + COMMAND ${Python3_EXECUTABLE} ${SPARSE_ATTN_VSA_CODE_GEN_ARGS} + --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/vsa_blob_list.txt + RESULT_VARIABLE ret +) +if(ret AND NOT ret EQUAL 0) + message(FATAL_ERROR "Failed to generate VSA kernel list") +endif() + +file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/vsa_blob_list.txt SPARSE_ATTN_VSA_GEN_BLOBS) + +# Generate the kernel source files at build time (not configure time) +add_custom_command( + OUTPUT ${SPARSE_ATTN_VSA_GEN_BLOBS} + COMMAND ${Python3_EXECUTABLE} ${SPARSE_ATTN_VSA_CODE_GEN_ARGS} + --output_dir ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${CODE_GEN_SCRIPTS} + COMMENT "Generate CK Tile VSA Sparse Attention kernels" +) + +message(STATUS "VSA kernel files to be generated: ${SPARSE_ATTN_VSA_GEN_BLOBS}") + +# VSA Instances +set(SPARSE_ATTN_VSA_INSTANCES "tile_sparse_attn_vsa_instances") + +add_library(${SPARSE_ATTN_VSA_INSTANCES} OBJECT EXCLUDE_FROM_ALL + ${SPARSE_ATTN_VSA_GEN_BLOBS} + ${CMAKE_CURRENT_LIST_DIR}/vsa_sparse_attention.cu +) +target_include_directories(${SPARSE_ATTN_VSA_INSTANCES} PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn +) +set_source_files_properties(${SPARSE_ATTN_VSA_GEN_BLOBS} PROPERTIES LANGUAGE HIP) +set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/vsa_sparse_attention.cu PROPERTIES LANGUAGE HIP) +set_property(TARGET ${SPARSE_ATTN_VSA_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) + +# Compile options +target_compile_options(${SPARSE_ATTN_VSA_INSTANCES} PRIVATE + -DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN + -DCK_TILE_FMHA_FWD_FAST_EXP2 + -Wno-undefined-func-template + -Wno-float-equal +) + +# Test executable +set(TEST_VSA_SPARSE_ATTN "tile_test_vsa_sparse_attn") +add_executable(${TEST_VSA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_vsa_sparse_attn.cpp) +target_link_libraries(${TEST_VSA_SPARSE_ATTN} ${SPARSE_ATTN_VSA_INSTANCES}) +target_include_directories(${TEST_VSA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_compile_options(${TEST_VSA_SPARSE_ATTN} PRIVATE + -Wno-undefined-func-template + -Wno-float-equal +) + +set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) diff --git a/example/ck_tile/50_sparse_attn/bias.hpp b/example/ck_tile/50_sparse_attn/bias.hpp new file mode 100644 index 00000000000..f9dc656f637 --- /dev/null +++ b/example/ck_tile/50_sparse_attn/bias.hpp @@ -0,0 +1,100 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha.hpp" + +// keep sync with BlockAttentionBiasEnum +enum class bias_enum +{ + no_bias = 0, + elementwise_bias = 1, + alibi = 2, +}; + +struct bias_info +{ + bias_enum type; + /* + * simple dispatch logic + * + * if type == elementwise_bias: + * if rank_info == 0: + * bias is 1*1*s*s + * elif rank_info == 1: + * bias is 1*h*s*s + * elif rank_info == 2: + * bias is b*h*s*s + * + * elif type == alibi: + * if rank_info == 0: + * alibi in 1*h + * elif rank_info == 1: + * alibi in b*h + */ + int rank_info; + + void serialize(std::ostream& os) const + { + if(type == bias_enum::no_bias) + os << "n"; + else if(type == bias_enum::elementwise_bias) + { + os << "e"; + if(rank_info != 0) + { + os << "[" << rank_info << "]"; + } + } + else if(type == bias_enum::alibi) + { + os << "alibi"; + if(rank_info != 0) + { + os << "[" << rank_info << "]"; + } + } + } + + static bias_info decode(std::string str) + { + bias_info info{bias_enum::no_bias, 0}; + if(str == "0" || str == "n") + { + info.type = bias_enum::no_bias; + } + else if(str.compare(0, 1, "1") == 0 || str.compare(0, 1, "e") == 0 || + str.compare(0, 11, "elementwise") == 0) + { + info.type = bias_enum::elementwise_bias; + auto found_0 = str.find(':'); + if(found_0 != std::string::npos) + { + std::string e = str.substr(found_0 + 1); + info.rank_info = atoi(e.c_str()); + } + } + else if(str.compare(0, 1, "2") == 0 || str.compare(0, 1, "a") == 0 || + str.compare(0, 5, "alibi") == 0) + { + info.type = bias_enum::alibi; + auto found_0 = str.find(':'); + if(found_0 != std::string::npos) + { + std::string e = str.substr(found_0 + 1); + info.rank_info = atoi(e.c_str()); + } + } + return info; + } + + friend std::ostream& operator<<(std::ostream& os, const bias_info& bi) + { + bi.serialize(os); + return os; + } +}; diff --git a/example/ck_tile/50_sparse_attn/codegen/__init__.py b/example/ck_tile/50_sparse_attn/codegen/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py b/example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py new file mode 100644 index 00000000000..1f5a03e243d --- /dev/null +++ b/example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py @@ -0,0 +1,141 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +FWD_DTYPE_MAP = { + "fp16" : "FmhaFwdFp16", + "bf16" : "FmhaFwdBf16", + "fp8" : "FmhaFwdFp8", + "fp8fp16": "FmhaFwdFp8Fp16", + "fp8bf16": "FmhaFwdFp8Bf16" +} + +BWD_DTYPE_MAP = { + "fp16": "FmhaBwdFp16", + "bf16": "FmhaBwdBf16" +} + +MASK_IMPL = { + "generic" : "ck_tile::GenericAttentionMask", + "simplified" : "ck_tile::SimplifiedGenericAttentionMask" +} + +_MASK_SIMPLIFIED_MAP = { + "s_no" : "ck_tile::SimplifiedGenericAttentionMask", + "s_mask" : "ck_tile::SimplifiedGenericAttentionMask", +} + +_MASK_MAP = { + "no" : "FmhaMasks::NoMask", + "causal" : "FmhaMasks::CausalMask", + "generic" : "FmhaMasks::GenericMask" +} + +def get_mask_map(mask : str): + if mask == "generic": + return _MASK_MAP + elif mask == "simplified": + return _MASK_SIMPLIFIED_MAP + else: + assert False + return None + +_MASK_CHECK_MAP = { + "no" : "t.mask_type == mask_enum::no_mask", + "causal" : "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right", + "generic" : "t.mask_type == mask_enum::window_generic", +} + +_MASK_SIMPLIFIED_CHECK_MAP = { + "s_no" : "t.mask_type == mask_enum::no_mask", + "s_mask" : "t.mask_type != mask_enum::no_mask", +} + +def get_mask_check_map(mask : str): + if mask == "generic": + return _MASK_CHECK_MAP + elif mask == "simplified": + return _MASK_SIMPLIFIED_CHECK_MAP + else: + assert False + return None + +BIAS_MAP = { + "no" : "ck_tile::BlockAttentionBiasEnum::NO_BIAS", + "bias" : "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS", + "alibi" : "ck_tile::BlockAttentionBiasEnum::ALIBI" +} + +# TODO: this is ugly +BIAS_CHECK_MAP = { + "no" : "bias_enum::no_bias", + "bias" : "bias_enum::elementwise_bias", + "alibi" : "bias_enum::alibi" +} + +DROPOUT_MAP = { + "no" : "ck_tile::BlockDropoutBwd", + "dropout_wg32" : "ck_tile::BlockDropoutBwd", + "dropout_wg32_storerandval" : "ck_tile::BlockDropoutBwd", + "dropout_wg16" : "ck_tile::BlockDropoutBwd", + "dropout_wg16_storerandval" : "ck_tile::BlockDropoutBwd" +} + +DROPOUT_CHECK_MAP = { + "no" : "t.has_dropout == false", + "dropout_wg32" : "t.has_dropout == true && t.is_store_randval == false", + "dropout_wg32_storerandval" : "t.has_dropout == true && t.is_store_randval == true", + "dropout_wg16" : "t.has_dropout == true && t.is_store_randval == false", + "dropout_wg16_storerandval" : "t.has_dropout == true && t.is_store_randval == true", +} + +ROPE_MAP = { + "no" : "ck_tile::RotaryEmbeddingEnum::NONE", + "inter" : "ck_tile::RotaryEmbeddingEnum::INTERLEAVED", + "half" : "ck_tile::RotaryEmbeddingEnum::HALF_ROTATED" +} + +ROPE_CHECK_MAP = { + "no" : "rope_enum::none", + "inter" : "rope_enum::interleaved", + "half" : "rope_enum::half_rotated" +} + +MODE_MAP = { + "batch" : "false", + "group" : "true" +} + +LAYOUT_MAP = { + "row" : "true", + "col" : "false" +} + +PIPELINE_MAP = { + "qr" : "ck_tile::BlockFmhaPipelineQRKSVS", + "qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsyncJenga", + "qs" : "ck_tile::BlockFmhaPipelineQSKSVS", + "qr_async_vsa" : "ck_tile::BlockFmhaPipelineQRKSVSAsyncVSA", +} + +PIPELINE_ENUM_MAP = { + "qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", + "qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", + "qr_async_vsa" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", + "qr_nwarp_sshuffle" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", + "qs" : "ck_tile::BlockFmhaPipelineEnum::QSKSVS", + "qr_pagedkv" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", + "qr_async_trload" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD", +} + +BOOL_MAP = { + "t" : "true", + "f" : "false", + True : "true", + False : "false", +} + +SQUANT_MAP = { + "t" : "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR", + "f" : "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE", +} diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/__init__.py b/example/ck_tile/50_sparse_attn/codegen/ops/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py new file mode 100644 index 00000000000..388c7d3a685 --- /dev/null +++ b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py @@ -0,0 +1,750 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import copy +from dataclasses import dataclass, field +import fnmatch +import itertools +import os +from pathlib import Path +from typing import List, Optional, Tuple + +from codegen.cpp_symbol_map import * + +GEN_DIR = "" + +import os.path as path + + +def update_file(file_path, content): + """Update the file at file_path with the given content if it differs from the existing content. + + It avoids unnecessary touching of the file which triggers rebuilds + """ + + existing_content = "" + if path.exists(file_path): + with open(file_path, "r") as file: + existing_content = file.read() + if existing_content == content: + return + with open(file_path, "w") as file: + file.write(content) + + +DTYPE_BITS = { + "fp32": 32, + "fp16": 16, + "bf16": 16, + "fp8" : 8, + "bf8" : 8 +} + +K0_MAX_SUBMAX_MAP = { + 32 : 32, + 64 : 64, + 96 : 128, + 128: 128, + 192: 192, + 256: 256 +} + +FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n +// auto generated by generate.py +#include "ck_tile/ops/fmha/block/variants.hpp" +#include "fmha_fwd_trek.hpp" +#include "block_fmha_pipeline_qr_ks_vs_async_jenga.hpp" +#include "fmha_fwd_jenga_kernel.hpp" + +""" + +FMHA_FWD_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; + +using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, + ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, + ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, + ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, + {F_vlayout}>; + +using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_logits}, + {F_bias}, + false, + {F_lse}, + {F_dropout}, + {F_squant}, + {F_occupancy}, + {F_skip}>; + +using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; + +using fmha_mask_{F_idx} = {F_mask}; + +using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_{F_idx}, + {F_mode}, + fmha_variant_{F_idx}, + fmha_mask_{F_idx}, + {F_trload}, + fmha_trait_{F_idx}>; + +using fmha_pipeline_{F_idx} = {F_pipeline}< + fmha_pipeline_problem_{F_idx}>; + +using fmha_epilogue_{F_idx} = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, + {F_spad}, {F_dvpad}>>; + +using fmha_kernel_{F_idx} = + ck_tile::FmhaFwdJengaKernel; + +using trait_{F_idx} = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; + +#include + +template<> +float fmha_jenga_fwd_(const ck_tile::stream_config& s, fmha_jenga_fwd_args a) +{{ + using k_ = fmha_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} +""" + +FMHA_FWD_API_FILENAME="fmha_jenga_fwd_api.cpp" +FMHA_FWD_API=""" +#include + +#include + +namespace {{ +bool get_num_cus(unsigned& num_cus) {{ + int device; + auto status = hipGetDevice(&device); + if(status != hipSuccess) {{ + fprintf(stderr, "failed to get device"); + return false; + }} + + hipDeviceProp_t props{{}}; + status = hipGetDeviceProperties(&props, device); + if(status != hipSuccess) {{ + fprintf(stderr, "failed to get device properties"); + return false; + }} + + num_cus = props.multiProcessorCount; + return true; +}} + +unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{ + const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0; + const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1 + + return batch * nheads * num_m_blocks * num_n_blocks; +}} +}} // namespace + +float fmha_jenga_fwd(fmha_jenga_fwd_traits t, fmha_jenga_fwd_args a, const ck_tile::stream_config& s){{ + float r = -1; + + [[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate + + unsigned num_cus; + if (!get_num_cus(num_cus)) {{ + return r; + }} + + [[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{ + return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0); + }}; + + const bool has_load_tr = ck_tile::is_load_tr_supported(); + +{F_dispatch} + return r; +}} +""" + +FMHA_FWD_API_PER_TRLOAD=""" {F_if}({F_trload_cond}){{ +{F_dtype_case} + }} +""" + +FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ +{F_inner_dispatch} + }} +""" + +FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && + ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ + using trait_ = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; + return fmha_jenga_fwd_(s, a); + }} +""" + +@dataclass +class CppConstraint: + bool_expr: str = None + + def __str__(self): + if self.bool_expr is None: + return 'true' + else: + return f'{self.bool_expr}' + + def __and__(self, other): + return CppConstraint(f'({str(self)}) && ({str(other)})') + +@dataclass +class FmhaFwdApiTrait: + pipeline_tag : str + # sync with fmha_fwd_traits<>, to generate fallback calls + hdim : str + dtype : str # data type + mode : str # value from MODE_MAP + bm0 : int # tile size along q seqlen (block size) + bn0 : int # tile size along qk seqlen + bk0 : int # tile size along qk gemm unroll + bn1 : int # tile size along v head_dim + bk1 : int # tile size along kv gemm unroll + bk0max : int + vlayout : str + logits : str + mask : str + bias : str # + lse : str # + dropout : str + squant : str # + spad : str + skpad : str + dpad : str + dvpad : str + skip : str + tr_load : str + constraint : CppConstraint + + @property + def name(self) -> str: + return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\ + f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}' + + @property + def scheck(self) -> str: + if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag in ['qr_async', 'qr_async_trload']: + if self.spad == 't' : return 'true' # always support + else : return 'true' + elif self.pipeline_tag in ['qr', 'qs']: + if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_q % {self.bm0} == 0' + else: assert False + + @property + def seqtune(self) -> str: + if self.bm0 == 128: return 'true/*fall back to largest tile*/' # group mode only generate spad/skpad == true + else: + return f'a.seqlen_q <= {self.bm0}' + + @property + def skcheck(self) -> str: + if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag == 'qr_async': + if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' + else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' + elif self.pipeline_tag in ['qr', 'qs']: + if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_k % {self.bn0} == 0' + elif self.pipeline_tag == 'qr_async_trload': + if self.skpad == 't' : return 'true' + else: return 'true' + else: assert False + + @property + def dcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dpad == 't': return f'a.hdim_q % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr', 'qs', 'qr_async_trload']: + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_q % {bk0submax} == 0' + else: assert False + + @property + def dvcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr', 'qs', 'qr_async_trload']: + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_v % {bk0submax} == 0' + else: assert False + +@dataclass +class FmhaFwdPipeline: + tag : str + + F_vlayout : str # row/col + F_spad : str # true/false + F_skpad : str # + F_dpad : str # + F_dvpad : str # + F_logits : str # t/f + F_bias : str # true/false + F_lse : str # + F_dropout : str # + F_squant : str # + F_mask : str # value from MASK_MAP + F_skip : str # true/false + F_trload : str # true/false + F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint()) + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_skpad == 't' : n += 'sk' + if self.F_dpad == 't' : n += 'd' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f'{self.tag}_v{self.F_vlayout[0]}' + if pn != '' : n += f'_{pn}' + else: n += '_npad' + + if self.F_logits == 't' : n += '_logits' + else: n += '_nlogits' + + if self.F_bias != 'no' : n += f'_{self.F_bias}' + else: n += '_nbias' + + if self.F_mask[0:2] == 's_': + if self.F_mask == 's_mask': n += f'_mask' + else: n += '_nmask' + else: + if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + else: n += '_nmask' + + if self.F_lse == 't' : n += '_lse' + else: n += '_nlse' + + if self.F_dropout == 't' : n += '_dropout' + else: n += '_ndropout' + + if self.F_skip == 't' : n += '_skip' + else: n += '_nskip' + + if self.F_squant == 't' : n += '_squant' + else: n += '_nsquant' + + if self.F_trload == 't' : n += '_trload' + else: n += '_ntrload' + + return n + +class FmhaFwdApiPool: + def __init__(self, mask_impl): + self.pool = dict() + self.mask_impl = mask_impl + + def register_traits(self, trait : FmhaFwdApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.pool.keys(): + self.pool[trait.dtype] = dict() + hdim = trait.hdim, trait.bn1 + if hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][hdim] = list() + + self.pool[trait.dtype][hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + tr_load_cond_map = { + "t": "has_load_tr", + "f": "true" + } + + per_tr_load =str() + for tr_load in ["t", "f"]: + per_dtypes=str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case=str() + for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): + traits=[t for t in self.pool[dtype][(hdim, hdim_v)] if tr_load == t.tr_load] + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], + F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip], F_trload=BOOL_MAP[trait.tr_load], + F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_seqtune=trait.seqtune, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, + F_constraint=trait.constraint, + F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, + F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + per_tr_load += FMHA_FWD_API_PER_TRLOAD.format(F_if='if', F_trload_cond=tr_load_cond_map[tr_load], F_dtype_case=per_dtypes) + if not per_tr_load: + # empty string we add some ignore to suppress warning in api + per_tr_load += ' (void)t ; (void)s ; (void)a;' + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_tr_load) + +@dataclass +class FmhaFwdTileSize: + F_bm0 : int # tile size along q seqlen (block size) + F_bn0 : int # tile size along k seqlen + F_bk0 : int # tile size along qk gemm unroll + F_bn1 : int # tile size along v head_dim + F_bk1 : int # tile size along kv gemm unroll + F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) + F_rm0 : int # number of warps for gemm0 along q seqlen + F_rn0 : int # number of warps for gemm0 along k seqlen + F_rk0 : int # number of warps for gemm0 along head dim q (not used) + F_rm1 : int # number of warps for gemm1 along q seqlen + F_rn1 : int # number of warps for gemm1 along head dim v + F_rk1 : int # number of warps for gemm1 along k seqlen (not used) + F_wm0 : int # gemm0 warp size along m + F_wn0 : int # gemm0 warp size along n + F_wk0 : int # gemm0 warp size along k + F_wm1 : int # gemm1 warp size along m + F_wn1 : int # gemm1 warp size along n + F_wk1 : int # gemm1 warp size along k + F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint()) + + @property + def name(self) -> str: + return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\ + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\ + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" +\ + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + +@dataclass +class FmhaFwdKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_mode : str # value from MODE_MAP + F_tile : FmhaFwdTileSize + F_pipeline : FmhaFwdPipeline + mask_impl : str + + @property + def template(self) -> str: + kernel_body = str() + return FMHA_FWD_KERNEL_HEADER + \ + FMHA_FWD_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = FWD_DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn0 = self.F_tile.F_bn0, + F_bk0 = self.F_tile.F_bk0, + F_bn1 = self.F_tile.F_bn1, + F_bk1 = self.F_tile.F_bk1, + F_bk0max = self.F_tile.F_bk0max, + F_rm0 = self.F_tile.F_rm0, + F_rn0 = self.F_tile.F_rn0, + F_rk0 = self.F_tile.F_rk0, + F_rm1 = self.F_tile.F_rm1, + F_rn1 = self.F_tile.F_rn1, + F_rk1 = self.F_tile.F_rk1, + F_wm0 = self.F_tile.F_wm0, + F_wn0 = self.F_tile.F_wn0, + F_wk0 = self.F_tile.F_wk0, + F_wm1 = self.F_tile.F_wm1, + F_wn1 = self.F_tile.F_wn1, + F_wk1 = self.F_tile.F_wk1, + F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad = BOOL_MAP[self.F_pipeline.F_spad], + F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_logits = BOOL_MAP[self.F_pipeline.F_logits], + F_bias = BIAS_MAP[self.F_pipeline.F_bias], + F_lse = BOOL_MAP[self.F_pipeline.F_lse], + F_dropout = BOOL_MAP[self.F_pipeline.F_dropout], + F_squant = BOOL_MAP[self.F_pipeline.F_squant], + F_skip = BOOL_MAP[self.F_pipeline.F_skip], + F_occupancy = self.F_tile.F_occupancy, + F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], + F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mode = MODE_MAP[self.F_mode], + F_pipeline = PIPELINE_MAP[self.F_pipeline.tag], + F_trload = BOOL_MAP[self.F_pipeline.F_trload]) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return f"fmha_jenga_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ + self.F_tile.name + '_' + self.F_pipeline.name + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaFwdApiTrait: + return FmhaFwdApiTrait( + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0max=self.F_tile.F_bk0max, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + logits=self.F_pipeline.F_logits, + bias=self.F_pipeline.F_bias, + lse=self.F_pipeline.F_lse, + dropout=self.F_pipeline.F_dropout, + squant=self.F_pipeline.F_squant, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad, + skip=self.F_pipeline.F_skip, + tr_load=self.F_pipeline.F_trload, + constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint) + +class KernelComponentFactory: + # TODO: design a more practical way to do it + # this is current supported tile size per hdim + @staticmethod + def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]: + if dtype == 'fp16' or dtype == 'bf16': + return { + # (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + # (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), + # FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + # (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (128,128) : [FmhaFwdTileSize(16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), + FmhaFwdTileSize(32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), + FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + # (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], + # (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + # (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], + # (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + } + elif dtype == 'fp8' or dtype == 'bf8': + return { + (64,64 ) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)], + (128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], + (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], + } + else: + return None + + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + @staticmethod + def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]: + # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # TODO: currently for qr pipeline, let 't' padding to appear later!! + # TODO: how to design this more generic? + squant = 't' if dtype == 'fp8' else 'f' + pipelines = [] + if dtype in ['fp16', 'bf16']: + for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]): + if hdim == 256 and hdim_v == 256: + # print("jenga fmha only support dim=128 now.") + continue + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) + # the below two is used for hdim vectorize load + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) + else: + if bias == "bias": + # print("jenga_fmha with bias is not implemented.") + continue + # TODO: rocm 6.2 compiler problem if using qr_async for bias case + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) + else: + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) + # if (hdim, hdim_v) in [(64, 64), (128, 128)] and logits == "f" and bias == "no" and dropout == "f" and lse == "f" and skip == "f": + # pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 't')) + # pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 't')) + # if receipt == 1 and bias != "bias": + # pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) # TODO: cover arbitraty hdim + elif dtype in ['fp8', 'bf8']: + # print("jenga fmha only support 16-bit compute.") + return pipelines + # no need lse/dropout kernels + for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f')) + elif dtype in ['fp8fp16', 'fp8bf16']: + # TODO + None + else: + assert False + return pipelines + +class CustomFactory(KernelComponentFactory): + @staticmethod + def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]: + result = KernelComponentFactory.get_hdim_tile_size_dict(dtype) + if dtype == 'fp16' or dtype == 'bf16': + if (128, 128) in result.keys(): + result[(128, 128)].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint('get_num_blocks(128) < num_cus * min_cu_util_rate'))) + return result + +def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: + gen = list() + api_pool = FmhaFwdApiPool(mask_impl) + + factory = CustomFactory if os.environ.get('CK_TILE_FMHA_FWD_CUSTOM_FACTORY', '0') == '1' else KernelComponentFactory + + for dtype in FWD_DTYPE_MAP.keys(): + d = factory.get_hdim_tile_size_dict(dtype) + if d == None: + continue + #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), MODE_MAP.keys()): + for tile, pipeline in itertools.product(tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)): + if mode == "group": + if pipeline.F_spad != 't' or pipeline.F_skpad != 't': + # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not + continue + if (hdim, hdim_v) == (192, 128): + # NOTE: this is used to speedup deepseek prefill case, we don't gen training + if pipeline.F_bias != 'no' or pipeline.F_dropout == 't': + continue + if pipeline.tag != 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 != 128) or ((hdim, hdim_v) != (128, 128) and tile.F_bm0 != 128)): + # non qr_async_trload only support km0=128 tile size when hdim is not 128 + # non qr_async only support kn0=128 tile size when hdim is 128 + continue + if pipeline.tag == 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 == 128) or ((hdim, hdim_v) not in [(64, 64), (128, 128)])): + continue + # logits_soft_cap is only allowed if no bias + if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): + continue + if pipeline.tag != "qr_async": + continue + k = FmhaFwdKernel(F_idx=2, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl) + if kernel_filter != '': + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue + # 2 - Flash attention integration + if receipt in (2, 3): + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'alibi'] + cond &= pipeline.F_squant == 'f' + cond &= pipeline.F_skip == 'f' + if not cond: + continue + # PyTorch integration + elif receipt == 4: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'bias'] + cond &= pipeline.F_squant == 'f' + cond &= mode == 'batch' + cond &= pipeline.F_skip == 'f' + cond &= pipeline.F_logits == 'f' + if not cond: + continue + # Aiter(mha_fwd) integration + elif receipt == 100: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == 'batch' + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue + # Aiter(mha_varlen_fwd) integration + elif receipt == 200: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == 'group' + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue + # aiter::mha_fwd C++ api integration + elif receipt == 600: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue + + api_pool.register_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + +def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: + update_file(autogen_dir / kernel.filename, kernel.template) + +def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: + update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api) + +def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: + api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + for kernel in kernels: + write_single_fwd_kernel(kernel, output_dir) + write_fwd_api(api_pool, output_dir) + +def list_blobs(file_path : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: + with file_path.open('a') as f: + _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py new file mode 100644 index 00000000000..1e7bbfc9c2a --- /dev/null +++ b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py @@ -0,0 +1,752 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import copy +from dataclasses import dataclass, field +import fnmatch +import itertools +import os +from pathlib import Path +from typing import List, Optional, Tuple + +from codegen.cpp_symbol_map import * +from codegen.cpp_symbol_map import SQUANT_MAP + +GEN_DIR = "" + +import os.path as path + + +def update_file(file_path, content): + """Update the file at file_path with the given content if it differs from the existing content. + + It avoids unnecessary touching of the file which triggers rebuilds + """ + + existing_content = "" + if path.exists(file_path): + with open(file_path, "r") as file: + existing_content = file.read() + if existing_content == content: + return + with open(file_path, "w") as file: + file.write(content) + + +DTYPE_BITS = { + "fp32": 32, + "fp16": 16, + "bf16": 16, + "fp8" : 8, + "bf8" : 8 +} + +K0_MAX_SUBMAX_MAP = { + 32 : 32, + 64 : 64, + 96 : 128, + 128: 128, + 192: 192, + 256: 256 +} + +FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n +// auto generated by generate.py +#include "ck_tile/ops/fmha/block/variants.hpp" +#include "fmha_fwd_trek.hpp" +#include "block_fmha_pipeline_qr_ks_vs_async_vsa.hpp" +#include "fmha_fwd_vsa_kernel.hpp" + +""" + +FMHA_FWD_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; + +using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, + ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, + ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, + ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, + {F_vlayout}>; + +using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_logits}, + {F_bias}, + false, + {F_lse}, + {F_dropout}, + {F_squant_enum}, + {F_occupancy}, + {F_skip}>; + +using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; + +using fmha_mask_{F_idx} = {F_mask}; + +using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_{F_idx}, + {F_mode}, + fmha_variant_{F_idx}, + fmha_mask_{F_idx}, + {F_trload}, + fmha_trait_{F_idx}>; + +using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaPipelineQRKSVSAsyncVSA< + fmha_pipeline_problem_{F_idx}>; + +using fmha_epilogue_{F_idx} = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, + {F_spad}, {F_dvpad}>>; + +using fmha_kernel_{F_idx} = + ck_tile::FmhaFwdVSAKernel; + +using trait_{F_idx} = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; + +#include + +template<> +float fmha_vsa_fwd_(const ck_tile::stream_config& s, fmha_jenga_fwd_args a) +{{ + using k_ = fmha_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} +""" + +FMHA_FWD_API_FILENAME="fmha_vsa_fwd_api.cpp" +FMHA_FWD_API=""" +#include + +#include + +namespace {{ +bool get_num_cus(unsigned& num_cus) {{ + int device; + auto status = hipGetDevice(&device); + if(status != hipSuccess) {{ + fprintf(stderr, "failed to get device"); + return false; + }} + + hipDeviceProp_t props{{}}; + status = hipGetDeviceProperties(&props, device); + if(status != hipSuccess) {{ + fprintf(stderr, "failed to get device properties"); + return false; + }} + + num_cus = props.multiProcessorCount; + return true; +}} + +unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{ + const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0; + const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1 + + return batch * nheads * num_m_blocks * num_n_blocks; +}} +}} // namespace + +float fmha_vsa_fwd(fmha_jenga_fwd_traits t, fmha_jenga_fwd_args a, const ck_tile::stream_config& s){{ + float r = -1; + + [[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate + + unsigned num_cus; + if (!get_num_cus(num_cus)) {{ + return r; + }} + + [[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{ + return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0); + }}; + + const bool has_load_tr = ck_tile::is_load_tr_supported(); + +{F_dispatch} + return r; +}} +""" + +FMHA_FWD_API_PER_TRLOAD=""" {F_if}({F_trload_cond}){{ +{F_dtype_case} + }} +""" + +FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ +{F_inner_dispatch} + }} +""" + +FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && + ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ + using trait_ = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; + return fmha_vsa_fwd_(s, a); + }} +""" + +@dataclass +class CppConstraint: + bool_expr: str = None + + def __str__(self): + if self.bool_expr is None: + return 'true' + else: + return f'{self.bool_expr}' + + def __and__(self, other): + return CppConstraint(f'({str(self)}) && ({str(other)})') + +@dataclass +class FmhaFwdApiTrait: + pipeline_tag : str + # sync with fmha_fwd_traits<>, to generate fallback calls + hdim : str + dtype : str # data type + mode : str # value from MODE_MAP + bm0 : int # tile size along q seqlen (block size) + bn0 : int # tile size along qk seqlen + bk0 : int # tile size along qk gemm unroll + bn1 : int # tile size along v head_dim + bk1 : int # tile size along kv gemm unroll + bk0max : int + vlayout : str + logits : str + mask : str + bias : str # + lse : str # + dropout : str + squant : str # + spad : str + skpad : str + dpad : str + dvpad : str + skip : str + tr_load : str + constraint : CppConstraint + + @property + def name(self) -> str: + return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\ + f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}' + + @property + def scheck(self) -> str: + if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag in ['qr_async_vsa', 'qr_async_trload']: + if self.spad == 't' : return 'true' # always support + else : return 'true' + elif self.pipeline_tag in ['qr', 'qs']: + if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_q % {self.bm0} == 0' + else: assert False + + @property + def seqtune(self) -> str: + if self.bm0 == 128: return 'true/*fall back to largest tile*/' # group mode only generate spad/skpad == true + else: + return f'a.seqlen_q <= {self.bm0}' + + @property + def skcheck(self) -> str: + if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag == 'qr_async_vsa': + if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' + else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' + elif self.pipeline_tag in ['qr', 'qs']: + if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_k % {self.bn0} == 0' + elif self.pipeline_tag == 'qr_async_trload': + if self.skpad == 't' : return 'true' + else: return 'true' + else: assert False + + @property + def dcheck(self) -> str: + if self.pipeline_tag == 'qr_async_vsa': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dpad == 't': return f'a.hdim_q % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr', 'qs', 'qr_async_trload']: + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_q % {bk0submax} == 0' + else: assert False + + @property + def dvcheck(self) -> str: + if self.pipeline_tag == 'qr_async_vsa': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr', 'qs', 'qr_async_trload']: + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_v % {bk0submax} == 0' + else: assert False + +@dataclass +class FmhaFwdPipeline: + tag : str + + F_vlayout : str # row/col + F_spad : str # true/false + F_skpad : str # + F_dpad : str # + F_dvpad : str # + F_logits : str # t/f + F_bias : str # true/false + F_lse : str # + F_dropout : str # + F_squant : str # + F_mask : str # value from MASK_MAP + F_skip : str # true/false + F_trload : str # true/false + F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint()) + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_skpad == 't' : n += 'sk' + if self.F_dpad == 't' : n += 'd' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f'{self.tag}_v{self.F_vlayout[0]}' + if pn != '' : n += f'_{pn}' + else: n += '_npad' + + if self.F_logits == 't' : n += '_logits' + else: n += '_nlogits' + + if self.F_bias != 'no' : n += f'_{self.F_bias}' + else: n += '_nbias' + + if self.F_mask[0:2] == 's_': + if self.F_mask == 's_mask': n += f'_mask' + else: n += '_nmask' + else: + if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + else: n += '_nmask' + + if self.F_lse == 't' : n += '_lse' + else: n += '_nlse' + + if self.F_dropout == 't' : n += '_dropout' + else: n += '_ndropout' + + if self.F_skip == 't' : n += '_skip' + else: n += '_nskip' + + if self.F_squant == 't' : n += '_squant' + else: n += '_nsquant' + + if self.F_trload == 't' : n += '_trload' + else: n += '_ntrload' + + return n + +class FmhaFwdApiPool: + def __init__(self, mask_impl): + self.pool = dict() + self.mask_impl = mask_impl + + def register_traits(self, trait : FmhaFwdApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.pool.keys(): + self.pool[trait.dtype] = dict() + hdim = trait.hdim, trait.bn1 + if hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][hdim] = list() + + self.pool[trait.dtype][hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + tr_load_cond_map = { + "t": "has_load_tr", + "f": "true" + } + + per_tr_load =str() + for tr_load in ["t", "f"]: + per_dtypes=str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case=str() + for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): + traits=[t for t in self.pool[dtype][(hdim, hdim_v)] if tr_load == t.tr_load] + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], + F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip], F_trload=BOOL_MAP[trait.tr_load], + F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_seqtune=trait.seqtune, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, + F_constraint=trait.constraint, + F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, + F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + per_tr_load += FMHA_FWD_API_PER_TRLOAD.format(F_if='if', F_trload_cond=tr_load_cond_map[tr_load], F_dtype_case=per_dtypes) + if not per_tr_load: + # empty string we add some ignore to suppress warning in api + per_tr_load += ' (void)t ; (void)s ; (void)a;' + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_tr_load) + +@dataclass +class FmhaFwdTileSize: + F_bm0 : int # tile size along q seqlen (block size) + F_bn0 : int # tile size along k seqlen + F_bk0 : int # tile size along qk gemm unroll + F_bn1 : int # tile size along v head_dim + F_bk1 : int # tile size along kv gemm unroll + F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) + F_rm0 : int # number of warps for gemm0 along q seqlen + F_rn0 : int # number of warps for gemm0 along k seqlen + F_rk0 : int # number of warps for gemm0 along head dim q (not used) + F_rm1 : int # number of warps for gemm1 along q seqlen + F_rn1 : int # number of warps for gemm1 along head dim v + F_rk1 : int # number of warps for gemm1 along k seqlen (not used) + F_wm0 : int # gemm0 warp size along m + F_wn0 : int # gemm0 warp size along n + F_wk0 : int # gemm0 warp size along k + F_wm1 : int # gemm1 warp size along m + F_wn1 : int # gemm1 warp size along n + F_wk1 : int # gemm1 warp size along k + F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint()) + + @property + def name(self) -> str: + return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\ + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\ + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" +\ + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + +@dataclass +class FmhaFwdKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_mode : str # value from MODE_MAP + F_tile : FmhaFwdTileSize + F_pipeline : FmhaFwdPipeline + mask_impl : str + + @property + def template(self) -> str: + kernel_body = str() + return FMHA_FWD_KERNEL_HEADER + \ + FMHA_FWD_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = FWD_DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn0 = self.F_tile.F_bn0, + F_bk0 = self.F_tile.F_bk0, + F_bn1 = self.F_tile.F_bn1, + F_bk1 = self.F_tile.F_bk1, + F_bk0max = self.F_tile.F_bk0max, + F_rm0 = self.F_tile.F_rm0, + F_rn0 = self.F_tile.F_rn0, + F_rk0 = self.F_tile.F_rk0, + F_rm1 = self.F_tile.F_rm1, + F_rn1 = self.F_tile.F_rn1, + F_rk1 = self.F_tile.F_rk1, + F_wm0 = self.F_tile.F_wm0, + F_wn0 = self.F_tile.F_wn0, + F_wk0 = self.F_tile.F_wk0, + F_wm1 = self.F_tile.F_wm1, + F_wn1 = self.F_tile.F_wn1, + F_wk1 = self.F_tile.F_wk1, + F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad = BOOL_MAP[self.F_pipeline.F_spad], + F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_logits = BOOL_MAP[self.F_pipeline.F_logits], + F_bias = BIAS_MAP[self.F_pipeline.F_bias], + F_lse = BOOL_MAP[self.F_pipeline.F_lse], + F_dropout = BOOL_MAP[self.F_pipeline.F_dropout], + F_squant = BOOL_MAP[self.F_pipeline.F_squant], + F_squant_enum = SQUANT_MAP[self.F_pipeline.F_squant], + F_skip = BOOL_MAP[self.F_pipeline.F_skip], + F_occupancy = self.F_tile.F_occupancy, + F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], + F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mode = MODE_MAP[self.F_mode], + F_pipeline = PIPELINE_MAP[self.F_pipeline.tag], + F_trload = BOOL_MAP[self.F_pipeline.F_trload]) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return f"fmha_vsa_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ + self.F_tile.name + '_' + self.F_pipeline.name + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaFwdApiTrait: + return FmhaFwdApiTrait( + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0max=self.F_tile.F_bk0max, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + logits=self.F_pipeline.F_logits, + bias=self.F_pipeline.F_bias, + lse=self.F_pipeline.F_lse, + dropout=self.F_pipeline.F_dropout, + squant=self.F_pipeline.F_squant, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad, + skip=self.F_pipeline.F_skip, + tr_load=self.F_pipeline.F_trload, + constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint) + +class KernelComponentFactory: + # TODO: design a more practical way to do it + # this is current supported tile size per hdim + @staticmethod + def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]: + if dtype == 'fp16' or dtype == 'bf16': + return { + # (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + # (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), + # FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + # (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (128,128) : [FmhaFwdTileSize(16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), + FmhaFwdTileSize(32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), + FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + # (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], + # (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + # (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], + # (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + } + elif dtype == 'fp8' or dtype == 'bf8': + return { + (64,64 ) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)], + (128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], + (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], + } + else: + return None + + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + @staticmethod + def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]: + # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # TODO: currently for qr pipeline, let 't' padding to appear later!! + # TODO: how to design this more generic? + squant = 't' if dtype == 'fp8' else 'f' + pipelines = [] + if dtype in ['fp16', 'bf16']: + for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]): + if hdim == 256 and hdim_v == 256: + # print("vsa fmha only support dim=128 now.") + continue + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) + # the below two is used for hdim vectorize load + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) + else: + if bias == "bias": + # print("vsa_fmha with bias is not implemented.") + continue + # TODO: rocm 6.2 compiler problem if using qr_async for bias case + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) + else: + pipelines.append(FmhaFwdPipeline('qr_async_vsa', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append(FmhaFwdPipeline('qr_async_vsa', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) + # if (hdim, hdim_v) in [(64, 64), (128, 128)] and logits == "f" and bias == "no" and dropout == "f" and lse == "f" and skip == "f": + # pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 't')) + # pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 't')) + # if receipt == 1 and bias != "bias": + # pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) # TODO: cover arbitraty hdim + elif dtype in ['fp8', 'bf8']: + # print("vsa fmha only support 16-bit compute.") + return pipelines + # no need lse/dropout kernels + for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f')) + elif dtype in ['fp8fp16', 'fp8bf16']: + # TODO + None + else: + assert False + return pipelines + +class CustomFactory(KernelComponentFactory): + @staticmethod + def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]: + result = KernelComponentFactory.get_hdim_tile_size_dict(dtype) + if dtype == 'fp16' or dtype == 'bf16': + if (128, 128) in result.keys(): + result[(128, 128)].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint('get_num_blocks(128) < num_cus * min_cu_util_rate'))) + return result + +def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: + gen = list() + api_pool = FmhaFwdApiPool(mask_impl) + + factory = CustomFactory if os.environ.get('CK_TILE_FMHA_FWD_CUSTOM_FACTORY', '0') == '1' else KernelComponentFactory + + for dtype in FWD_DTYPE_MAP.keys(): + d = factory.get_hdim_tile_size_dict(dtype) + if d == None: + continue + #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), MODE_MAP.keys()): + for tile, pipeline in itertools.product(tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)): + if mode == "group": + if pipeline.F_spad != 't' or pipeline.F_skpad != 't': + # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not + continue + if (hdim, hdim_v) == (192, 128): + # NOTE: this is used to speedup deepseek prefill case, we don't gen training + if pipeline.F_bias != 'no' or pipeline.F_dropout == 't': + continue + if pipeline.tag != 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 != 128) or ((hdim, hdim_v) != (128, 128) and tile.F_bm0 != 128)): + # non qr_async_trload only support km0=128 tile size when hdim is not 128 + # non qr_async only support kn0=128 tile size when hdim is 128 + continue + if pipeline.tag == 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 == 128) or ((hdim, hdim_v) not in [(64, 64), (128, 128)])): + continue + # logits_soft_cap is only allowed if no bias + if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): + continue + if pipeline.tag != "qr_async_vsa": + continue + k = FmhaFwdKernel(F_idx=1, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl) + if kernel_filter != '': + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue + # 2 - Flash attention integration + if receipt in (2, 3): + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'alibi'] + cond &= pipeline.F_squant == 'f' + cond &= pipeline.F_skip == 'f' + if not cond: + continue + # PyTorch integration + elif receipt == 4: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'bias'] + cond &= pipeline.F_squant == 'f' + cond &= mode == 'batch' + cond &= pipeline.F_skip == 'f' + cond &= pipeline.F_logits == 'f' + if not cond: + continue + # Aiter(mha_fwd) integration + elif receipt == 100: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == 'batch' + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue + # Aiter(mha_varlen_fwd) integration + elif receipt == 200: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == 'group' + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue + # aiter::mha_fwd C++ api integration + elif receipt == 600: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue + + api_pool.register_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + +def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: + update_file(autogen_dir / kernel.filename, kernel.template) + +def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: + update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api) + +def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: + api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + for kernel in kernels: + write_single_fwd_kernel(kernel, output_dir) + write_fwd_api(api_pool, output_dir) + +def list_blobs(file_path : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: + with file_path.open('a') as f: + _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") diff --git a/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp b/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp new file mode 100644 index 00000000000..613c6e7fa0c --- /dev/null +++ b/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp @@ -0,0 +1,725 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/fmha.hpp" + +#include "mask.hpp" +#include "bias.hpp" + +#include +#include +#include + +struct FmhaFwdFp16 +{ +}; + +struct FmhaFwdBf16 +{ +}; + +struct FmhaFwdFp8 +{ +}; + +struct FmhaFwdBf8 +{ +}; + +struct FmhaFwdFp8Fp16 +{ +}; + +struct FmhaFwdFp8Bf16 +{ +}; + +template +struct FmhaFwdTypeConfig; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::half_t; + using KDataType = ck_tile::half_t; + using VDataType = ck_tile::half_t; + using BiasDataType = ck_tile::half_t; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::half_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::half_t; +}; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::bf16_t; + using KDataType = ck_tile::bf16_t; + using VDataType = ck_tile::bf16_t; + using BiasDataType = ck_tile::bf16_t; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::bf16_t; +}; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::fp8_t; + using KDataType = ck_tile::fp8_t; + using VDataType = ck_tile::fp8_t; + using BiasDataType = float; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::fp8_t; +}; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::bf8_t; + using KDataType = ck_tile::bf8_t; + using VDataType = ck_tile::bf8_t; + using BiasDataType = ck_tile::bf8_t; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::bf8_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::bf8_t; +}; + +struct FmhaMasks +{ + using NoMask = ck_tile::GenericAttentionMask; + using GenericMask = ck_tile::GenericAttentionMask; + using CausalMask = ck_tile::GenericAttentionMask; +}; + +struct fmha_sparge_fwd_args +{ + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* lut_ptr; + const void* valid_block_num_ptr; + const void* bias_ptr; // bias or alibi_slope pointer + void* rand_val_ptr; + void* lse_ptr; + void* o_ptr; + + const void* seqstart_q_ptr; + const void* seqstart_k_ptr; + const void* + seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + + float pv_threshold; + float scale_s; + float scale_p; + float scale_o; + + float logits_soft_cap; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 + ck_tile::index_t stride_randval; + ck_tile::index_t stride_o; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_bias; + ck_tile::index_t nhead_stride_randval; + ck_tile::index_t nhead_stride_lse; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_bias; + ck_tile::index_t batch_stride_randval; + ck_tile::index_t batch_stride_lse; + ck_tile::index_t batch_stride_o; + + ck_tile::index_t window_size_left; + ck_tile::index_t window_size_right; + ck_tile::index_t mask_type; + ck_tile::index_t min_seqlen_q; + + float p_drop; + bool s_randval; + + std::variant, std::pair> + drop_seed_offset; +}; + +template +auto fmha_fwd_create_kargs_and_grids(fmha_sparge_fwd_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(FmhaKernel::kIsGroupMode) + { + return FmhaKernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.lut_ptr, + args.valid_block_num_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_ptr, + args.o_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.pv_threshold, + args.scale_s, + args.scale_p, + args.scale_o, + args.logits_soft_cap, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse, + args.nhead_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.min_seqlen_q, + args.p_drop, + args.s_randval, + args.drop_seed_offset); + } + else + { // create batch mode kernel arguments + return FmhaKernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.lut_ptr, + args.valid_block_num_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_ptr, + args.o_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.pv_threshold, + args.scale_s, + args.scale_p, + args.scale_o, + args.logits_soft_cap, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_randval, + args.batch_stride_lse, + args.batch_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); + } + }(); + + if constexpr(FmhaKernel::kIsGroupMode) + { + dim3 grids = FmhaKernel::GridSize( + args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.seqlen_k_ptr != nullptr); + return ck_tile::make_tuple(kargs, grids); + } + else + { + dim3 grids = + FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, false); + return ck_tile::make_tuple(kargs, grids); + } +} + +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template +struct fmha_sparge_fwd_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr ck_tile::index_t kM0 = kM0_; + static constexpr ck_tile::index_t kN0 = kN0_; + static constexpr ck_tile::index_t kK0 = kK0_; + static constexpr ck_tile::index_t kN1 = kN1_; + static constexpr ck_tile::index_t kK1 = kK1_; + static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_; + static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; + static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_; + static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr auto BiasEnum = BiasEnum_; + static constexpr bool kStoreLse = kStoreLse_; + static constexpr bool kHasDropout = kHasDropout_; + static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadSK = kPadSK_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kPadDv = kPadDv_; + static constexpr bool kUseTrLoad = kUseTrLoad_; + static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; +}; + +struct fmha_sparge_fwd_traits +{ + int hdim_q; + int hdim_v; + std::string data_type; + bool is_group_mode; + bool is_v_rowmajor; + bool has_logits_soft_cap; + mask_enum mask_type; + bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum + bool has_lse; + bool has_dropout; + bool do_fp8_static_quant; + bool skip_min_seqlen_q = false; + // TODO: padding check is inside this api +}; + +float fmha_sparge_fwd(fmha_sparge_fwd_traits, fmha_sparge_fwd_args, const ck_tile::stream_config&); + +template +float fmha_sparge_fwd_(const ck_tile::stream_config&, fmha_sparge_fwd_args); + +float fmha_sparge_fwd(fmha_sparge_fwd_args, const ck_tile::stream_config&); + +// jenga +struct fmha_jenga_fwd_args +{ + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* block_relation_onehot_ptr; + const void* lut_ptr; + const void* valid_block_num_ptr; + const void* bias_ptr; // bias or alibi_slope pointer + void* rand_val_ptr; + void* lse_ptr; + void* o_ptr; + + const void* seqstart_q_ptr; + const void* seqstart_k_ptr; + const void* + seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + + float scale_s; + float scale_p; + float scale_o; + + float logits_soft_cap; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 + ck_tile::index_t stride_randval; + ck_tile::index_t stride_o; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_bias; + ck_tile::index_t nhead_stride_randval; + ck_tile::index_t nhead_stride_lse; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_bias; + ck_tile::index_t batch_stride_randval; + ck_tile::index_t batch_stride_lse; + ck_tile::index_t batch_stride_o; + + ck_tile::index_t window_size_left; + ck_tile::index_t window_size_right; + ck_tile::index_t mask_type; + ck_tile::index_t min_seqlen_q; + + float p_drop; + bool s_randval; + + std::variant, std::pair> + drop_seed_offset; +}; + +template +auto fmha_fwd_create_kargs_and_grids(fmha_jenga_fwd_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + if constexpr(VSA) { + // create group mode kernel arguments + if constexpr(FmhaKernel::kIsGroupMode) + { + return FmhaKernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.lut_ptr, + args.valid_block_num_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_ptr, + args.o_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.scale_p, + args.scale_o, + args.logits_soft_cap, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse, + args.nhead_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.min_seqlen_q, + args.p_drop, + args.s_randval, + args.drop_seed_offset); + } + else + { // create batch mode kernel arguments + return FmhaKernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.lut_ptr, + args.valid_block_num_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_ptr, + args.o_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.scale_p, + args.scale_o, + args.logits_soft_cap, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_randval, + args.batch_stride_lse, + args.batch_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); + } + } else { + // create group mode kernel arguments + if constexpr(FmhaKernel::kIsGroupMode) + { + return FmhaKernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.block_relation_onehot_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_ptr, + args.o_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.scale_p, + args.scale_o, + args.logits_soft_cap, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse, + args.nhead_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.min_seqlen_q, + args.p_drop, + args.s_randval, + args.drop_seed_offset); + } + else + { // create batch mode kernel arguments + return FmhaKernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.block_relation_onehot_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_ptr, + args.o_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.scale_p, + args.scale_o, + args.logits_soft_cap, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_randval, + args.batch_stride_lse, + args.batch_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); + } + } + }(); + + if constexpr(FmhaKernel::kIsGroupMode) + { + dim3 grids = FmhaKernel::GridSize( + args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.seqlen_k_ptr != nullptr); + return ck_tile::make_tuple(kargs, grids); + } + else + { + dim3 grids = + FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, false); + return ck_tile::make_tuple(kargs, grids); + } +} + +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template +struct fmha_jenga_fwd_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr ck_tile::index_t kM0 = kM0_; + static constexpr ck_tile::index_t kN0 = kN0_; + static constexpr ck_tile::index_t kK0 = kK0_; + static constexpr ck_tile::index_t kN1 = kN1_; + static constexpr ck_tile::index_t kK1 = kK1_; + static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_; + static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; + static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_; + static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr auto BiasEnum = BiasEnum_; + static constexpr bool kStoreLse = kStoreLse_; + static constexpr bool kHasDropout = kHasDropout_; + static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadSK = kPadSK_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kPadDv = kPadDv_; + static constexpr bool kUseTrLoad = kUseTrLoad_; + static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; +}; + +struct fmha_jenga_fwd_traits +{ + int hdim_q; + int hdim_v; + std::string data_type; + bool is_group_mode; + bool is_v_rowmajor; + bool has_logits_soft_cap; + mask_enum mask_type; + bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum + bool has_lse; + bool has_dropout; + bool do_fp8_static_quant; + bool skip_min_seqlen_q = false; + // TODO: padding check is inside this api +}; + +float fmha_jenga_fwd(fmha_jenga_fwd_traits, fmha_jenga_fwd_args, const ck_tile::stream_config&); + +template +float fmha_jenga_fwd_(const ck_tile::stream_config&, fmha_jenga_fwd_args); + +float fmha_jenga_fwd(fmha_jenga_fwd_args, const ck_tile::stream_config&); + +float fmha_vsa_fwd(fmha_jenga_fwd_traits, fmha_jenga_fwd_args, const ck_tile::stream_config&); + +template +float fmha_vsa_fwd_(const ck_tile::stream_config&, fmha_jenga_fwd_args); + +float fmha_vsa_fwd(fmha_jenga_fwd_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/50_sparse_attn/generate.py b/example/ck_tile/50_sparse_attn/generate.py new file mode 100644 index 00000000000..eaeb555e05b --- /dev/null +++ b/example/ck_tile/50_sparse_attn/generate.py @@ -0,0 +1,131 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import argparse +from enum import IntEnum +from pathlib import Path +import pkgutil +import sys +from typing import List, Optional + +import codegen.ops + + +class HandlerId(IntEnum): + LIST_BLOBS = 0 + WRITE_BLOBS = 1 + +# inspect all modules under 'codegen.ops' and register API handlers +ops = [] +for importer, module_name, _ in pkgutil.iter_modules(codegen.ops.__path__): + full_module_name = '%s.%s' % (codegen.ops.__name__, module_name) + ops.append(importer.find_spec(module_name).loader.load_module(module_name)) +unwanted_prefix = 'fmha_' +handlers = dict( + [(op.__name__[len(unwanted_prefix):] if op.__name__.startswith(unwanted_prefix) else op.__name__, + (op.list_blobs, op.write_blobs)) for op in ops] +) +assert 0 < len(handlers) + +def write_blobs(output_dir: Optional[str], api_list : List[str], filters_list : List[str], optdim_list : List[int], receipt, mask_impl) -> None: + if output_dir is None: + output_dir = Path(__file__).parent + else: + output_dir = Path(output_dir) + + output_dir.mkdir(parents=True, exist_ok=True) + + for api, kernel_filter in zip(api_list, filters_list): + handler = handlers[api][HandlerId.WRITE_BLOBS] + handler(output_dir, kernel_filter, receipt, optdim_list, mask_impl) + +# list all the files that will be generated +def list_blobs(output_file : Optional[str], api_list : List[str], filters_list : List[str], optdim_list : List[int], receipt, mask_impl) -> None: + assert output_file is not None + file_path = Path(output_file) + + # create an empty file / drop its contents if it exists + open(file_path, "w").close() + + for api, kernel_filter in zip(api_list, filters_list): + handler = handlers[api][HandlerId.LIST_BLOBS] + handler(file_path, kernel_filter, receipt, optdim_list, mask_impl) + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="generate", + description="gen API for CK fmha kernel", + ) + parser.add_argument( + "-d", + "--direction", # we keep 'direction' option for backward compatibility + "-a", + "--api", + default='fwd_jenga', + required=False, + help="supply API(s) to generate (default: fwd). separated by comma." + ) + parser.add_argument( + "-o", + "--output_dir", + required=False, + help="write all the blobs into a directory" + ) + parser.add_argument( + "-l", + "--list_blobs", + required=False, + help="list all the kernels to a file" + ) + # TODO: if using filter, must apply same value to output_dir and list_blobs + parser.add_argument( + "-f", + "--filter", + default='', + required=False, + help="filter out kernels that need to generate, using fnmatch module" + ) + + parser.add_argument( + "-m", + "--mask", + default="simplified", + required=False, + help="mask implementation, simplified/generic" + ) + + parser.add_argument( + "-r", + "--receipt", + default=0, + required=False, + help="codegen receipt. 0: generate only 8xhdim coverage\n" + \ + " 1: generate more instance to cover all hdim\n" + \ + " 2: Only generate instance for Flash attention integration\n" + \ + " 4: Only generate instance for PyTorch integration\n" + \ + " 100-199: Only generate instance for Aiter(mha_fwd) integration\n" + \ + " 200-299: Only generate instance for Aiter(mha_varlen_fwd) integration\n" + \ + " 300-399: Only generate instance for Aiter(mha_bwd) integration\n" + \ + " 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration\n" + \ + " 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration" + ) + + parser.add_argument( + "--optdim", + default='-1', + required=False, + help="only optimize the hdim in the list. separated by comma. -1 is the default choice" + \ + "eg. --optdim=32,64,128,256" + ) + + args = parser.parse_args() + api_list = args.direction.split(',') + filter_list = args.filter.split(',') + filter_list.extend([''] * (len(api_list) - len(filter_list))) + optdim_list = [int(hdim) for hdim in args.optdim.split(',')] + + if args.list_blobs is not None: + list_blobs(args.list_blobs, api_list, filter_list, optdim_list, int(args.receipt), mask_impl=args.mask) + else: + write_blobs(args.output_dir, api_list, filter_list, optdim_list, int(args.receipt), mask_impl=args.mask) diff --git a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cu b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cu new file mode 100644 index 00000000000..a9e9f4f8010 --- /dev/null +++ b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cu @@ -0,0 +1,216 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "jenga_sparse_attention.h" +#include "fmha_fwd_trek.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +ck_tile::HostTensor jenga_sparse_attention( + ck_tile::HostTensor &TQ, + ck_tile::HostTensor &TK, + ck_tile::HostTensor &TV, + ck_tile::HostTensor &Tblock_relation_onehot, + ck_tile::HostTensor &Y, + std::optional> bias = std::nullopt, + std::optional> lse = std::nullopt, + std::optional> seqstart_q = std::nullopt, + std::optional> seqstart_k = std::nullopt, + int bias_type = 0, + int batch = 0, + int nhead = 0, + int nhead_k = 0, + int seqlen_q = 0, + int seqlen_k = 0, + int hdim_q = 0, + int hdim_v = 0, + int mode = 0, + bool i_perm = true, + bool o_perm = true, + int max_seqlen_q = 0, + int max_seqlen_k = 0 +){ + std::string data_type = "fp16"; + if (TQ.dtype() == ck_tile::bf16_t) { + data_type = "bf16"; + } + + if (max_seqlen_q == 0) max_seqlen_q = seqlen_q; + if (max_seqlen_k == 0) max_seqlen_k = seqlen_k; + bool is_v_rowmajor = true; + int seqlen_knew = 0; + float scale_s = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); + float scale_p = 1.f; + float scale_o = 1.f; + const float logits_soft_cap = 0.0; + + std::string msk_str = "0"; + mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k); + + const ck_tile::index_t shape_batch = (mode == 0 ? batch : 1); + const ck_tile::index_t shape_seqlen_q = (mode == 0 ? seqlen_q : max_seqlen_q); + const ck_tile::index_t shape_seqlen_k = (mode == 0 ? seqlen_k : max_seqlen_k); + + ck_tile::stream_config stream_config{nullptr, + false, // time_kernel + 0, /* log_level = */ + 0, + 1, + false}; + + const auto init_args = [&](auto& args) { + assert(nhead % nhead_k == 0); + const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); + const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); + const ck_tile::index_t stride_knew = (i_perm ? hdim_q : nhead_k * hdim_q); + const ck_tile::index_t stride_v = [&]() { + if(is_v_rowmajor) + return i_perm ? hdim_v : nhead_k * hdim_v; + else + return (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k); + }(); + const ck_tile::index_t stride_vnew = [&]() { + if(is_v_rowmajor) + return i_perm ? hdim_v : nhead_k * hdim_v; + else + return i_perm ? seqlen_knew : nhead_k * seqlen_knew; + }(); + const ck_tile::index_t stride_bias = (i_perm ? max_seqlen_k : 1 * max_seqlen_k); + const ck_tile::index_t stride_randval = (max_seqlen_k); + const ck_tile::index_t stride_o_acc = (hdim_v); + const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); + // setup nhead_stride_* arguments + const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); + const ck_tile::index_t nhead_stride_k = i_perm ? shape_seqlen_k * hdim_q : hdim_q; + const ck_tile::index_t nhead_stride_knew = (i_perm ? seqlen_knew * hdim_q : hdim_q); + const ck_tile::index_t nhead_stride_v = [&]() { + if(is_v_rowmajor) + return i_perm ? shape_seqlen_k * hdim_v : hdim_v; + else + return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k; + }(); + const ck_tile::index_t nhead_stride_vnew = [&]() { + if(is_v_rowmajor) + return i_perm ? seqlen_knew * hdim_v : hdim_v; + else + return i_perm ? hdim_v * seqlen_knew : seqlen_knew; + }(); + const ck_tile::index_t nhead_stride_bias = + (i_perm ? 0 * shape_seqlen_q * max_seqlen_k : 0 * max_seqlen_k); + const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); + const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; + const ck_tile::index_t nhead_stride_lse_acc = (shape_seqlen_q); + const ck_tile::index_t nhead_stride_o_acc = (shape_seqlen_q * hdim_v); + const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + // setup batch_stride_* arguments + const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); + const ck_tile::index_t batch_stride_k = nhead_k * shape_seqlen_k * hdim_q; + const ck_tile::index_t batch_stride_knew = (nhead_k * seqlen_knew * hdim_q); + const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * shape_seqlen_k; + const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew); + const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * max_seqlen_k); + const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); + const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q); + const ck_tile::index_t batch_stride_lse_acc = (nhead * shape_seqlen_q); + const ck_tile::index_t batch_stride_o_acc = (nhead * shape_seqlen_q * hdim_v); + const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); + // const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch); + // setup split_stride_* arguments (only used in split-kv kernel) + const ck_tile::index_t split_stride_lse_acc = (shape_seqlen_q); + const ck_tile::index_t split_stride_o_acc = (shape_seqlen_q * hdim_v); + + args.q_ptr = TQ.data_ptr(); + args.k_ptr = TK.data_ptr(); + args.v_ptr = TV.data_ptr(); + args.block_relation_onehot_ptr = Tblock_relation_onehot.data_ptr(); + + args.batch = batch; + args.seqlen_q = shape_seqlen_q; // unused in group mode + args.hdim_q = hdim_q; + args.hdim_v = hdim_v; + args.nhead_q = nhead; + args.nhead_k = nhead_k; + + args.stride_q = stride_q; + args.stride_k = stride_k; + args.stride_v = stride_v; + args.nhead_stride_q = nhead_stride_q; + args.nhead_stride_k = nhead_stride_k; + args.nhead_stride_v = nhead_stride_v; + args.batch_stride_q = batch_stride_q; + args.batch_stride_k = batch_stride_k; + args.batch_stride_v = batch_stride_v; + + // args.bias_ptr = bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer() + // : bias_buf.GetDeviceBuffer(); + args.bias_ptr = bias ? bias->data_ptr() : nullptr; + args.lse_ptr = lse ? lse->data_ptr() : nullptr; + args.o_ptr = Y.data_ptr(); + + args.seqstart_q_ptr = + (mode == 1 ? seqstart_q->data_ptr() : nullptr); + args.seqstart_k_ptr = + (mode == 1 ? seqstart_k->data_ptr() : nullptr); + args.seqlen_k_ptr = nullptr; + + args.seqlen_k = shape_seqlen_k; // unused in group mode (or kvcache enabled) + args.max_seqlen_q = max_seqlen_q; + + args.scale_s = scale_s; + args.scale_p = scale_p; + args.scale_o = scale_o; + + args.logits_soft_cap = logits_soft_cap; + + args.stride_bias =stride_bias; + args.stride_o = stride_o; + args.nhead_stride_bias = nhead_stride_bias; + args.nhead_stride_lse = nhead_stride_lse; + args.nhead_stride_o = nhead_stride_o; + args.batch_stride_bias = batch_stride_bias; + args.batch_stride_lse = batch_stride_lse; + args.batch_stride_o = batch_stride_o; + + args.window_size_left = mask.left; + args.window_size_right = mask.right; + args.mask_type = static_cast(mask.type); + + args.rand_val_ptr = nullptr; + + args.stride_randval = stride_randval; + args.nhead_stride_randval = nhead_stride_randval; + args.batch_stride_randval = batch_stride_randval; + + args.p_drop = 0.; + args.s_randval = false; + + }; + + const auto init_traits = [&](auto& traits) { + traits.hdim_q = hdim_q; + traits.hdim_v = hdim_v; + traits.data_type = data_type; + traits.is_v_rowmajor = is_v_rowmajor; + + + traits.is_group_mode = (mode == 1); + traits.has_logits_soft_cap = 0.f < logits_soft_cap; + traits.mask_type = mask.type; + traits.bias_type = static_cast(bias_type); + traits.has_lse = lse ? true: false; + traits.do_fp8_static_quant = false; + + traits.has_dropout = false; + + }; + + fmha_jenga_fwd_traits fmha_traits; + init_traits(fmha_traits); + + fmha_jenga_fwd_args args; + init_args(args); + + fmha_jenga_fwd(fmha_traits, args, stream_config); + + return Y; +} diff --git a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.h b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.h new file mode 100644 index 00000000000..b8fbfdc8d8e --- /dev/null +++ b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.h @@ -0,0 +1,61 @@ +#pragma once +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// +#include +#include +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +using DataType = ck_tile::half_t; + +ck_tile::HostTensor jenga_sparse_attention( + ck_tile::HostTensor &TQ, + ck_tile::HostTensor &TK, + ck_tile::HostTensor &TV, + ck_tile::HostTensor &Tblock_relation_onehot, + ck_tile::HostTensor &Y, + std::optional> bias, + std::optional> lse, + std::optional> seqstart_q, + std::optional> seqstart_k, + int bias_type, + int batch, + int nhead, + int nhead_k, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + int mode, + bool i_perm, + bool o_perm, + int max_seqlen_q, + int max_seqlen_k +); + +ck_tile::HostTensor vsa_sparse_attention( + ck_tile::HostTensor &TQ, + ck_tile::HostTensor &TK, + ck_tile::HostTensor &TV, + ck_tile::HostTensor &TKV_block_idx, // LUT must be int32_t + ck_tile::HostTensor &TKV_blocks, // valid_block_num must be int32_t + ck_tile::HostTensor &Y, + std::optional> bias, + std::optional> lse, + std::optional> seqstart_q, + std::optional> seqstart_k, + int bias_type, + int batch, + int nhead, + int nhead_k, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + int mode, + bool i_perm, + bool o_perm, + int max_seqlen_q, + int max_seqlen_k +); diff --git a/example/ck_tile/50_sparse_attn/mask.hpp b/example/ck_tile/50_sparse_attn/mask.hpp new file mode 100644 index 00000000000..b96482f5355 --- /dev/null +++ b/example/ck_tile/50_sparse_attn/mask.hpp @@ -0,0 +1,176 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha.hpp" + +// keep this in sync with ck_tile::GenericAttentionMaskEnum +enum class mask_enum +{ + no_mask = 0, + mask_top_left, + mask_bottom_right, + window_generic, +}; + +struct mask_info +{ + mask_enum type; + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t y, x; + ck_tile::index_t left, right; // FA style SWA left/right + + void serialize(std::ostream& os) const + { + if(type == mask_enum::no_mask) + os << "n"; + else if(type == mask_enum::mask_top_left) + os << "t(" << left << ":" << right << ")"; + else if(type == mask_enum::mask_bottom_right) + os << "b(" << left << ":" << right << ")"; + else + { + os << "g(" << y << ":" << x << ")"; + } + } + static mask_info decode(std::string str, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k) + { + ck_tile::index_t x_total = seqlen_k; + ck_tile::index_t y_total = seqlen_q; + mask_info tmp; + tmp.seqlen_q = seqlen_q; + tmp.seqlen_k = seqlen_k; + auto found_0 = str.find(':'); + if(found_0 != std::string::npos) + { + std::string t = str.substr(0, found_0); + std::string v = str.substr(found_0 + 1); + if(t == "xt" || t == "xb") + { + // xformer style sliding window attn from top-left + ck_tile::index_t window_size = atoi(v.c_str()); + ck_tile::index_t left_size = -1; + ck_tile::index_t right_size = 0; + if(window_size > 0) + { + left_size = window_size / 2; + right_size = window_size - 1 - left_size; + } + auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( + left_size, right_size, y_total, x_total, t == "xt"); + + tmp.type = t == "xt" ? mask_enum::mask_top_left : mask_enum::mask_bottom_right; + tmp.y = r.at(ck_tile::number<0>{}); + tmp.x = r.at(ck_tile::number<1>{}); + tmp.left = left_size; + tmp.right = right_size; + } + else + { + auto found_1 = v.find(","); + if(found_1 == std::string::npos) + { + printf("not supported value %s, %s\n", v.c_str(), str.c_str()); + assert(0); + } + tmp.type = mask_enum::window_generic; + ck_tile::index_t v0 = atoi(v.substr(0, found_1).c_str()); + ck_tile::index_t v1 = atoi(v.substr(found_1 + 1).c_str()); + // TODO: some validation + if(t == "t") + { + tmp.type = mask_enum::mask_top_left; + auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( + v0, v1, y_total, x_total, true); + tmp.y = r.at(ck_tile::number<0>{}); + tmp.x = r.at(ck_tile::number<1>{}); + tmp.left = v0; + tmp.right = v1; + } + else if(t == "b") + { + tmp.type = mask_enum::mask_bottom_right; + auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( + v0, v1, y_total, x_total, false); + tmp.y = r.at(ck_tile::number<0>{}); + tmp.x = r.at(ck_tile::number<1>{}); + tmp.left = v0; + tmp.right = v1; + } + else if(t == "g") + { + tmp.y = v0; + tmp.x = v1; + tmp.left = v0; // TODO: don't use this? + tmp.right = v1; + } + else + { + printf("not supported type %s, %s\n", t.c_str(), str.c_str()); + assert(0); + } + } + } + else + { + auto set_causal_top_left = [&]() { + tmp.type = mask_enum::mask_top_left; + tmp.y = seqlen_q; + tmp.x = 1; + tmp.left = -1; + tmp.right = 0; + }; + auto set_causal_bottom_right = [&]() { + tmp.type = mask_enum::mask_bottom_right; + tmp.y = seqlen_q; + tmp.x = seqlen_k - seqlen_q + 1; + tmp.left = -1; + tmp.right = 0; + }; + if(str == "t") + set_causal_top_left(); + else if(str == "b") + set_causal_bottom_right(); + else + { + tmp.type = static_cast(atoi(str.c_str())); + if(tmp.type == mask_enum::mask_top_left) + { + set_causal_top_left(); + } + else if(tmp.type == mask_enum::mask_bottom_right) + { + set_causal_bottom_right(); + } + } + } + return tmp; + } + ck_tile::index_t get_unmaskarea() const + { + if(type == mask_enum::no_mask) + return seqlen_q * seqlen_k; + ck_tile::index_t area = 0; + for(ck_tile::index_t i_y = 0; i_y < seqlen_q; ++i_y) + { + ck_tile::index_t x_start = std::max(-y + i_y + 1, static_cast(0)); + ck_tile::index_t x_end = std::min(i_y + x, seqlen_k); + if(x_end > x_start) + { + area += (x_end - x_start); + } + } + return area; + } + friend std::ostream& operator<<(std::ostream& os, const mask_info& mi) + { + mi.serialize(os); + return os; + } +}; diff --git a/example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp b/example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp new file mode 100644 index 00000000000..27b18a66960 --- /dev/null +++ b/example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp @@ -0,0 +1,490 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// +// Test for vsa_sparse_attention function +// Based on the Python test: test_jenga_attention.py + +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "ck_tile/core.hpp" + +// Define DataType before including the header +using DataType = ck_tile::half_t; + +#include "jenga_sparse_attention.h" +#include "fmha_fwd_trek.hpp" + +// ============================================================================ +// Helper Functions +// ============================================================================ + +// Convert block_relation_onehot to LUT format (similar to triton_block_map_to_lut_kernel) +template +void block_map_to_lut( + const ck_tile::HostTensor& block_map, // [B, H, Q_blocks, K_blocks] + ck_tile::HostTensor& lut, // [B, H, Q_blocks, K_blocks] - int32_t for kernel + ck_tile::HostTensor& valid_block_num, // [B, H, Q_blocks] - int32_t for kernel + ck_tile::index_t num_block_k) +{ + auto lengths = block_map.get_lengths(); + ck_tile::index_t B = lengths[0]; + ck_tile::index_t H = lengths[1]; + ck_tile::index_t Q = lengths[2]; + + for (ck_tile::index_t b = 0; b < B; ++b) { + for (ck_tile::index_t h = 0; h < H; ++h) { + for (ck_tile::index_t q = 0; q < Q; ++q) { + int32_t valid_count = 0; + int32_t prev_block = 0; + + for (ck_tile::index_t k = 0; k < num_block_k; ++k) { + T cur_block = block_map(b, h, q, k); + if (static_cast(cur_block) > 0.5f) { // Check if block is active + lut(b, h, q, valid_count) = static_cast(k - prev_block); + valid_count++; + prev_block = static_cast(k); + } + } + valid_block_num(b, h, q) = valid_count; + } + } + } +} + +// Reference implementation: blocked attention (similar to pytorch_blocked_attention) +template +void reference_blocked_attention( + const ck_tile::HostTensor& q, // [B, H, S_q, D] + const ck_tile::HostTensor& k, // [B, H, S_k, D] + const ck_tile::HostTensor& v, // [B, H, S_k, D_v] + const ck_tile::HostTensor& block_relation, // [B, H, Q_blocks, K_blocks] + const ck_tile::HostTensor& bias, // [B, H, S_q, S_k] + ck_tile::HostTensor& output, // [B, H, S_q, D_v] + ck_tile::index_t BLKQ, + ck_tile::index_t BLKK, + AccT scale) +{ + auto q_lengths = q.get_lengths(); + ck_tile::index_t batch = q_lengths[0]; + ck_tile::index_t nhead = q_lengths[1]; + ck_tile::index_t seqlen_q = q_lengths[2]; + ck_tile::index_t hdim = q_lengths[3]; + + auto v_lengths = v.get_lengths(); + ck_tile::index_t seqlen_k = v_lengths[2]; + ck_tile::index_t hdim_v = v_lengths[3]; + + ck_tile::index_t num_q_blocks = seqlen_q / BLKQ; + ck_tile::index_t num_k_blocks = seqlen_k / BLKK; + + for (ck_tile::index_t b = 0; b < batch; ++b) { + for (ck_tile::index_t h = 0; h < nhead; ++h) { + for (ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb) { + ck_tile::index_t q_start = qb * BLKQ; + ck_tile::index_t q_end = q_start + BLKQ; + + // Find relevant K blocks + std::vector relevant_k_indices; + for (ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) { + if (static_cast(block_relation(b, h, qb, kb)) > 0.5f) { + relevant_k_indices.push_back(kb); + } + } + + if (relevant_k_indices.empty()) continue; + + // For each query position in the block + for (ck_tile::index_t sq = q_start; sq < q_end; ++sq) { + // Compute attention scores for all relevant K blocks + std::vector scores; + AccT max_score = -std::numeric_limits::infinity(); + + for (auto kb : relevant_k_indices) { + ck_tile::index_t k_start = kb * BLKK; + ck_tile::index_t k_end = k_start + BLKK; + + for (ck_tile::index_t sk = k_start; sk < k_end; ++sk) { + AccT score = 0.0f; + for (ck_tile::index_t d = 0; d < hdim; ++d) { + score += static_cast(q(b, h, sq, d)) * + static_cast(k(b, h, sk, d)); + } + score = score * scale + static_cast(bias(b, h, sq, sk)); + scores.push_back(score); + max_score = std::max(max_score, score); + } + } + + // Softmax + AccT sum_exp = 0.0f; + for (auto& s : scores) { + s = std::exp(s - max_score); + sum_exp += s; + } + for (auto& s : scores) { + s /= sum_exp; + } + + // Compute output: P @ V + for (ck_tile::index_t dv = 0; dv < hdim_v; ++dv) { + AccT out_val = 0.0f; + size_t score_idx = 0; + + for (auto kb : relevant_k_indices) { + ck_tile::index_t k_start = kb * BLKK; + ck_tile::index_t k_end = k_start + BLKK; + + for (ck_tile::index_t sk = k_start; sk < k_end; ++sk) { + out_val += scores[score_idx] * + static_cast(v(b, h, sk, dv)); + score_idx++; + } + } + output(b, h, sq, dv) = static_cast(out_val); + } + } + } + } + } +} + +// Get error tolerance based on data type +template +auto get_error_tolerance() +{ + double rtol = 1e-2; + double atol = 4e-2; // Higher tolerance for bf16/fp16 + return ck_tile::make_tuple(rtol, atol); +} + +// ============================================================================ +// Command line argument parser +// ============================================================================ +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser + .insert("v", "1", "0:no validation, 1:cpu validation") + .insert("mode", "0", "kernel mode. 0:batch, 1:group") + .insert("b", "1", "batch size") + .insert("h", "4", "num of head for q") + .insert("h_k", "-1", "num of head for k/v, -1 means equal to h") + .insert("s", "4096", "seqlen_q") + .insert("s_k", "-1", "seqlen_k, -1 means equal to s") + .insert("d", "128", "head dim for q, k") + .insert("d_v", "-1", "head dim for v, -1 means equal to d") + .insert("block_size", "128", "block size for sparse attention (BLKQ=BLKK)") + .insert("sparsity", "0.5", "sparsity ratio (0.0 = dense, 1.0 = fully sparse)") + .insert("prec", "fp16", "data type: fp16/bf16") + .insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d") + .insert("operm", "1", "permute output") + .insert("bias", "0", "bias type: 0:no bias, 1:elementwise, 2:alibi") + .insert("lse", "0", "0:not store lse, 1:store lse") + .insert("seed", "42", "random seed") + .insert("warmup", "5", "warmup iterations") + .insert("repeat", "20", "benchmark iterations") + .insert("kname", "0", "print kernel name"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// ============================================================================ +// Main Test Function +// ============================================================================ +template +bool run_test(const ck_tile::ArgParser& arg_parser) +{ + // Parse arguments + int do_validation = arg_parser.get_int("v"); + int mode = arg_parser.get_int("mode"); + ck_tile::index_t batch = arg_parser.get_int("b"); + ck_tile::index_t nhead = arg_parser.get_int("h"); + ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); + ck_tile::index_t seqlen_q = arg_parser.get_int("s"); + ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); + ck_tile::index_t hdim_q = arg_parser.get_int("d"); + ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); + ck_tile::index_t block_size = arg_parser.get_int("block_size"); + float sparsity = arg_parser.get_float("sparsity"); + bool i_perm = arg_parser.get_bool("iperm"); + bool o_perm = arg_parser.get_bool("operm"); + int bias_type = arg_parser.get_int("bias"); + [[maybe_unused]] bool store_lse = arg_parser.get_bool("lse"); + uint32_t seed = arg_parser.get_uint32("seed"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + [[maybe_unused]] int kname = arg_parser.get_int("kname"); + + // Handle default values + if (nhead_k < 0) nhead_k = nhead; + if (seqlen_k < 0) seqlen_k = seqlen_q; + if (hdim_v < 0) hdim_v = hdim_q; + + ck_tile::index_t BLKQ = block_size; + ck_tile::index_t BLKK = block_size; + + // Calculate number of Q and K blocks + ck_tile::index_t num_q_blocks = seqlen_q / BLKQ; + ck_tile::index_t num_k_blocks = seqlen_k / BLKK; + + std::cout << "============================================================" << std::endl; + std::cout << "[VSA Sparse Attention Test]" << std::endl; + std::cout << "============================================================" << std::endl; + std::cout << " Batch: " << batch << ", nhead_q: " << nhead << ", nhead_k: " << nhead_k << std::endl; + std::cout << " seqlen_q: " << seqlen_q << ", seqlen_k: " << seqlen_k << std::endl; + std::cout << " hdim_q: " << hdim_q << ", hdim_v: " << hdim_v << std::endl; + std::cout << " block_size: " << block_size << " (BLKQ=" << BLKQ << ", BLKK=" << BLKK << ")" << std::endl; + std::cout << " num_q_blocks: " << num_q_blocks << ", num_k_blocks: " << num_k_blocks << std::endl; + std::cout << " sparsity: " << sparsity << std::endl; + std::cout << " i_perm: " << i_perm << ", o_perm: " << o_perm << std::endl; + + // Create host tensors (using BHSD layout when i_perm=true) + // Q: [B, H, S_q, D] + // K: [B, H_k, S_k, D] + // V: [B, H_k, S_k, D_v] + ck_tile::HostTensor q_host({batch, nhead, seqlen_q, hdim_q}); + ck_tile::HostTensor k_host({batch, nhead_k, seqlen_k, hdim_q}); + ck_tile::HostTensor v_host({batch, nhead_k, seqlen_k, hdim_v}); + ck_tile::HostTensor output_host({batch, nhead, seqlen_q, hdim_v}); + ck_tile::HostTensor output_ref({batch, nhead, seqlen_q, hdim_v}); + + // Bias tensor [B, H, S_q, S_k] + ck_tile::HostTensor bias_host({batch, nhead, seqlen_q, seqlen_k}); + + // Block relation onehot: [B, H, Q_blocks, K_blocks] + ck_tile::HostTensor block_relation_onehot({batch, nhead, num_q_blocks, num_k_blocks}); + + // LUT and valid_block_num (output of block_map_to_lut) - must be int32_t for kernel + ck_tile::HostTensor lut_host({batch, nhead, num_q_blocks, num_k_blocks}); + ck_tile::HostTensor valid_block_num_host({batch, nhead, num_q_blocks}); + + // Initialize tensors with random values + std::cout << "\nInitializing tensors..." << std::endl; + ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed}(q_host); + ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 1}(k_host); + ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 2}(v_host); + + // Initialize bias to zero (as in Python test) + std::fill(bias_host.mData.begin(), bias_host.mData.end(), static_cast(0.0f)); + + // Initialize block_relation_onehot with sparse pattern + std::mt19937 rng(seed + 100); + std::uniform_real_distribution dist(0.0f, 1.0f); + ck_tile::index_t total_blocks = 0; + ck_tile::index_t active_blocks = 0; + + for (ck_tile::index_t b = 0; b < batch; ++b) { + for (ck_tile::index_t h = 0; h < nhead; ++h) { + for (ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb) { + for (ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) { + total_blocks++; + // Each Q block always attends to its diagonal K block (if exists) + // Plus random blocks based on sparsity + bool is_diagonal = (qb == kb && qb < num_k_blocks); + bool random_active = (dist(rng) > sparsity); + + if (is_diagonal || random_active) { + block_relation_onehot(b, h, qb, kb) = static_cast(1.0f); + active_blocks++; + } else { + block_relation_onehot(b, h, qb, kb) = static_cast(0.0f); + } + } + } + } + } + + float actual_sparsity = 1.0f - static_cast(active_blocks) / static_cast(total_blocks); + std::cout << " Actual sparsity: " << actual_sparsity + << " (" << active_blocks << "/" << total_blocks << " blocks active)" << std::endl; + + // Convert block_relation_onehot to LUT format + std::cout << "Converting block map to LUT format..." << std::endl; + block_map_to_lut(block_relation_onehot, lut_host, valid_block_num_host, num_k_blocks); + + // vsa_sparse_attention handles device memory internally + + // Optional tensors + std::optional> bias_opt = std::nullopt; + std::optional> lse_opt = std::nullopt; + std::optional> seqstart_q_opt = std::nullopt; + std::optional> seqstart_k_opt = std::nullopt; + + if (bias_type != 0) { + bias_opt = bias_host; + } + + // Run kernel + std::cout << "\n--- Running VSA sparse attention kernel ---" << std::endl; + + try { + // Warmup + for (int i = 0; i < warmup; ++i) { + vsa_sparse_attention( + q_host, + k_host, + v_host, + lut_host, + valid_block_num_host, + output_host, + bias_opt, + lse_opt, + seqstart_q_opt, + seqstart_k_opt, + bias_type, + batch, + nhead, + nhead_k, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + mode, + i_perm, + o_perm, + seqlen_q, + seqlen_k + ); + } + + // Benchmark + [[maybe_unused]] auto sync_status1 = hipDeviceSynchronize(); + auto start = std::chrono::high_resolution_clock::now(); + + for (int i = 0; i < repeat; ++i) { + vsa_sparse_attention( + q_host, + k_host, + v_host, + lut_host, + valid_block_num_host, + output_host, + bias_opt, + lse_opt, + seqstart_q_opt, + seqstart_k_opt, + bias_type, + batch, + nhead, + nhead_k, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + mode, + i_perm, + o_perm, + seqlen_q, + seqlen_k + ); + } + + [[maybe_unused]] auto sync_status2 = hipDeviceSynchronize(); + auto end = std::chrono::high_resolution_clock::now(); + double avg_time_ms = std::chrono::duration(end - start).count() / repeat; + + std::cout << "\n>>>> VSA sparse attention average time: " << avg_time_ms << " ms <<<<" << std::endl; + + } catch (const std::exception& e) { + std::cerr << "Error during kernel execution: " << e.what() << std::endl; + return false; + } + + // Note: vsa_sparse_attention already returns output in output_host + + // Validation + bool pass = true; + if (do_validation) { + std::cout << "\n--- Performing CPU validation ---" << std::endl; + + // Compute scale factor + float scale = 1.0f / std::sqrt(static_cast(hdim_q)); + + // Run reference implementation + std::cout << "Computing reference output..." << std::endl; + reference_blocked_attention( + q_host, k_host, v_host, + block_relation_onehot, bias_host, + output_ref, + BLKQ, BLKK, scale + ); + + // Compare results + auto [rtol, atol] = get_error_tolerance(); + + float max_diff = 0.0f; + float max_rel_diff = 0.0f; + size_t num_errors = 0; + + for (size_t i = 0; i < output_host.mData.size(); ++i) { + float gpu_val = static_cast(output_host.mData[i]); + float ref_val = static_cast(output_ref.mData[i]); + float diff = std::abs(gpu_val - ref_val); + float rel_diff = (std::abs(ref_val) > 1e-6f) ? diff / std::abs(ref_val) : diff; + + max_diff = std::max(max_diff, diff); + max_rel_diff = std::max(max_rel_diff, rel_diff); + + if (diff > atol && rel_diff > rtol) { + num_errors++; + if (num_errors <= 5) { + std::cout << " Mismatch at index " << i + << ": GPU=" << gpu_val << ", Ref=" << ref_val + << ", Diff=" << diff << std::endl; + } + } + } + + std::cout << "\nValidation results:" << std::endl; + std::cout << " Max absolute difference: " << max_diff << std::endl; + std::cout << " Max relative difference: " << max_rel_diff << std::endl; + std::cout << " Number of mismatches: " << num_errors << " / " << output_host.mData.size() << std::endl; + + if (num_errors == 0) { + std::cout << "\n>>> VALIDATION PASSED <<<" << std::endl; + } else { + std::cout << "\n>>> VALIDATION FAILED <<<" << std::endl; + pass = false; + } + } + + std::cout << "\n" << (pass ? "TEST PASSED" : "TEST FAILED") << std::endl; + return pass; +} + +// ============================================================================ +// Main +// ============================================================================ +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if (!result) { + std::cerr << "Failed to parse arguments" << std::endl; + return -1; + } + + std::string prec = arg_parser.get_str("prec"); + + bool test_result = false; + if (prec == "fp16") { + test_result = run_test(arg_parser); + } else if (prec == "bf16") { + std::cout << "Note: Using bf16 precision" << std::endl; + // For bf16, we would need to compile with DataType = ck_tile::bf16_t + // For now, run with the compiled DataType + test_result = run_test(arg_parser); + } else { + std::cerr << "Unsupported precision: " << prec << std::endl; + return -1; + } + + return test_result ? 0 : -1; +} diff --git a/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu b/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu new file mode 100644 index 00000000000..3b7c3511fc2 --- /dev/null +++ b/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu @@ -0,0 +1,216 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "jenga_sparse_attention.h" +#include "fmha_fwd_trek.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" +#include "ck_tile/host/device_memory.hpp" + + +ck_tile::HostTensor vsa_sparse_attention( + ck_tile::HostTensor &TQ, + ck_tile::HostTensor &TK, + ck_tile::HostTensor &TV, + ck_tile::HostTensor &TKV_block_idx, // LUT must be int32_t + ck_tile::HostTensor &TKV_blocks, // valid_block_num must be int32_t + ck_tile::HostTensor &Y, + std::optional> bias, + std::optional> lse, + std::optional> seqstart_q, + std::optional> seqstart_k, + int bias_type, + int batch, + int nhead, + int nhead_k, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + int mode, + bool i_perm, + bool o_perm, + int max_seqlen_q, + int max_seqlen_k +){ + std::string data_type = "fp16"; + // DataType is determined at compile time via template + + if (max_seqlen_q == 0) max_seqlen_q = seqlen_q; + if (max_seqlen_k == 0) max_seqlen_k = seqlen_k; + bool is_v_rowmajor = true; + float scale_s = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); + float scale_p = 1.f; + float scale_o = 1.f; + const float logits_soft_cap = 0.0; + + std::string msk_str = "0"; + mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k); + + const ck_tile::index_t shape_seqlen_q = (mode == 0 ? seqlen_q : max_seqlen_q); + const ck_tile::index_t shape_seqlen_k = (mode == 0 ? seqlen_k : max_seqlen_k); + + ck_tile::stream_config stream_config{nullptr, + false, // time_kernel + 0, /* log_level = */ + 0, + 1, + false}; + + // Create device memory and copy data to device + ck_tile::DeviceMem q_buf(TQ.get_element_space_size_in_bytes()); + ck_tile::DeviceMem k_buf(TK.get_element_space_size_in_bytes()); + ck_tile::DeviceMem v_buf(TV.get_element_space_size_in_bytes()); + ck_tile::DeviceMem lut_buf(TKV_block_idx.get_element_space_size_in_bytes()); + ck_tile::DeviceMem valid_block_num_buf(TKV_blocks.get_element_space_size_in_bytes()); + ck_tile::DeviceMem o_buf(Y.get_element_space_size_in_bytes()); + + q_buf.ToDevice(TQ.data()); + k_buf.ToDevice(TK.data()); + v_buf.ToDevice(TV.data()); + lut_buf.ToDevice(TKV_block_idx.data()); + valid_block_num_buf.ToDevice(TKV_blocks.data()); + + // Optional buffers + ck_tile::DeviceMem bias_buf(bias ? bias->get_element_space_size_in_bytes() : 0); + ck_tile::DeviceMem lse_buf(lse ? lse->get_element_space_size_in_bytes() : 0); + ck_tile::DeviceMem seqstart_q_buf(seqstart_q ? seqstart_q->get_element_space_size_in_bytes() : 0); + ck_tile::DeviceMem seqstart_k_buf(seqstart_k ? seqstart_k->get_element_space_size_in_bytes() : 0); + + if (bias) bias_buf.ToDevice(bias->data()); + if (lse) lse_buf.ToDevice(lse->data()); + if (seqstart_q) seqstart_q_buf.ToDevice(seqstart_q->data()); + if (seqstart_k) seqstart_k_buf.ToDevice(seqstart_k->data()); + + const auto init_args = [&](auto& args) { + assert(nhead % nhead_k == 0); + const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); + const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); + const ck_tile::index_t stride_v = [&]() { + if(is_v_rowmajor) + return i_perm ? hdim_v : nhead_k * hdim_v; + else + return (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k); + }(); + const ck_tile::index_t stride_bias = (i_perm ? max_seqlen_k : 1 * max_seqlen_k); + const ck_tile::index_t stride_randval = (max_seqlen_k); + const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); + // setup nhead_stride_* arguments + const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); + const ck_tile::index_t nhead_stride_k = i_perm ? shape_seqlen_k * hdim_q : hdim_q; + const ck_tile::index_t nhead_stride_v = [&]() { + if(is_v_rowmajor) + return i_perm ? shape_seqlen_k * hdim_v : hdim_v; + else + return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k; + }(); + const ck_tile::index_t nhead_stride_bias = + (i_perm ? 0 * shape_seqlen_q * max_seqlen_k : 0 * max_seqlen_k); + const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); + const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; + const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + // setup batch_stride_* arguments + const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); + const ck_tile::index_t batch_stride_k = nhead_k * shape_seqlen_k * hdim_q; + const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * shape_seqlen_k; + const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * max_seqlen_k); + const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); + const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q); + const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); + + // Use device buffer pointers instead of host tensor data pointers + args.q_ptr = q_buf.GetDeviceBuffer(); + args.k_ptr = k_buf.GetDeviceBuffer(); + args.v_ptr = v_buf.GetDeviceBuffer(); + args.lut_ptr = lut_buf.GetDeviceBuffer(); + args.valid_block_num_ptr = valid_block_num_buf.GetDeviceBuffer(); + + args.batch = batch; + args.seqlen_q = shape_seqlen_q; // unused in group mode + args.hdim_q = hdim_q; + args.hdim_v = hdim_v; + args.nhead_q = nhead; + args.nhead_k = nhead_k; + + args.stride_q = stride_q; + args.stride_k = stride_k; + args.stride_v = stride_v; + args.nhead_stride_q = nhead_stride_q; + args.nhead_stride_k = nhead_stride_k; + args.nhead_stride_v = nhead_stride_v; + args.batch_stride_q = batch_stride_q; + args.batch_stride_k = batch_stride_k; + args.batch_stride_v = batch_stride_v; + + args.bias_ptr = bias ? bias_buf.GetDeviceBuffer() : nullptr; + args.lse_ptr = lse ? lse_buf.GetDeviceBuffer() : nullptr; + args.o_ptr = o_buf.GetDeviceBuffer(); + + args.seqstart_q_ptr = (mode == 1 ? seqstart_q_buf.GetDeviceBuffer() : nullptr); + args.seqstart_k_ptr = (mode == 1 ? seqstart_k_buf.GetDeviceBuffer() : nullptr); + args.seqlen_k_ptr = nullptr; + + args.seqlen_k = shape_seqlen_k; // unused in group mode (or kvcache enabled) + args.max_seqlen_q = max_seqlen_q; + + args.scale_s = scale_s; + args.scale_p = scale_p; + args.scale_o = scale_o; + + args.logits_soft_cap = logits_soft_cap; + + args.stride_bias =stride_bias; + args.stride_o = stride_o; + args.nhead_stride_bias = nhead_stride_bias; + args.nhead_stride_lse = nhead_stride_lse; + args.nhead_stride_o = nhead_stride_o; + args.batch_stride_bias = batch_stride_bias; + args.batch_stride_lse = batch_stride_lse; + args.batch_stride_o = batch_stride_o; + + args.window_size_left = mask.left; + args.window_size_right = mask.right; + args.mask_type = static_cast(mask.type); + + args.rand_val_ptr = nullptr; + + args.stride_randval = stride_randval; + args.nhead_stride_randval = nhead_stride_randval; + args.batch_stride_randval = batch_stride_randval; + + args.p_drop = 0.; + args.s_randval = false; + + }; + + const auto init_traits = [&](auto& traits) { + traits.hdim_q = hdim_q; + traits.hdim_v = hdim_v; + traits.data_type = data_type; + traits.is_v_rowmajor = is_v_rowmajor; + + + traits.is_group_mode = (mode == 1); + traits.has_logits_soft_cap = 0.f < logits_soft_cap; + traits.mask_type = mask.type; + traits.bias_type = static_cast(bias_type); + traits.has_lse = lse ? true: false; + traits.do_fp8_static_quant = false; + + traits.has_dropout = false; + + }; + + fmha_jenga_fwd_traits fmha_traits; + init_traits(fmha_traits); + + fmha_jenga_fwd_args args; + init_args(args); + + fmha_vsa_fwd(fmha_traits, args, stream_config); + + // Copy output back to host + Y = o_buf.ToHost(); + + return Y; +} diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index 92ee0a4c31a..bf11045a489 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -27,4 +27,5 @@ add_subdirectory(36_pooling) add_subdirectory(38_block_scale_gemm) add_subdirectory(40_streamk_gemm) add_subdirectory(41_batched_contraction) +add_subdirectory(50_sparse_attn) diff --git a/include/ck_tile/ops/sparse_attn/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp b/include/ck_tile/ops/sparse_attn/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp new file mode 100644 index 00000000000..caac6e2126e --- /dev/null +++ b/include/ck_tile/ops/sparse_attn/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp @@ -0,0 +1,887 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { + +// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future) +template +struct BlockFmhaPipelineQRKSVSAsyncJenga +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using AttentionVariant = remove_cvref_t; + using FmhaMask = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + using VLayout = remove_cvref_t; + static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once + static_assert(kQLoadOnce == Policy::QLoadOnce); + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kN1 = BlockFmhaShape::kN1; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; + + static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + // TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x) + // only need special care about seq_k padding (oob need set -INF of p instead of zero) + static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true && + Problem::kPadHeadDimV == true); + static constexpr bool kPadSeqLenQ = true; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x) + static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x) + static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasDropout = Problem::kHasDropout; + + static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && + (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || + !kHasLogitsSoftCap)) || + (!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap)); + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = []() { + if constexpr(std::is_same_v) + return Policy::template GetAlignmentV(); + else + return kPadSeqLenK ? 1 : Policy::template GetAlignmentV(); + }(); + static constexpr index_t kAlignmentO = Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); + +#if CK_TILE_FMHA_FWD_FAST_EXP2 + static constexpr auto R_LOG2E = 1.0 / log2e_v; +#endif + + static constexpr index_t kBlockPerCu = []() { + if constexpr(Problem::kBlockPerCu != -1) + return Problem::kBlockPerCu; + else + { + // minimize occupancy + if constexpr(BiasEnum != BlockAttentionBiasEnum::NO_BIAS && kHasDropout) + { + return 1; + } + + if constexpr(kQKHeaddim <= 32) + { + if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS && + FmhaMask::IsMasking) + return 1; + else + return 2; + } + else if constexpr(kQKHeaddim <= 64) + { + if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + return 2; + else + return 3; + } + else if constexpr(kQKHeaddim <= 128) + { + if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + return 1; + else + return 2; + } + else if constexpr(kQKHeaddim <= 192) + { + if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + return 1; + else + return 2; + } + else if constexpr(kQKHeaddim <= 256) + { + return 1; + } + else + { + return 1; + }; + } + }(); + + static constexpr const char* name = "qr_async"; + + using DropoutType = std::conditional_t; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const QElementFunction& q_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const KElementFunction& /*k_element_func*/, + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const VElementFunction& v_element_func, + const bool *block_relation_onehot_ptr, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + const BiasElementFunction& bias_element_func, + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile + const LSEElementFunction& lse_element_func, + const SAccElementFunction& s_acc_element_func, + const PComputeElementFunction& p_compute_element_func, + const OAccElementFunction& o_acc_element_func, + FmhaMask mask, + PositionEncoding position_encoding, + float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, + void* smem_ptr, + DropoutType& dropout) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + constexpr auto LdsSeq = Policy::template GetLdsBufferSequence(); + + // K tile in LDS + auto k_lds_ptr = reinterpret_cast(smem_ptr); + auto k_lds_store = generate_tuple( + [&](auto i_buf) { + return make_tile_window( + make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsStoreBlockDescriptor(i_buf)), + Policy::template MakeKLdsStoreBlockDescriptor(i_buf).get_lengths(), + {0, 0, 0}); + }, + number{}); + + auto k_lds_Load_view = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor()); + + auto k_lds_load = + make_tile_window(k_lds_Load_view, + Policy::template MakeKLdsLoadBlockDescriptor().get_lengths(), + {0, 0}); + + // V tile in LDS + auto v_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeVLdsBlockDescriptor()); + auto v_lds_window = make_tile_window( + v_lds, Policy::template MakeVLdsBlockDescriptor().get_lengths(), {0, 0}); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + + bool* block_relation_onehot = reinterpret_cast(smem_ptr) + GetSmemSize(); + amd_direct_load_global_to_lds(block_relation_onehot_ptr, 4*threadIdx.x, block_relation_onehot, 4*threadIdx.x, threadIdx.x/64==0, 256); + + auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + q_dram_block_window_tmp.get_window_lengths(), + q_dram_block_window_tmp.get_window_origin(), + Policy::template MakeQRegTileDistribution()); + q_dram_window.init_raw(); + + // TODO: we use async Copy for K, which is inline asm + // a side effect is we have to use inline asm for q as well + auto q = decltype(load_tile(q_dram_window)){}; + // TODO: start from rocm-6.2, compiler will have problem if manually set clear of q. + // however, q would be cleared in the constructor of static distributed tensor + // set_tile(q, number<0>{}); // use per-dword clear to avoid scratch + load_tile_raw(q, q_dram_window); + __builtin_amdgcn_sched_barrier(0); + + using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); + auto s_acc = SaccBlockTileType{}; + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + // infer Sacc, S, P, M, L, Oacc type + using SBlockTileType = decltype(cast_tile(s_acc)); + + using MLBlockTileType = decltype(block_tile_reduce( + SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0})); + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + + // init Oacc, M, L + auto o_acc = OaccBlockTileType{}; + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + clear_tile(o_acc); + set_tile(m, -numeric::infinity()); + clear_tile(l); + + __builtin_amdgcn_sched_barrier(0); + const auto q_origin = q_dram_window.get_window_origin(); + const auto [seqlen_k_start, seqlen_k_end] = + mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); + // if (threadIdx.x==0 && blockIdx.y==0) { + // printf("\nblockIdx.x : %d, seqlen_k_start: %d, seqlen_k_end: %d\n", blockIdx.x, seqlen_k_start, seqlen_k_end); + // } + + // check early exit if no work to do + if constexpr(FmhaMask::IsMasking || kPadSeqLenK) + { + if(num_total_loop <= 0) + { + if constexpr(kStoreLSE) + { + auto lse = + make_static_distributed_tensor(m.get_tile_distribution()); + + set_tile(lse, -numeric::infinity()); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0) + // otherwise will have compute error(maybe compiler bug?) + + // Note: here occ are all cleard, return it + return o_acc; + } + __builtin_amdgcn_sched_barrier(0); // make sure sched_barrier(0) for this check + } + + auto k_dram_block_window = + make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + k_dram_block_window_tmp.get_window_lengths(), + {seqlen_k_start, 0}); + + auto k_dram_window = make_tile_window( + k_dram_block_window.get_bottom_tensor_view(), + k_dram_block_window.get_window_lengths(), + k_dram_block_window.get_window_origin(), + Policy::template MakeKDramTileDistribution()); // K DRAM tile window for + // load + k_dram_window.init_raw(); + constexpr auto k_oob_ck = bool_constant{}; + constexpr auto k_pre_np = [&]() { + if constexpr(kPadSeqLenK && + (BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + (BiasEnum != BlockAttentionBiasEnum::NO_BIAS && kHasDropout))) + return bool_constant{}; + else + return bool_constant{}; + }(); + + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + auto bias_dram_window = + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), + bias_dram_block_window_tmp.get_window_lengths(), + {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N + Policy::template MakeBiasDramTileDistribution()); + + auto randval_dram_window = dropout.template MakeRandvalDramWindow( + randval_dram_block_window_tmp, seqlen_k_start); + + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + {0, seqlen_k_start}, // TODO: hdim split? + Policy::template MakeVDramTileDistribution()); + + buffer_load_fence(1); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + + if (block_relation_onehot[0]) { + // prefetch K tile + async_load_tile_raw( + k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, number<-1>{}, k_oob_ck, k_pre_np); + move_tile_window(k_dram_window, {0, kK0}); + __builtin_amdgcn_sched_barrier(0); + } + + // buffer_load_fence(k_dram_window.get_num_of_access(), q.get_thread_buffer()); + buffer_load_fence(k_dram_window.get_num_of_access()); + (void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32 + // auto q_tile = q; // tile_elementwise_in(q_element_func, q); + + index_t i_total_loops = 0; + constexpr index_t k0_loops = kQKHeaddim / kK0; + constexpr index_t k1_loops = kN0 / kK1; + + static_assert(1 <= k0_loops); + static_assert(1 <= k1_loops); + // main loop + do + { + if (!block_relation_onehot[i_total_loops]) + { + i_total_loops++; + if(i_total_loops < num_total_loop) + { + // move K tile windows + move_tile_window(k_dram_block_window, {kN0, 0}); + k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); + + if (block_relation_onehot[i_total_loops]) { + async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), + k_dram_window, + number<-1>{}, + k_oob_ck, + k_pre_np); + } + move_tile_window(k_dram_window, {0, kK0}); + move_tile_window(v_dram_window, {0, kN0}); + continue; + } + break; + } + + // STAGE 1, QK gemm + clear_tile(s_acc); // initialize C + if constexpr(k0_loops > 1) + { + static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { + async_load_tile_raw(k_lds_store(number{})>{}), + k_dram_window, + number<-1>{}, + k_oob_ck, + k_pre_np); + if constexpr(i_k0 < k0_loops - 1) + move_tile_window(k_dram_window, {0, kK0}); + + async_load_fence(k_dram_window.get_num_of_access()); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + gemm_0(s_acc, + get_slice_tile( + q, sequence<0, i_k0 * kK0>{}, sequence{}), + get_slice_tile(k_lds_load, + sequence<(LdsSeq.at(number{})) * kN0, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN0, kK0>{})); + }); + __shared__ int printed_flag; + if (blockIdx.x == 0 && threadIdx.x == 0 && i_total_loops==1000) { + printed_flag = 100; + } + } + + // TODO: this to fix a bug when loop smaller than 2, + // the following fence/barrier will be scheduled inside 1st loop + if constexpr(k0_loops <= 2) + __builtin_amdgcn_sched_barrier(0); + + async_load_fence(); + __builtin_amdgcn_s_barrier(); + + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); + __builtin_amdgcn_sched_barrier(0); + { // tail + gemm_0( + s_acc, + get_slice_tile( + q, sequence<0, (k0_loops - 1) * kK0>{}, sequence{}), + get_slice_tile(k_lds_load, + sequence<(LdsSeq.at(number{})) * kN0, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN0, kK0>{})); + } + __builtin_amdgcn_sched_barrier(1); + + // STAGE 2, scale_s, add bias, mask, softmax + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); + tile_elementwise_inout( + [&](auto& x, const auto& y) { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + x += type_convert(bias_element_func(y)); +#else + x += log2e_v * + type_convert(bias_element_func(y)); +#endif + }, + s_acc, + bias_tile); + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + s_acc(i_j_idx) *= scale_s; + position_encoding.update(s_acc(i_j_idx), row, col); + }); + }); + } + else + { + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + if constexpr(kHasLogitsSoftCap) + { + auto apply_logits_transform = + [&variant, &variant_params, &block_indices](auto& x) { + x = variant.LogitsTransform(variant_params, + variant.QueryTransform(variant_params, x), + block_indices.batch_idx, + block_indices.qo_head_idx, + block_indices.kv_head_idx); + }; +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i) + { + apply_logits_transform(s_acc.thread_buf_[i]); + } +#else + for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i) + { + apply_logits_transform(s_acc.thread_buf_[i]); + } +#endif + } + else + { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); +#endif + } + } + move_tile_window(bias_dram_window, {0, kN0}); + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), + k_origin.at(number<0>{}), + number{}, + number{}); + + if(need_perpixel_check) + { + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return !variant.LogitsMask(variant_params, + block_indices.batch_idx, + row, + col, + block_indices.qo_head_idx, + block_indices.kv_head_idx); + }); + } + } + + const auto s = cast_tile(s_acc); // S{j} + auto m_local = block_tile_reduce( + s, + sequence<1>{}, + f_max, + -numeric::infinity()); // m_local = rowmax(S{j}) + block_tile_reduce_sync(m_local, f_max, bool_constant{}); + + const auto m_old = m; // m{j-1} + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} + + auto p_compute = make_static_distributed_tensor( + s.get_tile_distribution()); // Pcompute{j} + + __builtin_amdgcn_sched_barrier(0x7F); + // store & prefetch next v, after the max reduction + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_buf); + + auto v_lds_window_tmp = + get_slice_tile(v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + + store_tile( + v_lds_window_tmp, + tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch + } + else + { + auto v_lds_window_tmp = + get_slice_tile(v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + store_tile(v_lds_window_tmp, + tile_elementwise_in(v_element_func, v_buf)); // store the prefetch + } + + if constexpr(k1_loops > 1) + { + move_tile_window( + v_dram_window, + {0, kK1}); // will have scratch if move this right after load_tile(v_dram)... + v_buf = load_tile( + v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf + } + __builtin_amdgcn_sched_barrier(0); + + static const auto get_validated_m = [](SMPLComputeDataType raw_m) { + /// NOTICE: bias might be materialized mask including -inf values, need + /// consideration. alibi does not have this problem + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) + { + return raw_m == -numeric::infinity() + ? type_convert(0.f) + : raw_m; + } + else + { + return raw_m; + } + }; + + constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + auto row_max = scale_s * get_validated_m(m[i_idx]); +#endif + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + if constexpr(kHasLogitsSoftCap) + { + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); + } + } +#else + p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); +#endif + }); + }); + + auto rowsum_p = block_tile_reduce( + p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) + + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); + // l{j}, Oacc{j} + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + const auto tmp = [&]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + if constexpr(kHasLogitsSoftCap) + { + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale_s * get_validated_m(m[i_idx]); + return exp2(scale_s * m_old[i_idx] - row_max); + } + } + }(); +#else + const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx])); +#endif + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + // FIXME: this use different equation from FA v2 paper, + // but produce correc result. + // Is the equation wrong? + o_acc(i_j_idx) *= tmp; + }); + }); + + if constexpr(kHasDropout) + { + auto randval_ptr = + reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); + dropout.template Run( + randval_ptr, + seqlen_k_start + i_total_loops * kN0, + p_compute, + randval_dram_window); + } + + const auto p = [&]() { + if constexpr(std::is_same_v) + return impl::cast_tile_pk_fp16_fp32( + tile_elementwise_in(p_compute_element_func, p_compute)); + else + return cast_tile( + tile_elementwise_in(p_compute_element_func, p_compute)); + }(); + + // STAGE 3, KV gemm + if constexpr(k1_loops > 1) + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1) + { + v_buf = load_tile( + v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf + } + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile( + p, sequence<0, i_k1 * kK1>{}, sequence{}), + get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{})); + + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_buf); + auto v_lds_window_tmp = get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + store_tile(v_lds_window_tmp, + tile_elementwise_in(v_element_func, + v_shuffle_tmp)); // store the prefetch + } + else + { + auto v_lds_window_tmp = get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + store_tile(v_lds_window_tmp, + tile_elementwise_in(v_element_func, v_buf)); // store next v_buf + } + if constexpr(i_k1 < k1_loops - 1) + move_tile_window(v_dram_window, {0, kK1}); + }); + } + i_total_loops++; + if(i_total_loops < num_total_loop) + { + // move K tile windows + move_tile_window(k_dram_block_window, {kN0, 0}); + k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); + + if constexpr(k1_loops >= 2 && + LdsSeq.at(number<0>{}) == LdsSeq.at(number{})) + __builtin_amdgcn_s_barrier(); + async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), + k_dram_window, + number<-1>{}, + k_oob_ck, + k_pre_np); + move_tile_window(k_dram_window, {0, kK0}); + } + // tail + { + block_sync_lds(); + gemm_1( + o_acc, + get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), + get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{})); + } + } while(i_total_loops < num_total_loop); + + // store lse + if constexpr(kStoreLSE) + { + auto lse = make_static_distributed_tensor(m.get_tile_distribution()); + + constexpr auto lse_spans = decltype(lse)::get_distributed_spans(); + sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + lse(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]); + } + else + { + if constexpr(kHasLogitsSoftCap) + { + lse(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]); + } + else + { + lse(i_idx) = m_[i_idx] * scale_s * R_LOG2E + log(l_[i_idx]); + } + } +#else + lse(i_idx) = m_[i_idx] + log(l_[i_idx]); +#endif + }); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + + // finally, O + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = [&]() { + if constexpr(FmhaMask::IsMasking) + { + return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; + } + else + return 1 / l[i_idx]; + }(); + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + }); + + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); + + return o_acc; + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const bool *block_relation_onehot_ptr, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile + LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile + FmhaMask mask, + PositionEncoding position_encoding, + float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, + void* smem_ptr, + DropoutType& dropout) const + { + return operator()(q_dram_block_window_tmp, + identity{}, + k_dram_block_window_tmp, + identity{}, + v_dram_block_window_tmp, + identity{}, + block_relation_onehot_ptr, + bias_dram_block_window_tmp, + identity{}, + randval_dram_block_window_tmp, + lse_dram_block_window_tmp, + identity{}, + identity{}, + identity{}, + identity{}, + mask, + position_encoding, + scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + dropout); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/sparse_attn/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp b/include/ck_tile/ops/sparse_attn/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp new file mode 100644 index 00000000000..f8a623f9bda --- /dev/null +++ b/include/ck_tile/ops/sparse_attn/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp @@ -0,0 +1,872 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { + +// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future) +template +struct BlockFmhaPipelineQRKSVSAsyncVSA +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using AttentionVariant = remove_cvref_t; + using FmhaMask = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + using VLayout = remove_cvref_t; + static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once + static_assert(kQLoadOnce == Policy::QLoadOnce); + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kN1 = BlockFmhaShape::kN1; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; + + static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + // TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x) + // only need special care about seq_k padding (oob need set -INF of p instead of zero) + static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true && + Problem::kPadHeadDimV == true); + static constexpr bool kPadSeqLenQ = true; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x) + static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x) + static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasDropout = Problem::kHasDropout; + + static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && + (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || + !kHasLogitsSoftCap)) || + (!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap)); + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = []() { + if constexpr(std::is_same_v) + return Policy::template GetAlignmentV(); + else + return kPadSeqLenK ? 1 : Policy::template GetAlignmentV(); + }(); + static constexpr index_t kAlignmentO = Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); + +#if CK_TILE_FMHA_FWD_FAST_EXP2 + static constexpr auto R_LOG2E = 1.0 / log2e_v; +#endif + + static constexpr index_t kBlockPerCu = []() { + if constexpr(Problem::kBlockPerCu != -1) + return Problem::kBlockPerCu; + else + { + // minimize occupancy + if constexpr(BiasEnum != BlockAttentionBiasEnum::NO_BIAS && kHasDropout) + { + return 1; + } + + if constexpr(kQKHeaddim <= 32) + { + if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS && + FmhaMask::IsMasking) + return 1; + else + return 2; + } + else if constexpr(kQKHeaddim <= 64) + { + if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + return 2; + else + return 3; + } + else if constexpr(kQKHeaddim <= 128) + { + if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + return 1; + else + return 2; + } + else if constexpr(kQKHeaddim <= 192) + { + if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + return 1; + else + return 2; + } + else if constexpr(kQKHeaddim <= 256) + { + return 1; + } + else + { + return 1; + }; + } + }(); + + static constexpr const char* name = "qr_async"; + + using DropoutType = std::conditional_t; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const QElementFunction& q_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const KElementFunction& /*k_element_func*/, + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const VElementFunction& v_element_func, + const int *kv_block_idx_ptr, + int kv_blocks, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + const BiasElementFunction& bias_element_func, + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile + const LSEElementFunction& lse_element_func, + const SAccElementFunction& s_acc_element_func, + const PComputeElementFunction& p_compute_element_func, + const OAccElementFunction& o_acc_element_func, + FmhaMask mask, + PositionEncoding position_encoding, + float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, + void* smem_ptr, + DropoutType& dropout) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + constexpr auto LdsSeq = Policy::template GetLdsBufferSequence(); + + // K tile in LDS + auto k_lds_ptr = reinterpret_cast(smem_ptr); + auto k_lds_store = generate_tuple( + [&](auto i_buf) { + return make_tile_window( + make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsStoreBlockDescriptor(i_buf)), + Policy::template MakeKLdsStoreBlockDescriptor(i_buf).get_lengths(), + {0, 0, 0}); + }, + number{}); + + auto k_lds_Load_view = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor()); + + auto k_lds_load = + make_tile_window(k_lds_Load_view, + Policy::template MakeKLdsLoadBlockDescriptor().get_lengths(), + {0, 0}); + + // V tile in LDS + auto v_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeVLdsBlockDescriptor()); + auto v_lds_window = make_tile_window( + v_lds, Policy::template MakeVLdsBlockDescriptor().get_lengths(), {0, 0}); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + + int seqlen_k_start = kv_block_idx_ptr[0] * kM0; + auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + q_dram_block_window_tmp.get_window_lengths(), + q_dram_block_window_tmp.get_window_origin(), + Policy::template MakeQRegTileDistribution()); + q_dram_window.init_raw(); + + // TODO: we use async Copy for K, which is inline asm + // a side effect is we have to use inline asm for q as well + auto q = decltype(load_tile(q_dram_window)){}; + // TODO: start from rocm-6.2, compiler will have problem if manually set clear of q. + // however, q would be cleared in the constructor of static distributed tensor + // set_tile(q, number<0>{}); // use per-dword clear to avoid scratch + load_tile_raw(q, q_dram_window); + __builtin_amdgcn_sched_barrier(0); + + using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); + auto s_acc = SaccBlockTileType{}; + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + // infer Sacc, S, P, M, L, Oacc type + using SBlockTileType = decltype(cast_tile(s_acc)); + + using MLBlockTileType = decltype(block_tile_reduce( + SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0})); + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + + // init Oacc, M, L + auto o_acc = OaccBlockTileType{}; + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + clear_tile(o_acc); + set_tile(m, -numeric::infinity()); + clear_tile(l); + + __builtin_amdgcn_sched_barrier(0); + const auto q_origin = q_dram_window.get_window_origin(); + // const auto [seqlen_k_start, seqlen_k_end] = + // mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + // const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); + + const auto num_total_loop = kv_blocks; + + // check early exit if no work to do + if constexpr(FmhaMask::IsMasking || kPadSeqLenK) + { + if(num_total_loop <= 0) + { + if constexpr(kStoreLSE) + { + auto lse = + make_static_distributed_tensor(m.get_tile_distribution()); + + set_tile(lse, -numeric::infinity()); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0) + // otherwise will have compute error(maybe compiler bug?) + + // Note: here occ are all cleard, return it + return o_acc; + } + __builtin_amdgcn_sched_barrier(0); // make sure sched_barrier(0) for this check + } + + auto k_dram_block_window = + make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + k_dram_block_window_tmp.get_window_lengths(), + {seqlen_k_start, 0}); + + auto k_dram_window = make_tile_window( + k_dram_block_window.get_bottom_tensor_view(), + k_dram_block_window.get_window_lengths(), + k_dram_block_window.get_window_origin(), + Policy::template MakeKDramTileDistribution()); // K DRAM tile window for + // load + k_dram_window.init_raw(); + constexpr auto k_oob_ck = bool_constant{}; + constexpr auto k_pre_np = [&]() { + if constexpr(kPadSeqLenK && + (BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + (BiasEnum != BlockAttentionBiasEnum::NO_BIAS && kHasDropout))) + return bool_constant{}; + else + return bool_constant{}; + }(); + + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + auto bias_dram_window = + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), + bias_dram_block_window_tmp.get_window_lengths(), + {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N + Policy::template MakeBiasDramTileDistribution()); + + auto randval_dram_window = dropout.template MakeRandvalDramWindow( + randval_dram_block_window_tmp, seqlen_k_start); + + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + {0, seqlen_k_start}, // TODO: hdim split? + Policy::template MakeVDramTileDistribution()); + + // prefetch K tile + async_load_tile_raw( + k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, number<-1>{}, k_oob_ck, k_pre_np); + move_tile_window(k_dram_window, {0, kK0}); + __builtin_amdgcn_sched_barrier(0); + + // buffer_load_fence(k_dram_window.get_num_of_access(), q.get_thread_buffer()); + buffer_load_fence(k_dram_window.get_num_of_access()); + (void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32 + // auto q_tile = q; // tile_elementwise_in(q_element_func, q); + + index_t i_total_loops = 0; + constexpr index_t k0_loops = kQKHeaddim / kK0; + constexpr index_t k1_loops = kN0 / kK1; + + static_assert(1 <= k0_loops); + static_assert(1 <= k1_loops); + // main loop + do + { + // STAGE 1, QK gemm + clear_tile(s_acc); // initialize C + if constexpr(k0_loops > 1) + { + static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { + async_load_tile_raw(k_lds_store(number{})>{}), + k_dram_window, + number<-1>{}, + k_oob_ck, + k_pre_np); + if constexpr(i_k0 < k0_loops - 1) + move_tile_window(k_dram_window, {0, kK0}); + + async_load_fence(k_dram_window.get_num_of_access()); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + gemm_0(s_acc, + get_slice_tile( + q, sequence<0, i_k0 * kK0>{}, sequence{}), + get_slice_tile(k_lds_load, + sequence<(LdsSeq.at(number{})) * kN0, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN0, kK0>{})); + }); + + } + //__shared__ int printed_flag; + //if (blockIdx.x == 0 && threadIdx.x == 0 && i_total_loops==1000) { + // printed_flag = 100; + //} + + // TODO: this to fix a bug when loop smaller than 2, + // the following fence/barrier will be scheduled inside 1st loop + if constexpr(k0_loops <= 2) + __builtin_amdgcn_sched_barrier(0); + + async_load_fence(); + __builtin_amdgcn_s_barrier(); + + int block_idx = kv_block_idx_ptr[i_total_loops+1]; + //if (threadIdx.x==0 && blockIdx.x==0 && blockIdx.z == 101) printf("%d %d %d\n", i_total_loops, num_total_loop, block_idx); + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); + __builtin_amdgcn_sched_barrier(0); + { // tail + gemm_0( + s_acc, + get_slice_tile( + q, sequence<0, (k0_loops - 1) * kK0>{}, sequence{}), + get_slice_tile(k_lds_load, + sequence<(LdsSeq.at(number{})) * kN0, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN0, kK0>{})); + } + __builtin_amdgcn_sched_barrier(1); + + // STAGE 2, scale_s, add bias, mask, softmax + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); + tile_elementwise_inout( + [&](auto& x, const auto& y) { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + x += type_convert(bias_element_func(y)); +#else + x += log2e_v * + type_convert(bias_element_func(y)); +#endif + }, + s_acc, + bias_tile); + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + s_acc(i_j_idx) *= scale_s; + position_encoding.update(s_acc(i_j_idx), row, col); + }); + }); + } + else + { + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + if constexpr(kHasLogitsSoftCap) + { + auto apply_logits_transform = + [&variant, &variant_params, &block_indices](auto& x) { + x = variant.LogitsTransform(variant_params, + variant.QueryTransform(variant_params, x), + block_indices.batch_idx, + block_indices.qo_head_idx, + block_indices.kv_head_idx); + }; +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i) + { + apply_logits_transform(s_acc.thread_buf_[i]); + } +#else + for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i) + { + apply_logits_transform(s_acc.thread_buf_[i]); + } +#endif + } + else + { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); +#endif + } + } + move_tile_window(bias_dram_window, {0, kN0}); + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), + k_origin.at(number<0>{}), + number{}, + number{}); + + if(need_perpixel_check) + { + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return !variant.LogitsMask(variant_params, + block_indices.batch_idx, + row, + col, + block_indices.qo_head_idx, + block_indices.kv_head_idx); + }); + } + } + + const auto s = cast_tile(s_acc); // S{j} + auto m_local = block_tile_reduce( + s, + sequence<1>{}, + f_max, + -numeric::infinity()); // m_local = rowmax(S{j}) + block_tile_reduce_sync(m_local, f_max, bool_constant{}); + + const auto m_old = m; // m{j-1} + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} + + auto p_compute = make_static_distributed_tensor( + s.get_tile_distribution()); // Pcompute{j} + + __builtin_amdgcn_sched_barrier(0x7F); + // store & prefetch next v, after the max reduction + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_buf); + + auto v_lds_window_tmp = + get_slice_tile(v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + + store_tile( + v_lds_window_tmp, + tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch + } + else + { + auto v_lds_window_tmp = + get_slice_tile(v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + store_tile(v_lds_window_tmp, + tile_elementwise_in(v_element_func, v_buf)); // store the prefetch + } + + if constexpr(k1_loops > 1) + { + move_tile_window( + v_dram_window, + {0, kK1}); // will have scratch if move this right after load_tile(v_dram)... + v_buf = load_tile( + v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf + } + __builtin_amdgcn_sched_barrier(0); + + static const auto get_validated_m = [](SMPLComputeDataType raw_m) { + /// NOTICE: bias might be materialized mask including -inf values, need + /// consideration. alibi does not have this problem + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) + { + return raw_m == -numeric::infinity() + ? type_convert(0.f) + : raw_m; + } + else + { + return raw_m; + } + }; + + constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + auto row_max = scale_s * get_validated_m(m[i_idx]); +#endif + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + if constexpr(kHasLogitsSoftCap) + { + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); + } + } +#else + p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); +#endif + }); + }); + + auto rowsum_p = block_tile_reduce( + p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) + + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); + // l{j}, Oacc{j} + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + const auto tmp = [&]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + if constexpr(kHasLogitsSoftCap) + { + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale_s * get_validated_m(m[i_idx]); + return exp2(scale_s * m_old[i_idx] - row_max); + } + } + }(); +#else + const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx])); +#endif + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + // FIXME: this use different equation from FA v2 paper, + // but produce correc result. + // Is the equation wrong? + o_acc(i_j_idx) *= tmp; + }); + }); + + if constexpr(kHasDropout) + { + auto randval_ptr = + reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); + dropout.template Run( + randval_ptr, + seqlen_k_start + i_total_loops * kN0, + p_compute, + randval_dram_window); + } + + const auto p = [&]() { + if constexpr(std::is_same_v) + return impl::cast_tile_pkrtz_fp16_fp32( + tile_elementwise_in(p_compute_element_func, p_compute)); + else + return cast_tile( + tile_elementwise_in(p_compute_element_func, p_compute)); + }(); + + // STAGE 3, KV gemm + if constexpr(k1_loops > 1) + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1) + { + v_buf = load_tile( + v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf + } + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile( + p, sequence<0, i_k1 * kK1>{}, sequence{}), + get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{})); + + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_buf); + auto v_lds_window_tmp = get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + store_tile(v_lds_window_tmp, + tile_elementwise_in(v_element_func, + v_shuffle_tmp)); // store the prefetch + } + else + { + auto v_lds_window_tmp = get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + store_tile(v_lds_window_tmp, + tile_elementwise_in(v_element_func, v_buf)); // store next v_buf + } + if constexpr(i_k1 < k1_loops - 1) + move_tile_window(v_dram_window, {0, kK1}); + }); + } + i_total_loops++; + if(i_total_loops < num_total_loop) + { + move_tile_window(v_dram_window, {0, kN0*(block_idx-1)}); + // v_dram_window = + // make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + // v_dram_block_window_tmp.get_window_lengths(), + // {0, kv_block_idx[i_total_loops]}, + // Policy::template MakeVDramTileDistribution()); + // move K tile windows + move_tile_window(k_dram_block_window, {kN0 * block_idx, 0}); + k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); + //k_dram_block_window = + // make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + // k_dram_block_window_tmp.get_window_lengths(), + // {kv_block_idx[i_total_loops], 0}); + //k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); + + if constexpr(k1_loops >= 2 && + LdsSeq.at(number<0>{}) == LdsSeq.at(number{})) + __builtin_amdgcn_s_barrier(); + async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), + k_dram_window, + number<-1>{}, + k_oob_ck, + k_pre_np); + move_tile_window(k_dram_window, {0, kK0}); + } + // tail + { + block_sync_lds(); + gemm_1( + o_acc, + get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), + get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{})); + } + } while(i_total_loops < num_total_loop); + + // store lse + if constexpr(kStoreLSE) + { + auto lse = make_static_distributed_tensor(m.get_tile_distribution()); + + constexpr auto lse_spans = decltype(lse)::get_distributed_spans(); + sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + lse(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]); + } + else + { + if constexpr(kHasLogitsSoftCap) + { + lse(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]); + } + else + { + lse(i_idx) = m_[i_idx] * scale_s * R_LOG2E + log(l_[i_idx]); + } + } +#else + lse(i_idx) = m_[i_idx] + log(l_[i_idx]); +#endif + }); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + + // finally, O + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = [&]() { + if constexpr(FmhaMask::IsMasking) + { + return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; + } + else + return 1 / l[i_idx]; + }(); + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + }); + + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); + + return o_acc; + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const int *kv_block_idx_ptr, + int kv_blocks, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile + LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile + FmhaMask mask, + PositionEncoding position_encoding, + float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, + void* smem_ptr, + DropoutType& dropout) const + { + return operator()(q_dram_block_window_tmp, + identity{}, + k_dram_block_window_tmp, + identity{}, + v_dram_block_window_tmp, + identity{}, + kv_block_idx_ptr, + kv_blocks, + bias_dram_block_window_tmp, + identity{}, + randval_dram_block_window_tmp, + lse_dram_block_window_tmp, + identity{}, + identity{}, + identity{}, + identity{}, + mask, + position_encoding, + scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + dropout); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/sparse_attn/fmha_fwd_jenga_kernel.hpp b/include/ck_tile/ops/sparse_attn/fmha_fwd_jenga_kernel.hpp new file mode 100644 index 00000000000..45d7af8ef98 --- /dev/null +++ b/include/ck_tile/ops/sparse_attn/fmha_fwd_jenga_kernel.hpp @@ -0,0 +1,1507 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/variants.hpp" + +#include +#include +#include +#include + +// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] +// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] +// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k] +// P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k]) +// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k] + +namespace ck_tile { + +template +struct FmhaFwdJengaKernel +{ + using FmhaPipeline = ck_tile::remove_cvref_t; + using EpiloguePipeline = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; + static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; + static_assert(kBlockPerCu > 0); + static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu; + + using QDataType = ck_tile::remove_cvref_t; + using KDataType = ck_tile::remove_cvref_t; + using VDataType = ck_tile::remove_cvref_t; + using BiasDataType = ck_tile::remove_cvref_t; + using RandValOutputDataType = + ck_tile::remove_cvref_t; + using LSEDataType = ck_tile::remove_cvref_t; + using ODataType = ck_tile::remove_cvref_t; + using SaccDataType = ck_tile::remove_cvref_t; + + using VLayout = ck_tile::remove_cvref_t; + + static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; + static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; + static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap; + static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; + static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; + static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; + static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; + static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ; + + using AttentionVariant = ck_tile::remove_cvref_t; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr bool kHasMask = FmhaMask::IsMasking; + + static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy; + + // clang-format off + template struct t2s; + template <> struct t2s { static constexpr const char * name = "fp32"; }; + template <> struct t2s { static constexpr const char * name = "fp16"; }; + template <> struct t2s { static constexpr const char * name = "bf16"; }; + template <> struct t2s { static constexpr const char * name = "fp8"; }; + template <> struct t2s { static constexpr const char * name = "bf8"; }; + // clang-format on + + CK_TILE_HOST static std::string GetName() + { + // sync with generate.py + // clang-format off + using bfs = typename FmhaPipeline::BlockFmhaShape; + using g0br = typename bfs::Gemm0BlockWarps; + using g1br = typename bfs::Gemm1BlockWarps; + using g0wt = typename bfs::Gemm0WarpTile; + using g1wt = typename bfs::Gemm1WarpTile; + #define _SS_ std::string + #define _TS_ std::to_string + auto pn = [&] () { + std::string n; + if (kPadSeqLenQ) n += "s"; + if (kPadSeqLenK) n += "sk"; + if (kPadHeadDimQ) n += "d"; + if (kPadHeadDimV) n += "dv"; + return n.empty() ? n : std::string("p") + n; }(); + return + _SS_("fmha_fwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s::name) + + "_" + (kIsGroupMode ? "group" : "batch") + "_" + "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + + _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" + + "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" + + "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" + + "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" + + "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" + + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) + + (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ); + #undef _SS_ + #undef _TS_ + // clang-format on + } + + template // to avoid duplicated base class prblem, introduce an template + // arg + struct FmhaFwdEmptyKargs + { + }; + + // kargs use aggregate initializer, so no constructor will provided + // use inheritance to minimize karg size + // user need to use MakeKargs() function to create kargs. + struct FmhaFwdCommonKargs + { + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* block_relation_onehot_ptr; + void* o_ptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + + ck_tile::index_t num_head_q; + // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k + // if this param is larger than 1, indicate MQA/GQA case + ck_tile::index_t nhead_ratio_qk; + float scale_s; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_o; + + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_o; + }; + + struct FmhaFwdLogitsSoftCapKargs + { + FmhaFwdLogitsSoftCapKargs() = default; + + void init_logits_soft_cap(float logits_soft_cap_) + { + if(0 < logits_soft_cap_) + { + logits_soft_cap = logits_soft_cap_; + logits_soft_cap_rcp = 1.f / logits_soft_cap; + } + else + { + logits_soft_cap = 0.f; + logits_soft_cap_rcp = 0.f; + } + } + + float logits_soft_cap; + float logits_soft_cap_rcp; + }; + + struct FmhaFwdCommonBiasKargs + { + const void* bias_ptr = nullptr; + ck_tile::index_t stride_bias = 0; + ck_tile::index_t nhead_stride_bias = 0; + }; + + struct FmhaFwdBatchModeBiasKargs : FmhaFwdCommonBiasKargs + { + ck_tile::index_t batch_stride_bias = 0; + }; + + struct FmhaFwdAlibiKargs + { + // alibi is batch*nhead*1, no matter in batch/group mode, they are the same + const void* alibi_slope_ptr; + ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope + }; + + struct FmhaFwdMaskKargs + { + // ck_tile::index_t window_size_left, window_size_right; + ck_tile::index_t window_size_left, window_size_right; + ck_tile::GenericAttentionMaskEnum mask_type; + }; + + struct FmhaFwdFp8StaticQuantKargs + { + float scale_p; + float scale_o; + }; + + struct FmhaFwdCommonLSEKargs + { + void* lse_ptr = nullptr; + ck_tile::index_t nhead_stride_lse = 0; + ck_tile::index_t batch_stride_lse = 0; + }; + + struct FmhaFwdDropoutSeedOffset + { + template + union ValueOrPointer + { + T val; + const T* ptr; + }; + + ValueOrPointer drop_seed; + ValueOrPointer drop_offset; + bool is_drop_seed_offset_from_host; + }; + + struct FmhaFwdCommonDropoutKargs : FmhaFwdDropoutSeedOffset + { + void init_dropout(float p_drop, uint64_t seed, uint64_t offset) + { + float p_undrop = 1.0 - p_drop; + p_undrop_in_uint8_t = + uint8_t(std::floor(p_undrop * std::numeric_limits::max())); + rp_undrop = 1.0 / p_undrop; + + this->drop_seed.val = seed; + this->drop_offset.val = offset; + this->is_drop_seed_offset_from_host = true; + } + + void init_dropout(float p_drop, const uint64_t* seed_ptr, const uint64_t* offset_ptr) + { + float p_undrop = 1.0 - p_drop; + p_undrop_in_uint8_t = + uint8_t(std::floor(p_undrop * std::numeric_limits::max())); + rp_undrop = 1.0 / p_undrop; + + this->drop_seed.ptr = seed_ptr; + this->drop_offset.ptr = offset_ptr; + this->is_drop_seed_offset_from_host = false; + } + + float rp_undrop = 1; + uint8_t p_undrop_in_uint8_t = std::numeric_limits::max(); + bool is_store_randval = false; + void* rand_val_ptr = nullptr; + + ck_tile::index_t stride_randval = 0; + ck_tile::index_t nhead_stride_randval = 0; + }; + + struct FmhaFwdBatchModeDropoutKargs : FmhaFwdCommonDropoutKargs + { + ck_tile::index_t batch_stride_randval = 0; + }; + + struct FmhaFwdSkipMinSeqlenQKargs + { + ck_tile::index_t min_seqlen_q = 0; + }; + + struct FmhaFwdBatchModeKargs + : FmhaFwdCommonKargs, + std::conditional_t>>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t> + { + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_o; + }; + + struct FmhaFwdGroupModeKargs + : FmhaFwdCommonKargs, + std::conditional_t>>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t> + { + const int32_t* seqstart_q_ptr; + const int32_t* seqstart_k_ptr; + const int32_t* seqlen_k_ptr; + }; + + using Kargs = std::conditional_t; + + struct BlockIndices + { + ck_tile::index_t batch_idx; + ck_tile::index_t qo_head_idx; + ck_tile::index_t kv_head_idx; + }; + + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargsImpl(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* block_relation_onehot_ptr, + const void* bias_ptr, + void* rand_val_ptr, + void* lse_ptr, + void* o_ptr, + ck_tile::index_t seqlen_q, + ck_tile::index_t seqlen_k, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale_s, + float scale_p, + float scale_o, + float logits_soft_cap, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_q, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, + ck_tile::index_t batch_stride_bias, + ck_tile::index_t batch_stride_randval, + ck_tile::index_t batch_stride_lse, + ck_tile::index_t batch_stride_o, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + std::variant, std::pair> + drop_seed_offset) + { + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + block_relation_onehot_ptr, + o_ptr, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, +#if CK_TILE_FMHA_FWD_FAST_EXP2 + static_cast(scale_s * ck_tile::log2e_v<>), +#else + scale_s, +#endif + stride_q, + stride_k, + stride_v, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_o}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for mask + {}, // placeholder for lse + {}, // placeholder for fp8_static_quant args + {}, // placeholder for dropout + {}, // placeholder for logits_soft_cap + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_o}; + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + kargs.bias_ptr = bias_ptr; + kargs.stride_bias = stride_bias; + kargs.nhead_stride_bias = nhead_stride_bias; + kargs.batch_stride_bias = batch_stride_bias; + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + kargs.alibi_slope_ptr = bias_ptr; + kargs.alibi_slope_stride = stride_bias; + } + if constexpr(kHasMask) + { + kargs.window_size_left = window_size_left; + kargs.window_size_right = window_size_right; + kargs.mask_type = static_cast(mask_type); + } + if constexpr(kStoreLSE) + { + kargs.lse_ptr = lse_ptr; + kargs.nhead_stride_lse = nhead_stride_lse; + kargs.batch_stride_lse = batch_stride_lse; + } + if constexpr(kDoFp8StaticQuant) + { + kargs.scale_p = scale_p; + kargs.scale_o = scale_o; + } + if constexpr(kHasDropout) + { + if(drop_seed_offset.index() == 0) // seed & offset come from host + { + const auto& [seed, offset] = std::get<0>(drop_seed_offset); + kargs.init_dropout(p_drop, seed, offset); + } + else // seed & offset come from device + { + const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset); + kargs.init_dropout(p_drop, + reinterpret_cast(seed_ptr), + reinterpret_cast(offset_ptr)); + } + + kargs.rand_val_ptr = rand_val_ptr; + kargs.stride_randval = stride_randval; + kargs.nhead_stride_randval = nhead_stride_randval; + kargs.batch_stride_randval = batch_stride_randval; + kargs.is_store_randval = s_randval; + } + if constexpr(kHasLogitsSoftCap) + { + kargs.init_logits_soft_cap(logits_soft_cap); + } + + return kargs; + } + + // std::variant<> can't take in a list initializer, overload for backward compatibility + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* block_relation_onehot_ptr, + const void* bias_ptr, + void* rand_val_ptr, + void* lse_ptr, + void* o_ptr, + ck_tile::index_t seqlen_q, + ck_tile::index_t seqlen_k, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale_s, + float scale_p, + float scale_o, + float logits_soft_cap, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_q, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, + ck_tile::index_t batch_stride_bias, + ck_tile::index_t batch_stride_randval, + ck_tile::index_t batch_stride_lse, + ck_tile::index_t batch_stride_o, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + const std::tuple& drop_seed_offset) + { + return MakeKargsImpl( + q_ptr, + k_ptr, + v_ptr, + block_relation_onehot_ptr, + bias_ptr, + rand_val_ptr, + lse_ptr, + o_ptr, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + scale_s, + scale_p, + scale_o, + logits_soft_cap, + stride_q, + stride_k, + stride_v, + stride_bias, + stride_randval, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_bias, + nhead_stride_randval, + nhead_stride_lse, + nhead_stride_o, + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_bias, + batch_stride_randval, + batch_stride_lse, + batch_stride_o, + window_size_left, + window_size_right, + mask_type, + p_drop, + s_randval, + std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); + } + + // std::variant<> can't take in a list initializer, overload for backward compatibility + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* block_relation_onehot_ptr, + const void* bias_ptr, + void* rand_val_ptr, + void* lse_ptr, + void* o_ptr, + ck_tile::index_t seqlen_q, + ck_tile::index_t seqlen_k, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale_s, + float scale_p, + float scale_o, + float logits_soft_cap, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_q, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, + ck_tile::index_t batch_stride_bias, + ck_tile::index_t batch_stride_randval, + ck_tile::index_t batch_stride_lse, + ck_tile::index_t batch_stride_o, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + const std::tuple& drop_seed_offset) + { + return MakeKargsImpl( + q_ptr, + k_ptr, + v_ptr, + block_relation_onehot_ptr, + bias_ptr, + rand_val_ptr, + lse_ptr, + o_ptr, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + scale_s, + scale_p, + scale_o, + logits_soft_cap, + stride_q, + stride_k, + stride_v, + stride_bias, + stride_randval, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_bias, + nhead_stride_randval, + nhead_stride_lse, + nhead_stride_o, + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_bias, + batch_stride_randval, + batch_stride_lse, + batch_stride_o, + window_size_left, + window_size_right, + mask_type, + p_drop, + s_randval, + std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); + } + + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargsImpl(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* block_relation_onehot_ptr, + const void* bias_ptr, + void* rand_val_ptr, + void* lse_ptr, + void* o_ptr, + const void* seqstart_q_ptr, + const void* seqstart_k_ptr, + const void* seqlen_k_ptr, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale_s, + float scale_p, + float scale_o, + float logits_soft_cap, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + ck_tile::index_t min_seqlen_q, + float p_drop, + bool s_randval, + std::variant, std::pair> + drop_seed_offset) + { + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + block_relation_onehot_ptr, + o_ptr, + -1, // seqlen will be updated by another pointer + -1, // + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, +#if CK_TILE_FMHA_FWD_FAST_EXP2 + static_cast(scale_s * ck_tile::log2e_v<>), +#else + scale_s, +#endif + stride_q, + stride_k, + stride_v, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_o}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for mask + {}, // placeholder for lse + {}, // placeholder for fp8_static_quant args + {}, // placeholder for dropout + {}, // placeholder for logits_soft_cap + {}, // placeholder for min_seqlen_q + reinterpret_cast(seqstart_q_ptr), + reinterpret_cast(seqstart_k_ptr), + reinterpret_cast(seqlen_k_ptr)}; + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + kargs.bias_ptr = bias_ptr; + kargs.stride_bias = stride_bias; + kargs.nhead_stride_bias = nhead_stride_bias; + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + kargs.alibi_slope_ptr = bias_ptr; + kargs.alibi_slope_stride = stride_bias; + } + if constexpr(kHasMask) + { + kargs.window_size_left = window_size_left; + kargs.window_size_right = window_size_right; + kargs.mask_type = static_cast(mask_type); + } + if constexpr(kStoreLSE) + { + kargs.lse_ptr = lse_ptr; + kargs.nhead_stride_lse = nhead_stride_lse; + } + if constexpr(kDoFp8StaticQuant) + { + kargs.scale_p = scale_p; + kargs.scale_o = scale_o; + } + if constexpr(kHasDropout) + { + if(drop_seed_offset.index() == 0) // seed & offset come from host + { + const auto& [seed, offset] = std::get<0>(drop_seed_offset); + kargs.init_dropout(p_drop, seed, offset); + } + else // seed & offset come from device + { + const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset); + kargs.init_dropout(p_drop, + reinterpret_cast(seed_ptr), + reinterpret_cast(offset_ptr)); + } + + kargs.rand_val_ptr = rand_val_ptr; + kargs.stride_randval = stride_randval; + kargs.nhead_stride_randval = nhead_stride_randval; + kargs.is_store_randval = s_randval; + } + if constexpr(kHasLogitsSoftCap) + { + kargs.init_logits_soft_cap(logits_soft_cap); + } + if constexpr(kSkipMinSeqlenQ) + { + kargs.min_seqlen_q = min_seqlen_q; + } + + return kargs; + } + + // std::variant<> can't take in a list initializer, overload for backward compatibility + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* block_relation_onehot_ptr, + const void* bias_ptr, + void* rand_val_ptr, + void* lse_ptr, + void* o_ptr, + const void* seqstart_q_ptr, + const void* seqstart_k_ptr, + const void* seqlen_k_ptr, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale_s, + float scale_p, + float scale_o, + float logits_soft_cap, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + const std::tuple& drop_seed_offset) + { + return MakeKargsImpl( + q_ptr, + k_ptr, + v_ptr, + block_relation_onehot_ptr, + bias_ptr, + rand_val_ptr, + lse_ptr, + o_ptr, + seqstart_q_ptr, + seqstart_k_ptr, + seqlen_k_ptr, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + scale_s, + scale_p, + scale_o, + logits_soft_cap, + stride_q, + stride_k, + stride_v, + stride_bias, + stride_randval, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_bias, + nhead_stride_randval, + nhead_stride_lse, + nhead_stride_o, + window_size_left, + window_size_right, + mask_type, + p_drop, + s_randval, + std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); + } + + // std::variant<> can't take in a list initializer, overload for backward compatibility + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* block_relation_onehot_ptr, + const void* bias_ptr, + void* rand_val_ptr, + void* lse_ptr, + void* o_ptr, + const void* seqstart_q_ptr, + const void* seqstart_k_ptr, + const void* seqlen_k_ptr, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale_s, + float scale_p, + float scale_o, + float logits_soft_cap, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + const std::tuple& drop_seed_offset) + { + return MakeKargsImpl( + q_ptr, + k_ptr, + v_ptr, + block_relation_onehot_ptr, + bias_ptr, + rand_val_ptr, + lse_ptr, + o_ptr, + seqstart_q_ptr, + seqstart_k_ptr, + seqlen_k_ptr, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + scale_s, + scale_p, + scale_o, + logits_soft_cap, + stride_q, + stride_k, + stride_v, + stride_bias, + stride_randval, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_bias, + nhead_stride_randval, + nhead_stride_lse, + nhead_stride_o, + window_size_left, + window_size_right, + mask_type, + p_drop, + s_randval, + std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); + } + + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, + ck_tile::index_t nhead_, + ck_tile::index_t seqlen_q_, + ck_tile::index_t hdim_v_, + bool has_padded_seqlen_k = false) + { + has_padded_seqlen_k = true; + // has_padded_seqlen_k is determined by checking (seqlen_k_ptr != nullptr) + if(has_padded_seqlen_k) + { + // TODO: this may need tuning + return dim3(nhead_, + batch_size_, + ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1)); + } + else + { + // TODO: this may need tuning + return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1), + nhead_, + batch_size_); + } + } + + CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs) + { + bool has_padded_seqlen_k = false; + + if constexpr(kIsGroupMode) + has_padded_seqlen_k = (kargs.seqlen_k_ptr != nullptr); + + has_padded_seqlen_k = true; + + if(has_padded_seqlen_k) + { + // const index_t num_tile_m0 = seqlen_q / kM0; + const index_t num_tile_n1 = + ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1); + + const index_t i_block = blockIdx.z; + const index_t i_nhead = blockIdx.x; + const index_t i_batch = blockIdx.y; + + const auto f = [](index_t dividend, index_t divisor) { + index_t quotient = dividend / divisor; + index_t modulus = dividend - quotient * divisor; + return ck_tile::make_tuple(quotient, modulus); + }; + + const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); + + if constexpr(kHasMask) + { + // assume that num_tile_n1 is always 1 + return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); + } + else + { + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } + } + else + { + // const index_t num_tile_m0 = seqlen_q / kM0; + const index_t num_tile_n1 = + ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1); + + const index_t i_block = blockIdx.x; + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; + + const auto f = [](index_t dividend, index_t divisor) { + index_t quotient = dividend / divisor; + index_t modulus = dividend - quotient * divisor; + return ck_tile::make_tuple(quotient, modulus); + }; + + const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); + + if constexpr(kHasMask) + { + // assume that num_tile_n1 is always 1 + return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); + } + else + { + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } + } + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + // allocate LDS + __shared__ char smem_ptr[GetSmemSize()+256*sizeof(int)]; + + // if (threadIdx.x==0 && blockIdx.x==0 && blockIdx.z ==0) printf("smem size: %d", int(GetSmemSize())); + + // divide problem + const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); + + const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); + const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); + + long_index_t batch_offset_q = 0; + long_index_t batch_offset_k = 0; + long_index_t batch_offset_v = 0; + long_index_t batch_offset_bias = 0; + long_index_t batch_offset_randval = 0; + long_index_t batch_offset_lse = 0; + long_index_t batch_offset_o = 0; + + if constexpr(kIsGroupMode) + { + // get starting offset for each batch + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; + + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; + if constexpr(std::is_same_v) + { + batch_offset_v = key_start * kargs.stride_v; + } + else + { + batch_offset_v = key_start; + } + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + batch_offset_bias = query_start * kargs.stride_bias; + } + if constexpr(kStoreLSE) + { + batch_offset_lse = query_start; + } + if constexpr(kHasDropout) + { + batch_offset_randval = query_start * kargs.stride_randval; + } + batch_offset_o = query_start * kargs.stride_o; + + // get real # queries & # keys under group mode + const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; + kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + + if constexpr(kSkipMinSeqlenQ) + { + if(kargs.seqlen_q <= kargs.min_seqlen_q) + { + return; + } + } + + // # of required blocks is different in each groups, terminate unnecessary blocks + // earlier + if(kargs.seqlen_q <= i_m0) + { + return; + } + + if(kargs.seqlen_k_ptr != nullptr) + { + kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; + } + else + { + const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; + kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; + } + } + else + { + batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; + batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; + batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; + } + if constexpr(kStoreLSE) + { + batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; + } + if constexpr(kHasDropout) + { + batch_offset_randval = + static_cast(i_batch) * kargs.batch_stride_randval; + } + batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; + } + + // for simplicity, batch stride we just modify the pointer + const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_q + + batch_offset_q; + const KDataType* k_ptr = + reinterpret_cast(kargs.k_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + + batch_offset_k; + const VDataType* v_ptr = + reinterpret_cast(kargs.v_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + + batch_offset_v; + + // sparse mask + const bool* block_relation_onehot_ptr = reinterpret_cast(kargs.block_relation_onehot_ptr) + + static_cast(i_batch * kargs.num_head_q + i_nhead) * ck_tile::integer_divide_ceil(kargs.seqlen_q, FmhaPipeline::kM0) + * ck_tile::integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kN0) + i_tile_m * ck_tile::integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kN0); + + ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_o + + batch_offset_o; + + // Q/K/V DRAM and DRAM window + const auto q_dram = [&]() { + const auto q_dram_naive = make_naive_tensor_view( + q_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_q), + make_tuple(kargs.stride_q, 1), + number{}, + number<1>{}); + if constexpr(FmhaPipeline::kQLoadOnce) + { + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + const auto k_dram = [&]() { + const auto k_dram_naive = make_naive_tensor_view( + k_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_q), + make_tuple(kargs.stride_k, 1), + number{}, + number<1>{}); + + constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false; + return pad_tensor_view( + k_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + const auto v_dram = [&]() { + if constexpr(std::is_same_v) + { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_v), + make_tuple(kargs.stride_v, 1), + number{}, + number<1>{}); + + const auto v_dram_transposed = + transform_tensor_view(v_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_v), + make_pass_through_transform(kargs.seqlen_k)), + make_tuple(sequence<1>{}, sequence<0>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false; + return pad_tensor_view( + v_dram_transposed, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.hdim_v, kargs.seqlen_k), + make_tuple(kargs.stride_v, 1), + number{}, + number<1>{}); + + constexpr bool kPadHeadDimV_ = kUseAsyncCopy ? kPadHeadDimV : false; + return pad_tensor_view( + v_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + // sparse mask + // const auto lut_dram = make_naive_tensor_view( + // lut_ptr, + // make_tuple(kargs.seqlen_k/number{}, 1), + // make_tuple(1, 1), + // number<1>{}, + // number<1>{}); + + // const auto valid_block_num_dram = make_naive_tensor_view( + // valid_block_num_ptr, + // make_tuple(kargs.seqlen_q/number{}), + // make_tuple(1), + // number<1>{}, + // number<1>{}); + + + auto q_dram_window = make_tile_window( + q_dram, + [&]() { + if constexpr(FmhaPipeline::kQLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, number{}); + }(), + {i_m0, 0}); + + auto k_dram_window = make_tile_window( + k_dram, make_tuple(number{}, number{}), {0, 0}); + + auto v_dram_window = + make_tile_window(v_dram, + make_tuple(number{}, number{}), + {i_n1, 0}); + + // auto lut_dram_window = make_tile_window( + // lut_dram, make_tuple(1,1), {0,0}); + // auto valid_block_num_window = make_tile_window( + // valid_block_num_dram, make_tuple(1), {i_tile_m}); + + /// FIXME: Before C++20, capturing structured binding variables are not supported. Remove + /// following copy capture of the 'i_nhead' if in C++20 + const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto bias_dram_window_lengths = + make_tuple(number{}, number{}); + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + const BiasDataType* bias_ptr = + reinterpret_cast(kargs.bias_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_bias + + batch_offset_bias; + + const auto bias_dram = [&]() { + const auto bias_dram_naive = make_naive_tensor_view( + bias_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_bias, 1), + number{}, + number<1>{}); + + return pad_tensor_view(bias_dram_naive, + bias_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0}); + } + else + { + return make_null_tile_window(bias_dram_window_lengths); + } + }(); + + // lse + auto lse_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto lse_dram_window_lengths = make_tuple(number{}); + if constexpr(kStoreLSE) + { + LSEDataType* lse_ptr = + reinterpret_cast(kargs.lse_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse; + + const auto lse_dram = [&]() { + const auto lse_dram_naive = make_naive_tensor_view( + lse_ptr, + make_tuple(kargs.seqlen_q), + make_tuple(1), + number<1>{}, + number<1>{}); + + return pad_tensor_view( + lse_dram_naive, lse_dram_window_lengths, sequence{}); + }(); + + return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0}); + } + else + { + return make_null_tile_window(lse_dram_window_lengths); + } + }(); + + auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() { + if constexpr(kHasDropout) + { + return BlockDropout{i_batch_, + i_nhead_, + kargs.num_head_q, + kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val + : *kargs.drop_seed.ptr, + kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val + : *kargs.drop_offset.ptr, + kargs.rp_undrop, + kargs.p_undrop_in_uint8_t, + kargs.is_store_randval}; + } + else + { + return NullBlockDropout{}; + }; + }(); + + auto randval_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto randval_dram_window_lengths = + make_tuple(number{}, number{}); + if constexpr(kHasDropout) + { + RandValOutputDataType* rand_val_ptr = + reinterpret_cast(kargs.rand_val_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_randval + + batch_offset_randval; + + const auto randval_dram = [&]() { + const auto randval_dram_naive = + make_naive_tensor_view( + rand_val_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_randval, 1), + number<1>{}, + number<1>{}); + + return pad_tensor_view(randval_dram_naive, + randval_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0}); + } + else + { + return make_null_tile_window(randval_dram_window_lengths); + } + }(); + + FmhaMask mask = [&]() { + if constexpr(kHasMask) + return ck_tile::make_generic_attention_mask_from_lr_window( + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); + else + return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; + }(); + + // WA i_batch capture structure binding before c++20 + auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + // data loading, shared by entire wg + // TODO: how to use s_read? + SaccDataType slope = + *(reinterpret_cast(kargs.alibi_slope_ptr) + + i_batch_ * kargs.alibi_slope_stride + i_nhead_); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + slope *= ck_tile::log2e_v<>; +#endif + if constexpr(kHasMask) + { + return make_alibi_from_lr_mask(slope, + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type); + } + else + { + return Alibi{ + slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT}; + } + } + else + { + return EmptyPositionEncoding{}; + } + }(); + + AttentionVariant variant; + const auto variant_params = [&] { + if constexpr(kHasLogitsSoftCap) + { + return ck_tile::LogitsSoftCapParams{ + mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp}; + } + else + { + return ck_tile::StandardAttentionParams{mask, kargs.scale_s}; + } + }(); + + BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; + + auto o_acc_tile = [&]() { + // TODO: constexpr(kDoFp8StaticQuant) + return FmhaPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + block_relation_onehot_ptr, + bias_dram_window, + randval_dram_window, + lse_dram_window, + mask, + position_encoding, + kargs.scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + dropout); + }(); + + // O DRAM and O DRAM window + auto o_dram = [&]() { + const auto o_dram_naive = make_naive_tensor_view( + o_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.stride_o, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + o_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto o_dram_window = + make_tile_window(o_dram, + make_tuple(number{}, number{}), + {i_m0, i_n1}); + + EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/sparse_attn/fmha_fwd_vsa_kernel.hpp b/include/ck_tile/ops/sparse_attn/fmha_fwd_vsa_kernel.hpp new file mode 100644 index 00000000000..619e21b0c09 --- /dev/null +++ b/include/ck_tile/ops/sparse_attn/fmha_fwd_vsa_kernel.hpp @@ -0,0 +1,1524 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/variants.hpp" + +#include +#include +#include +#include + +// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] +// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] +// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k] +// P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k]) +// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k] + +namespace ck_tile { + +template +struct FmhaFwdVSAKernel +{ + using FmhaPipeline = ck_tile::remove_cvref_t; + using EpiloguePipeline = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; + static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; + static_assert(kBlockPerCu > 0); + static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu; + + using QDataType = ck_tile::remove_cvref_t; + using KDataType = ck_tile::remove_cvref_t; + using VDataType = ck_tile::remove_cvref_t; + using BiasDataType = ck_tile::remove_cvref_t; + using RandValOutputDataType = + ck_tile::remove_cvref_t; + using LSEDataType = ck_tile::remove_cvref_t; + using ODataType = ck_tile::remove_cvref_t; + using SaccDataType = ck_tile::remove_cvref_t; + + using VLayout = ck_tile::remove_cvref_t; + + static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; + static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; + static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap; + static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; + static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; + static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; + static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum; + static constexpr bool kDoFp8StaticQuant = (QScaleEnum != ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE); + static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ; + + using AttentionVariant = ck_tile::remove_cvref_t; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr bool kHasMask = FmhaMask::IsMasking; + + static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy; + + // clang-format off + template struct t2s; + template <> struct t2s { static constexpr const char * name = "fp32"; }; + template <> struct t2s { static constexpr const char * name = "fp16"; }; + template <> struct t2s { static constexpr const char * name = "bf16"; }; + template <> struct t2s { static constexpr const char * name = "fp8"; }; + template <> struct t2s { static constexpr const char * name = "bf8"; }; + // clang-format on + + CK_TILE_HOST static std::string GetName() + { + // sync with generate.py + // clang-format off + using bfs = typename FmhaPipeline::BlockFmhaShape; + using g0br = typename bfs::Gemm0BlockWarps; + using g1br = typename bfs::Gemm1BlockWarps; + using g0wt = typename bfs::Gemm0WarpTile; + using g1wt = typename bfs::Gemm1WarpTile; + #define _SS_ std::string + #define _TS_ std::to_string + auto pn = [&] () { + std::string n; + if (kPadSeqLenQ) n += "s"; + if (kPadSeqLenK) n += "sk"; + if (kPadHeadDimQ) n += "d"; + if (kPadHeadDimV) n += "dv"; + return n.empty() ? n : std::string("p") + n; }(); + return + _SS_("fmha_fwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s::name) + + "_" + (kIsGroupMode ? "group" : "batch") + "_" + "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + + _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" + + "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" + + "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" + + "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" + + "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" + + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) + + (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ); + #undef _SS_ + #undef _TS_ + // clang-format on + } + + template // to avoid duplicated base class prblem, introduce an template + // arg + struct FmhaFwdEmptyKargs + { + }; + + // kargs use aggregate initializer, so no constructor will provided + // use inheritance to minimize karg size + // user need to use MakeKargs() function to create kargs. + struct FmhaFwdCommonKargs + { + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* lut_ptr; + const void* valid_block_num_ptr; + void* o_ptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + + ck_tile::index_t num_head_q; + // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k + // if this param is larger than 1, indicate MQA/GQA case + ck_tile::index_t nhead_ratio_qk; + float scale_s; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_o; + + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_o; + }; + + struct FmhaFwdLogitsSoftCapKargs + { + FmhaFwdLogitsSoftCapKargs() = default; + + void init_logits_soft_cap(float logits_soft_cap_) + { + if(0 < logits_soft_cap_) + { + logits_soft_cap = logits_soft_cap_; + logits_soft_cap_rcp = 1.f / logits_soft_cap; + } + else + { + logits_soft_cap = 0.f; + logits_soft_cap_rcp = 0.f; + } + } + + float logits_soft_cap; + float logits_soft_cap_rcp; + }; + + struct FmhaFwdCommonBiasKargs + { + const void* bias_ptr = nullptr; + ck_tile::index_t stride_bias = 0; + ck_tile::index_t nhead_stride_bias = 0; + }; + + struct FmhaFwdBatchModeBiasKargs : FmhaFwdCommonBiasKargs + { + ck_tile::index_t batch_stride_bias = 0; + }; + + struct FmhaFwdAlibiKargs + { + // alibi is batch*nhead*1, no matter in batch/group mode, they are the same + const void* alibi_slope_ptr; + ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope + }; + + struct FmhaFwdMaskKargs + { + // ck_tile::index_t window_size_left, window_size_right; + ck_tile::index_t window_size_left, window_size_right; + ck_tile::GenericAttentionMaskEnum mask_type; + }; + + struct FmhaFwdFp8StaticQuantKargs + { + float scale_p; + float scale_o; + }; + + struct FmhaFwdCommonLSEKargs + { + void* lse_ptr = nullptr; + ck_tile::index_t nhead_stride_lse = 0; + ck_tile::index_t batch_stride_lse = 0; + }; + + struct FmhaFwdDropoutSeedOffset + { + template + union ValueOrPointer + { + T val; + const T* ptr; + }; + + ValueOrPointer drop_seed; + ValueOrPointer drop_offset; + bool is_drop_seed_offset_from_host; + }; + + struct FmhaFwdCommonDropoutKargs : FmhaFwdDropoutSeedOffset + { + void init_dropout(float p_drop, uint64_t seed, uint64_t offset) + { + float p_undrop = 1.0 - p_drop; + p_undrop_in_uint8_t = + uint8_t(std::floor(p_undrop * std::numeric_limits::max())); + rp_undrop = 1.0 / p_undrop; + + this->drop_seed.val = seed; + this->drop_offset.val = offset; + this->is_drop_seed_offset_from_host = true; + } + + void init_dropout(float p_drop, const uint64_t* seed_ptr, const uint64_t* offset_ptr) + { + float p_undrop = 1.0 - p_drop; + p_undrop_in_uint8_t = + uint8_t(std::floor(p_undrop * std::numeric_limits::max())); + rp_undrop = 1.0 / p_undrop; + + this->drop_seed.ptr = seed_ptr; + this->drop_offset.ptr = offset_ptr; + this->is_drop_seed_offset_from_host = false; + } + + float rp_undrop = 1; + uint8_t p_undrop_in_uint8_t = std::numeric_limits::max(); + bool is_store_randval = false; + void* rand_val_ptr = nullptr; + + ck_tile::index_t stride_randval = 0; + ck_tile::index_t nhead_stride_randval = 0; + }; + + struct FmhaFwdBatchModeDropoutKargs : FmhaFwdCommonDropoutKargs + { + ck_tile::index_t batch_stride_randval = 0; + }; + + struct FmhaFwdSkipMinSeqlenQKargs + { + ck_tile::index_t min_seqlen_q = 0; + }; + + struct FmhaFwdBatchModeKargs + : FmhaFwdCommonKargs, + std::conditional_t>>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t> + { + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_o; + }; + + struct FmhaFwdGroupModeKargs + : FmhaFwdCommonKargs, + std::conditional_t>>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t> + { + const int32_t* seqstart_q_ptr; + const int32_t* seqstart_k_ptr; + const int32_t* seqlen_k_ptr; + }; + + using Kargs = std::conditional_t; + + struct BlockIndices + { + ck_tile::index_t batch_idx; + ck_tile::index_t qo_head_idx; + ck_tile::index_t kv_head_idx; + }; + + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargsImpl(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* lut_ptr, + const void* valid_block_num_ptr, + const void* bias_ptr, + void* rand_val_ptr, + void* lse_ptr, + void* o_ptr, + ck_tile::index_t seqlen_q, + ck_tile::index_t seqlen_k, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale_s, + float scale_p, + float scale_o, + float logits_soft_cap, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_q, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, + ck_tile::index_t batch_stride_bias, + ck_tile::index_t batch_stride_randval, + ck_tile::index_t batch_stride_lse, + ck_tile::index_t batch_stride_o, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + std::variant, std::pair> + drop_seed_offset) + { + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + lut_ptr, + valid_block_num_ptr, + o_ptr, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, +#if CK_TILE_FMHA_FWD_FAST_EXP2 + static_cast(scale_s * ck_tile::log2e_v<>), +#else + scale_s, +#endif + stride_q, + stride_k, + stride_v, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_o}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for mask + {}, // placeholder for lse + {}, // placeholder for fp8_static_quant args + {}, // placeholder for dropout + {}, // placeholder for logits_soft_cap + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_o}; + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + kargs.bias_ptr = bias_ptr; + kargs.stride_bias = stride_bias; + kargs.nhead_stride_bias = nhead_stride_bias; + kargs.batch_stride_bias = batch_stride_bias; + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + kargs.alibi_slope_ptr = bias_ptr; + kargs.alibi_slope_stride = stride_bias; + } + if constexpr(kHasMask) + { + kargs.window_size_left = window_size_left; + kargs.window_size_right = window_size_right; + kargs.mask_type = static_cast(mask_type); + } + if constexpr(kStoreLSE) + { + kargs.lse_ptr = lse_ptr; + kargs.nhead_stride_lse = nhead_stride_lse; + kargs.batch_stride_lse = batch_stride_lse; + } + if constexpr(kDoFp8StaticQuant) + { + kargs.scale_p = scale_p; + kargs.scale_o = scale_o; + } + if constexpr(kHasDropout) + { + if(drop_seed_offset.index() == 0) // seed & offset come from host + { + const auto& [seed, offset] = std::get<0>(drop_seed_offset); + kargs.init_dropout(p_drop, seed, offset); + } + else // seed & offset come from device + { + const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset); + kargs.init_dropout(p_drop, + reinterpret_cast(seed_ptr), + reinterpret_cast(offset_ptr)); + } + + kargs.rand_val_ptr = rand_val_ptr; + kargs.stride_randval = stride_randval; + kargs.nhead_stride_randval = nhead_stride_randval; + kargs.batch_stride_randval = batch_stride_randval; + kargs.is_store_randval = s_randval; + } + if constexpr(kHasLogitsSoftCap) + { + kargs.init_logits_soft_cap(logits_soft_cap); + } + + return kargs; + } + + // std::variant<> can't take in a list initializer, overload for backward compatibility + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* lut_ptr, + const void* valid_block_num_ptr, + const void* bias_ptr, + void* rand_val_ptr, + void* lse_ptr, + void* o_ptr, + ck_tile::index_t seqlen_q, + ck_tile::index_t seqlen_k, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale_s, + float scale_p, + float scale_o, + float logits_soft_cap, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_q, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, + ck_tile::index_t batch_stride_bias, + ck_tile::index_t batch_stride_randval, + ck_tile::index_t batch_stride_lse, + ck_tile::index_t batch_stride_o, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + const std::tuple& drop_seed_offset) + { + return MakeKargsImpl( + q_ptr, + k_ptr, + v_ptr, + lut_ptr, + valid_block_num_ptr, + bias_ptr, + rand_val_ptr, + lse_ptr, + o_ptr, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + scale_s, + scale_p, + scale_o, + logits_soft_cap, + stride_q, + stride_k, + stride_v, + stride_bias, + stride_randval, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_bias, + nhead_stride_randval, + nhead_stride_lse, + nhead_stride_o, + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_bias, + batch_stride_randval, + batch_stride_lse, + batch_stride_o, + window_size_left, + window_size_right, + mask_type, + p_drop, + s_randval, + std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); + } + + // std::variant<> can't take in a list initializer, overload for backward compatibility + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* lut_ptr, + const void* valid_block_num_ptr, + const void* bias_ptr, + void* rand_val_ptr, + void* lse_ptr, + void* o_ptr, + ck_tile::index_t seqlen_q, + ck_tile::index_t seqlen_k, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale_s, + float scale_p, + float scale_o, + float logits_soft_cap, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_q, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, + ck_tile::index_t batch_stride_bias, + ck_tile::index_t batch_stride_randval, + ck_tile::index_t batch_stride_lse, + ck_tile::index_t batch_stride_o, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + const std::tuple& drop_seed_offset) + { + return MakeKargsImpl( + q_ptr, + k_ptr, + v_ptr, + lut_ptr, + valid_block_num_ptr, + bias_ptr, + rand_val_ptr, + lse_ptr, + o_ptr, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + scale_s, + scale_p, + scale_o, + logits_soft_cap, + stride_q, + stride_k, + stride_v, + stride_bias, + stride_randval, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_bias, + nhead_stride_randval, + nhead_stride_lse, + nhead_stride_o, + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_bias, + batch_stride_randval, + batch_stride_lse, + batch_stride_o, + window_size_left, + window_size_right, + mask_type, + p_drop, + s_randval, + std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); + } + + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargsImpl(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* lut_ptr, + const void* valid_block_num_ptr, + const void* bias_ptr, + void* rand_val_ptr, + void* lse_ptr, + void* o_ptr, + const void* seqstart_q_ptr, + const void* seqstart_k_ptr, + const void* seqlen_k_ptr, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale_s, + float scale_p, + float scale_o, + float logits_soft_cap, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + ck_tile::index_t min_seqlen_q, + float p_drop, + bool s_randval, + std::variant, std::pair> + drop_seed_offset) + { + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + lut_ptr, + valid_block_num_ptr, + o_ptr, + -1, // seqlen will be updated by another pointer + -1, // + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, +#if CK_TILE_FMHA_FWD_FAST_EXP2 + static_cast(scale_s * ck_tile::log2e_v<>), +#else + scale_s, +#endif + stride_q, + stride_k, + stride_v, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_o}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for mask + {}, // placeholder for lse + {}, // placeholder for fp8_static_quant args + {}, // placeholder for dropout + {}, // placeholder for logits_soft_cap + {}, // placeholder for min_seqlen_q + reinterpret_cast(seqstart_q_ptr), + reinterpret_cast(seqstart_k_ptr), + reinterpret_cast(seqlen_k_ptr)}; + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + kargs.bias_ptr = bias_ptr; + kargs.stride_bias = stride_bias; + kargs.nhead_stride_bias = nhead_stride_bias; + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + kargs.alibi_slope_ptr = bias_ptr; + kargs.alibi_slope_stride = stride_bias; + } + if constexpr(kHasMask) + { + kargs.window_size_left = window_size_left; + kargs.window_size_right = window_size_right; + kargs.mask_type = static_cast(mask_type); + } + if constexpr(kStoreLSE) + { + kargs.lse_ptr = lse_ptr; + kargs.nhead_stride_lse = nhead_stride_lse; + } + if constexpr(kDoFp8StaticQuant) + { + kargs.scale_p = scale_p; + kargs.scale_o = scale_o; + } + if constexpr(kHasDropout) + { + if(drop_seed_offset.index() == 0) // seed & offset come from host + { + const auto& [seed, offset] = std::get<0>(drop_seed_offset); + kargs.init_dropout(p_drop, seed, offset); + } + else // seed & offset come from device + { + const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset); + kargs.init_dropout(p_drop, + reinterpret_cast(seed_ptr), + reinterpret_cast(offset_ptr)); + } + + kargs.rand_val_ptr = rand_val_ptr; + kargs.stride_randval = stride_randval; + kargs.nhead_stride_randval = nhead_stride_randval; + kargs.is_store_randval = s_randval; + } + if constexpr(kHasLogitsSoftCap) + { + kargs.init_logits_soft_cap(logits_soft_cap); + } + if constexpr(kSkipMinSeqlenQ) + { + kargs.min_seqlen_q = min_seqlen_q; + } + + return kargs; + } + + // std::variant<> can't take in a list initializer, overload for backward compatibility + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* lut_ptr, + const void* valid_block_num_ptr, + const void* bias_ptr, + void* rand_val_ptr, + void* lse_ptr, + void* o_ptr, + const void* seqstart_q_ptr, + const void* seqstart_k_ptr, + const void* seqlen_k_ptr, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale_s, + float scale_p, + float scale_o, + float logits_soft_cap, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + const std::tuple& drop_seed_offset) + { + return MakeKargsImpl( + q_ptr, + k_ptr, + v_ptr, + lut_ptr, + valid_block_num_ptr, + bias_ptr, + rand_val_ptr, + lse_ptr, + o_ptr, + seqstart_q_ptr, + seqstart_k_ptr, + seqlen_k_ptr, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + scale_s, + scale_p, + scale_o, + logits_soft_cap, + stride_q, + stride_k, + stride_v, + stride_bias, + stride_randval, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_bias, + nhead_stride_randval, + nhead_stride_lse, + nhead_stride_o, + window_size_left, + window_size_right, + mask_type, + p_drop, + s_randval, + std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); + } + + // std::variant<> can't take in a list initializer, overload for backward compatibility + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* lut_ptr, + const void* valid_block_num_ptr, + const void* bias_ptr, + void* rand_val_ptr, + void* lse_ptr, + void* o_ptr, + const void* seqstart_q_ptr, + const void* seqstart_k_ptr, + const void* seqlen_k_ptr, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale_s, + float scale_p, + float scale_o, + float logits_soft_cap, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + const std::tuple& drop_seed_offset) + { + return MakeKargsImpl( + q_ptr, + k_ptr, + v_ptr, + lut_ptr, + valid_block_num_ptr, + bias_ptr, + rand_val_ptr, + lse_ptr, + o_ptr, + seqstart_q_ptr, + seqstart_k_ptr, + seqlen_k_ptr, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + scale_s, + scale_p, + scale_o, + logits_soft_cap, + stride_q, + stride_k, + stride_v, + stride_bias, + stride_randval, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_bias, + nhead_stride_randval, + nhead_stride_lse, + nhead_stride_o, + window_size_left, + window_size_right, + mask_type, + p_drop, + s_randval, + std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); + } + + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, + ck_tile::index_t nhead_, + ck_tile::index_t seqlen_q_, + ck_tile::index_t hdim_v_, + bool has_padded_seqlen_k = false) + { + // has_padded_seqlen_k is determined by checking (seqlen_k_ptr != nullptr) + if(has_padded_seqlen_k) + { + // TODO: this may need tuning + return dim3(nhead_, + batch_size_, + ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1)); + } + else + { + // TODO: this may need tuning + return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1), + nhead_, + batch_size_); + } + } + + CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs) + { + bool has_padded_seqlen_k = false; + + if constexpr(kIsGroupMode) + has_padded_seqlen_k = (kargs.seqlen_k_ptr != nullptr); + + + if(has_padded_seqlen_k) + { + // const index_t num_tile_m0 = seqlen_q / kM0; + const index_t num_tile_n1 = + ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1); + + const index_t i_block = blockIdx.z; + const index_t i_nhead = blockIdx.x; + const index_t i_batch = blockIdx.y; + + const auto f = [](index_t dividend, index_t divisor) { + index_t quotient = dividend / divisor; + index_t modulus = dividend - quotient * divisor; + return ck_tile::make_tuple(quotient, modulus); + }; + + const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); + + if constexpr(kHasMask) + { + // assume that num_tile_n1 is always 1 + return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); + } + else + { + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } + } + else + { + // const index_t num_tile_m0 = seqlen_q / kM0; + const index_t num_tile_n1 = + ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1); + + const index_t i_block = blockIdx.x; + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; + + const auto f = [](index_t dividend, index_t divisor) { + index_t quotient = dividend / divisor; + index_t modulus = dividend - quotient * divisor; + return ck_tile::make_tuple(quotient, modulus); + }; + + const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); + + if constexpr(kHasMask) + { + // assume that num_tile_n1 is always 1 + return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); + } + else + { + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } + } + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + // allocate LDS + __shared__ char smem_ptr[GetSmemSize()+256*sizeof(int)]; + + // if (threadIdx.x==0 && blockIdx.x==0 && blockIdx.z ==0) printf("smem size: %d", int(GetSmemSize())); + + // divide problem + const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); + + const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); + const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); + + long_index_t batch_offset_q = 0; + long_index_t batch_offset_k = 0; + long_index_t batch_offset_v = 0; + long_index_t batch_offset_bias = 0; + long_index_t batch_offset_randval = 0; + long_index_t batch_offset_lse = 0; + long_index_t batch_offset_o = 0; + + if constexpr(kIsGroupMode) + { + // get starting offset for each batch + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; + + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; + if constexpr(std::is_same_v) + { + batch_offset_v = key_start * kargs.stride_v; + } + else + { + batch_offset_v = key_start; + } + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + batch_offset_bias = query_start * kargs.stride_bias; + } + if constexpr(kStoreLSE) + { + batch_offset_lse = query_start; + } + if constexpr(kHasDropout) + { + batch_offset_randval = query_start * kargs.stride_randval; + } + batch_offset_o = query_start * kargs.stride_o; + + // get real # queries & # keys under group mode + const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; + kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + + if constexpr(kSkipMinSeqlenQ) + { + if(kargs.seqlen_q <= kargs.min_seqlen_q) + { + return; + } + } + + // # of required blocks is different in each groups, terminate unnecessary blocks + // earlier + if(kargs.seqlen_q <= i_m0) + { + return; + } + + if(kargs.seqlen_k_ptr != nullptr) + { + kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; + } + else + { + const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; + kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; + } + } + else + { + batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; + batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; + batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; + } + if constexpr(kStoreLSE) + { + batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; + } + if constexpr(kHasDropout) + { + batch_offset_randval = + static_cast(i_batch) * kargs.batch_stride_randval; + } + batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; + } + + // for simplicity, batch stride we just modify the pointer + const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_q + + batch_offset_q; + const KDataType* k_ptr = + reinterpret_cast(kargs.k_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + + batch_offset_k; + const VDataType* v_ptr = + reinterpret_cast(kargs.v_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + + batch_offset_v; + + // sparse mask + const int* lut_ptr = reinterpret_cast(kargs.lut_ptr) + + static_cast(i_batch * kargs.num_head_q + i_nhead) * ck_tile::integer_divide_ceil(kargs.seqlen_q, FmhaPipeline::kM0) + * ck_tile::integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kN0) + i_tile_m * ck_tile::integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kN0); + const int* valid_block_num_ptr = reinterpret_cast(kargs.valid_block_num_ptr) + + static_cast(i_batch * kargs.num_head_q + i_nhead) * ck_tile::integer_divide_ceil(kargs.seqlen_q, FmhaPipeline::kM0) + i_tile_m; + // int valid_block_num_value = __builtin_amdgcn_readfirstlane(valid_block_num_ptr[0]); + const int valid_block_num_value = valid_block_num_ptr[0]; + + ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_o + + batch_offset_o; + + // Q/K/V DRAM and DRAM window + const auto q_dram = [&]() { + const auto q_dram_naive = make_naive_tensor_view( + q_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_q), + make_tuple(kargs.stride_q, 1), + number{}, + number<1>{}); + if constexpr(FmhaPipeline::kQLoadOnce) + { + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + const auto k_dram = [&]() { + const auto k_dram_naive = make_naive_tensor_view( + k_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_q), + make_tuple(kargs.stride_k, 1), + number{}, + number<1>{}); + + constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false; + return pad_tensor_view( + k_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + const auto v_dram = [&]() { + if constexpr(std::is_same_v) + { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_v), + make_tuple(kargs.stride_v, 1), + number{}, + number<1>{}); + + const auto v_dram_transposed = + transform_tensor_view(v_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_v), + make_pass_through_transform(kargs.seqlen_k)), + make_tuple(sequence<1>{}, sequence<0>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false; + return pad_tensor_view( + v_dram_transposed, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.hdim_v, kargs.seqlen_k), + make_tuple(kargs.stride_v, 1), + number{}, + number<1>{}); + + constexpr bool kPadHeadDimV_ = kUseAsyncCopy ? kPadHeadDimV : false; + return pad_tensor_view( + v_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + // sparse mask + // const auto lut_dram = make_naive_tensor_view( + // lut_ptr, + // make_tuple(kargs.seqlen_k/number{}, 1), + // make_tuple(1, 1), + // number<1>{}, + // number<1>{}); + + // const auto valid_block_num_dram = make_naive_tensor_view( + // valid_block_num_ptr, + // make_tuple(kargs.seqlen_q/number{}), + // make_tuple(1), + // number<1>{}, + // number<1>{}); + + + auto q_dram_window = make_tile_window( + q_dram, + [&]() { + if constexpr(FmhaPipeline::kQLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, number{}); + }(), + {i_m0, 0}); + + auto k_dram_window = make_tile_window( + k_dram, make_tuple(number{}, number{}), {0, 0}); + + auto v_dram_window = + make_tile_window(v_dram, + make_tuple(number{}, number{}), + {i_n1, 0}); + + // auto lut_dram_window = make_tile_window( + // lut_dram, make_tuple(1,1), {0,0}); + // auto valid_block_num_window = make_tile_window( + // valid_block_num_dram, make_tuple(1), {i_tile_m}); + + /// FIXME: Before C++20, capturing structured binding variables are not supported. Remove + /// following copy capture of the 'i_nhead' if in C++20 + const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto bias_dram_window_lengths = + make_tuple(number{}, number{}); + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + const BiasDataType* bias_ptr = + reinterpret_cast(kargs.bias_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_bias + + batch_offset_bias; + + const auto bias_dram = [&]() { + const auto bias_dram_naive = make_naive_tensor_view( + bias_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_bias, 1), + number{}, + number<1>{}); + + return pad_tensor_view(bias_dram_naive, + bias_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0}); + } + else + { + return make_null_tile_window(bias_dram_window_lengths); + } + }(); + + // lse + auto lse_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto lse_dram_window_lengths = make_tuple(number{}); + if constexpr(kStoreLSE) + { + LSEDataType* lse_ptr = + reinterpret_cast(kargs.lse_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse; + + const auto lse_dram = [&]() { + const auto lse_dram_naive = make_naive_tensor_view( + lse_ptr, + make_tuple(kargs.seqlen_q), + make_tuple(1), + number<1>{}, + number<1>{}); + + return pad_tensor_view( + lse_dram_naive, lse_dram_window_lengths, sequence{}); + }(); + + return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0}); + } + else + { + return make_null_tile_window(lse_dram_window_lengths); + } + }(); + + auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() { + if constexpr(kHasDropout) + { + return BlockDropout{i_batch_, + i_nhead_, + kargs.num_head_q, + kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val + : *kargs.drop_seed.ptr, + kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val + : *kargs.drop_offset.ptr, + kargs.rp_undrop, + kargs.p_undrop_in_uint8_t, + kargs.is_store_randval}; + } + else + { + return NullBlockDropout{}; + }; + }(); + + auto randval_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto randval_dram_window_lengths = + make_tuple(number{}, number{}); + if constexpr(kHasDropout) + { + RandValOutputDataType* rand_val_ptr = + reinterpret_cast(kargs.rand_val_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_randval + + batch_offset_randval; + + const auto randval_dram = [&]() { + const auto randval_dram_naive = + make_naive_tensor_view( + rand_val_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_randval, 1), + number<1>{}, + number<1>{}); + + return pad_tensor_view(randval_dram_naive, + randval_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0}); + } + else + { + return make_null_tile_window(randval_dram_window_lengths); + } + }(); + + FmhaMask mask = [&]() { + if constexpr(kHasMask) + return ck_tile::make_generic_attention_mask_from_lr_window( + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); + else + return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; + }(); + + // WA i_batch capture structure binding before c++20 + auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + // data loading, shared by entire wg + // TODO: how to use s_read? + SaccDataType slope = + *(reinterpret_cast(kargs.alibi_slope_ptr) + + i_batch_ * kargs.alibi_slope_stride + i_nhead_); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + slope *= ck_tile::log2e_v<>; +#endif + if constexpr(kHasMask) + { + return make_alibi_from_lr_mask(slope, + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type); + } + else + { + return Alibi{ + slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT}; + } + } + else + { + return EmptyPositionEncoding{}; + } + }(); + + AttentionVariant variant; + const auto variant_params = [&] { + if constexpr(kHasLogitsSoftCap) + { + return ck_tile::LogitsSoftCapParams{ + mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp}; + } + else + { + return ck_tile::StandardAttentionParams{mask, kargs.scale_s}; + } + }(); + + BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; + + auto o_acc_tile = [&]() { + // TODO: constexpr(kDoFp8StaticQuant) + return FmhaPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + lut_ptr, + valid_block_num_value, + bias_dram_window, + randval_dram_window, + lse_dram_window, + mask, + position_encoding, + kargs.scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + dropout); + }(); + + // O DRAM and O DRAM window + auto o_dram = [&]() { + const auto o_dram_naive = make_naive_tensor_view( + o_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.stride_o, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + o_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto o_dram_window = + make_tile_window(o_dram, + make_tuple(number{}, number{}), + {i_m0, i_n1}); + + EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr); + } +}; + +} // namespace ck_tile From 8ff98b8095763ae36511c5f81547efb6d5781187 Mon Sep 17 00:00:00 2001 From: Jiangyon Date: Thu, 4 Dec 2025 01:09:53 +0000 Subject: [PATCH 02/22] fix the pre-commit --- .../50_sparse_attn/codegen/cpp_symbol_map.py | 133 +- .../codegen/ops/fmha_fwd_jenga.py | 1118 +++++++++++----- .../codegen/ops/fmha_fwd_vsa.py | 1122 ++++++++++++----- .../ck_tile/50_sparse_attn/fmha_fwd_trek.hpp | 337 ++--- example/ck_tile/50_sparse_attn/generate.py | 107 +- .../50_sparse_attn/jenga_sparse_attention.cu | 110 +- .../50_sparse_attn/jenga_sparse_attention.h | 96 +- .../50_sparse_attn/test_vsa_sparse_attn.cpp | 461 ++++--- .../50_sparse_attn/vsa_sparse_attention.cu | 124 +- ...ock_fmha_pipeline_qr_ks_vs_async_jenga.hpp | 46 +- ...block_fmha_pipeline_qr_ks_vs_async_vsa.hpp | 36 +- .../ops/sparse_attn/fmha_fwd_jenga_kernel.hpp | 49 +- .../ops/sparse_attn/fmha_fwd_vsa_kernel.hpp | 64 +- 13 files changed, 2402 insertions(+), 1401 deletions(-) diff --git a/example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py b/example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py index 1f5a03e243d..63751d43fa2 100644 --- a/example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py +++ b/example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py @@ -3,35 +3,33 @@ # generate kernel instances to speed up compilation FWD_DTYPE_MAP = { - "fp16" : "FmhaFwdFp16", - "bf16" : "FmhaFwdBf16", - "fp8" : "FmhaFwdFp8", + "fp16": "FmhaFwdFp16", + "bf16": "FmhaFwdBf16", + "fp8": "FmhaFwdFp8", "fp8fp16": "FmhaFwdFp8Fp16", - "fp8bf16": "FmhaFwdFp8Bf16" + "fp8bf16": "FmhaFwdFp8Bf16", } -BWD_DTYPE_MAP = { - "fp16": "FmhaBwdFp16", - "bf16": "FmhaBwdBf16" -} +BWD_DTYPE_MAP = {"fp16": "FmhaBwdFp16", "bf16": "FmhaBwdBf16"} MASK_IMPL = { - "generic" : "ck_tile::GenericAttentionMask", - "simplified" : "ck_tile::SimplifiedGenericAttentionMask" + "generic": "ck_tile::GenericAttentionMask", + "simplified": "ck_tile::SimplifiedGenericAttentionMask", } _MASK_SIMPLIFIED_MAP = { - "s_no" : "ck_tile::SimplifiedGenericAttentionMask", - "s_mask" : "ck_tile::SimplifiedGenericAttentionMask", + "s_no": "ck_tile::SimplifiedGenericAttentionMask", + "s_mask": "ck_tile::SimplifiedGenericAttentionMask", } _MASK_MAP = { - "no" : "FmhaMasks::NoMask", - "causal" : "FmhaMasks::CausalMask", - "generic" : "FmhaMasks::GenericMask" + "no": "FmhaMasks::NoMask", + "causal": "FmhaMasks::CausalMask", + "generic": "FmhaMasks::GenericMask", } -def get_mask_map(mask : str): + +def get_mask_map(mask: str): if mask == "generic": return _MASK_MAP elif mask == "simplified": @@ -40,18 +38,20 @@ def get_mask_map(mask : str): assert False return None + _MASK_CHECK_MAP = { - "no" : "t.mask_type == mask_enum::no_mask", - "causal" : "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right", - "generic" : "t.mask_type == mask_enum::window_generic", + "no": "t.mask_type == mask_enum::no_mask", + "causal": "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right", + "generic": "t.mask_type == mask_enum::window_generic", } _MASK_SIMPLIFIED_CHECK_MAP = { - "s_no" : "t.mask_type == mask_enum::no_mask", - "s_mask" : "t.mask_type != mask_enum::no_mask", + "s_no": "t.mask_type == mask_enum::no_mask", + "s_mask": "t.mask_type != mask_enum::no_mask", } -def get_mask_check_map(mask : str): + +def get_mask_check_map(mask: str): if mask == "generic": return _MASK_CHECK_MAP elif mask == "simplified": @@ -60,82 +60,77 @@ def get_mask_check_map(mask : str): assert False return None + BIAS_MAP = { - "no" : "ck_tile::BlockAttentionBiasEnum::NO_BIAS", - "bias" : "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS", - "alibi" : "ck_tile::BlockAttentionBiasEnum::ALIBI" + "no": "ck_tile::BlockAttentionBiasEnum::NO_BIAS", + "bias": "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS", + "alibi": "ck_tile::BlockAttentionBiasEnum::ALIBI", } # TODO: this is ugly BIAS_CHECK_MAP = { - "no" : "bias_enum::no_bias", - "bias" : "bias_enum::elementwise_bias", - "alibi" : "bias_enum::alibi" + "no": "bias_enum::no_bias", + "bias": "bias_enum::elementwise_bias", + "alibi": "bias_enum::alibi", } DROPOUT_MAP = { - "no" : "ck_tile::BlockDropoutBwd", - "dropout_wg32" : "ck_tile::BlockDropoutBwd", - "dropout_wg32_storerandval" : "ck_tile::BlockDropoutBwd", - "dropout_wg16" : "ck_tile::BlockDropoutBwd", - "dropout_wg16_storerandval" : "ck_tile::BlockDropoutBwd" + "no": "ck_tile::BlockDropoutBwd", + "dropout_wg32": "ck_tile::BlockDropoutBwd", + "dropout_wg32_storerandval": "ck_tile::BlockDropoutBwd", + "dropout_wg16": "ck_tile::BlockDropoutBwd", + "dropout_wg16_storerandval": "ck_tile::BlockDropoutBwd", } DROPOUT_CHECK_MAP = { - "no" : "t.has_dropout == false", - "dropout_wg32" : "t.has_dropout == true && t.is_store_randval == false", - "dropout_wg32_storerandval" : "t.has_dropout == true && t.is_store_randval == true", - "dropout_wg16" : "t.has_dropout == true && t.is_store_randval == false", - "dropout_wg16_storerandval" : "t.has_dropout == true && t.is_store_randval == true", + "no": "t.has_dropout == false", + "dropout_wg32": "t.has_dropout == true && t.is_store_randval == false", + "dropout_wg32_storerandval": "t.has_dropout == true && t.is_store_randval == true", + "dropout_wg16": "t.has_dropout == true && t.is_store_randval == false", + "dropout_wg16_storerandval": "t.has_dropout == true && t.is_store_randval == true", } ROPE_MAP = { - "no" : "ck_tile::RotaryEmbeddingEnum::NONE", - "inter" : "ck_tile::RotaryEmbeddingEnum::INTERLEAVED", - "half" : "ck_tile::RotaryEmbeddingEnum::HALF_ROTATED" + "no": "ck_tile::RotaryEmbeddingEnum::NONE", + "inter": "ck_tile::RotaryEmbeddingEnum::INTERLEAVED", + "half": "ck_tile::RotaryEmbeddingEnum::HALF_ROTATED", } ROPE_CHECK_MAP = { - "no" : "rope_enum::none", - "inter" : "rope_enum::interleaved", - "half" : "rope_enum::half_rotated" + "no": "rope_enum::none", + "inter": "rope_enum::interleaved", + "half": "rope_enum::half_rotated", } -MODE_MAP = { - "batch" : "false", - "group" : "true" -} +MODE_MAP = {"batch": "false", "group": "true"} -LAYOUT_MAP = { - "row" : "true", - "col" : "false" -} +LAYOUT_MAP = {"row": "true", "col": "false"} PIPELINE_MAP = { - "qr" : "ck_tile::BlockFmhaPipelineQRKSVS", - "qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsyncJenga", - "qs" : "ck_tile::BlockFmhaPipelineQSKSVS", - "qr_async_vsa" : "ck_tile::BlockFmhaPipelineQRKSVSAsyncVSA", + "qr": "ck_tile::BlockFmhaPipelineQRKSVS", + "qr_async": "ck_tile::BlockFmhaPipelineQRKSVSAsyncJenga", + "qs": "ck_tile::BlockFmhaPipelineQSKSVS", + "qr_async_vsa": "ck_tile::BlockFmhaPipelineQRKSVSAsyncVSA", } PIPELINE_ENUM_MAP = { - "qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", - "qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", - "qr_async_vsa" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", - "qr_nwarp_sshuffle" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", - "qs" : "ck_tile::BlockFmhaPipelineEnum::QSKSVS", - "qr_pagedkv" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", - "qr_async_trload" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD", + "qr": "ck_tile::BlockFmhaPipelineEnum::QRKSVS", + "qr_async": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", + "qr_async_vsa": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", + "qr_nwarp_sshuffle": "ck_tile::BlockFmhaPipelineEnum::QRKSVS", + "qs": "ck_tile::BlockFmhaPipelineEnum::QSKSVS", + "qr_pagedkv": "ck_tile::BlockFmhaPipelineEnum::QRKSVS", + "qr_async_trload": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD", } BOOL_MAP = { - "t" : "true", - "f" : "false", - True : "true", - False : "false", + "t": "true", + "f": "false", + True: "true", + False: "false", } SQUANT_MAP = { - "t" : "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR", - "f" : "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE", + "t": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR", + "f": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE", } diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py index 388c7d3a685..a3a9e6bf871 100644 --- a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py +++ b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py @@ -7,15 +7,25 @@ import fnmatch import itertools import os +import os.path as path from pathlib import Path from typing import List, Optional, Tuple -from codegen.cpp_symbol_map import * +from codegen.cpp_symbol_map import ( + BIAS_CHECK_MAP, + BIAS_MAP, + BOOL_MAP, + FWD_DTYPE_MAP, + LAYOUT_MAP, + MODE_MAP, + PIPELINE_ENUM_MAP, + PIPELINE_MAP, + get_mask_check_map, + get_mask_map, +) GEN_DIR = "" -import os.path as path - def update_file(file_path, content): """Update the file at file_path with the given content if it differs from the existing content. @@ -33,22 +43,9 @@ def update_file(file_path, content): file.write(content) -DTYPE_BITS = { - "fp32": 32, - "fp16": 16, - "bf16": 16, - "fp8" : 8, - "bf8" : 8 -} +DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16, "fp8": 8, "bf8": 8} -K0_MAX_SUBMAX_MAP = { - 32 : 32, - 64 : 64, - 96 : 128, - 128: 128, - 192: 192, - 256: 256 -} +K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256} FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n @@ -60,7 +57,7 @@ def update_file(file_path, content): """ -FMHA_FWD_KERNEL_BODY=""" +FMHA_FWD_KERNEL_BODY = """ using fmha_dtype_{F_idx} = {F_dtype}; using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; @@ -137,8 +134,8 @@ def update_file(file_path, content): }} """ -FMHA_FWD_API_FILENAME="fmha_jenga_fwd_api.cpp" -FMHA_FWD_API=""" +FMHA_FWD_API_FILENAME = "fmha_jenga_fwd_api.cpp" +FMHA_FWD_API = """ #include #include @@ -192,198 +189,255 @@ def update_file(file_path, content): }} """ -FMHA_FWD_API_PER_TRLOAD=""" {F_if}({F_trload_cond}){{ +FMHA_FWD_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{ {F_dtype_case} }} """ -FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ {F_hdim_case} }} """ -FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ +FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ {F_inner_dispatch} }} """ -FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && +FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ using trait_ = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; return fmha_jenga_fwd_(s, a); }} """ + @dataclass class CppConstraint: bool_expr: str = None def __str__(self): if self.bool_expr is None: - return 'true' + return "true" else: - return f'{self.bool_expr}' + return f"{self.bool_expr}" def __and__(self, other): - return CppConstraint(f'({str(self)}) && ({str(other)})') + return CppConstraint(f"({str(self)}) && ({str(other)})") + @dataclass class FmhaFwdApiTrait: - pipeline_tag : str + pipeline_tag: str # sync with fmha_fwd_traits<>, to generate fallback calls - hdim : str - dtype : str # data type - mode : str # value from MODE_MAP - bm0 : int # tile size along q seqlen (block size) - bn0 : int # tile size along qk seqlen - bk0 : int # tile size along qk gemm unroll - bn1 : int # tile size along v head_dim - bk1 : int # tile size along kv gemm unroll - bk0max : int - vlayout : str - logits : str - mask : str - bias : str # - lse : str # - dropout : str - squant : str # - spad : str - skpad : str - dpad : str - dvpad : str - skip : str - tr_load : str - constraint : CppConstraint + hdim: str + dtype: str # data type + mode: str # value from MODE_MAP + bm0: int # tile size along q seqlen (block size) + bn0: int # tile size along qk seqlen + bk0: int # tile size along qk gemm unroll + bn1: int # tile size along v head_dim + bk1: int # tile size along kv gemm unroll + bk0max: int + vlayout: str + logits: str + mask: str + bias: str # + lse: str # + dropout: str + squant: str # + spad: str + skpad: str + dpad: str + dvpad: str + skip: str + tr_load: str + constraint: CppConstraint @property def name(self) -> str: - return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\ - f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}' + return ( + f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" + + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}" + ) @property def scheck(self) -> str: - if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true - if self.pipeline_tag in ['qr_async', 'qr_async_trload']: - if self.spad == 't' : return 'true' # always support - else : return 'true' - elif self.pipeline_tag in ['qr', 'qs']: - if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.seqlen_q % {self.bm0} == 0' - else: assert False - + if self.mode == "group": + return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true + if self.pipeline_tag in ["qr_async", "qr_async_trload"]: + if self.spad == "t": + return "true" # always support + else: + return "true" + elif self.pipeline_tag in ["qr", "qs"]: + if self.spad == "t": + return f"true /*a.seqlen_q % {self.bm0} != 0*/" # TODO: order of get_pipelines() matters! (ugly) + else: + return f"a.seqlen_q % {self.bm0} == 0" + else: + assert False + @property def seqtune(self) -> str: - if self.bm0 == 128: return 'true/*fall back to largest tile*/' # group mode only generate spad/skpad == true - else: - return f'a.seqlen_q <= {self.bm0}' + if self.bm0 == 128: + return "true/*fall back to largest tile*/" # group mode only generate spad/skpad == true + else: + return f"a.seqlen_q <= {self.bm0}" @property def skcheck(self) -> str: - if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true - if self.pipeline_tag == 'qr_async': - if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' - else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' - elif self.pipeline_tag in ['qr', 'qs']: - if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.seqlen_k % {self.bn0} == 0' - elif self.pipeline_tag == 'qr_async_trload': - if self.skpad == 't' : return 'true' - else: return 'true' - else: assert False + if self.mode == "group": + return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true + if self.pipeline_tag == "qr_async": + if self.skpad == "t": + return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0" + else: + return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0" + elif self.pipeline_tag in ["qr", "qs"]: + if self.skpad == "t": + return f"true /*a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly) + else: + return f"a.seqlen_k % {self.bn0} == 0" + elif self.pipeline_tag == "qr_async_trload": + if self.skpad == "t": + return "true" + else: + return "true" + else: + assert False @property def dcheck(self) -> str: - if self.pipeline_tag == 'qr_async': + if self.pipeline_tag == "qr_async": vec = int((32 * 4) / DTYPE_BITS[self.dtype]) - if self.dpad == 't': return f'a.hdim_q % {vec} == 0' - else : assert False - elif self.pipeline_tag in ['qr', 'qs', 'qr_async_trload']: + if self.dpad == "t": + return f"a.hdim_q % {vec} == 0" + else: + assert False + elif self.pipeline_tag in ["qr", "qs", "qr_async_trload"]: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] - if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.hdim_q % {bk0submax} == 0' - else: assert False + if self.dpad == "t": + return f"true /*a.hdim_q % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly) + else: + return f"a.hdim_q % {bk0submax} == 0" + else: + assert False @property def dvcheck(self) -> str: - if self.pipeline_tag == 'qr_async': + if self.pipeline_tag == "qr_async": vec = int((32 * 4) / DTYPE_BITS[self.dtype]) - if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' - else : assert False - elif self.pipeline_tag in ['qr', 'qs', 'qr_async_trload']: + if self.dvpad == "t": + return f"a.hdim_v % {vec} == 0" + else: + assert False + elif self.pipeline_tag in ["qr", "qs", "qr_async_trload"]: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] - if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.hdim_v % {bk0submax} == 0' - else: assert False + if self.dvpad == "t": + return f"true /*a.hdim_v % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly) + else: + return f"a.hdim_v % {bk0submax} == 0" + else: + assert False + @dataclass class FmhaFwdPipeline: - tag : str - - F_vlayout : str # row/col - F_spad : str # true/false - F_skpad : str # - F_dpad : str # - F_dvpad : str # - F_logits : str # t/f - F_bias : str # true/false - F_lse : str # - F_dropout : str # - F_squant : str # - F_mask : str # value from MASK_MAP - F_skip : str # true/false - F_trload : str # true/false - F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint()) + tag: str + + F_vlayout: str # row/col + F_spad: str # true/false + F_skpad: str # + F_dpad: str # + F_dvpad: str # + F_logits: str # t/f + F_bias: str # true/false + F_lse: str # + F_dropout: str # + F_squant: str # + F_mask: str # value from MASK_MAP + F_skip: str # true/false + F_trload: str # true/false + F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint()) @property def name(self) -> str: def pad_name() -> str: - n = '' - if self.F_spad == 't': n += 's' - if self.F_skpad == 't' : n += 'sk' - if self.F_dpad == 't' : n += 'd' - if self.F_dvpad == 't' : n += 'dv' - if n != '' : n = 'p' + n + n = "" + if self.F_spad == "t": + n += "s" + if self.F_skpad == "t": + n += "sk" + if self.F_dpad == "t": + n += "d" + if self.F_dvpad == "t": + n += "dv" + if n != "": + n = "p" + n return n + pn = pad_name() - n = f'{self.tag}_v{self.F_vlayout[0]}' - if pn != '' : n += f'_{pn}' - else: n += '_npad' + n = f"{self.tag}_v{self.F_vlayout[0]}" + if pn != "": + n += f"_{pn}" + else: + n += "_npad" - if self.F_logits == 't' : n += '_logits' - else: n += '_nlogits' + if self.F_logits == "t": + n += "_logits" + else: + n += "_nlogits" - if self.F_bias != 'no' : n += f'_{self.F_bias}' - else: n += '_nbias' + if self.F_bias != "no": + n += f"_{self.F_bias}" + else: + n += "_nbias" - if self.F_mask[0:2] == 's_': - if self.F_mask == 's_mask': n += f'_mask' - else: n += '_nmask' + if self.F_mask[0:2] == "s_": + if self.F_mask == "s_mask": + n += "_mask" + else: + n += "_nmask" else: - if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' - else: n += '_nmask' + if self.F_mask != "no": + n += f"_m{self.F_mask[0]}" + else: + n += "_nmask" - if self.F_lse == 't' : n += '_lse' - else: n += '_nlse' + if self.F_lse == "t": + n += "_lse" + else: + n += "_nlse" + + if self.F_dropout == "t": + n += "_dropout" + else: + n += "_ndropout" - if self.F_dropout == 't' : n += '_dropout' - else: n += '_ndropout' + if self.F_skip == "t": + n += "_skip" + else: + n += "_nskip" - if self.F_skip == 't' : n += '_skip' - else: n += '_nskip' + if self.F_squant == "t": + n += "_squant" + else: + n += "_nsquant" - if self.F_squant == 't' : n += '_squant' - else: n += '_nsquant' - - if self.F_trload == 't' : n += '_trload' - else: n += '_ntrload' + if self.F_trload == "t": + n += "_trload" + else: + n += "_ntrload" return n + class FmhaFwdApiPool: def __init__(self, mask_impl): self.pool = dict() self.mask_impl = mask_impl - def register_traits(self, trait : FmhaFwdApiTrait) -> None: + def register_traits(self, trait: FmhaFwdApiTrait) -> None: # TODO: do we need to check duplication? if trait.dtype not in self.pool.keys(): self.pool[trait.dtype] = dict() @@ -395,129 +449,171 @@ def register_traits(self, trait : FmhaFwdApiTrait) -> None: @property def api(self) -> str: - tr_load_cond_map = { - "t": "has_load_tr", - "f": "true" - } - - per_tr_load =str() + tr_load_cond_map = {"t": "has_load_tr", "f": "true"} + + per_tr_load = str() for tr_load in ["t", "f"]: - per_dtypes=str() + per_dtypes = str() for i, dtype in enumerate(self.pool.keys()): - per_hdim_case=str() + per_hdim_case = str() for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): - traits=[t for t in self.pool[dtype][(hdim, hdim_v)] if tr_load == t.tr_load] - inners=str() + traits = [ + t + for t in self.pool[dtype][(hdim, hdim_v)] + if tr_load == t.tr_load + ] + inners = str() for k, trait in enumerate(traits): - if_k = 'if' if k == 0 else 'else if' - inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], - F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], - F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip], F_trload=BOOL_MAP[trait.tr_load], - F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_seqtune=trait.seqtune, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, - F_constraint=trait.constraint, - F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], - F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, - F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) - if_j = 'if' if j == 0 else 'else if' - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners) - if_i = 'if' if i == 0 else 'else if' - per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) - per_tr_load += FMHA_FWD_API_PER_TRLOAD.format(F_if='if', F_trload_cond=tr_load_cond_map[tr_load], F_dtype_case=per_dtypes) + if_k = "if" if k == 0 else "else if" + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format( + F_if=if_k, + F_mode=MODE_MAP[trait.mode], + F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], + F_logits=BOOL_MAP[trait.logits], + F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], + F_bias_check=BIAS_CHECK_MAP[trait.bias], + F_bias=BIAS_MAP[trait.bias], + F_lse=BOOL_MAP[trait.lse], + F_dropout=BOOL_MAP[trait.dropout], + F_skip=BOOL_MAP[trait.skip], + F_trload=BOOL_MAP[trait.tr_load], + F_squant=BOOL_MAP[trait.squant], + F_scheck=trait.scheck, + F_seqtune=trait.seqtune, + F_skcheck=trait.skcheck, + F_dcheck=trait.dcheck, + F_dvcheck=trait.dvcheck, + F_constraint=trait.constraint, + F_spad=BOOL_MAP[trait.spad], + F_skpad=BOOL_MAP[trait.skpad], + F_dpad=BOOL_MAP[trait.dpad], + F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, + F_bn0=trait.bn0, + F_bk0=trait.bk0, + F_bn1=trait.bn1, + F_bk1=trait.bk1, + F_bk0max=trait.bk0max, + F_hdim=hdim, + F_dtype=FWD_DTYPE_MAP[dtype], + ) + if_j = "if" if j == 0 else "else if" + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format( + F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners + ) + if_i = "if" if i == 0 else "else if" + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format( + F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case + ) + per_tr_load += FMHA_FWD_API_PER_TRLOAD.format( + F_if="if", + F_trload_cond=tr_load_cond_map[tr_load], + F_dtype_case=per_dtypes, + ) if not per_tr_load: # empty string we add some ignore to suppress warning in api - per_tr_load += ' (void)t ; (void)s ; (void)a;' - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_tr_load) + per_tr_load += " (void)t ; (void)s ; (void)a;" + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load) + @dataclass class FmhaFwdTileSize: - F_bm0 : int # tile size along q seqlen (block size) - F_bn0 : int # tile size along k seqlen - F_bk0 : int # tile size along qk gemm unroll - F_bn1 : int # tile size along v head_dim - F_bk1 : int # tile size along kv gemm unroll - F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) - F_rm0 : int # number of warps for gemm0 along q seqlen - F_rn0 : int # number of warps for gemm0 along k seqlen - F_rk0 : int # number of warps for gemm0 along head dim q (not used) - F_rm1 : int # number of warps for gemm1 along q seqlen - F_rn1 : int # number of warps for gemm1 along head dim v - F_rk1 : int # number of warps for gemm1 along k seqlen (not used) - F_wm0 : int # gemm0 warp size along m - F_wn0 : int # gemm0 warp size along n - F_wk0 : int # gemm0 warp size along k - F_wm1 : int # gemm1 warp size along m - F_wn1 : int # gemm1 warp size along n - F_wk1 : int # gemm1 warp size along k - F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy - F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint()) + F_bm0: int # tile size along q seqlen (block size) + F_bn0: int # tile size along k seqlen + F_bk0: int # tile size along qk gemm unroll + F_bn1: int # tile size along v head_dim + F_bk1: int # tile size along kv gemm unroll + F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) + F_rm0: int # number of warps for gemm0 along q seqlen + F_rn0: int # number of warps for gemm0 along k seqlen + F_rk0: int # number of warps for gemm0 along head dim q (not used) + F_rm1: int # number of warps for gemm1 along q seqlen + F_rn1: int # number of warps for gemm1 along head dim v + F_rk1: int # number of warps for gemm1 along k seqlen (not used) + F_wm0: int # gemm0 warp size along m + F_wn0: int # gemm0 warp size along n + F_wk0: int # gemm0 warp size along k + F_wm1: int # gemm1 warp size along m + F_wn1: int # gemm1 warp size along n + F_wk1: int # gemm1 warp size along k + F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint()) @property def name(self) -> str: - return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\ - f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\ - f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" +\ - ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + return ( + f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" + + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" + + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" + + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + ) + @dataclass class FmhaFwdKernel: - F_idx : int # this is not a tunable, but a counter to differentiate symbol - F_hdim : int # hdim - F_dtype : str # data type - F_mode : str # value from MODE_MAP - F_tile : FmhaFwdTileSize - F_pipeline : FmhaFwdPipeline - mask_impl : str + F_idx: int # this is not a tunable, but a counter to differentiate symbol + F_hdim: int # hdim + F_dtype: str # data type + F_mode: str # value from MODE_MAP + F_tile: FmhaFwdTileSize + F_pipeline: FmhaFwdPipeline + mask_impl: str @property def template(self) -> str: - kernel_body = str() - return FMHA_FWD_KERNEL_HEADER + \ - FMHA_FWD_KERNEL_BODY.format( - F_idx = self.F_idx, - F_hdim = self.F_hdim, - F_dtype = FWD_DTYPE_MAP[self.F_dtype], - F_bm0 = self.F_tile.F_bm0, - F_bn0 = self.F_tile.F_bn0, - F_bk0 = self.F_tile.F_bk0, - F_bn1 = self.F_tile.F_bn1, - F_bk1 = self.F_tile.F_bk1, - F_bk0max = self.F_tile.F_bk0max, - F_rm0 = self.F_tile.F_rm0, - F_rn0 = self.F_tile.F_rn0, - F_rk0 = self.F_tile.F_rk0, - F_rm1 = self.F_tile.F_rm1, - F_rn1 = self.F_tile.F_rn1, - F_rk1 = self.F_tile.F_rk1, - F_wm0 = self.F_tile.F_wm0, - F_wn0 = self.F_tile.F_wn0, - F_wk0 = self.F_tile.F_wk0, - F_wm1 = self.F_tile.F_wm1, - F_wn1 = self.F_tile.F_wn1, - F_wk1 = self.F_tile.F_wk1, - F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], - F_spad = BOOL_MAP[self.F_pipeline.F_spad], - F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], - F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], - F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], - F_logits = BOOL_MAP[self.F_pipeline.F_logits], - F_bias = BIAS_MAP[self.F_pipeline.F_bias], - F_lse = BOOL_MAP[self.F_pipeline.F_lse], - F_dropout = BOOL_MAP[self.F_pipeline.F_dropout], - F_squant = BOOL_MAP[self.F_pipeline.F_squant], - F_skip = BOOL_MAP[self.F_pipeline.F_skip], - F_occupancy = self.F_tile.F_occupancy, - F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], - F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], - F_mode = MODE_MAP[self.F_mode], - F_pipeline = PIPELINE_MAP[self.F_pipeline.tag], - F_trload = BOOL_MAP[self.F_pipeline.F_trload]) + # kernel_body removed - unused + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format( + F_idx=self.F_idx, + F_hdim=self.F_hdim, + F_dtype=FWD_DTYPE_MAP[self.F_dtype], + F_bm0=self.F_tile.F_bm0, + F_bn0=self.F_tile.F_bn0, + F_bk0=self.F_tile.F_bk0, + F_bn1=self.F_tile.F_bn1, + F_bk1=self.F_tile.F_bk1, + F_bk0max=self.F_tile.F_bk0max, + F_rm0=self.F_tile.F_rm0, + F_rn0=self.F_tile.F_rn0, + F_rk0=self.F_tile.F_rk0, + F_rm1=self.F_tile.F_rm1, + F_rn1=self.F_tile.F_rn1, + F_rk1=self.F_tile.F_rk1, + F_wm0=self.F_tile.F_wm0, + F_wn0=self.F_tile.F_wn0, + F_wk0=self.F_tile.F_wk0, + F_wm1=self.F_tile.F_wm1, + F_wn1=self.F_tile.F_wn1, + F_wk1=self.F_tile.F_wk1, + F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad=BOOL_MAP[self.F_pipeline.F_spad], + F_skpad=BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad=BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad], + F_logits=BOOL_MAP[self.F_pipeline.F_logits], + F_bias=BIAS_MAP[self.F_pipeline.F_bias], + F_lse=BOOL_MAP[self.F_pipeline.F_lse], + F_dropout=BOOL_MAP[self.F_pipeline.F_dropout], + F_squant=BOOL_MAP[self.F_pipeline.F_squant], + F_skip=BOOL_MAP[self.F_pipeline.F_skip], + F_occupancy=self.F_tile.F_occupancy, + F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], + F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mode=MODE_MAP[self.F_mode], + F_pipeline=PIPELINE_MAP[self.F_pipeline.tag], + F_trload=BOOL_MAP[self.F_pipeline.F_trload], + ) @property def name(self) -> str: # TODO: we don't encode idx here - return f"fmha_jenga_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ - self.F_tile.name + '_' + self.F_pipeline.name + return ( + f"fmha_jenga_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + + self.F_tile.name + + "_" + + self.F_pipeline.name + ) @property def filename(self) -> str: @@ -525,57 +621,207 @@ def filename(self) -> str: def api_trait(self) -> FmhaFwdApiTrait: return FmhaFwdApiTrait( - pipeline_tag=self.F_pipeline.tag, - hdim=str(self.F_hdim), - dtype=self.F_dtype, - mode=self.F_mode, - bm0=self.F_tile.F_bm0, - bn0=self.F_tile.F_bn0, - bk0=self.F_tile.F_bk0, - bn1=self.F_tile.F_bn1, - bk1=self.F_tile.F_bk1, - bk0max=self.F_tile.F_bk0max, - vlayout=self.F_pipeline.F_vlayout, - mask=self.F_pipeline.F_mask, - logits=self.F_pipeline.F_logits, - bias=self.F_pipeline.F_bias, - lse=self.F_pipeline.F_lse, - dropout=self.F_pipeline.F_dropout, - squant=self.F_pipeline.F_squant, - spad=self.F_pipeline.F_spad, - skpad=self.F_pipeline.F_skpad, - dpad=self.F_pipeline.F_dpad, - dvpad=self.F_pipeline.F_dvpad, - skip=self.F_pipeline.F_skip, - tr_load=self.F_pipeline.F_trload, - constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint) + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0max=self.F_tile.F_bk0max, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + logits=self.F_pipeline.F_logits, + bias=self.F_pipeline.F_bias, + lse=self.F_pipeline.F_lse, + dropout=self.F_pipeline.F_dropout, + squant=self.F_pipeline.F_squant, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad, + skip=self.F_pipeline.F_skip, + tr_load=self.F_pipeline.F_trload, + constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint, + ) + class KernelComponentFactory: # TODO: design a more practical way to do it # this is current supported tile size per hdim @staticmethod - def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]: - if dtype == 'fp16' or dtype == 'bf16': + def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: + if dtype == "fp16" or dtype == "bf16": return { # (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], # (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), # FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), # FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], # (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - (128,128) : [FmhaFwdTileSize(16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), - FmhaFwdTileSize(32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), - FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (128, 128): [ + FmhaFwdTileSize( + 16, + 32, + 64, + 128, + 32, + 128, + 1, + 1, + 1, + 1, + 1, + 1, + 16, + 16, + 32, + 16, + 16, + 32, + -1, + ), + FmhaFwdTileSize( + 32, + 32, + 128, + 128, + 32, + 128, + 1, + 1, + 1, + 1, + 1, + 1, + 32, + 32, + 16, + 32, + 32, + 16, + -1, + ), + FmhaFwdTileSize( + 128, + 64, + 32, + 128, + 16, + 128, + 4, + 1, + 1, + 4, + 1, + 1, + 32, + 32, + 16, + 32, + 32, + 16, + -1, + ), + FmhaFwdTileSize( + 128, + 128, + 32, + 128, + 32, + 128, + 4, + 1, + 1, + 4, + 1, + 1, + 32, + 32, + 16, + 32, + 32, + 16, + -1, + ), + ], # (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], # (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], # (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], # (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], } - elif dtype == 'fp8' or dtype == 'bf8': + elif dtype == "fp8" or dtype == "bf8": return { - (64,64 ) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)], - (128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], - (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], + (64, 64): [ + FmhaFwdTileSize( + 128, + 64, + 32, + 64, + 32, + 64, + 2, + 1, + 1, + 2, + 1, + 1, + 32, + 32, + 32, + 32, + 32, + 32, + -1, + ) + ], + (128, 128): [ + FmhaFwdTileSize( + 128, + 128, + 32, + 128, + 32, + 128, + 4, + 1, + 1, + 4, + 1, + 1, + 32, + 32, + 32, + 32, + 32, + 32, + -1, + ) + ], + (256, 256): [ + FmhaFwdTileSize( + 128, + 128, + 32, + 256, + 32, + 256, + 4, + 1, + 1, + 4, + 1, + 1, + 32, + 32, + 32, + 32, + 32, + 32, + -1, + ) + ], } else: return None @@ -588,94 +834,288 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli # TODO: the order of List matters! the later in this list will be also be checked later # TODO: currently for qr pipeline, let 't' padding to appear later!! # TODO: how to design this more generic? - squant = 't' if dtype == 'fp8' else 'f' + squant = "t" if dtype == "fp8" else "f" pipelines = [] - if dtype in ['fp16', 'bf16']: - for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]): + if dtype in ["fp16", "bf16"]: + for logits, mask, bias, lse, dropout, skip in itertools.product( + ["t", "f"], + get_mask_map(mask_impl).keys(), + BIAS_MAP.keys(), + ["t", "f"], + ["t", "f"], + ["t", "f"], + ): if hdim == 256 and hdim_v == 256: # print("jenga fmha only support dim=128 now.") continue - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append( + FmhaFwdPipeline( + "qr", + "row", + "f", + "f", + "f", + "f", + logits, + bias, + lse, + dropout, + squant, + mask, + skip, + "f", + ) + ) # the below two is used for hdim vectorize load - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append( + FmhaFwdPipeline( + "qr", + "row", + "t", + "t", + "f", + "f", + logits, + bias, + lse, + dropout, + squant, + mask, + skip, + "f", + ) + ) + pipelines.append( + FmhaFwdPipeline( + "qr", + "row", + "t", + "t", + "t", + "t", + logits, + bias, + lse, + dropout, + squant, + mask, + skip, + "f", + ) + ) else: if bias == "bias": # print("jenga_fmha with bias is not implemented.") continue # TODO: rocm 6.2 compiler problem if using qr_async for bias case - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append( + FmhaFwdPipeline( + "qr", + "row", + "f", + "f", + "f", + "f", + logits, + bias, + lse, + dropout, + squant, + mask, + skip, + "f", + ) + ) + pipelines.append( + FmhaFwdPipeline( + "qr", + "row", + "t", + "t", + "t", + "t", + logits, + bias, + lse, + dropout, + squant, + mask, + skip, + "f", + ) + ) else: - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append( + FmhaFwdPipeline( + "qr_async", + "row", + "t", + "f", + "t", + "t", + logits, + bias, + lse, + dropout, + squant, + mask, + skip, + "f", + ) + ) + pipelines.append( + FmhaFwdPipeline( + "qr_async", + "row", + "t", + "t", + "t", + "t", + logits, + bias, + lse, + dropout, + squant, + mask, + skip, + "f", + ) + ) # if (hdim, hdim_v) in [(64, 64), (128, 128)] and logits == "f" and bias == "no" and dropout == "f" and lse == "f" and skip == "f": # pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 't')) # pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 't')) # if receipt == 1 and bias != "bias": # pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) # TODO: cover arbitraty hdim - elif dtype in ['fp8', 'bf8']: + elif dtype in ["fp8", "bf8"]: # print("jenga fmha only support 16-bit compute.") return pipelines # no need lse/dropout kernels - for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f')) - elif dtype in ['fp8fp16', 'fp8bf16']: + for logits, mask, bias in itertools.product( + ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys() + ): + pipelines.append( + FmhaFwdPipeline( + "qr", + "col", + "f", + "f", + "f", + "f", + logits, + bias, + "f", + "f", + squant, + mask, + "f", + "f", + ) + ) + elif dtype in ["fp8fp16", "fp8bf16"]: # TODO None else: assert False return pipelines + class CustomFactory(KernelComponentFactory): @staticmethod - def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]: + def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: result = KernelComponentFactory.get_hdim_tile_size_dict(dtype) - if dtype == 'fp16' or dtype == 'bf16': + if dtype == "fp16" or dtype == "bf16": if (128, 128) in result.keys(): - result[(128, 128)].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint('get_num_blocks(128) < num_cus * min_cu_util_rate'))) + result[(128, 128)].insert( + 0, + FmhaFwdTileSize( + 64, + 128, + 64, + 128, + 64, + 128, + 4, + 1, + 1, + 4, + 1, + 1, + 16, + 16, + 16, + 16, + 16, + 16, + -1, + CppConstraint( + "get_num_blocks(128) < num_cus * min_cu_util_rate" + ), + ), + ) return result -def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: + +def get_fwd_blobs( + kernel_filter: Optional[str], receipt, optdim_list, mask_impl +) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: gen = list() api_pool = FmhaFwdApiPool(mask_impl) - factory = CustomFactory if os.environ.get('CK_TILE_FMHA_FWD_CUSTOM_FACTORY', '0') == '1' else KernelComponentFactory + factory = ( + CustomFactory + if os.environ.get("CK_TILE_FMHA_FWD_CUSTOM_FACTORY", "0") == "1" + else KernelComponentFactory + ) for dtype in FWD_DTYPE_MAP.keys(): d = factory.get_hdim_tile_size_dict(dtype) - if d == None: + if d is None: continue - #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): - for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), MODE_MAP.keys()): - for tile, pipeline in itertools.product(tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)): + # for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + for ((hdim, hdim_v), tiles), mode in itertools.product( + d.items(), MODE_MAP.keys() + ): + for tile, pipeline in itertools.product( + tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) + ): if mode == "group": - if pipeline.F_spad != 't' or pipeline.F_skpad != 't': + if pipeline.F_spad != "t" or pipeline.F_skpad != "t": # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not continue if (hdim, hdim_v) == (192, 128): # NOTE: this is used to speedup deepseek prefill case, we don't gen training - if pipeline.F_bias != 'no' or pipeline.F_dropout == 't': + if pipeline.F_bias != "no" or pipeline.F_dropout == "t": continue - if pipeline.tag != 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 != 128) or ((hdim, hdim_v) != (128, 128) and tile.F_bm0 != 128)): + if pipeline.tag != "qr_async_trload" and ( + ((hdim, hdim_v) == (128, 128) and tile.F_bn0 != 128) + or ((hdim, hdim_v) != (128, 128) and tile.F_bm0 != 128) + ): # non qr_async_trload only support km0=128 tile size when hdim is not 128 # non qr_async only support kn0=128 tile size when hdim is 128 continue - if pipeline.tag == 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 == 128) or ((hdim, hdim_v) not in [(64, 64), (128, 128)])): + if pipeline.tag == "qr_async_trload" and ( + ((hdim, hdim_v) == (128, 128) and tile.F_bn0 == 128) + or ((hdim, hdim_v) not in [(64, 64), (128, 128)]) + ): continue # logits_soft_cap is only allowed if no bias - if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): + if not ( + (pipeline.F_logits == "t" and pipeline.F_bias == "no") + or pipeline.F_logits == "f" + ): continue if pipeline.tag != "qr_async": continue - k = FmhaFwdKernel(F_idx=2, - F_hdim=hdim, - F_dtype=dtype, - F_mode=mode, - F_tile=tile, - F_pipeline=pipeline, - mask_impl=mask_impl) - if kernel_filter != '': + k = FmhaFwdKernel( + F_idx=2, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl, + ) + if kernel_filter != "": if not fnmatch.fnmatch(k.name, kernel_filter): continue if optdim_list != [-1]: @@ -683,45 +1123,45 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl continue # 2 - Flash attention integration if receipt in (2, 3): - cond = dtype in ['fp16', 'bf16'] - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_bias in ['no', 'alibi'] - cond &= pipeline.F_squant == 'f' - cond &= pipeline.F_skip == 'f' + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_bias in ["no", "alibi"] + cond &= pipeline.F_squant == "f" + cond &= pipeline.F_skip == "f" if not cond: continue # PyTorch integration elif receipt == 4: - cond = dtype in ['fp16', 'bf16'] - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_bias in ['no', 'bias'] - cond &= pipeline.F_squant == 'f' - cond &= mode == 'batch' - cond &= pipeline.F_skip == 'f' - cond &= pipeline.F_logits == 'f' + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_bias in ["no", "bias"] + cond &= pipeline.F_squant == "f" + cond &= mode == "batch" + cond &= pipeline.F_skip == "f" + cond &= pipeline.F_logits == "f" if not cond: continue # Aiter(mha_fwd) integration elif receipt == 100: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == 'batch' - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_squant == 'f' + cond = dtype in ["fp16", "bf16"] + cond &= mode == "batch" + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_squant == "f" if not cond: continue # Aiter(mha_varlen_fwd) integration elif receipt == 200: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == 'group' - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_squant == 'f' + cond = dtype in ["fp16", "bf16"] + cond &= mode == "group" + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_squant == "f" if not cond: continue # aiter::mha_fwd C++ api integration elif receipt == 600: - cond = dtype in ['fp16', 'bf16'] - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_squant == 'f' + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_squant == "f" if not cond: continue @@ -730,20 +1170,28 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl return (api_pool, gen) + def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: update_file(autogen_dir / kernel.filename, kernel.template) -def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: + +def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None: update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api) -def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: + +def write_blobs( + output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl +) -> None: api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) for kernel in kernels: write_single_fwd_kernel(kernel, output_dir) write_fwd_api(api_pool, output_dir) -def list_blobs(file_path : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: - with file_path.open('a') as f: + +def list_blobs( + file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl +) -> None: + with file_path.open("a") as f: _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py index 1e7bbfc9c2a..0d4bf9b7811 100644 --- a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py +++ b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py @@ -7,16 +7,26 @@ import fnmatch import itertools import os +import os.path as path from pathlib import Path from typing import List, Optional, Tuple -from codegen.cpp_symbol_map import * -from codegen.cpp_symbol_map import SQUANT_MAP +from codegen.cpp_symbol_map import ( + BIAS_CHECK_MAP, + BIAS_MAP, + BOOL_MAP, + FWD_DTYPE_MAP, + LAYOUT_MAP, + MODE_MAP, + PIPELINE_ENUM_MAP, + PIPELINE_MAP, + SQUANT_MAP, + get_mask_check_map, + get_mask_map, +) GEN_DIR = "" -import os.path as path - def update_file(file_path, content): """Update the file at file_path with the given content if it differs from the existing content. @@ -34,22 +44,9 @@ def update_file(file_path, content): file.write(content) -DTYPE_BITS = { - "fp32": 32, - "fp16": 16, - "bf16": 16, - "fp8" : 8, - "bf8" : 8 -} +DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16, "fp8": 8, "bf8": 8} -K0_MAX_SUBMAX_MAP = { - 32 : 32, - 64 : 64, - 96 : 128, - 128: 128, - 192: 192, - 256: 256 -} +K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256} FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n @@ -61,7 +58,7 @@ def update_file(file_path, content): """ -FMHA_FWD_KERNEL_BODY=""" +FMHA_FWD_KERNEL_BODY = """ using fmha_dtype_{F_idx} = {F_dtype}; using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; @@ -138,8 +135,8 @@ def update_file(file_path, content): }} """ -FMHA_FWD_API_FILENAME="fmha_vsa_fwd_api.cpp" -FMHA_FWD_API=""" +FMHA_FWD_API_FILENAME = "fmha_vsa_fwd_api.cpp" +FMHA_FWD_API = """ #include #include @@ -193,198 +190,255 @@ def update_file(file_path, content): }} """ -FMHA_FWD_API_PER_TRLOAD=""" {F_if}({F_trload_cond}){{ +FMHA_FWD_API_PER_TRLOAD = """ {F_if}({F_trload_cond}){{ {F_dtype_case} }} """ -FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ {F_hdim_case} }} """ -FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ +FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ {F_inner_dispatch} }} """ -FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && +FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ using trait_ = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; return fmha_vsa_fwd_(s, a); }} """ + @dataclass class CppConstraint: bool_expr: str = None def __str__(self): if self.bool_expr is None: - return 'true' + return "true" else: - return f'{self.bool_expr}' + return f"{self.bool_expr}" def __and__(self, other): - return CppConstraint(f'({str(self)}) && ({str(other)})') + return CppConstraint(f"({str(self)}) && ({str(other)})") + @dataclass class FmhaFwdApiTrait: - pipeline_tag : str + pipeline_tag: str # sync with fmha_fwd_traits<>, to generate fallback calls - hdim : str - dtype : str # data type - mode : str # value from MODE_MAP - bm0 : int # tile size along q seqlen (block size) - bn0 : int # tile size along qk seqlen - bk0 : int # tile size along qk gemm unroll - bn1 : int # tile size along v head_dim - bk1 : int # tile size along kv gemm unroll - bk0max : int - vlayout : str - logits : str - mask : str - bias : str # - lse : str # - dropout : str - squant : str # - spad : str - skpad : str - dpad : str - dvpad : str - skip : str - tr_load : str - constraint : CppConstraint + hdim: str + dtype: str # data type + mode: str # value from MODE_MAP + bm0: int # tile size along q seqlen (block size) + bn0: int # tile size along qk seqlen + bk0: int # tile size along qk gemm unroll + bn1: int # tile size along v head_dim + bk1: int # tile size along kv gemm unroll + bk0max: int + vlayout: str + logits: str + mask: str + bias: str # + lse: str # + dropout: str + squant: str # + spad: str + skpad: str + dpad: str + dvpad: str + skip: str + tr_load: str + constraint: CppConstraint @property def name(self) -> str: - return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\ - f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}' + return ( + f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" + + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}" + ) @property def scheck(self) -> str: - if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true - if self.pipeline_tag in ['qr_async_vsa', 'qr_async_trload']: - if self.spad == 't' : return 'true' # always support - else : return 'true' - elif self.pipeline_tag in ['qr', 'qs']: - if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.seqlen_q % {self.bm0} == 0' - else: assert False - + if self.mode == "group": + return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true + if self.pipeline_tag in ["qr_async_vsa", "qr_async_trload"]: + if self.spad == "t": + return "true" # always support + else: + return "true" + elif self.pipeline_tag in ["qr", "qs"]: + if self.spad == "t": + return f"true /*a.seqlen_q % {self.bm0} != 0*/" # TODO: order of get_pipelines() matters! (ugly) + else: + return f"a.seqlen_q % {self.bm0} == 0" + else: + assert False + @property def seqtune(self) -> str: - if self.bm0 == 128: return 'true/*fall back to largest tile*/' # group mode only generate spad/skpad == true - else: - return f'a.seqlen_q <= {self.bm0}' + if self.bm0 == 128: + return "true/*fall back to largest tile*/" # group mode only generate spad/skpad == true + else: + return f"a.seqlen_q <= {self.bm0}" @property def skcheck(self) -> str: - if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true - if self.pipeline_tag == 'qr_async_vsa': - if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' - else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' - elif self.pipeline_tag in ['qr', 'qs']: - if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.seqlen_k % {self.bn0} == 0' - elif self.pipeline_tag == 'qr_async_trload': - if self.skpad == 't' : return 'true' - else: return 'true' - else: assert False + if self.mode == "group": + return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true + if self.pipeline_tag == "qr_async_vsa": + if self.skpad == "t": + return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0" + else: + return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0" + elif self.pipeline_tag in ["qr", "qs"]: + if self.skpad == "t": + return f"true /*a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly) + else: + return f"a.seqlen_k % {self.bn0} == 0" + elif self.pipeline_tag == "qr_async_trload": + if self.skpad == "t": + return "true" + else: + return "true" + else: + assert False @property def dcheck(self) -> str: - if self.pipeline_tag == 'qr_async_vsa': + if self.pipeline_tag == "qr_async_vsa": vec = int((32 * 4) / DTYPE_BITS[self.dtype]) - if self.dpad == 't': return f'a.hdim_q % {vec} == 0' - else : assert False - elif self.pipeline_tag in ['qr', 'qs', 'qr_async_trload']: + if self.dpad == "t": + return f"a.hdim_q % {vec} == 0" + else: + assert False + elif self.pipeline_tag in ["qr", "qs", "qr_async_trload"]: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] - if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.hdim_q % {bk0submax} == 0' - else: assert False + if self.dpad == "t": + return f"true /*a.hdim_q % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly) + else: + return f"a.hdim_q % {bk0submax} == 0" + else: + assert False @property def dvcheck(self) -> str: - if self.pipeline_tag == 'qr_async_vsa': + if self.pipeline_tag == "qr_async_vsa": vec = int((32 * 4) / DTYPE_BITS[self.dtype]) - if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' - else : assert False - elif self.pipeline_tag in ['qr', 'qs', 'qr_async_trload']: + if self.dvpad == "t": + return f"a.hdim_v % {vec} == 0" + else: + assert False + elif self.pipeline_tag in ["qr", "qs", "qr_async_trload"]: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] - if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.hdim_v % {bk0submax} == 0' - else: assert False + if self.dvpad == "t": + return f"true /*a.hdim_v % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly) + else: + return f"a.hdim_v % {bk0submax} == 0" + else: + assert False + @dataclass class FmhaFwdPipeline: - tag : str - - F_vlayout : str # row/col - F_spad : str # true/false - F_skpad : str # - F_dpad : str # - F_dvpad : str # - F_logits : str # t/f - F_bias : str # true/false - F_lse : str # - F_dropout : str # - F_squant : str # - F_mask : str # value from MASK_MAP - F_skip : str # true/false - F_trload : str # true/false - F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint()) + tag: str + + F_vlayout: str # row/col + F_spad: str # true/false + F_skpad: str # + F_dpad: str # + F_dvpad: str # + F_logits: str # t/f + F_bias: str # true/false + F_lse: str # + F_dropout: str # + F_squant: str # + F_mask: str # value from MASK_MAP + F_skip: str # true/false + F_trload: str # true/false + F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint()) @property def name(self) -> str: def pad_name() -> str: - n = '' - if self.F_spad == 't': n += 's' - if self.F_skpad == 't' : n += 'sk' - if self.F_dpad == 't' : n += 'd' - if self.F_dvpad == 't' : n += 'dv' - if n != '' : n = 'p' + n + n = "" + if self.F_spad == "t": + n += "s" + if self.F_skpad == "t": + n += "sk" + if self.F_dpad == "t": + n += "d" + if self.F_dvpad == "t": + n += "dv" + if n != "": + n = "p" + n return n + pn = pad_name() - n = f'{self.tag}_v{self.F_vlayout[0]}' - if pn != '' : n += f'_{pn}' - else: n += '_npad' + n = f"{self.tag}_v{self.F_vlayout[0]}" + if pn != "": + n += f"_{pn}" + else: + n += "_npad" - if self.F_logits == 't' : n += '_logits' - else: n += '_nlogits' + if self.F_logits == "t": + n += "_logits" + else: + n += "_nlogits" - if self.F_bias != 'no' : n += f'_{self.F_bias}' - else: n += '_nbias' + if self.F_bias != "no": + n += f"_{self.F_bias}" + else: + n += "_nbias" - if self.F_mask[0:2] == 's_': - if self.F_mask == 's_mask': n += f'_mask' - else: n += '_nmask' + if self.F_mask[0:2] == "s_": + if self.F_mask == "s_mask": + n += "_mask" + else: + n += "_nmask" else: - if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' - else: n += '_nmask' + if self.F_mask != "no": + n += f"_m{self.F_mask[0]}" + else: + n += "_nmask" - if self.F_lse == 't' : n += '_lse' - else: n += '_nlse' + if self.F_lse == "t": + n += "_lse" + else: + n += "_nlse" + + if self.F_dropout == "t": + n += "_dropout" + else: + n += "_ndropout" - if self.F_dropout == 't' : n += '_dropout' - else: n += '_ndropout' + if self.F_skip == "t": + n += "_skip" + else: + n += "_nskip" - if self.F_skip == 't' : n += '_skip' - else: n += '_nskip' + if self.F_squant == "t": + n += "_squant" + else: + n += "_nsquant" - if self.F_squant == 't' : n += '_squant' - else: n += '_nsquant' - - if self.F_trload == 't' : n += '_trload' - else: n += '_ntrload' + if self.F_trload == "t": + n += "_trload" + else: + n += "_ntrload" return n + class FmhaFwdApiPool: def __init__(self, mask_impl): self.pool = dict() self.mask_impl = mask_impl - def register_traits(self, trait : FmhaFwdApiTrait) -> None: + def register_traits(self, trait: FmhaFwdApiTrait) -> None: # TODO: do we need to check duplication? if trait.dtype not in self.pool.keys(): self.pool[trait.dtype] = dict() @@ -396,130 +450,172 @@ def register_traits(self, trait : FmhaFwdApiTrait) -> None: @property def api(self) -> str: - tr_load_cond_map = { - "t": "has_load_tr", - "f": "true" - } - - per_tr_load =str() + tr_load_cond_map = {"t": "has_load_tr", "f": "true"} + + per_tr_load = str() for tr_load in ["t", "f"]: - per_dtypes=str() + per_dtypes = str() for i, dtype in enumerate(self.pool.keys()): - per_hdim_case=str() + per_hdim_case = str() for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): - traits=[t for t in self.pool[dtype][(hdim, hdim_v)] if tr_load == t.tr_load] - inners=str() + traits = [ + t + for t in self.pool[dtype][(hdim, hdim_v)] + if tr_load == t.tr_load + ] + inners = str() for k, trait in enumerate(traits): - if_k = 'if' if k == 0 else 'else if' - inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], - F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], - F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip], F_trload=BOOL_MAP[trait.tr_load], - F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_seqtune=trait.seqtune, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, - F_constraint=trait.constraint, - F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], - F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, - F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) - if_j = 'if' if j == 0 else 'else if' - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners) - if_i = 'if' if i == 0 else 'else if' - per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) - per_tr_load += FMHA_FWD_API_PER_TRLOAD.format(F_if='if', F_trload_cond=tr_load_cond_map[tr_load], F_dtype_case=per_dtypes) + if_k = "if" if k == 0 else "else if" + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format( + F_if=if_k, + F_mode=MODE_MAP[trait.mode], + F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], + F_logits=BOOL_MAP[trait.logits], + F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], + F_bias_check=BIAS_CHECK_MAP[trait.bias], + F_bias=BIAS_MAP[trait.bias], + F_lse=BOOL_MAP[trait.lse], + F_dropout=BOOL_MAP[trait.dropout], + F_skip=BOOL_MAP[trait.skip], + F_trload=BOOL_MAP[trait.tr_load], + F_squant=BOOL_MAP[trait.squant], + F_scheck=trait.scheck, + F_seqtune=trait.seqtune, + F_skcheck=trait.skcheck, + F_dcheck=trait.dcheck, + F_dvcheck=trait.dvcheck, + F_constraint=trait.constraint, + F_spad=BOOL_MAP[trait.spad], + F_skpad=BOOL_MAP[trait.skpad], + F_dpad=BOOL_MAP[trait.dpad], + F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, + F_bn0=trait.bn0, + F_bk0=trait.bk0, + F_bn1=trait.bn1, + F_bk1=trait.bk1, + F_bk0max=trait.bk0max, + F_hdim=hdim, + F_dtype=FWD_DTYPE_MAP[dtype], + ) + if_j = "if" if j == 0 else "else if" + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format( + F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners + ) + if_i = "if" if i == 0 else "else if" + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format( + F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case + ) + per_tr_load += FMHA_FWD_API_PER_TRLOAD.format( + F_if="if", + F_trload_cond=tr_load_cond_map[tr_load], + F_dtype_case=per_dtypes, + ) if not per_tr_load: # empty string we add some ignore to suppress warning in api - per_tr_load += ' (void)t ; (void)s ; (void)a;' - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_tr_load) + per_tr_load += " (void)t ; (void)s ; (void)a;" + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_tr_load) + @dataclass class FmhaFwdTileSize: - F_bm0 : int # tile size along q seqlen (block size) - F_bn0 : int # tile size along k seqlen - F_bk0 : int # tile size along qk gemm unroll - F_bn1 : int # tile size along v head_dim - F_bk1 : int # tile size along kv gemm unroll - F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) - F_rm0 : int # number of warps for gemm0 along q seqlen - F_rn0 : int # number of warps for gemm0 along k seqlen - F_rk0 : int # number of warps for gemm0 along head dim q (not used) - F_rm1 : int # number of warps for gemm1 along q seqlen - F_rn1 : int # number of warps for gemm1 along head dim v - F_rk1 : int # number of warps for gemm1 along k seqlen (not used) - F_wm0 : int # gemm0 warp size along m - F_wn0 : int # gemm0 warp size along n - F_wk0 : int # gemm0 warp size along k - F_wm1 : int # gemm1 warp size along m - F_wn1 : int # gemm1 warp size along n - F_wk1 : int # gemm1 warp size along k - F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy - F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint()) + F_bm0: int # tile size along q seqlen (block size) + F_bn0: int # tile size along k seqlen + F_bk0: int # tile size along qk gemm unroll + F_bn1: int # tile size along v head_dim + F_bk1: int # tile size along kv gemm unroll + F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) + F_rm0: int # number of warps for gemm0 along q seqlen + F_rn0: int # number of warps for gemm0 along k seqlen + F_rk0: int # number of warps for gemm0 along head dim q (not used) + F_rm1: int # number of warps for gemm1 along q seqlen + F_rn1: int # number of warps for gemm1 along head dim v + F_rk1: int # number of warps for gemm1 along k seqlen (not used) + F_wm0: int # gemm0 warp size along m + F_wn0: int # gemm0 warp size along n + F_wk0: int # gemm0 warp size along k + F_wm1: int # gemm1 warp size along m + F_wn1: int # gemm1 warp size along n + F_wk1: int # gemm1 warp size along k + F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint()) @property def name(self) -> str: - return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\ - f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\ - f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" +\ - ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + return ( + f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" + + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" + + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" + + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + ) + @dataclass class FmhaFwdKernel: - F_idx : int # this is not a tunable, but a counter to differentiate symbol - F_hdim : int # hdim - F_dtype : str # data type - F_mode : str # value from MODE_MAP - F_tile : FmhaFwdTileSize - F_pipeline : FmhaFwdPipeline - mask_impl : str + F_idx: int # this is not a tunable, but a counter to differentiate symbol + F_hdim: int # hdim + F_dtype: str # data type + F_mode: str # value from MODE_MAP + F_tile: FmhaFwdTileSize + F_pipeline: FmhaFwdPipeline + mask_impl: str @property def template(self) -> str: - kernel_body = str() - return FMHA_FWD_KERNEL_HEADER + \ - FMHA_FWD_KERNEL_BODY.format( - F_idx = self.F_idx, - F_hdim = self.F_hdim, - F_dtype = FWD_DTYPE_MAP[self.F_dtype], - F_bm0 = self.F_tile.F_bm0, - F_bn0 = self.F_tile.F_bn0, - F_bk0 = self.F_tile.F_bk0, - F_bn1 = self.F_tile.F_bn1, - F_bk1 = self.F_tile.F_bk1, - F_bk0max = self.F_tile.F_bk0max, - F_rm0 = self.F_tile.F_rm0, - F_rn0 = self.F_tile.F_rn0, - F_rk0 = self.F_tile.F_rk0, - F_rm1 = self.F_tile.F_rm1, - F_rn1 = self.F_tile.F_rn1, - F_rk1 = self.F_tile.F_rk1, - F_wm0 = self.F_tile.F_wm0, - F_wn0 = self.F_tile.F_wn0, - F_wk0 = self.F_tile.F_wk0, - F_wm1 = self.F_tile.F_wm1, - F_wn1 = self.F_tile.F_wn1, - F_wk1 = self.F_tile.F_wk1, - F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], - F_spad = BOOL_MAP[self.F_pipeline.F_spad], - F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], - F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], - F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], - F_logits = BOOL_MAP[self.F_pipeline.F_logits], - F_bias = BIAS_MAP[self.F_pipeline.F_bias], - F_lse = BOOL_MAP[self.F_pipeline.F_lse], - F_dropout = BOOL_MAP[self.F_pipeline.F_dropout], - F_squant = BOOL_MAP[self.F_pipeline.F_squant], - F_squant_enum = SQUANT_MAP[self.F_pipeline.F_squant], - F_skip = BOOL_MAP[self.F_pipeline.F_skip], - F_occupancy = self.F_tile.F_occupancy, - F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], - F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], - F_mode = MODE_MAP[self.F_mode], - F_pipeline = PIPELINE_MAP[self.F_pipeline.tag], - F_trload = BOOL_MAP[self.F_pipeline.F_trload]) + # kernel_body removed - unused + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format( + F_idx=self.F_idx, + F_hdim=self.F_hdim, + F_dtype=FWD_DTYPE_MAP[self.F_dtype], + F_bm0=self.F_tile.F_bm0, + F_bn0=self.F_tile.F_bn0, + F_bk0=self.F_tile.F_bk0, + F_bn1=self.F_tile.F_bn1, + F_bk1=self.F_tile.F_bk1, + F_bk0max=self.F_tile.F_bk0max, + F_rm0=self.F_tile.F_rm0, + F_rn0=self.F_tile.F_rn0, + F_rk0=self.F_tile.F_rk0, + F_rm1=self.F_tile.F_rm1, + F_rn1=self.F_tile.F_rn1, + F_rk1=self.F_tile.F_rk1, + F_wm0=self.F_tile.F_wm0, + F_wn0=self.F_tile.F_wn0, + F_wk0=self.F_tile.F_wk0, + F_wm1=self.F_tile.F_wm1, + F_wn1=self.F_tile.F_wn1, + F_wk1=self.F_tile.F_wk1, + F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad=BOOL_MAP[self.F_pipeline.F_spad], + F_skpad=BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad=BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad], + F_logits=BOOL_MAP[self.F_pipeline.F_logits], + F_bias=BIAS_MAP[self.F_pipeline.F_bias], + F_lse=BOOL_MAP[self.F_pipeline.F_lse], + F_dropout=BOOL_MAP[self.F_pipeline.F_dropout], + F_squant=BOOL_MAP[self.F_pipeline.F_squant], + F_squant_enum=SQUANT_MAP[self.F_pipeline.F_squant], + F_skip=BOOL_MAP[self.F_pipeline.F_skip], + F_occupancy=self.F_tile.F_occupancy, + F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], + F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mode=MODE_MAP[self.F_mode], + F_pipeline=PIPELINE_MAP[self.F_pipeline.tag], + F_trload=BOOL_MAP[self.F_pipeline.F_trload], + ) @property def name(self) -> str: # TODO: we don't encode idx here - return f"fmha_vsa_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ - self.F_tile.name + '_' + self.F_pipeline.name + return ( + f"fmha_vsa_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + + self.F_tile.name + + "_" + + self.F_pipeline.name + ) @property def filename(self) -> str: @@ -527,57 +623,207 @@ def filename(self) -> str: def api_trait(self) -> FmhaFwdApiTrait: return FmhaFwdApiTrait( - pipeline_tag=self.F_pipeline.tag, - hdim=str(self.F_hdim), - dtype=self.F_dtype, - mode=self.F_mode, - bm0=self.F_tile.F_bm0, - bn0=self.F_tile.F_bn0, - bk0=self.F_tile.F_bk0, - bn1=self.F_tile.F_bn1, - bk1=self.F_tile.F_bk1, - bk0max=self.F_tile.F_bk0max, - vlayout=self.F_pipeline.F_vlayout, - mask=self.F_pipeline.F_mask, - logits=self.F_pipeline.F_logits, - bias=self.F_pipeline.F_bias, - lse=self.F_pipeline.F_lse, - dropout=self.F_pipeline.F_dropout, - squant=self.F_pipeline.F_squant, - spad=self.F_pipeline.F_spad, - skpad=self.F_pipeline.F_skpad, - dpad=self.F_pipeline.F_dpad, - dvpad=self.F_pipeline.F_dvpad, - skip=self.F_pipeline.F_skip, - tr_load=self.F_pipeline.F_trload, - constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint) + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0max=self.F_tile.F_bk0max, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + logits=self.F_pipeline.F_logits, + bias=self.F_pipeline.F_bias, + lse=self.F_pipeline.F_lse, + dropout=self.F_pipeline.F_dropout, + squant=self.F_pipeline.F_squant, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad, + skip=self.F_pipeline.F_skip, + tr_load=self.F_pipeline.F_trload, + constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint, + ) + class KernelComponentFactory: # TODO: design a more practical way to do it # this is current supported tile size per hdim @staticmethod - def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]: - if dtype == 'fp16' or dtype == 'bf16': + def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: + if dtype == "fp16" or dtype == "bf16": return { # (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], # (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), # FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), # FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], # (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - (128,128) : [FmhaFwdTileSize(16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), - FmhaFwdTileSize(32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), - FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (128, 128): [ + FmhaFwdTileSize( + 16, + 32, + 64, + 128, + 32, + 128, + 1, + 1, + 1, + 1, + 1, + 1, + 16, + 16, + 32, + 16, + 16, + 32, + -1, + ), + FmhaFwdTileSize( + 32, + 32, + 128, + 128, + 32, + 128, + 1, + 1, + 1, + 1, + 1, + 1, + 32, + 32, + 16, + 32, + 32, + 16, + -1, + ), + FmhaFwdTileSize( + 128, + 64, + 32, + 128, + 16, + 128, + 4, + 1, + 1, + 4, + 1, + 1, + 32, + 32, + 16, + 32, + 32, + 16, + -1, + ), + FmhaFwdTileSize( + 128, + 128, + 32, + 128, + 32, + 128, + 4, + 1, + 1, + 4, + 1, + 1, + 32, + 32, + 16, + 32, + 32, + 16, + -1, + ), + ], # (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], # (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], # (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], # (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], } - elif dtype == 'fp8' or dtype == 'bf8': + elif dtype == "fp8" or dtype == "bf8": return { - (64,64 ) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)], - (128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], - (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], + (64, 64): [ + FmhaFwdTileSize( + 128, + 64, + 32, + 64, + 32, + 64, + 2, + 1, + 1, + 2, + 1, + 1, + 32, + 32, + 32, + 32, + 32, + 32, + -1, + ) + ], + (128, 128): [ + FmhaFwdTileSize( + 128, + 128, + 32, + 128, + 32, + 128, + 4, + 1, + 1, + 4, + 1, + 1, + 32, + 32, + 32, + 32, + 32, + 32, + -1, + ) + ], + (256, 256): [ + FmhaFwdTileSize( + 128, + 128, + 32, + 256, + 32, + 256, + 4, + 1, + 1, + 4, + 1, + 1, + 32, + 32, + 32, + 32, + 32, + 32, + -1, + ) + ], } else: return None @@ -590,94 +836,288 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli # TODO: the order of List matters! the later in this list will be also be checked later # TODO: currently for qr pipeline, let 't' padding to appear later!! # TODO: how to design this more generic? - squant = 't' if dtype == 'fp8' else 'f' + squant = "t" if dtype == "fp8" else "f" pipelines = [] - if dtype in ['fp16', 'bf16']: - for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]): + if dtype in ["fp16", "bf16"]: + for logits, mask, bias, lse, dropout, skip in itertools.product( + ["t", "f"], + get_mask_map(mask_impl).keys(), + BIAS_MAP.keys(), + ["t", "f"], + ["t", "f"], + ["t", "f"], + ): if hdim == 256 and hdim_v == 256: # print("vsa fmha only support dim=128 now.") continue - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append( + FmhaFwdPipeline( + "qr", + "row", + "f", + "f", + "f", + "f", + logits, + bias, + lse, + dropout, + squant, + mask, + skip, + "f", + ) + ) # the below two is used for hdim vectorize load - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append( + FmhaFwdPipeline( + "qr", + "row", + "t", + "t", + "f", + "f", + logits, + bias, + lse, + dropout, + squant, + mask, + skip, + "f", + ) + ) + pipelines.append( + FmhaFwdPipeline( + "qr", + "row", + "t", + "t", + "t", + "t", + logits, + bias, + lse, + dropout, + squant, + mask, + skip, + "f", + ) + ) else: if bias == "bias": # print("vsa_fmha with bias is not implemented.") continue # TODO: rocm 6.2 compiler problem if using qr_async for bias case - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append( + FmhaFwdPipeline( + "qr", + "row", + "f", + "f", + "f", + "f", + logits, + bias, + lse, + dropout, + squant, + mask, + skip, + "f", + ) + ) + pipelines.append( + FmhaFwdPipeline( + "qr", + "row", + "t", + "t", + "t", + "t", + logits, + bias, + lse, + dropout, + squant, + mask, + skip, + "f", + ) + ) else: - pipelines.append(FmhaFwdPipeline('qr_async_vsa', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) - pipelines.append(FmhaFwdPipeline('qr_async_vsa', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append( + FmhaFwdPipeline( + "qr_async_vsa", + "row", + "t", + "f", + "t", + "t", + logits, + bias, + lse, + dropout, + squant, + mask, + skip, + "f", + ) + ) + pipelines.append( + FmhaFwdPipeline( + "qr_async_vsa", + "row", + "t", + "t", + "t", + "t", + logits, + bias, + lse, + dropout, + squant, + mask, + skip, + "f", + ) + ) # if (hdim, hdim_v) in [(64, 64), (128, 128)] and logits == "f" and bias == "no" and dropout == "f" and lse == "f" and skip == "f": # pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 't')) # pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 't')) # if receipt == 1 and bias != "bias": # pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) # TODO: cover arbitraty hdim - elif dtype in ['fp8', 'bf8']: + elif dtype in ["fp8", "bf8"]: # print("vsa fmha only support 16-bit compute.") return pipelines # no need lse/dropout kernels - for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f')) - elif dtype in ['fp8fp16', 'fp8bf16']: + for logits, mask, bias in itertools.product( + ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys() + ): + pipelines.append( + FmhaFwdPipeline( + "qr", + "col", + "f", + "f", + "f", + "f", + logits, + bias, + "f", + "f", + squant, + mask, + "f", + "f", + ) + ) + elif dtype in ["fp8fp16", "fp8bf16"]: # TODO None else: assert False return pipelines + class CustomFactory(KernelComponentFactory): @staticmethod - def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]: + def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: result = KernelComponentFactory.get_hdim_tile_size_dict(dtype) - if dtype == 'fp16' or dtype == 'bf16': + if dtype == "fp16" or dtype == "bf16": if (128, 128) in result.keys(): - result[(128, 128)].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint('get_num_blocks(128) < num_cus * min_cu_util_rate'))) + result[(128, 128)].insert( + 0, + FmhaFwdTileSize( + 64, + 128, + 64, + 128, + 64, + 128, + 4, + 1, + 1, + 4, + 1, + 1, + 16, + 16, + 16, + 16, + 16, + 16, + -1, + CppConstraint( + "get_num_blocks(128) < num_cus * min_cu_util_rate" + ), + ), + ) return result -def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: + +def get_fwd_blobs( + kernel_filter: Optional[str], receipt, optdim_list, mask_impl +) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: gen = list() api_pool = FmhaFwdApiPool(mask_impl) - factory = CustomFactory if os.environ.get('CK_TILE_FMHA_FWD_CUSTOM_FACTORY', '0') == '1' else KernelComponentFactory + factory = ( + CustomFactory + if os.environ.get("CK_TILE_FMHA_FWD_CUSTOM_FACTORY", "0") == "1" + else KernelComponentFactory + ) for dtype in FWD_DTYPE_MAP.keys(): d = factory.get_hdim_tile_size_dict(dtype) - if d == None: + if d is None: continue - #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): - for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), MODE_MAP.keys()): - for tile, pipeline in itertools.product(tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)): + # for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + for ((hdim, hdim_v), tiles), mode in itertools.product( + d.items(), MODE_MAP.keys() + ): + for tile, pipeline in itertools.product( + tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) + ): if mode == "group": - if pipeline.F_spad != 't' or pipeline.F_skpad != 't': + if pipeline.F_spad != "t" or pipeline.F_skpad != "t": # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not continue if (hdim, hdim_v) == (192, 128): # NOTE: this is used to speedup deepseek prefill case, we don't gen training - if pipeline.F_bias != 'no' or pipeline.F_dropout == 't': + if pipeline.F_bias != "no" or pipeline.F_dropout == "t": continue - if pipeline.tag != 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 != 128) or ((hdim, hdim_v) != (128, 128) and tile.F_bm0 != 128)): + if pipeline.tag != "qr_async_trload" and ( + ((hdim, hdim_v) == (128, 128) and tile.F_bn0 != 128) + or ((hdim, hdim_v) != (128, 128) and tile.F_bm0 != 128) + ): # non qr_async_trload only support km0=128 tile size when hdim is not 128 # non qr_async only support kn0=128 tile size when hdim is 128 continue - if pipeline.tag == 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 == 128) or ((hdim, hdim_v) not in [(64, 64), (128, 128)])): + if pipeline.tag == "qr_async_trload" and ( + ((hdim, hdim_v) == (128, 128) and tile.F_bn0 == 128) + or ((hdim, hdim_v) not in [(64, 64), (128, 128)]) + ): continue # logits_soft_cap is only allowed if no bias - if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): + if not ( + (pipeline.F_logits == "t" and pipeline.F_bias == "no") + or pipeline.F_logits == "f" + ): continue if pipeline.tag != "qr_async_vsa": continue - k = FmhaFwdKernel(F_idx=1, - F_hdim=hdim, - F_dtype=dtype, - F_mode=mode, - F_tile=tile, - F_pipeline=pipeline, - mask_impl=mask_impl) - if kernel_filter != '': + k = FmhaFwdKernel( + F_idx=1, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl, + ) + if kernel_filter != "": if not fnmatch.fnmatch(k.name, kernel_filter): continue if optdim_list != [-1]: @@ -685,45 +1125,45 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl continue # 2 - Flash attention integration if receipt in (2, 3): - cond = dtype in ['fp16', 'bf16'] - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_bias in ['no', 'alibi'] - cond &= pipeline.F_squant == 'f' - cond &= pipeline.F_skip == 'f' + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_bias in ["no", "alibi"] + cond &= pipeline.F_squant == "f" + cond &= pipeline.F_skip == "f" if not cond: continue # PyTorch integration elif receipt == 4: - cond = dtype in ['fp16', 'bf16'] - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_bias in ['no', 'bias'] - cond &= pipeline.F_squant == 'f' - cond &= mode == 'batch' - cond &= pipeline.F_skip == 'f' - cond &= pipeline.F_logits == 'f' + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_bias in ["no", "bias"] + cond &= pipeline.F_squant == "f" + cond &= mode == "batch" + cond &= pipeline.F_skip == "f" + cond &= pipeline.F_logits == "f" if not cond: continue # Aiter(mha_fwd) integration elif receipt == 100: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == 'batch' - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_squant == 'f' + cond = dtype in ["fp16", "bf16"] + cond &= mode == "batch" + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_squant == "f" if not cond: continue # Aiter(mha_varlen_fwd) integration elif receipt == 200: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == 'group' - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_squant == 'f' + cond = dtype in ["fp16", "bf16"] + cond &= mode == "group" + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_squant == "f" if not cond: continue # aiter::mha_fwd C++ api integration elif receipt == 600: - cond = dtype in ['fp16', 'bf16'] - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_squant == 'f' + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_squant == "f" if not cond: continue @@ -732,20 +1172,28 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl return (api_pool, gen) + def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: update_file(autogen_dir / kernel.filename, kernel.template) -def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: + +def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None: update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api) -def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: + +def write_blobs( + output_dir: Path, kernel_filter: str, receipt, optdim_list, mask_impl +) -> None: api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) for kernel in kernels: write_single_fwd_kernel(kernel, output_dir) write_fwd_api(api_pool, output_dir) -def list_blobs(file_path : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: - with file_path.open('a') as f: + +def list_blobs( + file_path: Path, kernel_filter: str, receipt, optdim_list, mask_impl +) -> None: + with file_path.open("a") as f: _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") diff --git a/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp b/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp index 613c6e7fa0c..31f59ef5167 100644 --- a/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp +++ b/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp @@ -436,194 +436,197 @@ struct fmha_jenga_fwd_args drop_seed_offset; }; -template +template auto fmha_fwd_create_kargs_and_grids(fmha_jenga_fwd_args args) { assert(args.nhead_q % args.nhead_k == 0); auto kargs = [&] { - if constexpr(VSA) { + if constexpr(VSA) + { // create group mode kernel arguments if constexpr(FmhaKernel::kIsGroupMode) { return FmhaKernel::MakeKargsImpl(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.lut_ptr, - args.valid_block_num_ptr, - args.bias_ptr, - args.rand_val_ptr, - args.lse_ptr, - args.o_ptr, - args.seqstart_q_ptr, - args.seqstart_k_ptr, - args.seqlen_k_ptr, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.scale_s, - args.scale_p, - args.scale_o, - args.logits_soft_cap, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_bias, - args.stride_randval, - args.stride_o, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - args.nhead_stride_bias, - args.nhead_stride_randval, - args.nhead_stride_lse, - args.nhead_stride_o, - args.window_size_left, - args.window_size_right, - args.mask_type, - args.min_seqlen_q, - args.p_drop, - args.s_randval, - args.drop_seed_offset); + args.k_ptr, + args.v_ptr, + args.lut_ptr, + args.valid_block_num_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_ptr, + args.o_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.scale_p, + args.scale_o, + args.logits_soft_cap, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse, + args.nhead_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.min_seqlen_q, + args.p_drop, + args.s_randval, + args.drop_seed_offset); } else { // create batch mode kernel arguments return FmhaKernel::MakeKargsImpl(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.lut_ptr, - args.valid_block_num_ptr, - args.bias_ptr, - args.rand_val_ptr, - args.lse_ptr, - args.o_ptr, - args.seqlen_q, - args.seqlen_k, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.scale_s, - args.scale_p, - args.scale_o, - args.logits_soft_cap, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_bias, - args.stride_randval, - args.stride_o, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - args.nhead_stride_bias, - args.nhead_stride_randval, - args.nhead_stride_lse, - args.nhead_stride_o, - args.batch_stride_q, - args.batch_stride_k, - args.batch_stride_v, - args.batch_stride_bias, - args.batch_stride_randval, - args.batch_stride_lse, - args.batch_stride_o, - args.window_size_left, - args.window_size_right, - args.mask_type, - args.p_drop, - args.s_randval, - args.drop_seed_offset); + args.k_ptr, + args.v_ptr, + args.lut_ptr, + args.valid_block_num_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_ptr, + args.o_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.scale_p, + args.scale_o, + args.logits_soft_cap, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_randval, + args.batch_stride_lse, + args.batch_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); } - } else { + } + else + { // create group mode kernel arguments if constexpr(FmhaKernel::kIsGroupMode) { return FmhaKernel::MakeKargsImpl(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.block_relation_onehot_ptr, - args.bias_ptr, - args.rand_val_ptr, - args.lse_ptr, - args.o_ptr, - args.seqstart_q_ptr, - args.seqstart_k_ptr, - args.seqlen_k_ptr, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.scale_s, - args.scale_p, - args.scale_o, - args.logits_soft_cap, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_bias, - args.stride_randval, - args.stride_o, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - args.nhead_stride_bias, - args.nhead_stride_randval, - args.nhead_stride_lse, - args.nhead_stride_o, - args.window_size_left, - args.window_size_right, - args.mask_type, - args.min_seqlen_q, - args.p_drop, - args.s_randval, - args.drop_seed_offset); + args.k_ptr, + args.v_ptr, + args.block_relation_onehot_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_ptr, + args.o_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.scale_p, + args.scale_o, + args.logits_soft_cap, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse, + args.nhead_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.min_seqlen_q, + args.p_drop, + args.s_randval, + args.drop_seed_offset); } else { // create batch mode kernel arguments return FmhaKernel::MakeKargsImpl(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.block_relation_onehot_ptr, - args.bias_ptr, - args.rand_val_ptr, - args.lse_ptr, - args.o_ptr, - args.seqlen_q, - args.seqlen_k, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.scale_s, - args.scale_p, - args.scale_o, - args.logits_soft_cap, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_bias, - args.stride_randval, - args.stride_o, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - args.nhead_stride_bias, - args.nhead_stride_randval, - args.nhead_stride_lse, - args.nhead_stride_o, - args.batch_stride_q, - args.batch_stride_k, - args.batch_stride_v, - args.batch_stride_bias, - args.batch_stride_randval, - args.batch_stride_lse, - args.batch_stride_o, - args.window_size_left, - args.window_size_right, - args.mask_type, - args.p_drop, - args.s_randval, - args.drop_seed_offset); + args.k_ptr, + args.v_ptr, + args.block_relation_onehot_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_ptr, + args.o_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.scale_p, + args.scale_o, + args.logits_soft_cap, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_randval, + args.batch_stride_lse, + args.batch_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); } } }(); diff --git a/example/ck_tile/50_sparse_attn/generate.py b/example/ck_tile/50_sparse_attn/generate.py index eaeb555e05b..f1bf88efd77 100644 --- a/example/ck_tile/50_sparse_attn/generate.py +++ b/example/ck_tile/50_sparse_attn/generate.py @@ -6,7 +6,6 @@ from enum import IntEnum from pathlib import Path import pkgutil -import sys from typing import List, Optional import codegen.ops @@ -16,19 +15,35 @@ class HandlerId(IntEnum): LIST_BLOBS = 0 WRITE_BLOBS = 1 + # inspect all modules under 'codegen.ops' and register API handlers ops = [] for importer, module_name, _ in pkgutil.iter_modules(codegen.ops.__path__): - full_module_name = '%s.%s' % (codegen.ops.__name__, module_name) + full_module_name = "%s.%s" % (codegen.ops.__name__, module_name) ops.append(importer.find_spec(module_name).loader.load_module(module_name)) -unwanted_prefix = 'fmha_' +unwanted_prefix = "fmha_" handlers = dict( - [(op.__name__[len(unwanted_prefix):] if op.__name__.startswith(unwanted_prefix) else op.__name__, - (op.list_blobs, op.write_blobs)) for op in ops] + [ + ( + op.__name__[len(unwanted_prefix) :] + if op.__name__.startswith(unwanted_prefix) + else op.__name__, + (op.list_blobs, op.write_blobs), + ) + for op in ops + ] ) assert 0 < len(handlers) -def write_blobs(output_dir: Optional[str], api_list : List[str], filters_list : List[str], optdim_list : List[int], receipt, mask_impl) -> None: + +def write_blobs( + output_dir: Optional[str], + api_list: List[str], + filters_list: List[str], + optdim_list: List[int], + receipt, + mask_impl, +) -> None: if output_dir is None: output_dir = Path(__file__).parent else: @@ -40,8 +55,16 @@ def write_blobs(output_dir: Optional[str], api_list : List[str], filters_list : handler = handlers[api][HandlerId.WRITE_BLOBS] handler(output_dir, kernel_filter, receipt, optdim_list, mask_impl) + # list all the files that will be generated -def list_blobs(output_file : Optional[str], api_list : List[str], filters_list : List[str], optdim_list : List[int], receipt, mask_impl) -> None: +def list_blobs( + output_file: Optional[str], + api_list: List[str], + filters_list: List[str], + optdim_list: List[int], + receipt, + mask_impl, +) -> None: assert output_file is not None file_path = Path(output_file) @@ -52,6 +75,7 @@ def list_blobs(output_file : Optional[str], api_list : List[str], filters_list : handler = handlers[api][HandlerId.LIST_BLOBS] handler(file_path, kernel_filter, receipt, optdim_list, mask_impl) + if __name__ == "__main__": parser = argparse.ArgumentParser( prog="generate", @@ -59,32 +83,29 @@ def list_blobs(output_file : Optional[str], api_list : List[str], filters_list : ) parser.add_argument( "-d", - "--direction", # we keep 'direction' option for backward compatibility + "--direction", # we keep 'direction' option for backward compatibility "-a", "--api", - default='fwd_jenga', + default="fwd_jenga", required=False, - help="supply API(s) to generate (default: fwd). separated by comma." + help="supply API(s) to generate (default: fwd). separated by comma.", ) parser.add_argument( "-o", "--output_dir", required=False, - help="write all the blobs into a directory" + help="write all the blobs into a directory", ) parser.add_argument( - "-l", - "--list_blobs", - required=False, - help="list all the kernels to a file" + "-l", "--list_blobs", required=False, help="list all the kernels to a file" ) # TODO: if using filter, must apply same value to output_dir and list_blobs parser.add_argument( "-f", "--filter", - default='', + default="", required=False, - help="filter out kernels that need to generate, using fnmatch module" + help="filter out kernels that need to generate, using fnmatch module", ) parser.add_argument( @@ -92,7 +113,7 @@ def list_blobs(output_file : Optional[str], api_list : List[str], filters_list : "--mask", default="simplified", required=False, - help="mask implementation, simplified/generic" + help="mask implementation, simplified/generic", ) parser.add_argument( @@ -100,32 +121,46 @@ def list_blobs(output_file : Optional[str], api_list : List[str], filters_list : "--receipt", default=0, required=False, - help="codegen receipt. 0: generate only 8xhdim coverage\n" + \ - " 1: generate more instance to cover all hdim\n" + \ - " 2: Only generate instance for Flash attention integration\n" + \ - " 4: Only generate instance for PyTorch integration\n" + \ - " 100-199: Only generate instance for Aiter(mha_fwd) integration\n" + \ - " 200-299: Only generate instance for Aiter(mha_varlen_fwd) integration\n" + \ - " 300-399: Only generate instance for Aiter(mha_bwd) integration\n" + \ - " 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration\n" + \ - " 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration" + help="codegen receipt. 0: generate only 8xhdim coverage\n" + + " 1: generate more instance to cover all hdim\n" + + " 2: Only generate instance for Flash attention integration\n" + + " 4: Only generate instance for PyTorch integration\n" + + " 100-199: Only generate instance for Aiter(mha_fwd) integration\n" + + " 200-299: Only generate instance for Aiter(mha_varlen_fwd) integration\n" + + " 300-399: Only generate instance for Aiter(mha_bwd) integration\n" + + " 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration\n" + + " 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration", ) parser.add_argument( "--optdim", - default='-1', + default="-1", required=False, - help="only optimize the hdim in the list. separated by comma. -1 is the default choice" + \ - "eg. --optdim=32,64,128,256" + help="only optimize the hdim in the list. separated by comma. -1 is the default choice" + + "eg. --optdim=32,64,128,256", ) args = parser.parse_args() - api_list = args.direction.split(',') - filter_list = args.filter.split(',') - filter_list.extend([''] * (len(api_list) - len(filter_list))) - optdim_list = [int(hdim) for hdim in args.optdim.split(',')] + api_list = args.direction.split(",") + filter_list = args.filter.split(",") + filter_list.extend([""] * (len(api_list) - len(filter_list))) + optdim_list = [int(hdim) for hdim in args.optdim.split(",")] if args.list_blobs is not None: - list_blobs(args.list_blobs, api_list, filter_list, optdim_list, int(args.receipt), mask_impl=args.mask) + list_blobs( + args.list_blobs, + api_list, + filter_list, + optdim_list, + int(args.receipt), + mask_impl=args.mask, + ) else: - write_blobs(args.output_dir, api_list, filter_list, optdim_list, int(args.receipt), mask_impl=args.mask) + write_blobs( + args.output_dir, + api_list, + filter_list, + optdim_list, + int(args.receipt), + mask_impl=args.mask, + ) diff --git a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cu b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cu index a9e9f4f8010..3aa3c9760c0 100644 --- a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cu +++ b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cu @@ -6,46 +6,49 @@ #include "ck_tile/core.hpp" #include "ck_tile/host/host_tensor.hpp" -ck_tile::HostTensor jenga_sparse_attention( - ck_tile::HostTensor &TQ, - ck_tile::HostTensor &TK, - ck_tile::HostTensor &TV, - ck_tile::HostTensor &Tblock_relation_onehot, - ck_tile::HostTensor &Y, - std::optional> bias = std::nullopt, - std::optional> lse = std::nullopt, - std::optional> seqstart_q = std::nullopt, - std::optional> seqstart_k = std::nullopt, - int bias_type = 0, - int batch = 0, - int nhead = 0, - int nhead_k = 0, - int seqlen_q = 0, - int seqlen_k = 0, - int hdim_q = 0, - int hdim_v = 0, - int mode = 0, - bool i_perm = true, - bool o_perm = true, - int max_seqlen_q = 0, - int max_seqlen_k = 0 -){ +ck_tile::HostTensor +jenga_sparse_attention(ck_tile::HostTensor& TQ, + ck_tile::HostTensor& TK, + ck_tile::HostTensor& TV, + ck_tile::HostTensor& Tblock_relation_onehot, + ck_tile::HostTensor& Y, + std::optional> bias = std::nullopt, + std::optional> lse = std::nullopt, + std::optional> seqstart_q = std::nullopt, + std::optional> seqstart_k = std::nullopt, + int bias_type = 0, + int batch = 0, + int nhead = 0, + int nhead_k = 0, + int seqlen_q = 0, + int seqlen_k = 0, + int hdim_q = 0, + int hdim_v = 0, + int mode = 0, + bool i_perm = true, + bool o_perm = true, + int max_seqlen_q = 0, + int max_seqlen_k = 0) +{ std::string data_type = "fp16"; - if (TQ.dtype() == ck_tile::bf16_t) { + if(TQ.dtype() == ck_tile::bf16_t) + { data_type = "bf16"; } - if (max_seqlen_q == 0) max_seqlen_q = seqlen_q; - if (max_seqlen_k == 0) max_seqlen_k = seqlen_k; - bool is_v_rowmajor = true; - int seqlen_knew = 0; - float scale_s = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); - float scale_p = 1.f; - float scale_o = 1.f; + if(max_seqlen_q == 0) + max_seqlen_q = seqlen_q; + if(max_seqlen_k == 0) + max_seqlen_k = seqlen_k; + bool is_v_rowmajor = true; + int seqlen_knew = 0; + float scale_s = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); + float scale_p = 1.f; + float scale_o = 1.f; const float logits_soft_cap = 0.0; std::string msk_str = "0"; - mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k); + mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k); const ck_tile::index_t shape_batch = (mode == 0 ? batch : 1); const ck_tile::index_t shape_seqlen_q = (mode == 0 ? seqlen_q : max_seqlen_q); @@ -53,7 +56,7 @@ ck_tile::HostTensor jenga_sparse_attention( ck_tile::stream_config stream_config{nullptr, false, // time_kernel - 0, /* log_level = */ + 0, /* log_level = */ 0, 1, false}; @@ -80,8 +83,8 @@ ck_tile::HostTensor jenga_sparse_attention( const ck_tile::index_t stride_o_acc = (hdim_v); const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); // setup nhead_stride_* arguments - const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); - const ck_tile::index_t nhead_stride_k = i_perm ? shape_seqlen_k * hdim_q : hdim_q; + const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); + const ck_tile::index_t nhead_stride_k = i_perm ? shape_seqlen_k * hdim_q : hdim_q; const ck_tile::index_t nhead_stride_knew = (i_perm ? seqlen_knew * hdim_q : hdim_q); const ck_tile::index_t nhead_stride_v = [&]() { if(is_v_rowmajor) @@ -103,25 +106,25 @@ ck_tile::HostTensor jenga_sparse_attention( const ck_tile::index_t nhead_stride_o_acc = (shape_seqlen_q * hdim_v); const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); // setup batch_stride_* arguments - const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); - const ck_tile::index_t batch_stride_k = nhead_k * shape_seqlen_k * hdim_q; - const ck_tile::index_t batch_stride_knew = (nhead_k * seqlen_knew * hdim_q); - const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * shape_seqlen_k; + const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); + const ck_tile::index_t batch_stride_k = nhead_k * shape_seqlen_k * hdim_q; + const ck_tile::index_t batch_stride_knew = (nhead_k * seqlen_knew * hdim_q); + const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * shape_seqlen_k; const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew); const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q); const ck_tile::index_t batch_stride_lse_acc = (nhead * shape_seqlen_q); - const ck_tile::index_t batch_stride_o_acc = (nhead * shape_seqlen_q * hdim_v); - const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); + const ck_tile::index_t batch_stride_o_acc = (nhead * shape_seqlen_q * hdim_v); + const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); // const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch); // setup split_stride_* arguments (only used in split-kv kernel) const ck_tile::index_t split_stride_lse_acc = (shape_seqlen_q); const ck_tile::index_t split_stride_o_acc = (shape_seqlen_q * hdim_v); - args.q_ptr = TQ.data_ptr(); - args.k_ptr = TK.data_ptr(); - args.v_ptr = TV.data_ptr(); + args.q_ptr = TQ.data_ptr(); + args.k_ptr = TK.data_ptr(); + args.v_ptr = TV.data_ptr(); args.block_relation_onehot_ptr = Tblock_relation_onehot.data_ptr(); args.batch = batch; @@ -147,11 +150,9 @@ ck_tile::HostTensor jenga_sparse_attention( args.lse_ptr = lse ? lse->data_ptr() : nullptr; args.o_ptr = Y.data_ptr(); - args.seqstart_q_ptr = - (mode == 1 ? seqstart_q->data_ptr() : nullptr); - args.seqstart_k_ptr = - (mode == 1 ? seqstart_k->data_ptr() : nullptr); - args.seqlen_k_ptr = nullptr; + args.seqstart_q_ptr = (mode == 1 ? seqstart_q->data_ptr() : nullptr); + args.seqstart_k_ptr = (mode == 1 ? seqstart_k->data_ptr() : nullptr); + args.seqlen_k_ptr = nullptr; args.seqlen_k = shape_seqlen_k; // unused in group mode (or kvcache enabled) args.max_seqlen_q = max_seqlen_q; @@ -162,7 +163,7 @@ ck_tile::HostTensor jenga_sparse_attention( args.logits_soft_cap = logits_soft_cap; - args.stride_bias =stride_bias; + args.stride_bias = stride_bias; args.stride_o = stride_o; args.nhead_stride_bias = nhead_stride_bias; args.nhead_stride_lse = nhead_stride_lse; @@ -183,7 +184,6 @@ ck_tile::HostTensor jenga_sparse_attention( args.p_drop = 0.; args.s_randval = false; - }; const auto init_traits = [&](auto& traits) { @@ -192,16 +192,14 @@ ck_tile::HostTensor jenga_sparse_attention( traits.data_type = data_type; traits.is_v_rowmajor = is_v_rowmajor; - traits.is_group_mode = (mode == 1); traits.has_logits_soft_cap = 0.f < logits_soft_cap; traits.mask_type = mask.type; traits.bias_type = static_cast(bias_type); - traits.has_lse = lse ? true: false; + traits.has_lse = lse ? true : false; traits.do_fp8_static_quant = false; traits.has_dropout = false; - }; fmha_jenga_fwd_traits fmha_traits; @@ -209,7 +207,7 @@ ck_tile::HostTensor jenga_sparse_attention( fmha_jenga_fwd_args args; init_args(args); - + fmha_jenga_fwd(fmha_traits, args, stream_config); return Y; diff --git a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.h b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.h index b8fbfdc8d8e..0e8eab8e6f2 100644 --- a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.h +++ b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.h @@ -9,53 +9,51 @@ using DataType = ck_tile::half_t; -ck_tile::HostTensor jenga_sparse_attention( - ck_tile::HostTensor &TQ, - ck_tile::HostTensor &TK, - ck_tile::HostTensor &TV, - ck_tile::HostTensor &Tblock_relation_onehot, - ck_tile::HostTensor &Y, - std::optional> bias, - std::optional> lse, - std::optional> seqstart_q, - std::optional> seqstart_k, - int bias_type, - int batch, - int nhead, - int nhead_k, - int seqlen_q, - int seqlen_k, - int hdim_q, - int hdim_v, - int mode, - bool i_perm, - bool o_perm, - int max_seqlen_q, - int max_seqlen_k -); +ck_tile::HostTensor +jenga_sparse_attention(ck_tile::HostTensor& TQ, + ck_tile::HostTensor& TK, + ck_tile::HostTensor& TV, + ck_tile::HostTensor& Tblock_relation_onehot, + ck_tile::HostTensor& Y, + std::optional> bias, + std::optional> lse, + std::optional> seqstart_q, + std::optional> seqstart_k, + int bias_type, + int batch, + int nhead, + int nhead_k, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + int mode, + bool i_perm, + bool o_perm, + int max_seqlen_q, + int max_seqlen_k); -ck_tile::HostTensor vsa_sparse_attention( - ck_tile::HostTensor &TQ, - ck_tile::HostTensor &TK, - ck_tile::HostTensor &TV, - ck_tile::HostTensor &TKV_block_idx, // LUT must be int32_t - ck_tile::HostTensor &TKV_blocks, // valid_block_num must be int32_t - ck_tile::HostTensor &Y, - std::optional> bias, - std::optional> lse, - std::optional> seqstart_q, - std::optional> seqstart_k, - int bias_type, - int batch, - int nhead, - int nhead_k, - int seqlen_q, - int seqlen_k, - int hdim_q, - int hdim_v, - int mode, - bool i_perm, - bool o_perm, - int max_seqlen_q, - int max_seqlen_k -); +ck_tile::HostTensor +vsa_sparse_attention(ck_tile::HostTensor& TQ, + ck_tile::HostTensor& TK, + ck_tile::HostTensor& TV, + ck_tile::HostTensor& TKV_block_idx, // LUT must be int32_t + ck_tile::HostTensor& TKV_blocks, // valid_block_num must be int32_t + ck_tile::HostTensor& Y, + std::optional> bias, + std::optional> lse, + std::optional> seqstart_q, + std::optional> seqstart_k, + int bias_type, + int batch, + int nhead, + int nhead_k, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + int mode, + bool i_perm, + bool o_perm, + int max_seqlen_q, + int max_seqlen_k); diff --git a/example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp b/example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp index 27b18a66960..9ac91660d2b 100644 --- a/example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp +++ b/example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp @@ -28,25 +28,30 @@ using DataType = ck_tile::half_t; // Convert block_relation_onehot to LUT format (similar to triton_block_map_to_lut_kernel) template void block_map_to_lut( - const ck_tile::HostTensor& block_map, // [B, H, Q_blocks, K_blocks] - ck_tile::HostTensor& lut, // [B, H, Q_blocks, K_blocks] - int32_t for kernel + const ck_tile::HostTensor& block_map, // [B, H, Q_blocks, K_blocks] + ck_tile::HostTensor& lut, // [B, H, Q_blocks, K_blocks] - int32_t for kernel ck_tile::HostTensor& valid_block_num, // [B, H, Q_blocks] - int32_t for kernel ck_tile::index_t num_block_k) { - auto lengths = block_map.get_lengths(); + auto lengths = block_map.get_lengths(); ck_tile::index_t B = lengths[0]; ck_tile::index_t H = lengths[1]; ck_tile::index_t Q = lengths[2]; - - for (ck_tile::index_t b = 0; b < B; ++b) { - for (ck_tile::index_t h = 0; h < H; ++h) { - for (ck_tile::index_t q = 0; q < Q; ++q) { + + for(ck_tile::index_t b = 0; b < B; ++b) + { + for(ck_tile::index_t h = 0; h < H; ++h) + { + for(ck_tile::index_t q = 0; q < Q; ++q) + { int32_t valid_count = 0; - int32_t prev_block = 0; - - for (ck_tile::index_t k = 0; k < num_block_k; ++k) { + int32_t prev_block = 0; + + for(ck_tile::index_t k = 0; k < num_block_k; ++k) + { T cur_block = block_map(b, h, q, k); - if (static_cast(cur_block) > 0.5f) { // Check if block is active + if(static_cast(cur_block) > 0.5f) + { // Check if block is active lut(b, h, q, valid_count) = static_cast(k - prev_block); valid_count++; prev_block = static_cast(k); @@ -61,59 +66,69 @@ void block_map_to_lut( // Reference implementation: blocked attention (similar to pytorch_blocked_attention) template void reference_blocked_attention( - const ck_tile::HostTensor& q, // [B, H, S_q, D] - const ck_tile::HostTensor& k, // [B, H, S_k, D] - const ck_tile::HostTensor& v, // [B, H, S_k, D_v] - const ck_tile::HostTensor& block_relation, // [B, H, Q_blocks, K_blocks] - const ck_tile::HostTensor& bias, // [B, H, S_q, S_k] - ck_tile::HostTensor& output, // [B, H, S_q, D_v] + const ck_tile::HostTensor& q, // [B, H, S_q, D] + const ck_tile::HostTensor& k, // [B, H, S_k, D] + const ck_tile::HostTensor& v, // [B, H, S_k, D_v] + const ck_tile::HostTensor& block_relation, // [B, H, Q_blocks, K_blocks] + const ck_tile::HostTensor& bias, // [B, H, S_q, S_k] + ck_tile::HostTensor& output, // [B, H, S_q, D_v] ck_tile::index_t BLKQ, ck_tile::index_t BLKK, AccT scale) { - auto q_lengths = q.get_lengths(); - ck_tile::index_t batch = q_lengths[0]; - ck_tile::index_t nhead = q_lengths[1]; + auto q_lengths = q.get_lengths(); + ck_tile::index_t batch = q_lengths[0]; + ck_tile::index_t nhead = q_lengths[1]; ck_tile::index_t seqlen_q = q_lengths[2]; - ck_tile::index_t hdim = q_lengths[3]; - - auto v_lengths = v.get_lengths(); + ck_tile::index_t hdim = q_lengths[3]; + + auto v_lengths = v.get_lengths(); ck_tile::index_t seqlen_k = v_lengths[2]; - ck_tile::index_t hdim_v = v_lengths[3]; - + ck_tile::index_t hdim_v = v_lengths[3]; + ck_tile::index_t num_q_blocks = seqlen_q / BLKQ; ck_tile::index_t num_k_blocks = seqlen_k / BLKK; - - for (ck_tile::index_t b = 0; b < batch; ++b) { - for (ck_tile::index_t h = 0; h < nhead; ++h) { - for (ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb) { + + for(ck_tile::index_t b = 0; b < batch; ++b) + { + for(ck_tile::index_t h = 0; h < nhead; ++h) + { + for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb) + { ck_tile::index_t q_start = qb * BLKQ; - ck_tile::index_t q_end = q_start + BLKQ; - + ck_tile::index_t q_end = q_start + BLKQ; + // Find relevant K blocks std::vector relevant_k_indices; - for (ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) { - if (static_cast(block_relation(b, h, qb, kb)) > 0.5f) { + for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) + { + if(static_cast(block_relation(b, h, qb, kb)) > 0.5f) + { relevant_k_indices.push_back(kb); } } - - if (relevant_k_indices.empty()) continue; - + + if(relevant_k_indices.empty()) + continue; + // For each query position in the block - for (ck_tile::index_t sq = q_start; sq < q_end; ++sq) { + for(ck_tile::index_t sq = q_start; sq < q_end; ++sq) + { // Compute attention scores for all relevant K blocks std::vector scores; AccT max_score = -std::numeric_limits::infinity(); - - for (auto kb : relevant_k_indices) { + + for(auto kb : relevant_k_indices) + { ck_tile::index_t k_start = kb * BLKK; - ck_tile::index_t k_end = k_start + BLKK; - - for (ck_tile::index_t sk = k_start; sk < k_end; ++sk) { + ck_tile::index_t k_end = k_start + BLKK; + + for(ck_tile::index_t sk = k_start; sk < k_end; ++sk) + { AccT score = 0.0f; - for (ck_tile::index_t d = 0; d < hdim; ++d) { - score += static_cast(q(b, h, sq, d)) * + for(ck_tile::index_t d = 0; d < hdim; ++d) + { + score += static_cast(q(b, h, sq, d)) * static_cast(k(b, h, sk, d)); } score = score * scale + static_cast(bias(b, h, sq, sk)); @@ -121,29 +136,33 @@ void reference_blocked_attention( max_score = std::max(max_score, score); } } - + // Softmax AccT sum_exp = 0.0f; - for (auto& s : scores) { + for(auto& s : scores) + { s = std::exp(s - max_score); sum_exp += s; } - for (auto& s : scores) { + for(auto& s : scores) + { s /= sum_exp; } - + // Compute output: P @ V - for (ck_tile::index_t dv = 0; dv < hdim_v; ++dv) { - AccT out_val = 0.0f; + for(ck_tile::index_t dv = 0; dv < hdim_v; ++dv) + { + AccT out_val = 0.0f; size_t score_idx = 0; - - for (auto kb : relevant_k_indices) { + + for(auto kb : relevant_k_indices) + { ck_tile::index_t k_start = kb * BLKK; - ck_tile::index_t k_end = k_start + BLKK; - - for (ck_tile::index_t sk = k_start; sk < k_end; ++sk) { - out_val += scores[score_idx] * - static_cast(v(b, h, sk, dv)); + ck_tile::index_t k_end = k_start + BLKK; + + for(ck_tile::index_t sk = k_start; sk < k_end; ++sk) + { + out_val += scores[score_idx] * static_cast(v(b, h, sk, dv)); score_idx++; } } @@ -160,7 +179,7 @@ template auto get_error_tolerance() { double rtol = 1e-2; - double atol = 4e-2; // Higher tolerance for bf16/fp16 + double atol = 4e-2; // Higher tolerance for bf16/fp16 return ck_tile::make_tuple(rtol, atol); } @@ -170,8 +189,7 @@ auto get_error_tolerance() auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; - arg_parser - .insert("v", "1", "0:no validation, 1:cpu validation") + arg_parser.insert("v", "1", "0:no validation, 1:cpu validation") .insert("mode", "0", "kernel mode. 0:batch, 1:group") .insert("b", "1", "batch size") .insert("h", "4", "num of head for q") @@ -203,31 +221,34 @@ template bool run_test(const ck_tile::ArgParser& arg_parser) { // Parse arguments - int do_validation = arg_parser.get_int("v"); - int mode = arg_parser.get_int("mode"); - ck_tile::index_t batch = arg_parser.get_int("b"); - ck_tile::index_t nhead = arg_parser.get_int("h"); - ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); - ck_tile::index_t seqlen_q = arg_parser.get_int("s"); - ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); - ck_tile::index_t hdim_q = arg_parser.get_int("d"); - ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); - ck_tile::index_t block_size = arg_parser.get_int("block_size"); - float sparsity = arg_parser.get_float("sparsity"); - bool i_perm = arg_parser.get_bool("iperm"); - bool o_perm = arg_parser.get_bool("operm"); - int bias_type = arg_parser.get_int("bias"); + int do_validation = arg_parser.get_int("v"); + int mode = arg_parser.get_int("mode"); + ck_tile::index_t batch = arg_parser.get_int("b"); + ck_tile::index_t nhead = arg_parser.get_int("h"); + ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); + ck_tile::index_t seqlen_q = arg_parser.get_int("s"); + ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); + ck_tile::index_t hdim_q = arg_parser.get_int("d"); + ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); + ck_tile::index_t block_size = arg_parser.get_int("block_size"); + float sparsity = arg_parser.get_float("sparsity"); + bool i_perm = arg_parser.get_bool("iperm"); + bool o_perm = arg_parser.get_bool("operm"); + int bias_type = arg_parser.get_int("bias"); [[maybe_unused]] bool store_lse = arg_parser.get_bool("lse"); - uint32_t seed = arg_parser.get_uint32("seed"); - int warmup = arg_parser.get_int("warmup"); - int repeat = arg_parser.get_int("repeat"); - [[maybe_unused]] int kname = arg_parser.get_int("kname"); + uint32_t seed = arg_parser.get_uint32("seed"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + [[maybe_unused]] int kname = arg_parser.get_int("kname"); // Handle default values - if (nhead_k < 0) nhead_k = nhead; - if (seqlen_k < 0) seqlen_k = seqlen_q; - if (hdim_v < 0) hdim_v = hdim_q; - + if(nhead_k < 0) + nhead_k = nhead; + if(seqlen_k < 0) + seqlen_k = seqlen_q; + if(hdim_v < 0) + hdim_v = hdim_q; + ck_tile::index_t BLKQ = block_size; ck_tile::index_t BLKK = block_size; @@ -238,30 +259,33 @@ bool run_test(const ck_tile::ArgParser& arg_parser) std::cout << "============================================================" << std::endl; std::cout << "[VSA Sparse Attention Test]" << std::endl; std::cout << "============================================================" << std::endl; - std::cout << " Batch: " << batch << ", nhead_q: " << nhead << ", nhead_k: " << nhead_k << std::endl; + std::cout << " Batch: " << batch << ", nhead_q: " << nhead << ", nhead_k: " << nhead_k + << std::endl; std::cout << " seqlen_q: " << seqlen_q << ", seqlen_k: " << seqlen_k << std::endl; std::cout << " hdim_q: " << hdim_q << ", hdim_v: " << hdim_v << std::endl; - std::cout << " block_size: " << block_size << " (BLKQ=" << BLKQ << ", BLKK=" << BLKK << ")" << std::endl; - std::cout << " num_q_blocks: " << num_q_blocks << ", num_k_blocks: " << num_k_blocks << std::endl; + std::cout << " block_size: " << block_size << " (BLKQ=" << BLKQ << ", BLKK=" << BLKK << ")" + << std::endl; + std::cout << " num_q_blocks: " << num_q_blocks << ", num_k_blocks: " << num_k_blocks + << std::endl; std::cout << " sparsity: " << sparsity << std::endl; std::cout << " i_perm: " << i_perm << ", o_perm: " << o_perm << std::endl; // Create host tensors (using BHSD layout when i_perm=true) // Q: [B, H, S_q, D] - // K: [B, H_k, S_k, D] + // K: [B, H_k, S_k, D] // V: [B, H_k, S_k, D_v] ck_tile::HostTensor q_host({batch, nhead, seqlen_q, hdim_q}); ck_tile::HostTensor k_host({batch, nhead_k, seqlen_k, hdim_q}); ck_tile::HostTensor v_host({batch, nhead_k, seqlen_k, hdim_v}); ck_tile::HostTensor output_host({batch, nhead, seqlen_q, hdim_v}); ck_tile::HostTensor output_ref({batch, nhead, seqlen_q, hdim_v}); - + // Bias tensor [B, H, S_q, S_k] ck_tile::HostTensor bias_host({batch, nhead, seqlen_q, seqlen_k}); // Block relation onehot: [B, H, Q_blocks, K_blocks] ck_tile::HostTensor block_relation_onehot({batch, nhead, num_q_blocks, num_k_blocks}); - + // LUT and valid_block_num (output of block_map_to_lut) - must be int32_t for kernel ck_tile::HostTensor lut_host({batch, nhead, num_q_blocks, num_k_blocks}); ck_tile::HostTensor valid_block_num_host({batch, nhead, num_q_blocks}); @@ -271,40 +295,48 @@ bool run_test(const ck_tile::ArgParser& arg_parser) ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed}(q_host); ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 1}(k_host); ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 2}(v_host); - + // Initialize bias to zero (as in Python test) std::fill(bias_host.mData.begin(), bias_host.mData.end(), static_cast(0.0f)); // Initialize block_relation_onehot with sparse pattern std::mt19937 rng(seed + 100); std::uniform_real_distribution dist(0.0f, 1.0f); - ck_tile::index_t total_blocks = 0; + ck_tile::index_t total_blocks = 0; ck_tile::index_t active_blocks = 0; - - for (ck_tile::index_t b = 0; b < batch; ++b) { - for (ck_tile::index_t h = 0; h < nhead; ++h) { - for (ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb) { - for (ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) { + + for(ck_tile::index_t b = 0; b < batch; ++b) + { + for(ck_tile::index_t h = 0; h < nhead; ++h) + { + for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb) + { + for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) + { total_blocks++; // Each Q block always attends to its diagonal K block (if exists) // Plus random blocks based on sparsity - bool is_diagonal = (qb == kb && qb < num_k_blocks); + bool is_diagonal = (qb == kb && qb < num_k_blocks); bool random_active = (dist(rng) > sparsity); - - if (is_diagonal || random_active) { + + if(is_diagonal || random_active) + { block_relation_onehot(b, h, qb, kb) = static_cast(1.0f); active_blocks++; - } else { + } + else + { block_relation_onehot(b, h, qb, kb) = static_cast(0.0f); } } } } } - - float actual_sparsity = 1.0f - static_cast(active_blocks) / static_cast(total_blocks); - std::cout << " Actual sparsity: " << actual_sparsity - << " (" << active_blocks << "/" << total_blocks << " blocks active)" << std::endl; + + float actual_sparsity = + 1.0f - static_cast(active_blocks) / static_cast(total_blocks); + std::cout << " Actual sparsity: " << actual_sparsity << " (" << active_blocks << "/" + << total_blocks << " blocks active)" << std::endl; // Convert block_relation_onehot to LUT format std::cout << "Converting block map to LUT format..." << std::endl; @@ -313,87 +345,90 @@ bool run_test(const ck_tile::ArgParser& arg_parser) // vsa_sparse_attention handles device memory internally // Optional tensors - std::optional> bias_opt = std::nullopt; - std::optional> lse_opt = std::nullopt; + std::optional> bias_opt = std::nullopt; + std::optional> lse_opt = std::nullopt; std::optional> seqstart_q_opt = std::nullopt; std::optional> seqstart_k_opt = std::nullopt; - if (bias_type != 0) { + if(bias_type != 0) + { bias_opt = bias_host; } // Run kernel std::cout << "\n--- Running VSA sparse attention kernel ---" << std::endl; - - try { + + try + { // Warmup - for (int i = 0; i < warmup; ++i) { - vsa_sparse_attention( - q_host, - k_host, - v_host, - lut_host, - valid_block_num_host, - output_host, - bias_opt, - lse_opt, - seqstart_q_opt, - seqstart_k_opt, - bias_type, - batch, - nhead, - nhead_k, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - mode, - i_perm, - o_perm, - seqlen_q, - seqlen_k - ); + for(int i = 0; i < warmup; ++i) + { + vsa_sparse_attention(q_host, + k_host, + v_host, + lut_host, + valid_block_num_host, + output_host, + bias_opt, + lse_opt, + seqstart_q_opt, + seqstart_k_opt, + bias_type, + batch, + nhead, + nhead_k, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + mode, + i_perm, + o_perm, + seqlen_q, + seqlen_k); } - + // Benchmark [[maybe_unused]] auto sync_status1 = hipDeviceSynchronize(); - auto start = std::chrono::high_resolution_clock::now(); - - for (int i = 0; i < repeat; ++i) { - vsa_sparse_attention( - q_host, - k_host, - v_host, - lut_host, - valid_block_num_host, - output_host, - bias_opt, - lse_opt, - seqstart_q_opt, - seqstart_k_opt, - bias_type, - batch, - nhead, - nhead_k, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - mode, - i_perm, - o_perm, - seqlen_q, - seqlen_k - ); + auto start = std::chrono::high_resolution_clock::now(); + + for(int i = 0; i < repeat; ++i) + { + vsa_sparse_attention(q_host, + k_host, + v_host, + lut_host, + valid_block_num_host, + output_host, + bias_opt, + lse_opt, + seqstart_q_opt, + seqstart_k_opt, + bias_type, + batch, + nhead, + nhead_k, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + mode, + i_perm, + o_perm, + seqlen_q, + seqlen_k); } - + [[maybe_unused]] auto sync_status2 = hipDeviceSynchronize(); - auto end = std::chrono::high_resolution_clock::now(); - double avg_time_ms = std::chrono::duration(end - start).count() / repeat; - - std::cout << "\n>>>> VSA sparse attention average time: " << avg_time_ms << " ms <<<<" << std::endl; + auto end = std::chrono::high_resolution_clock::now(); + double avg_time_ms = + std::chrono::duration(end - start).count() / repeat; - } catch (const std::exception& e) { + std::cout << "\n>>>> VSA sparse attention average time: " << avg_time_ms << " ms <<<<" + << std::endl; + } + catch(const std::exception& e) + { std::cerr << "Error during kernel execution: " << e.what() << std::endl; return false; } @@ -402,55 +437,65 @@ bool run_test(const ck_tile::ArgParser& arg_parser) // Validation bool pass = true; - if (do_validation) { + if(do_validation) + { std::cout << "\n--- Performing CPU validation ---" << std::endl; - + // Compute scale factor float scale = 1.0f / std::sqrt(static_cast(hdim_q)); - + // Run reference implementation std::cout << "Computing reference output..." << std::endl; - reference_blocked_attention( - q_host, k_host, v_host, - block_relation_onehot, bias_host, - output_ref, - BLKQ, BLKK, scale - ); - + reference_blocked_attention(q_host, + k_host, + v_host, + block_relation_onehot, + bias_host, + output_ref, + BLKQ, + BLKK, + scale); + // Compare results auto [rtol, atol] = get_error_tolerance(); - - float max_diff = 0.0f; + + float max_diff = 0.0f; float max_rel_diff = 0.0f; - size_t num_errors = 0; - - for (size_t i = 0; i < output_host.mData.size(); ++i) { - float gpu_val = static_cast(output_host.mData[i]); - float ref_val = static_cast(output_ref.mData[i]); - float diff = std::abs(gpu_val - ref_val); + size_t num_errors = 0; + + for(size_t i = 0; i < output_host.mData.size(); ++i) + { + float gpu_val = static_cast(output_host.mData[i]); + float ref_val = static_cast(output_ref.mData[i]); + float diff = std::abs(gpu_val - ref_val); float rel_diff = (std::abs(ref_val) > 1e-6f) ? diff / std::abs(ref_val) : diff; - - max_diff = std::max(max_diff, diff); + + max_diff = std::max(max_diff, diff); max_rel_diff = std::max(max_rel_diff, rel_diff); - - if (diff > atol && rel_diff > rtol) { + + if(diff > atol && rel_diff > rtol) + { num_errors++; - if (num_errors <= 5) { - std::cout << " Mismatch at index " << i - << ": GPU=" << gpu_val << ", Ref=" << ref_val - << ", Diff=" << diff << std::endl; + if(num_errors <= 5) + { + std::cout << " Mismatch at index " << i << ": GPU=" << gpu_val + << ", Ref=" << ref_val << ", Diff=" << diff << std::endl; } } } - + std::cout << "\nValidation results:" << std::endl; std::cout << " Max absolute difference: " << max_diff << std::endl; std::cout << " Max relative difference: " << max_rel_diff << std::endl; - std::cout << " Number of mismatches: " << num_errors << " / " << output_host.mData.size() << std::endl; - - if (num_errors == 0) { + std::cout << " Number of mismatches: " << num_errors << " / " << output_host.mData.size() + << std::endl; + + if(num_errors == 0) + { std::cout << "\n>>> VALIDATION PASSED <<<" << std::endl; - } else { + } + else + { std::cout << "\n>>> VALIDATION FAILED <<<" << std::endl; pass = false; } @@ -466,22 +511,28 @@ bool run_test(const ck_tile::ArgParser& arg_parser) int main(int argc, char* argv[]) { auto [result, arg_parser] = create_args(argc, argv); - if (!result) { + if(!result) + { std::cerr << "Failed to parse arguments" << std::endl; return -1; } std::string prec = arg_parser.get_str("prec"); - + bool test_result = false; - if (prec == "fp16") { + if(prec == "fp16") + { test_result = run_test(arg_parser); - } else if (prec == "bf16") { + } + else if(prec == "bf16") + { std::cout << "Note: Using bf16 precision" << std::endl; // For bf16, we would need to compile with DataType = ck_tile::bf16_t // For now, run with the compiled DataType test_result = run_test(arg_parser); - } else { + } + else + { std::cerr << "Unsupported precision: " << prec << std::endl; return -1; } diff --git a/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu b/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu index 3b7c3511fc2..d75b5bae657 100644 --- a/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu +++ b/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu @@ -7,52 +7,53 @@ #include "ck_tile/host/host_tensor.hpp" #include "ck_tile/host/device_memory.hpp" - -ck_tile::HostTensor vsa_sparse_attention( - ck_tile::HostTensor &TQ, - ck_tile::HostTensor &TK, - ck_tile::HostTensor &TV, - ck_tile::HostTensor &TKV_block_idx, // LUT must be int32_t - ck_tile::HostTensor &TKV_blocks, // valid_block_num must be int32_t - ck_tile::HostTensor &Y, - std::optional> bias, - std::optional> lse, - std::optional> seqstart_q, - std::optional> seqstart_k, - int bias_type, - int batch, - int nhead, - int nhead_k, - int seqlen_q, - int seqlen_k, - int hdim_q, - int hdim_v, - int mode, - bool i_perm, - bool o_perm, - int max_seqlen_q, - int max_seqlen_k -){ +ck_tile::HostTensor +vsa_sparse_attention(ck_tile::HostTensor& TQ, + ck_tile::HostTensor& TK, + ck_tile::HostTensor& TV, + ck_tile::HostTensor& TKV_block_idx, // LUT must be int32_t + ck_tile::HostTensor& TKV_blocks, // valid_block_num must be int32_t + ck_tile::HostTensor& Y, + std::optional> bias, + std::optional> lse, + std::optional> seqstart_q, + std::optional> seqstart_k, + int bias_type, + int batch, + int nhead, + int nhead_k, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + int mode, + bool i_perm, + bool o_perm, + int max_seqlen_q, + int max_seqlen_k) +{ std::string data_type = "fp16"; // DataType is determined at compile time via template - if (max_seqlen_q == 0) max_seqlen_q = seqlen_q; - if (max_seqlen_k == 0) max_seqlen_k = seqlen_k; - bool is_v_rowmajor = true; - float scale_s = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); - float scale_p = 1.f; - float scale_o = 1.f; + if(max_seqlen_q == 0) + max_seqlen_q = seqlen_q; + if(max_seqlen_k == 0) + max_seqlen_k = seqlen_k; + bool is_v_rowmajor = true; + float scale_s = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); + float scale_p = 1.f; + float scale_o = 1.f; const float logits_soft_cap = 0.0; std::string msk_str = "0"; - mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k); + mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k); const ck_tile::index_t shape_seqlen_q = (mode == 0 ? seqlen_q : max_seqlen_q); const ck_tile::index_t shape_seqlen_k = (mode == 0 ? seqlen_k : max_seqlen_k); ck_tile::stream_config stream_config{nullptr, false, // time_kernel - 0, /* log_level = */ + 0, /* log_level = */ 0, 1, false}; @@ -74,19 +75,25 @@ ck_tile::HostTensor vsa_sparse_attention( // Optional buffers ck_tile::DeviceMem bias_buf(bias ? bias->get_element_space_size_in_bytes() : 0); ck_tile::DeviceMem lse_buf(lse ? lse->get_element_space_size_in_bytes() : 0); - ck_tile::DeviceMem seqstart_q_buf(seqstart_q ? seqstart_q->get_element_space_size_in_bytes() : 0); - ck_tile::DeviceMem seqstart_k_buf(seqstart_k ? seqstart_k->get_element_space_size_in_bytes() : 0); - - if (bias) bias_buf.ToDevice(bias->data()); - if (lse) lse_buf.ToDevice(lse->data()); - if (seqstart_q) seqstart_q_buf.ToDevice(seqstart_q->data()); - if (seqstart_k) seqstart_k_buf.ToDevice(seqstart_k->data()); + ck_tile::DeviceMem seqstart_q_buf(seqstart_q ? seqstart_q->get_element_space_size_in_bytes() + : 0); + ck_tile::DeviceMem seqstart_k_buf(seqstart_k ? seqstart_k->get_element_space_size_in_bytes() + : 0); + + if(bias) + bias_buf.ToDevice(bias->data()); + if(lse) + lse_buf.ToDevice(lse->data()); + if(seqstart_q) + seqstart_q_buf.ToDevice(seqstart_q->data()); + if(seqstart_k) + seqstart_k_buf.ToDevice(seqstart_k->data()); const auto init_args = [&](auto& args) { assert(nhead % nhead_k == 0); - const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); - const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); - const ck_tile::index_t stride_v = [&]() { + const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); + const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); + const ck_tile::index_t stride_v = [&]() { if(is_v_rowmajor) return i_perm ? hdim_v : nhead_k * hdim_v; else @@ -98,7 +105,7 @@ ck_tile::HostTensor vsa_sparse_attention( // setup nhead_stride_* arguments const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); const ck_tile::index_t nhead_stride_k = i_perm ? shape_seqlen_k * hdim_q : hdim_q; - const ck_tile::index_t nhead_stride_v = [&]() { + const ck_tile::index_t nhead_stride_v = [&]() { if(is_v_rowmajor) return i_perm ? shape_seqlen_k * hdim_v : hdim_v; else @@ -110,19 +117,19 @@ ck_tile::HostTensor vsa_sparse_attention( const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); // setup batch_stride_* arguments - const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); - const ck_tile::index_t batch_stride_k = nhead_k * shape_seqlen_k * hdim_q; - const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * shape_seqlen_k; + const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); + const ck_tile::index_t batch_stride_k = nhead_k * shape_seqlen_k * hdim_q; + const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * shape_seqlen_k; const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q); - const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); + const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); // Use device buffer pointers instead of host tensor data pointers - args.q_ptr = q_buf.GetDeviceBuffer(); - args.k_ptr = k_buf.GetDeviceBuffer(); - args.v_ptr = v_buf.GetDeviceBuffer(); - args.lut_ptr = lut_buf.GetDeviceBuffer(); + args.q_ptr = q_buf.GetDeviceBuffer(); + args.k_ptr = k_buf.GetDeviceBuffer(); + args.v_ptr = v_buf.GetDeviceBuffer(); + args.lut_ptr = lut_buf.GetDeviceBuffer(); args.valid_block_num_ptr = valid_block_num_buf.GetDeviceBuffer(); args.batch = batch; @@ -148,7 +155,7 @@ ck_tile::HostTensor vsa_sparse_attention( args.seqstart_q_ptr = (mode == 1 ? seqstart_q_buf.GetDeviceBuffer() : nullptr); args.seqstart_k_ptr = (mode == 1 ? seqstart_k_buf.GetDeviceBuffer() : nullptr); - args.seqlen_k_ptr = nullptr; + args.seqlen_k_ptr = nullptr; args.seqlen_k = shape_seqlen_k; // unused in group mode (or kvcache enabled) args.max_seqlen_q = max_seqlen_q; @@ -159,7 +166,7 @@ ck_tile::HostTensor vsa_sparse_attention( args.logits_soft_cap = logits_soft_cap; - args.stride_bias =stride_bias; + args.stride_bias = stride_bias; args.stride_o = stride_o; args.nhead_stride_bias = nhead_stride_bias; args.nhead_stride_lse = nhead_stride_lse; @@ -180,7 +187,6 @@ ck_tile::HostTensor vsa_sparse_attention( args.p_drop = 0.; args.s_randval = false; - }; const auto init_traits = [&](auto& traits) { @@ -189,16 +195,14 @@ ck_tile::HostTensor vsa_sparse_attention( traits.data_type = data_type; traits.is_v_rowmajor = is_v_rowmajor; - traits.is_group_mode = (mode == 1); traits.has_logits_soft_cap = 0.f < logits_soft_cap; traits.mask_type = mask.type; traits.bias_type = static_cast(bias_type); - traits.has_lse = lse ? true: false; + traits.has_lse = lse ? true : false; traits.do_fp8_static_quant = false; traits.has_dropout = false; - }; fmha_jenga_fwd_traits fmha_traits; @@ -206,7 +210,7 @@ ck_tile::HostTensor vsa_sparse_attention( fmha_jenga_fwd_args args; init_args(args); - + fmha_vsa_fwd(fmha_traits, args, stream_config); // Copy output back to host diff --git a/include/ck_tile/ops/sparse_attn/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp b/include/ck_tile/ops/sparse_attn/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp index caac6e2126e..27cfb913279 100644 --- a/include/ck_tile/ops/sparse_attn/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp +++ b/include/ck_tile/ops/sparse_attn/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp @@ -170,7 +170,7 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga const KElementFunction& /*k_element_func*/, const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const VElementFunction& v_element_func, - const bool *block_relation_onehot_ptr, + const bool* block_relation_onehot_ptr, const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasElementFunction& bias_element_func, RandValDramBlockWindowTmp& randval_dram_block_window_tmp, @@ -237,7 +237,12 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); bool* block_relation_onehot = reinterpret_cast(smem_ptr) + GetSmemSize(); - amd_direct_load_global_to_lds(block_relation_onehot_ptr, 4*threadIdx.x, block_relation_onehot, 4*threadIdx.x, threadIdx.x/64==0, 256); + amd_direct_load_global_to_lds(block_relation_onehot_ptr, + 4 * threadIdx.x, + block_relation_onehot, + 4 * threadIdx.x, + threadIdx.x / 64 == 0, + 256); auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), q_dram_block_window_tmp.get_window_lengths(), @@ -284,7 +289,8 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); // if (threadIdx.x==0 && blockIdx.y==0) { - // printf("\nblockIdx.x : %d, seqlen_k_start: %d, seqlen_k_end: %d\n", blockIdx.x, seqlen_k_start, seqlen_k_end); + // printf("\nblockIdx.x : %d, seqlen_k_start: %d, seqlen_k_end: %d\n", blockIdx.x, + // seqlen_k_start, seqlen_k_end); // } // check early exit if no work to do @@ -352,10 +358,14 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga __builtin_amdgcn_s_barrier(); __builtin_amdgcn_sched_barrier(0); - if (block_relation_onehot[0]) { + if(block_relation_onehot[0]) + { // prefetch K tile - async_load_tile_raw( - k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, number<-1>{}, k_oob_ck, k_pre_np); + async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), + k_dram_window, + number<-1>{}, + k_oob_ck, + k_pre_np); move_tile_window(k_dram_window, {0, kK0}); __builtin_amdgcn_sched_barrier(0); } @@ -374,7 +384,7 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga // main loop do { - if (!block_relation_onehot[i_total_loops]) + if(!block_relation_onehot[i_total_loops]) { i_total_loops++; if(i_total_loops < num_total_loop) @@ -383,7 +393,8 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga move_tile_window(k_dram_block_window, {kN0, 0}); k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); - if (block_relation_onehot[i_total_loops]) { + if(block_relation_onehot[i_total_loops]) + { async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, number<-1>{}, @@ -393,10 +404,10 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga move_tile_window(k_dram_window, {0, kK0}); move_tile_window(v_dram_window, {0, kN0}); continue; - } + } break; } - + // STAGE 1, QK gemm clear_tile(s_acc); // initialize C if constexpr(k0_loops > 1) @@ -421,9 +432,10 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga sequence<(LdsSeq.at(number{}) + 1) * kN0, kK0>{})); }); __shared__ int printed_flag; - if (blockIdx.x == 0 && threadIdx.x == 0 && i_total_loops==1000) { - printed_flag = 100; - } + if(blockIdx.x == 0 && threadIdx.x == 0 && i_total_loops == 1000) + { + printed_flag = 100; + } } // TODO: this to fix a bug when loop smaller than 2, @@ -842,10 +854,10 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga typename AttentionVariantParams, typename BlockIndices> CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - const bool *block_relation_onehot_ptr, + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const bool* block_relation_onehot_ptr, const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile diff --git a/include/ck_tile/ops/sparse_attn/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp b/include/ck_tile/ops/sparse_attn/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp index f8a623f9bda..ac5b4db3519 100644 --- a/include/ck_tile/ops/sparse_attn/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp +++ b/include/ck_tile/ops/sparse_attn/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp @@ -170,7 +170,7 @@ struct BlockFmhaPipelineQRKSVSAsyncVSA const KElementFunction& /*k_element_func*/, const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const VElementFunction& v_element_func, - const int *kv_block_idx_ptr, + const int* kv_block_idx_ptr, int kv_blocks, const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasElementFunction& bias_element_func, @@ -284,7 +284,7 @@ struct BlockFmhaPipelineQRKSVSAsyncVSA // const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); const auto num_total_loop = kv_blocks; - + // check early exit if no work to do if constexpr(FmhaMask::IsMasking || kPadSeqLenK) { @@ -389,10 +389,9 @@ struct BlockFmhaPipelineQRKSVSAsyncVSA sequence<(LdsSeq.at(number{})) * kN0, 0>{}, sequence<(LdsSeq.at(number{}) + 1) * kN0, kK0>{})); }); - } //__shared__ int printed_flag; - //if (blockIdx.x == 0 && threadIdx.x == 0 && i_total_loops==1000) { + // if (blockIdx.x == 0 && threadIdx.x == 0 && i_total_loops==1000) { // printed_flag = 100; //} @@ -403,10 +402,11 @@ struct BlockFmhaPipelineQRKSVSAsyncVSA async_load_fence(); __builtin_amdgcn_s_barrier(); - - int block_idx = kv_block_idx_ptr[i_total_loops+1]; - //if (threadIdx.x==0 && blockIdx.x==0 && blockIdx.z == 101) printf("%d %d %d\n", i_total_loops, num_total_loop, block_idx); - const auto bias_tile = load_tile(bias_dram_window); // load bias tile + + int block_idx = kv_block_idx_ptr[i_total_loops + 1]; + // if (threadIdx.x==0 && blockIdx.x==0 && blockIdx.z == 101) printf("%d %d %d\n", + // i_total_loops, num_total_loop, block_idx); + const auto bias_tile = load_tile(bias_dram_window); // load bias tile auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); __builtin_amdgcn_sched_barrier(0); { // tail @@ -720,7 +720,7 @@ struct BlockFmhaPipelineQRKSVSAsyncVSA i_total_loops++; if(i_total_loops < num_total_loop) { - move_tile_window(v_dram_window, {0, kN0*(block_idx-1)}); + move_tile_window(v_dram_window, {0, kN0 * (block_idx - 1)}); // v_dram_window = // make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), // v_dram_block_window_tmp.get_window_lengths(), @@ -729,11 +729,11 @@ struct BlockFmhaPipelineQRKSVSAsyncVSA // move K tile windows move_tile_window(k_dram_block_window, {kN0 * block_idx, 0}); k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); - //k_dram_block_window = - // make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), - // k_dram_block_window_tmp.get_window_lengths(), - // {kv_block_idx[i_total_loops], 0}); - //k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); + // k_dram_block_window = + // make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + // k_dram_block_window_tmp.get_window_lengths(), + // {kv_block_idx[i_total_loops], 0}); + // k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); if constexpr(k1_loops >= 2 && LdsSeq.at(number<0>{}) == LdsSeq.at(number{})) @@ -825,10 +825,10 @@ struct BlockFmhaPipelineQRKSVSAsyncVSA typename AttentionVariantParams, typename BlockIndices> CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - const int *kv_block_idx_ptr, + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const int* kv_block_idx_ptr, int kv_blocks, const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile diff --git a/include/ck_tile/ops/sparse_attn/fmha_fwd_jenga_kernel.hpp b/include/ck_tile/ops/sparse_attn/fmha_fwd_jenga_kernel.hpp index 45d7af8ef98..7b6d49c0faf 100644 --- a/include/ck_tile/ops/sparse_attn/fmha_fwd_jenga_kernel.hpp +++ b/include/ck_tile/ops/sparse_attn/fmha_fwd_jenga_kernel.hpp @@ -1050,9 +1050,10 @@ struct FmhaFwdJengaKernel CK_TILE_DEVICE void operator()(Kargs kargs) const { // allocate LDS - __shared__ char smem_ptr[GetSmemSize()+256*sizeof(int)]; + __shared__ char smem_ptr[GetSmemSize() + 256 * sizeof(int)]; - // if (threadIdx.x==0 && blockIdx.x==0 && blockIdx.z ==0) printf("smem size: %d", int(GetSmemSize())); + // if (threadIdx.x==0 && blockIdx.x==0 && blockIdx.z ==0) printf("smem size: %d", + // int(GetSmemSize())); // divide problem const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); @@ -1160,11 +1161,14 @@ struct FmhaFwdJengaKernel reinterpret_cast(kargs.v_ptr) + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + batch_offset_v; - + // sparse mask - const bool* block_relation_onehot_ptr = reinterpret_cast(kargs.block_relation_onehot_ptr) + - static_cast(i_batch * kargs.num_head_q + i_nhead) * ck_tile::integer_divide_ceil(kargs.seqlen_q, FmhaPipeline::kM0) - * ck_tile::integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kN0) + i_tile_m * ck_tile::integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kN0); + const bool* block_relation_onehot_ptr = + reinterpret_cast(kargs.block_relation_onehot_ptr) + + static_cast(i_batch * kargs.num_head_q + i_nhead) * + ck_tile::integer_divide_ceil(kargs.seqlen_q, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kN0) + + i_tile_m * ck_tile::integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kN0); ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + static_cast(i_nhead) * kargs.nhead_stride_o + @@ -1254,7 +1258,7 @@ struct FmhaFwdJengaKernel // make_tuple(1, 1), // number<1>{}, // number<1>{}); - + // const auto valid_block_num_dram = make_naive_tensor_view( // valid_block_num_ptr, // make_tuple(kargs.seqlen_q/number{}), @@ -1262,7 +1266,6 @@ struct FmhaFwdJengaKernel // number<1>{}, // number<1>{}); - auto q_dram_window = make_tile_window( q_dram, [&]() { @@ -1286,7 +1289,7 @@ struct FmhaFwdJengaKernel // lut_dram, make_tuple(1,1), {0,0}); // auto valid_block_num_window = make_tile_window( // valid_block_num_dram, make_tuple(1), {i_tile_m}); - + /// FIXME: Before C++20, capturing structured binding variables are not supported. Remove /// following copy capture of the 'i_nhead' if in C++20 const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { @@ -1464,20 +1467,20 @@ struct FmhaFwdJengaKernel auto o_acc_tile = [&]() { // TODO: constexpr(kDoFp8StaticQuant) return FmhaPipeline{}(q_dram_window, - k_dram_window, - v_dram_window, - block_relation_onehot_ptr, - bias_dram_window, - randval_dram_window, - lse_dram_window, - mask, - position_encoding, - kargs.scale_s, - variant, - variant_params, - block_indices, - smem_ptr, - dropout); + k_dram_window, + v_dram_window, + block_relation_onehot_ptr, + bias_dram_window, + randval_dram_window, + lse_dram_window, + mask, + position_encoding, + kargs.scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + dropout); }(); // O DRAM and O DRAM window diff --git a/include/ck_tile/ops/sparse_attn/fmha_fwd_vsa_kernel.hpp b/include/ck_tile/ops/sparse_attn/fmha_fwd_vsa_kernel.hpp index 619e21b0c09..3b2c7019790 100644 --- a/include/ck_tile/ops/sparse_attn/fmha_fwd_vsa_kernel.hpp +++ b/include/ck_tile/ops/sparse_attn/fmha_fwd_vsa_kernel.hpp @@ -54,8 +54,9 @@ struct FmhaFwdVSAKernel static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum; - static constexpr bool kDoFp8StaticQuant = (QScaleEnum != ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE); - static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ; + static constexpr bool kDoFp8StaticQuant = + (QScaleEnum != ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE); + static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ; using AttentionVariant = ck_tile::remove_cvref_t; using FmhaMask = ck_tile::remove_cvref_t; @@ -993,7 +994,6 @@ struct FmhaFwdVSAKernel if constexpr(kIsGroupMode) has_padded_seqlen_k = (kargs.seqlen_k_ptr != nullptr); - if(has_padded_seqlen_k) { // const index_t num_tile_m0 = seqlen_q / kM0; @@ -1062,9 +1062,10 @@ struct FmhaFwdVSAKernel CK_TILE_DEVICE void operator()(Kargs kargs) const { // allocate LDS - __shared__ char smem_ptr[GetSmemSize()+256*sizeof(int)]; + __shared__ char smem_ptr[GetSmemSize() + 256 * sizeof(int)]; - // if (threadIdx.x==0 && blockIdx.x==0 && blockIdx.z ==0) printf("smem size: %d", int(GetSmemSize())); + // if (threadIdx.x==0 && blockIdx.x==0 && blockIdx.z ==0) printf("smem size: %d", + // int(GetSmemSize())); // divide problem const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); @@ -1172,13 +1173,19 @@ struct FmhaFwdVSAKernel reinterpret_cast(kargs.v_ptr) + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + batch_offset_v; - + // sparse mask - const int* lut_ptr = reinterpret_cast(kargs.lut_ptr) + - static_cast(i_batch * kargs.num_head_q + i_nhead) * ck_tile::integer_divide_ceil(kargs.seqlen_q, FmhaPipeline::kM0) - * ck_tile::integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kN0) + i_tile_m * ck_tile::integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kN0); - const int* valid_block_num_ptr = reinterpret_cast(kargs.valid_block_num_ptr) + - static_cast(i_batch * kargs.num_head_q + i_nhead) * ck_tile::integer_divide_ceil(kargs.seqlen_q, FmhaPipeline::kM0) + i_tile_m; + const int* lut_ptr = + reinterpret_cast(kargs.lut_ptr) + + static_cast(i_batch * kargs.num_head_q + i_nhead) * + ck_tile::integer_divide_ceil(kargs.seqlen_q, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kN0) + + i_tile_m * ck_tile::integer_divide_ceil(kargs.seqlen_k, FmhaPipeline::kN0); + const int* valid_block_num_ptr = + reinterpret_cast(kargs.valid_block_num_ptr) + + static_cast(i_batch * kargs.num_head_q + i_nhead) * + ck_tile::integer_divide_ceil(kargs.seqlen_q, FmhaPipeline::kM0) + + i_tile_m; // int valid_block_num_value = __builtin_amdgcn_readfirstlane(valid_block_num_ptr[0]); const int valid_block_num_value = valid_block_num_ptr[0]; @@ -1270,7 +1277,7 @@ struct FmhaFwdVSAKernel // make_tuple(1, 1), // number<1>{}, // number<1>{}); - + // const auto valid_block_num_dram = make_naive_tensor_view( // valid_block_num_ptr, // make_tuple(kargs.seqlen_q/number{}), @@ -1278,7 +1285,6 @@ struct FmhaFwdVSAKernel // number<1>{}, // number<1>{}); - auto q_dram_window = make_tile_window( q_dram, [&]() { @@ -1302,7 +1308,7 @@ struct FmhaFwdVSAKernel // lut_dram, make_tuple(1,1), {0,0}); // auto valid_block_num_window = make_tile_window( // valid_block_num_dram, make_tuple(1), {i_tile_m}); - + /// FIXME: Before C++20, capturing structured binding variables are not supported. Remove /// following copy capture of the 'i_nhead' if in C++20 const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { @@ -1480,21 +1486,21 @@ struct FmhaFwdVSAKernel auto o_acc_tile = [&]() { // TODO: constexpr(kDoFp8StaticQuant) return FmhaPipeline{}(q_dram_window, - k_dram_window, - v_dram_window, - lut_ptr, - valid_block_num_value, - bias_dram_window, - randval_dram_window, - lse_dram_window, - mask, - position_encoding, - kargs.scale_s, - variant, - variant_params, - block_indices, - smem_ptr, - dropout); + k_dram_window, + v_dram_window, + lut_ptr, + valid_block_num_value, + bias_dram_window, + randval_dram_window, + lse_dram_window, + mask, + position_encoding, + kargs.scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + dropout); }(); // O DRAM and O DRAM window From 3b00e4022ddec9fe5cbac9cde442b92e7276c145 Mon Sep 17 00:00:00 2001 From: Jiangyon Date: Fri, 5 Dec 2025 03:07:11 +0000 Subject: [PATCH 03/22] Add jenga test and pre-commit --- example/ck_tile/50_sparse_attn/CMakeLists.txt | 91 +++- .../codegen/ops/fmha_fwd_jenga.py | 4 +- .../50_sparse_attn/jenga_sparse_attention.cu | 130 ++--- .../50_sparse_attn/test_jenga_sparse_attn.cpp | 455 ++++++++++++++++++ .../arch/amd_buffer_addressing_builtins.hpp | 5 + include/ck_tile/ops/sparse_attn.hpp | 14 + ...ock_fmha_pipeline_qr_ks_vs_async_jenga.hpp | 15 +- .../ops/sparse_attn/fmha_fwd_jenga_kernel.hpp | 5 +- 8 files changed, 626 insertions(+), 93 deletions(-) create mode 100644 example/ck_tile/50_sparse_attn/test_jenga_sparse_attn.cpp create mode 100644 include/ck_tile/ops/sparse_attn.hpp diff --git a/example/ck_tile/50_sparse_attn/CMakeLists.txt b/example/ck_tile/50_sparse_attn/CMakeLists.txt index 533fe6587ac..e7ca29b95e4 100644 --- a/example/ck_tile/50_sparse_attn/CMakeLists.txt +++ b/example/ck_tile/50_sparse_attn/CMakeLists.txt @@ -6,7 +6,7 @@ set(INST_TARGETS ${SUPPORTED_GPU_TARGETS}) set(GPU_TARGETS ${SUPPORTED_GPU_TARGETS}) -message(STATUS "VSA Sparse Attention: SUPPORTED_GPU_TARGETS=${SUPPORTED_GPU_TARGETS}, INST_TARGETS=${INST_TARGETS}") +message(STATUS "Sparse Attention: SUPPORTED_GPU_TARGETS=${SUPPORTED_GPU_TARGETS}, INST_TARGETS=${INST_TARGETS}") list(FILTER INST_TARGETS INCLUDE REGEX "gfx9|gfx12") if(NOT INST_TARGETS) @@ -14,7 +14,7 @@ if(NOT INST_TARGETS) return() endif() -message(STATUS "Building VSA Sparse Attention for targets: ${INST_TARGETS}") +message(STATUS "Building Sparse Attention (Jenga & VSA) for targets: ${INST_TARGETS}") # Code generation scripts file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS @@ -23,14 +23,81 @@ file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS ) set_directory_properties(PROPERTIES CMAKE_CONFIGURE_DEPENDS "${CODE_GEN_SCRIPTS}") -# Code generation for VSA (receipt 600 for aiter integration) +# ============================================================================ +# Jenga Sparse Attention +# ============================================================================ +set(SPARSE_ATTN_JENGA_CODE_GEN_ARGS + ${CMAKE_CURRENT_LIST_DIR}/generate.py + --api fwd_jenga + --receipt 600 +) + +# Generate list of Jenga kernels (at configure time, only list) +execute_process( + COMMAND ${Python3_EXECUTABLE} ${SPARSE_ATTN_JENGA_CODE_GEN_ARGS} + --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/jenga_blob_list.txt + RESULT_VARIABLE ret +) +if(ret AND NOT ret EQUAL 0) + message(FATAL_ERROR "Failed to generate Jenga kernel list") +endif() + +file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/jenga_blob_list.txt SPARSE_ATTN_JENGA_GEN_BLOBS) + +# Generate Jenga kernel source files at build time +add_custom_command( + OUTPUT ${SPARSE_ATTN_JENGA_GEN_BLOBS} + COMMAND ${Python3_EXECUTABLE} ${SPARSE_ATTN_JENGA_CODE_GEN_ARGS} + --output_dir ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${CODE_GEN_SCRIPTS} + COMMENT "Generate CK Tile Jenga Sparse Attention kernels" +) + +message(STATUS "Jenga kernel files to be generated: ${SPARSE_ATTN_JENGA_GEN_BLOBS}") + +# Jenga Instances +set(SPARSE_ATTN_JENGA_INSTANCES "tile_sparse_attn_jenga_instances") + +add_library(${SPARSE_ATTN_JENGA_INSTANCES} OBJECT EXCLUDE_FROM_ALL + ${SPARSE_ATTN_JENGA_GEN_BLOBS} + ${CMAKE_CURRENT_LIST_DIR}/jenga_sparse_attention.cu +) +target_include_directories(${SPARSE_ATTN_JENGA_INSTANCES} PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn +) +set_source_files_properties(${SPARSE_ATTN_JENGA_GEN_BLOBS} PROPERTIES LANGUAGE HIP) +set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/jenga_sparse_attention.cu PROPERTIES LANGUAGE HIP) +set_property(TARGET ${SPARSE_ATTN_JENGA_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) + +target_compile_options(${SPARSE_ATTN_JENGA_INSTANCES} PRIVATE + -DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN + -DCK_TILE_FMHA_FWD_FAST_EXP2 + -Wno-undefined-func-template + -Wno-float-equal +) + +# Jenga Example executable +set(EXAMPLE_JENGA_SPARSE_ATTN "tile_example_jenga_sparse_attn") +message(DEBUG "adding example ${EXAMPLE_JENGA_SPARSE_ATTN}") +add_executable(${EXAMPLE_JENGA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_jenga_sparse_attn.cpp) +target_link_libraries(${EXAMPLE_JENGA_SPARSE_ATTN} ${SPARSE_ATTN_JENGA_INSTANCES}) +target_include_directories(${EXAMPLE_JENGA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_compile_options(${EXAMPLE_JENGA_SPARSE_ATTN} PRIVATE + -Wno-undefined-func-template + -Wno-float-equal +) + +# ============================================================================ +# VSA Sparse Attention +# ============================================================================ set(SPARSE_ATTN_VSA_CODE_GEN_ARGS ${CMAKE_CURRENT_LIST_DIR}/generate.py --api fwd_vsa --receipt 600 ) -# Generate list of VSA kernels (at configure time, only list, not generate) +# Generate list of VSA kernels (at configure time, only list) execute_process( COMMAND ${Python3_EXECUTABLE} ${SPARSE_ATTN_VSA_CODE_GEN_ARGS} --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/vsa_blob_list.txt @@ -42,7 +109,7 @@ endif() file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/vsa_blob_list.txt SPARSE_ATTN_VSA_GEN_BLOBS) -# Generate the kernel source files at build time (not configure time) +# Generate VSA kernel source files at build time add_custom_command( OUTPUT ${SPARSE_ATTN_VSA_GEN_BLOBS} COMMAND ${Python3_EXECUTABLE} ${SPARSE_ATTN_VSA_CODE_GEN_ARGS} @@ -68,7 +135,6 @@ set_source_files_properties(${SPARSE_ATTN_VSA_GEN_BLOBS} PROPERTIES LANGUAGE HIP set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/vsa_sparse_attention.cu PROPERTIES LANGUAGE HIP) set_property(TARGET ${SPARSE_ATTN_VSA_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) -# Compile options target_compile_options(${SPARSE_ATTN_VSA_INSTANCES} PRIVATE -DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN -DCK_TILE_FMHA_FWD_FAST_EXP2 @@ -76,12 +142,13 @@ target_compile_options(${SPARSE_ATTN_VSA_INSTANCES} PRIVATE -Wno-float-equal ) -# Test executable -set(TEST_VSA_SPARSE_ATTN "tile_test_vsa_sparse_attn") -add_executable(${TEST_VSA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_vsa_sparse_attn.cpp) -target_link_libraries(${TEST_VSA_SPARSE_ATTN} ${SPARSE_ATTN_VSA_INSTANCES}) -target_include_directories(${TEST_VSA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) -target_compile_options(${TEST_VSA_SPARSE_ATTN} PRIVATE +# VSA Example executable +set(EXAMPLE_VSA_SPARSE_ATTN "tile_example_vsa_sparse_attn") +message(DEBUG "adding example ${EXAMPLE_VSA_SPARSE_ATTN}") +add_executable(${EXAMPLE_VSA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_vsa_sparse_attn.cpp) +target_link_libraries(${EXAMPLE_VSA_SPARSE_ATTN} ${SPARSE_ATTN_VSA_INSTANCES}) +target_include_directories(${EXAMPLE_VSA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_compile_options(${EXAMPLE_VSA_SPARSE_ATTN} PRIVATE -Wno-undefined-func-template -Wno-float-equal ) diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py index a3a9e6bf871..ae4fd78d0fa 100644 --- a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py +++ b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py @@ -20,6 +20,7 @@ MODE_MAP, PIPELINE_ENUM_MAP, PIPELINE_MAP, + SQUANT_MAP, get_mask_check_map, get_mask_map, ) @@ -78,7 +79,7 @@ def update_file(file_path, content): false, {F_lse}, {F_dropout}, - {F_squant}, + {F_squant_enum}, {F_occupancy}, {F_skip}>; @@ -596,6 +597,7 @@ def template(self) -> str: F_lse=BOOL_MAP[self.F_pipeline.F_lse], F_dropout=BOOL_MAP[self.F_pipeline.F_dropout], F_squant=BOOL_MAP[self.F_pipeline.F_squant], + F_squant_enum=SQUANT_MAP[self.F_pipeline.F_squant], F_skip=BOOL_MAP[self.F_pipeline.F_skip], F_occupancy=self.F_tile.F_occupancy, F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], diff --git a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cu b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cu index 3aa3c9760c0..02f48ee0059 100644 --- a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cu +++ b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cu @@ -5,6 +5,7 @@ #include "fmha_fwd_trek.hpp" #include "ck_tile/core.hpp" #include "ck_tile/host/host_tensor.hpp" +#include "ck_tile/host/device_memory.hpp" ck_tile::HostTensor jenga_sparse_attention(ck_tile::HostTensor& TQ, @@ -12,36 +13,32 @@ jenga_sparse_attention(ck_tile::HostTensor& TQ, ck_tile::HostTensor& TV, ck_tile::HostTensor& Tblock_relation_onehot, ck_tile::HostTensor& Y, - std::optional> bias = std::nullopt, - std::optional> lse = std::nullopt, - std::optional> seqstart_q = std::nullopt, - std::optional> seqstart_k = std::nullopt, - int bias_type = 0, - int batch = 0, - int nhead = 0, - int nhead_k = 0, - int seqlen_q = 0, - int seqlen_k = 0, - int hdim_q = 0, - int hdim_v = 0, - int mode = 0, - bool i_perm = true, - bool o_perm = true, - int max_seqlen_q = 0, - int max_seqlen_k = 0) + std::optional> bias, + std::optional> lse, + std::optional> seqstart_q, + std::optional> seqstart_k, + int bias_type, + int batch, + int nhead, + int nhead_k, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + int mode, + bool i_perm, + bool o_perm, + int max_seqlen_q, + int max_seqlen_k) { std::string data_type = "fp16"; - if(TQ.dtype() == ck_tile::bf16_t) - { - data_type = "bf16"; - } + // DataType is determined at compile time via template if(max_seqlen_q == 0) max_seqlen_q = seqlen_q; if(max_seqlen_k == 0) max_seqlen_k = seqlen_k; bool is_v_rowmajor = true; - int seqlen_knew = 0; float scale_s = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); float scale_p = 1.f; float scale_o = 1.f; @@ -50,7 +47,6 @@ jenga_sparse_attention(ck_tile::HostTensor& TQ, std::string msk_str = "0"; mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k); - const ck_tile::index_t shape_batch = (mode == 0 ? batch : 1); const ck_tile::index_t shape_seqlen_q = (mode == 0 ? seqlen_q : max_seqlen_q); const ck_tile::index_t shape_seqlen_k = (mode == 0 ? seqlen_k : max_seqlen_k); @@ -61,71 +57,74 @@ jenga_sparse_attention(ck_tile::HostTensor& TQ, 1, false}; + // Create device memory and copy data to device + ck_tile::DeviceMem q_buf(TQ.get_element_space_size_in_bytes()); + ck_tile::DeviceMem k_buf(TK.get_element_space_size_in_bytes()); + ck_tile::DeviceMem v_buf(TV.get_element_space_size_in_bytes()); + ck_tile::DeviceMem block_relation_buf(Tblock_relation_onehot.get_element_space_size_in_bytes()); + ck_tile::DeviceMem o_buf(Y.get_element_space_size_in_bytes()); + + q_buf.ToDevice(TQ.data()); + k_buf.ToDevice(TK.data()); + v_buf.ToDevice(TV.data()); + block_relation_buf.ToDevice(Tblock_relation_onehot.data()); + + // Optional buffers + ck_tile::DeviceMem bias_buf(bias ? bias->get_element_space_size_in_bytes() : 0); + ck_tile::DeviceMem lse_buf(lse ? lse->get_element_space_size_in_bytes() : 0); + ck_tile::DeviceMem seqstart_q_buf(seqstart_q ? seqstart_q->get_element_space_size_in_bytes() : 0); + ck_tile::DeviceMem seqstart_k_buf(seqstart_k ? seqstart_k->get_element_space_size_in_bytes() : 0); + + if(bias) + bias_buf.ToDevice(bias->data()); + if(lse) + lse_buf.ToDevice(lse->data()); + if(seqstart_q) + seqstart_q_buf.ToDevice(seqstart_q->data()); + if(seqstart_k) + seqstart_k_buf.ToDevice(seqstart_k->data()); + const auto init_args = [&](auto& args) { assert(nhead % nhead_k == 0); - const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); - const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); - const ck_tile::index_t stride_knew = (i_perm ? hdim_q : nhead_k * hdim_q); - const ck_tile::index_t stride_v = [&]() { + const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); + const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); + const ck_tile::index_t stride_v = [&]() { if(is_v_rowmajor) return i_perm ? hdim_v : nhead_k * hdim_v; else return (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k); }(); - const ck_tile::index_t stride_vnew = [&]() { - if(is_v_rowmajor) - return i_perm ? hdim_v : nhead_k * hdim_v; - else - return i_perm ? seqlen_knew : nhead_k * seqlen_knew; - }(); const ck_tile::index_t stride_bias = (i_perm ? max_seqlen_k : 1 * max_seqlen_k); const ck_tile::index_t stride_randval = (max_seqlen_k); - const ck_tile::index_t stride_o_acc = (hdim_v); const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); // setup nhead_stride_* arguments - const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); - const ck_tile::index_t nhead_stride_k = i_perm ? shape_seqlen_k * hdim_q : hdim_q; - const ck_tile::index_t nhead_stride_knew = (i_perm ? seqlen_knew * hdim_q : hdim_q); - const ck_tile::index_t nhead_stride_v = [&]() { + const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); + const ck_tile::index_t nhead_stride_k = i_perm ? shape_seqlen_k * hdim_q : hdim_q; + const ck_tile::index_t nhead_stride_v = [&]() { if(is_v_rowmajor) return i_perm ? shape_seqlen_k * hdim_v : hdim_v; else return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k; }(); - const ck_tile::index_t nhead_stride_vnew = [&]() { - if(is_v_rowmajor) - return i_perm ? seqlen_knew * hdim_v : hdim_v; - else - return i_perm ? hdim_v * seqlen_knew : seqlen_knew; - }(); const ck_tile::index_t nhead_stride_bias = (i_perm ? 0 * shape_seqlen_q * max_seqlen_k : 0 * max_seqlen_k); const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; - const ck_tile::index_t nhead_stride_lse_acc = (shape_seqlen_q); - const ck_tile::index_t nhead_stride_o_acc = (shape_seqlen_q * hdim_v); const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); // setup batch_stride_* arguments const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); const ck_tile::index_t batch_stride_k = nhead_k * shape_seqlen_k * hdim_q; - const ck_tile::index_t batch_stride_knew = (nhead_k * seqlen_knew * hdim_q); const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * shape_seqlen_k; - const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew); const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q); - const ck_tile::index_t batch_stride_lse_acc = (nhead * shape_seqlen_q); - const ck_tile::index_t batch_stride_o_acc = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); - // const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch); - // setup split_stride_* arguments (only used in split-kv kernel) - const ck_tile::index_t split_stride_lse_acc = (shape_seqlen_q); - const ck_tile::index_t split_stride_o_acc = (shape_seqlen_q * hdim_v); - args.q_ptr = TQ.data_ptr(); - args.k_ptr = TK.data_ptr(); - args.v_ptr = TV.data_ptr(); - args.block_relation_onehot_ptr = Tblock_relation_onehot.data_ptr(); + // Use device buffer pointers instead of host tensor data pointers + args.q_ptr = q_buf.GetDeviceBuffer(); + args.k_ptr = k_buf.GetDeviceBuffer(); + args.v_ptr = v_buf.GetDeviceBuffer(); + args.block_relation_onehot_ptr = block_relation_buf.GetDeviceBuffer(); args.batch = batch; args.seqlen_q = shape_seqlen_q; // unused in group mode @@ -144,14 +143,12 @@ jenga_sparse_attention(ck_tile::HostTensor& TQ, args.batch_stride_k = batch_stride_k; args.batch_stride_v = batch_stride_v; - // args.bias_ptr = bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer() - // : bias_buf.GetDeviceBuffer(); - args.bias_ptr = bias ? bias->data_ptr() : nullptr; - args.lse_ptr = lse ? lse->data_ptr() : nullptr; - args.o_ptr = Y.data_ptr(); + args.bias_ptr = bias ? bias_buf.GetDeviceBuffer() : nullptr; + args.lse_ptr = lse ? lse_buf.GetDeviceBuffer() : nullptr; + args.o_ptr = o_buf.GetDeviceBuffer(); - args.seqstart_q_ptr = (mode == 1 ? seqstart_q->data_ptr() : nullptr); - args.seqstart_k_ptr = (mode == 1 ? seqstart_k->data_ptr() : nullptr); + args.seqstart_q_ptr = (mode == 1 ? seqstart_q_buf.GetDeviceBuffer() : nullptr); + args.seqstart_k_ptr = (mode == 1 ? seqstart_k_buf.GetDeviceBuffer() : nullptr); args.seqlen_k_ptr = nullptr; args.seqlen_k = shape_seqlen_k; // unused in group mode (or kvcache enabled) @@ -210,5 +207,8 @@ jenga_sparse_attention(ck_tile::HostTensor& TQ, fmha_jenga_fwd(fmha_traits, args, stream_config); + // Copy output back to host + Y = o_buf.ToHost(); + return Y; } diff --git a/example/ck_tile/50_sparse_attn/test_jenga_sparse_attn.cpp b/example/ck_tile/50_sparse_attn/test_jenga_sparse_attn.cpp new file mode 100644 index 00000000000..75e3aa0b7d5 --- /dev/null +++ b/example/ck_tile/50_sparse_attn/test_jenga_sparse_attn.cpp @@ -0,0 +1,455 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// +// Test for jenga_sparse_attention function + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "ck_tile/core.hpp" + +#include "jenga_sparse_attention.h" + +// ============================================================================ +// Helper Functions +// ============================================================================ + +// Reference implementation: blocked attention +template +void reference_blocked_attention( + const ck_tile::HostTensor& q, // [B, H, S_q, D] + const ck_tile::HostTensor& k, // [B, H, S_k, D] + const ck_tile::HostTensor& v, // [B, H, S_k, D_v] + const ck_tile::HostTensor& block_relation, // [B, H, Q_blocks, K_blocks] + const ck_tile::HostTensor& bias, // [B, H, S_q, S_k] + ck_tile::HostTensor& output, // [B, H, S_q, D_v] + ck_tile::index_t BLKQ, + ck_tile::index_t BLKK, + AccT scale) +{ + auto q_lengths = q.get_lengths(); + ck_tile::index_t batch = q_lengths[0]; + ck_tile::index_t nhead = q_lengths[1]; + ck_tile::index_t seqlen_q = q_lengths[2]; + ck_tile::index_t hdim = q_lengths[3]; + + auto v_lengths = v.get_lengths(); + ck_tile::index_t seqlen_k = v_lengths[2]; + ck_tile::index_t hdim_v = v_lengths[3]; + + ck_tile::index_t num_q_blocks = seqlen_q / BLKQ; + ck_tile::index_t num_k_blocks = seqlen_k / BLKK; + + for(ck_tile::index_t b = 0; b < batch; ++b) + { + for(ck_tile::index_t h = 0; h < nhead; ++h) + { + for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb) + { + ck_tile::index_t q_start = qb * BLKQ; + ck_tile::index_t q_end = q_start + BLKQ; + + // Find relevant K blocks + std::vector relevant_k_indices; + for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) + { + if(static_cast(block_relation(b, h, qb, kb)) > 0.5f) + { + relevant_k_indices.push_back(kb); + } + } + + if(relevant_k_indices.empty()) + continue; + + // For each query position in the block + for(ck_tile::index_t sq = q_start; sq < q_end; ++sq) + { + std::vector scores; + AccT max_score = -std::numeric_limits::infinity(); + + for(auto kb : relevant_k_indices) + { + ck_tile::index_t k_start = kb * BLKK; + ck_tile::index_t k_end = k_start + BLKK; + + for(ck_tile::index_t sk = k_start; sk < k_end; ++sk) + { + AccT score = 0.0f; + for(ck_tile::index_t d = 0; d < hdim; ++d) + { + score += static_cast(q(b, h, sq, d)) * + static_cast(k(b, h, sk, d)); + } + score = score * scale + static_cast(bias(b, h, sq, sk)); + scores.push_back(score); + max_score = std::max(max_score, score); + } + } + + // Softmax + AccT sum_exp = 0.0f; + for(auto& s : scores) + { + s = std::exp(s - max_score); + sum_exp += s; + } + for(auto& s : scores) + { + s /= sum_exp; + } + + // Compute output: P @ V + for(ck_tile::index_t dv = 0; dv < hdim_v; ++dv) + { + AccT out_val = 0.0f; + size_t score_idx = 0; + + for(auto kb : relevant_k_indices) + { + ck_tile::index_t k_start = kb * BLKK; + ck_tile::index_t k_end = k_start + BLKK; + + for(ck_tile::index_t sk = k_start; sk < k_end; ++sk) + { + out_val += scores[score_idx] * static_cast(v(b, h, sk, dv)); + score_idx++; + } + } + output(b, h, sq, dv) = static_cast(out_val); + } + } + } + } + } +} + +// ============================================================================ +// Command line argument parser +// ============================================================================ +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("v", "1", "0:no validation, 1:cpu validation") + .insert("mode", "0", "kernel mode. 0:batch, 1:group") + .insert("b", "1", "batch size") + .insert("h", "4", "num of head for q") + .insert("h_k", "-1", "num of head for k/v, -1 means equal to h") + .insert("s", "4096", "seqlen_q") + .insert("s_k", "-1", "seqlen_k, -1 means equal to s") + .insert("d", "128", "head dim for q, k") + .insert("d_v", "-1", "head dim for v, -1 means equal to d") + .insert("block_size", "128", "block size for sparse attention (BLKQ=BLKK)") + .insert("sparsity", "0.5", "sparsity ratio (0.0 = dense, 1.0 = fully sparse)") + .insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d") + .insert("operm", "1", "permute output") + .insert("bias", "0", "bias type: 0:no bias, 1:elementwise, 2:alibi") + .insert("lse", "0", "0:not store lse, 1:store lse") + .insert("seed", "42", "random seed") + .insert("warmup", "5", "warmup iterations") + .insert("repeat", "20", "benchmark iterations"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// ============================================================================ +// Main Test Function +// ============================================================================ +bool run_test(const ck_tile::ArgParser& arg_parser) +{ + using T = DataType; // Use DataType defined in header (half_t) + + // Parse arguments + int do_validation = arg_parser.get_int("v"); + int mode = arg_parser.get_int("mode"); + ck_tile::index_t batch = arg_parser.get_int("b"); + ck_tile::index_t nhead = arg_parser.get_int("h"); + ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); + ck_tile::index_t seqlen_q = arg_parser.get_int("s"); + ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); + ck_tile::index_t hdim_q = arg_parser.get_int("d"); + ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); + ck_tile::index_t block_size = arg_parser.get_int("block_size"); + float sparsity = arg_parser.get_float("sparsity"); + bool i_perm = arg_parser.get_bool("iperm"); + bool o_perm = arg_parser.get_bool("operm"); + int bias_type = arg_parser.get_int("bias"); + bool store_lse = arg_parser.get_bool("lse"); + uint32_t seed = arg_parser.get_uint32("seed"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + // Handle default values + if(nhead_k < 0) + nhead_k = nhead; + if(seqlen_k < 0) + seqlen_k = seqlen_q; + if(hdim_v < 0) + hdim_v = hdim_q; + + ck_tile::index_t BLKQ = block_size; + ck_tile::index_t BLKK = block_size; + + // Calculate number of Q and K blocks + ck_tile::index_t num_q_blocks = seqlen_q / BLKQ; + ck_tile::index_t num_k_blocks = seqlen_k / BLKK; + + std::cout << "============================================================" << std::endl; + std::cout << "[Jenga Sparse Attention Test]" << std::endl; + std::cout << "============================================================" << std::endl; + std::cout << " Batch: " << batch << ", nhead_q: " << nhead << ", nhead_k: " << nhead_k + << std::endl; + std::cout << " seqlen_q: " << seqlen_q << ", seqlen_k: " << seqlen_k << std::endl; + std::cout << " hdim_q: " << hdim_q << ", hdim_v: " << hdim_v << std::endl; + std::cout << " block_size: " << block_size << " (BLKQ=" << BLKQ << ", BLKK=" << BLKK << ")" + << std::endl; + std::cout << " num_q_blocks: " << num_q_blocks << ", num_k_blocks: " << num_k_blocks + << std::endl; + std::cout << " sparsity: " << sparsity << std::endl; + std::cout << " i_perm: " << i_perm << ", o_perm: " << o_perm << std::endl; + + // Create host tensors (using BHSD layout when i_perm=true) + ck_tile::HostTensor q_host({batch, nhead, seqlen_q, hdim_q}); + ck_tile::HostTensor k_host({batch, nhead_k, seqlen_k, hdim_q}); + ck_tile::HostTensor v_host({batch, nhead_k, seqlen_k, hdim_v}); + ck_tile::HostTensor output_host({batch, nhead, seqlen_q, hdim_v}); + ck_tile::HostTensor output_ref({batch, nhead, seqlen_q, hdim_v}); + + // Bias tensor [B, H, S_q, S_k] + ck_tile::HostTensor bias_host({batch, nhead, seqlen_q, seqlen_k}); + + // Block relation onehot: [B, H, Q_blocks, K_blocks] + ck_tile::HostTensor block_relation_onehot({batch, nhead, num_q_blocks, num_k_blocks}); + + // LSE tensor (optional) + ck_tile::HostTensor lse_host({batch, nhead, seqlen_q}); + + // Initialize tensors with random values + std::cout << "\nInitializing tensors..." << std::endl; + ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed}(q_host); + ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 1}(k_host); + ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 2}(v_host); + + // Initialize bias to zero + std::fill(bias_host.mData.begin(), bias_host.mData.end(), static_cast(0.0f)); + + // Initialize block_relation_onehot with sparse pattern + std::mt19937 rng(seed + 100); + std::uniform_real_distribution dist(0.0f, 1.0f); + ck_tile::index_t total_blocks = 0; + ck_tile::index_t active_blocks = 0; + + for(ck_tile::index_t b = 0; b < batch; ++b) + { + for(ck_tile::index_t h = 0; h < nhead; ++h) + { + for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb) + { + for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) + { + total_blocks++; + bool is_diagonal = (qb == kb && qb < num_k_blocks); + bool random_active = (dist(rng) > sparsity); + + if(is_diagonal || random_active) + { + block_relation_onehot(b, h, qb, kb) = static_cast(1.0f); + active_blocks++; + } + else + { + block_relation_onehot(b, h, qb, kb) = static_cast(0.0f); + } + } + } + } + } + + float actual_sparsity = + 1.0f - static_cast(active_blocks) / static_cast(total_blocks); + std::cout << " Actual sparsity: " << actual_sparsity << " (" << active_blocks << "/" + << total_blocks << " blocks active)" << std::endl; + + // Optional tensors + std::optional> bias_opt = std::nullopt; + std::optional> lse_opt = std::nullopt; + std::optional> seqstart_q_opt = std::nullopt; + std::optional> seqstart_k_opt = std::nullopt; + + if(bias_type != 0) + { + bias_opt = bias_host; + } + if(store_lse) + { + lse_opt = lse_host; + } + + // Run kernel + std::cout << "\n--- Running Jenga sparse attention kernel ---" << std::endl; + + try + { + // Warmup + for(int i = 0; i < warmup; ++i) + { + jenga_sparse_attention(q_host, + k_host, + v_host, + block_relation_onehot, + output_host, + bias_opt, + lse_opt, + seqstart_q_opt, + seqstart_k_opt, + bias_type, + batch, + nhead, + nhead_k, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + mode, + i_perm, + o_perm, + seqlen_q, + seqlen_k); + } + + // Benchmark + [[maybe_unused]] auto sync_status1 = hipDeviceSynchronize(); + auto start = std::chrono::high_resolution_clock::now(); + + for(int i = 0; i < repeat; ++i) + { + jenga_sparse_attention(q_host, + k_host, + v_host, + block_relation_onehot, + output_host, + bias_opt, + lse_opt, + seqstart_q_opt, + seqstart_k_opt, + bias_type, + batch, + nhead, + nhead_k, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + mode, + i_perm, + o_perm, + seqlen_q, + seqlen_k); + } + + [[maybe_unused]] auto sync_status2 = hipDeviceSynchronize(); + auto end = std::chrono::high_resolution_clock::now(); + double avg_time_ms = + std::chrono::duration(end - start).count() / repeat; + + std::cout << "\n>>>> Jenga sparse attention average time: " << avg_time_ms << " ms <<<<" + << std::endl; + } + catch(const std::exception& e) + { + std::cerr << "Error during kernel execution: " << e.what() << std::endl; + return false; + } + + // Validation + bool pass = true; + if(do_validation) + { + std::cout << "\n--- Performing CPU validation ---" << std::endl; + + float scale = 1.0f / std::sqrt(static_cast(hdim_q)); + + std::cout << "Computing reference output..." << std::endl; + reference_blocked_attention(q_host, + k_host, + v_host, + block_relation_onehot, + bias_host, + output_ref, + BLKQ, + BLKK, + scale); + + // Compare results + double rtol = 1e-2; + double atol = 4e-2; + + float max_diff = 0.0f; + float max_rel_diff = 0.0f; + size_t num_errors = 0; + + for(size_t i = 0; i < output_host.mData.size(); ++i) + { + float gpu_val = static_cast(output_host.mData[i]); + float ref_val = static_cast(output_ref.mData[i]); + float diff = std::abs(gpu_val - ref_val); + float rel_diff = (std::abs(ref_val) > 1e-6f) ? diff / std::abs(ref_val) : diff; + + max_diff = std::max(max_diff, diff); + max_rel_diff = std::max(max_rel_diff, rel_diff); + + if(diff > atol && rel_diff > rtol) + { + num_errors++; + if(num_errors <= 5) + { + std::cout << " Mismatch at index " << i << ": GPU=" << gpu_val + << ", Ref=" << ref_val << ", Diff=" << diff << std::endl; + } + } + } + + std::cout << "\nValidation results:" << std::endl; + std::cout << " Max absolute difference: " << max_diff << std::endl; + std::cout << " Max relative difference: " << max_rel_diff << std::endl; + std::cout << " Number of mismatches: " << num_errors << " / " << output_host.mData.size() + << std::endl; + + if(num_errors == 0) + { + std::cout << "\n>>> VALIDATION PASSED <<<" << std::endl; + } + else + { + std::cout << "\n>>> VALIDATION FAILED <<<" << std::endl; + pass = false; + } + } + + std::cout << "\n" << (pass ? "TEST PASSED" : "TEST FAILED") << std::endl; + return pass; +} + +// ============================================================================ +// Main +// ============================================================================ +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + { + std::cerr << "Failed to parse arguments" << std::endl; + return -1; + } + + bool test_result = run_test(arg_parser); + return test_result ? 0 : -1; +} diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index 6d7de749c90..de77cf64c11 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -43,6 +43,11 @@ __device__ inline int32_t amd_wave_read_first_lane(int32_t value) return __builtin_amdgcn_readfirstlane(value); } +__device__ inline int32_t amd_wave_read_first_lane(uintptr_t value) +{ + return __builtin_amdgcn_readfirstlane(value); +} + template , int> = 0> __device__ inline auto amd_wave_read_first_lane(const Object& obj) { diff --git a/include/ck_tile/ops/sparse_attn.hpp b/include/ck_tile/ops/sparse_attn.hpp new file mode 100644 index 00000000000..8b48003e034 --- /dev/null +++ b/include/ck_tile/ops/sparse_attn.hpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/sparse_attn/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp" +#include "ck_tile/ops/sparse_attn/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp" +#include "ck_tile/ops/sparse_attn/fmha_fwd_jenga_kernel.hpp" +#include "ck_tile/ops/sparse_attn/fmha_fwd_vsa_kernel.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/sparse_attn/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp b/include/ck_tile/ops/sparse_attn/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp index 27cfb913279..2644c0cfdc5 100644 --- a/include/ck_tile/ops/sparse_attn/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp +++ b/include/ck_tile/ops/sparse_attn/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp @@ -431,11 +431,6 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga sequence<(LdsSeq.at(number{})) * kN0, 0>{}, sequence<(LdsSeq.at(number{}) + 1) * kN0, kK0>{})); }); - __shared__ int printed_flag; - if(blockIdx.x == 0 && threadIdx.x == 0 && i_total_loops == 1000) - { - printed_flag = 100; - } } // TODO: this to fix a bug when loop smaller than 2, @@ -704,14 +699,8 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga randval_dram_window); } - const auto p = [&]() { - if constexpr(std::is_same_v) - return impl::cast_tile_pk_fp16_fp32( - tile_elementwise_in(p_compute_element_func, p_compute)); - else - return cast_tile( - tile_elementwise_in(p_compute_element_func, p_compute)); - }(); + const auto p = + cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); // STAGE 3, KV gemm if constexpr(k1_loops > 1) diff --git a/include/ck_tile/ops/sparse_attn/fmha_fwd_jenga_kernel.hpp b/include/ck_tile/ops/sparse_attn/fmha_fwd_jenga_kernel.hpp index 7b6d49c0faf..72acd0f604f 100644 --- a/include/ck_tile/ops/sparse_attn/fmha_fwd_jenga_kernel.hpp +++ b/include/ck_tile/ops/sparse_attn/fmha_fwd_jenga_kernel.hpp @@ -53,8 +53,9 @@ struct FmhaFwdJengaKernel static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; - static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; - static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ; + static constexpr bool kDoFp8StaticQuant = + (FmhaPipeline::Problem::QScaleEnum != ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE); + static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ; using AttentionVariant = ck_tile::remove_cvref_t; using FmhaMask = ck_tile::remove_cvref_t; From 997ec8f89cc7eea2386ed3a59dbde387db49995e Mon Sep 17 00:00:00 2001 From: Jiangyon Date: Tue, 16 Dec 2025 06:18:13 +0000 Subject: [PATCH 04/22] add bf16 for vsa --- .../50_sparse_attn/jenga_sparse_attention.h | 19 ++-- .../50_sparse_attn/test_vsa_sparse_attn.cpp | 100 ++++++++---------- .../50_sparse_attn/vsa_sparse_attention.cu | 55 +++++++--- 3 files changed, 99 insertions(+), 75 deletions(-) diff --git a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.h b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.h index 0e8eab8e6f2..5ebc3fb94e7 100644 --- a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.h +++ b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.h @@ -33,17 +33,18 @@ jenga_sparse_attention(ck_tile::HostTensor& TQ, int max_seqlen_q, int max_seqlen_k); -ck_tile::HostTensor -vsa_sparse_attention(ck_tile::HostTensor& TQ, - ck_tile::HostTensor& TK, - ck_tile::HostTensor& TV, +template +ck_tile::HostTensor +vsa_sparse_attention(ck_tile::HostTensor& TQ, + ck_tile::HostTensor& TK, + ck_tile::HostTensor& TV, ck_tile::HostTensor& TKV_block_idx, // LUT must be int32_t ck_tile::HostTensor& TKV_blocks, // valid_block_num must be int32_t - ck_tile::HostTensor& Y, - std::optional> bias, - std::optional> lse, - std::optional> seqstart_q, - std::optional> seqstart_k, + ck_tile::HostTensor& Y, + std::optional> bias, + std::optional> lse, + std::optional> seqstart_q, + std::optional> seqstart_k, int bias_type, int batch, int nhead, diff --git a/example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp b/example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp index 9ac91660d2b..a5c3f9f2b3b 100644 --- a/example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp +++ b/example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp @@ -15,9 +15,6 @@ #include "ck_tile/host.hpp" #include "ck_tile/core.hpp" -// Define DataType before including the header -using DataType = ck_tile::half_t; - #include "jenga_sparse_attention.h" #include "fmha_fwd_trek.hpp" @@ -363,29 +360,29 @@ bool run_test(const ck_tile::ArgParser& arg_parser) // Warmup for(int i = 0; i < warmup; ++i) { - vsa_sparse_attention(q_host, - k_host, - v_host, - lut_host, - valid_block_num_host, - output_host, - bias_opt, - lse_opt, - seqstart_q_opt, - seqstart_k_opt, - bias_type, - batch, - nhead, - nhead_k, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - mode, - i_perm, - o_perm, - seqlen_q, - seqlen_k); + vsa_sparse_attention(q_host, + k_host, + v_host, + lut_host, + valid_block_num_host, + output_host, + bias_opt, + lse_opt, + seqstart_q_opt, + seqstart_k_opt, + bias_type, + batch, + nhead, + nhead_k, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + mode, + i_perm, + o_perm, + seqlen_q, + seqlen_k); } // Benchmark @@ -394,29 +391,29 @@ bool run_test(const ck_tile::ArgParser& arg_parser) for(int i = 0; i < repeat; ++i) { - vsa_sparse_attention(q_host, - k_host, - v_host, - lut_host, - valid_block_num_host, - output_host, - bias_opt, - lse_opt, - seqstart_q_opt, - seqstart_k_opt, - bias_type, - batch, - nhead, - nhead_k, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - mode, - i_perm, - o_perm, - seqlen_q, - seqlen_k); + vsa_sparse_attention(q_host, + k_host, + v_host, + lut_host, + valid_block_num_host, + output_host, + bias_opt, + lse_opt, + seqstart_q_opt, + seqstart_k_opt, + bias_type, + batch, + nhead, + nhead_k, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + mode, + i_perm, + o_perm, + seqlen_q, + seqlen_k); } [[maybe_unused]] auto sync_status2 = hipDeviceSynchronize(); @@ -526,10 +523,7 @@ int main(int argc, char* argv[]) } else if(prec == "bf16") { - std::cout << "Note: Using bf16 precision" << std::endl; - // For bf16, we would need to compile with DataType = ck_tile::bf16_t - // For now, run with the compiled DataType - test_result = run_test(arg_parser); + test_result = run_test(arg_parser); } else { diff --git a/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu b/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu index d75b5bae657..a824e9389f1 100644 --- a/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu +++ b/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu @@ -6,18 +6,20 @@ #include "ck_tile/core.hpp" #include "ck_tile/host/host_tensor.hpp" #include "ck_tile/host/device_memory.hpp" - -ck_tile::HostTensor -vsa_sparse_attention(ck_tile::HostTensor& TQ, - ck_tile::HostTensor& TK, - ck_tile::HostTensor& TV, - ck_tile::HostTensor& TKV_block_idx, // LUT must be int32_t - ck_tile::HostTensor& TKV_blocks, // valid_block_num must be int32_t - ck_tile::HostTensor& Y, - std::optional> bias, - std::optional> lse, - std::optional> seqstart_q, - std::optional> seqstart_k, +#include + +template +ck_tile::HostTensor +vsa_sparse_attention(ck_tile::HostTensor& TQ, + ck_tile::HostTensor& TK, + ck_tile::HostTensor& TV, + ck_tile::HostTensor& TKV_block_idx, + ck_tile::HostTensor& TKV_blocks, + ck_tile::HostTensor& Y, + std::optional> bias, + std::optional> lse, + std::optional> seqstart_q, + std::optional> seqstart_k, int bias_type, int batch, int nhead, @@ -32,8 +34,12 @@ vsa_sparse_attention(ck_tile::HostTensor& TQ, int max_seqlen_q, int max_seqlen_k) { + // Determine data type string based on template parameter std::string data_type = "fp16"; - // DataType is determined at compile time via template + if constexpr(std::is_same_v) + { + data_type = "bf16"; + } if(max_seqlen_q == 0) max_seqlen_q = seqlen_q; @@ -218,3 +224,26 @@ vsa_sparse_attention(ck_tile::HostTensor& TQ, return Y; } + +// Explicit template instantiations +template ck_tile::HostTensor +vsa_sparse_attention( + ck_tile::HostTensor&, ck_tile::HostTensor&, + ck_tile::HostTensor&, ck_tile::HostTensor&, + ck_tile::HostTensor&, ck_tile::HostTensor&, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + int, int, int, int, int, int, int, int, int, bool, bool, int, int); + +template ck_tile::HostTensor +vsa_sparse_attention( + ck_tile::HostTensor&, ck_tile::HostTensor&, + ck_tile::HostTensor&, ck_tile::HostTensor&, + ck_tile::HostTensor&, ck_tile::HostTensor&, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + int, int, int, int, int, int, int, int, int, bool, bool, int, int); From 29d96a90f0caf7bf6a2ce2c3f20cc10046d9f28d Mon Sep 17 00:00:00 2001 From: Jiangyon Date: Tue, 16 Dec 2025 06:57:42 +0000 Subject: [PATCH 05/22] add jenga support bf16 --- .../50_sparse_attn/jenga_sparse_attention.cu | 55 ++++-- .../50_sparse_attn/jenga_sparse_attention.h | 23 ++- .../50_sparse_attn/test_jenga_sparse_attn.cpp | 162 ++++++++++-------- .../50_sparse_attn/vsa_sparse_attention.cu | 2 +- 4 files changed, 148 insertions(+), 94 deletions(-) diff --git a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cu b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cu index 02f48ee0059..925960a0a8f 100644 --- a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cu +++ b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cu @@ -6,17 +6,19 @@ #include "ck_tile/core.hpp" #include "ck_tile/host/host_tensor.hpp" #include "ck_tile/host/device_memory.hpp" - -ck_tile::HostTensor -jenga_sparse_attention(ck_tile::HostTensor& TQ, - ck_tile::HostTensor& TK, - ck_tile::HostTensor& TV, - ck_tile::HostTensor& Tblock_relation_onehot, - ck_tile::HostTensor& Y, - std::optional> bias, - std::optional> lse, - std::optional> seqstart_q, - std::optional> seqstart_k, +#include + +template +ck_tile::HostTensor +jenga_sparse_attention(ck_tile::HostTensor& TQ, + ck_tile::HostTensor& TK, + ck_tile::HostTensor& TV, + ck_tile::HostTensor& Tblock_relation_onehot, + ck_tile::HostTensor& Y, + std::optional> bias, + std::optional> lse, + std::optional> seqstart_q, + std::optional> seqstart_k, int bias_type, int batch, int nhead, @@ -31,8 +33,12 @@ jenga_sparse_attention(ck_tile::HostTensor& TQ, int max_seqlen_q, int max_seqlen_k) { + // Determine data type string based on template parameter std::string data_type = "fp16"; - // DataType is determined at compile time via template + if constexpr(std::is_same_v) + { + data_type = "bf16"; + } if(max_seqlen_q == 0) max_seqlen_q = seqlen_q; @@ -208,7 +214,30 @@ jenga_sparse_attention(ck_tile::HostTensor& TQ, fmha_jenga_fwd(fmha_traits, args, stream_config); // Copy output back to host - Y = o_buf.ToHost(); + Y = o_buf.ToHost(); return Y; } + +// Explicit template instantiations +template ck_tile::HostTensor +jenga_sparse_attention( + ck_tile::HostTensor&, ck_tile::HostTensor&, + ck_tile::HostTensor&, ck_tile::HostTensor&, + ck_tile::HostTensor&, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + int, int, int, int, int, int, int, int, int, bool, bool, int, int); + +template ck_tile::HostTensor +jenga_sparse_attention( + ck_tile::HostTensor&, ck_tile::HostTensor&, + ck_tile::HostTensor&, ck_tile::HostTensor&, + ck_tile::HostTensor&, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + int, int, int, int, int, int, int, int, int, bool, bool, int, int); diff --git a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.h b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.h index 5ebc3fb94e7..8fad02ce04e 100644 --- a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.h +++ b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.h @@ -7,18 +7,17 @@ #include "ck_tile/core.hpp" #include "ck_tile/host/host_tensor.hpp" -using DataType = ck_tile::half_t; - -ck_tile::HostTensor -jenga_sparse_attention(ck_tile::HostTensor& TQ, - ck_tile::HostTensor& TK, - ck_tile::HostTensor& TV, - ck_tile::HostTensor& Tblock_relation_onehot, - ck_tile::HostTensor& Y, - std::optional> bias, - std::optional> lse, - std::optional> seqstart_q, - std::optional> seqstart_k, +template +ck_tile::HostTensor +jenga_sparse_attention(ck_tile::HostTensor& TQ, + ck_tile::HostTensor& TK, + ck_tile::HostTensor& TV, + ck_tile::HostTensor& Tblock_relation_onehot, + ck_tile::HostTensor& Y, + std::optional> bias, + std::optional> lse, + std::optional> seqstart_q, + std::optional> seqstart_k, int bias_type, int batch, int nhead, diff --git a/example/ck_tile/50_sparse_attn/test_jenga_sparse_attn.cpp b/example/ck_tile/50_sparse_attn/test_jenga_sparse_attn.cpp index 75e3aa0b7d5..fa0eea5b4fc 100644 --- a/example/ck_tile/50_sparse_attn/test_jenga_sparse_attn.cpp +++ b/example/ck_tile/50_sparse_attn/test_jenga_sparse_attn.cpp @@ -131,6 +131,15 @@ void reference_blocked_attention( } } +// Get error tolerance based on data type +template +auto get_error_tolerance() +{ + double rtol = 1e-2; + double atol = 4e-2; // Higher tolerance for bf16/fp16 + return ck_tile::make_tuple(rtol, atol); +} + // ============================================================================ // Command line argument parser // ============================================================================ @@ -148,13 +157,15 @@ auto create_args(int argc, char* argv[]) .insert("d_v", "-1", "head dim for v, -1 means equal to d") .insert("block_size", "128", "block size for sparse attention (BLKQ=BLKK)") .insert("sparsity", "0.5", "sparsity ratio (0.0 = dense, 1.0 = fully sparse)") + .insert("prec", "fp16", "data type: fp16/bf16") .insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d") .insert("operm", "1", "permute output") .insert("bias", "0", "bias type: 0:no bias, 1:elementwise, 2:alibi") .insert("lse", "0", "0:not store lse, 1:store lse") .insert("seed", "42", "random seed") .insert("warmup", "5", "warmup iterations") - .insert("repeat", "20", "benchmark iterations"); + .insert("repeat", "20", "benchmark iterations") + .insert("kname", "0", "print kernel name"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -163,29 +174,29 @@ auto create_args(int argc, char* argv[]) // ============================================================================ // Main Test Function // ============================================================================ +template bool run_test(const ck_tile::ArgParser& arg_parser) { - using T = DataType; // Use DataType defined in header (half_t) - // Parse arguments - int do_validation = arg_parser.get_int("v"); - int mode = arg_parser.get_int("mode"); - ck_tile::index_t batch = arg_parser.get_int("b"); - ck_tile::index_t nhead = arg_parser.get_int("h"); - ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); - ck_tile::index_t seqlen_q = arg_parser.get_int("s"); - ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); - ck_tile::index_t hdim_q = arg_parser.get_int("d"); - ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); - ck_tile::index_t block_size = arg_parser.get_int("block_size"); - float sparsity = arg_parser.get_float("sparsity"); - bool i_perm = arg_parser.get_bool("iperm"); - bool o_perm = arg_parser.get_bool("operm"); - int bias_type = arg_parser.get_int("bias"); - bool store_lse = arg_parser.get_bool("lse"); - uint32_t seed = arg_parser.get_uint32("seed"); - int warmup = arg_parser.get_int("warmup"); - int repeat = arg_parser.get_int("repeat"); + int do_validation = arg_parser.get_int("v"); + int mode = arg_parser.get_int("mode"); + ck_tile::index_t batch = arg_parser.get_int("b"); + ck_tile::index_t nhead = arg_parser.get_int("h"); + ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); + ck_tile::index_t seqlen_q = arg_parser.get_int("s"); + ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); + ck_tile::index_t hdim_q = arg_parser.get_int("d"); + ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); + ck_tile::index_t block_size = arg_parser.get_int("block_size"); + float sparsity = arg_parser.get_float("sparsity"); + bool i_perm = arg_parser.get_bool("iperm"); + bool o_perm = arg_parser.get_bool("operm"); + int bias_type = arg_parser.get_int("bias"); + [[maybe_unused]] bool store_lse = arg_parser.get_bool("lse"); + uint32_t seed = arg_parser.get_uint32("seed"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + [[maybe_unused]] int kname = arg_parser.get_int("kname"); // Handle default values if(nhead_k < 0) @@ -301,28 +312,28 @@ bool run_test(const ck_tile::ArgParser& arg_parser) // Warmup for(int i = 0; i < warmup; ++i) { - jenga_sparse_attention(q_host, - k_host, - v_host, - block_relation_onehot, - output_host, - bias_opt, - lse_opt, - seqstart_q_opt, - seqstart_k_opt, - bias_type, - batch, - nhead, - nhead_k, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - mode, - i_perm, - o_perm, - seqlen_q, - seqlen_k); + jenga_sparse_attention(q_host, + k_host, + v_host, + block_relation_onehot, + output_host, + bias_opt, + lse_opt, + seqstart_q_opt, + seqstart_k_opt, + bias_type, + batch, + nhead, + nhead_k, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + mode, + i_perm, + o_perm, + seqlen_q, + seqlen_k); } // Benchmark @@ -331,28 +342,28 @@ bool run_test(const ck_tile::ArgParser& arg_parser) for(int i = 0; i < repeat; ++i) { - jenga_sparse_attention(q_host, - k_host, - v_host, - block_relation_onehot, - output_host, - bias_opt, - lse_opt, - seqstart_q_opt, - seqstart_k_opt, - bias_type, - batch, - nhead, - nhead_k, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - mode, - i_perm, - o_perm, - seqlen_q, - seqlen_k); + jenga_sparse_attention(q_host, + k_host, + v_host, + block_relation_onehot, + output_host, + bias_opt, + lse_opt, + seqstart_q_opt, + seqstart_k_opt, + bias_type, + batch, + nhead, + nhead_k, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + mode, + i_perm, + o_perm, + seqlen_q, + seqlen_k); } [[maybe_unused]] auto sync_status2 = hipDeviceSynchronize(); @@ -389,8 +400,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser) scale); // Compare results - double rtol = 1e-2; - double atol = 4e-2; + auto [rtol, atol] = get_error_tolerance(); float max_diff = 0.0f; float max_rel_diff = 0.0f; @@ -450,6 +460,22 @@ int main(int argc, char* argv[]) return -1; } - bool test_result = run_test(arg_parser); + std::string prec = arg_parser.get_str("prec"); + + bool test_result = false; + if(prec == "fp16") + { + test_result = run_test(arg_parser); + } + else if(prec == "bf16") + { + test_result = run_test(arg_parser); + } + else + { + std::cerr << "Unsupported precision: " << prec << std::endl; + return -1; + } + return test_result ? 0 : -1; } diff --git a/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu b/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu index a824e9389f1..e7a8fefa7a6 100644 --- a/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu +++ b/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu @@ -220,7 +220,7 @@ vsa_sparse_attention(ck_tile::HostTensor& TQ, fmha_vsa_fwd(fmha_traits, args, stream_config); // Copy output back to host - Y = o_buf.ToHost(); + Y = o_buf.ToHost(); return Y; } From 5e8a010fc612e4ecea9a3ecb94f7f2c7484f41d9 Mon Sep 17 00:00:00 2001 From: Jiangyon Date: Tue, 16 Dec 2025 07:40:45 +0000 Subject: [PATCH 06/22] remove lse arg --- .../50_sparse_attn/jenga_sparse_attention.cu | 10 +--- .../50_sparse_attn/jenga_sparse_attention.h | 2 - .../50_sparse_attn/test_jenga_sparse_attn.cpp | 48 +++++++------------ .../50_sparse_attn/test_vsa_sparse_attn.cpp | 41 +++++++--------- .../50_sparse_attn/vsa_sparse_attention.cu | 10 +--- 5 files changed, 40 insertions(+), 71 deletions(-) diff --git a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cu b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cu index 925960a0a8f..3be8cae198b 100644 --- a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cu +++ b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cu @@ -16,7 +16,6 @@ jenga_sparse_attention(ck_tile::HostTensor& TQ, ck_tile::HostTensor& Tblock_relation_onehot, ck_tile::HostTensor& Y, std::optional> bias, - std::optional> lse, std::optional> seqstart_q, std::optional> seqstart_k, int bias_type, @@ -77,14 +76,11 @@ jenga_sparse_attention(ck_tile::HostTensor& TQ, // Optional buffers ck_tile::DeviceMem bias_buf(bias ? bias->get_element_space_size_in_bytes() : 0); - ck_tile::DeviceMem lse_buf(lse ? lse->get_element_space_size_in_bytes() : 0); ck_tile::DeviceMem seqstart_q_buf(seqstart_q ? seqstart_q->get_element_space_size_in_bytes() : 0); ck_tile::DeviceMem seqstart_k_buf(seqstart_k ? seqstart_k->get_element_space_size_in_bytes() : 0); if(bias) bias_buf.ToDevice(bias->data()); - if(lse) - lse_buf.ToDevice(lse->data()); if(seqstart_q) seqstart_q_buf.ToDevice(seqstart_q->data()); if(seqstart_k) @@ -150,7 +146,7 @@ jenga_sparse_attention(ck_tile::HostTensor& TQ, args.batch_stride_v = batch_stride_v; args.bias_ptr = bias ? bias_buf.GetDeviceBuffer() : nullptr; - args.lse_ptr = lse ? lse_buf.GetDeviceBuffer() : nullptr; + args.lse_ptr = nullptr; args.o_ptr = o_buf.GetDeviceBuffer(); args.seqstart_q_ptr = (mode == 1 ? seqstart_q_buf.GetDeviceBuffer() : nullptr); @@ -199,7 +195,7 @@ jenga_sparse_attention(ck_tile::HostTensor& TQ, traits.has_logits_soft_cap = 0.f < logits_soft_cap; traits.mask_type = mask.type; traits.bias_type = static_cast(bias_type); - traits.has_lse = lse ? true : false; + traits.has_lse = false; traits.do_fp8_static_quant = false; traits.has_dropout = false; @@ -228,7 +224,6 @@ jenga_sparse_attention( std::optional>, std::optional>, std::optional>, - std::optional>, int, int, int, int, int, int, int, int, int, bool, bool, int, int); template ck_tile::HostTensor @@ -239,5 +234,4 @@ jenga_sparse_attention( std::optional>, std::optional>, std::optional>, - std::optional>, int, int, int, int, int, int, int, int, int, bool, bool, int, int); diff --git a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.h b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.h index 8fad02ce04e..2f0be76bf54 100644 --- a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.h +++ b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.h @@ -15,7 +15,6 @@ jenga_sparse_attention(ck_tile::HostTensor& TQ, ck_tile::HostTensor& Tblock_relation_onehot, ck_tile::HostTensor& Y, std::optional> bias, - std::optional> lse, std::optional> seqstart_q, std::optional> seqstart_k, int bias_type, @@ -41,7 +40,6 @@ vsa_sparse_attention(ck_tile::HostTensor& TQ, ck_tile::HostTensor& TKV_blocks, // valid_block_num must be int32_t ck_tile::HostTensor& Y, std::optional> bias, - std::optional> lse, std::optional> seqstart_q, std::optional> seqstart_k, int bias_type, diff --git a/example/ck_tile/50_sparse_attn/test_jenga_sparse_attn.cpp b/example/ck_tile/50_sparse_attn/test_jenga_sparse_attn.cpp index fa0eea5b4fc..0acde4bed3b 100644 --- a/example/ck_tile/50_sparse_attn/test_jenga_sparse_attn.cpp +++ b/example/ck_tile/50_sparse_attn/test_jenga_sparse_attn.cpp @@ -161,7 +161,6 @@ auto create_args(int argc, char* argv[]) .insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d") .insert("operm", "1", "permute output") .insert("bias", "0", "bias type: 0:no bias, 1:elementwise, 2:alibi") - .insert("lse", "0", "0:not store lse, 1:store lse") .insert("seed", "42", "random seed") .insert("warmup", "5", "warmup iterations") .insert("repeat", "20", "benchmark iterations") @@ -178,25 +177,24 @@ template bool run_test(const ck_tile::ArgParser& arg_parser) { // Parse arguments - int do_validation = arg_parser.get_int("v"); - int mode = arg_parser.get_int("mode"); - ck_tile::index_t batch = arg_parser.get_int("b"); - ck_tile::index_t nhead = arg_parser.get_int("h"); - ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); - ck_tile::index_t seqlen_q = arg_parser.get_int("s"); - ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); - ck_tile::index_t hdim_q = arg_parser.get_int("d"); - ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); - ck_tile::index_t block_size = arg_parser.get_int("block_size"); - float sparsity = arg_parser.get_float("sparsity"); - bool i_perm = arg_parser.get_bool("iperm"); - bool o_perm = arg_parser.get_bool("operm"); - int bias_type = arg_parser.get_int("bias"); - [[maybe_unused]] bool store_lse = arg_parser.get_bool("lse"); - uint32_t seed = arg_parser.get_uint32("seed"); - int warmup = arg_parser.get_int("warmup"); - int repeat = arg_parser.get_int("repeat"); - [[maybe_unused]] int kname = arg_parser.get_int("kname"); + int do_validation = arg_parser.get_int("v"); + int mode = arg_parser.get_int("mode"); + ck_tile::index_t batch = arg_parser.get_int("b"); + ck_tile::index_t nhead = arg_parser.get_int("h"); + ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); + ck_tile::index_t seqlen_q = arg_parser.get_int("s"); + ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); + ck_tile::index_t hdim_q = arg_parser.get_int("d"); + ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); + ck_tile::index_t block_size = arg_parser.get_int("block_size"); + float sparsity = arg_parser.get_float("sparsity"); + bool i_perm = arg_parser.get_bool("iperm"); + bool o_perm = arg_parser.get_bool("operm"); + int bias_type = arg_parser.get_int("bias"); + uint32_t seed = arg_parser.get_uint32("seed"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + [[maybe_unused]] int kname = arg_parser.get_int("kname"); // Handle default values if(nhead_k < 0) @@ -240,9 +238,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser) // Block relation onehot: [B, H, Q_blocks, K_blocks] ck_tile::HostTensor block_relation_onehot({batch, nhead, num_q_blocks, num_k_blocks}); - // LSE tensor (optional) - ck_tile::HostTensor lse_host({batch, nhead, seqlen_q}); - // Initialize tensors with random values std::cout << "\nInitializing tensors..." << std::endl; ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed}(q_host); @@ -291,7 +286,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser) // Optional tensors std::optional> bias_opt = std::nullopt; - std::optional> lse_opt = std::nullopt; std::optional> seqstart_q_opt = std::nullopt; std::optional> seqstart_k_opt = std::nullopt; @@ -299,10 +293,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser) { bias_opt = bias_host; } - if(store_lse) - { - lse_opt = lse_host; - } // Run kernel std::cout << "\n--- Running Jenga sparse attention kernel ---" << std::endl; @@ -318,7 +308,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser) block_relation_onehot, output_host, bias_opt, - lse_opt, seqstart_q_opt, seqstart_k_opt, bias_type, @@ -348,7 +337,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser) block_relation_onehot, output_host, bias_opt, - lse_opt, seqstart_q_opt, seqstart_k_opt, bias_type, diff --git a/example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp b/example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp index a5c3f9f2b3b..be4653c994a 100644 --- a/example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp +++ b/example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp @@ -201,7 +201,6 @@ auto create_args(int argc, char* argv[]) .insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d") .insert("operm", "1", "permute output") .insert("bias", "0", "bias type: 0:no bias, 1:elementwise, 2:alibi") - .insert("lse", "0", "0:not store lse, 1:store lse") .insert("seed", "42", "random seed") .insert("warmup", "5", "warmup iterations") .insert("repeat", "20", "benchmark iterations") @@ -218,25 +217,24 @@ template bool run_test(const ck_tile::ArgParser& arg_parser) { // Parse arguments - int do_validation = arg_parser.get_int("v"); - int mode = arg_parser.get_int("mode"); - ck_tile::index_t batch = arg_parser.get_int("b"); - ck_tile::index_t nhead = arg_parser.get_int("h"); - ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); - ck_tile::index_t seqlen_q = arg_parser.get_int("s"); - ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); - ck_tile::index_t hdim_q = arg_parser.get_int("d"); - ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); - ck_tile::index_t block_size = arg_parser.get_int("block_size"); - float sparsity = arg_parser.get_float("sparsity"); - bool i_perm = arg_parser.get_bool("iperm"); - bool o_perm = arg_parser.get_bool("operm"); - int bias_type = arg_parser.get_int("bias"); - [[maybe_unused]] bool store_lse = arg_parser.get_bool("lse"); - uint32_t seed = arg_parser.get_uint32("seed"); - int warmup = arg_parser.get_int("warmup"); - int repeat = arg_parser.get_int("repeat"); - [[maybe_unused]] int kname = arg_parser.get_int("kname"); + int do_validation = arg_parser.get_int("v"); + int mode = arg_parser.get_int("mode"); + ck_tile::index_t batch = arg_parser.get_int("b"); + ck_tile::index_t nhead = arg_parser.get_int("h"); + ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); + ck_tile::index_t seqlen_q = arg_parser.get_int("s"); + ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); + ck_tile::index_t hdim_q = arg_parser.get_int("d"); + ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); + ck_tile::index_t block_size = arg_parser.get_int("block_size"); + float sparsity = arg_parser.get_float("sparsity"); + bool i_perm = arg_parser.get_bool("iperm"); + bool o_perm = arg_parser.get_bool("operm"); + int bias_type = arg_parser.get_int("bias"); + uint32_t seed = arg_parser.get_uint32("seed"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + [[maybe_unused]] int kname = arg_parser.get_int("kname"); // Handle default values if(nhead_k < 0) @@ -343,7 +341,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser) // Optional tensors std::optional> bias_opt = std::nullopt; - std::optional> lse_opt = std::nullopt; std::optional> seqstart_q_opt = std::nullopt; std::optional> seqstart_k_opt = std::nullopt; @@ -367,7 +364,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser) valid_block_num_host, output_host, bias_opt, - lse_opt, seqstart_q_opt, seqstart_k_opt, bias_type, @@ -398,7 +394,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser) valid_block_num_host, output_host, bias_opt, - lse_opt, seqstart_q_opt, seqstart_k_opt, bias_type, diff --git a/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu b/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu index e7a8fefa7a6..e3199b444fb 100644 --- a/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu +++ b/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu @@ -17,7 +17,6 @@ vsa_sparse_attention(ck_tile::HostTensor& TQ, ck_tile::HostTensor& TKV_blocks, ck_tile::HostTensor& Y, std::optional> bias, - std::optional> lse, std::optional> seqstart_q, std::optional> seqstart_k, int bias_type, @@ -80,7 +79,6 @@ vsa_sparse_attention(ck_tile::HostTensor& TQ, // Optional buffers ck_tile::DeviceMem bias_buf(bias ? bias->get_element_space_size_in_bytes() : 0); - ck_tile::DeviceMem lse_buf(lse ? lse->get_element_space_size_in_bytes() : 0); ck_tile::DeviceMem seqstart_q_buf(seqstart_q ? seqstart_q->get_element_space_size_in_bytes() : 0); ck_tile::DeviceMem seqstart_k_buf(seqstart_k ? seqstart_k->get_element_space_size_in_bytes() @@ -88,8 +86,6 @@ vsa_sparse_attention(ck_tile::HostTensor& TQ, if(bias) bias_buf.ToDevice(bias->data()); - if(lse) - lse_buf.ToDevice(lse->data()); if(seqstart_q) seqstart_q_buf.ToDevice(seqstart_q->data()); if(seqstart_k) @@ -156,7 +152,7 @@ vsa_sparse_attention(ck_tile::HostTensor& TQ, args.batch_stride_v = batch_stride_v; args.bias_ptr = bias ? bias_buf.GetDeviceBuffer() : nullptr; - args.lse_ptr = lse ? lse_buf.GetDeviceBuffer() : nullptr; + args.lse_ptr = nullptr; args.o_ptr = o_buf.GetDeviceBuffer(); args.seqstart_q_ptr = (mode == 1 ? seqstart_q_buf.GetDeviceBuffer() : nullptr); @@ -205,7 +201,7 @@ vsa_sparse_attention(ck_tile::HostTensor& TQ, traits.has_logits_soft_cap = 0.f < logits_soft_cap; traits.mask_type = mask.type; traits.bias_type = static_cast(bias_type); - traits.has_lse = lse ? true : false; + traits.has_lse = false; traits.do_fp8_static_quant = false; traits.has_dropout = false; @@ -234,7 +230,6 @@ vsa_sparse_attention( std::optional>, std::optional>, std::optional>, - std::optional>, int, int, int, int, int, int, int, int, int, bool, bool, int, int); template ck_tile::HostTensor @@ -245,5 +240,4 @@ vsa_sparse_attention( std::optional>, std::optional>, std::optional>, - std::optional>, int, int, int, int, int, int, int, int, int, bool, bool, int, int); From d2278ab2526735e98e6a4931b18ab39c623466a4 Mon Sep 17 00:00:00 2001 From: Jiangyon Date: Tue, 16 Dec 2025 08:09:06 +0000 Subject: [PATCH 07/22] split kernel code to block & kernel --- .../ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py | 4 ++-- .../ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py | 4 ++-- include/ck_tile/ops/sparse_attn.hpp | 8 ++++---- .../block_fmha_pipeline_qr_ks_vs_async_jenga.hpp | 0 .../block_fmha_pipeline_qr_ks_vs_async_vsa.hpp | 0 .../sparse_attn/{ => kernel}/fmha_fwd_jenga_kernel.hpp | 0 .../ops/sparse_attn/{ => kernel}/fmha_fwd_vsa_kernel.hpp | 0 7 files changed, 8 insertions(+), 8 deletions(-) rename include/ck_tile/ops/sparse_attn/{ => block}/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp (100%) rename include/ck_tile/ops/sparse_attn/{ => block}/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp (100%) rename include/ck_tile/ops/sparse_attn/{ => kernel}/fmha_fwd_jenga_kernel.hpp (100%) rename include/ck_tile/ops/sparse_attn/{ => kernel}/fmha_fwd_vsa_kernel.hpp (100%) diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py index ae4fd78d0fa..92d0a8c7d4a 100644 --- a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py +++ b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py @@ -53,8 +53,8 @@ def update_file(file_path, content): // auto generated by generate.py #include "ck_tile/ops/fmha/block/variants.hpp" #include "fmha_fwd_trek.hpp" -#include "block_fmha_pipeline_qr_ks_vs_async_jenga.hpp" -#include "fmha_fwd_jenga_kernel.hpp" +#include "block/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp" +#include "kernel/fmha_fwd_jenga_kernel.hpp" """ diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py index 0d4bf9b7811..998bfc91bbe 100644 --- a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py +++ b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py @@ -53,8 +53,8 @@ def update_file(file_path, content): // auto generated by generate.py #include "ck_tile/ops/fmha/block/variants.hpp" #include "fmha_fwd_trek.hpp" -#include "block_fmha_pipeline_qr_ks_vs_async_vsa.hpp" -#include "fmha_fwd_vsa_kernel.hpp" +#include "block/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp" +#include "kernel/fmha_fwd_vsa_kernel.hpp" """ diff --git a/include/ck_tile/ops/sparse_attn.hpp b/include/ck_tile/ops/sparse_attn.hpp index 8b48003e034..5e7e692a2a1 100644 --- a/include/ck_tile/ops/sparse_attn.hpp +++ b/include/ck_tile/ops/sparse_attn.hpp @@ -3,10 +3,10 @@ #pragma once -#include "ck_tile/ops/sparse_attn/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp" -#include "ck_tile/ops/sparse_attn/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp" -#include "ck_tile/ops/sparse_attn/fmha_fwd_jenga_kernel.hpp" -#include "ck_tile/ops/sparse_attn/fmha_fwd_vsa_kernel.hpp" +#include "ck_tile/ops/sparse_attn/block/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp" +#include "ck_tile/ops/sparse_attn/block/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp" +#include "ck_tile/ops/sparse_attn/kernel/fmha_fwd_jenga_kernel.hpp" +#include "ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_interleaved_pk_type.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/include/ck_tile/ops/sparse_attn/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp b/include/ck_tile/ops/sparse_attn/block/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp similarity index 100% rename from include/ck_tile/ops/sparse_attn/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp rename to include/ck_tile/ops/sparse_attn/block/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp diff --git a/include/ck_tile/ops/sparse_attn/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp b/include/ck_tile/ops/sparse_attn/block/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp similarity index 100% rename from include/ck_tile/ops/sparse_attn/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp rename to include/ck_tile/ops/sparse_attn/block/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp diff --git a/include/ck_tile/ops/sparse_attn/fmha_fwd_jenga_kernel.hpp b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_jenga_kernel.hpp similarity index 100% rename from include/ck_tile/ops/sparse_attn/fmha_fwd_jenga_kernel.hpp rename to include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_jenga_kernel.hpp diff --git a/include/ck_tile/ops/sparse_attn/fmha_fwd_vsa_kernel.hpp b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp similarity index 100% rename from include/ck_tile/ops/sparse_attn/fmha_fwd_vsa_kernel.hpp rename to include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp From faff9abaac72f7f3ca21c15fd9ff68e6e611eea1 Mon Sep 17 00:00:00 2001 From: Jiangyon Date: Tue, 16 Dec 2025 08:22:20 +0000 Subject: [PATCH 08/22] fix the pre-commit --- include/ck_tile/ops/sparse_attn.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/ck_tile/ops/sparse_attn.hpp b/include/ck_tile/ops/sparse_attn.hpp index 5e7e692a2a1..c6922303cbb 100644 --- a/include/ck_tile/ops/sparse_attn.hpp +++ b/include/ck_tile/ops/sparse_attn.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. #pragma once From 55d9a8e5769a45f60cd507198267daf5ba6e0861 Mon Sep 17 00:00:00 2001 From: Jiangyon Date: Tue, 16 Dec 2025 08:28:07 +0000 Subject: [PATCH 09/22] fix the pre-commit --- include/ck_tile/ops/sparse_attn.hpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/include/ck_tile/ops/sparse_attn.hpp b/include/ck_tile/ops/sparse_attn.hpp index c6922303cbb..2e3872c8249 100644 --- a/include/ck_tile/ops/sparse_attn.hpp +++ b/include/ck_tile/ops/sparse_attn.hpp @@ -1,6 +1,5 @@ -// SPDX-License-Identifier: MIT // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. - +// SPDX-License-Identifier: MIT #pragma once #include "ck_tile/ops/sparse_attn/block/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp" From a86fc8076ce9891eb99a1ac5dbdd561a035d74c5 Mon Sep 17 00:00:00 2001 From: Jiangyon Date: Tue, 16 Dec 2025 08:38:34 +0000 Subject: [PATCH 10/22] fix the copyrights --- example/ck_tile/50_sparse_attn/bias.hpp | 3 +-- example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp | 3 +-- example/ck_tile/50_sparse_attn/mask.hpp | 3 +-- include/ck_tile/ops/sparse_attn.hpp | 3 ++- .../ck_tile/ops/sparse_attn/kernel/fmha_fwd_jenga_kernel.hpp | 3 +-- include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp | 3 +-- 6 files changed, 7 insertions(+), 11 deletions(-) diff --git a/example/ck_tile/50_sparse_attn/bias.hpp b/example/ck_tile/50_sparse_attn/bias.hpp index f9dc656f637..4f013341e8b 100644 --- a/example/ck_tile/50_sparse_attn/bias.hpp +++ b/example/ck_tile/50_sparse_attn/bias.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include diff --git a/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp b/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp index 31f59ef5167..1c6cb39a3c8 100644 --- a/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp +++ b/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/core.hpp" diff --git a/example/ck_tile/50_sparse_attn/mask.hpp b/example/ck_tile/50_sparse_attn/mask.hpp index b96482f5355..b484ccc590f 100644 --- a/example/ck_tile/50_sparse_attn/mask.hpp +++ b/example/ck_tile/50_sparse_attn/mask.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include diff --git a/include/ck_tile/ops/sparse_attn.hpp b/include/ck_tile/ops/sparse_attn.hpp index 2e3872c8249..5e7e692a2a1 100644 --- a/include/ck_tile/ops/sparse_attn.hpp +++ b/include/ck_tile/ops/sparse_attn.hpp @@ -1,5 +1,6 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + #pragma once #include "ck_tile/ops/sparse_attn/block/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp" diff --git a/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_jenga_kernel.hpp b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_jenga_kernel.hpp index 72acd0f604f..985b2f5348e 100644 --- a/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_jenga_kernel.hpp +++ b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_jenga_kernel.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/core.hpp" diff --git a/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp index 3b2c7019790..52948f0d30b 100644 --- a/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp +++ b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/core.hpp" From 12420cd6f77d956a99ed9ac7748e49828b39f5ea Mon Sep 17 00:00:00 2001 From: Jiangyon Date: Tue, 16 Dec 2025 08:59:26 +0000 Subject: [PATCH 11/22] fix the copyright --- example/ck_tile/50_sparse_attn/CMakeLists.txt | 2 +- example/ck_tile/50_sparse_attn/codegen/__init__.py | 3 +++ example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py | 2 +- example/ck_tile/50_sparse_attn/codegen/ops/__init__.py | 3 +++ example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py | 2 +- example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py | 2 +- example/ck_tile/50_sparse_attn/generate.py | 2 +- .../block/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp | 3 +-- .../block/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp | 3 +-- 9 files changed, 13 insertions(+), 9 deletions(-) diff --git a/example/ck_tile/50_sparse_attn/CMakeLists.txt b/example/ck_tile/50_sparse_attn/CMakeLists.txt index e7ca29b95e4..5ea00e39129 100644 --- a/example/ck_tile/50_sparse_attn/CMakeLists.txt +++ b/example/ck_tile/50_sparse_attn/CMakeLists.txt @@ -1,5 +1,5 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. # CMakeLists.txt for sparse attention (Jenga and VSA) # Use SUPPORTED_GPU_TARGETS directly diff --git a/example/ck_tile/50_sparse_attn/codegen/__init__.py b/example/ck_tile/50_sparse_attn/codegen/__init__.py index e69de29bb2d..fb0a4926040 100644 --- a/example/ck_tile/50_sparse_attn/codegen/__init__.py +++ b/example/ck_tile/50_sparse_attn/codegen/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + diff --git a/example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py b/example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py index 63751d43fa2..bcc66ee7ca0 100644 --- a/example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py +++ b/example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py @@ -1,5 +1,5 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. # generate kernel instances to speed up compilation FWD_DTYPE_MAP = { diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/__init__.py b/example/ck_tile/50_sparse_attn/codegen/ops/__init__.py index e69de29bb2d..fb0a4926040 100644 --- a/example/ck_tile/50_sparse_attn/codegen/ops/__init__.py +++ b/example/ck_tile/50_sparse_attn/codegen/ops/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py index 92d0a8c7d4a..229575ea62a 100644 --- a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py +++ b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py @@ -1,5 +1,5 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. # generate kernel instances to speed up compilation import copy diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py index 998bfc91bbe..52c23c2055b 100644 --- a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py +++ b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py @@ -1,5 +1,5 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. # generate kernel instances to speed up compilation import copy diff --git a/example/ck_tile/50_sparse_attn/generate.py b/example/ck_tile/50_sparse_attn/generate.py index f1bf88efd77..a294eb172ec 100644 --- a/example/ck_tile/50_sparse_attn/generate.py +++ b/example/ck_tile/50_sparse_attn/generate.py @@ -1,5 +1,5 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. # generate kernel instances to speed up compilation import argparse diff --git a/include/ck_tile/ops/sparse_attn/block/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp b/include/ck_tile/ops/sparse_attn/block/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp index 2644c0cfdc5..1a5a3f77b5a 100644 --- a/include/ck_tile/ops/sparse_attn/block/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp +++ b/include/ck_tile/ops/sparse_attn/block/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/core.hpp" diff --git a/include/ck_tile/ops/sparse_attn/block/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp b/include/ck_tile/ops/sparse_attn/block/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp index ac5b4db3519..bf4f89e3c4f 100644 --- a/include/ck_tile/ops/sparse_attn/block/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp +++ b/include/ck_tile/ops/sparse_attn/block/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/core.hpp" From 776664aa7b074c74a4bd1431ee4b8805c9b7a94a Mon Sep 17 00:00:00 2001 From: Jiangyon Date: Tue, 16 Dec 2025 09:18:27 +0000 Subject: [PATCH 12/22] fix the copyright & rename block to pipeline --- example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py | 2 +- example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py | 2 +- example/ck_tile/50_sparse_attn/jenga_sparse_attention.cu | 3 +-- example/ck_tile/50_sparse_attn/jenga_sparse_attention.h | 5 ++--- example/ck_tile/50_sparse_attn/test_jenga_sparse_attn.cpp | 3 +-- example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp | 3 +-- example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu | 3 +-- include/ck_tile/ops/sparse_attn.hpp | 4 ++-- .../block_fmha_pipeline_qr_ks_vs_async_jenga.hpp | 0 .../block_fmha_pipeline_qr_ks_vs_async_vsa.hpp | 0 10 files changed, 10 insertions(+), 15 deletions(-) rename include/ck_tile/ops/sparse_attn/{block => pipeline}/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp (100%) rename include/ck_tile/ops/sparse_attn/{block => pipeline}/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp (100%) diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py index 229575ea62a..3450e6afce6 100644 --- a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py +++ b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py @@ -49,7 +49,7 @@ def update_file(file_path, content): K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256} FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.\n // auto generated by generate.py #include "ck_tile/ops/fmha/block/variants.hpp" #include "fmha_fwd_trek.hpp" diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py index 52c23c2055b..169d0ce6c61 100644 --- a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py +++ b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py @@ -49,7 +49,7 @@ def update_file(file_path, content): K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256} FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.\n // auto generated by generate.py #include "ck_tile/ops/fmha/block/variants.hpp" #include "fmha_fwd_trek.hpp" diff --git a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cu b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cu index 3be8cae198b..96f502a6f76 100644 --- a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cu +++ b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cu @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - #include "jenga_sparse_attention.h" #include "fmha_fwd_trek.hpp" #include "ck_tile/core.hpp" diff --git a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.h b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.h index 2f0be76bf54..b619926ceab 100644 --- a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.h +++ b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.h @@ -1,7 +1,6 @@ -#pragma once +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. -// +#pragma once #include #include #include "ck_tile/core.hpp" diff --git a/example/ck_tile/50_sparse_attn/test_jenga_sparse_attn.cpp b/example/ck_tile/50_sparse_attn/test_jenga_sparse_attn.cpp index 0acde4bed3b..c6d58d55995 100644 --- a/example/ck_tile/50_sparse_attn/test_jenga_sparse_attn.cpp +++ b/example/ck_tile/50_sparse_attn/test_jenga_sparse_attn.cpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. -// // Test for jenga_sparse_attention function #include diff --git a/example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp b/example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp index be4653c994a..1238f32fd85 100644 --- a/example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp +++ b/example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. -// // Test for vsa_sparse_attention function // Based on the Python test: test_jenga_attention.py diff --git a/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu b/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu index e3199b444fb..df13084d8fa 100644 --- a/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu +++ b/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - #include "jenga_sparse_attention.h" #include "fmha_fwd_trek.hpp" #include "ck_tile/core.hpp" diff --git a/include/ck_tile/ops/sparse_attn.hpp b/include/ck_tile/ops/sparse_attn.hpp index 5e7e692a2a1..b6f6f4c8eca 100644 --- a/include/ck_tile/ops/sparse_attn.hpp +++ b/include/ck_tile/ops/sparse_attn.hpp @@ -3,10 +3,10 @@ #pragma once -#include "ck_tile/ops/sparse_attn/block/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp" -#include "ck_tile/ops/sparse_attn/block/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp" #include "ck_tile/ops/sparse_attn/kernel/fmha_fwd_jenga_kernel.hpp" #include "ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp" +#include "ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp" +#include "ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_interleaved_pk_type.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/include/ck_tile/ops/sparse_attn/block/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp similarity index 100% rename from include/ck_tile/ops/sparse_attn/block/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp rename to include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp diff --git a/include/ck_tile/ops/sparse_attn/block/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp similarity index 100% rename from include/ck_tile/ops/sparse_attn/block/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp rename to include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp From 8ba592d65c913836d76a4fd6c4025cfb1f2a364f Mon Sep 17 00:00:00 2001 From: Jiangyon Date: Tue, 16 Dec 2025 09:30:12 +0000 Subject: [PATCH 13/22] fix the copyright and pipeline --- example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py | 2 +- example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py | 2 +- include/ck_tile/ops/sparse_attn.hpp | 3 +-- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py index 3450e6afce6..2e4a734e3b1 100644 --- a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py +++ b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py @@ -53,7 +53,7 @@ def update_file(file_path, content): // auto generated by generate.py #include "ck_tile/ops/fmha/block/variants.hpp" #include "fmha_fwd_trek.hpp" -#include "block/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp" +#include "pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp" #include "kernel/fmha_fwd_jenga_kernel.hpp" """ diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py index 169d0ce6c61..09953614406 100644 --- a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py +++ b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py @@ -53,7 +53,7 @@ def update_file(file_path, content): // auto generated by generate.py #include "ck_tile/ops/fmha/block/variants.hpp" #include "fmha_fwd_trek.hpp" -#include "block/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp" +#include "pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp" #include "kernel/fmha_fwd_vsa_kernel.hpp" """ diff --git a/include/ck_tile/ops/sparse_attn.hpp b/include/ck_tile/ops/sparse_attn.hpp index b6f6f4c8eca..3ee643d7299 100644 --- a/include/ck_tile/ops/sparse_attn.hpp +++ b/include/ck_tile/ops/sparse_attn.hpp @@ -1,6 +1,5 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - #pragma once #include "ck_tile/ops/sparse_attn/kernel/fmha_fwd_jenga_kernel.hpp" From 59902860eabeeb4ce66f39c88cf21626dd681b13 Mon Sep 17 00:00:00 2001 From: Jiangyon Date: Fri, 19 Dec 2025 08:26:17 +0000 Subject: [PATCH 14/22] remove lse & dropout & add fmt --- .../50_sparse_attn/codegen/cpp_symbol_map.py | 16 -------- .../codegen/ops/fmha_fwd_jenga.py | 39 ++++++++----------- .../codegen/ops/fmha_fwd_vsa.py | 33 +++++++--------- .../50_sparse_attn/test_jenga_sparse_attn.cpp | 8 ++-- .../50_sparse_attn/test_vsa_sparse_attn.cpp | 8 ++-- 5 files changed, 37 insertions(+), 67 deletions(-) diff --git a/example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py b/example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py index bcc66ee7ca0..1dad4423c6a 100644 --- a/example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py +++ b/example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py @@ -74,22 +74,6 @@ def get_mask_check_map(mask: str): "alibi": "bias_enum::alibi", } -DROPOUT_MAP = { - "no": "ck_tile::BlockDropoutBwd", - "dropout_wg32": "ck_tile::BlockDropoutBwd", - "dropout_wg32_storerandval": "ck_tile::BlockDropoutBwd", - "dropout_wg16": "ck_tile::BlockDropoutBwd", - "dropout_wg16_storerandval": "ck_tile::BlockDropoutBwd", -} - -DROPOUT_CHECK_MAP = { - "no": "t.has_dropout == false", - "dropout_wg32": "t.has_dropout == true && t.is_store_randval == false", - "dropout_wg32_storerandval": "t.has_dropout == true && t.is_store_randval == true", - "dropout_wg16": "t.has_dropout == true && t.is_store_randval == false", - "dropout_wg16_storerandval": "t.has_dropout == true && t.is_store_randval == true", -} - ROPE_MAP = { "no": "ck_tile::RotaryEmbeddingEnum::NONE", "inter": "ck_tile::RotaryEmbeddingEnum::INTERLEAVED", diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py index 2e4a734e3b1..d4221911f9f 100644 --- a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py +++ b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py @@ -405,15 +405,7 @@ def pad_name() -> str: else: n += "_nmask" - if self.F_lse == "t": - n += "_lse" - else: - n += "_nlse" - - if self.F_dropout == "t": - n += "_dropout" - else: - n += "_ndropout" + # Note: lse and dropout are not supported, so we don't add them to filename if self.F_skip == "t": n += "_skip" @@ -663,7 +655,7 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: # FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], # (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (128, 128): [ - FmhaFwdTileSize( + FmhaFwdTileSize( # fmt: skip 16, 32, 64, @@ -684,7 +676,7 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: 32, -1, ), - FmhaFwdTileSize( + FmhaFwdTileSize( # fmt: skip 32, 32, 128, @@ -705,7 +697,7 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: 16, -1, ), - FmhaFwdTileSize( + FmhaFwdTileSize( # fmt: skip 128, 64, 32, @@ -726,7 +718,7 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: 16, -1, ), - FmhaFwdTileSize( + FmhaFwdTileSize( # fmt: skip 128, 128, 32, @@ -839,19 +831,20 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli squant = "t" if dtype == "fp8" else "f" pipelines = [] if dtype in ["fp16", "bf16"]: - for logits, mask, bias, lse, dropout, skip in itertools.product( + for logits, mask, bias, skip in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], - ["t", "f"], - ["t", "f"], ): + # Always use lse="f" and dropout="f" (not supported) + lse = "f" + dropout = "f" if hdim == 256 and hdim_v == 256: # print("jenga fmha only support dim=128 now.") continue pipelines.append( - FmhaFwdPipeline( + FmhaFwdPipeline( # fmt: skip "qr", "row", "f", @@ -870,7 +863,7 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli ) # the below two is used for hdim vectorize load pipelines.append( - FmhaFwdPipeline( + FmhaFwdPipeline( # fmt: skip "qr", "row", "t", @@ -888,7 +881,7 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli ) ) pipelines.append( - FmhaFwdPipeline( + FmhaFwdPipeline( # fmt: skip "qr", "row", "t", @@ -911,7 +904,7 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli continue # TODO: rocm 6.2 compiler problem if using qr_async for bias case pipelines.append( - FmhaFwdPipeline( + FmhaFwdPipeline( # fmt: skip "qr", "row", "f", @@ -929,7 +922,7 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli ) ) pipelines.append( - FmhaFwdPipeline( + FmhaFwdPipeline( # fmt: skip "qr", "row", "t", @@ -948,7 +941,7 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli ) else: pipelines.append( - FmhaFwdPipeline( + FmhaFwdPipeline( # fmt: skip "qr_async", "row", "t", @@ -966,7 +959,7 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli ) ) pipelines.append( - FmhaFwdPipeline( + FmhaFwdPipeline( # fmt: skip "qr_async", "row", "t", diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py index 09953614406..ebcc14b5ddc 100644 --- a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py +++ b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py @@ -405,15 +405,7 @@ def pad_name() -> str: else: n += "_nmask" - if self.F_lse == "t": - n += "_lse" - else: - n += "_nlse" - - if self.F_dropout == "t": - n += "_dropout" - else: - n += "_ndropout" + # Note: lse and dropout are not supported, so we don't add them to filename if self.F_skip == "t": n += "_skip" @@ -663,7 +655,7 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: # FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], # (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (128, 128): [ - FmhaFwdTileSize( + FmhaFwdTileSize( # fmt: skip 16, 32, 64, @@ -684,7 +676,7 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: 32, -1, ), - FmhaFwdTileSize( + FmhaFwdTileSize( # fmt: skip 32, 32, 128, @@ -705,7 +697,7 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: 16, -1, ), - FmhaFwdTileSize( + FmhaFwdTileSize( # fmt: skip 128, 64, 32, @@ -726,7 +718,7 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: 16, -1, ), - FmhaFwdTileSize( + FmhaFwdTileSize( # fmt: skip 128, 128, 32, @@ -756,7 +748,7 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: elif dtype == "fp8" or dtype == "bf8": return { (64, 64): [ - FmhaFwdTileSize( + FmhaFwdTileSize( # fmt: skip 128, 64, 32, @@ -779,7 +771,7 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: ) ], (128, 128): [ - FmhaFwdTileSize( + FmhaFwdTileSize( # fmt: skip 128, 128, 32, @@ -802,7 +794,7 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: ) ], (256, 256): [ - FmhaFwdTileSize( + FmhaFwdTileSize( # fmt: skip 128, 128, 32, @@ -839,14 +831,15 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli squant = "t" if dtype == "fp8" else "f" pipelines = [] if dtype in ["fp16", "bf16"]: - for logits, mask, bias, lse, dropout, skip in itertools.product( + for logits, mask, bias, skip in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], - ["t", "f"], - ["t", "f"], ): + # Always use lse="f" and dropout="f" (not supported) + lse = "f" + dropout = "f" if hdim == 256 and hdim_v == 256: # print("vsa fmha only support dim=128 now.") continue @@ -1086,7 +1079,7 @@ def get_fwd_blobs( continue if (hdim, hdim_v) == (192, 128): # NOTE: this is used to speedup deepseek prefill case, we don't gen training - if pipeline.F_bias != "no" or pipeline.F_dropout == "t": + if pipeline.F_bias != "no": continue if pipeline.tag != "qr_async_trload" and ( ((hdim, hdim_v) == (128, 128) and tile.F_bn0 != 128) diff --git a/example/ck_tile/50_sparse_attn/test_jenga_sparse_attn.cpp b/example/ck_tile/50_sparse_attn/test_jenga_sparse_attn.cpp index c6d58d55995..e550a40f670 100644 --- a/example/ck_tile/50_sparse_attn/test_jenga_sparse_attn.cpp +++ b/example/ck_tile/50_sparse_attn/test_jenga_sparse_attn.cpp @@ -43,8 +43,8 @@ void reference_blocked_attention( ck_tile::index_t seqlen_k = v_lengths[2]; ck_tile::index_t hdim_v = v_lengths[3]; - ck_tile::index_t num_q_blocks = seqlen_q / BLKQ; - ck_tile::index_t num_k_blocks = seqlen_k / BLKK; + ck_tile::index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ; + ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK; for(ck_tile::index_t b = 0; b < batch; ++b) { @@ -207,8 +207,8 @@ bool run_test(const ck_tile::ArgParser& arg_parser) ck_tile::index_t BLKK = block_size; // Calculate number of Q and K blocks - ck_tile::index_t num_q_blocks = seqlen_q / BLKQ; - ck_tile::index_t num_k_blocks = seqlen_k / BLKK; + ck_tile::index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ; + ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK; std::cout << "============================================================" << std::endl; std::cout << "[Jenga Sparse Attention Test]" << std::endl; diff --git a/example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp b/example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp index 1238f32fd85..32d64872cf2 100644 --- a/example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp +++ b/example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp @@ -82,8 +82,8 @@ void reference_blocked_attention( ck_tile::index_t seqlen_k = v_lengths[2]; ck_tile::index_t hdim_v = v_lengths[3]; - ck_tile::index_t num_q_blocks = seqlen_q / BLKQ; - ck_tile::index_t num_k_blocks = seqlen_k / BLKK; + ck_tile::index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ; + ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK; for(ck_tile::index_t b = 0; b < batch; ++b) { @@ -247,8 +247,8 @@ bool run_test(const ck_tile::ArgParser& arg_parser) ck_tile::index_t BLKK = block_size; // Calculate number of Q and K blocks - ck_tile::index_t num_q_blocks = seqlen_q / BLKQ; - ck_tile::index_t num_k_blocks = seqlen_k / BLKK; + ck_tile::index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ; + ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK; std::cout << "============================================================" << std::endl; std::cout << "[VSA Sparse Attention Test]" << std::endl; From 404a7ef7ee43fb98df71fc936f04dc9a4ff02e07 Mon Sep 17 00:00:00 2001 From: Jiangyong Date: Mon, 26 Jan 2026 17:54:18 +0800 Subject: [PATCH 15/22] fix the jenga&VSA code review --- .../50_sparse_attn/codegen/cpp_symbol_map.py | 18 -- .../codegen/ops/fmha_fwd_jenga.py | 73 +----- .../codegen/ops/fmha_fwd_vsa.py | 128 +-------- .../ck_tile/50_sparse_attn/fmha_fwd_trek.hpp | 5 + .../50_sparse_attn/jenga_sparse_attention.cu | 59 ++--- .../50_sparse_attn/jenga_sparse_attention.h | 14 +- .../50_sparse_attn/test_jenga_sparse_attn.cpp | 220 +++++++--------- .../50_sparse_attn/test_vsa_sparse_attn.cpp | 244 +++++++----------- .../50_sparse_attn/vsa_sparse_attention.cu | 55 ++-- .../arch/amd_buffer_addressing_builtins.hpp | 4 +- include/ck_tile/host.hpp | 1 + .../reference/reference_blocked_attention.hpp | 137 ++++++++++ .../kernel/fmha_fwd_jenga_kernel.hpp | 96 +------ .../kernel/fmha_fwd_vsa_kernel.hpp | 29 +-- ...ock_fmha_pipeline_qr_ks_vs_async_jenga.hpp | 22 +- 15 files changed, 427 insertions(+), 678 deletions(-) create mode 100644 include/ck_tile/host/reference/reference_blocked_attention.hpp diff --git a/example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py b/example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py index 1dad4423c6a..c7aec047462 100644 --- a/example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py +++ b/example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py @@ -10,13 +10,6 @@ "fp8bf16": "FmhaFwdFp8Bf16", } -BWD_DTYPE_MAP = {"fp16": "FmhaBwdFp16", "bf16": "FmhaBwdBf16"} - -MASK_IMPL = { - "generic": "ck_tile::GenericAttentionMask", - "simplified": "ck_tile::SimplifiedGenericAttentionMask", -} - _MASK_SIMPLIFIED_MAP = { "s_no": "ck_tile::SimplifiedGenericAttentionMask", "s_mask": "ck_tile::SimplifiedGenericAttentionMask", @@ -74,17 +67,6 @@ def get_mask_check_map(mask: str): "alibi": "bias_enum::alibi", } -ROPE_MAP = { - "no": "ck_tile::RotaryEmbeddingEnum::NONE", - "inter": "ck_tile::RotaryEmbeddingEnum::INTERLEAVED", - "half": "ck_tile::RotaryEmbeddingEnum::HALF_ROTATED", -} - -ROPE_CHECK_MAP = { - "no": "rope_enum::none", - "inter": "rope_enum::interleaved", - "half": "rope_enum::half_rotated", -} MODE_MAP = {"batch": "false", "group": "true"} diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py index d4221911f9f..e598210291c 100644 --- a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py +++ b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py @@ -359,7 +359,7 @@ class FmhaFwdPipeline: F_mask: str # value from MASK_MAP F_skip: str # true/false F_trload: str # true/false - F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint()) + F_constraint: CppConstraint = field(default_factory=CppConstraint) @property def name(self) -> str: @@ -532,7 +532,7 @@ class FmhaFwdTileSize: F_wn1: int # gemm1 warp size along n F_wk1: int # gemm1 warp size along k F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy - F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint()) + F_constraint: CppConstraint = field(default_factory=CppConstraint) @property def name(self) -> str: @@ -900,45 +900,8 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli ) else: if bias == "bias": - # print("jenga_fmha with bias is not implemented.") + # jenga_fmha with bias is not implemented. continue - # TODO: rocm 6.2 compiler problem if using qr_async for bias case - pipelines.append( - FmhaFwdPipeline( # fmt: skip - "qr", - "row", - "f", - "f", - "f", - "f", - logits, - bias, - lse, - dropout, - squant, - mask, - skip, - "f", - ) - ) - pipelines.append( - FmhaFwdPipeline( # fmt: skip - "qr", - "row", - "t", - "t", - "t", - "t", - logits, - bias, - lse, - dropout, - squant, - mask, - skip, - "f", - ) - ) else: pipelines.append( FmhaFwdPipeline( # fmt: skip @@ -976,36 +939,12 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli "f", ) ) - # if (hdim, hdim_v) in [(64, 64), (128, 128)] and logits == "f" and bias == "no" and dropout == "f" and lse == "f" and skip == "f": - # pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 't')) - # pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 't')) - # if receipt == 1 and bias != "bias": - # pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) # TODO: cover arbitraty hdim + # TODO: consider enabling extra qr_async_trload pipelines for select + # (hdim, hdim_v) when logits/bias/dropout/lse/skip allow. + # TODO: consider enabling extra qr pipelines when receipt == 1 and bias != "bias". elif dtype in ["fp8", "bf8"]: # print("jenga fmha only support 16-bit compute.") return pipelines - # no need lse/dropout kernels - for logits, mask, bias in itertools.product( - ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys() - ): - pipelines.append( - FmhaFwdPipeline( - "qr", - "col", - "f", - "f", - "f", - "f", - logits, - bias, - "f", - "f", - squant, - mask, - "f", - "f", - ) - ) elif dtype in ["fp8fp16", "fp8bf16"]: # TODO None diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py index ebcc14b5ddc..9c784dddbb7 100644 --- a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py +++ b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py @@ -359,7 +359,7 @@ class FmhaFwdPipeline: F_mask: str # value from MASK_MAP F_skip: str # true/false F_trload: str # true/false - F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint()) + F_constraint: CppConstraint = field(default_factory=CppConstraint) @property def name(self) -> str: @@ -532,7 +532,7 @@ class FmhaFwdTileSize: F_wn1: int # gemm1 warp size along n F_wk1: int # gemm1 warp size along k F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy - F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint()) + F_constraint: CppConstraint = field(default_factory=CppConstraint) @property def name(self) -> str: @@ -843,102 +843,10 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli if hdim == 256 and hdim_v == 256: # print("vsa fmha only support dim=128 now.") continue - pipelines.append( - FmhaFwdPipeline( - "qr", - "row", - "f", - "f", - "f", - "f", - logits, - bias, - lse, - dropout, - squant, - mask, - skip, - "f", - ) - ) - # the below two is used for hdim vectorize load - pipelines.append( - FmhaFwdPipeline( - "qr", - "row", - "t", - "t", - "f", - "f", - logits, - bias, - lse, - dropout, - squant, - mask, - skip, - "f", - ) - ) - pipelines.append( - FmhaFwdPipeline( - "qr", - "row", - "t", - "t", - "t", - "t", - logits, - bias, - lse, - dropout, - squant, - mask, - skip, - "f", - ) - ) else: if bias == "bias": - # print("vsa_fmha with bias is not implemented.") + # vsa_fmha with bias is not implemented. continue - # TODO: rocm 6.2 compiler problem if using qr_async for bias case - pipelines.append( - FmhaFwdPipeline( - "qr", - "row", - "f", - "f", - "f", - "f", - logits, - bias, - lse, - dropout, - squant, - mask, - skip, - "f", - ) - ) - pipelines.append( - FmhaFwdPipeline( - "qr", - "row", - "t", - "t", - "t", - "t", - logits, - bias, - lse, - dropout, - squant, - mask, - skip, - "f", - ) - ) else: pipelines.append( FmhaFwdPipeline( @@ -976,36 +884,12 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli "f", ) ) - # if (hdim, hdim_v) in [(64, 64), (128, 128)] and logits == "f" and bias == "no" and dropout == "f" and lse == "f" and skip == "f": - # pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 't')) - # pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 't')) - # if receipt == 1 and bias != "bias": - # pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) # TODO: cover arbitraty hdim + # TODO: consider enabling extra qr_async_trload pipelines for select + # (hdim, hdim_v) when logits/bias/dropout/lse/skip allow. + # TODO: consider enabling extra qr pipelines when receipt == 1 and bias != "bias". elif dtype in ["fp8", "bf8"]: # print("vsa fmha only support 16-bit compute.") return pipelines - # no need lse/dropout kernels - for logits, mask, bias in itertools.product( - ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys() - ): - pipelines.append( - FmhaFwdPipeline( - "qr", - "col", - "f", - "f", - "f", - "f", - logits, - bias, - "f", - "f", - squant, - mask, - "f", - "f", - ) - ) elif dtype in ["fp8fp16", "fp8bf16"]: # TODO None diff --git a/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp b/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp index 1c6cb39a3c8..641d12490b5 100644 --- a/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp +++ b/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp @@ -4,6 +4,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/host/device_prop.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/fmha.hpp" @@ -14,6 +15,10 @@ #include #include +namespace ck_tile { +inline bool is_load_tr_supported() { return is_gfx95_supported(); } +} // namespace ck_tile + struct FmhaFwdFp16 { }; diff --git a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cu b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cu index 96f502a6f76..45b61b6dd8a 100644 --- a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cu +++ b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cu @@ -12,11 +12,9 @@ ck_tile::HostTensor jenga_sparse_attention(ck_tile::HostTensor& TQ, ck_tile::HostTensor& TK, ck_tile::HostTensor& TV, - ck_tile::HostTensor& Tblock_relation_onehot, + ck_tile::HostTensor& Tblock_relation_onehot, ck_tile::HostTensor& Y, std::optional> bias, - std::optional> seqstart_q, - std::optional> seqstart_k, int bias_type, int batch, int nhead, @@ -25,18 +23,28 @@ jenga_sparse_attention(ck_tile::HostTensor& TQ, int seqlen_k, int hdim_q, int hdim_v, - int mode, bool i_perm, bool o_perm, int max_seqlen_q, - int max_seqlen_k) + int max_seqlen_k, + int log_level) { // Determine data type string based on template parameter + constexpr bool is_fp8 = + std::is_same_v || std::is_same_v; std::string data_type = "fp16"; if constexpr(std::is_same_v) { data_type = "bf16"; } + else if constexpr(std::is_same_v) + { + data_type = "fp8"; + } + else if constexpr(std::is_same_v) + { + data_type = "bf8"; + } if(max_seqlen_q == 0) max_seqlen_q = seqlen_q; @@ -51,12 +59,12 @@ jenga_sparse_attention(ck_tile::HostTensor& TQ, std::string msk_str = "0"; mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k); - const ck_tile::index_t shape_seqlen_q = (mode == 0 ? seqlen_q : max_seqlen_q); - const ck_tile::index_t shape_seqlen_k = (mode == 0 ? seqlen_k : max_seqlen_k); + const ck_tile::index_t shape_seqlen_q = seqlen_q; + const ck_tile::index_t shape_seqlen_k = seqlen_k; ck_tile::stream_config stream_config{nullptr, false, // time_kernel - 0, /* log_level = */ + log_level, 0, 1, false}; @@ -75,16 +83,9 @@ jenga_sparse_attention(ck_tile::HostTensor& TQ, // Optional buffers ck_tile::DeviceMem bias_buf(bias ? bias->get_element_space_size_in_bytes() : 0); - ck_tile::DeviceMem seqstart_q_buf(seqstart_q ? seqstart_q->get_element_space_size_in_bytes() : 0); - ck_tile::DeviceMem seqstart_k_buf(seqstart_k ? seqstart_k->get_element_space_size_in_bytes() : 0); if(bias) bias_buf.ToDevice(bias->data()); - if(seqstart_q) - seqstart_q_buf.ToDevice(seqstart_q->data()); - if(seqstart_k) - seqstart_k_buf.ToDevice(seqstart_k->data()); - const auto init_args = [&](auto& args) { assert(nhead % nhead_k == 0); const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); @@ -128,7 +129,7 @@ jenga_sparse_attention(ck_tile::HostTensor& TQ, args.block_relation_onehot_ptr = block_relation_buf.GetDeviceBuffer(); args.batch = batch; - args.seqlen_q = shape_seqlen_q; // unused in group mode + args.seqlen_q = shape_seqlen_q; // batch mode only args.hdim_q = hdim_q; args.hdim_v = hdim_v; args.nhead_q = nhead; @@ -148,11 +149,11 @@ jenga_sparse_attention(ck_tile::HostTensor& TQ, args.lse_ptr = nullptr; args.o_ptr = o_buf.GetDeviceBuffer(); - args.seqstart_q_ptr = (mode == 1 ? seqstart_q_buf.GetDeviceBuffer() : nullptr); - args.seqstart_k_ptr = (mode == 1 ? seqstart_k_buf.GetDeviceBuffer() : nullptr); + args.seqstart_q_ptr = nullptr; + args.seqstart_k_ptr = nullptr; args.seqlen_k_ptr = nullptr; - args.seqlen_k = shape_seqlen_k; // unused in group mode (or kvcache enabled) + args.seqlen_k = shape_seqlen_k; // batch mode only args.max_seqlen_q = max_seqlen_q; args.scale_s = scale_s; @@ -190,12 +191,12 @@ jenga_sparse_attention(ck_tile::HostTensor& TQ, traits.data_type = data_type; traits.is_v_rowmajor = is_v_rowmajor; - traits.is_group_mode = (mode == 1); + traits.is_group_mode = false; traits.has_logits_soft_cap = 0.f < logits_soft_cap; traits.mask_type = mask.type; traits.bias_type = static_cast(bias_type); traits.has_lse = false; - traits.do_fp8_static_quant = false; + traits.do_fp8_static_quant = is_fp8; traits.has_dropout = false; }; @@ -208,8 +209,8 @@ jenga_sparse_attention(ck_tile::HostTensor& TQ, fmha_jenga_fwd(fmha_traits, args, stream_config); - // Copy output back to host - Y = o_buf.ToHost(); + // Copy output back to host without changing tensor shape + o_buf.FromDevice(Y.data(), Y.get_element_space_size_in_bytes()); return Y; } @@ -218,19 +219,15 @@ jenga_sparse_attention(ck_tile::HostTensor& TQ, template ck_tile::HostTensor jenga_sparse_attention( ck_tile::HostTensor&, ck_tile::HostTensor&, - ck_tile::HostTensor&, ck_tile::HostTensor&, + ck_tile::HostTensor&, ck_tile::HostTensor&, ck_tile::HostTensor&, std::optional>, - std::optional>, - std::optional>, - int, int, int, int, int, int, int, int, int, bool, bool, int, int); + int, int, int, int, int, int, int, int, bool, bool, int, int, int); template ck_tile::HostTensor jenga_sparse_attention( ck_tile::HostTensor&, ck_tile::HostTensor&, - ck_tile::HostTensor&, ck_tile::HostTensor&, + ck_tile::HostTensor&, ck_tile::HostTensor&, ck_tile::HostTensor&, std::optional>, - std::optional>, - std::optional>, - int, int, int, int, int, int, int, int, int, bool, bool, int, int); + int, int, int, int, int, int, int, int, bool, bool, int, int, int); diff --git a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.h b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.h index b619926ceab..601f849c232 100644 --- a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.h +++ b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.h @@ -11,11 +11,9 @@ ck_tile::HostTensor jenga_sparse_attention(ck_tile::HostTensor& TQ, ck_tile::HostTensor& TK, ck_tile::HostTensor& TV, - ck_tile::HostTensor& Tblock_relation_onehot, + ck_tile::HostTensor& Tblock_relation_onehot, ck_tile::HostTensor& Y, std::optional> bias, - std::optional> seqstart_q, - std::optional> seqstart_k, int bias_type, int batch, int nhead, @@ -24,11 +22,11 @@ jenga_sparse_attention(ck_tile::HostTensor& TQ, int seqlen_k, int hdim_q, int hdim_v, - int mode, bool i_perm, bool o_perm, int max_seqlen_q, - int max_seqlen_k); + int max_seqlen_k, + int log_level = 0); template ck_tile::HostTensor @@ -39,8 +37,6 @@ vsa_sparse_attention(ck_tile::HostTensor& TQ, ck_tile::HostTensor& TKV_blocks, // valid_block_num must be int32_t ck_tile::HostTensor& Y, std::optional> bias, - std::optional> seqstart_q, - std::optional> seqstart_k, int bias_type, int batch, int nhead, @@ -49,8 +45,8 @@ vsa_sparse_attention(ck_tile::HostTensor& TQ, int seqlen_k, int hdim_q, int hdim_v, - int mode, bool i_perm, bool o_perm, int max_seqlen_q, - int max_seqlen_k); + int max_seqlen_k, + int log_level = 0); diff --git a/example/ck_tile/50_sparse_attn/test_jenga_sparse_attn.cpp b/example/ck_tile/50_sparse_attn/test_jenga_sparse_attn.cpp index e550a40f670..46d95ebefec 100644 --- a/example/ck_tile/50_sparse_attn/test_jenga_sparse_attn.cpp +++ b/example/ck_tile/50_sparse_attn/test_jenga_sparse_attn.cpp @@ -13,6 +13,7 @@ #include "ck_tile/host.hpp" #include "ck_tile/core.hpp" +#include "ck_tile/host/reference/reference_blocked_attention.hpp" #include "jenga_sparse_attention.h" @@ -20,114 +21,44 @@ // Helper Functions // ============================================================================ -// Reference implementation: blocked attention -template -void reference_blocked_attention( - const ck_tile::HostTensor& q, // [B, H, S_q, D] - const ck_tile::HostTensor& k, // [B, H, S_k, D] - const ck_tile::HostTensor& v, // [B, H, S_k, D_v] - const ck_tile::HostTensor& block_relation, // [B, H, Q_blocks, K_blocks] - const ck_tile::HostTensor& bias, // [B, H, S_q, S_k] - ck_tile::HostTensor& output, // [B, H, S_q, D_v] - ck_tile::index_t BLKQ, - ck_tile::index_t BLKK, - AccT scale) +template +ck_tile::HostTensor make_qkv_tensor(ck_tile::index_t batch, + ck_tile::index_t nhead, + ck_tile::index_t seqlen, + ck_tile::index_t hdim, + bool i_perm) { - auto q_lengths = q.get_lengths(); - ck_tile::index_t batch = q_lengths[0]; - ck_tile::index_t nhead = q_lengths[1]; - ck_tile::index_t seqlen_q = q_lengths[2]; - ck_tile::index_t hdim = q_lengths[3]; - - auto v_lengths = v.get_lengths(); - ck_tile::index_t seqlen_k = v_lengths[2]; - ck_tile::index_t hdim_v = v_lengths[3]; + if(i_perm) + { + return ck_tile::HostTensor({batch, nhead, seqlen, hdim}); + } + return ck_tile::HostTensor({batch, seqlen, nhead, hdim}); +} - ck_tile::index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ; - ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK; +template +ck_tile::HostTensor to_bhsd(const ck_tile::HostTensor& tensor, bool is_bhsd) +{ + auto lens = tensor.get_lengths(); + ck_tile::index_t batch = lens[0]; + ck_tile::index_t seqlen = is_bhsd ? lens[2] : lens[1]; + ck_tile::index_t nhead = is_bhsd ? lens[1] : lens[2]; + ck_tile::index_t hdim = lens[3]; + ck_tile::HostTensor out({batch, nhead, seqlen, hdim}); for(ck_tile::index_t b = 0; b < batch; ++b) { for(ck_tile::index_t h = 0; h < nhead; ++h) { - for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb) + for(ck_tile::index_t s = 0; s < seqlen; ++s) { - ck_tile::index_t q_start = qb * BLKQ; - ck_tile::index_t q_end = q_start + BLKQ; - - // Find relevant K blocks - std::vector relevant_k_indices; - for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) - { - if(static_cast(block_relation(b, h, qb, kb)) > 0.5f) - { - relevant_k_indices.push_back(kb); - } - } - - if(relevant_k_indices.empty()) - continue; - - // For each query position in the block - for(ck_tile::index_t sq = q_start; sq < q_end; ++sq) + for(ck_tile::index_t d = 0; d < hdim; ++d) { - std::vector scores; - AccT max_score = -std::numeric_limits::infinity(); - - for(auto kb : relevant_k_indices) - { - ck_tile::index_t k_start = kb * BLKK; - ck_tile::index_t k_end = k_start + BLKK; - - for(ck_tile::index_t sk = k_start; sk < k_end; ++sk) - { - AccT score = 0.0f; - for(ck_tile::index_t d = 0; d < hdim; ++d) - { - score += static_cast(q(b, h, sq, d)) * - static_cast(k(b, h, sk, d)); - } - score = score * scale + static_cast(bias(b, h, sq, sk)); - scores.push_back(score); - max_score = std::max(max_score, score); - } - } - - // Softmax - AccT sum_exp = 0.0f; - for(auto& s : scores) - { - s = std::exp(s - max_score); - sum_exp += s; - } - for(auto& s : scores) - { - s /= sum_exp; - } - - // Compute output: P @ V - for(ck_tile::index_t dv = 0; dv < hdim_v; ++dv) - { - AccT out_val = 0.0f; - size_t score_idx = 0; - - for(auto kb : relevant_k_indices) - { - ck_tile::index_t k_start = kb * BLKK; - ck_tile::index_t k_end = k_start + BLKK; - - for(ck_tile::index_t sk = k_start; sk < k_end; ++sk) - { - out_val += scores[score_idx] * static_cast(v(b, h, sk, dv)); - score_idx++; - } - } - output(b, h, sq, dv) = static_cast(out_val); - } + out(b, h, s, d) = is_bhsd ? tensor(b, h, s, d) : tensor(b, s, h, d); } } } } + return out; } // Get error tolerance based on data type @@ -146,7 +77,6 @@ auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; arg_parser.insert("v", "1", "0:no validation, 1:cpu validation") - .insert("mode", "0", "kernel mode. 0:batch, 1:group") .insert("b", "1", "batch size") .insert("h", "4", "num of head for q") .insert("h_k", "-1", "num of head for k/v, -1 means equal to h") @@ -177,7 +107,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser) { // Parse arguments int do_validation = arg_parser.get_int("v"); - int mode = arg_parser.get_int("mode"); ck_tile::index_t batch = arg_parser.get_int("b"); ck_tile::index_t nhead = arg_parser.get_int("h"); ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); @@ -193,7 +122,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser) uint32_t seed = arg_parser.get_uint32("seed"); int warmup = arg_parser.get_int("warmup"); int repeat = arg_parser.get_int("repeat"); - [[maybe_unused]] int kname = arg_parser.get_int("kname"); + int kname = arg_parser.get_int("kname"); // Handle default values if(nhead_k < 0) @@ -206,6 +135,23 @@ bool run_test(const ck_tile::ArgParser& arg_parser) ck_tile::index_t BLKQ = block_size; ck_tile::index_t BLKK = block_size; + if(block_size != 128 || hdim_q != 128 || hdim_v != 128) + { + std::cout << "\n>>> TEST SKIPPED <<<" << std::endl; + std::cout << "Jenga kernel instances are generated for block_size=128 and hdim=128 only." + << std::endl; + std::cout << "TEST SKIPPED" << std::endl; + return true; + } + + if(bias_type == 1) + { + std::cout << "\n>>> TEST SKIPPED <<<" << std::endl; + std::cout << "Elementwise bias is not supported by generated Jenga kernels." << std::endl; + std::cout << "TEST SKIPPED" << std::endl; + return true; + } + // Calculate number of Q and K blocks ck_tile::index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ; ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK; @@ -225,17 +171,19 @@ bool run_test(const ck_tile::ArgParser& arg_parser) std::cout << " i_perm: " << i_perm << ", o_perm: " << o_perm << std::endl; // Create host tensors (using BHSD layout when i_perm=true) - ck_tile::HostTensor q_host({batch, nhead, seqlen_q, hdim_q}); - ck_tile::HostTensor k_host({batch, nhead_k, seqlen_k, hdim_q}); - ck_tile::HostTensor v_host({batch, nhead_k, seqlen_k, hdim_v}); - ck_tile::HostTensor output_host({batch, nhead, seqlen_q, hdim_v}); + ck_tile::HostTensor q_host = make_qkv_tensor(batch, nhead, seqlen_q, hdim_q, i_perm); + ck_tile::HostTensor k_host = make_qkv_tensor(batch, nhead_k, seqlen_k, hdim_q, i_perm); + ck_tile::HostTensor v_host = make_qkv_tensor(batch, nhead_k, seqlen_k, hdim_v, i_perm); + ck_tile::HostTensor output_host = + o_perm ? ck_tile::HostTensor({batch, nhead, seqlen_q, hdim_v}) + : ck_tile::HostTensor({batch, seqlen_q, nhead, hdim_v}); ck_tile::HostTensor output_ref({batch, nhead, seqlen_q, hdim_v}); // Bias tensor [B, H, S_q, S_k] ck_tile::HostTensor bias_host({batch, nhead, seqlen_q, seqlen_k}); // Block relation onehot: [B, H, Q_blocks, K_blocks] - ck_tile::HostTensor block_relation_onehot({batch, nhead, num_q_blocks, num_k_blocks}); + ck_tile::HostTensor block_relation_onehot({batch, nhead, num_q_blocks, num_k_blocks}); // Initialize tensors with random values std::cout << "\nInitializing tensors..." << std::endl; @@ -266,12 +214,12 @@ bool run_test(const ck_tile::ArgParser& arg_parser) if(is_diagonal || random_active) { - block_relation_onehot(b, h, qb, kb) = static_cast(1.0f); + block_relation_onehot(b, h, qb, kb) = static_cast(1); active_blocks++; } else { - block_relation_onehot(b, h, qb, kb) = static_cast(0.0f); + block_relation_onehot(b, h, qb, kb) = static_cast(0); } } } @@ -284,9 +232,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser) << total_blocks << " blocks active)" << std::endl; // Optional tensors - std::optional> bias_opt = std::nullopt; - std::optional> seqstart_q_opt = std::nullopt; - std::optional> seqstart_k_opt = std::nullopt; + std::optional> bias_opt = std::nullopt; if(bias_type != 0) { @@ -298,6 +244,29 @@ bool run_test(const ck_tile::ArgParser& arg_parser) try { + if(kname) + { + jenga_sparse_attention(q_host, + k_host, + v_host, + block_relation_onehot, + output_host, + bias_opt, + bias_type, + batch, + nhead, + nhead_k, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + i_perm, + o_perm, + seqlen_q, + seqlen_k, + 1); + } + // Warmup for(int i = 0; i < warmup; ++i) { @@ -307,8 +276,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser) block_relation_onehot, output_host, bias_opt, - seqstart_q_opt, - seqstart_k_opt, bias_type, batch, nhead, @@ -317,11 +284,11 @@ bool run_test(const ck_tile::ArgParser& arg_parser) seqlen_k, hdim_q, hdim_v, - mode, i_perm, o_perm, seqlen_q, - seqlen_k); + seqlen_k, + 0); } // Benchmark @@ -336,8 +303,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser) block_relation_onehot, output_host, bias_opt, - seqstart_q_opt, - seqstart_k_opt, bias_type, batch, nhead, @@ -346,11 +311,11 @@ bool run_test(const ck_tile::ArgParser& arg_parser) seqlen_k, hdim_q, hdim_v, - mode, i_perm, o_perm, seqlen_q, - seqlen_k); + seqlen_k, + 0); } [[maybe_unused]] auto sync_status2 = hipDeviceSynchronize(); @@ -376,15 +341,11 @@ bool run_test(const ck_tile::ArgParser& arg_parser) float scale = 1.0f / std::sqrt(static_cast(hdim_q)); std::cout << "Computing reference output..." << std::endl; - reference_blocked_attention(q_host, - k_host, - v_host, - block_relation_onehot, - bias_host, - output_ref, - BLKQ, - BLKK, - scale); + auto q_ref = to_bhsd(q_host, i_perm); + auto k_ref = to_bhsd(k_host, i_perm); + auto v_ref = to_bhsd(v_host, i_perm); + ck_tile::reference_blocked_attention( + q_ref, k_ref, v_ref, block_relation_onehot, bias_host, output_ref, BLKQ, BLKK, scale); // Compare results auto [rtol, atol] = get_error_tolerance(); @@ -393,9 +354,10 @@ bool run_test(const ck_tile::ArgParser& arg_parser) float max_rel_diff = 0.0f; size_t num_errors = 0; - for(size_t i = 0; i < output_host.mData.size(); ++i) + auto output_host_bhsd = to_bhsd(output_host, o_perm); + for(size_t i = 0; i < output_host_bhsd.mData.size(); ++i) { - float gpu_val = static_cast(output_host.mData[i]); + float gpu_val = static_cast(output_host_bhsd.mData[i]); float ref_val = static_cast(output_ref.mData[i]); float diff = std::abs(gpu_val - ref_val); float rel_diff = (std::abs(ref_val) > 1e-6f) ? diff / std::abs(ref_val) : diff; @@ -417,8 +379,8 @@ bool run_test(const ck_tile::ArgParser& arg_parser) std::cout << "\nValidation results:" << std::endl; std::cout << " Max absolute difference: " << max_diff << std::endl; std::cout << " Max relative difference: " << max_rel_diff << std::endl; - std::cout << " Number of mismatches: " << num_errors << " / " << output_host.mData.size() - << std::endl; + std::cout << " Number of mismatches: " << num_errors << " / " + << output_host_bhsd.mData.size() << std::endl; if(num_errors == 0) { diff --git a/example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp b/example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp index 32d64872cf2..3f0ed27c4e3 100644 --- a/example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp +++ b/example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp @@ -13,6 +13,7 @@ #include "ck_tile/host.hpp" #include "ck_tile/core.hpp" +#include "ck_tile/host/reference/reference_blocked_attention.hpp" #include "jenga_sparse_attention.h" #include "fmha_fwd_trek.hpp" @@ -21,6 +22,46 @@ // Helper Functions // ============================================================================ +template +ck_tile::HostTensor make_qkv_tensor(ck_tile::index_t batch, + ck_tile::index_t nhead, + ck_tile::index_t seqlen, + ck_tile::index_t hdim, + bool i_perm) +{ + if(i_perm) + { + return ck_tile::HostTensor({batch, nhead, seqlen, hdim}); + } + return ck_tile::HostTensor({batch, seqlen, nhead, hdim}); +} + +template +ck_tile::HostTensor to_bhsd(const ck_tile::HostTensor& tensor, bool is_bhsd) +{ + auto lens = tensor.get_lengths(); + ck_tile::index_t batch = lens[0]; + ck_tile::index_t seqlen = is_bhsd ? lens[2] : lens[1]; + ck_tile::index_t nhead = is_bhsd ? lens[1] : lens[2]; + ck_tile::index_t hdim = lens[3]; + + ck_tile::HostTensor out({batch, nhead, seqlen, hdim}); + for(ck_tile::index_t b = 0; b < batch; ++b) + { + for(ck_tile::index_t h = 0; h < nhead; ++h) + { + for(ck_tile::index_t s = 0; s < seqlen; ++s) + { + for(ck_tile::index_t d = 0; d < hdim; ++d) + { + out(b, h, s, d) = is_bhsd ? tensor(b, h, s, d) : tensor(b, s, h, d); + } + } + } + } + return out; +} + // Convert block_relation_onehot to LUT format (similar to triton_block_map_to_lut_kernel) template void block_map_to_lut( @@ -59,117 +100,6 @@ void block_map_to_lut( } } -// Reference implementation: blocked attention (similar to pytorch_blocked_attention) -template -void reference_blocked_attention( - const ck_tile::HostTensor& q, // [B, H, S_q, D] - const ck_tile::HostTensor& k, // [B, H, S_k, D] - const ck_tile::HostTensor& v, // [B, H, S_k, D_v] - const ck_tile::HostTensor& block_relation, // [B, H, Q_blocks, K_blocks] - const ck_tile::HostTensor& bias, // [B, H, S_q, S_k] - ck_tile::HostTensor& output, // [B, H, S_q, D_v] - ck_tile::index_t BLKQ, - ck_tile::index_t BLKK, - AccT scale) -{ - auto q_lengths = q.get_lengths(); - ck_tile::index_t batch = q_lengths[0]; - ck_tile::index_t nhead = q_lengths[1]; - ck_tile::index_t seqlen_q = q_lengths[2]; - ck_tile::index_t hdim = q_lengths[3]; - - auto v_lengths = v.get_lengths(); - ck_tile::index_t seqlen_k = v_lengths[2]; - ck_tile::index_t hdim_v = v_lengths[3]; - - ck_tile::index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ; - ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK; - - for(ck_tile::index_t b = 0; b < batch; ++b) - { - for(ck_tile::index_t h = 0; h < nhead; ++h) - { - for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb) - { - ck_tile::index_t q_start = qb * BLKQ; - ck_tile::index_t q_end = q_start + BLKQ; - - // Find relevant K blocks - std::vector relevant_k_indices; - for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) - { - if(static_cast(block_relation(b, h, qb, kb)) > 0.5f) - { - relevant_k_indices.push_back(kb); - } - } - - if(relevant_k_indices.empty()) - continue; - - // For each query position in the block - for(ck_tile::index_t sq = q_start; sq < q_end; ++sq) - { - // Compute attention scores for all relevant K blocks - std::vector scores; - AccT max_score = -std::numeric_limits::infinity(); - - for(auto kb : relevant_k_indices) - { - ck_tile::index_t k_start = kb * BLKK; - ck_tile::index_t k_end = k_start + BLKK; - - for(ck_tile::index_t sk = k_start; sk < k_end; ++sk) - { - AccT score = 0.0f; - for(ck_tile::index_t d = 0; d < hdim; ++d) - { - score += static_cast(q(b, h, sq, d)) * - static_cast(k(b, h, sk, d)); - } - score = score * scale + static_cast(bias(b, h, sq, sk)); - scores.push_back(score); - max_score = std::max(max_score, score); - } - } - - // Softmax - AccT sum_exp = 0.0f; - for(auto& s : scores) - { - s = std::exp(s - max_score); - sum_exp += s; - } - for(auto& s : scores) - { - s /= sum_exp; - } - - // Compute output: P @ V - for(ck_tile::index_t dv = 0; dv < hdim_v; ++dv) - { - AccT out_val = 0.0f; - size_t score_idx = 0; - - for(auto kb : relevant_k_indices) - { - ck_tile::index_t k_start = kb * BLKK; - ck_tile::index_t k_end = k_start + BLKK; - - for(ck_tile::index_t sk = k_start; sk < k_end; ++sk) - { - out_val += scores[score_idx] * static_cast(v(b, h, sk, dv)); - score_idx++; - } - } - output(b, h, sq, dv) = static_cast(out_val); - } - } - } - } - } -} - // Get error tolerance based on data type template auto get_error_tolerance() @@ -186,7 +116,6 @@ auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; arg_parser.insert("v", "1", "0:no validation, 1:cpu validation") - .insert("mode", "0", "kernel mode. 0:batch, 1:group") .insert("b", "1", "batch size") .insert("h", "4", "num of head for q") .insert("h_k", "-1", "num of head for k/v, -1 means equal to h") @@ -217,7 +146,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser) { // Parse arguments int do_validation = arg_parser.get_int("v"); - int mode = arg_parser.get_int("mode"); ck_tile::index_t batch = arg_parser.get_int("b"); ck_tile::index_t nhead = arg_parser.get_int("h"); ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); @@ -233,7 +161,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser) uint32_t seed = arg_parser.get_uint32("seed"); int warmup = arg_parser.get_int("warmup"); int repeat = arg_parser.get_int("repeat"); - [[maybe_unused]] int kname = arg_parser.get_int("kname"); + int kname = arg_parser.get_int("kname"); // Handle default values if(nhead_k < 0) @@ -246,6 +174,15 @@ bool run_test(const ck_tile::ArgParser& arg_parser) ck_tile::index_t BLKQ = block_size; ck_tile::index_t BLKK = block_size; + if(block_size != 128 || hdim_q != 128 || hdim_v != 128) + { + std::cout << "\n>>> TEST SKIPPED <<<" << std::endl; + std::cout << "VSA kernel instances are generated for block_size=128 and hdim=128 only." + << std::endl; + std::cout << "TEST SKIPPED" << std::endl; + return true; + } + // Calculate number of Q and K blocks ck_tile::index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ; ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK; @@ -268,17 +205,19 @@ bool run_test(const ck_tile::ArgParser& arg_parser) // Q: [B, H, S_q, D] // K: [B, H_k, S_k, D] // V: [B, H_k, S_k, D_v] - ck_tile::HostTensor q_host({batch, nhead, seqlen_q, hdim_q}); - ck_tile::HostTensor k_host({batch, nhead_k, seqlen_k, hdim_q}); - ck_tile::HostTensor v_host({batch, nhead_k, seqlen_k, hdim_v}); - ck_tile::HostTensor output_host({batch, nhead, seqlen_q, hdim_v}); + ck_tile::HostTensor q_host = make_qkv_tensor(batch, nhead, seqlen_q, hdim_q, i_perm); + ck_tile::HostTensor k_host = make_qkv_tensor(batch, nhead_k, seqlen_k, hdim_q, i_perm); + ck_tile::HostTensor v_host = make_qkv_tensor(batch, nhead_k, seqlen_k, hdim_v, i_perm); + ck_tile::HostTensor output_host = + o_perm ? ck_tile::HostTensor({batch, nhead, seqlen_q, hdim_v}) + : ck_tile::HostTensor({batch, seqlen_q, nhead, hdim_v}); ck_tile::HostTensor output_ref({batch, nhead, seqlen_q, hdim_v}); // Bias tensor [B, H, S_q, S_k] ck_tile::HostTensor bias_host({batch, nhead, seqlen_q, seqlen_k}); // Block relation onehot: [B, H, Q_blocks, K_blocks] - ck_tile::HostTensor block_relation_onehot({batch, nhead, num_q_blocks, num_k_blocks}); + ck_tile::HostTensor block_relation_onehot({batch, nhead, num_q_blocks, num_k_blocks}); // LUT and valid_block_num (output of block_map_to_lut) - must be int32_t for kernel ck_tile::HostTensor lut_host({batch, nhead, num_q_blocks, num_k_blocks}); @@ -315,12 +254,12 @@ bool run_test(const ck_tile::ArgParser& arg_parser) if(is_diagonal || random_active) { - block_relation_onehot(b, h, qb, kb) = static_cast(1.0f); + block_relation_onehot(b, h, qb, kb) = static_cast(1); active_blocks++; } else { - block_relation_onehot(b, h, qb, kb) = static_cast(0.0f); + block_relation_onehot(b, h, qb, kb) = static_cast(0); } } } @@ -339,9 +278,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser) // vsa_sparse_attention handles device memory internally // Optional tensors - std::optional> bias_opt = std::nullopt; - std::optional> seqstart_q_opt = std::nullopt; - std::optional> seqstart_k_opt = std::nullopt; + std::optional> bias_opt = std::nullopt; if(bias_type != 0) { @@ -353,6 +290,30 @@ bool run_test(const ck_tile::ArgParser& arg_parser) try { + if(kname) + { + vsa_sparse_attention(q_host, + k_host, + v_host, + lut_host, + valid_block_num_host, + output_host, + bias_opt, + bias_type, + batch, + nhead, + nhead_k, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + i_perm, + o_perm, + seqlen_q, + seqlen_k, + 1); + } + // Warmup for(int i = 0; i < warmup; ++i) { @@ -363,8 +324,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser) valid_block_num_host, output_host, bias_opt, - seqstart_q_opt, - seqstart_k_opt, bias_type, batch, nhead, @@ -373,11 +332,11 @@ bool run_test(const ck_tile::ArgParser& arg_parser) seqlen_k, hdim_q, hdim_v, - mode, i_perm, o_perm, seqlen_q, - seqlen_k); + seqlen_k, + 0); } // Benchmark @@ -393,8 +352,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser) valid_block_num_host, output_host, bias_opt, - seqstart_q_opt, - seqstart_k_opt, bias_type, batch, nhead, @@ -403,11 +360,11 @@ bool run_test(const ck_tile::ArgParser& arg_parser) seqlen_k, hdim_q, hdim_v, - mode, i_perm, o_perm, seqlen_q, - seqlen_k); + seqlen_k, + 0); } [[maybe_unused]] auto sync_status2 = hipDeviceSynchronize(); @@ -437,15 +394,11 @@ bool run_test(const ck_tile::ArgParser& arg_parser) // Run reference implementation std::cout << "Computing reference output..." << std::endl; - reference_blocked_attention(q_host, - k_host, - v_host, - block_relation_onehot, - bias_host, - output_ref, - BLKQ, - BLKK, - scale); + auto q_ref = to_bhsd(q_host, i_perm); + auto k_ref = to_bhsd(k_host, i_perm); + auto v_ref = to_bhsd(v_host, i_perm); + ck_tile::reference_blocked_attention( + q_ref, k_ref, v_ref, block_relation_onehot, bias_host, output_ref, BLKQ, BLKK, scale); // Compare results auto [rtol, atol] = get_error_tolerance(); @@ -454,9 +407,10 @@ bool run_test(const ck_tile::ArgParser& arg_parser) float max_rel_diff = 0.0f; size_t num_errors = 0; - for(size_t i = 0; i < output_host.mData.size(); ++i) + auto output_host_bhsd = to_bhsd(output_host, o_perm); + for(size_t i = 0; i < output_host_bhsd.mData.size(); ++i) { - float gpu_val = static_cast(output_host.mData[i]); + float gpu_val = static_cast(output_host_bhsd.mData[i]); float ref_val = static_cast(output_ref.mData[i]); float diff = std::abs(gpu_val - ref_val); float rel_diff = (std::abs(ref_val) > 1e-6f) ? diff / std::abs(ref_val) : diff; @@ -478,8 +432,8 @@ bool run_test(const ck_tile::ArgParser& arg_parser) std::cout << "\nValidation results:" << std::endl; std::cout << " Max absolute difference: " << max_diff << std::endl; std::cout << " Max relative difference: " << max_rel_diff << std::endl; - std::cout << " Number of mismatches: " << num_errors << " / " << output_host.mData.size() - << std::endl; + std::cout << " Number of mismatches: " << num_errors << " / " + << output_host_bhsd.mData.size() << std::endl; if(num_errors == 0) { diff --git a/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu b/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu index df13084d8fa..d3492ed6a48 100644 --- a/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu +++ b/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu @@ -16,8 +16,6 @@ vsa_sparse_attention(ck_tile::HostTensor& TQ, ck_tile::HostTensor& TKV_blocks, ck_tile::HostTensor& Y, std::optional> bias, - std::optional> seqstart_q, - std::optional> seqstart_k, int bias_type, int batch, int nhead, @@ -26,18 +24,28 @@ vsa_sparse_attention(ck_tile::HostTensor& TQ, int seqlen_k, int hdim_q, int hdim_v, - int mode, bool i_perm, bool o_perm, int max_seqlen_q, - int max_seqlen_k) + int max_seqlen_k, + int log_level) { // Determine data type string based on template parameter + constexpr bool is_fp8 = + std::is_same_v || std::is_same_v; std::string data_type = "fp16"; if constexpr(std::is_same_v) { data_type = "bf16"; } + else if constexpr(std::is_same_v) + { + data_type = "fp8"; + } + else if constexpr(std::is_same_v) + { + data_type = "bf8"; + } if(max_seqlen_q == 0) max_seqlen_q = seqlen_q; @@ -52,12 +60,12 @@ vsa_sparse_attention(ck_tile::HostTensor& TQ, std::string msk_str = "0"; mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k); - const ck_tile::index_t shape_seqlen_q = (mode == 0 ? seqlen_q : max_seqlen_q); - const ck_tile::index_t shape_seqlen_k = (mode == 0 ? seqlen_k : max_seqlen_k); + const ck_tile::index_t shape_seqlen_q = seqlen_q; + const ck_tile::index_t shape_seqlen_k = seqlen_k; ck_tile::stream_config stream_config{nullptr, false, // time_kernel - 0, /* log_level = */ + log_level, 0, 1, false}; @@ -78,18 +86,9 @@ vsa_sparse_attention(ck_tile::HostTensor& TQ, // Optional buffers ck_tile::DeviceMem bias_buf(bias ? bias->get_element_space_size_in_bytes() : 0); - ck_tile::DeviceMem seqstart_q_buf(seqstart_q ? seqstart_q->get_element_space_size_in_bytes() - : 0); - ck_tile::DeviceMem seqstart_k_buf(seqstart_k ? seqstart_k->get_element_space_size_in_bytes() - : 0); if(bias) bias_buf.ToDevice(bias->data()); - if(seqstart_q) - seqstart_q_buf.ToDevice(seqstart_q->data()); - if(seqstart_k) - seqstart_k_buf.ToDevice(seqstart_k->data()); - const auto init_args = [&](auto& args) { assert(nhead % nhead_k == 0); const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); @@ -134,7 +133,7 @@ vsa_sparse_attention(ck_tile::HostTensor& TQ, args.valid_block_num_ptr = valid_block_num_buf.GetDeviceBuffer(); args.batch = batch; - args.seqlen_q = shape_seqlen_q; // unused in group mode + args.seqlen_q = shape_seqlen_q; // batch mode only args.hdim_q = hdim_q; args.hdim_v = hdim_v; args.nhead_q = nhead; @@ -154,11 +153,11 @@ vsa_sparse_attention(ck_tile::HostTensor& TQ, args.lse_ptr = nullptr; args.o_ptr = o_buf.GetDeviceBuffer(); - args.seqstart_q_ptr = (mode == 1 ? seqstart_q_buf.GetDeviceBuffer() : nullptr); - args.seqstart_k_ptr = (mode == 1 ? seqstart_k_buf.GetDeviceBuffer() : nullptr); + args.seqstart_q_ptr = nullptr; + args.seqstart_k_ptr = nullptr; args.seqlen_k_ptr = nullptr; - args.seqlen_k = shape_seqlen_k; // unused in group mode (or kvcache enabled) + args.seqlen_k = shape_seqlen_k; // batch mode only args.max_seqlen_q = max_seqlen_q; args.scale_s = scale_s; @@ -196,12 +195,12 @@ vsa_sparse_attention(ck_tile::HostTensor& TQ, traits.data_type = data_type; traits.is_v_rowmajor = is_v_rowmajor; - traits.is_group_mode = (mode == 1); + traits.is_group_mode = false; traits.has_logits_soft_cap = 0.f < logits_soft_cap; traits.mask_type = mask.type; traits.bias_type = static_cast(bias_type); traits.has_lse = false; - traits.do_fp8_static_quant = false; + traits.do_fp8_static_quant = is_fp8; traits.has_dropout = false; }; @@ -214,8 +213,8 @@ vsa_sparse_attention(ck_tile::HostTensor& TQ, fmha_vsa_fwd(fmha_traits, args, stream_config); - // Copy output back to host - Y = o_buf.ToHost(); + // Copy output back to host without changing tensor shape + o_buf.FromDevice(Y.data(), Y.get_element_space_size_in_bytes()); return Y; } @@ -227,9 +226,7 @@ vsa_sparse_attention( ck_tile::HostTensor&, ck_tile::HostTensor&, ck_tile::HostTensor&, ck_tile::HostTensor&, std::optional>, - std::optional>, - std::optional>, - int, int, int, int, int, int, int, int, int, bool, bool, int, int); + int, int, int, int, int, int, int, int, bool, bool, int, int, int); template ck_tile::HostTensor vsa_sparse_attention( @@ -237,6 +234,4 @@ vsa_sparse_attention( ck_tile::HostTensor&, ck_tile::HostTensor&, ck_tile::HostTensor&, ck_tile::HostTensor&, std::optional>, - std::optional>, - std::optional>, - int, int, int, int, int, int, int, int, int, bool, bool, int, int); + int, int, int, int, int, int, int, int, bool, bool, int, int, int); diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index 2e65557a151..d0d2b918312 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -44,9 +44,9 @@ __device__ inline int32_t amd_wave_read_first_lane(int32_t value) return __builtin_amdgcn_readfirstlane(value); } -__device__ inline int32_t amd_wave_read_first_lane(uintptr_t value) +__device__ inline uint32_t amd_wave_read_first_lane(uintptr_t value) { - return __builtin_amdgcn_readfirstlane(value); + return __builtin_amdgcn_readfirstlane(static_cast(value)); } template , int> = 0> diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index 014fcfdd658..f04879f7cdc 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -27,6 +27,7 @@ #include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp" #include "ck_tile/host/reference/reference_batched_softmax.hpp" #include "ck_tile/host/reference/reference_batched_transpose.hpp" +#include "ck_tile/host/reference/reference_blocked_attention.hpp" #include "ck_tile/host/reference/reference_elementwise.hpp" #include "ck_tile/host/reference/reference_fused_moe.hpp" #include "ck_tile/host/reference/reference_gemm.hpp" diff --git a/include/ck_tile/host/reference/reference_blocked_attention.hpp b/include/ck_tile/host/reference/reference_blocked_attention.hpp new file mode 100644 index 00000000000..ccd5845f51f --- /dev/null +++ b/include/ck_tile/host/reference/reference_blocked_attention.hpp @@ -0,0 +1,137 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +namespace ck_tile { + +// Reference implementation: blocked attention (for sparse attention tests). +template +void reference_blocked_attention( + const HostTensor& q, // [B, H, S_q, D] + const HostTensor& k, // [B, H, S_k, D] + const HostTensor& v, // [B, H, S_k, D_v] + const HostTensor& block_relation, // [B, H, Q_blocks, K_blocks] + const HostTensor& bias, // [B, H, S_q, S_k] + HostTensor& output, // [B, H, S_q, D_v] + index_t BLKQ, + index_t BLKK, + AccT scale) +{ + auto q_lengths = q.get_lengths(); + index_t batch = q_lengths[0]; + index_t nhead = q_lengths[1]; + index_t seqlen_q = q_lengths[2]; + index_t hdim = q_lengths[3]; + + auto v_lengths = v.get_lengths(); + index_t seqlen_k = v_lengths[2]; + index_t hdim_v = v_lengths[3]; + + index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ; + index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK; + + for(index_t b = 0; b < batch; ++b) + { + for(index_t h = 0; h < nhead; ++h) + { + for(index_t qb = 0; qb < num_q_blocks; ++qb) + { + index_t q_start = qb * BLKQ; + if(q_start >= seqlen_q) + { + continue; + } + index_t q_end = std::min(q_start + BLKQ, seqlen_q); + + std::vector relevant_k_indices; + for(index_t kb = 0; kb < num_k_blocks; ++kb) + { + if(static_cast(block_relation(b, h, qb, kb)) > 0.5f) + { + relevant_k_indices.push_back(kb); + } + } + + if(relevant_k_indices.empty()) + { + continue; + } + + for(index_t sq = q_start; sq < q_end; ++sq) + { + std::vector scores; + AccT max_score = -std::numeric_limits::infinity(); + + for(auto kb : relevant_k_indices) + { + index_t k_start = kb * BLKK; + if(k_start >= seqlen_k) + { + continue; + } + index_t k_end = std::min(k_start + BLKK, seqlen_k); + + for(index_t sk = k_start; sk < k_end; ++sk) + { + AccT score = 0.0f; + for(index_t d = 0; d < hdim; ++d) + { + score += static_cast(q(b, h, sq, d)) * + static_cast(k(b, h, sk, d)); + } + score = score * scale + static_cast(bias(b, h, sq, sk)); + scores.push_back(score); + max_score = std::max(max_score, score); + } + } + + AccT sum_exp = 0.0f; + for(auto& s : scores) + { + s = std::exp(s - max_score); + sum_exp += s; + } + for(auto& s : scores) + { + s /= sum_exp; + } + + for(index_t dv = 0; dv < hdim_v; ++dv) + { + AccT out_val = 0.0f; + size_t score_idx = 0; + + for(auto kb : relevant_k_indices) + { + index_t k_start = kb * BLKK; + if(k_start >= seqlen_k) + { + continue; + } + index_t k_end = std::min(k_start + BLKK, seqlen_k); + + for(index_t sk = k_start; sk < k_end; ++sk) + { + out_val += scores[score_idx] * static_cast(v(b, h, sk, dv)); + score_idx++; + } + } + + output(b, h, sq, dv) = static_cast(out_val); + } + } + } + } + } +} + +} // namespace ck_tile diff --git a/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_jenga_kernel.hpp b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_jenga_kernel.hpp index 985b2f5348e..828a3bb072a 100644 --- a/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_jenga_kernel.hpp +++ b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_jenga_kernel.hpp @@ -55,6 +55,8 @@ struct FmhaFwdJengaKernel static constexpr bool kDoFp8StaticQuant = (FmhaPipeline::Problem::QScaleEnum != ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE); static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ; + static_assert(!kStoreLSE, "Jenga sparse attention does not support LSE output."); + static_assert(!kHasDropout, "Jenga sparse attention does not support dropout."); using AttentionVariant = ck_tile::remove_cvref_t; using FmhaMask = ck_tile::remove_cvref_t; @@ -90,7 +92,7 @@ struct FmhaFwdJengaKernel if (kPadHeadDimV) n += "dv"; return n.empty() ? n : std::string("p") + n; }(); return - _SS_("fmha_fwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s::name) + + _SS_("fmha_jenga_fwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s::name) + "_" + (kIsGroupMode ? "group" : "batch") + "_" "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" + @@ -101,7 +103,7 @@ struct FmhaFwdJengaKernel (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) + (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + - (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ); + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ); #undef _SS_ #undef _TS_ // clang-format on @@ -200,67 +202,6 @@ struct FmhaFwdJengaKernel float scale_o; }; - struct FmhaFwdCommonLSEKargs - { - void* lse_ptr = nullptr; - ck_tile::index_t nhead_stride_lse = 0; - ck_tile::index_t batch_stride_lse = 0; - }; - - struct FmhaFwdDropoutSeedOffset - { - template - union ValueOrPointer - { - T val; - const T* ptr; - }; - - ValueOrPointer drop_seed; - ValueOrPointer drop_offset; - bool is_drop_seed_offset_from_host; - }; - - struct FmhaFwdCommonDropoutKargs : FmhaFwdDropoutSeedOffset - { - void init_dropout(float p_drop, uint64_t seed, uint64_t offset) - { - float p_undrop = 1.0 - p_drop; - p_undrop_in_uint8_t = - uint8_t(std::floor(p_undrop * std::numeric_limits::max())); - rp_undrop = 1.0 / p_undrop; - - this->drop_seed.val = seed; - this->drop_offset.val = offset; - this->is_drop_seed_offset_from_host = true; - } - - void init_dropout(float p_drop, const uint64_t* seed_ptr, const uint64_t* offset_ptr) - { - float p_undrop = 1.0 - p_drop; - p_undrop_in_uint8_t = - uint8_t(std::floor(p_undrop * std::numeric_limits::max())); - rp_undrop = 1.0 / p_undrop; - - this->drop_seed.ptr = seed_ptr; - this->drop_offset.ptr = offset_ptr; - this->is_drop_seed_offset_from_host = false; - } - - float rp_undrop = 1; - uint8_t p_undrop_in_uint8_t = std::numeric_limits::max(); - bool is_store_randval = false; - void* rand_val_ptr = nullptr; - - ck_tile::index_t stride_randval = 0; - ck_tile::index_t nhead_stride_randval = 0; - }; - - struct FmhaFwdBatchModeDropoutKargs : FmhaFwdCommonDropoutKargs - { - ck_tile::index_t batch_stride_randval = 0; - }; - struct FmhaFwdSkipMinSeqlenQKargs { ck_tile::index_t min_seqlen_q = 0; @@ -274,9 +215,9 @@ struct FmhaFwdJengaKernel FmhaFwdAlibiKargs, FmhaFwdEmptyKargs<0>>>, std::conditional_t>, - std::conditional_t>, + FmhaFwdEmptyKargs<2>, std::conditional_t>, - std::conditional_t>, + FmhaFwdEmptyKargs<4>, std::conditional_t> { ck_tile::index_t batch_stride_q; @@ -293,9 +234,9 @@ struct FmhaFwdJengaKernel FmhaFwdAlibiKargs, FmhaFwdEmptyKargs<0>>>, std::conditional_t>, - std::conditional_t>, + FmhaFwdEmptyKargs<2>, std::conditional_t>, - std::conditional_t>, + FmhaFwdEmptyKargs<4>, std::conditional_t>, std::conditional_t> { @@ -1050,6 +991,7 @@ struct FmhaFwdJengaKernel CK_TILE_DEVICE void operator()(Kargs kargs) const { // allocate LDS + // Extra LDS for staging block_relation_onehot (256 bools); keep 4B alignment for LDS loads. __shared__ char smem_ptr[GetSmemSize() + 256 * sizeof(int)]; // if (threadIdx.x==0 && blockIdx.x==0 && blockIdx.z ==0) printf("smem size: %d", @@ -1251,21 +1193,6 @@ struct FmhaFwdJengaKernel } }(); - // sparse mask - // const auto lut_dram = make_naive_tensor_view( - // lut_ptr, - // make_tuple(kargs.seqlen_k/number{}, 1), - // make_tuple(1, 1), - // number<1>{}, - // number<1>{}); - - // const auto valid_block_num_dram = make_naive_tensor_view( - // valid_block_num_ptr, - // make_tuple(kargs.seqlen_q/number{}), - // make_tuple(1), - // number<1>{}, - // number<1>{}); - auto q_dram_window = make_tile_window( q_dram, [&]() { @@ -1285,11 +1212,6 @@ struct FmhaFwdJengaKernel make_tuple(number{}, number{}), {i_n1, 0}); - // auto lut_dram_window = make_tile_window( - // lut_dram, make_tuple(1,1), {0,0}); - // auto valid_block_num_window = make_tile_window( - // valid_block_num_dram, make_tuple(1), {i_tile_m}); - /// FIXME: Before C++20, capturing structured binding variables are not supported. Remove /// following copy capture of the 'i_nhead' if in C++20 const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { diff --git a/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp index 52948f0d30b..e17a4da6c81 100644 --- a/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp +++ b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp @@ -91,7 +91,7 @@ struct FmhaFwdVSAKernel if (kPadHeadDimV) n += "dv"; return n.empty() ? n : std::string("p") + n; }(); return - _SS_("fmha_fwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s::name) + + _SS_("fmha_vsa_fwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s::name) + "_" + (kIsGroupMode ? "group" : "batch") + "_" "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" + @@ -102,7 +102,7 @@ struct FmhaFwdVSAKernel (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) + (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + - (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ); + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ); #undef _SS_ #undef _TS_ // clang-format on @@ -1061,11 +1061,9 @@ struct FmhaFwdVSAKernel CK_TILE_DEVICE void operator()(Kargs kargs) const { // allocate LDS + // Extra LDS for staging block_relation_onehot (256 bools); keep 4B alignment for LDS loads. __shared__ char smem_ptr[GetSmemSize() + 256 * sizeof(int)]; - // if (threadIdx.x==0 && blockIdx.x==0 && blockIdx.z ==0) printf("smem size: %d", - // int(GetSmemSize())); - // divide problem const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); @@ -1185,7 +1183,6 @@ struct FmhaFwdVSAKernel static_cast(i_batch * kargs.num_head_q + i_nhead) * ck_tile::integer_divide_ceil(kargs.seqlen_q, FmhaPipeline::kM0) + i_tile_m; - // int valid_block_num_value = __builtin_amdgcn_readfirstlane(valid_block_num_ptr[0]); const int valid_block_num_value = valid_block_num_ptr[0]; ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + @@ -1269,21 +1266,6 @@ struct FmhaFwdVSAKernel } }(); - // sparse mask - // const auto lut_dram = make_naive_tensor_view( - // lut_ptr, - // make_tuple(kargs.seqlen_k/number{}, 1), - // make_tuple(1, 1), - // number<1>{}, - // number<1>{}); - - // const auto valid_block_num_dram = make_naive_tensor_view( - // valid_block_num_ptr, - // make_tuple(kargs.seqlen_q/number{}), - // make_tuple(1), - // number<1>{}, - // number<1>{}); - auto q_dram_window = make_tile_window( q_dram, [&]() { @@ -1303,11 +1285,6 @@ struct FmhaFwdVSAKernel make_tuple(number{}, number{}), {i_n1, 0}); - // auto lut_dram_window = make_tile_window( - // lut_dram, make_tuple(1,1), {0,0}); - // auto valid_block_num_window = make_tile_window( - // valid_block_num_dram, make_tuple(1), {i_tile_m}); - /// FIXME: Before C++20, capturing structured binding variables are not supported. Remove /// following copy capture of the 'i_nhead' if in C++20 const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { diff --git a/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp index 1a5a3f77b5a..b198b904cde 100644 --- a/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp +++ b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp @@ -235,14 +235,6 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); - bool* block_relation_onehot = reinterpret_cast(smem_ptr) + GetSmemSize(); - amd_direct_load_global_to_lds(block_relation_onehot_ptr, - 4 * threadIdx.x, - block_relation_onehot, - 4 * threadIdx.x, - threadIdx.x / 64 == 0, - 256); - auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), q_dram_block_window_tmp.get_window_lengths(), q_dram_block_window_tmp.get_window_origin(), @@ -287,10 +279,6 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); - // if (threadIdx.x==0 && blockIdx.y==0) { - // printf("\nblockIdx.x : %d, seqlen_k_start: %d, seqlen_k_end: %d\n", blockIdx.x, - // seqlen_k_start, seqlen_k_end); - // } // check early exit if no work to do if constexpr(FmhaMask::IsMasking || kPadSeqLenK) @@ -315,6 +303,16 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga __builtin_amdgcn_sched_barrier(0); // make sure sched_barrier(0) for this check } + const index_t num_block = num_total_loop; + bool* block_relation_onehot = reinterpret_cast(smem_ptr) + GetSmemSize(); + const index_t thread_offset = static_cast(4 * threadIdx.x); + amd_direct_load_global_to_lds(block_relation_onehot_ptr, + 4 * threadIdx.x, + block_relation_onehot, + 4 * threadIdx.x, + thread_offset < num_block, + num_block); + auto k_dram_block_window = make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), k_dram_block_window_tmp.get_window_lengths(), From 83bed30c45ed47d0e601e017477ee13cf3ea5e50 Mon Sep 17 00:00:00 2001 From: Jiangyong Date: Wed, 28 Jan 2026 23:32:39 +0800 Subject: [PATCH 16/22] remove the useless code & resolved the comments --- example/ck_tile/50_sparse_attn/CMakeLists.txt | 8 +- .../50_sparse_attn/codegen/cpp_symbol_map.py | 21 +- .../codegen/ops/fmha_fwd_jenga.py | 347 +++-------- .../codegen/ops/fmha_fwd_vsa.py | 302 +++------- .../ck_tile/50_sparse_attn/fmha_fwd_trek.hpp | 554 +++++++++--------- ...ttention.cu => jenga_sparse_attention.cpp} | 115 ++-- .../50_sparse_attn/jenga_sparse_attention.h | 50 +- example/ck_tile/50_sparse_attn/mask.hpp | 175 ------ .../50_sparse_attn/test_jenga_sparse_attn.cpp | 61 +- .../50_sparse_attn/test_vsa_sparse_attn.cpp | 55 +- ..._attention.cu => vsa_sparse_attention.cpp} | 121 ++-- .../reference/reference_blocked_attention.hpp | 29 +- .../kernel/fmha_fwd_jenga_kernel.hpp | 412 ++----------- .../kernel/fmha_fwd_vsa_kernel.hpp | 456 ++------------ 14 files changed, 738 insertions(+), 1968 deletions(-) rename example/ck_tile/50_sparse_attn/{jenga_sparse_attention.cu => jenga_sparse_attention.cpp} (69%) delete mode 100644 example/ck_tile/50_sparse_attn/mask.hpp rename example/ck_tile/50_sparse_attn/{vsa_sparse_attention.cu => vsa_sparse_attention.cpp} (68%) diff --git a/example/ck_tile/50_sparse_attn/CMakeLists.txt b/example/ck_tile/50_sparse_attn/CMakeLists.txt index 5ea00e39129..65bb2077642 100644 --- a/example/ck_tile/50_sparse_attn/CMakeLists.txt +++ b/example/ck_tile/50_sparse_attn/CMakeLists.txt @@ -60,14 +60,14 @@ set(SPARSE_ATTN_JENGA_INSTANCES "tile_sparse_attn_jenga_instances") add_library(${SPARSE_ATTN_JENGA_INSTANCES} OBJECT EXCLUDE_FROM_ALL ${SPARSE_ATTN_JENGA_GEN_BLOBS} - ${CMAKE_CURRENT_LIST_DIR}/jenga_sparse_attention.cu + ${CMAKE_CURRENT_LIST_DIR}/jenga_sparse_attention.cpp ) target_include_directories(${SPARSE_ATTN_JENGA_INSTANCES} PRIVATE ${CMAKE_CURRENT_LIST_DIR} ${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn ) set_source_files_properties(${SPARSE_ATTN_JENGA_GEN_BLOBS} PROPERTIES LANGUAGE HIP) -set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/jenga_sparse_attention.cu PROPERTIES LANGUAGE HIP) +set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/jenga_sparse_attention.cpp PROPERTIES LANGUAGE HIP) set_property(TARGET ${SPARSE_ATTN_JENGA_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) target_compile_options(${SPARSE_ATTN_JENGA_INSTANCES} PRIVATE @@ -125,14 +125,14 @@ set(SPARSE_ATTN_VSA_INSTANCES "tile_sparse_attn_vsa_instances") add_library(${SPARSE_ATTN_VSA_INSTANCES} OBJECT EXCLUDE_FROM_ALL ${SPARSE_ATTN_VSA_GEN_BLOBS} - ${CMAKE_CURRENT_LIST_DIR}/vsa_sparse_attention.cu + ${CMAKE_CURRENT_LIST_DIR}/vsa_sparse_attention.cpp ) target_include_directories(${SPARSE_ATTN_VSA_INSTANCES} PRIVATE ${CMAKE_CURRENT_LIST_DIR} ${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn ) set_source_files_properties(${SPARSE_ATTN_VSA_GEN_BLOBS} PROPERTIES LANGUAGE HIP) -set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/vsa_sparse_attention.cu PROPERTIES LANGUAGE HIP) +set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/vsa_sparse_attention.cpp PROPERTIES LANGUAGE HIP) set_property(TARGET ${SPARSE_ATTN_VSA_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) target_compile_options(${SPARSE_ATTN_VSA_INSTANCES} PRIVATE diff --git a/example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py b/example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py index c7aec047462..d2b655cfd1e 100644 --- a/example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py +++ b/example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py @@ -3,11 +3,8 @@ # generate kernel instances to speed up compilation FWD_DTYPE_MAP = { - "fp16": "FmhaFwdFp16", - "bf16": "FmhaFwdBf16", - "fp8": "FmhaFwdFp8", - "fp8fp16": "FmhaFwdFp8Fp16", - "fp8bf16": "FmhaFwdFp8Bf16", + "fp16": "FmhaSparseFwdFp16", + "bf16": "FmhaSparseFwdBf16", } _MASK_SIMPLIFIED_MAP = { @@ -56,37 +53,26 @@ def get_mask_check_map(mask: str): BIAS_MAP = { "no": "ck_tile::BlockAttentionBiasEnum::NO_BIAS", - "bias": "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS", - "alibi": "ck_tile::BlockAttentionBiasEnum::ALIBI", } # TODO: this is ugly BIAS_CHECK_MAP = { "no": "bias_enum::no_bias", - "bias": "bias_enum::elementwise_bias", - "alibi": "bias_enum::alibi", } -MODE_MAP = {"batch": "false", "group": "true"} +MODE_MAP = {"batch": "false"} LAYOUT_MAP = {"row": "true", "col": "false"} PIPELINE_MAP = { - "qr": "ck_tile::BlockFmhaPipelineQRKSVS", "qr_async": "ck_tile::BlockFmhaPipelineQRKSVSAsyncJenga", - "qs": "ck_tile::BlockFmhaPipelineQSKSVS", "qr_async_vsa": "ck_tile::BlockFmhaPipelineQRKSVSAsyncVSA", } PIPELINE_ENUM_MAP = { - "qr": "ck_tile::BlockFmhaPipelineEnum::QRKSVS", "qr_async": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", "qr_async_vsa": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", - "qr_nwarp_sshuffle": "ck_tile::BlockFmhaPipelineEnum::QRKSVS", - "qs": "ck_tile::BlockFmhaPipelineEnum::QSKSVS", - "qr_pagedkv": "ck_tile::BlockFmhaPipelineEnum::QRKSVS", - "qr_async_trload": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD", } BOOL_MAP = { @@ -97,6 +83,5 @@ def get_mask_check_map(mask: str): } SQUANT_MAP = { - "t": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR", "f": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE", } diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py index e598210291c..4fb5db365ea 100644 --- a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py +++ b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py @@ -44,7 +44,7 @@ def update_file(file_path, content): file.write(content) -DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16, "fp8": 8, "bf8": 8} +DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16} K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256} @@ -88,17 +88,17 @@ def update_file(file_path, content): using fmha_mask_{F_idx} = {F_mask}; using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< - typename FmhaFwdTypeConfig::QDataType, - typename FmhaFwdTypeConfig::KDataType, - typename FmhaFwdTypeConfig::VDataType, - typename FmhaFwdTypeConfig::SaccDataType, - typename FmhaFwdTypeConfig::SMPLComputeDataType, - typename FmhaFwdTypeConfig::BiasDataType, - typename FmhaFwdTypeConfig::RandValOutputDataType, - typename FmhaFwdTypeConfig::LSEDataType, - typename FmhaFwdTypeConfig::PDataType, - typename FmhaFwdTypeConfig::OaccDataType, - typename FmhaFwdTypeConfig::ODataType, + typename FmhaSparseFwdTypeConfig::QDataType, + typename FmhaSparseFwdTypeConfig::KDataType, + typename FmhaSparseFwdTypeConfig::VDataType, + typename FmhaSparseFwdTypeConfig::SaccDataType, + typename FmhaSparseFwdTypeConfig::SMPLComputeDataType, + typename FmhaSparseFwdTypeConfig::BiasDataType, + typename FmhaSparseFwdTypeConfig::RandValOutputDataType, + typename FmhaSparseFwdTypeConfig::LSEDataType, + typename FmhaSparseFwdTypeConfig::PDataType, + typename FmhaSparseFwdTypeConfig::OaccDataType, + typename FmhaSparseFwdTypeConfig::ODataType, fmha_shape_{F_idx}, {F_mode}, fmha_variant_{F_idx}, @@ -110,8 +110,8 @@ def update_file(file_path, content): fmha_pipeline_problem_{F_idx}>; using fmha_epilogue_{F_idx} = - ck_tile::Default2DEpilogue::OaccDataType, - typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaSparseFwdTypeConfig<{F_dtype}>::ODataType, {F_spad}, {F_dvpad}>>; using fmha_kernel_{F_idx} = @@ -127,7 +127,7 @@ def update_file(file_path, content): {{ using k_ = fmha_kernel_{F_idx}; if(s.log_level_ > 0) - std::cout << ", " << k_::GetName() << std::flush; + std::cout << ", " << "{F_kernel_name}" << std::flush; auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; @@ -204,7 +204,7 @@ def update_file(file_path, content): }} """ -FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && +FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ using trait_ = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; return fmha_jenga_fwd_(s, a); @@ -265,18 +265,9 @@ def name(self) -> str: def scheck(self) -> str: if self.mode == "group": return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true - if self.pipeline_tag in ["qr_async", "qr_async_trload"]: - if self.spad == "t": - return "true" # always support - else: - return "true" - elif self.pipeline_tag in ["qr", "qs"]: - if self.spad == "t": - return f"true /*a.seqlen_q % {self.bm0} != 0*/" # TODO: order of get_pipelines() matters! (ugly) - else: - return f"a.seqlen_q % {self.bm0} == 0" - else: - assert False + if self.spad == "t": + return "true" # always support + return "true" @property def seqtune(self) -> str: @@ -289,57 +280,23 @@ def seqtune(self) -> str: def skcheck(self) -> str: if self.mode == "group": return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true - if self.pipeline_tag == "qr_async": - if self.skpad == "t": - return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0" - else: - return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0" - elif self.pipeline_tag in ["qr", "qs"]: - if self.skpad == "t": - return f"true /*a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly) - else: - return f"a.seqlen_k % {self.bn0} == 0" - elif self.pipeline_tag == "qr_async_trload": - if self.skpad == "t": - return "true" - else: - return "true" - else: - assert False + if self.skpad == "t": + return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0" + return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0" @property def dcheck(self) -> str: - if self.pipeline_tag == "qr_async": - vec = int((32 * 4) / DTYPE_BITS[self.dtype]) - if self.dpad == "t": - return f"a.hdim_q % {vec} == 0" - else: - assert False - elif self.pipeline_tag in ["qr", "qs", "qr_async_trload"]: - bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] - if self.dpad == "t": - return f"true /*a.hdim_q % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly) - else: - return f"a.hdim_q % {bk0submax} == 0" - else: - assert False + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dpad == "t": + return f"a.hdim_q % {vec} == 0" + assert False @property def dvcheck(self) -> str: - if self.pipeline_tag == "qr_async": - vec = int((32 * 4) / DTYPE_BITS[self.dtype]) - if self.dvpad == "t": - return f"a.hdim_v % {vec} == 0" - else: - assert False - elif self.pipeline_tag in ["qr", "qs", "qr_async_trload"]: - bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] - if self.dvpad == "t": - return f"true /*a.hdim_v % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly) - else: - return f"a.hdim_v % {bk0submax} == 0" - else: - assert False + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dvpad == "t": + return f"a.hdim_v % {vec} == 0" + assert False @dataclass @@ -597,6 +554,7 @@ def template(self) -> str: F_mode=MODE_MAP[self.F_mode], F_pipeline=PIPELINE_MAP[self.F_pipeline.tag], F_trload=BOOL_MAP[self.F_pipeline.F_trload], + F_kernel_name=self.name, ) @property @@ -745,78 +703,6 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: # (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], # (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], } - elif dtype == "fp8" or dtype == "bf8": - return { - (64, 64): [ - FmhaFwdTileSize( - 128, - 64, - 32, - 64, - 32, - 64, - 2, - 1, - 1, - 2, - 1, - 1, - 32, - 32, - 32, - 32, - 32, - 32, - -1, - ) - ], - (128, 128): [ - FmhaFwdTileSize( - 128, - 128, - 32, - 128, - 32, - 128, - 4, - 1, - 1, - 4, - 1, - 1, - 32, - 32, - 32, - 32, - 32, - 32, - -1, - ) - ], - (256, 256): [ - FmhaFwdTileSize( - 128, - 128, - 32, - 256, - 32, - 256, - 4, - 1, - 1, - 4, - 1, - 1, - 32, - 32, - 32, - 32, - 32, - 32, - -1, - ) - ], - } else: return None @@ -826,128 +712,58 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]: # this function will populate a list possible pipelines # TODO: the order of List matters! the later in this list will be also be checked later - # TODO: currently for qr pipeline, let 't' padding to appear later!! - # TODO: how to design this more generic? - squant = "t" if dtype == "fp8" else "f" + # FP8 static quantization is not supported in sparse attention yet. + squant = "f" pipelines = [] if dtype in ["fp16", "bf16"]: for logits, mask, bias, skip in itertools.product( - ["t", "f"], + ["f"], get_mask_map(mask_impl).keys(), - BIAS_MAP.keys(), + ["no"], ["t", "f"], ): # Always use lse="f" and dropout="f" (not supported) lse = "f" dropout = "f" if hdim == 256 and hdim_v == 256: - # print("jenga fmha only support dim=128 now.") + # jenga fmha only supports dim <= 192 for now. continue - pipelines.append( - FmhaFwdPipeline( # fmt: skip - "qr", - "row", - "f", - "f", - "f", - "f", - logits, - bias, - lse, - dropout, - squant, - mask, - skip, - "f", - ) - ) - # the below two is used for hdim vectorize load - pipelines.append( - FmhaFwdPipeline( # fmt: skip - "qr", - "row", - "t", - "t", - "f", - "f", - logits, - bias, - lse, - dropout, - squant, - mask, - skip, - "f", - ) + pipelines.append( + FmhaFwdPipeline( # fmt: skip + "qr_async", + "row", + "t", + "f", + "t", + "t", + logits, + bias, + lse, + dropout, + squant, + mask, + skip, + "f", ) - pipelines.append( - FmhaFwdPipeline( # fmt: skip - "qr", - "row", - "t", - "t", - "t", - "t", - logits, - bias, - lse, - dropout, - squant, - mask, - skip, - "f", - ) + ) + pipelines.append( + FmhaFwdPipeline( # fmt: skip + "qr_async", + "row", + "t", + "t", + "t", + "t", + logits, + bias, + lse, + dropout, + squant, + mask, + skip, + "f", ) - else: - if bias == "bias": - # jenga_fmha with bias is not implemented. - continue - else: - pipelines.append( - FmhaFwdPipeline( # fmt: skip - "qr_async", - "row", - "t", - "f", - "t", - "t", - logits, - bias, - lse, - dropout, - squant, - mask, - skip, - "f", - ) - ) - pipelines.append( - FmhaFwdPipeline( # fmt: skip - "qr_async", - "row", - "t", - "t", - "t", - "t", - logits, - bias, - lse, - dropout, - squant, - mask, - skip, - "f", - ) - ) - # TODO: consider enabling extra qr_async_trload pipelines for select - # (hdim, hdim_v) when logits/bias/dropout/lse/skip allow. - # TODO: consider enabling extra qr pipelines when receipt == 1 and bias != "bias". - elif dtype in ["fp8", "bf8"]: - # print("jenga fmha only support 16-bit compute.") - return pipelines - elif dtype in ["fp8fp16", "fp8bf16"]: - # TODO - None + ) else: assert False return pipelines @@ -1001,17 +817,18 @@ def get_fwd_blobs( else KernelComponentFactory ) - for dtype in FWD_DTYPE_MAP.keys(): + # Only generate fp16/bf16 kernels for now. + for dtype in ["fp16", "bf16"]: d = factory.get_hdim_tile_size_dict(dtype) if d is None: continue # for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): - for ((hdim, hdim_v), tiles), mode in itertools.product( - d.items(), MODE_MAP.keys() - ): + for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), ["batch"]): for tile, pipeline in itertools.product( tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) ): + if tile.F_bm0 != 128 or tile.F_bn0 != 128: + continue if mode == "group": if pipeline.F_spad != "t" or pipeline.F_skpad != "t": # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not @@ -1020,19 +837,7 @@ def get_fwd_blobs( # NOTE: this is used to speedup deepseek prefill case, we don't gen training if pipeline.F_bias != "no" or pipeline.F_dropout == "t": continue - if pipeline.tag != "qr_async_trload" and ( - ((hdim, hdim_v) == (128, 128) and tile.F_bn0 != 128) - or ((hdim, hdim_v) != (128, 128) and tile.F_bm0 != 128) - ): - # non qr_async_trload only support km0=128 tile size when hdim is not 128 - # non qr_async only support kn0=128 tile size when hdim is 128 - continue - if pipeline.tag == "qr_async_trload" and ( - ((hdim, hdim_v) == (128, 128) and tile.F_bn0 == 128) - or ((hdim, hdim_v) not in [(64, 64), (128, 128)]) - ): - continue - # logits_soft_cap is only allowed if no bias + # logits soft-cap is not generated for sparse attention if not ( (pipeline.F_logits == "t" and pipeline.F_bias == "no") or pipeline.F_logits == "f" diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py index 9c784dddbb7..4afe529e1a5 100644 --- a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py +++ b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py @@ -44,7 +44,7 @@ def update_file(file_path, content): file.write(content) -DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16, "fp8": 8, "bf8": 8} +DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16} K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 192: 192, 256: 256} @@ -88,17 +88,17 @@ def update_file(file_path, content): using fmha_mask_{F_idx} = {F_mask}; using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< - typename FmhaFwdTypeConfig::QDataType, - typename FmhaFwdTypeConfig::KDataType, - typename FmhaFwdTypeConfig::VDataType, - typename FmhaFwdTypeConfig::SaccDataType, - typename FmhaFwdTypeConfig::SMPLComputeDataType, - typename FmhaFwdTypeConfig::BiasDataType, - typename FmhaFwdTypeConfig::RandValOutputDataType, - typename FmhaFwdTypeConfig::LSEDataType, - typename FmhaFwdTypeConfig::PDataType, - typename FmhaFwdTypeConfig::OaccDataType, - typename FmhaFwdTypeConfig::ODataType, + typename FmhaSparseFwdTypeConfig::QDataType, + typename FmhaSparseFwdTypeConfig::KDataType, + typename FmhaSparseFwdTypeConfig::VDataType, + typename FmhaSparseFwdTypeConfig::SaccDataType, + typename FmhaSparseFwdTypeConfig::SMPLComputeDataType, + typename FmhaSparseFwdTypeConfig::BiasDataType, + typename FmhaSparseFwdTypeConfig::RandValOutputDataType, + typename FmhaSparseFwdTypeConfig::LSEDataType, + typename FmhaSparseFwdTypeConfig::PDataType, + typename FmhaSparseFwdTypeConfig::OaccDataType, + typename FmhaSparseFwdTypeConfig::ODataType, fmha_shape_{F_idx}, {F_mode}, fmha_variant_{F_idx}, @@ -110,8 +110,8 @@ def update_file(file_path, content): fmha_pipeline_problem_{F_idx}>; using fmha_epilogue_{F_idx} = - ck_tile::Default2DEpilogue::OaccDataType, - typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaSparseFwdTypeConfig<{F_dtype}>::ODataType, {F_spad}, {F_dvpad}>>; using fmha_kernel_{F_idx} = @@ -123,12 +123,12 @@ def update_file(file_path, content): #include template<> -float fmha_vsa_fwd_(const ck_tile::stream_config& s, fmha_jenga_fwd_args a) +float fmha_vsa_fwd_(const ck_tile::stream_config& s, fmha_vsa_fwd_args a) {{ using k_ = fmha_kernel_{F_idx}; if(s.log_level_ > 0) - std::cout << ", " << k_::GetName() << std::flush; - auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + std::cout << ", " << "{F_kernel_name}" << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); @@ -169,7 +169,7 @@ def update_file(file_path, content): }} }} // namespace -float fmha_vsa_fwd(fmha_jenga_fwd_traits t, fmha_jenga_fwd_args a, const ck_tile::stream_config& s){{ +float fmha_vsa_fwd(fmha_jenga_fwd_traits t, fmha_vsa_fwd_args a, const ck_tile::stream_config& s){{ float r = -1; [[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate @@ -204,7 +204,7 @@ def update_file(file_path, content): }} """ -FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && +FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ using trait_ = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; return fmha_vsa_fwd_(s, a); @@ -265,18 +265,9 @@ def name(self) -> str: def scheck(self) -> str: if self.mode == "group": return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true - if self.pipeline_tag in ["qr_async_vsa", "qr_async_trload"]: - if self.spad == "t": - return "true" # always support - else: - return "true" - elif self.pipeline_tag in ["qr", "qs"]: - if self.spad == "t": - return f"true /*a.seqlen_q % {self.bm0} != 0*/" # TODO: order of get_pipelines() matters! (ugly) - else: - return f"a.seqlen_q % {self.bm0} == 0" - else: - assert False + if self.spad == "t": + return "true" # always support + return "true" @property def seqtune(self) -> str: @@ -289,57 +280,23 @@ def seqtune(self) -> str: def skcheck(self) -> str: if self.mode == "group": return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true - if self.pipeline_tag == "qr_async_vsa": - if self.skpad == "t": - return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0" - else: - return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0" - elif self.pipeline_tag in ["qr", "qs"]: - if self.skpad == "t": - return f"true /*a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly) - else: - return f"a.seqlen_k % {self.bn0} == 0" - elif self.pipeline_tag == "qr_async_trload": - if self.skpad == "t": - return "true" - else: - return "true" - else: - assert False + if self.skpad == "t": + return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0" + return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0" @property def dcheck(self) -> str: - if self.pipeline_tag == "qr_async_vsa": - vec = int((32 * 4) / DTYPE_BITS[self.dtype]) - if self.dpad == "t": - return f"a.hdim_q % {vec} == 0" - else: - assert False - elif self.pipeline_tag in ["qr", "qs", "qr_async_trload"]: - bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] - if self.dpad == "t": - return f"true /*a.hdim_q % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly) - else: - return f"a.hdim_q % {bk0submax} == 0" - else: - assert False + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dpad == "t": + return f"a.hdim_q % {vec} == 0" + assert False @property def dvcheck(self) -> str: - if self.pipeline_tag == "qr_async_vsa": - vec = int((32 * 4) / DTYPE_BITS[self.dtype]) - if self.dvpad == "t": - return f"a.hdim_v % {vec} == 0" - else: - assert False - elif self.pipeline_tag in ["qr", "qs", "qr_async_trload"]: - bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] - if self.dvpad == "t": - return f"true /*a.hdim_v % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly) - else: - return f"a.hdim_v % {bk0submax} == 0" - else: - assert False + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dvpad == "t": + return f"a.hdim_v % {vec} == 0" + assert False @dataclass @@ -597,6 +554,7 @@ def template(self) -> str: F_mode=MODE_MAP[self.F_mode], F_pipeline=PIPELINE_MAP[self.F_pipeline.tag], F_trload=BOOL_MAP[self.F_pipeline.F_trload], + F_kernel_name=self.name, ) @property @@ -745,78 +703,6 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: # (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], # (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], } - elif dtype == "fp8" or dtype == "bf8": - return { - (64, 64): [ - FmhaFwdTileSize( # fmt: skip - 128, - 64, - 32, - 64, - 32, - 64, - 2, - 1, - 1, - 2, - 1, - 1, - 32, - 32, - 32, - 32, - 32, - 32, - -1, - ) - ], - (128, 128): [ - FmhaFwdTileSize( # fmt: skip - 128, - 128, - 32, - 128, - 32, - 128, - 4, - 1, - 1, - 4, - 1, - 1, - 32, - 32, - 32, - 32, - 32, - 32, - -1, - ) - ], - (256, 256): [ - FmhaFwdTileSize( # fmt: skip - 128, - 128, - 32, - 256, - 32, - 256, - 4, - 1, - 1, - 4, - 1, - 1, - 32, - 32, - 32, - 32, - 32, - 32, - -1, - ) - ], - } else: return None @@ -826,73 +712,58 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]: # this function will populate a list possible pipelines # TODO: the order of List matters! the later in this list will be also be checked later - # TODO: currently for qr pipeline, let 't' padding to appear later!! - # TODO: how to design this more generic? - squant = "t" if dtype == "fp8" else "f" + # FP8 static quantization is not supported in sparse attention yet. + squant = "f" pipelines = [] if dtype in ["fp16", "bf16"]: for logits, mask, bias, skip in itertools.product( - ["t", "f"], + ["f"], get_mask_map(mask_impl).keys(), - BIAS_MAP.keys(), + ["no"], ["t", "f"], ): # Always use lse="f" and dropout="f" (not supported) lse = "f" dropout = "f" if hdim == 256 and hdim_v == 256: - # print("vsa fmha only support dim=128 now.") + # vsa fmha only supports dim <= 192 for now. continue - else: - if bias == "bias": - # vsa_fmha with bias is not implemented. - continue - else: - pipelines.append( - FmhaFwdPipeline( - "qr_async_vsa", - "row", - "t", - "f", - "t", - "t", - logits, - bias, - lse, - dropout, - squant, - mask, - skip, - "f", - ) - ) - pipelines.append( - FmhaFwdPipeline( - "qr_async_vsa", - "row", - "t", - "t", - "t", - "t", - logits, - bias, - lse, - dropout, - squant, - mask, - skip, - "f", - ) - ) - # TODO: consider enabling extra qr_async_trload pipelines for select - # (hdim, hdim_v) when logits/bias/dropout/lse/skip allow. - # TODO: consider enabling extra qr pipelines when receipt == 1 and bias != "bias". - elif dtype in ["fp8", "bf8"]: - # print("vsa fmha only support 16-bit compute.") - return pipelines - elif dtype in ["fp8fp16", "fp8bf16"]: - # TODO - None + pipelines.append( + FmhaFwdPipeline( + "qr_async_vsa", + "row", + "t", + "f", + "t", + "t", + logits, + bias, + lse, + dropout, + squant, + mask, + skip, + "f", + ) + ) + pipelines.append( + FmhaFwdPipeline( + "qr_async_vsa", + "row", + "t", + "t", + "t", + "t", + logits, + bias, + lse, + dropout, + squant, + mask, + skip, + "f", + ) + ) else: assert False return pipelines @@ -946,17 +817,18 @@ def get_fwd_blobs( else KernelComponentFactory ) - for dtype in FWD_DTYPE_MAP.keys(): + # Only generate fp16/bf16 kernels for now. + for dtype in ["fp16", "bf16"]: d = factory.get_hdim_tile_size_dict(dtype) if d is None: continue # for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): - for ((hdim, hdim_v), tiles), mode in itertools.product( - d.items(), MODE_MAP.keys() - ): + for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), ["batch"]): for tile, pipeline in itertools.product( tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) ): + if tile.F_bm0 != 128 or tile.F_bn0 != 128: + continue if mode == "group": if pipeline.F_spad != "t" or pipeline.F_skpad != "t": # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not @@ -965,19 +837,7 @@ def get_fwd_blobs( # NOTE: this is used to speedup deepseek prefill case, we don't gen training if pipeline.F_bias != "no": continue - if pipeline.tag != "qr_async_trload" and ( - ((hdim, hdim_v) == (128, 128) and tile.F_bn0 != 128) - or ((hdim, hdim_v) != (128, 128) and tile.F_bm0 != 128) - ): - # non qr_async_trload only support km0=128 tile size when hdim is not 128 - # non qr_async only support kn0=128 tile size when hdim is 128 - continue - if pipeline.tag == "qr_async_trload" and ( - ((hdim, hdim_v) == (128, 128) and tile.F_bn0 == 128) - or ((hdim, hdim_v) not in [(64, 64), (128, 128)]) - ): - continue - # logits_soft_cap is only allowed if no bias + # logits soft-cap is not generated for sparse attention if not ( (pipeline.F_logits == "t" and pipeline.F_bias == "no") or pipeline.F_logits == "f" diff --git a/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp b/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp index 641d12490b5..f924aa0705d 100644 --- a/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp +++ b/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp @@ -8,8 +8,7 @@ #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/fmha.hpp" -#include "mask.hpp" -#include "bias.hpp" +#include "01_fmha/mask.hpp" #include #include @@ -19,35 +18,19 @@ namespace ck_tile { inline bool is_load_tr_supported() { return is_gfx95_supported(); } } // namespace ck_tile -struct FmhaFwdFp16 +struct FmhaSparseFwdFp16 { }; -struct FmhaFwdBf16 -{ -}; - -struct FmhaFwdFp8 -{ -}; - -struct FmhaFwdBf8 -{ -}; - -struct FmhaFwdFp8Fp16 -{ -}; - -struct FmhaFwdFp8Bf16 +struct FmhaSparseFwdBf16 { }; template -struct FmhaFwdTypeConfig; +struct FmhaSparseFwdTypeConfig; template <> -struct FmhaFwdTypeConfig +struct FmhaSparseFwdTypeConfig { using QDataType = ck_tile::half_t; using KDataType = ck_tile::half_t; @@ -63,7 +46,7 @@ struct FmhaFwdTypeConfig }; template <> -struct FmhaFwdTypeConfig +struct FmhaSparseFwdTypeConfig { using QDataType = ck_tile::bf16_t; using KDataType = ck_tile::bf16_t; @@ -78,38 +61,6 @@ struct FmhaFwdTypeConfig using ODataType = ck_tile::bf16_t; }; -template <> -struct FmhaFwdTypeConfig -{ - using QDataType = ck_tile::fp8_t; - using KDataType = ck_tile::fp8_t; - using VDataType = ck_tile::fp8_t; - using BiasDataType = float; - using RandValOutputDataType = uint8_t; - using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) - using SaccDataType = float; // data type for first gemm accumulation - using SMPLComputeDataType = float; // data type for reduction, softmax - using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm - using OaccDataType = float; // data type for second gemm accumulation - using ODataType = ck_tile::fp8_t; -}; - -template <> -struct FmhaFwdTypeConfig -{ - using QDataType = ck_tile::bf8_t; - using KDataType = ck_tile::bf8_t; - using VDataType = ck_tile::bf8_t; - using BiasDataType = ck_tile::bf8_t; - using RandValOutputDataType = uint8_t; - using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) - using SaccDataType = float; // data type for first gemm accumulation - using SMPLComputeDataType = float; // data type for reduction, softmax - using PDataType = ck_tile::bf8_t; // data type for A matrix of second gemm - using OaccDataType = float; // data type for second gemm accumulation - using ODataType = ck_tile::bf8_t; -}; - struct FmhaMasks { using NoMask = ck_tile::GenericAttentionMask; @@ -122,9 +73,8 @@ struct fmha_sparge_fwd_args const void* q_ptr; const void* k_ptr; const void* v_ptr; - const void* lut_ptr; - const void* valid_block_num_ptr; - const void* bias_ptr; // bias or alibi_slope pointer + const void* lut_ptr; // delta-encoded K-block indices per Q-block, int32 [B,H,Q_blk,K_blk] + const void* valid_block_num_ptr; // valid K-block count per Q-block, int32 [B,H,Q_blk] void* rand_val_ptr; void* lse_ptr; void* o_ptr; @@ -148,25 +98,20 @@ struct fmha_sparge_fwd_args float scale_p; float scale_o; - float logits_soft_cap; - ck_tile::index_t stride_q; ck_tile::index_t stride_k; ck_tile::index_t stride_v; - ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 ck_tile::index_t stride_randval; ck_tile::index_t stride_o; ck_tile::index_t nhead_stride_q; ck_tile::index_t nhead_stride_k; ck_tile::index_t nhead_stride_v; - ck_tile::index_t nhead_stride_bias; ck_tile::index_t nhead_stride_randval; ck_tile::index_t nhead_stride_lse; ck_tile::index_t nhead_stride_o; ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_v; - ck_tile::index_t batch_stride_bias; ck_tile::index_t batch_stride_randval; ck_tile::index_t batch_stride_lse; ck_tile::index_t batch_stride_o; @@ -176,11 +121,7 @@ struct fmha_sparge_fwd_args ck_tile::index_t mask_type; ck_tile::index_t min_seqlen_q; - float p_drop; - bool s_randval; - - std::variant, std::pair> - drop_seed_offset; + // Dropout is not supported for sparse attention; keep args minimal. }; template @@ -196,7 +137,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_sparge_fwd_args args) args.v_ptr, args.lut_ptr, args.valid_block_num_ptr, - args.bias_ptr, + nullptr, args.rand_val_ptr, args.lse_ptr, args.o_ptr, @@ -211,17 +152,17 @@ auto fmha_fwd_create_kargs_and_grids(fmha_sparge_fwd_args args) args.scale_s, args.scale_p, args.scale_o, - args.logits_soft_cap, + 0.0f, args.stride_q, args.stride_k, args.stride_v, - args.stride_bias, + 0, args.stride_randval, args.stride_o, args.nhead_stride_q, args.nhead_stride_k, args.nhead_stride_v, - args.nhead_stride_bias, + 0, args.nhead_stride_randval, args.nhead_stride_lse, args.nhead_stride_o, @@ -229,9 +170,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_sparge_fwd_args args) args.window_size_right, args.mask_type, args.min_seqlen_q, - args.p_drop, - args.s_randval, - args.drop_seed_offset); + 0.0f, + false, + std::make_pair(uint64_t{0}, uint64_t{0})); } else { // create batch mode kernel arguments @@ -240,7 +181,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_sparge_fwd_args args) args.v_ptr, args.lut_ptr, args.valid_block_num_ptr, - args.bias_ptr, + nullptr, args.rand_val_ptr, args.lse_ptr, args.o_ptr, @@ -254,33 +195,33 @@ auto fmha_fwd_create_kargs_and_grids(fmha_sparge_fwd_args args) args.scale_s, args.scale_p, args.scale_o, - args.logits_soft_cap, + 0.0f, args.stride_q, args.stride_k, args.stride_v, - args.stride_bias, + 0, args.stride_randval, args.stride_o, args.nhead_stride_q, args.nhead_stride_k, args.nhead_stride_v, - args.nhead_stride_bias, + 0, args.nhead_stride_randval, args.nhead_stride_lse, args.nhead_stride_o, args.batch_stride_q, args.batch_stride_k, args.batch_stride_v, - args.batch_stride_bias, + 0, args.batch_stride_randval, args.batch_stride_lse, args.batch_stride_o, args.window_size_left, args.window_size_right, args.mask_type, - args.p_drop, - args.s_randval, - args.drop_seed_offset); + 0.0f, + false, + std::make_pair(uint64_t{0}, uint64_t{0})); } }(); @@ -356,9 +297,7 @@ struct fmha_sparge_fwd_traits std::string data_type; bool is_group_mode; bool is_v_rowmajor; - bool has_logits_soft_cap; mask_enum mask_type; - bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum bool has_lse; bool has_dropout; bool do_fp8_static_quant; @@ -379,10 +318,7 @@ struct fmha_jenga_fwd_args const void* q_ptr; const void* k_ptr; const void* v_ptr; - const void* block_relation_onehot_ptr; - const void* lut_ptr; - const void* valid_block_num_ptr; - const void* bias_ptr; // bias or alibi_slope pointer + const void* block_relation_onehot_ptr; // one-hot block map [B,H,Q_blk,K_blk], 1=active void* rand_val_ptr; void* lse_ptr; void* o_ptr; @@ -405,25 +341,20 @@ struct fmha_jenga_fwd_args float scale_p; float scale_o; - float logits_soft_cap; - ck_tile::index_t stride_q; ck_tile::index_t stride_k; ck_tile::index_t stride_v; - ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 ck_tile::index_t stride_randval; ck_tile::index_t stride_o; ck_tile::index_t nhead_stride_q; ck_tile::index_t nhead_stride_k; ck_tile::index_t nhead_stride_v; - ck_tile::index_t nhead_stride_bias; ck_tile::index_t nhead_stride_randval; ck_tile::index_t nhead_stride_lse; ck_tile::index_t nhead_stride_o; ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_v; - ck_tile::index_t batch_stride_bias; ck_tile::index_t batch_stride_randval; ck_tile::index_t batch_stride_lse; ck_tile::index_t batch_stride_o; @@ -433,205 +364,256 @@ struct fmha_jenga_fwd_args ck_tile::index_t mask_type; ck_tile::index_t min_seqlen_q; - float p_drop; - bool s_randval; + // Dropout is not supported for sparse attention; keep args minimal. +}; + +// vsa +struct fmha_vsa_fwd_args +{ + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* lut_ptr; // delta-encoded K-block indices per Q-block, int32 [B,H,Q_blk,K_blk] + const void* valid_block_num_ptr; // valid K-block count per Q-block, int32 [B,H,Q_blk] + void* rand_val_ptr; + void* lse_ptr; + void* o_ptr; + + const void* seqstart_q_ptr; + const void* seqstart_k_ptr; + const void* + seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr - std::variant, std::pair> - drop_seed_offset; + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + + float scale_s; + float scale_p; + float scale_o; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_randval; + ck_tile::index_t stride_o; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_randval; + ck_tile::index_t nhead_stride_lse; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_randval; + ck_tile::index_t batch_stride_lse; + ck_tile::index_t batch_stride_o; + + ck_tile::index_t window_size_left; + ck_tile::index_t window_size_right; + ck_tile::index_t mask_type; + ck_tile::index_t min_seqlen_q; + + // Dropout is not supported for sparse attention; keep args minimal. }; -template +template auto fmha_fwd_create_kargs_and_grids(fmha_jenga_fwd_args args) { assert(args.nhead_q % args.nhead_k == 0); auto kargs = [&] { - if constexpr(VSA) + if constexpr(FmhaKernel::kIsGroupMode) { - // create group mode kernel arguments - if constexpr(FmhaKernel::kIsGroupMode) - { - return FmhaKernel::MakeKargsImpl(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.lut_ptr, - args.valid_block_num_ptr, - args.bias_ptr, - args.rand_val_ptr, - args.lse_ptr, - args.o_ptr, - args.seqstart_q_ptr, - args.seqstart_k_ptr, - args.seqlen_k_ptr, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.scale_s, - args.scale_p, - args.scale_o, - args.logits_soft_cap, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_bias, - args.stride_randval, - args.stride_o, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - args.nhead_stride_bias, - args.nhead_stride_randval, - args.nhead_stride_lse, - args.nhead_stride_o, - args.window_size_left, - args.window_size_right, - args.mask_type, - args.min_seqlen_q, - args.p_drop, - args.s_randval, - args.drop_seed_offset); - } - else - { // create batch mode kernel arguments - return FmhaKernel::MakeKargsImpl(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.lut_ptr, - args.valid_block_num_ptr, - args.bias_ptr, - args.rand_val_ptr, - args.lse_ptr, - args.o_ptr, - args.seqlen_q, - args.seqlen_k, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.scale_s, - args.scale_p, - args.scale_o, - args.logits_soft_cap, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_bias, - args.stride_randval, - args.stride_o, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - args.nhead_stride_bias, - args.nhead_stride_randval, - args.nhead_stride_lse, - args.nhead_stride_o, - args.batch_stride_q, - args.batch_stride_k, - args.batch_stride_v, - args.batch_stride_bias, - args.batch_stride_randval, - args.batch_stride_lse, - args.batch_stride_o, - args.window_size_left, - args.window_size_right, - args.mask_type, - args.p_drop, - args.s_randval, - args.drop_seed_offset); - } + return FmhaKernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.block_relation_onehot_ptr, + args.rand_val_ptr, + args.lse_ptr, + args.o_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.scale_p, + args.scale_o, + 0.0f, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_randval, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_randval, + args.nhead_stride_lse, + args.nhead_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.min_seqlen_q, + 0.0f, + false, + std::make_pair(uint64_t{0}, uint64_t{0})); } else { - // create group mode kernel arguments - if constexpr(FmhaKernel::kIsGroupMode) - { - return FmhaKernel::MakeKargsImpl(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.block_relation_onehot_ptr, - args.bias_ptr, - args.rand_val_ptr, - args.lse_ptr, - args.o_ptr, - args.seqstart_q_ptr, - args.seqstart_k_ptr, - args.seqlen_k_ptr, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.scale_s, - args.scale_p, - args.scale_o, - args.logits_soft_cap, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_bias, - args.stride_randval, - args.stride_o, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - args.nhead_stride_bias, - args.nhead_stride_randval, - args.nhead_stride_lse, - args.nhead_stride_o, - args.window_size_left, - args.window_size_right, - args.mask_type, - args.min_seqlen_q, - args.p_drop, - args.s_randval, - args.drop_seed_offset); - } - else - { // create batch mode kernel arguments - return FmhaKernel::MakeKargsImpl(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.block_relation_onehot_ptr, - args.bias_ptr, - args.rand_val_ptr, - args.lse_ptr, - args.o_ptr, - args.seqlen_q, - args.seqlen_k, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.scale_s, - args.scale_p, - args.scale_o, - args.logits_soft_cap, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_bias, - args.stride_randval, - args.stride_o, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - args.nhead_stride_bias, - args.nhead_stride_randval, - args.nhead_stride_lse, - args.nhead_stride_o, - args.batch_stride_q, - args.batch_stride_k, - args.batch_stride_v, - args.batch_stride_bias, - args.batch_stride_randval, - args.batch_stride_lse, - args.batch_stride_o, - args.window_size_left, - args.window_size_right, - args.mask_type, - args.p_drop, - args.s_randval, - args.drop_seed_offset); - } + return FmhaKernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.block_relation_onehot_ptr, + args.rand_val_ptr, + args.lse_ptr, + args.o_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.scale_p, + args.scale_o, + 0.0f, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_randval, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_randval, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_randval, + args.batch_stride_lse, + args.batch_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + 0.0f, + false, + std::make_pair(uint64_t{0}, uint64_t{0})); + } + }(); + + if constexpr(FmhaKernel::kIsGroupMode) + { + dim3 grids = FmhaKernel::GridSize( + args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.seqlen_k_ptr != nullptr); + return ck_tile::make_tuple(kargs, grids); + } + else + { + dim3 grids = + FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, false); + return ck_tile::make_tuple(kargs, grids); + } +} + +template +auto fmha_fwd_create_kargs_and_grids(fmha_vsa_fwd_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + if constexpr(FmhaKernel::kIsGroupMode) + { + return FmhaKernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.lut_ptr, + args.valid_block_num_ptr, + args.rand_val_ptr, + args.lse_ptr, + args.o_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.scale_p, + args.scale_o, + 0.0f, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_randval, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_randval, + args.nhead_stride_lse, + args.nhead_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.min_seqlen_q, + 0.0f, + false, + std::make_pair(uint64_t{0}, uint64_t{0})); + } + else + { + return FmhaKernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.lut_ptr, + args.valid_block_num_ptr, + args.rand_val_ptr, + args.lse_ptr, + args.o_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.scale_p, + args.scale_o, + 0.0f, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_randval, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_randval, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_randval, + args.batch_stride_lse, + args.batch_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + 0.0f, + false, + std::make_pair(uint64_t{0}, uint64_t{0})); } }(); @@ -707,9 +689,7 @@ struct fmha_jenga_fwd_traits std::string data_type; bool is_group_mode; bool is_v_rowmajor; - bool has_logits_soft_cap; mask_enum mask_type; - bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum bool has_lse; bool has_dropout; bool do_fp8_static_quant; @@ -724,9 +704,9 @@ float fmha_jenga_fwd_(const ck_tile::stream_config&, fmha_jenga_fwd_args); float fmha_jenga_fwd(fmha_jenga_fwd_args, const ck_tile::stream_config&); -float fmha_vsa_fwd(fmha_jenga_fwd_traits, fmha_jenga_fwd_args, const ck_tile::stream_config&); +float fmha_vsa_fwd(fmha_jenga_fwd_traits, fmha_vsa_fwd_args, const ck_tile::stream_config&); template -float fmha_vsa_fwd_(const ck_tile::stream_config&, fmha_jenga_fwd_args); +float fmha_vsa_fwd_(const ck_tile::stream_config&, fmha_vsa_fwd_args); -float fmha_vsa_fwd(fmha_jenga_fwd_args, const ck_tile::stream_config&); +float fmha_vsa_fwd(fmha_vsa_fwd_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cu b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cpp similarity index 69% rename from example/ck_tile/50_sparse_attn/jenga_sparse_attention.cu rename to example/ck_tile/50_sparse_attn/jenga_sparse_attention.cpp index 45b61b6dd8a..98f5c940709 100644 --- a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cu +++ b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cpp @@ -9,13 +9,11 @@ template ck_tile::HostTensor -jenga_sparse_attention(ck_tile::HostTensor& TQ, - ck_tile::HostTensor& TK, - ck_tile::HostTensor& TV, - ck_tile::HostTensor& Tblock_relation_onehot, +jenga_sparse_attention(const ck_tile::HostTensor& TQ, + const ck_tile::HostTensor& TK, + const ck_tile::HostTensor& TV, + const ck_tile::HostTensor& Tblock_relation_onehot, ck_tile::HostTensor& Y, - std::optional> bias, - int bias_type, int batch, int nhead, int nhead_k, @@ -29,33 +27,24 @@ jenga_sparse_attention(ck_tile::HostTensor& TQ, int max_seqlen_k, int log_level) { + static_assert(std::is_same_v || + std::is_same_v, + "Jenga sparse attention supports fp16/bf16 only."); // Determine data type string based on template parameter - constexpr bool is_fp8 = - std::is_same_v || std::is_same_v; std::string data_type = "fp16"; if constexpr(std::is_same_v) { data_type = "bf16"; } - else if constexpr(std::is_same_v) - { - data_type = "fp8"; - } - else if constexpr(std::is_same_v) - { - data_type = "bf8"; - } if(max_seqlen_q == 0) max_seqlen_q = seqlen_q; if(max_seqlen_k == 0) max_seqlen_k = seqlen_k; - bool is_v_rowmajor = true; - float scale_s = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); - float scale_p = 1.f; - float scale_o = 1.f; - const float logits_soft_cap = 0.0; - + bool is_v_rowmajor = true; + float scale_s = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); + float scale_p = 1.f; + float scale_o = 1.f; std::string msk_str = "0"; mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k); @@ -81,11 +70,6 @@ jenga_sparse_attention(ck_tile::HostTensor& TQ, v_buf.ToDevice(TV.data()); block_relation_buf.ToDevice(Tblock_relation_onehot.data()); - // Optional buffers - ck_tile::DeviceMem bias_buf(bias ? bias->get_element_space_size_in_bytes() : 0); - - if(bias) - bias_buf.ToDevice(bias->data()); const auto init_args = [&](auto& args) { assert(nhead % nhead_k == 0); const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); @@ -96,7 +80,6 @@ jenga_sparse_attention(ck_tile::HostTensor& TQ, else return (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k); }(); - const ck_tile::index_t stride_bias = (i_perm ? max_seqlen_k : 1 * max_seqlen_k); const ck_tile::index_t stride_randval = (max_seqlen_k); const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); // setup nhead_stride_* arguments @@ -108,8 +91,6 @@ jenga_sparse_attention(ck_tile::HostTensor& TQ, else return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k; }(); - const ck_tile::index_t nhead_stride_bias = - (i_perm ? 0 * shape_seqlen_q * max_seqlen_k : 0 * max_seqlen_k); const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); @@ -117,7 +98,6 @@ jenga_sparse_attention(ck_tile::HostTensor& TQ, const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); const ck_tile::index_t batch_stride_k = nhead_k * shape_seqlen_k * hdim_q; const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * shape_seqlen_k; - const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); @@ -145,9 +125,8 @@ jenga_sparse_attention(ck_tile::HostTensor& TQ, args.batch_stride_k = batch_stride_k; args.batch_stride_v = batch_stride_v; - args.bias_ptr = bias ? bias_buf.GetDeviceBuffer() : nullptr; - args.lse_ptr = nullptr; - args.o_ptr = o_buf.GetDeviceBuffer(); + args.lse_ptr = nullptr; + args.o_ptr = o_buf.GetDeviceBuffer(); args.seqstart_q_ptr = nullptr; args.seqstart_k_ptr = nullptr; @@ -160,16 +139,11 @@ jenga_sparse_attention(ck_tile::HostTensor& TQ, args.scale_p = scale_p; args.scale_o = scale_o; - args.logits_soft_cap = logits_soft_cap; - - args.stride_bias = stride_bias; - args.stride_o = stride_o; - args.nhead_stride_bias = nhead_stride_bias; - args.nhead_stride_lse = nhead_stride_lse; - args.nhead_stride_o = nhead_stride_o; - args.batch_stride_bias = batch_stride_bias; - args.batch_stride_lse = batch_stride_lse; - args.batch_stride_o = batch_stride_o; + args.stride_o = stride_o; + args.nhead_stride_lse = nhead_stride_lse; + args.nhead_stride_o = nhead_stride_o; + args.batch_stride_lse = batch_stride_lse; + args.batch_stride_o = batch_stride_o; args.window_size_left = mask.left; args.window_size_right = mask.right; @@ -181,8 +155,7 @@ jenga_sparse_attention(ck_tile::HostTensor& TQ, args.nhead_stride_randval = nhead_stride_randval; args.batch_stride_randval = batch_stride_randval; - args.p_drop = 0.; - args.s_randval = false; + // Dropout not supported for sparse attention. }; const auto init_traits = [&](auto& traits) { @@ -192,11 +165,9 @@ jenga_sparse_attention(ck_tile::HostTensor& TQ, traits.is_v_rowmajor = is_v_rowmajor; traits.is_group_mode = false; - traits.has_logits_soft_cap = 0.f < logits_soft_cap; traits.mask_type = mask.type; - traits.bias_type = static_cast(bias_type); traits.has_lse = false; - traits.do_fp8_static_quant = is_fp8; + traits.do_fp8_static_quant = false; traits.has_dropout = false; }; @@ -217,17 +188,39 @@ jenga_sparse_attention(ck_tile::HostTensor& TQ, // Explicit template instantiations template ck_tile::HostTensor -jenga_sparse_attention( - ck_tile::HostTensor&, ck_tile::HostTensor&, - ck_tile::HostTensor&, ck_tile::HostTensor&, - ck_tile::HostTensor&, - std::optional>, - int, int, int, int, int, int, int, int, bool, bool, int, int, int); +jenga_sparse_attention(const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + ck_tile::HostTensor&, + int, + int, + int, + int, + int, + int, + int, + bool, + bool, + int, + int, + int); template ck_tile::HostTensor -jenga_sparse_attention( - ck_tile::HostTensor&, ck_tile::HostTensor&, - ck_tile::HostTensor&, ck_tile::HostTensor&, - ck_tile::HostTensor&, - std::optional>, - int, int, int, int, int, int, int, int, bool, bool, int, int, int); +jenga_sparse_attention(const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + ck_tile::HostTensor&, + int, + int, + int, + int, + int, + int, + int, + bool, + bool, + int, + int, + int); diff --git a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.h b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.h index 601f849c232..09b5731dd01 100644 --- a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.h +++ b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.h @@ -8,13 +8,11 @@ template ck_tile::HostTensor -jenga_sparse_attention(ck_tile::HostTensor& TQ, - ck_tile::HostTensor& TK, - ck_tile::HostTensor& TV, - ck_tile::HostTensor& Tblock_relation_onehot, +jenga_sparse_attention(const ck_tile::HostTensor& TQ, + const ck_tile::HostTensor& TK, + const ck_tile::HostTensor& TV, + const ck_tile::HostTensor& Tblock_relation_onehot, ck_tile::HostTensor& Y, - std::optional> bias, - int bias_type, int batch, int nhead, int nhead_k, @@ -29,24 +27,22 @@ jenga_sparse_attention(ck_tile::HostTensor& TQ, int log_level = 0); template -ck_tile::HostTensor -vsa_sparse_attention(ck_tile::HostTensor& TQ, - ck_tile::HostTensor& TK, - ck_tile::HostTensor& TV, - ck_tile::HostTensor& TKV_block_idx, // LUT must be int32_t - ck_tile::HostTensor& TKV_blocks, // valid_block_num must be int32_t - ck_tile::HostTensor& Y, - std::optional> bias, - int bias_type, - int batch, - int nhead, - int nhead_k, - int seqlen_q, - int seqlen_k, - int hdim_q, - int hdim_v, - bool i_perm, - bool o_perm, - int max_seqlen_q, - int max_seqlen_k, - int log_level = 0); +ck_tile::HostTensor vsa_sparse_attention( + const ck_tile::HostTensor& TQ, + const ck_tile::HostTensor& TK, + const ck_tile::HostTensor& TV, + const ck_tile::HostTensor& TKV_block_idx, // LUT must be int32_t + const ck_tile::HostTensor& TKV_blocks, // valid_block_num must be int32_t + ck_tile::HostTensor& Y, + int batch, + int nhead, + int nhead_k, + int seqlen_q, + int seqlen_k, + int hdim_q, + int hdim_v, + bool i_perm, + bool o_perm, + int max_seqlen_q, + int max_seqlen_k, + int log_level = 0); diff --git a/example/ck_tile/50_sparse_attn/mask.hpp b/example/ck_tile/50_sparse_attn/mask.hpp deleted file mode 100644 index b484ccc590f..00000000000 --- a/example/ck_tile/50_sparse_attn/mask.hpp +++ /dev/null @@ -1,175 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT -#pragma once - -#include -#include - -#include "ck_tile/core.hpp" -#include "ck_tile/ops/fmha.hpp" - -// keep this in sync with ck_tile::GenericAttentionMaskEnum -enum class mask_enum -{ - no_mask = 0, - mask_top_left, - mask_bottom_right, - window_generic, -}; - -struct mask_info -{ - mask_enum type; - ck_tile::index_t seqlen_q; - ck_tile::index_t seqlen_k; - ck_tile::index_t y, x; - ck_tile::index_t left, right; // FA style SWA left/right - - void serialize(std::ostream& os) const - { - if(type == mask_enum::no_mask) - os << "n"; - else if(type == mask_enum::mask_top_left) - os << "t(" << left << ":" << right << ")"; - else if(type == mask_enum::mask_bottom_right) - os << "b(" << left << ":" << right << ")"; - else - { - os << "g(" << y << ":" << x << ")"; - } - } - static mask_info decode(std::string str, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k) - { - ck_tile::index_t x_total = seqlen_k; - ck_tile::index_t y_total = seqlen_q; - mask_info tmp; - tmp.seqlen_q = seqlen_q; - tmp.seqlen_k = seqlen_k; - auto found_0 = str.find(':'); - if(found_0 != std::string::npos) - { - std::string t = str.substr(0, found_0); - std::string v = str.substr(found_0 + 1); - if(t == "xt" || t == "xb") - { - // xformer style sliding window attn from top-left - ck_tile::index_t window_size = atoi(v.c_str()); - ck_tile::index_t left_size = -1; - ck_tile::index_t right_size = 0; - if(window_size > 0) - { - left_size = window_size / 2; - right_size = window_size - 1 - left_size; - } - auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( - left_size, right_size, y_total, x_total, t == "xt"); - - tmp.type = t == "xt" ? mask_enum::mask_top_left : mask_enum::mask_bottom_right; - tmp.y = r.at(ck_tile::number<0>{}); - tmp.x = r.at(ck_tile::number<1>{}); - tmp.left = left_size; - tmp.right = right_size; - } - else - { - auto found_1 = v.find(","); - if(found_1 == std::string::npos) - { - printf("not supported value %s, %s\n", v.c_str(), str.c_str()); - assert(0); - } - tmp.type = mask_enum::window_generic; - ck_tile::index_t v0 = atoi(v.substr(0, found_1).c_str()); - ck_tile::index_t v1 = atoi(v.substr(found_1 + 1).c_str()); - // TODO: some validation - if(t == "t") - { - tmp.type = mask_enum::mask_top_left; - auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( - v0, v1, y_total, x_total, true); - tmp.y = r.at(ck_tile::number<0>{}); - tmp.x = r.at(ck_tile::number<1>{}); - tmp.left = v0; - tmp.right = v1; - } - else if(t == "b") - { - tmp.type = mask_enum::mask_bottom_right; - auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( - v0, v1, y_total, x_total, false); - tmp.y = r.at(ck_tile::number<0>{}); - tmp.x = r.at(ck_tile::number<1>{}); - tmp.left = v0; - tmp.right = v1; - } - else if(t == "g") - { - tmp.y = v0; - tmp.x = v1; - tmp.left = v0; // TODO: don't use this? - tmp.right = v1; - } - else - { - printf("not supported type %s, %s\n", t.c_str(), str.c_str()); - assert(0); - } - } - } - else - { - auto set_causal_top_left = [&]() { - tmp.type = mask_enum::mask_top_left; - tmp.y = seqlen_q; - tmp.x = 1; - tmp.left = -1; - tmp.right = 0; - }; - auto set_causal_bottom_right = [&]() { - tmp.type = mask_enum::mask_bottom_right; - tmp.y = seqlen_q; - tmp.x = seqlen_k - seqlen_q + 1; - tmp.left = -1; - tmp.right = 0; - }; - if(str == "t") - set_causal_top_left(); - else if(str == "b") - set_causal_bottom_right(); - else - { - tmp.type = static_cast(atoi(str.c_str())); - if(tmp.type == mask_enum::mask_top_left) - { - set_causal_top_left(); - } - else if(tmp.type == mask_enum::mask_bottom_right) - { - set_causal_bottom_right(); - } - } - } - return tmp; - } - ck_tile::index_t get_unmaskarea() const - { - if(type == mask_enum::no_mask) - return seqlen_q * seqlen_k; - ck_tile::index_t area = 0; - for(ck_tile::index_t i_y = 0; i_y < seqlen_q; ++i_y) - { - ck_tile::index_t x_start = std::max(-y + i_y + 1, static_cast(0)); - ck_tile::index_t x_end = std::min(i_y + x, seqlen_k); - if(x_end > x_start) - { - area += (x_end - x_start); - } - } - return area; - } - friend std::ostream& operator<<(std::ostream& os, const mask_info& mi) - { - mi.serialize(os); - return os; - } -}; diff --git a/example/ck_tile/50_sparse_attn/test_jenga_sparse_attn.cpp b/example/ck_tile/50_sparse_attn/test_jenga_sparse_attn.cpp index 46d95ebefec..b2e969bdc19 100644 --- a/example/ck_tile/50_sparse_attn/test_jenga_sparse_attn.cpp +++ b/example/ck_tile/50_sparse_attn/test_jenga_sparse_attn.cpp @@ -14,6 +14,7 @@ #include "ck_tile/host.hpp" #include "ck_tile/core.hpp" #include "ck_tile/host/reference/reference_blocked_attention.hpp" +#include "ck_tile/core/utility/bit_cast.hpp" #include "jenga_sparse_attention.h" @@ -66,10 +67,32 @@ template auto get_error_tolerance() { double rtol = 1e-2; - double atol = 4e-2; // Higher tolerance for bf16/fp16 + double atol = 4e-2; + if constexpr(std::is_same_v) + { + // bf16 accumulation/rounding can be noisier in sparse patterns + atol = 2e-1; + rtol = 2e-1; + } return ck_tile::make_tuple(rtol, atol); } +template +float to_float_for_compare(T value) +{ + return static_cast(value); +} + +template <> +float to_float_for_compare(ck_tile::bf16_t value) +{ +#if CK_TILE_USE_CUSTOM_DATA_TYPE + return static_cast(value); +#else + return ck_tile::bf16_to_float_raw(ck_tile::bit_cast(value)); +#endif +} + // ============================================================================ // Command line argument parser // ============================================================================ @@ -89,7 +112,6 @@ auto create_args(int argc, char* argv[]) .insert("prec", "fp16", "data type: fp16/bf16") .insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d") .insert("operm", "1", "permute output") - .insert("bias", "0", "bias type: 0:no bias, 1:elementwise, 2:alibi") .insert("seed", "42", "random seed") .insert("warmup", "5", "warmup iterations") .insert("repeat", "20", "benchmark iterations") @@ -118,7 +140,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser) float sparsity = arg_parser.get_float("sparsity"); bool i_perm = arg_parser.get_bool("iperm"); bool o_perm = arg_parser.get_bool("operm"); - int bias_type = arg_parser.get_int("bias"); uint32_t seed = arg_parser.get_uint32("seed"); int warmup = arg_parser.get_int("warmup"); int repeat = arg_parser.get_int("repeat"); @@ -144,14 +165,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser) return true; } - if(bias_type == 1) - { - std::cout << "\n>>> TEST SKIPPED <<<" << std::endl; - std::cout << "Elementwise bias is not supported by generated Jenga kernels." << std::endl; - std::cout << "TEST SKIPPED" << std::endl; - return true; - } - // Calculate number of Q and K blocks ck_tile::index_t num_q_blocks = (seqlen_q + BLKQ - 1) / BLKQ; ck_tile::index_t num_k_blocks = (seqlen_k + BLKK - 1) / BLKK; @@ -179,9 +192,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser) : ck_tile::HostTensor({batch, seqlen_q, nhead, hdim_v}); ck_tile::HostTensor output_ref({batch, nhead, seqlen_q, hdim_v}); - // Bias tensor [B, H, S_q, S_k] - ck_tile::HostTensor bias_host({batch, nhead, seqlen_q, seqlen_k}); - // Block relation onehot: [B, H, Q_blocks, K_blocks] ck_tile::HostTensor block_relation_onehot({batch, nhead, num_q_blocks, num_k_blocks}); @@ -191,9 +201,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser) ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 1}(k_host); ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 2}(v_host); - // Initialize bias to zero - std::fill(bias_host.mData.begin(), bias_host.mData.end(), static_cast(0.0f)); - // Initialize block_relation_onehot with sparse pattern std::mt19937 rng(seed + 100); std::uniform_real_distribution dist(0.0f, 1.0f); @@ -231,14 +238,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser) std::cout << " Actual sparsity: " << actual_sparsity << " (" << active_blocks << "/" << total_blocks << " blocks active)" << std::endl; - // Optional tensors - std::optional> bias_opt = std::nullopt; - - if(bias_type != 0) - { - bias_opt = bias_host; - } - // Run kernel std::cout << "\n--- Running Jenga sparse attention kernel ---" << std::endl; @@ -251,8 +250,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser) v_host, block_relation_onehot, output_host, - bias_opt, - bias_type, batch, nhead, nhead_k, @@ -275,8 +272,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser) v_host, block_relation_onehot, output_host, - bias_opt, - bias_type, batch, nhead, nhead_k, @@ -302,8 +297,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser) v_host, block_relation_onehot, output_host, - bias_opt, - bias_type, batch, nhead, nhead_k, @@ -345,7 +338,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser) auto k_ref = to_bhsd(k_host, i_perm); auto v_ref = to_bhsd(v_host, i_perm); ck_tile::reference_blocked_attention( - q_ref, k_ref, v_ref, block_relation_onehot, bias_host, output_ref, BLKQ, BLKK, scale); + q_ref, k_ref, v_ref, block_relation_onehot, output_ref, BLKQ, BLKK, scale); // Compare results auto [rtol, atol] = get_error_tolerance(); @@ -357,8 +350,8 @@ bool run_test(const ck_tile::ArgParser& arg_parser) auto output_host_bhsd = to_bhsd(output_host, o_perm); for(size_t i = 0; i < output_host_bhsd.mData.size(); ++i) { - float gpu_val = static_cast(output_host_bhsd.mData[i]); - float ref_val = static_cast(output_ref.mData[i]); + float gpu_val = to_float_for_compare(output_host_bhsd.mData[i]); + float ref_val = to_float_for_compare(output_ref.mData[i]); float diff = std::abs(gpu_val - ref_val); float rel_diff = (std::abs(ref_val) > 1e-6f) ? diff / std::abs(ref_val) : diff; diff --git a/example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp b/example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp index 3f0ed27c4e3..e0a1ce7998a 100644 --- a/example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp +++ b/example/ck_tile/50_sparse_attn/test_vsa_sparse_attn.cpp @@ -14,6 +14,7 @@ #include "ck_tile/host.hpp" #include "ck_tile/core.hpp" #include "ck_tile/host/reference/reference_blocked_attention.hpp" +#include "ck_tile/core/utility/bit_cast.hpp" #include "jenga_sparse_attention.h" #include "fmha_fwd_trek.hpp" @@ -105,10 +106,32 @@ template auto get_error_tolerance() { double rtol = 1e-2; - double atol = 4e-2; // Higher tolerance for bf16/fp16 + double atol = 4e-2; + if constexpr(std::is_same_v) + { + // bf16 accumulation/rounding can be noisier in sparse patterns + atol = 2e-1; + rtol = 2e-1; + } return ck_tile::make_tuple(rtol, atol); } +template +float to_float_for_compare(T value) +{ + return static_cast(value); +} + +template <> +float to_float_for_compare(ck_tile::bf16_t value) +{ +#if CK_TILE_USE_CUSTOM_DATA_TYPE + return static_cast(value); +#else + return ck_tile::bf16_to_float_raw(ck_tile::bit_cast(value)); +#endif +} + // ============================================================================ // Command line argument parser // ============================================================================ @@ -128,7 +151,6 @@ auto create_args(int argc, char* argv[]) .insert("prec", "fp16", "data type: fp16/bf16") .insert("iperm", "1", "permute input, 1: b*h*s*d, 0: b*s*h*d") .insert("operm", "1", "permute output") - .insert("bias", "0", "bias type: 0:no bias, 1:elementwise, 2:alibi") .insert("seed", "42", "random seed") .insert("warmup", "5", "warmup iterations") .insert("repeat", "20", "benchmark iterations") @@ -157,7 +179,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser) float sparsity = arg_parser.get_float("sparsity"); bool i_perm = arg_parser.get_bool("iperm"); bool o_perm = arg_parser.get_bool("operm"); - int bias_type = arg_parser.get_int("bias"); uint32_t seed = arg_parser.get_uint32("seed"); int warmup = arg_parser.get_int("warmup"); int repeat = arg_parser.get_int("repeat"); @@ -213,9 +234,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser) : ck_tile::HostTensor({batch, seqlen_q, nhead, hdim_v}); ck_tile::HostTensor output_ref({batch, nhead, seqlen_q, hdim_v}); - // Bias tensor [B, H, S_q, S_k] - ck_tile::HostTensor bias_host({batch, nhead, seqlen_q, seqlen_k}); - // Block relation onehot: [B, H, Q_blocks, K_blocks] ck_tile::HostTensor block_relation_onehot({batch, nhead, num_q_blocks, num_k_blocks}); @@ -229,9 +247,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser) ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 1}(k_host); ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 2}(v_host); - // Initialize bias to zero (as in Python test) - std::fill(bias_host.mData.begin(), bias_host.mData.end(), static_cast(0.0f)); - // Initialize block_relation_onehot with sparse pattern std::mt19937 rng(seed + 100); std::uniform_real_distribution dist(0.0f, 1.0f); @@ -277,19 +292,13 @@ bool run_test(const ck_tile::ArgParser& arg_parser) // vsa_sparse_attention handles device memory internally - // Optional tensors - std::optional> bias_opt = std::nullopt; - - if(bias_type != 0) - { - bias_opt = bias_host; - } - // Run kernel std::cout << "\n--- Running VSA sparse attention kernel ---" << std::endl; try { + // Print kernel name once by invoking with log_level=1. + // This is separate from warmup/benchmark to avoid polluting timing. if(kname) { vsa_sparse_attention(q_host, @@ -298,8 +307,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser) lut_host, valid_block_num_host, output_host, - bias_opt, - bias_type, batch, nhead, nhead_k, @@ -323,8 +330,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser) lut_host, valid_block_num_host, output_host, - bias_opt, - bias_type, batch, nhead, nhead_k, @@ -351,8 +356,6 @@ bool run_test(const ck_tile::ArgParser& arg_parser) lut_host, valid_block_num_host, output_host, - bias_opt, - bias_type, batch, nhead, nhead_k, @@ -398,7 +401,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser) auto k_ref = to_bhsd(k_host, i_perm); auto v_ref = to_bhsd(v_host, i_perm); ck_tile::reference_blocked_attention( - q_ref, k_ref, v_ref, block_relation_onehot, bias_host, output_ref, BLKQ, BLKK, scale); + q_ref, k_ref, v_ref, block_relation_onehot, output_ref, BLKQ, BLKK, scale); // Compare results auto [rtol, atol] = get_error_tolerance(); @@ -410,8 +413,8 @@ bool run_test(const ck_tile::ArgParser& arg_parser) auto output_host_bhsd = to_bhsd(output_host, o_perm); for(size_t i = 0; i < output_host_bhsd.mData.size(); ++i) { - float gpu_val = static_cast(output_host_bhsd.mData[i]); - float ref_val = static_cast(output_ref.mData[i]); + float gpu_val = to_float_for_compare(output_host_bhsd.mData[i]); + float ref_val = to_float_for_compare(output_ref.mData[i]); float diff = std::abs(gpu_val - ref_val); float rel_diff = (std::abs(ref_val) > 1e-6f) ? diff / std::abs(ref_val) : diff; diff --git a/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu b/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cpp similarity index 68% rename from example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu rename to example/ck_tile/50_sparse_attn/vsa_sparse_attention.cpp index d3492ed6a48..f07e3d99ae1 100644 --- a/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cu +++ b/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cpp @@ -9,14 +9,12 @@ template ck_tile::HostTensor -vsa_sparse_attention(ck_tile::HostTensor& TQ, - ck_tile::HostTensor& TK, - ck_tile::HostTensor& TV, - ck_tile::HostTensor& TKV_block_idx, - ck_tile::HostTensor& TKV_blocks, +vsa_sparse_attention(const ck_tile::HostTensor& TQ, + const ck_tile::HostTensor& TK, + const ck_tile::HostTensor& TV, + const ck_tile::HostTensor& TKV_block_idx, + const ck_tile::HostTensor& TKV_blocks, ck_tile::HostTensor& Y, - std::optional> bias, - int bias_type, int batch, int nhead, int nhead_k, @@ -30,33 +28,24 @@ vsa_sparse_attention(ck_tile::HostTensor& TQ, int max_seqlen_k, int log_level) { + static_assert(std::is_same_v || + std::is_same_v, + "VSA sparse attention supports fp16/bf16 only."); // Determine data type string based on template parameter - constexpr bool is_fp8 = - std::is_same_v || std::is_same_v; std::string data_type = "fp16"; if constexpr(std::is_same_v) { data_type = "bf16"; } - else if constexpr(std::is_same_v) - { - data_type = "fp8"; - } - else if constexpr(std::is_same_v) - { - data_type = "bf8"; - } if(max_seqlen_q == 0) max_seqlen_q = seqlen_q; if(max_seqlen_k == 0) max_seqlen_k = seqlen_k; - bool is_v_rowmajor = true; - float scale_s = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); - float scale_p = 1.f; - float scale_o = 1.f; - const float logits_soft_cap = 0.0; - + bool is_v_rowmajor = true; + float scale_s = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); + float scale_p = 1.f; + float scale_o = 1.f; std::string msk_str = "0"; mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k); @@ -84,11 +73,6 @@ vsa_sparse_attention(ck_tile::HostTensor& TQ, lut_buf.ToDevice(TKV_block_idx.data()); valid_block_num_buf.ToDevice(TKV_blocks.data()); - // Optional buffers - ck_tile::DeviceMem bias_buf(bias ? bias->get_element_space_size_in_bytes() : 0); - - if(bias) - bias_buf.ToDevice(bias->data()); const auto init_args = [&](auto& args) { assert(nhead % nhead_k == 0); const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); @@ -99,7 +83,6 @@ vsa_sparse_attention(ck_tile::HostTensor& TQ, else return (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k); }(); - const ck_tile::index_t stride_bias = (i_perm ? max_seqlen_k : 1 * max_seqlen_k); const ck_tile::index_t stride_randval = (max_seqlen_k); const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); // setup nhead_stride_* arguments @@ -111,8 +94,6 @@ vsa_sparse_attention(ck_tile::HostTensor& TQ, else return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k; }(); - const ck_tile::index_t nhead_stride_bias = - (i_perm ? 0 * shape_seqlen_q * max_seqlen_k : 0 * max_seqlen_k); const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); @@ -120,7 +101,6 @@ vsa_sparse_attention(ck_tile::HostTensor& TQ, const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); const ck_tile::index_t batch_stride_k = nhead_k * shape_seqlen_k * hdim_q; const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * shape_seqlen_k; - const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); @@ -149,9 +129,8 @@ vsa_sparse_attention(ck_tile::HostTensor& TQ, args.batch_stride_k = batch_stride_k; args.batch_stride_v = batch_stride_v; - args.bias_ptr = bias ? bias_buf.GetDeviceBuffer() : nullptr; - args.lse_ptr = nullptr; - args.o_ptr = o_buf.GetDeviceBuffer(); + args.lse_ptr = nullptr; + args.o_ptr = o_buf.GetDeviceBuffer(); args.seqstart_q_ptr = nullptr; args.seqstart_k_ptr = nullptr; @@ -164,16 +143,11 @@ vsa_sparse_attention(ck_tile::HostTensor& TQ, args.scale_p = scale_p; args.scale_o = scale_o; - args.logits_soft_cap = logits_soft_cap; - - args.stride_bias = stride_bias; - args.stride_o = stride_o; - args.nhead_stride_bias = nhead_stride_bias; - args.nhead_stride_lse = nhead_stride_lse; - args.nhead_stride_o = nhead_stride_o; - args.batch_stride_bias = batch_stride_bias; - args.batch_stride_lse = batch_stride_lse; - args.batch_stride_o = batch_stride_o; + args.stride_o = stride_o; + args.nhead_stride_lse = nhead_stride_lse; + args.nhead_stride_o = nhead_stride_o; + args.batch_stride_lse = batch_stride_lse; + args.batch_stride_o = batch_stride_o; args.window_size_left = mask.left; args.window_size_right = mask.right; @@ -185,8 +159,7 @@ vsa_sparse_attention(ck_tile::HostTensor& TQ, args.nhead_stride_randval = nhead_stride_randval; args.batch_stride_randval = batch_stride_randval; - args.p_drop = 0.; - args.s_randval = false; + // Dropout not supported for sparse attention. }; const auto init_traits = [&](auto& traits) { @@ -196,11 +169,9 @@ vsa_sparse_attention(ck_tile::HostTensor& TQ, traits.is_v_rowmajor = is_v_rowmajor; traits.is_group_mode = false; - traits.has_logits_soft_cap = 0.f < logits_soft_cap; traits.mask_type = mask.type; - traits.bias_type = static_cast(bias_type); traits.has_lse = false; - traits.do_fp8_static_quant = is_fp8; + traits.do_fp8_static_quant = false; traits.has_dropout = false; }; @@ -208,7 +179,7 @@ vsa_sparse_attention(ck_tile::HostTensor& TQ, fmha_jenga_fwd_traits fmha_traits; init_traits(fmha_traits); - fmha_jenga_fwd_args args; + fmha_vsa_fwd_args args; init_args(args); fmha_vsa_fwd(fmha_traits, args, stream_config); @@ -221,17 +192,41 @@ vsa_sparse_attention(ck_tile::HostTensor& TQ, // Explicit template instantiations template ck_tile::HostTensor -vsa_sparse_attention( - ck_tile::HostTensor&, ck_tile::HostTensor&, - ck_tile::HostTensor&, ck_tile::HostTensor&, - ck_tile::HostTensor&, ck_tile::HostTensor&, - std::optional>, - int, int, int, int, int, int, int, int, bool, bool, int, int, int); +vsa_sparse_attention(const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + ck_tile::HostTensor&, + int, + int, + int, + int, + int, + int, + int, + bool, + bool, + int, + int, + int); template ck_tile::HostTensor -vsa_sparse_attention( - ck_tile::HostTensor&, ck_tile::HostTensor&, - ck_tile::HostTensor&, ck_tile::HostTensor&, - ck_tile::HostTensor&, ck_tile::HostTensor&, - std::optional>, - int, int, int, int, int, int, int, int, bool, bool, int, int, int); +vsa_sparse_attention(const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + ck_tile::HostTensor&, + int, + int, + int, + int, + int, + int, + int, + bool, + bool, + int, + int, + int); diff --git a/include/ck_tile/host/reference/reference_blocked_attention.hpp b/include/ck_tile/host/reference/reference_blocked_attention.hpp index ccd5845f51f..2b6c1017b24 100644 --- a/include/ck_tile/host/reference/reference_blocked_attention.hpp +++ b/include/ck_tile/host/reference/reference_blocked_attention.hpp @@ -9,10 +9,29 @@ #include #include "ck_tile/core.hpp" +#include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/host/host_tensor.hpp" namespace ck_tile { +template +CK_TILE_HOST_DEVICE constexpr AccT to_acc(T value) +{ + if constexpr(std::is_same_v) + { +#if CK_TILE_USE_CUSTOM_DATA_TYPE + return static_cast(value); +#else + return static_cast( + ck_tile::bf16_to_float_raw(ck_tile::bit_cast(value))); +#endif + } + else + { + return static_cast(value); + } +} + // Reference implementation: blocked attention (for sparse attention tests). template void reference_blocked_attention( @@ -20,7 +39,6 @@ void reference_blocked_attention( const HostTensor& k, // [B, H, S_k, D] const HostTensor& v, // [B, H, S_k, D_v] const HostTensor& block_relation, // [B, H, Q_blocks, K_blocks] - const HostTensor& bias, // [B, H, S_q, S_k] HostTensor& output, // [B, H, S_q, D_v] index_t BLKQ, index_t BLKK, @@ -55,6 +73,7 @@ void reference_blocked_attention( std::vector relevant_k_indices; for(index_t kb = 0; kb < num_k_blocks; ++kb) { + // Treat block_relation as boolean; >0.5 marks an active block. if(static_cast(block_relation(b, h, qb, kb)) > 0.5f) { relevant_k_indices.push_back(kb); @@ -85,10 +104,10 @@ void reference_blocked_attention( AccT score = 0.0f; for(index_t d = 0; d < hdim; ++d) { - score += static_cast(q(b, h, sq, d)) * - static_cast(k(b, h, sk, d)); + score += + to_acc(q(b, h, sq, d)) * to_acc(k(b, h, sk, d)); } - score = score * scale + static_cast(bias(b, h, sq, sk)); + score = score * scale; scores.push_back(score); max_score = std::max(max_score, score); } @@ -121,7 +140,7 @@ void reference_blocked_attention( for(index_t sk = k_start; sk < k_end; ++sk) { - out_val += scores[score_idx] * static_cast(v(b, h, sk, dv)); + out_val += scores[score_idx] * to_acc(v(b, h, sk, dv)); score_idx++; } } diff --git a/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_jenga_kernel.hpp b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_jenga_kernel.hpp index 828a3bb072a..b1414fce40e 100644 --- a/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_jenga_kernel.hpp +++ b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_jenga_kernel.hpp @@ -55,8 +55,14 @@ struct FmhaFwdJengaKernel static constexpr bool kDoFp8StaticQuant = (FmhaPipeline::Problem::QScaleEnum != ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE); static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ; + static_assert(!kIsGroupMode, "Jenga sparse attention currently supports batch mode only."); + static_assert(BiasEnum == BlockAttentionBiasEnum::NO_BIAS, + "Jenga sparse attention does not support bias."); static_assert(!kStoreLSE, "Jenga sparse attention does not support LSE output."); static_assert(!kHasDropout, "Jenga sparse attention does not support dropout."); + static_assert(!kHasLogitsSoftCap, "Jenga sparse attention does not support logits soft-cap."); + static_assert(!kDoFp8StaticQuant, + "Jenga sparse attention does not support FP8 static quantization yet."); using AttentionVariant = ck_tile::remove_cvref_t; using FmhaMask = ck_tile::remove_cvref_t; @@ -148,47 +154,6 @@ struct FmhaFwdJengaKernel ck_tile::index_t nhead_stride_o; }; - struct FmhaFwdLogitsSoftCapKargs - { - FmhaFwdLogitsSoftCapKargs() = default; - - void init_logits_soft_cap(float logits_soft_cap_) - { - if(0 < logits_soft_cap_) - { - logits_soft_cap = logits_soft_cap_; - logits_soft_cap_rcp = 1.f / logits_soft_cap; - } - else - { - logits_soft_cap = 0.f; - logits_soft_cap_rcp = 0.f; - } - } - - float logits_soft_cap; - float logits_soft_cap_rcp; - }; - - struct FmhaFwdCommonBiasKargs - { - const void* bias_ptr = nullptr; - ck_tile::index_t stride_bias = 0; - ck_tile::index_t nhead_stride_bias = 0; - }; - - struct FmhaFwdBatchModeBiasKargs : FmhaFwdCommonBiasKargs - { - ck_tile::index_t batch_stride_bias = 0; - }; - - struct FmhaFwdAlibiKargs - { - // alibi is batch*nhead*1, no matter in batch/group mode, they are the same - const void* alibi_slope_ptr; - ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope - }; - struct FmhaFwdMaskKargs { // ck_tile::index_t window_size_left, window_size_right; @@ -209,16 +174,9 @@ struct FmhaFwdJengaKernel struct FmhaFwdBatchModeKargs : FmhaFwdCommonKargs, - std::conditional_t>>, std::conditional_t>, FmhaFwdEmptyKargs<2>, - std::conditional_t>, - FmhaFwdEmptyKargs<4>, - std::conditional_t> + std::conditional_t> { ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; @@ -228,17 +186,10 @@ struct FmhaFwdJengaKernel struct FmhaFwdGroupModeKargs : FmhaFwdCommonKargs, - std::conditional_t>>, std::conditional_t>, FmhaFwdEmptyKargs<2>, std::conditional_t>, - FmhaFwdEmptyKargs<4>, - std::conditional_t>, - std::conditional_t> + std::conditional_t> { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; @@ -260,7 +211,6 @@ struct FmhaFwdJengaKernel const void* k_ptr, const void* v_ptr, const void* block_relation_onehot_ptr, - const void* bias_ptr, void* rand_val_ptr, void* lse_ptr, void* o_ptr, @@ -277,20 +227,17 @@ struct FmhaFwdJengaKernel ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, - ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, - ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, @@ -302,6 +249,17 @@ struct FmhaFwdJengaKernel std::variant, std::pair> drop_seed_offset) { + (void)rand_val_ptr; + (void)lse_ptr; + (void)logits_soft_cap; + (void)stride_randval; + (void)nhead_stride_randval; + (void)nhead_stride_lse; + (void)batch_stride_randval; + (void)batch_stride_lse; + (void)p_drop; + (void)s_randval; + (void)drop_seed_offset; Kargs kargs{{q_ptr, k_ptr, v_ptr, @@ -326,71 +284,25 @@ struct FmhaFwdJengaKernel nhead_stride_k, nhead_stride_v, nhead_stride_o}, // args for common karg - {}, // placeholder for bias {}, // placeholder for mask - {}, // placeholder for lse + {}, // placeholder for empty kargs {}, // placeholder for fp8_static_quant args - {}, // placeholder for dropout - {}, // placeholder for logits_soft_cap batch_stride_q, batch_stride_k, batch_stride_v, batch_stride_o}; - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - kargs.bias_ptr = bias_ptr; - kargs.stride_bias = stride_bias; - kargs.nhead_stride_bias = nhead_stride_bias; - kargs.batch_stride_bias = batch_stride_bias; - } - else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - kargs.alibi_slope_ptr = bias_ptr; - kargs.alibi_slope_stride = stride_bias; - } if constexpr(kHasMask) { kargs.window_size_left = window_size_left; kargs.window_size_right = window_size_right; kargs.mask_type = static_cast(mask_type); } - if constexpr(kStoreLSE) - { - kargs.lse_ptr = lse_ptr; - kargs.nhead_stride_lse = nhead_stride_lse; - kargs.batch_stride_lse = batch_stride_lse; - } if constexpr(kDoFp8StaticQuant) { kargs.scale_p = scale_p; kargs.scale_o = scale_o; } - if constexpr(kHasDropout) - { - if(drop_seed_offset.index() == 0) // seed & offset come from host - { - const auto& [seed, offset] = std::get<0>(drop_seed_offset); - kargs.init_dropout(p_drop, seed, offset); - } - else // seed & offset come from device - { - const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset); - kargs.init_dropout(p_drop, - reinterpret_cast(seed_ptr), - reinterpret_cast(offset_ptr)); - } - - kargs.rand_val_ptr = rand_val_ptr; - kargs.stride_randval = stride_randval; - kargs.nhead_stride_randval = nhead_stride_randval; - kargs.batch_stride_randval = batch_stride_randval; - kargs.is_store_randval = s_randval; - } - if constexpr(kHasLogitsSoftCap) - { - kargs.init_logits_soft_cap(logits_soft_cap); - } return kargs; } @@ -402,7 +314,6 @@ struct FmhaFwdJengaKernel const void* k_ptr, const void* v_ptr, const void* block_relation_onehot_ptr, - const void* bias_ptr, void* rand_val_ptr, void* lse_ptr, void* o_ptr, @@ -419,20 +330,17 @@ struct FmhaFwdJengaKernel ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, - ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, - ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, @@ -448,7 +356,6 @@ struct FmhaFwdJengaKernel k_ptr, v_ptr, block_relation_onehot_ptr, - bias_ptr, rand_val_ptr, lse_ptr, o_ptr, @@ -465,20 +372,17 @@ struct FmhaFwdJengaKernel stride_q, stride_k, stride_v, - stride_bias, stride_randval, stride_o, nhead_stride_q, nhead_stride_k, nhead_stride_v, - nhead_stride_bias, nhead_stride_randval, nhead_stride_lse, nhead_stride_o, batch_stride_q, batch_stride_k, batch_stride_v, - batch_stride_bias, batch_stride_randval, batch_stride_lse, batch_stride_o, @@ -497,7 +401,6 @@ struct FmhaFwdJengaKernel const void* k_ptr, const void* v_ptr, const void* block_relation_onehot_ptr, - const void* bias_ptr, void* rand_val_ptr, void* lse_ptr, void* o_ptr, @@ -514,20 +417,17 @@ struct FmhaFwdJengaKernel ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, - ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, - ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, @@ -543,7 +443,6 @@ struct FmhaFwdJengaKernel k_ptr, v_ptr, block_relation_onehot_ptr, - bias_ptr, rand_val_ptr, lse_ptr, o_ptr, @@ -560,20 +459,17 @@ struct FmhaFwdJengaKernel stride_q, stride_k, stride_v, - stride_bias, stride_randval, stride_o, nhead_stride_q, nhead_stride_k, nhead_stride_v, - nhead_stride_bias, nhead_stride_randval, nhead_stride_lse, nhead_stride_o, batch_stride_q, batch_stride_k, batch_stride_v, - batch_stride_bias, batch_stride_randval, batch_stride_lse, batch_stride_o, @@ -591,7 +487,6 @@ struct FmhaFwdJengaKernel const void* k_ptr, const void* v_ptr, const void* block_relation_onehot_ptr, - const void* bias_ptr, void* rand_val_ptr, void* lse_ptr, void* o_ptr, @@ -609,13 +504,11 @@ struct FmhaFwdJengaKernel ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, - ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, @@ -628,6 +521,15 @@ struct FmhaFwdJengaKernel std::variant, std::pair> drop_seed_offset) { + (void)rand_val_ptr; + (void)lse_ptr; + (void)logits_soft_cap; + (void)stride_randval; + (void)nhead_stride_randval; + (void)nhead_stride_lse; + (void)p_drop; + (void)s_randval; + (void)drop_seed_offset; Kargs kargs{{q_ptr, k_ptr, v_ptr, @@ -652,68 +554,25 @@ struct FmhaFwdJengaKernel nhead_stride_k, nhead_stride_v, nhead_stride_o}, // args for common karg - {}, // placeholder for bias {}, // placeholder for mask - {}, // placeholder for lse + {}, // placeholder for empty kargs {}, // placeholder for fp8_static_quant args - {}, // placeholder for dropout - {}, // placeholder for logits_soft_cap {}, // placeholder for min_seqlen_q reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), reinterpret_cast(seqlen_k_ptr)}; - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - kargs.bias_ptr = bias_ptr; - kargs.stride_bias = stride_bias; - kargs.nhead_stride_bias = nhead_stride_bias; - } - else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - kargs.alibi_slope_ptr = bias_ptr; - kargs.alibi_slope_stride = stride_bias; - } if constexpr(kHasMask) { kargs.window_size_left = window_size_left; kargs.window_size_right = window_size_right; kargs.mask_type = static_cast(mask_type); } - if constexpr(kStoreLSE) - { - kargs.lse_ptr = lse_ptr; - kargs.nhead_stride_lse = nhead_stride_lse; - } if constexpr(kDoFp8StaticQuant) { kargs.scale_p = scale_p; kargs.scale_o = scale_o; } - if constexpr(kHasDropout) - { - if(drop_seed_offset.index() == 0) // seed & offset come from host - { - const auto& [seed, offset] = std::get<0>(drop_seed_offset); - kargs.init_dropout(p_drop, seed, offset); - } - else // seed & offset come from device - { - const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset); - kargs.init_dropout(p_drop, - reinterpret_cast(seed_ptr), - reinterpret_cast(offset_ptr)); - } - - kargs.rand_val_ptr = rand_val_ptr; - kargs.stride_randval = stride_randval; - kargs.nhead_stride_randval = nhead_stride_randval; - kargs.is_store_randval = s_randval; - } - if constexpr(kHasLogitsSoftCap) - { - kargs.init_logits_soft_cap(logits_soft_cap); - } if constexpr(kSkipMinSeqlenQ) { kargs.min_seqlen_q = min_seqlen_q; @@ -729,7 +588,6 @@ struct FmhaFwdJengaKernel const void* k_ptr, const void* v_ptr, const void* block_relation_onehot_ptr, - const void* bias_ptr, void* rand_val_ptr, void* lse_ptr, void* o_ptr, @@ -747,13 +605,11 @@ struct FmhaFwdJengaKernel ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, - ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, @@ -769,7 +625,6 @@ struct FmhaFwdJengaKernel k_ptr, v_ptr, block_relation_onehot_ptr, - bias_ptr, rand_val_ptr, lse_ptr, o_ptr, @@ -787,13 +642,11 @@ struct FmhaFwdJengaKernel stride_q, stride_k, stride_v, - stride_bias, stride_randval, stride_o, nhead_stride_q, nhead_stride_k, nhead_stride_v, - nhead_stride_bias, nhead_stride_randval, nhead_stride_lse, nhead_stride_o, @@ -812,7 +665,6 @@ struct FmhaFwdJengaKernel const void* k_ptr, const void* v_ptr, const void* block_relation_onehot_ptr, - const void* bias_ptr, void* rand_val_ptr, void* lse_ptr, void* o_ptr, @@ -830,13 +682,11 @@ struct FmhaFwdJengaKernel ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, - ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, @@ -852,7 +702,6 @@ struct FmhaFwdJengaKernel k_ptr, v_ptr, block_relation_onehot_ptr, - bias_ptr, rand_val_ptr, lse_ptr, o_ptr, @@ -870,13 +719,11 @@ struct FmhaFwdJengaKernel stride_q, stride_k, stride_v, - stride_bias, stride_randval, stride_o, nhead_stride_q, nhead_stride_k, nhead_stride_v, - nhead_stride_bias, nhead_stride_randval, nhead_stride_lse, nhead_stride_o, @@ -1003,13 +850,10 @@ struct FmhaFwdJengaKernel const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); - long_index_t batch_offset_q = 0; - long_index_t batch_offset_k = 0; - long_index_t batch_offset_v = 0; - long_index_t batch_offset_bias = 0; - long_index_t batch_offset_randval = 0; - long_index_t batch_offset_lse = 0; - long_index_t batch_offset_o = 0; + long_index_t batch_offset_q = 0; + long_index_t batch_offset_k = 0; + long_index_t batch_offset_v = 0; + long_index_t batch_offset_o = 0; if constexpr(kIsGroupMode) { @@ -1027,18 +871,6 @@ struct FmhaFwdJengaKernel { batch_offset_v = key_start; } - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - batch_offset_bias = query_start * kargs.stride_bias; - } - if constexpr(kStoreLSE) - { - batch_offset_lse = query_start; - } - if constexpr(kHasDropout) - { - batch_offset_randval = query_start * kargs.stride_randval; - } batch_offset_o = query_start * kargs.stride_o; // get real # queries & # keys under group mode @@ -1075,19 +907,6 @@ struct FmhaFwdJengaKernel batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; - } - if constexpr(kStoreLSE) - { - batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; - } - if constexpr(kHasDropout) - { - batch_offset_randval = - static_cast(i_batch) * kargs.batch_stride_randval; - } batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; } @@ -1212,119 +1031,18 @@ struct FmhaFwdJengaKernel make_tuple(number{}, number{}), {i_n1, 0}); - /// FIXME: Before C++20, capturing structured binding variables are not supported. Remove - /// following copy capture of the 'i_nhead' if in C++20 - const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { - constexpr auto bias_dram_window_lengths = - make_tuple(number{}, number{}); - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - const BiasDataType* bias_ptr = - reinterpret_cast(kargs.bias_ptr) + - static_cast(i_nhead_) * kargs.nhead_stride_bias + - batch_offset_bias; - - const auto bias_dram = [&]() { - const auto bias_dram_naive = make_naive_tensor_view( - bias_ptr, - make_tuple(kargs.seqlen_q, kargs.seqlen_k), - make_tuple(kargs.stride_bias, 1), - number{}, - number<1>{}); - - return pad_tensor_view(bias_dram_naive, - bias_dram_window_lengths, - sequence{}); - }(); - - return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0}); - } - else - { - return make_null_tile_window(bias_dram_window_lengths); - } - }(); + constexpr auto bias_dram_window_lengths = + make_tuple(number{}, number{}); + const auto bias_dram_window = make_null_tile_window(bias_dram_window_lengths); - // lse - auto lse_dram_window = [&, i_nhead_ = i_nhead]() { - constexpr auto lse_dram_window_lengths = make_tuple(number{}); - if constexpr(kStoreLSE) - { - LSEDataType* lse_ptr = - reinterpret_cast(kargs.lse_ptr) + - static_cast(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse; - - const auto lse_dram = [&]() { - const auto lse_dram_naive = make_naive_tensor_view( - lse_ptr, - make_tuple(kargs.seqlen_q), - make_tuple(1), - number<1>{}, - number<1>{}); - - return pad_tensor_view( - lse_dram_naive, lse_dram_window_lengths, sequence{}); - }(); - - return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0}); - } - else - { - return make_null_tile_window(lse_dram_window_lengths); - } - }(); + constexpr auto lse_dram_window_lengths = make_tuple(number{}); + auto lse_dram_window = make_null_tile_window(lse_dram_window_lengths); - auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() { - if constexpr(kHasDropout) - { - return BlockDropout{i_batch_, - i_nhead_, - kargs.num_head_q, - kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val - : *kargs.drop_seed.ptr, - kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val - : *kargs.drop_offset.ptr, - kargs.rp_undrop, - kargs.p_undrop_in_uint8_t, - kargs.is_store_randval}; - } - else - { - return NullBlockDropout{}; - }; - }(); + auto dropout = NullBlockDropout{}; - auto randval_dram_window = [&, i_nhead_ = i_nhead]() { - constexpr auto randval_dram_window_lengths = - make_tuple(number{}, number{}); - if constexpr(kHasDropout) - { - RandValOutputDataType* rand_val_ptr = - reinterpret_cast(kargs.rand_val_ptr) + - static_cast(i_nhead_) * kargs.nhead_stride_randval + - batch_offset_randval; - - const auto randval_dram = [&]() { - const auto randval_dram_naive = - make_naive_tensor_view( - rand_val_ptr, - make_tuple(kargs.seqlen_q, kargs.seqlen_k), - make_tuple(kargs.stride_randval, 1), - number<1>{}, - number<1>{}); - - return pad_tensor_view(randval_dram_naive, - randval_dram_window_lengths, - sequence{}); - }(); - - return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0}); - } - else - { - return make_null_tile_window(randval_dram_window_lengths); - } - }(); + constexpr auto randval_dram_window_lengths = + make_tuple(number{}, number{}); + auto randval_dram_window = make_null_tile_window(randval_dram_window_lengths); FmhaMask mask = [&]() { if constexpr(kHasMask) @@ -1339,50 +1057,10 @@ struct FmhaFwdJengaKernel }(); // WA i_batch capture structure binding before c++20 - auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { - if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - // data loading, shared by entire wg - // TODO: how to use s_read? - SaccDataType slope = - *(reinterpret_cast(kargs.alibi_slope_ptr) + - i_batch_ * kargs.alibi_slope_stride + i_nhead_); -#if CK_TILE_FMHA_FWD_FAST_EXP2 - slope *= ck_tile::log2e_v<>; -#endif - if constexpr(kHasMask) - { - return make_alibi_from_lr_mask(slope, - kargs.window_size_left, - kargs.window_size_right, - kargs.seqlen_q, - kargs.seqlen_k, - kargs.mask_type); - } - else - { - return Alibi{ - slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT}; - } - } - else - { - return EmptyPositionEncoding{}; - } - }(); + auto position_encoding = EmptyPositionEncoding{}; AttentionVariant variant; - const auto variant_params = [&] { - if constexpr(kHasLogitsSoftCap) - { - return ck_tile::LogitsSoftCapParams{ - mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp}; - } - else - { - return ck_tile::StandardAttentionParams{mask, kargs.scale_s}; - } - }(); + const auto variant_params = ck_tile::StandardAttentionParams{mask, kargs.scale_s}; BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; diff --git a/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp index e17a4da6c81..33b42b4266d 100644 --- a/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp +++ b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp @@ -56,6 +56,11 @@ struct FmhaFwdVSAKernel static constexpr bool kDoFp8StaticQuant = (QScaleEnum != ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE); static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ; + static_assert(BiasEnum == BlockAttentionBiasEnum::NO_BIAS, + "VSA sparse attention does not support bias."); + static_assert(!kHasLogitsSoftCap, "VSA sparse attention does not support logits soft-cap."); + static_assert(!kDoFp8StaticQuant, + "VSA sparse attention does not support FP8 static quantization yet."); using AttentionVariant = ck_tile::remove_cvref_t; using FmhaMask = ck_tile::remove_cvref_t; @@ -170,25 +175,6 @@ struct FmhaFwdVSAKernel float logits_soft_cap_rcp; }; - struct FmhaFwdCommonBiasKargs - { - const void* bias_ptr = nullptr; - ck_tile::index_t stride_bias = 0; - ck_tile::index_t nhead_stride_bias = 0; - }; - - struct FmhaFwdBatchModeBiasKargs : FmhaFwdCommonBiasKargs - { - ck_tile::index_t batch_stride_bias = 0; - }; - - struct FmhaFwdAlibiKargs - { - // alibi is batch*nhead*1, no matter in batch/group mode, they are the same - const void* alibi_slope_ptr; - ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope - }; - struct FmhaFwdMaskKargs { // ck_tile::index_t window_size_left, window_size_right; @@ -202,67 +188,6 @@ struct FmhaFwdVSAKernel float scale_o; }; - struct FmhaFwdCommonLSEKargs - { - void* lse_ptr = nullptr; - ck_tile::index_t nhead_stride_lse = 0; - ck_tile::index_t batch_stride_lse = 0; - }; - - struct FmhaFwdDropoutSeedOffset - { - template - union ValueOrPointer - { - T val; - const T* ptr; - }; - - ValueOrPointer drop_seed; - ValueOrPointer drop_offset; - bool is_drop_seed_offset_from_host; - }; - - struct FmhaFwdCommonDropoutKargs : FmhaFwdDropoutSeedOffset - { - void init_dropout(float p_drop, uint64_t seed, uint64_t offset) - { - float p_undrop = 1.0 - p_drop; - p_undrop_in_uint8_t = - uint8_t(std::floor(p_undrop * std::numeric_limits::max())); - rp_undrop = 1.0 / p_undrop; - - this->drop_seed.val = seed; - this->drop_offset.val = offset; - this->is_drop_seed_offset_from_host = true; - } - - void init_dropout(float p_drop, const uint64_t* seed_ptr, const uint64_t* offset_ptr) - { - float p_undrop = 1.0 - p_drop; - p_undrop_in_uint8_t = - uint8_t(std::floor(p_undrop * std::numeric_limits::max())); - rp_undrop = 1.0 / p_undrop; - - this->drop_seed.ptr = seed_ptr; - this->drop_offset.ptr = offset_ptr; - this->is_drop_seed_offset_from_host = false; - } - - float rp_undrop = 1; - uint8_t p_undrop_in_uint8_t = std::numeric_limits::max(); - bool is_store_randval = false; - void* rand_val_ptr = nullptr; - - ck_tile::index_t stride_randval = 0; - ck_tile::index_t nhead_stride_randval = 0; - }; - - struct FmhaFwdBatchModeDropoutKargs : FmhaFwdCommonDropoutKargs - { - ck_tile::index_t batch_stride_randval = 0; - }; - struct FmhaFwdSkipMinSeqlenQKargs { ck_tile::index_t min_seqlen_q = 0; @@ -270,16 +195,9 @@ struct FmhaFwdVSAKernel struct FmhaFwdBatchModeKargs : FmhaFwdCommonKargs, - std::conditional_t>>, + FmhaFwdEmptyKargs<0>, std::conditional_t>, - std::conditional_t>, - std::conditional_t>, - std::conditional_t>, - std::conditional_t> + std::conditional_t> { ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; @@ -289,17 +207,10 @@ struct FmhaFwdVSAKernel struct FmhaFwdGroupModeKargs : FmhaFwdCommonKargs, - std::conditional_t>>, + FmhaFwdEmptyKargs<0>, std::conditional_t>, - std::conditional_t>, - std::conditional_t>, - std::conditional_t>, - std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t> { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; @@ -322,7 +233,6 @@ struct FmhaFwdVSAKernel const void* v_ptr, const void* lut_ptr, const void* valid_block_num_ptr, - const void* bias_ptr, void* rand_val_ptr, void* lse_ptr, void* o_ptr, @@ -339,20 +249,17 @@ struct FmhaFwdVSAKernel ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, - ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, - ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, @@ -364,6 +271,17 @@ struct FmhaFwdVSAKernel std::variant, std::pair> drop_seed_offset) { + (void)rand_val_ptr; + (void)lse_ptr; + (void)logits_soft_cap; + (void)stride_randval; + (void)nhead_stride_randval; + (void)nhead_stride_lse; + (void)batch_stride_randval; + (void)batch_stride_lse; + (void)p_drop; + (void)s_randval; + (void)drop_seed_offset; Kargs kargs{{q_ptr, k_ptr, v_ptr, @@ -389,71 +307,25 @@ struct FmhaFwdVSAKernel nhead_stride_k, nhead_stride_v, nhead_stride_o}, // args for common karg - {}, // placeholder for bias {}, // placeholder for mask - {}, // placeholder for lse + {}, // placeholder for empty kargs {}, // placeholder for fp8_static_quant args - {}, // placeholder for dropout - {}, // placeholder for logits_soft_cap batch_stride_q, batch_stride_k, batch_stride_v, batch_stride_o}; - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - kargs.bias_ptr = bias_ptr; - kargs.stride_bias = stride_bias; - kargs.nhead_stride_bias = nhead_stride_bias; - kargs.batch_stride_bias = batch_stride_bias; - } - else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - kargs.alibi_slope_ptr = bias_ptr; - kargs.alibi_slope_stride = stride_bias; - } if constexpr(kHasMask) { kargs.window_size_left = window_size_left; kargs.window_size_right = window_size_right; kargs.mask_type = static_cast(mask_type); } - if constexpr(kStoreLSE) - { - kargs.lse_ptr = lse_ptr; - kargs.nhead_stride_lse = nhead_stride_lse; - kargs.batch_stride_lse = batch_stride_lse; - } if constexpr(kDoFp8StaticQuant) { kargs.scale_p = scale_p; kargs.scale_o = scale_o; } - if constexpr(kHasDropout) - { - if(drop_seed_offset.index() == 0) // seed & offset come from host - { - const auto& [seed, offset] = std::get<0>(drop_seed_offset); - kargs.init_dropout(p_drop, seed, offset); - } - else // seed & offset come from device - { - const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset); - kargs.init_dropout(p_drop, - reinterpret_cast(seed_ptr), - reinterpret_cast(offset_ptr)); - } - - kargs.rand_val_ptr = rand_val_ptr; - kargs.stride_randval = stride_randval; - kargs.nhead_stride_randval = nhead_stride_randval; - kargs.batch_stride_randval = batch_stride_randval; - kargs.is_store_randval = s_randval; - } - if constexpr(kHasLogitsSoftCap) - { - kargs.init_logits_soft_cap(logits_soft_cap); - } return kargs; } @@ -466,7 +338,6 @@ struct FmhaFwdVSAKernel const void* v_ptr, const void* lut_ptr, const void* valid_block_num_ptr, - const void* bias_ptr, void* rand_val_ptr, void* lse_ptr, void* o_ptr, @@ -483,20 +354,17 @@ struct FmhaFwdVSAKernel ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, - ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, - ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, @@ -513,7 +381,6 @@ struct FmhaFwdVSAKernel v_ptr, lut_ptr, valid_block_num_ptr, - bias_ptr, rand_val_ptr, lse_ptr, o_ptr, @@ -530,20 +397,17 @@ struct FmhaFwdVSAKernel stride_q, stride_k, stride_v, - stride_bias, stride_randval, stride_o, nhead_stride_q, nhead_stride_k, nhead_stride_v, - nhead_stride_bias, nhead_stride_randval, nhead_stride_lse, nhead_stride_o, batch_stride_q, batch_stride_k, batch_stride_v, - batch_stride_bias, batch_stride_randval, batch_stride_lse, batch_stride_o, @@ -563,7 +427,6 @@ struct FmhaFwdVSAKernel const void* v_ptr, const void* lut_ptr, const void* valid_block_num_ptr, - const void* bias_ptr, void* rand_val_ptr, void* lse_ptr, void* o_ptr, @@ -580,20 +443,17 @@ struct FmhaFwdVSAKernel ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, - ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, - ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, @@ -610,7 +470,6 @@ struct FmhaFwdVSAKernel v_ptr, lut_ptr, valid_block_num_ptr, - bias_ptr, rand_val_ptr, lse_ptr, o_ptr, @@ -627,20 +486,17 @@ struct FmhaFwdVSAKernel stride_q, stride_k, stride_v, - stride_bias, stride_randval, stride_o, nhead_stride_q, nhead_stride_k, nhead_stride_v, - nhead_stride_bias, nhead_stride_randval, nhead_stride_lse, nhead_stride_o, batch_stride_q, batch_stride_k, batch_stride_v, - batch_stride_bias, batch_stride_randval, batch_stride_lse, batch_stride_o, @@ -659,7 +515,6 @@ struct FmhaFwdVSAKernel const void* v_ptr, const void* lut_ptr, const void* valid_block_num_ptr, - const void* bias_ptr, void* rand_val_ptr, void* lse_ptr, void* o_ptr, @@ -677,13 +532,11 @@ struct FmhaFwdVSAKernel ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, - ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, @@ -696,6 +549,15 @@ struct FmhaFwdVSAKernel std::variant, std::pair> drop_seed_offset) { + (void)rand_val_ptr; + (void)lse_ptr; + (void)logits_soft_cap; + (void)stride_randval; + (void)nhead_stride_randval; + (void)nhead_stride_lse; + (void)p_drop; + (void)s_randval; + (void)drop_seed_offset; Kargs kargs{{q_ptr, k_ptr, v_ptr, @@ -721,68 +583,25 @@ struct FmhaFwdVSAKernel nhead_stride_k, nhead_stride_v, nhead_stride_o}, // args for common karg - {}, // placeholder for bias {}, // placeholder for mask - {}, // placeholder for lse + {}, // placeholder for empty kargs {}, // placeholder for fp8_static_quant args - {}, // placeholder for dropout - {}, // placeholder for logits_soft_cap {}, // placeholder for min_seqlen_q reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), reinterpret_cast(seqlen_k_ptr)}; - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - kargs.bias_ptr = bias_ptr; - kargs.stride_bias = stride_bias; - kargs.nhead_stride_bias = nhead_stride_bias; - } - else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - kargs.alibi_slope_ptr = bias_ptr; - kargs.alibi_slope_stride = stride_bias; - } if constexpr(kHasMask) { kargs.window_size_left = window_size_left; kargs.window_size_right = window_size_right; kargs.mask_type = static_cast(mask_type); } - if constexpr(kStoreLSE) - { - kargs.lse_ptr = lse_ptr; - kargs.nhead_stride_lse = nhead_stride_lse; - } if constexpr(kDoFp8StaticQuant) { kargs.scale_p = scale_p; kargs.scale_o = scale_o; } - if constexpr(kHasDropout) - { - if(drop_seed_offset.index() == 0) // seed & offset come from host - { - const auto& [seed, offset] = std::get<0>(drop_seed_offset); - kargs.init_dropout(p_drop, seed, offset); - } - else // seed & offset come from device - { - const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset); - kargs.init_dropout(p_drop, - reinterpret_cast(seed_ptr), - reinterpret_cast(offset_ptr)); - } - - kargs.rand_val_ptr = rand_val_ptr; - kargs.stride_randval = stride_randval; - kargs.nhead_stride_randval = nhead_stride_randval; - kargs.is_store_randval = s_randval; - } - if constexpr(kHasLogitsSoftCap) - { - kargs.init_logits_soft_cap(logits_soft_cap); - } if constexpr(kSkipMinSeqlenQ) { kargs.min_seqlen_q = min_seqlen_q; @@ -799,7 +618,6 @@ struct FmhaFwdVSAKernel const void* v_ptr, const void* lut_ptr, const void* valid_block_num_ptr, - const void* bias_ptr, void* rand_val_ptr, void* lse_ptr, void* o_ptr, @@ -817,13 +635,11 @@ struct FmhaFwdVSAKernel ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, - ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, @@ -840,7 +656,6 @@ struct FmhaFwdVSAKernel v_ptr, lut_ptr, valid_block_num_ptr, - bias_ptr, rand_val_ptr, lse_ptr, o_ptr, @@ -858,13 +673,11 @@ struct FmhaFwdVSAKernel stride_q, stride_k, stride_v, - stride_bias, stride_randval, stride_o, nhead_stride_q, nhead_stride_k, nhead_stride_v, - nhead_stride_bias, nhead_stride_randval, nhead_stride_lse, nhead_stride_o, @@ -884,7 +697,6 @@ struct FmhaFwdVSAKernel const void* v_ptr, const void* lut_ptr, const void* valid_block_num_ptr, - const void* bias_ptr, void* rand_val_ptr, void* lse_ptr, void* o_ptr, @@ -902,13 +714,11 @@ struct FmhaFwdVSAKernel ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, - ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, @@ -925,7 +735,6 @@ struct FmhaFwdVSAKernel v_ptr, lut_ptr, valid_block_num_ptr, - bias_ptr, rand_val_ptr, lse_ptr, o_ptr, @@ -943,13 +752,11 @@ struct FmhaFwdVSAKernel stride_q, stride_k, stride_v, - stride_bias, stride_randval, stride_o, nhead_stride_q, nhead_stride_k, nhead_stride_v, - nhead_stride_bias, nhead_stride_randval, nhead_stride_lse, nhead_stride_o, @@ -1070,13 +877,10 @@ struct FmhaFwdVSAKernel const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); - long_index_t batch_offset_q = 0; - long_index_t batch_offset_k = 0; - long_index_t batch_offset_v = 0; - long_index_t batch_offset_bias = 0; - long_index_t batch_offset_randval = 0; - long_index_t batch_offset_lse = 0; - long_index_t batch_offset_o = 0; + long_index_t batch_offset_q = 0; + long_index_t batch_offset_k = 0; + long_index_t batch_offset_v = 0; + long_index_t batch_offset_o = 0; if constexpr(kIsGroupMode) { @@ -1094,18 +898,6 @@ struct FmhaFwdVSAKernel { batch_offset_v = key_start; } - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - batch_offset_bias = query_start * kargs.stride_bias; - } - if constexpr(kStoreLSE) - { - batch_offset_lse = query_start; - } - if constexpr(kHasDropout) - { - batch_offset_randval = query_start * kargs.stride_randval; - } batch_offset_o = query_start * kargs.stride_o; // get real # queries & # keys under group mode @@ -1142,19 +934,6 @@ struct FmhaFwdVSAKernel batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; - } - if constexpr(kStoreLSE) - { - batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; - } - if constexpr(kHasDropout) - { - batch_offset_randval = - static_cast(i_batch) * kargs.batch_stride_randval; - } batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; } @@ -1285,119 +1064,18 @@ struct FmhaFwdVSAKernel make_tuple(number{}, number{}), {i_n1, 0}); - /// FIXME: Before C++20, capturing structured binding variables are not supported. Remove - /// following copy capture of the 'i_nhead' if in C++20 - const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { - constexpr auto bias_dram_window_lengths = - make_tuple(number{}, number{}); - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - const BiasDataType* bias_ptr = - reinterpret_cast(kargs.bias_ptr) + - static_cast(i_nhead_) * kargs.nhead_stride_bias + - batch_offset_bias; - - const auto bias_dram = [&]() { - const auto bias_dram_naive = make_naive_tensor_view( - bias_ptr, - make_tuple(kargs.seqlen_q, kargs.seqlen_k), - make_tuple(kargs.stride_bias, 1), - number{}, - number<1>{}); - - return pad_tensor_view(bias_dram_naive, - bias_dram_window_lengths, - sequence{}); - }(); - - return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0}); - } - else - { - return make_null_tile_window(bias_dram_window_lengths); - } - }(); + constexpr auto bias_dram_window_lengths = + make_tuple(number{}, number{}); + const auto bias_dram_window = make_null_tile_window(bias_dram_window_lengths); - // lse - auto lse_dram_window = [&, i_nhead_ = i_nhead]() { - constexpr auto lse_dram_window_lengths = make_tuple(number{}); - if constexpr(kStoreLSE) - { - LSEDataType* lse_ptr = - reinterpret_cast(kargs.lse_ptr) + - static_cast(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse; - - const auto lse_dram = [&]() { - const auto lse_dram_naive = make_naive_tensor_view( - lse_ptr, - make_tuple(kargs.seqlen_q), - make_tuple(1), - number<1>{}, - number<1>{}); - - return pad_tensor_view( - lse_dram_naive, lse_dram_window_lengths, sequence{}); - }(); - - return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0}); - } - else - { - return make_null_tile_window(lse_dram_window_lengths); - } - }(); + constexpr auto lse_dram_window_lengths = make_tuple(number{}); + auto lse_dram_window = make_null_tile_window(lse_dram_window_lengths); - auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() { - if constexpr(kHasDropout) - { - return BlockDropout{i_batch_, - i_nhead_, - kargs.num_head_q, - kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val - : *kargs.drop_seed.ptr, - kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val - : *kargs.drop_offset.ptr, - kargs.rp_undrop, - kargs.p_undrop_in_uint8_t, - kargs.is_store_randval}; - } - else - { - return NullBlockDropout{}; - }; - }(); + auto dropout = NullBlockDropout{}; - auto randval_dram_window = [&, i_nhead_ = i_nhead]() { - constexpr auto randval_dram_window_lengths = - make_tuple(number{}, number{}); - if constexpr(kHasDropout) - { - RandValOutputDataType* rand_val_ptr = - reinterpret_cast(kargs.rand_val_ptr) + - static_cast(i_nhead_) * kargs.nhead_stride_randval + - batch_offset_randval; - - const auto randval_dram = [&]() { - const auto randval_dram_naive = - make_naive_tensor_view( - rand_val_ptr, - make_tuple(kargs.seqlen_q, kargs.seqlen_k), - make_tuple(kargs.stride_randval, 1), - number<1>{}, - number<1>{}); - - return pad_tensor_view(randval_dram_naive, - randval_dram_window_lengths, - sequence{}); - }(); - - return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0}); - } - else - { - return make_null_tile_window(randval_dram_window_lengths); - } - }(); + constexpr auto randval_dram_window_lengths = + make_tuple(number{}, number{}); + auto randval_dram_window = make_null_tile_window(randval_dram_window_lengths); FmhaMask mask = [&]() { if constexpr(kHasMask) @@ -1412,50 +1090,10 @@ struct FmhaFwdVSAKernel }(); // WA i_batch capture structure binding before c++20 - auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { - if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - // data loading, shared by entire wg - // TODO: how to use s_read? - SaccDataType slope = - *(reinterpret_cast(kargs.alibi_slope_ptr) + - i_batch_ * kargs.alibi_slope_stride + i_nhead_); -#if CK_TILE_FMHA_FWD_FAST_EXP2 - slope *= ck_tile::log2e_v<>; -#endif - if constexpr(kHasMask) - { - return make_alibi_from_lr_mask(slope, - kargs.window_size_left, - kargs.window_size_right, - kargs.seqlen_q, - kargs.seqlen_k, - kargs.mask_type); - } - else - { - return Alibi{ - slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT}; - } - } - else - { - return EmptyPositionEncoding{}; - } - }(); + auto position_encoding = EmptyPositionEncoding{}; AttentionVariant variant; - const auto variant_params = [&] { - if constexpr(kHasLogitsSoftCap) - { - return ck_tile::LogitsSoftCapParams{ - mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp}; - } - else - { - return ck_tile::StandardAttentionParams{mask, kargs.scale_s}; - } - }(); + const auto variant_params = ck_tile::StandardAttentionParams{mask, kargs.scale_s}; BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; From 3eef3ad3817763841bfa39a44fe7f5fe03b0f434 Mon Sep 17 00:00:00 2001 From: Jiangyong Date: Thu, 29 Jan 2026 10:41:59 +0800 Subject: [PATCH 17/22] remove useless code --- example/ck_tile/50_sparse_attn/bias.hpp | 99 ---- .../50_sparse_attn/codegen/cpp_symbol_map.py | 14 - .../codegen/ops/fmha_fwd_jenga.py | 85 +--- .../codegen/ops/fmha_fwd_vsa.py | 85 +--- .../ck_tile/50_sparse_attn/fmha_fwd_trek.hpp | 142 +----- .../50_sparse_attn/jenga_sparse_attention.cpp | 46 +- .../50_sparse_attn/vsa_sparse_attention.cpp | 46 +- .../kernel/fmha_fwd_jenga_kernel.hpp | 395 +++------------- .../kernel/fmha_fwd_vsa_kernel.hpp | 427 +++--------------- 9 files changed, 172 insertions(+), 1167 deletions(-) delete mode 100644 example/ck_tile/50_sparse_attn/bias.hpp diff --git a/example/ck_tile/50_sparse_attn/bias.hpp b/example/ck_tile/50_sparse_attn/bias.hpp deleted file mode 100644 index 4f013341e8b..00000000000 --- a/example/ck_tile/50_sparse_attn/bias.hpp +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT -#pragma once - -#include -#include -#include "ck_tile/core.hpp" -#include "ck_tile/ops/fmha.hpp" - -// keep sync with BlockAttentionBiasEnum -enum class bias_enum -{ - no_bias = 0, - elementwise_bias = 1, - alibi = 2, -}; - -struct bias_info -{ - bias_enum type; - /* - * simple dispatch logic - * - * if type == elementwise_bias: - * if rank_info == 0: - * bias is 1*1*s*s - * elif rank_info == 1: - * bias is 1*h*s*s - * elif rank_info == 2: - * bias is b*h*s*s - * - * elif type == alibi: - * if rank_info == 0: - * alibi in 1*h - * elif rank_info == 1: - * alibi in b*h - */ - int rank_info; - - void serialize(std::ostream& os) const - { - if(type == bias_enum::no_bias) - os << "n"; - else if(type == bias_enum::elementwise_bias) - { - os << "e"; - if(rank_info != 0) - { - os << "[" << rank_info << "]"; - } - } - else if(type == bias_enum::alibi) - { - os << "alibi"; - if(rank_info != 0) - { - os << "[" << rank_info << "]"; - } - } - } - - static bias_info decode(std::string str) - { - bias_info info{bias_enum::no_bias, 0}; - if(str == "0" || str == "n") - { - info.type = bias_enum::no_bias; - } - else if(str.compare(0, 1, "1") == 0 || str.compare(0, 1, "e") == 0 || - str.compare(0, 11, "elementwise") == 0) - { - info.type = bias_enum::elementwise_bias; - auto found_0 = str.find(':'); - if(found_0 != std::string::npos) - { - std::string e = str.substr(found_0 + 1); - info.rank_info = atoi(e.c_str()); - } - } - else if(str.compare(0, 1, "2") == 0 || str.compare(0, 1, "a") == 0 || - str.compare(0, 5, "alibi") == 0) - { - info.type = bias_enum::alibi; - auto found_0 = str.find(':'); - if(found_0 != std::string::npos) - { - std::string e = str.substr(found_0 + 1); - info.rank_info = atoi(e.c_str()); - } - } - return info; - } - - friend std::ostream& operator<<(std::ostream& os, const bias_info& bi) - { - bi.serialize(os); - return os; - } -}; diff --git a/example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py b/example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py index d2b655cfd1e..8614a1ff3ba 100644 --- a/example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py +++ b/example/ck_tile/50_sparse_attn/codegen/cpp_symbol_map.py @@ -51,16 +51,6 @@ def get_mask_check_map(mask: str): return None -BIAS_MAP = { - "no": "ck_tile::BlockAttentionBiasEnum::NO_BIAS", -} - -# TODO: this is ugly -BIAS_CHECK_MAP = { - "no": "bias_enum::no_bias", -} - - MODE_MAP = {"batch": "false"} LAYOUT_MAP = {"row": "true", "col": "false"} @@ -81,7 +71,3 @@ def get_mask_check_map(mask: str): True: "true", False: "false", } - -SQUANT_MAP = { - "f": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE", -} diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py index 4fb5db365ea..cec3fe307a9 100644 --- a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py +++ b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py @@ -12,15 +12,12 @@ from typing import List, Optional, Tuple from codegen.cpp_symbol_map import ( - BIAS_CHECK_MAP, - BIAS_MAP, BOOL_MAP, FWD_DTYPE_MAP, LAYOUT_MAP, MODE_MAP, PIPELINE_ENUM_MAP, PIPELINE_MAP, - SQUANT_MAP, get_mask_check_map, get_mask_map, ) @@ -75,11 +72,11 @@ def update_file(file_path, content): {F_dpad}, {F_dvpad}, {F_logits}, - {F_bias}, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, - {F_lse}, - {F_dropout}, - {F_squant_enum}, + false, + false, + ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE, {F_occupancy}, {F_skip}>; @@ -118,7 +115,7 @@ def update_file(file_path, content): ck_tile::FmhaFwdJengaKernel; using trait_{F_idx} = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; #include @@ -204,9 +201,9 @@ def update_file(file_path, content): }} """ -FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && +FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.skip_min_seqlen_q == {F_skip}) && ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ - using trait_ = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; + using trait_ = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; return fmha_jenga_fwd_(s, a); }} """ @@ -242,10 +239,6 @@ class FmhaFwdApiTrait: vlayout: str logits: str mask: str - bias: str # - lse: str # - dropout: str - squant: str # spad: str skpad: str dpad: str @@ -258,7 +251,7 @@ class FmhaFwdApiTrait: def name(self) -> str: return ( f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" - + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}" + + f"{self.vlayout}-{self.logits}-{self.mask}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}" ) @property @@ -309,10 +302,6 @@ class FmhaFwdPipeline: F_dpad: str # F_dvpad: str # F_logits: str # t/f - F_bias: str # true/false - F_lse: str # - F_dropout: str # - F_squant: str # F_mask: str # value from MASK_MAP F_skip: str # true/false F_trload: str # true/false @@ -346,10 +335,7 @@ def pad_name() -> str: else: n += "_nlogits" - if self.F_bias != "no": - n += f"_{self.F_bias}" - else: - n += "_nbias" + n += "_nbias" if self.F_mask[0:2] == "s_": if self.F_mask == "s_mask": @@ -362,17 +348,12 @@ def pad_name() -> str: else: n += "_nmask" - # Note: lse and dropout are not supported, so we don't add them to filename - if self.F_skip == "t": n += "_skip" else: n += "_nskip" - if self.F_squant == "t": - n += "_squant" - else: - n += "_nsquant" + n += "_nsquant" if self.F_trload == "t": n += "_trload" @@ -423,13 +404,8 @@ def api(self) -> str: F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], - F_bias_check=BIAS_CHECK_MAP[trait.bias], - F_bias=BIAS_MAP[trait.bias], - F_lse=BOOL_MAP[trait.lse], - F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip], F_trload=BOOL_MAP[trait.tr_load], - F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_seqtune=trait.seqtune, F_skcheck=trait.skcheck, @@ -542,11 +518,6 @@ def template(self) -> str: F_dpad=BOOL_MAP[self.F_pipeline.F_dpad], F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad], F_logits=BOOL_MAP[self.F_pipeline.F_logits], - F_bias=BIAS_MAP[self.F_pipeline.F_bias], - F_lse=BOOL_MAP[self.F_pipeline.F_lse], - F_dropout=BOOL_MAP[self.F_pipeline.F_dropout], - F_squant=BOOL_MAP[self.F_pipeline.F_squant], - F_squant_enum=SQUANT_MAP[self.F_pipeline.F_squant], F_skip=BOOL_MAP[self.F_pipeline.F_skip], F_occupancy=self.F_tile.F_occupancy, F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], @@ -586,10 +557,6 @@ def api_trait(self) -> FmhaFwdApiTrait: vlayout=self.F_pipeline.F_vlayout, mask=self.F_pipeline.F_mask, logits=self.F_pipeline.F_logits, - bias=self.F_pipeline.F_bias, - lse=self.F_pipeline.F_lse, - dropout=self.F_pipeline.F_dropout, - squant=self.F_pipeline.F_squant, spad=self.F_pipeline.F_spad, skpad=self.F_pipeline.F_skpad, dpad=self.F_pipeline.F_dpad, @@ -712,19 +679,13 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]: # this function will populate a list possible pipelines # TODO: the order of List matters! the later in this list will be also be checked later - # FP8 static quantization is not supported in sparse attention yet. - squant = "f" pipelines = [] if dtype in ["fp16", "bf16"]: - for logits, mask, bias, skip in itertools.product( + for logits, mask, skip in itertools.product( ["f"], get_mask_map(mask_impl).keys(), - ["no"], ["t", "f"], ): - # Always use lse="f" and dropout="f" (not supported) - lse = "f" - dropout = "f" if hdim == 256 and hdim_v == 256: # jenga fmha only supports dim <= 192 for now. continue @@ -737,10 +698,6 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli "t", "t", logits, - bias, - lse, - dropout, - squant, mask, skip, "f", @@ -755,10 +712,6 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli "t", "t", logits, - bias, - lse, - dropout, - squant, mask, skip, "f", @@ -833,15 +786,8 @@ def get_fwd_blobs( if pipeline.F_spad != "t" or pipeline.F_skpad != "t": # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not continue - if (hdim, hdim_v) == (192, 128): - # NOTE: this is used to speedup deepseek prefill case, we don't gen training - if pipeline.F_bias != "no" or pipeline.F_dropout == "t": - continue # logits soft-cap is not generated for sparse attention - if not ( - (pipeline.F_logits == "t" and pipeline.F_bias == "no") - or pipeline.F_logits == "f" - ): + if not (pipeline.F_logits == "f"): continue if pipeline.tag != "qr_async": continue @@ -864,8 +810,6 @@ def get_fwd_blobs( if receipt in (2, 3): cond = dtype in ["fp16", "bf16"] cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_bias in ["no", "alibi"] - cond &= pipeline.F_squant == "f" cond &= pipeline.F_skip == "f" if not cond: continue @@ -873,8 +817,6 @@ def get_fwd_blobs( elif receipt == 4: cond = dtype in ["fp16", "bf16"] cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_bias in ["no", "bias"] - cond &= pipeline.F_squant == "f" cond &= mode == "batch" cond &= pipeline.F_skip == "f" cond &= pipeline.F_logits == "f" @@ -885,7 +827,6 @@ def get_fwd_blobs( cond = dtype in ["fp16", "bf16"] cond &= mode == "batch" cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_squant == "f" if not cond: continue # Aiter(mha_varlen_fwd) integration @@ -893,14 +834,12 @@ def get_fwd_blobs( cond = dtype in ["fp16", "bf16"] cond &= mode == "group" cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_squant == "f" if not cond: continue # aiter::mha_fwd C++ api integration elif receipt == 600: cond = dtype in ["fp16", "bf16"] cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_squant == "f" if not cond: continue diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py index 4afe529e1a5..d1d4f901133 100644 --- a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py +++ b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py @@ -12,15 +12,12 @@ from typing import List, Optional, Tuple from codegen.cpp_symbol_map import ( - BIAS_CHECK_MAP, - BIAS_MAP, BOOL_MAP, FWD_DTYPE_MAP, LAYOUT_MAP, MODE_MAP, PIPELINE_ENUM_MAP, PIPELINE_MAP, - SQUANT_MAP, get_mask_check_map, get_mask_map, ) @@ -75,11 +72,11 @@ def update_file(file_path, content): {F_dpad}, {F_dvpad}, {F_logits}, - {F_bias}, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, - {F_lse}, - {F_dropout}, - {F_squant_enum}, + false, + false, + ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE, {F_occupancy}, {F_skip}>; @@ -118,7 +115,7 @@ def update_file(file_path, content): ck_tile::FmhaFwdVSAKernel; using trait_{F_idx} = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; #include @@ -204,9 +201,9 @@ def update_file(file_path, content): }} """ -FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && +FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.skip_min_seqlen_q == {F_skip}) && ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ - using trait_ = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; + using trait_ = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; return fmha_vsa_fwd_(s, a); }} """ @@ -242,10 +239,6 @@ class FmhaFwdApiTrait: vlayout: str logits: str mask: str - bias: str # - lse: str # - dropout: str - squant: str # spad: str skpad: str dpad: str @@ -258,7 +251,7 @@ class FmhaFwdApiTrait: def name(self) -> str: return ( f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" - + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}" + + f"{self.vlayout}-{self.logits}-{self.mask}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}" ) @property @@ -309,10 +302,6 @@ class FmhaFwdPipeline: F_dpad: str # F_dvpad: str # F_logits: str # t/f - F_bias: str # true/false - F_lse: str # - F_dropout: str # - F_squant: str # F_mask: str # value from MASK_MAP F_skip: str # true/false F_trload: str # true/false @@ -346,10 +335,7 @@ def pad_name() -> str: else: n += "_nlogits" - if self.F_bias != "no": - n += f"_{self.F_bias}" - else: - n += "_nbias" + n += "_nbias" if self.F_mask[0:2] == "s_": if self.F_mask == "s_mask": @@ -362,17 +348,12 @@ def pad_name() -> str: else: n += "_nmask" - # Note: lse and dropout are not supported, so we don't add them to filename - if self.F_skip == "t": n += "_skip" else: n += "_nskip" - if self.F_squant == "t": - n += "_squant" - else: - n += "_nsquant" + n += "_nsquant" if self.F_trload == "t": n += "_trload" @@ -423,13 +404,8 @@ def api(self) -> str: F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], - F_bias_check=BIAS_CHECK_MAP[trait.bias], - F_bias=BIAS_MAP[trait.bias], - F_lse=BOOL_MAP[trait.lse], - F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip], F_trload=BOOL_MAP[trait.tr_load], - F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_seqtune=trait.seqtune, F_skcheck=trait.skcheck, @@ -542,11 +518,6 @@ def template(self) -> str: F_dpad=BOOL_MAP[self.F_pipeline.F_dpad], F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad], F_logits=BOOL_MAP[self.F_pipeline.F_logits], - F_bias=BIAS_MAP[self.F_pipeline.F_bias], - F_lse=BOOL_MAP[self.F_pipeline.F_lse], - F_dropout=BOOL_MAP[self.F_pipeline.F_dropout], - F_squant=BOOL_MAP[self.F_pipeline.F_squant], - F_squant_enum=SQUANT_MAP[self.F_pipeline.F_squant], F_skip=BOOL_MAP[self.F_pipeline.F_skip], F_occupancy=self.F_tile.F_occupancy, F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], @@ -586,10 +557,6 @@ def api_trait(self) -> FmhaFwdApiTrait: vlayout=self.F_pipeline.F_vlayout, mask=self.F_pipeline.F_mask, logits=self.F_pipeline.F_logits, - bias=self.F_pipeline.F_bias, - lse=self.F_pipeline.F_lse, - dropout=self.F_pipeline.F_dropout, - squant=self.F_pipeline.F_squant, spad=self.F_pipeline.F_spad, skpad=self.F_pipeline.F_skpad, dpad=self.F_pipeline.F_dpad, @@ -712,19 +679,13 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]: # this function will populate a list possible pipelines # TODO: the order of List matters! the later in this list will be also be checked later - # FP8 static quantization is not supported in sparse attention yet. - squant = "f" pipelines = [] if dtype in ["fp16", "bf16"]: - for logits, mask, bias, skip in itertools.product( + for logits, mask, skip in itertools.product( ["f"], get_mask_map(mask_impl).keys(), - ["no"], ["t", "f"], ): - # Always use lse="f" and dropout="f" (not supported) - lse = "f" - dropout = "f" if hdim == 256 and hdim_v == 256: # vsa fmha only supports dim <= 192 for now. continue @@ -737,10 +698,6 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli "t", "t", logits, - bias, - lse, - dropout, - squant, mask, skip, "f", @@ -755,10 +712,6 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli "t", "t", logits, - bias, - lse, - dropout, - squant, mask, skip, "f", @@ -833,15 +786,8 @@ def get_fwd_blobs( if pipeline.F_spad != "t" or pipeline.F_skpad != "t": # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not continue - if (hdim, hdim_v) == (192, 128): - # NOTE: this is used to speedup deepseek prefill case, we don't gen training - if pipeline.F_bias != "no": - continue # logits soft-cap is not generated for sparse attention - if not ( - (pipeline.F_logits == "t" and pipeline.F_bias == "no") - or pipeline.F_logits == "f" - ): + if not (pipeline.F_logits == "f"): continue if pipeline.tag != "qr_async_vsa": continue @@ -864,8 +810,6 @@ def get_fwd_blobs( if receipt in (2, 3): cond = dtype in ["fp16", "bf16"] cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_bias in ["no", "alibi"] - cond &= pipeline.F_squant == "f" cond &= pipeline.F_skip == "f" if not cond: continue @@ -873,8 +817,6 @@ def get_fwd_blobs( elif receipt == 4: cond = dtype in ["fp16", "bf16"] cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_bias in ["no", "bias"] - cond &= pipeline.F_squant == "f" cond &= mode == "batch" cond &= pipeline.F_skip == "f" cond &= pipeline.F_logits == "f" @@ -885,7 +827,6 @@ def get_fwd_blobs( cond = dtype in ["fp16", "bf16"] cond &= mode == "batch" cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_squant == "f" if not cond: continue # Aiter(mha_varlen_fwd) integration @@ -893,14 +834,12 @@ def get_fwd_blobs( cond = dtype in ["fp16", "bf16"] cond &= mode == "group" cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_squant == "f" if not cond: continue # aiter::mha_fwd C++ api integration elif receipt == 600: cond = dtype in ["fp16", "bf16"] cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_squant == "f" if not cond: continue diff --git a/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp b/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp index f924aa0705d..42d2c3d5fd8 100644 --- a/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp +++ b/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp @@ -75,8 +75,6 @@ struct fmha_sparge_fwd_args const void* v_ptr; const void* lut_ptr; // delta-encoded K-block indices per Q-block, int32 [B,H,Q_blk,K_blk] const void* valid_block_num_ptr; // valid K-block count per Q-block, int32 [B,H,Q_blk] - void* rand_val_ptr; - void* lse_ptr; void* o_ptr; const void* seqstart_q_ptr; @@ -95,25 +93,18 @@ struct fmha_sparge_fwd_args float pv_threshold; float scale_s; - float scale_p; - float scale_o; ck_tile::index_t stride_q; ck_tile::index_t stride_k; ck_tile::index_t stride_v; - ck_tile::index_t stride_randval; ck_tile::index_t stride_o; ck_tile::index_t nhead_stride_q; ck_tile::index_t nhead_stride_k; ck_tile::index_t nhead_stride_v; - ck_tile::index_t nhead_stride_randval; - ck_tile::index_t nhead_stride_lse; ck_tile::index_t nhead_stride_o; ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_v; - ck_tile::index_t batch_stride_randval; - ck_tile::index_t batch_stride_lse; ck_tile::index_t batch_stride_o; ck_tile::index_t window_size_left; @@ -137,9 +128,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_sparge_fwd_args args) args.v_ptr, args.lut_ptr, args.valid_block_num_ptr, - nullptr, - args.rand_val_ptr, - args.lse_ptr, args.o_ptr, args.seqstart_q_ptr, args.seqstart_k_ptr, @@ -149,30 +137,18 @@ auto fmha_fwd_create_kargs_and_grids(fmha_sparge_fwd_args args) args.nhead_q, args.nhead_q / args.nhead_k, args.pv_threshold, - args.scale_s, - args.scale_p, - args.scale_o, - 0.0f, args.stride_q, args.stride_k, args.stride_v, - 0, - args.stride_randval, args.stride_o, args.nhead_stride_q, args.nhead_stride_k, args.nhead_stride_v, - 0, - args.nhead_stride_randval, - args.nhead_stride_lse, args.nhead_stride_o, args.window_size_left, args.window_size_right, args.mask_type, - args.min_seqlen_q, - 0.0f, - false, - std::make_pair(uint64_t{0}, uint64_t{0})); + args.min_seqlen_q); } else { // create batch mode kernel arguments @@ -181,9 +157,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_sparge_fwd_args args) args.v_ptr, args.lut_ptr, args.valid_block_num_ptr, - nullptr, - args.rand_val_ptr, - args.lse_ptr, args.o_ptr, args.seqlen_q, args.seqlen_k, @@ -192,36 +165,21 @@ auto fmha_fwd_create_kargs_and_grids(fmha_sparge_fwd_args args) args.nhead_q, args.nhead_q / args.nhead_k, args.pv_threshold, - args.scale_s, - args.scale_p, - args.scale_o, - 0.0f, args.stride_q, args.stride_k, args.stride_v, - 0, - args.stride_randval, args.stride_o, args.nhead_stride_q, args.nhead_stride_k, args.nhead_stride_v, - 0, - args.nhead_stride_randval, - args.nhead_stride_lse, args.nhead_stride_o, args.batch_stride_q, args.batch_stride_k, args.batch_stride_v, - 0, - args.batch_stride_randval, - args.batch_stride_lse, args.batch_stride_o, args.window_size_left, args.window_size_right, - args.mask_type, - 0.0f, - false, - std::make_pair(uint64_t{0}, uint64_t{0})); + args.mask_type); } }(); @@ -253,10 +211,6 @@ template ; - static constexpr auto BiasEnum = BiasEnum_; - static constexpr bool kStoreLse = kStoreLse_; - static constexpr bool kHasDropout = kHasDropout_; - static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; static constexpr bool kPadS = kPadS_; static constexpr bool kPadSK = kPadSK_; static constexpr bool kPadD = kPadD_; @@ -298,9 +248,6 @@ struct fmha_sparge_fwd_traits bool is_group_mode; bool is_v_rowmajor; mask_enum mask_type; - bool has_lse; - bool has_dropout; - bool do_fp8_static_quant; bool skip_min_seqlen_q = false; // TODO: padding check is inside this api }; @@ -319,8 +266,6 @@ struct fmha_jenga_fwd_args const void* k_ptr; const void* v_ptr; const void* block_relation_onehot_ptr; // one-hot block map [B,H,Q_blk,K_blk], 1=active - void* rand_val_ptr; - void* lse_ptr; void* o_ptr; const void* seqstart_q_ptr; @@ -338,25 +283,18 @@ struct fmha_jenga_fwd_args ck_tile::index_t nhead_k; float scale_s; - float scale_p; - float scale_o; ck_tile::index_t stride_q; ck_tile::index_t stride_k; ck_tile::index_t stride_v; - ck_tile::index_t stride_randval; ck_tile::index_t stride_o; ck_tile::index_t nhead_stride_q; ck_tile::index_t nhead_stride_k; ck_tile::index_t nhead_stride_v; - ck_tile::index_t nhead_stride_randval; - ck_tile::index_t nhead_stride_lse; ck_tile::index_t nhead_stride_o; ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_v; - ck_tile::index_t batch_stride_randval; - ck_tile::index_t batch_stride_lse; ck_tile::index_t batch_stride_o; ck_tile::index_t window_size_left; @@ -375,8 +313,6 @@ struct fmha_vsa_fwd_args const void* v_ptr; const void* lut_ptr; // delta-encoded K-block indices per Q-block, int32 [B,H,Q_blk,K_blk] const void* valid_block_num_ptr; // valid K-block count per Q-block, int32 [B,H,Q_blk] - void* rand_val_ptr; - void* lse_ptr; void* o_ptr; const void* seqstart_q_ptr; @@ -394,25 +330,18 @@ struct fmha_vsa_fwd_args ck_tile::index_t nhead_k; float scale_s; - float scale_p; - float scale_o; ck_tile::index_t stride_q; ck_tile::index_t stride_k; ck_tile::index_t stride_v; - ck_tile::index_t stride_randval; ck_tile::index_t stride_o; ck_tile::index_t nhead_stride_q; ck_tile::index_t nhead_stride_k; ck_tile::index_t nhead_stride_v; - ck_tile::index_t nhead_stride_randval; - ck_tile::index_t nhead_stride_lse; ck_tile::index_t nhead_stride_o; ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_v; - ck_tile::index_t batch_stride_randval; - ck_tile::index_t batch_stride_lse; ck_tile::index_t batch_stride_o; ck_tile::index_t window_size_left; @@ -434,8 +363,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_jenga_fwd_args args) args.k_ptr, args.v_ptr, args.block_relation_onehot_ptr, - args.rand_val_ptr, - args.lse_ptr, args.o_ptr, args.seqstart_q_ptr, args.seqstart_k_ptr, @@ -445,27 +372,18 @@ auto fmha_fwd_create_kargs_and_grids(fmha_jenga_fwd_args args) args.nhead_q, args.nhead_q / args.nhead_k, args.scale_s, - args.scale_p, - args.scale_o, - 0.0f, args.stride_q, args.stride_k, args.stride_v, - args.stride_randval, args.stride_o, args.nhead_stride_q, args.nhead_stride_k, args.nhead_stride_v, - args.nhead_stride_randval, - args.nhead_stride_lse, args.nhead_stride_o, args.window_size_left, args.window_size_right, args.mask_type, - args.min_seqlen_q, - 0.0f, - false, - std::make_pair(uint64_t{0}, uint64_t{0})); + args.min_seqlen_q); } else { @@ -473,8 +391,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_jenga_fwd_args args) args.k_ptr, args.v_ptr, args.block_relation_onehot_ptr, - args.rand_val_ptr, - args.lse_ptr, args.o_ptr, args.seqlen_q, args.seqlen_k, @@ -483,32 +399,21 @@ auto fmha_fwd_create_kargs_and_grids(fmha_jenga_fwd_args args) args.nhead_q, args.nhead_q / args.nhead_k, args.scale_s, - args.scale_p, - args.scale_o, - 0.0f, args.stride_q, args.stride_k, args.stride_v, - args.stride_randval, args.stride_o, args.nhead_stride_q, args.nhead_stride_k, args.nhead_stride_v, - args.nhead_stride_randval, - args.nhead_stride_lse, args.nhead_stride_o, args.batch_stride_q, args.batch_stride_k, args.batch_stride_v, - args.batch_stride_randval, - args.batch_stride_lse, args.batch_stride_o, args.window_size_left, args.window_size_right, - args.mask_type, - 0.0f, - false, - std::make_pair(uint64_t{0}, uint64_t{0})); + args.mask_type); } }(); @@ -538,8 +443,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_vsa_fwd_args args) args.v_ptr, args.lut_ptr, args.valid_block_num_ptr, - args.rand_val_ptr, - args.lse_ptr, args.o_ptr, args.seqstart_q_ptr, args.seqstart_k_ptr, @@ -549,27 +452,18 @@ auto fmha_fwd_create_kargs_and_grids(fmha_vsa_fwd_args args) args.nhead_q, args.nhead_q / args.nhead_k, args.scale_s, - args.scale_p, - args.scale_o, - 0.0f, args.stride_q, args.stride_k, args.stride_v, - args.stride_randval, args.stride_o, args.nhead_stride_q, args.nhead_stride_k, args.nhead_stride_v, - args.nhead_stride_randval, - args.nhead_stride_lse, args.nhead_stride_o, args.window_size_left, args.window_size_right, args.mask_type, - args.min_seqlen_q, - 0.0f, - false, - std::make_pair(uint64_t{0}, uint64_t{0})); + args.min_seqlen_q); } else { @@ -578,8 +472,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_vsa_fwd_args args) args.v_ptr, args.lut_ptr, args.valid_block_num_ptr, - args.rand_val_ptr, - args.lse_ptr, args.o_ptr, args.seqlen_q, args.seqlen_k, @@ -588,32 +480,21 @@ auto fmha_fwd_create_kargs_and_grids(fmha_vsa_fwd_args args) args.nhead_q, args.nhead_q / args.nhead_k, args.scale_s, - args.scale_p, - args.scale_o, - 0.0f, args.stride_q, args.stride_k, args.stride_v, - args.stride_randval, args.stride_o, args.nhead_stride_q, args.nhead_stride_k, args.nhead_stride_v, - args.nhead_stride_randval, - args.nhead_stride_lse, args.nhead_stride_o, args.batch_stride_q, args.batch_stride_k, args.batch_stride_v, - args.batch_stride_randval, - args.batch_stride_lse, args.batch_stride_o, args.window_size_left, args.window_size_right, - args.mask_type, - 0.0f, - false, - std::make_pair(uint64_t{0}, uint64_t{0})); + args.mask_type); } }(); @@ -645,10 +526,6 @@ template ; - static constexpr auto BiasEnum = BiasEnum_; - static constexpr bool kStoreLse = kStoreLse_; - static constexpr bool kHasDropout = kHasDropout_; - static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; static constexpr bool kPadS = kPadS_; static constexpr bool kPadSK = kPadSK_; static constexpr bool kPadD = kPadD_; @@ -690,9 +563,6 @@ struct fmha_jenga_fwd_traits bool is_group_mode; bool is_v_rowmajor; mask_enum mask_type; - bool has_lse; - bool has_dropout; - bool do_fp8_static_quant; bool skip_min_seqlen_q = false; // TODO: padding check is inside this api }; diff --git a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cpp b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cpp index 98f5c940709..82292a78686 100644 --- a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cpp +++ b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cpp @@ -43,8 +43,6 @@ jenga_sparse_attention(const ck_tile::HostTensor& TQ, max_seqlen_k = seqlen_k; bool is_v_rowmajor = true; float scale_s = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); - float scale_p = 1.f; - float scale_o = 1.f; std::string msk_str = "0"; mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k); @@ -80,8 +78,7 @@ jenga_sparse_attention(const ck_tile::HostTensor& TQ, else return (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k); }(); - const ck_tile::index_t stride_randval = (max_seqlen_k); - const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); + const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); // setup nhead_stride_* arguments const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); const ck_tile::index_t nhead_stride_k = i_perm ? shape_seqlen_k * hdim_q : hdim_q; @@ -91,16 +88,12 @@ jenga_sparse_attention(const ck_tile::HostTensor& TQ, else return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k; }(); - const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); - const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; - const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); // setup batch_stride_* arguments - const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); - const ck_tile::index_t batch_stride_k = nhead_k * shape_seqlen_k * hdim_q; - const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * shape_seqlen_k; - const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); - const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q); - const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); + const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); + const ck_tile::index_t batch_stride_k = nhead_k * shape_seqlen_k * hdim_q; + const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * shape_seqlen_k; + const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); // Use device buffer pointers instead of host tensor data pointers args.q_ptr = q_buf.GetDeviceBuffer(); @@ -125,8 +118,7 @@ jenga_sparse_attention(const ck_tile::HostTensor& TQ, args.batch_stride_k = batch_stride_k; args.batch_stride_v = batch_stride_v; - args.lse_ptr = nullptr; - args.o_ptr = o_buf.GetDeviceBuffer(); + args.o_ptr = o_buf.GetDeviceBuffer(); args.seqstart_q_ptr = nullptr; args.seqstart_k_ptr = nullptr; @@ -136,25 +128,15 @@ jenga_sparse_attention(const ck_tile::HostTensor& TQ, args.max_seqlen_q = max_seqlen_q; args.scale_s = scale_s; - args.scale_p = scale_p; - args.scale_o = scale_o; - args.stride_o = stride_o; - args.nhead_stride_lse = nhead_stride_lse; - args.nhead_stride_o = nhead_stride_o; - args.batch_stride_lse = batch_stride_lse; - args.batch_stride_o = batch_stride_o; + args.stride_o = stride_o; + args.nhead_stride_o = nhead_stride_o; + args.batch_stride_o = batch_stride_o; args.window_size_left = mask.left; args.window_size_right = mask.right; args.mask_type = static_cast(mask.type); - args.rand_val_ptr = nullptr; - - args.stride_randval = stride_randval; - args.nhead_stride_randval = nhead_stride_randval; - args.batch_stride_randval = batch_stride_randval; - // Dropout not supported for sparse attention. }; @@ -164,12 +146,8 @@ jenga_sparse_attention(const ck_tile::HostTensor& TQ, traits.data_type = data_type; traits.is_v_rowmajor = is_v_rowmajor; - traits.is_group_mode = false; - traits.mask_type = mask.type; - traits.has_lse = false; - traits.do_fp8_static_quant = false; - - traits.has_dropout = false; + traits.is_group_mode = false; + traits.mask_type = mask.type; }; fmha_jenga_fwd_traits fmha_traits; diff --git a/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cpp b/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cpp index f07e3d99ae1..8b54db884e1 100644 --- a/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cpp +++ b/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cpp @@ -44,8 +44,6 @@ vsa_sparse_attention(const ck_tile::HostTensor& TQ, max_seqlen_k = seqlen_k; bool is_v_rowmajor = true; float scale_s = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); - float scale_p = 1.f; - float scale_o = 1.f; std::string msk_str = "0"; mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k); @@ -83,8 +81,7 @@ vsa_sparse_attention(const ck_tile::HostTensor& TQ, else return (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k); }(); - const ck_tile::index_t stride_randval = (max_seqlen_k); - const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); + const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); // setup nhead_stride_* arguments const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); const ck_tile::index_t nhead_stride_k = i_perm ? shape_seqlen_k * hdim_q : hdim_q; @@ -94,16 +91,12 @@ vsa_sparse_attention(const ck_tile::HostTensor& TQ, else return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k; }(); - const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); - const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; - const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); // setup batch_stride_* arguments - const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); - const ck_tile::index_t batch_stride_k = nhead_k * shape_seqlen_k * hdim_q; - const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * shape_seqlen_k; - const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); - const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q); - const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); + const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); + const ck_tile::index_t batch_stride_k = nhead_k * shape_seqlen_k * hdim_q; + const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * shape_seqlen_k; + const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); // Use device buffer pointers instead of host tensor data pointers args.q_ptr = q_buf.GetDeviceBuffer(); @@ -129,8 +122,7 @@ vsa_sparse_attention(const ck_tile::HostTensor& TQ, args.batch_stride_k = batch_stride_k; args.batch_stride_v = batch_stride_v; - args.lse_ptr = nullptr; - args.o_ptr = o_buf.GetDeviceBuffer(); + args.o_ptr = o_buf.GetDeviceBuffer(); args.seqstart_q_ptr = nullptr; args.seqstart_k_ptr = nullptr; @@ -140,25 +132,15 @@ vsa_sparse_attention(const ck_tile::HostTensor& TQ, args.max_seqlen_q = max_seqlen_q; args.scale_s = scale_s; - args.scale_p = scale_p; - args.scale_o = scale_o; - args.stride_o = stride_o; - args.nhead_stride_lse = nhead_stride_lse; - args.nhead_stride_o = nhead_stride_o; - args.batch_stride_lse = batch_stride_lse; - args.batch_stride_o = batch_stride_o; + args.stride_o = stride_o; + args.nhead_stride_o = nhead_stride_o; + args.batch_stride_o = batch_stride_o; args.window_size_left = mask.left; args.window_size_right = mask.right; args.mask_type = static_cast(mask.type); - args.rand_val_ptr = nullptr; - - args.stride_randval = stride_randval; - args.nhead_stride_randval = nhead_stride_randval; - args.batch_stride_randval = batch_stride_randval; - // Dropout not supported for sparse attention. }; @@ -168,12 +150,8 @@ vsa_sparse_attention(const ck_tile::HostTensor& TQ, traits.data_type = data_type; traits.is_v_rowmajor = is_v_rowmajor; - traits.is_group_mode = false; - traits.mask_type = mask.type; - traits.has_lse = false; - traits.do_fp8_static_quant = false; - - traits.has_dropout = false; + traits.is_group_mode = false; + traits.mask_type = mask.type; }; fmha_jenga_fwd_traits fmha_traits; diff --git a/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_jenga_kernel.hpp b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_jenga_kernel.hpp index b1414fce40e..5945dcf5aee 100644 --- a/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_jenga_kernel.hpp +++ b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_jenga_kernel.hpp @@ -161,12 +161,6 @@ struct FmhaFwdJengaKernel ck_tile::GenericAttentionMaskEnum mask_type; }; - struct FmhaFwdFp8StaticQuantKargs - { - float scale_p; - float scale_o; - }; - struct FmhaFwdSkipMinSeqlenQKargs { ck_tile::index_t min_seqlen_q = 0; @@ -175,8 +169,7 @@ struct FmhaFwdJengaKernel struct FmhaFwdBatchModeKargs : FmhaFwdCommonKargs, std::conditional_t>, - FmhaFwdEmptyKargs<2>, - std::conditional_t> + FmhaFwdEmptyKargs<2> { ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; @@ -188,8 +181,7 @@ struct FmhaFwdJengaKernel : FmhaFwdCommonKargs, std::conditional_t>, FmhaFwdEmptyKargs<2>, - std::conditional_t>, - std::conditional_t> + std::conditional_t> { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; @@ -211,8 +203,6 @@ struct FmhaFwdJengaKernel const void* k_ptr, const void* v_ptr, const void* block_relation_onehot_ptr, - void* rand_val_ptr, - void* lse_ptr, void* o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, @@ -221,45 +211,22 @@ struct FmhaFwdJengaKernel ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, - float scale_p, - float scale_o, - float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, - ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_randval, - ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, - ck_tile::index_t batch_stride_randval, - ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, - ck_tile::index_t mask_type, - float p_drop, - bool s_randval, - std::variant, std::pair> - drop_seed_offset) + ck_tile::index_t mask_type) { - (void)rand_val_ptr; - (void)lse_ptr; - (void)logits_soft_cap; - (void)stride_randval; - (void)nhead_stride_randval; - (void)nhead_stride_lse; - (void)batch_stride_randval; - (void)batch_stride_lse; - (void)p_drop; - (void)s_randval; - (void)drop_seed_offset; Kargs kargs{{q_ptr, k_ptr, v_ptr, @@ -286,7 +253,6 @@ struct FmhaFwdJengaKernel nhead_stride_o}, // args for common karg {}, // placeholder for mask {}, // placeholder for empty kargs - {}, // placeholder for fp8_static_quant args batch_stride_q, batch_stride_k, batch_stride_v, @@ -298,12 +264,6 @@ struct FmhaFwdJengaKernel kargs.window_size_right = window_size_right; kargs.mask_type = static_cast(mask_type); } - if constexpr(kDoFp8StaticQuant) - { - kargs.scale_p = scale_p; - kargs.scale_o = scale_o; - } - return kargs; } @@ -314,8 +274,6 @@ struct FmhaFwdJengaKernel const void* k_ptr, const void* v_ptr, const void* block_relation_onehot_ptr, - void* rand_val_ptr, - void* lse_ptr, void* o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, @@ -324,161 +282,49 @@ struct FmhaFwdJengaKernel ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, - float scale_p, - float scale_o, - float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, - ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_randval, - ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, - ck_tile::index_t batch_stride_randval, - ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, - ck_tile::index_t mask_type, - float p_drop, - bool s_randval, - const std::tuple& drop_seed_offset) + ck_tile::index_t mask_type) { - return MakeKargsImpl( - q_ptr, - k_ptr, - v_ptr, - block_relation_onehot_ptr, - rand_val_ptr, - lse_ptr, - o_ptr, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - num_head_q, - nhead_ratio_qk, - scale_s, - scale_p, - scale_o, - logits_soft_cap, - stride_q, - stride_k, - stride_v, - stride_randval, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_randval, - nhead_stride_lse, - nhead_stride_o, - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_randval, - batch_stride_lse, - batch_stride_o, - window_size_left, - window_size_right, - mask_type, - p_drop, - s_randval, - std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); - } - - // std::variant<> can't take in a list initializer, overload for backward compatibility - template - CK_TILE_HOST static constexpr std::enable_if_t - MakeKargs(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - const void* block_relation_onehot_ptr, - void* rand_val_ptr, - void* lse_ptr, - void* o_ptr, - ck_tile::index_t seqlen_q, - ck_tile::index_t seqlen_k, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head_q, - ck_tile::index_t nhead_ratio_qk, - float scale_s, - float scale_p, - float scale_o, - float logits_soft_cap, - ck_tile::index_t stride_q, - ck_tile::index_t stride_k, - ck_tile::index_t stride_v, - ck_tile::index_t stride_randval, - ck_tile::index_t stride_o, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_randval, - ck_tile::index_t nhead_stride_lse, - ck_tile::index_t nhead_stride_o, - ck_tile::index_t batch_stride_q, - ck_tile::index_t batch_stride_k, - ck_tile::index_t batch_stride_v, - ck_tile::index_t batch_stride_randval, - ck_tile::index_t batch_stride_lse, - ck_tile::index_t batch_stride_o, - ck_tile::index_t window_size_left, - ck_tile::index_t window_size_right, - ck_tile::index_t mask_type, - float p_drop, - bool s_randval, - const std::tuple& drop_seed_offset) - { - return MakeKargsImpl( - q_ptr, - k_ptr, - v_ptr, - block_relation_onehot_ptr, - rand_val_ptr, - lse_ptr, - o_ptr, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - num_head_q, - nhead_ratio_qk, - scale_s, - scale_p, - scale_o, - logits_soft_cap, - stride_q, - stride_k, - stride_v, - stride_randval, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_randval, - nhead_stride_lse, - nhead_stride_o, - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_randval, - batch_stride_lse, - batch_stride_o, - window_size_left, - window_size_right, - mask_type, - p_drop, - s_randval, - std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); + return MakeKargsImpl(q_ptr, + k_ptr, + v_ptr, + block_relation_onehot_ptr, + o_ptr, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + scale_s, + stride_q, + stride_k, + stride_v, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_o, + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_o, + window_size_left, + window_size_right, + mask_type); } template @@ -487,8 +333,6 @@ struct FmhaFwdJengaKernel const void* k_ptr, const void* v_ptr, const void* block_relation_onehot_ptr, - void* rand_val_ptr, - void* lse_ptr, void* o_ptr, const void* seqstart_q_ptr, const void* seqstart_k_ptr, @@ -498,38 +342,19 @@ struct FmhaFwdJengaKernel ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, - float scale_p, - float scale_o, - float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, - ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_randval, - ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, - ck_tile::index_t min_seqlen_q, - float p_drop, - bool s_randval, - std::variant, std::pair> - drop_seed_offset) + ck_tile::index_t min_seqlen_q) { - (void)rand_val_ptr; - (void)lse_ptr; - (void)logits_soft_cap; - (void)stride_randval; - (void)nhead_stride_randval; - (void)nhead_stride_lse; - (void)p_drop; - (void)s_randval; - (void)drop_seed_offset; Kargs kargs{{q_ptr, k_ptr, v_ptr, @@ -556,7 +381,6 @@ struct FmhaFwdJengaKernel nhead_stride_o}, // args for common karg {}, // placeholder for mask {}, // placeholder for empty kargs - {}, // placeholder for fp8_static_quant args {}, // placeholder for min_seqlen_q reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), @@ -568,11 +392,6 @@ struct FmhaFwdJengaKernel kargs.window_size_right = window_size_right; kargs.mask_type = static_cast(mask_type); } - if constexpr(kDoFp8StaticQuant) - { - kargs.scale_p = scale_p; - kargs.scale_o = scale_o; - } if constexpr(kSkipMinSeqlenQ) { kargs.min_seqlen_q = min_seqlen_q; @@ -588,85 +407,6 @@ struct FmhaFwdJengaKernel const void* k_ptr, const void* v_ptr, const void* block_relation_onehot_ptr, - void* rand_val_ptr, - void* lse_ptr, - void* o_ptr, - const void* seqstart_q_ptr, - const void* seqstart_k_ptr, - const void* seqlen_k_ptr, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head_q, - ck_tile::index_t nhead_ratio_qk, - float scale_s, - float scale_p, - float scale_o, - float logits_soft_cap, - ck_tile::index_t stride_q, - ck_tile::index_t stride_k, - ck_tile::index_t stride_v, - ck_tile::index_t stride_randval, - ck_tile::index_t stride_o, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_randval, - ck_tile::index_t nhead_stride_lse, - ck_tile::index_t nhead_stride_o, - ck_tile::index_t window_size_left, - ck_tile::index_t window_size_right, - ck_tile::index_t mask_type, - float p_drop, - bool s_randval, - const std::tuple& drop_seed_offset) - { - return MakeKargsImpl( - q_ptr, - k_ptr, - v_ptr, - block_relation_onehot_ptr, - rand_val_ptr, - lse_ptr, - o_ptr, - seqstart_q_ptr, - seqstart_k_ptr, - seqlen_k_ptr, - hdim_q, - hdim_v, - num_head_q, - nhead_ratio_qk, - scale_s, - scale_p, - scale_o, - logits_soft_cap, - stride_q, - stride_k, - stride_v, - stride_randval, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_randval, - nhead_stride_lse, - nhead_stride_o, - window_size_left, - window_size_right, - mask_type, - p_drop, - s_randval, - std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); - } - - // std::variant<> can't take in a list initializer, overload for backward compatibility - template - CK_TILE_HOST static constexpr std::enable_if_t - MakeKargs(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - const void* block_relation_onehot_ptr, - void* rand_val_ptr, - void* lse_ptr, void* o_ptr, const void* seqstart_q_ptr, const void* seqstart_k_ptr, @@ -676,63 +416,42 @@ struct FmhaFwdJengaKernel ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, - float scale_p, - float scale_o, - float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, - ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_randval, - ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, - ck_tile::index_t mask_type, - float p_drop, - bool s_randval, - const std::tuple& drop_seed_offset) + ck_tile::index_t mask_type) { - return MakeKargsImpl( - q_ptr, - k_ptr, - v_ptr, - block_relation_onehot_ptr, - rand_val_ptr, - lse_ptr, - o_ptr, - seqstart_q_ptr, - seqstart_k_ptr, - seqlen_k_ptr, - hdim_q, - hdim_v, - num_head_q, - nhead_ratio_qk, - scale_s, - scale_p, - scale_o, - logits_soft_cap, - stride_q, - stride_k, - stride_v, - stride_randval, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_randval, - nhead_stride_lse, - nhead_stride_o, - window_size_left, - window_size_right, - mask_type, - p_drop, - s_randval, - std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); + return MakeKargsImpl(q_ptr, + k_ptr, + v_ptr, + block_relation_onehot_ptr, + o_ptr, + seqstart_q_ptr, + seqstart_k_ptr, + seqlen_k_ptr, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + scale_s, + stride_q, + stride_k, + stride_v, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_o, + window_size_left, + window_size_right, + mask_type); } CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, diff --git a/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp index 33b42b4266d..c570fed1455 100644 --- a/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp +++ b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp @@ -58,6 +58,8 @@ struct FmhaFwdVSAKernel static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ; static_assert(BiasEnum == BlockAttentionBiasEnum::NO_BIAS, "VSA sparse attention does not support bias."); + static_assert(!kStoreLSE, "VSA sparse attention does not support LSE output."); + static_assert(!kHasDropout, "VSA sparse attention does not support dropout."); static_assert(!kHasLogitsSoftCap, "VSA sparse attention does not support logits soft-cap."); static_assert(!kDoFp8StaticQuant, "VSA sparse attention does not support FP8 static quantization yet."); @@ -153,28 +155,6 @@ struct FmhaFwdVSAKernel ck_tile::index_t nhead_stride_o; }; - struct FmhaFwdLogitsSoftCapKargs - { - FmhaFwdLogitsSoftCapKargs() = default; - - void init_logits_soft_cap(float logits_soft_cap_) - { - if(0 < logits_soft_cap_) - { - logits_soft_cap = logits_soft_cap_; - logits_soft_cap_rcp = 1.f / logits_soft_cap; - } - else - { - logits_soft_cap = 0.f; - logits_soft_cap_rcp = 0.f; - } - } - - float logits_soft_cap; - float logits_soft_cap_rcp; - }; - struct FmhaFwdMaskKargs { // ck_tile::index_t window_size_left, window_size_right; @@ -182,12 +162,6 @@ struct FmhaFwdVSAKernel ck_tile::GenericAttentionMaskEnum mask_type; }; - struct FmhaFwdFp8StaticQuantKargs - { - float scale_p; - float scale_o; - }; - struct FmhaFwdSkipMinSeqlenQKargs { ck_tile::index_t min_seqlen_q = 0; @@ -195,9 +169,8 @@ struct FmhaFwdVSAKernel struct FmhaFwdBatchModeKargs : FmhaFwdCommonKargs, - FmhaFwdEmptyKargs<0>, std::conditional_t>, - std::conditional_t> + FmhaFwdEmptyKargs<2> { ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; @@ -207,9 +180,8 @@ struct FmhaFwdVSAKernel struct FmhaFwdGroupModeKargs : FmhaFwdCommonKargs, - FmhaFwdEmptyKargs<0>, std::conditional_t>, - std::conditional_t>, + FmhaFwdEmptyKargs<2>, std::conditional_t> { const int32_t* seqstart_q_ptr; @@ -233,8 +205,6 @@ struct FmhaFwdVSAKernel const void* v_ptr, const void* lut_ptr, const void* valid_block_num_ptr, - void* rand_val_ptr, - void* lse_ptr, void* o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, @@ -243,45 +213,22 @@ struct FmhaFwdVSAKernel ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, - float scale_p, - float scale_o, - float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, - ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_randval, - ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, - ck_tile::index_t batch_stride_randval, - ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, - ck_tile::index_t mask_type, - float p_drop, - bool s_randval, - std::variant, std::pair> - drop_seed_offset) + ck_tile::index_t mask_type) { - (void)rand_val_ptr; - (void)lse_ptr; - (void)logits_soft_cap; - (void)stride_randval; - (void)nhead_stride_randval; - (void)nhead_stride_lse; - (void)batch_stride_randval; - (void)batch_stride_lse; - (void)p_drop; - (void)s_randval; - (void)drop_seed_offset; Kargs kargs{{q_ptr, k_ptr, v_ptr, @@ -309,7 +256,6 @@ struct FmhaFwdVSAKernel nhead_stride_o}, // args for common karg {}, // placeholder for mask {}, // placeholder for empty kargs - {}, // placeholder for fp8_static_quant args batch_stride_q, batch_stride_k, batch_stride_v, @@ -321,12 +267,6 @@ struct FmhaFwdVSAKernel kargs.window_size_right = window_size_right; kargs.mask_type = static_cast(mask_type); } - if constexpr(kDoFp8StaticQuant) - { - kargs.scale_p = scale_p; - kargs.scale_o = scale_o; - } - return kargs; } @@ -338,97 +278,6 @@ struct FmhaFwdVSAKernel const void* v_ptr, const void* lut_ptr, const void* valid_block_num_ptr, - void* rand_val_ptr, - void* lse_ptr, - void* o_ptr, - ck_tile::index_t seqlen_q, - ck_tile::index_t seqlen_k, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head_q, - ck_tile::index_t nhead_ratio_qk, - float scale_s, - float scale_p, - float scale_o, - float logits_soft_cap, - ck_tile::index_t stride_q, - ck_tile::index_t stride_k, - ck_tile::index_t stride_v, - ck_tile::index_t stride_randval, - ck_tile::index_t stride_o, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_randval, - ck_tile::index_t nhead_stride_lse, - ck_tile::index_t nhead_stride_o, - ck_tile::index_t batch_stride_q, - ck_tile::index_t batch_stride_k, - ck_tile::index_t batch_stride_v, - ck_tile::index_t batch_stride_randval, - ck_tile::index_t batch_stride_lse, - ck_tile::index_t batch_stride_o, - ck_tile::index_t window_size_left, - ck_tile::index_t window_size_right, - ck_tile::index_t mask_type, - float p_drop, - bool s_randval, - const std::tuple& drop_seed_offset) - { - return MakeKargsImpl( - q_ptr, - k_ptr, - v_ptr, - lut_ptr, - valid_block_num_ptr, - rand_val_ptr, - lse_ptr, - o_ptr, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - num_head_q, - nhead_ratio_qk, - scale_s, - scale_p, - scale_o, - logits_soft_cap, - stride_q, - stride_k, - stride_v, - stride_randval, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_randval, - nhead_stride_lse, - nhead_stride_o, - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_randval, - batch_stride_lse, - batch_stride_o, - window_size_left, - window_size_right, - mask_type, - p_drop, - s_randval, - std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); - } - - // std::variant<> can't take in a list initializer, overload for backward compatibility - template - CK_TILE_HOST static constexpr std::enable_if_t - MakeKargs(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - const void* lut_ptr, - const void* valid_block_num_ptr, - void* rand_val_ptr, - void* lse_ptr, void* o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, @@ -437,75 +286,50 @@ struct FmhaFwdVSAKernel ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, - float scale_p, - float scale_o, - float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, - ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_randval, - ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, - ck_tile::index_t batch_stride_randval, - ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, - ck_tile::index_t mask_type, - float p_drop, - bool s_randval, - const std::tuple& drop_seed_offset) + ck_tile::index_t mask_type) { - return MakeKargsImpl( - q_ptr, - k_ptr, - v_ptr, - lut_ptr, - valid_block_num_ptr, - rand_val_ptr, - lse_ptr, - o_ptr, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - num_head_q, - nhead_ratio_qk, - scale_s, - scale_p, - scale_o, - logits_soft_cap, - stride_q, - stride_k, - stride_v, - stride_randval, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_randval, - nhead_stride_lse, - nhead_stride_o, - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_randval, - batch_stride_lse, - batch_stride_o, - window_size_left, - window_size_right, - mask_type, - p_drop, - s_randval, - std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); + return MakeKargsImpl(q_ptr, + k_ptr, + v_ptr, + lut_ptr, + valid_block_num_ptr, + o_ptr, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + scale_s, + stride_q, + stride_k, + stride_v, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_o, + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_o, + window_size_left, + window_size_right, + mask_type); } template @@ -515,8 +339,6 @@ struct FmhaFwdVSAKernel const void* v_ptr, const void* lut_ptr, const void* valid_block_num_ptr, - void* rand_val_ptr, - void* lse_ptr, void* o_ptr, const void* seqstart_q_ptr, const void* seqstart_k_ptr, @@ -526,38 +348,19 @@ struct FmhaFwdVSAKernel ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, - float scale_p, - float scale_o, - float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, - ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_randval, - ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, - ck_tile::index_t min_seqlen_q, - float p_drop, - bool s_randval, - std::variant, std::pair> - drop_seed_offset) + ck_tile::index_t min_seqlen_q) { - (void)rand_val_ptr; - (void)lse_ptr; - (void)logits_soft_cap; - (void)stride_randval; - (void)nhead_stride_randval; - (void)nhead_stride_lse; - (void)p_drop; - (void)s_randval; - (void)drop_seed_offset; Kargs kargs{{q_ptr, k_ptr, v_ptr, @@ -585,7 +388,6 @@ struct FmhaFwdVSAKernel nhead_stride_o}, // args for common karg {}, // placeholder for mask {}, // placeholder for empty kargs - {}, // placeholder for fp8_static_quant args {}, // placeholder for min_seqlen_q reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), @@ -597,11 +399,6 @@ struct FmhaFwdVSAKernel kargs.window_size_right = window_size_right; kargs.mask_type = static_cast(mask_type); } - if constexpr(kDoFp8StaticQuant) - { - kargs.scale_p = scale_p; - kargs.scale_o = scale_o; - } if constexpr(kSkipMinSeqlenQ) { kargs.min_seqlen_q = min_seqlen_q; @@ -618,87 +415,6 @@ struct FmhaFwdVSAKernel const void* v_ptr, const void* lut_ptr, const void* valid_block_num_ptr, - void* rand_val_ptr, - void* lse_ptr, - void* o_ptr, - const void* seqstart_q_ptr, - const void* seqstart_k_ptr, - const void* seqlen_k_ptr, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head_q, - ck_tile::index_t nhead_ratio_qk, - float scale_s, - float scale_p, - float scale_o, - float logits_soft_cap, - ck_tile::index_t stride_q, - ck_tile::index_t stride_k, - ck_tile::index_t stride_v, - ck_tile::index_t stride_randval, - ck_tile::index_t stride_o, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_randval, - ck_tile::index_t nhead_stride_lse, - ck_tile::index_t nhead_stride_o, - ck_tile::index_t window_size_left, - ck_tile::index_t window_size_right, - ck_tile::index_t mask_type, - float p_drop, - bool s_randval, - const std::tuple& drop_seed_offset) - { - return MakeKargsImpl( - q_ptr, - k_ptr, - v_ptr, - lut_ptr, - valid_block_num_ptr, - rand_val_ptr, - lse_ptr, - o_ptr, - seqstart_q_ptr, - seqstart_k_ptr, - seqlen_k_ptr, - hdim_q, - hdim_v, - num_head_q, - nhead_ratio_qk, - scale_s, - scale_p, - scale_o, - logits_soft_cap, - stride_q, - stride_k, - stride_v, - stride_randval, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_randval, - nhead_stride_lse, - nhead_stride_o, - window_size_left, - window_size_right, - mask_type, - p_drop, - s_randval, - std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); - } - - // std::variant<> can't take in a list initializer, overload for backward compatibility - template - CK_TILE_HOST static constexpr std::enable_if_t - MakeKargs(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - const void* lut_ptr, - const void* valid_block_num_ptr, - void* rand_val_ptr, - void* lse_ptr, void* o_ptr, const void* seqstart_q_ptr, const void* seqstart_k_ptr, @@ -708,64 +424,43 @@ struct FmhaFwdVSAKernel ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, - float scale_p, - float scale_o, - float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, - ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_randval, - ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, - ck_tile::index_t mask_type, - float p_drop, - bool s_randval, - const std::tuple& drop_seed_offset) + ck_tile::index_t mask_type) { - return MakeKargsImpl( - q_ptr, - k_ptr, - v_ptr, - lut_ptr, - valid_block_num_ptr, - rand_val_ptr, - lse_ptr, - o_ptr, - seqstart_q_ptr, - seqstart_k_ptr, - seqlen_k_ptr, - hdim_q, - hdim_v, - num_head_q, - nhead_ratio_qk, - scale_s, - scale_p, - scale_o, - logits_soft_cap, - stride_q, - stride_k, - stride_v, - stride_randval, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_randval, - nhead_stride_lse, - nhead_stride_o, - window_size_left, - window_size_right, - mask_type, - p_drop, - s_randval, - std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); + return MakeKargsImpl(q_ptr, + k_ptr, + v_ptr, + lut_ptr, + valid_block_num_ptr, + o_ptr, + seqstart_q_ptr, + seqstart_k_ptr, + seqlen_k_ptr, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + scale_s, + stride_q, + stride_k, + stride_v, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_o, + window_size_left, + window_size_right, + mask_type); } CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, From a9456ba928b801a8139d7b2009738e6258cb8ea0 Mon Sep 17 00:00:00 2001 From: Jiangyong Date: Thu, 29 Jan 2026 15:48:56 +0800 Subject: [PATCH 18/22] remove useless code --- .../codegen/ops/fmha_fwd_jenga.py | 30 +- .../codegen/ops/fmha_fwd_vsa.py | 30 +- .../ck_tile/50_sparse_attn/fmha_fwd_trek.hpp | 420 +++------------- .../50_sparse_attn/jenga_sparse_attention.cpp | 7 +- .../50_sparse_attn/vsa_sparse_attention.cpp | 7 +- .../kernel/fmha_fwd_jenga_kernel.hpp | 465 +++--------------- .../kernel/fmha_fwd_vsa_kernel.hpp | 403 ++------------- ...ock_fmha_pipeline_qr_ks_vs_async_jenga.hpp | 285 ++--------- ...block_fmha_pipeline_qr_ks_vs_async_vsa.hpp | 218 +------- 9 files changed, 243 insertions(+), 1622 deletions(-) diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py index cec3fe307a9..a8566d9e53e 100644 --- a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py +++ b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py @@ -78,7 +78,7 @@ def update_file(file_path, content): false, ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE, {F_occupancy}, - {F_skip}>; + false>; using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; @@ -114,8 +114,8 @@ def update_file(file_path, content): using fmha_kernel_{F_idx} = ck_tile::FmhaFwdJengaKernel; -using trait_{F_idx} = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; +using trait_{F_idx} = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; #include @@ -201,9 +201,9 @@ def update_file(file_path, content): }} """ -FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.skip_min_seqlen_q == {F_skip}) && +FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ - using trait_ = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; + using trait_ = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; return fmha_jenga_fwd_(s, a); }} """ @@ -243,7 +243,6 @@ class FmhaFwdApiTrait: skpad: str dpad: str dvpad: str - skip: str tr_load: str constraint: CppConstraint @@ -251,7 +250,7 @@ class FmhaFwdApiTrait: def name(self) -> str: return ( f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" - + f"{self.vlayout}-{self.logits}-{self.mask}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}" + + f"{self.vlayout}-{self.logits}-{self.mask}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}" ) @property @@ -303,7 +302,6 @@ class FmhaFwdPipeline: F_dvpad: str # F_logits: str # t/f F_mask: str # value from MASK_MAP - F_skip: str # true/false F_trload: str # true/false F_constraint: CppConstraint = field(default_factory=CppConstraint) @@ -348,10 +346,7 @@ def pad_name() -> str: else: n += "_nmask" - if self.F_skip == "t": - n += "_skip" - else: - n += "_nskip" + n += "_nskip" n += "_nsquant" @@ -398,13 +393,11 @@ def api(self) -> str: if_k = "if" if k == 0 else "else if" inners = inners + FMHA_FWD_API_INNER_DISPATCH.format( F_if=if_k, - F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], - F_skip=BOOL_MAP[trait.skip], F_trload=BOOL_MAP[trait.tr_load], F_scheck=trait.scheck, F_seqtune=trait.seqtune, @@ -518,7 +511,6 @@ def template(self) -> str: F_dpad=BOOL_MAP[self.F_pipeline.F_dpad], F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad], F_logits=BOOL_MAP[self.F_pipeline.F_logits], - F_skip=BOOL_MAP[self.F_pipeline.F_skip], F_occupancy=self.F_tile.F_occupancy, F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], @@ -561,7 +553,6 @@ def api_trait(self) -> FmhaFwdApiTrait: skpad=self.F_pipeline.F_skpad, dpad=self.F_pipeline.F_dpad, dvpad=self.F_pipeline.F_dvpad, - skip=self.F_pipeline.F_skip, tr_load=self.F_pipeline.F_trload, constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint, ) @@ -681,10 +672,9 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli # TODO: the order of List matters! the later in this list will be also be checked later pipelines = [] if dtype in ["fp16", "bf16"]: - for logits, mask, skip in itertools.product( + for logits, mask in itertools.product( ["f"], get_mask_map(mask_impl).keys(), - ["t", "f"], ): if hdim == 256 and hdim_v == 256: # jenga fmha only supports dim <= 192 for now. @@ -699,7 +689,6 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli "t", logits, mask, - skip, "f", ) ) @@ -713,7 +702,6 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli "t", logits, mask, - skip, "f", ) ) @@ -810,7 +798,6 @@ def get_fwd_blobs( if receipt in (2, 3): cond = dtype in ["fp16", "bf16"] cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_skip == "f" if not cond: continue # PyTorch integration @@ -818,7 +805,6 @@ def get_fwd_blobs( cond = dtype in ["fp16", "bf16"] cond &= pipeline.F_vlayout == "row" cond &= mode == "batch" - cond &= pipeline.F_skip == "f" cond &= pipeline.F_logits == "f" if not cond: continue diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py index d1d4f901133..731e4be9b2c 100644 --- a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py +++ b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py @@ -78,7 +78,7 @@ def update_file(file_path, content): false, ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE, {F_occupancy}, - {F_skip}>; + false>; using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; @@ -114,8 +114,8 @@ def update_file(file_path, content): using fmha_kernel_{F_idx} = ck_tile::FmhaFwdVSAKernel; -using trait_{F_idx} = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; +using trait_{F_idx} = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; #include @@ -201,9 +201,9 @@ def update_file(file_path, content): }} """ -FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.skip_min_seqlen_q == {F_skip}) && +FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ - using trait_ = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; + using trait_ = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; return fmha_vsa_fwd_(s, a); }} """ @@ -243,7 +243,6 @@ class FmhaFwdApiTrait: skpad: str dpad: str dvpad: str - skip: str tr_load: str constraint: CppConstraint @@ -251,7 +250,7 @@ class FmhaFwdApiTrait: def name(self) -> str: return ( f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" - + f"{self.vlayout}-{self.logits}-{self.mask}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}" + + f"{self.vlayout}-{self.logits}-{self.mask}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}" ) @property @@ -303,7 +302,6 @@ class FmhaFwdPipeline: F_dvpad: str # F_logits: str # t/f F_mask: str # value from MASK_MAP - F_skip: str # true/false F_trload: str # true/false F_constraint: CppConstraint = field(default_factory=CppConstraint) @@ -348,10 +346,7 @@ def pad_name() -> str: else: n += "_nmask" - if self.F_skip == "t": - n += "_skip" - else: - n += "_nskip" + n += "_nskip" n += "_nsquant" @@ -398,13 +393,11 @@ def api(self) -> str: if_k = "if" if k == 0 else "else if" inners = inners + FMHA_FWD_API_INNER_DISPATCH.format( F_if=if_k, - F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], - F_skip=BOOL_MAP[trait.skip], F_trload=BOOL_MAP[trait.tr_load], F_scheck=trait.scheck, F_seqtune=trait.seqtune, @@ -518,7 +511,6 @@ def template(self) -> str: F_dpad=BOOL_MAP[self.F_pipeline.F_dpad], F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad], F_logits=BOOL_MAP[self.F_pipeline.F_logits], - F_skip=BOOL_MAP[self.F_pipeline.F_skip], F_occupancy=self.F_tile.F_occupancy, F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], @@ -561,7 +553,6 @@ def api_trait(self) -> FmhaFwdApiTrait: skpad=self.F_pipeline.F_skpad, dpad=self.F_pipeline.F_dpad, dvpad=self.F_pipeline.F_dvpad, - skip=self.F_pipeline.F_skip, tr_load=self.F_pipeline.F_trload, constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint, ) @@ -681,10 +672,9 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli # TODO: the order of List matters! the later in this list will be also be checked later pipelines = [] if dtype in ["fp16", "bf16"]: - for logits, mask, skip in itertools.product( + for logits, mask in itertools.product( ["f"], get_mask_map(mask_impl).keys(), - ["t", "f"], ): if hdim == 256 and hdim_v == 256: # vsa fmha only supports dim <= 192 for now. @@ -699,7 +689,6 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli "t", logits, mask, - skip, "f", ) ) @@ -713,7 +702,6 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli "t", logits, mask, - skip, "f", ) ) @@ -810,7 +798,6 @@ def get_fwd_blobs( if receipt in (2, 3): cond = dtype in ["fp16", "bf16"] cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_skip == "f" if not cond: continue # PyTorch integration @@ -818,7 +805,6 @@ def get_fwd_blobs( cond = dtype in ["fp16", "bf16"] cond &= pipeline.F_vlayout == "row" cond &= mode == "batch" - cond &= pipeline.F_skip == "f" cond &= pipeline.F_logits == "f" if not cond: continue diff --git a/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp b/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp index 42d2c3d5fd8..e90431b22fc 100644 --- a/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp +++ b/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp @@ -68,197 +68,6 @@ struct FmhaMasks using CausalMask = ck_tile::GenericAttentionMask; }; -struct fmha_sparge_fwd_args -{ - const void* q_ptr; - const void* k_ptr; - const void* v_ptr; - const void* lut_ptr; // delta-encoded K-block indices per Q-block, int32 [B,H,Q_blk,K_blk] - const void* valid_block_num_ptr; // valid K-block count per Q-block, int32 [B,H,Q_blk] - void* o_ptr; - - const void* seqstart_q_ptr; - const void* seqstart_k_ptr; - const void* - seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr - - ck_tile::index_t seqlen_q; - ck_tile::index_t seqlen_k; - ck_tile::index_t batch; - ck_tile::index_t max_seqlen_q; - ck_tile::index_t hdim_q; - ck_tile::index_t hdim_v; - ck_tile::index_t nhead_q; - ck_tile::index_t nhead_k; - - float pv_threshold; - float scale_s; - - ck_tile::index_t stride_q; - ck_tile::index_t stride_k; - ck_tile::index_t stride_v; - ck_tile::index_t stride_o; - ck_tile::index_t nhead_stride_q; - ck_tile::index_t nhead_stride_k; - ck_tile::index_t nhead_stride_v; - ck_tile::index_t nhead_stride_o; - ck_tile::index_t batch_stride_q; - ck_tile::index_t batch_stride_k; - ck_tile::index_t batch_stride_v; - ck_tile::index_t batch_stride_o; - - ck_tile::index_t window_size_left; - ck_tile::index_t window_size_right; - ck_tile::index_t mask_type; - ck_tile::index_t min_seqlen_q; - - // Dropout is not supported for sparse attention; keep args minimal. -}; - -template -auto fmha_fwd_create_kargs_and_grids(fmha_sparge_fwd_args args) -{ - assert(args.nhead_q % args.nhead_k == 0); - auto kargs = [&] { - // create group mode kernel arguments - if constexpr(FmhaKernel::kIsGroupMode) - { - return FmhaKernel::MakeKargsImpl(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.lut_ptr, - args.valid_block_num_ptr, - args.o_ptr, - args.seqstart_q_ptr, - args.seqstart_k_ptr, - args.seqlen_k_ptr, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.pv_threshold, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_o, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - args.nhead_stride_o, - args.window_size_left, - args.window_size_right, - args.mask_type, - args.min_seqlen_q); - } - else - { // create batch mode kernel arguments - return FmhaKernel::MakeKargsImpl(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.lut_ptr, - args.valid_block_num_ptr, - args.o_ptr, - args.seqlen_q, - args.seqlen_k, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.pv_threshold, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_o, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - args.nhead_stride_o, - args.batch_stride_q, - args.batch_stride_k, - args.batch_stride_v, - args.batch_stride_o, - args.window_size_left, - args.window_size_right, - args.mask_type); - } - }(); - - if constexpr(FmhaKernel::kIsGroupMode) - { - dim3 grids = FmhaKernel::GridSize( - args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.seqlen_k_ptr != nullptr); - return ck_tile::make_tuple(kargs, grids); - } - else - { - dim3 grids = - FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, false); - return ck_tile::make_tuple(kargs, grids); - } -} - -// this is used to pattern-match internl kernel implementation, not to instantiate kernel -template -struct fmha_sparge_fwd_traits_ -{ - static constexpr ck_tile::index_t HDim = HDim_; - using DataType = ck_tile::remove_cvref_t; - static constexpr bool kIsGroupMode = kIsGroupMode_; - static constexpr ck_tile::index_t kM0 = kM0_; - static constexpr ck_tile::index_t kN0 = kN0_; - static constexpr ck_tile::index_t kK0 = kK0_; - static constexpr ck_tile::index_t kN1 = kN1_; - static constexpr ck_tile::index_t kK1 = kK1_; - static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_; - static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; - static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_; - static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_; - using FmhaMask = ck_tile::remove_cvref_t; - static constexpr bool kPadS = kPadS_; - static constexpr bool kPadSK = kPadSK_; - static constexpr bool kPadD = kPadD_; - static constexpr bool kPadDv = kPadDv_; - static constexpr bool kUseTrLoad = kUseTrLoad_; - static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; -}; - -struct fmha_sparge_fwd_traits -{ - int hdim_q; - int hdim_v; - std::string data_type; - bool is_group_mode; - bool is_v_rowmajor; - mask_enum mask_type; - bool skip_min_seqlen_q = false; - // TODO: padding check is inside this api -}; - -float fmha_sparge_fwd(fmha_sparge_fwd_traits, fmha_sparge_fwd_args, const ck_tile::stream_config&); - -template -float fmha_sparge_fwd_(const ck_tile::stream_config&, fmha_sparge_fwd_args); - -float fmha_sparge_fwd(fmha_sparge_fwd_args, const ck_tile::stream_config&); - // jenga struct fmha_jenga_fwd_args { @@ -268,11 +77,6 @@ struct fmha_jenga_fwd_args const void* block_relation_onehot_ptr; // one-hot block map [B,H,Q_blk,K_blk], 1=active void* o_ptr; - const void* seqstart_q_ptr; - const void* seqstart_k_ptr; - const void* - seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr - ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; ck_tile::index_t batch; @@ -300,7 +104,6 @@ struct fmha_jenga_fwd_args ck_tile::index_t window_size_left; ck_tile::index_t window_size_right; ck_tile::index_t mask_type; - ck_tile::index_t min_seqlen_q; // Dropout is not supported for sparse attention; keep args minimal. }; @@ -315,11 +118,6 @@ struct fmha_vsa_fwd_args const void* valid_block_num_ptr; // valid K-block count per Q-block, int32 [B,H,Q_blk] void* o_ptr; - const void* seqstart_q_ptr; - const void* seqstart_k_ptr; - const void* - seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr - ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; ck_tile::index_t batch; @@ -347,7 +145,6 @@ struct fmha_vsa_fwd_args ck_tile::index_t window_size_left; ck_tile::index_t window_size_right; ck_tile::index_t mask_type; - ck_tile::index_t min_seqlen_q; // Dropout is not supported for sparse attention; keep args minimal. }; @@ -356,166 +153,78 @@ template auto fmha_fwd_create_kargs_and_grids(fmha_jenga_fwd_args args) { assert(args.nhead_q % args.nhead_k == 0); - auto kargs = [&] { - if constexpr(FmhaKernel::kIsGroupMode) - { - return FmhaKernel::MakeKargsImpl(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.block_relation_onehot_ptr, - args.o_ptr, - args.seqstart_q_ptr, - args.seqstart_k_ptr, - args.seqlen_k_ptr, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.scale_s, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_o, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - args.nhead_stride_o, - args.window_size_left, - args.window_size_right, - args.mask_type, - args.min_seqlen_q); - } - else - { - return FmhaKernel::MakeKargsImpl(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.block_relation_onehot_ptr, - args.o_ptr, - args.seqlen_q, - args.seqlen_k, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.scale_s, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_o, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - args.nhead_stride_o, - args.batch_stride_q, - args.batch_stride_k, - args.batch_stride_v, - args.batch_stride_o, - args.window_size_left, - args.window_size_right, - args.mask_type); - } - }(); - - if constexpr(FmhaKernel::kIsGroupMode) - { - dim3 grids = FmhaKernel::GridSize( - args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.seqlen_k_ptr != nullptr); - return ck_tile::make_tuple(kargs, grids); - } - else - { - dim3 grids = - FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, false); - return ck_tile::make_tuple(kargs, grids); - } + auto kargs = FmhaKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.block_relation_onehot_ptr, + args.o_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_o, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type); + + dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v); + return ck_tile::make_tuple(kargs, grids); } template auto fmha_fwd_create_kargs_and_grids(fmha_vsa_fwd_args args) { assert(args.nhead_q % args.nhead_k == 0); - auto kargs = [&] { - if constexpr(FmhaKernel::kIsGroupMode) - { - return FmhaKernel::MakeKargsImpl(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.lut_ptr, - args.valid_block_num_ptr, - args.o_ptr, - args.seqstart_q_ptr, - args.seqstart_k_ptr, - args.seqlen_k_ptr, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.scale_s, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_o, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - args.nhead_stride_o, - args.window_size_left, - args.window_size_right, - args.mask_type, - args.min_seqlen_q); - } - else - { - return FmhaKernel::MakeKargsImpl(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.lut_ptr, - args.valid_block_num_ptr, - args.o_ptr, - args.seqlen_q, - args.seqlen_k, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.scale_s, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_o, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - args.nhead_stride_o, - args.batch_stride_q, - args.batch_stride_k, - args.batch_stride_v, - args.batch_stride_o, - args.window_size_left, - args.window_size_right, - args.mask_type); - } - }(); - - if constexpr(FmhaKernel::kIsGroupMode) - { - dim3 grids = FmhaKernel::GridSize( - args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.seqlen_k_ptr != nullptr); - return ck_tile::make_tuple(kargs, grids); - } - else - { - dim3 grids = - FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, false); - return ck_tile::make_tuple(kargs, grids); - } + auto kargs = FmhaKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.lut_ptr, + args.valid_block_num_ptr, + args.o_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_o, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type); + + dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v); + return ck_tile::make_tuple(kargs, grids); } // this is used to pattern-match internl kernel implementation, not to instantiate kernel template + bool kUseTrLoad_> struct fmha_jenga_fwd_traits_ { static constexpr ck_tile::index_t HDim = HDim_; using DataType = ck_tile::remove_cvref_t; - static constexpr bool kIsGroupMode = kIsGroupMode_; static constexpr ck_tile::index_t kM0 = kM0_; static constexpr ck_tile::index_t kN0 = kN0_; static constexpr ck_tile::index_t kK0 = kK0_; @@ -552,7 +259,6 @@ struct fmha_jenga_fwd_traits_ static constexpr bool kPadD = kPadD_; static constexpr bool kPadDv = kPadDv_; static constexpr bool kUseTrLoad = kUseTrLoad_; - static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; }; struct fmha_jenga_fwd_traits @@ -560,10 +266,8 @@ struct fmha_jenga_fwd_traits int hdim_q; int hdim_v; std::string data_type; - bool is_group_mode; bool is_v_rowmajor; mask_enum mask_type; - bool skip_min_seqlen_q = false; // TODO: padding check is inside this api }; diff --git a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cpp b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cpp index 82292a78686..09327211587 100644 --- a/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cpp +++ b/example/ck_tile/50_sparse_attn/jenga_sparse_attention.cpp @@ -120,10 +120,6 @@ jenga_sparse_attention(const ck_tile::HostTensor& TQ, args.o_ptr = o_buf.GetDeviceBuffer(); - args.seqstart_q_ptr = nullptr; - args.seqstart_k_ptr = nullptr; - args.seqlen_k_ptr = nullptr; - args.seqlen_k = shape_seqlen_k; // batch mode only args.max_seqlen_q = max_seqlen_q; @@ -146,8 +142,7 @@ jenga_sparse_attention(const ck_tile::HostTensor& TQ, traits.data_type = data_type; traits.is_v_rowmajor = is_v_rowmajor; - traits.is_group_mode = false; - traits.mask_type = mask.type; + traits.mask_type = mask.type; }; fmha_jenga_fwd_traits fmha_traits; diff --git a/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cpp b/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cpp index 8b54db884e1..2dbd265ef22 100644 --- a/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cpp +++ b/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cpp @@ -124,10 +124,6 @@ vsa_sparse_attention(const ck_tile::HostTensor& TQ, args.o_ptr = o_buf.GetDeviceBuffer(); - args.seqstart_q_ptr = nullptr; - args.seqstart_k_ptr = nullptr; - args.seqlen_k_ptr = nullptr; - args.seqlen_k = shape_seqlen_k; // batch mode only args.max_seqlen_q = max_seqlen_q; @@ -150,8 +146,7 @@ vsa_sparse_attention(const ck_tile::HostTensor& TQ, traits.data_type = data_type; traits.is_v_rowmajor = is_v_rowmajor; - traits.is_group_mode = false; - traits.mask_type = mask.type; + traits.mask_type = mask.type; }; fmha_jenga_fwd_traits fmha_traits; diff --git a/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_jenga_kernel.hpp b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_jenga_kernel.hpp index 5945dcf5aee..d0424898f09 100644 --- a/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_jenga_kernel.hpp +++ b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_jenga_kernel.hpp @@ -43,7 +43,6 @@ struct FmhaFwdJengaKernel using VLayout = ck_tile::remove_cvref_t; - static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; @@ -54,8 +53,8 @@ struct FmhaFwdJengaKernel static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; static constexpr bool kDoFp8StaticQuant = (FmhaPipeline::Problem::QScaleEnum != ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE); - static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ; - static_assert(!kIsGroupMode, "Jenga sparse attention currently supports batch mode only."); + static_assert(!FmhaPipeline::kIsGroupMode, + "Jenga sparse attention currently supports batch mode only."); static_assert(BiasEnum == BlockAttentionBiasEnum::NO_BIAS, "Jenga sparse attention does not support bias."); static_assert(!kStoreLSE, "Jenga sparse attention does not support LSE output."); @@ -70,51 +69,6 @@ struct FmhaFwdJengaKernel static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy; - // clang-format off - template struct t2s; - template <> struct t2s { static constexpr const char * name = "fp32"; }; - template <> struct t2s { static constexpr const char * name = "fp16"; }; - template <> struct t2s { static constexpr const char * name = "bf16"; }; - template <> struct t2s { static constexpr const char * name = "fp8"; }; - template <> struct t2s { static constexpr const char * name = "bf8"; }; - // clang-format on - - CK_TILE_HOST static std::string GetName() - { - // sync with generate.py - // clang-format off - using bfs = typename FmhaPipeline::BlockFmhaShape; - using g0br = typename bfs::Gemm0BlockWarps; - using g1br = typename bfs::Gemm1BlockWarps; - using g0wt = typename bfs::Gemm0WarpTile; - using g1wt = typename bfs::Gemm1WarpTile; - #define _SS_ std::string - #define _TS_ std::to_string - auto pn = [&] () { - std::string n; - if (kPadSeqLenQ) n += "s"; - if (kPadSeqLenK) n += "sk"; - if (kPadHeadDimQ) n += "d"; - if (kPadHeadDimV) n += "dv"; - return n.empty() ? n : std::string("p") + n; }(); - return - _SS_("fmha_jenga_fwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s::name) + - "_" + (kIsGroupMode ? "group" : "batch") + "_" - "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + - _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" + - "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" + - "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" + - "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" + - "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" + - (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + - "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) + - (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + - (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ); - #undef _SS_ - #undef _TS_ - // clang-format on - } - template // to avoid duplicated base class prblem, introduce an template // arg struct FmhaFwdEmptyKargs @@ -161,11 +115,6 @@ struct FmhaFwdJengaKernel ck_tile::GenericAttentionMaskEnum mask_type; }; - struct FmhaFwdSkipMinSeqlenQKargs - { - ck_tile::index_t min_seqlen_q = 0; - }; - struct FmhaFwdBatchModeKargs : FmhaFwdCommonKargs, std::conditional_t>, @@ -176,19 +125,7 @@ struct FmhaFwdJengaKernel ck_tile::index_t batch_stride_v; ck_tile::index_t batch_stride_o; }; - - struct FmhaFwdGroupModeKargs - : FmhaFwdCommonKargs, - std::conditional_t>, - FmhaFwdEmptyKargs<2>, - std::conditional_t> - { - const int32_t* seqstart_q_ptr; - const int32_t* seqstart_k_ptr; - const int32_t* seqlen_k_ptr; - }; - - using Kargs = std::conditional_t; + using Kargs = FmhaFwdBatchModeKargs; struct BlockIndices { @@ -197,35 +134,34 @@ struct FmhaFwdJengaKernel ck_tile::index_t kv_head_idx; }; - template - CK_TILE_HOST static constexpr std::enable_if_t - MakeKargsImpl(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - const void* block_relation_onehot_ptr, - void* o_ptr, - ck_tile::index_t seqlen_q, - ck_tile::index_t seqlen_k, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head_q, - ck_tile::index_t nhead_ratio_qk, - float scale_s, - ck_tile::index_t stride_q, - ck_tile::index_t stride_k, - ck_tile::index_t stride_v, - ck_tile::index_t stride_o, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_o, - ck_tile::index_t batch_stride_q, - ck_tile::index_t batch_stride_k, - ck_tile::index_t batch_stride_v, - ck_tile::index_t batch_stride_o, - ck_tile::index_t window_size_left, - ck_tile::index_t window_size_right, - ck_tile::index_t mask_type) + // std::variant<> can't take in a list initializer, overload for backward compatibility + CK_TILE_HOST static constexpr Kargs MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* block_relation_onehot_ptr, + void* o_ptr, + ck_tile::index_t seqlen_q, + ck_tile::index_t seqlen_k, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale_s, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_q, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, + ck_tile::index_t batch_stride_o, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type) { Kargs kargs{{q_ptr, k_ptr, @@ -267,283 +203,44 @@ struct FmhaFwdJengaKernel return kargs; } - // std::variant<> can't take in a list initializer, overload for backward compatibility - template - CK_TILE_HOST static constexpr std::enable_if_t - MakeKargs(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - const void* block_relation_onehot_ptr, - void* o_ptr, - ck_tile::index_t seqlen_q, - ck_tile::index_t seqlen_k, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head_q, - ck_tile::index_t nhead_ratio_qk, - float scale_s, - ck_tile::index_t stride_q, - ck_tile::index_t stride_k, - ck_tile::index_t stride_v, - ck_tile::index_t stride_o, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_o, - ck_tile::index_t batch_stride_q, - ck_tile::index_t batch_stride_k, - ck_tile::index_t batch_stride_v, - ck_tile::index_t batch_stride_o, - ck_tile::index_t window_size_left, - ck_tile::index_t window_size_right, - ck_tile::index_t mask_type) - { - return MakeKargsImpl(q_ptr, - k_ptr, - v_ptr, - block_relation_onehot_ptr, - o_ptr, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - num_head_q, - nhead_ratio_qk, - scale_s, - stride_q, - stride_k, - stride_v, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_o, - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_o, - window_size_left, - window_size_right, - mask_type); - } - - template - CK_TILE_HOST static constexpr std::enable_if_t - MakeKargsImpl(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - const void* block_relation_onehot_ptr, - void* o_ptr, - const void* seqstart_q_ptr, - const void* seqstart_k_ptr, - const void* seqlen_k_ptr, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head_q, - ck_tile::index_t nhead_ratio_qk, - float scale_s, - ck_tile::index_t stride_q, - ck_tile::index_t stride_k, - ck_tile::index_t stride_v, - ck_tile::index_t stride_o, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_o, - ck_tile::index_t window_size_left, - ck_tile::index_t window_size_right, - ck_tile::index_t mask_type, - ck_tile::index_t min_seqlen_q) - { - Kargs kargs{{q_ptr, - k_ptr, - v_ptr, - block_relation_onehot_ptr, - o_ptr, - -1, // seqlen will be updated by another pointer - -1, // - hdim_q, - hdim_v, - num_head_q, - nhead_ratio_qk, -#if CK_TILE_FMHA_FWD_FAST_EXP2 - static_cast(scale_s * ck_tile::log2e_v<>), -#else - scale_s, -#endif - stride_q, - stride_k, - stride_v, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_o}, // args for common karg - {}, // placeholder for mask - {}, // placeholder for empty kargs - {}, // placeholder for min_seqlen_q - reinterpret_cast(seqstart_q_ptr), - reinterpret_cast(seqstart_k_ptr), - reinterpret_cast(seqlen_k_ptr)}; - - if constexpr(kHasMask) - { - kargs.window_size_left = window_size_left; - kargs.window_size_right = window_size_right; - kargs.mask_type = static_cast(mask_type); - } - if constexpr(kSkipMinSeqlenQ) - { - kargs.min_seqlen_q = min_seqlen_q; - } - - return kargs; - } - - // std::variant<> can't take in a list initializer, overload for backward compatibility - template - CK_TILE_HOST static constexpr std::enable_if_t - MakeKargs(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - const void* block_relation_onehot_ptr, - void* o_ptr, - const void* seqstart_q_ptr, - const void* seqstart_k_ptr, - const void* seqlen_k_ptr, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head_q, - ck_tile::index_t nhead_ratio_qk, - float scale_s, - ck_tile::index_t stride_q, - ck_tile::index_t stride_k, - ck_tile::index_t stride_v, - ck_tile::index_t stride_o, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_o, - ck_tile::index_t window_size_left, - ck_tile::index_t window_size_right, - ck_tile::index_t mask_type) - { - return MakeKargsImpl(q_ptr, - k_ptr, - v_ptr, - block_relation_onehot_ptr, - o_ptr, - seqstart_q_ptr, - seqstart_k_ptr, - seqlen_k_ptr, - hdim_q, - hdim_v, - num_head_q, - nhead_ratio_qk, - scale_s, - stride_q, - stride_k, - stride_v, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_o, - window_size_left, - window_size_right, - mask_type); - } - CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_, ck_tile::index_t hdim_v_, bool has_padded_seqlen_k = false) { - has_padded_seqlen_k = true; - // has_padded_seqlen_k is determined by checking (seqlen_k_ptr != nullptr) - if(has_padded_seqlen_k) - { - // TODO: this may need tuning - return dim3(nhead_, - batch_size_, - ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * - ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1)); - } - else - { - // TODO: this may need tuning - return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * - ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1), - nhead_, - batch_size_); - } + (void)has_padded_seqlen_k; + // TODO: this may need tuning + return dim3(nhead_, + batch_size_, + ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1)); } CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs) { - bool has_padded_seqlen_k = false; + const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1); - if constexpr(kIsGroupMode) - has_padded_seqlen_k = (kargs.seqlen_k_ptr != nullptr); + const index_t i_block = blockIdx.z; + const index_t i_nhead = blockIdx.x; + const index_t i_batch = blockIdx.y; - has_padded_seqlen_k = true; - - if(has_padded_seqlen_k) - { - // const index_t num_tile_m0 = seqlen_q / kM0; - const index_t num_tile_n1 = - ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1); + const auto f = [](index_t dividend, index_t divisor) { + index_t quotient = dividend / divisor; + index_t modulus = dividend - quotient * divisor; + return ck_tile::make_tuple(quotient, modulus); + }; - const index_t i_block = blockIdx.z; - const index_t i_nhead = blockIdx.x; - const index_t i_batch = blockIdx.y; + const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); - const auto f = [](index_t dividend, index_t divisor) { - index_t quotient = dividend / divisor; - index_t modulus = dividend - quotient * divisor; - return ck_tile::make_tuple(quotient, modulus); - }; - - const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); - - if constexpr(kHasMask) - { - // assume that num_tile_n1 is always 1 - return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); - } - else - { - return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); - } + if constexpr(kHasMask) + { + // assume that num_tile_n1 is always 1 + return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); } else { - // const index_t num_tile_m0 = seqlen_q / kM0; - const index_t num_tile_n1 = - ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1); - - const index_t i_block = blockIdx.x; - const index_t i_nhead = blockIdx.y; - const index_t i_batch = blockIdx.z; - - const auto f = [](index_t dividend, index_t divisor) { - index_t quotient = dividend / divisor; - index_t modulus = dividend - quotient * divisor; - return ck_tile::make_tuple(quotient, modulus); - }; - - const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); - - if constexpr(kHasMask) - { - // assume that num_tile_n1 is always 1 - return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); - } - else - { - return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); - } + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); } } @@ -574,60 +271,10 @@ struct FmhaFwdJengaKernel long_index_t batch_offset_v = 0; long_index_t batch_offset_o = 0; - if constexpr(kIsGroupMode) - { - // get starting offset for each batch - const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; - const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; - - batch_offset_q = query_start * kargs.stride_q; - batch_offset_k = key_start * kargs.stride_k; - if constexpr(std::is_same_v) - { - batch_offset_v = key_start * kargs.stride_v; - } - else - { - batch_offset_v = key_start; - } - batch_offset_o = query_start * kargs.stride_o; - - // get real # queries & # keys under group mode - const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; - kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; - - if constexpr(kSkipMinSeqlenQ) - { - if(kargs.seqlen_q <= kargs.min_seqlen_q) - { - return; - } - } - - // # of required blocks is different in each groups, terminate unnecessary blocks - // earlier - if(kargs.seqlen_q <= i_m0) - { - return; - } - - if(kargs.seqlen_k_ptr != nullptr) - { - kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; - } - else - { - const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; - kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; - } - } - else - { - batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; - batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; - batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; - batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; - } + batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; + batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; + batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; + batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; // for simplicity, batch stride we just modify the pointer const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + diff --git a/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp index c570fed1455..686576bb9be 100644 --- a/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp +++ b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp @@ -43,7 +43,6 @@ struct FmhaFwdVSAKernel using VLayout = ck_tile::remove_cvref_t; - static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; @@ -55,7 +54,7 @@ struct FmhaFwdVSAKernel static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum; static constexpr bool kDoFp8StaticQuant = (QScaleEnum != ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE); - static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ; + static_assert(!FmhaPipeline::kIsGroupMode, "VSA sparse attention supports batch mode only."); static_assert(BiasEnum == BlockAttentionBiasEnum::NO_BIAS, "VSA sparse attention does not support bias."); static_assert(!kStoreLSE, "VSA sparse attention does not support LSE output."); @@ -70,51 +69,6 @@ struct FmhaFwdVSAKernel static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy; - // clang-format off - template struct t2s; - template <> struct t2s { static constexpr const char * name = "fp32"; }; - template <> struct t2s { static constexpr const char * name = "fp16"; }; - template <> struct t2s { static constexpr const char * name = "bf16"; }; - template <> struct t2s { static constexpr const char * name = "fp8"; }; - template <> struct t2s { static constexpr const char * name = "bf8"; }; - // clang-format on - - CK_TILE_HOST static std::string GetName() - { - // sync with generate.py - // clang-format off - using bfs = typename FmhaPipeline::BlockFmhaShape; - using g0br = typename bfs::Gemm0BlockWarps; - using g1br = typename bfs::Gemm1BlockWarps; - using g0wt = typename bfs::Gemm0WarpTile; - using g1wt = typename bfs::Gemm1WarpTile; - #define _SS_ std::string - #define _TS_ std::to_string - auto pn = [&] () { - std::string n; - if (kPadSeqLenQ) n += "s"; - if (kPadSeqLenK) n += "sk"; - if (kPadHeadDimQ) n += "d"; - if (kPadHeadDimV) n += "dv"; - return n.empty() ? n : std::string("p") + n; }(); - return - _SS_("fmha_vsa_fwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s::name) + - "_" + (kIsGroupMode ? "group" : "batch") + "_" - "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + - _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" + - "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" + - "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" + - "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" + - "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" + - (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + - "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) + - (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + - (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ); - #undef _SS_ - #undef _TS_ - // clang-format on - } - template // to avoid duplicated base class prblem, introduce an template // arg struct FmhaFwdEmptyKargs @@ -162,15 +116,7 @@ struct FmhaFwdVSAKernel ck_tile::GenericAttentionMaskEnum mask_type; }; - struct FmhaFwdSkipMinSeqlenQKargs - { - ck_tile::index_t min_seqlen_q = 0; - }; - - struct FmhaFwdBatchModeKargs - : FmhaFwdCommonKargs, - std::conditional_t>, - FmhaFwdEmptyKargs<2> + struct FmhaFwdBatchModeKargs : FmhaFwdCommonKargs, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<2> { ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; @@ -178,18 +124,7 @@ struct FmhaFwdVSAKernel ck_tile::index_t batch_stride_o; }; - struct FmhaFwdGroupModeKargs - : FmhaFwdCommonKargs, - std::conditional_t>, - FmhaFwdEmptyKargs<2>, - std::conditional_t> - { - const int32_t* seqstart_q_ptr; - const int32_t* seqstart_k_ptr; - const int32_t* seqlen_k_ptr; - }; - - using Kargs = std::conditional_t; + using Kargs = FmhaFwdBatchModeKargs; struct BlockIndices { @@ -198,36 +133,35 @@ struct FmhaFwdVSAKernel ck_tile::index_t kv_head_idx; }; - template - CK_TILE_HOST static constexpr std::enable_if_t - MakeKargsImpl(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - const void* lut_ptr, - const void* valid_block_num_ptr, - void* o_ptr, - ck_tile::index_t seqlen_q, - ck_tile::index_t seqlen_k, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head_q, - ck_tile::index_t nhead_ratio_qk, - float scale_s, - ck_tile::index_t stride_q, - ck_tile::index_t stride_k, - ck_tile::index_t stride_v, - ck_tile::index_t stride_o, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_o, - ck_tile::index_t batch_stride_q, - ck_tile::index_t batch_stride_k, - ck_tile::index_t batch_stride_v, - ck_tile::index_t batch_stride_o, - ck_tile::index_t window_size_left, - ck_tile::index_t window_size_right, - ck_tile::index_t mask_type) + // std::variant<> can't take in a list initializer, overload for backward compatibility + CK_TILE_HOST static constexpr Kargs MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* lut_ptr, + const void* valid_block_num_ptr, + void* o_ptr, + ck_tile::index_t seqlen_q, + ck_tile::index_t seqlen_k, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale_s, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_q, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, + ck_tile::index_t batch_stride_o, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type) { Kargs kargs{{q_ptr, k_ptr, @@ -270,199 +204,6 @@ struct FmhaFwdVSAKernel return kargs; } - // std::variant<> can't take in a list initializer, overload for backward compatibility - template - CK_TILE_HOST static constexpr std::enable_if_t - MakeKargs(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - const void* lut_ptr, - const void* valid_block_num_ptr, - void* o_ptr, - ck_tile::index_t seqlen_q, - ck_tile::index_t seqlen_k, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head_q, - ck_tile::index_t nhead_ratio_qk, - float scale_s, - ck_tile::index_t stride_q, - ck_tile::index_t stride_k, - ck_tile::index_t stride_v, - ck_tile::index_t stride_o, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_o, - ck_tile::index_t batch_stride_q, - ck_tile::index_t batch_stride_k, - ck_tile::index_t batch_stride_v, - ck_tile::index_t batch_stride_o, - ck_tile::index_t window_size_left, - ck_tile::index_t window_size_right, - ck_tile::index_t mask_type) - { - return MakeKargsImpl(q_ptr, - k_ptr, - v_ptr, - lut_ptr, - valid_block_num_ptr, - o_ptr, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - num_head_q, - nhead_ratio_qk, - scale_s, - stride_q, - stride_k, - stride_v, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_o, - batch_stride_q, - batch_stride_k, - batch_stride_v, - batch_stride_o, - window_size_left, - window_size_right, - mask_type); - } - - template - CK_TILE_HOST static constexpr std::enable_if_t - MakeKargsImpl(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - const void* lut_ptr, - const void* valid_block_num_ptr, - void* o_ptr, - const void* seqstart_q_ptr, - const void* seqstart_k_ptr, - const void* seqlen_k_ptr, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head_q, - ck_tile::index_t nhead_ratio_qk, - float scale_s, - ck_tile::index_t stride_q, - ck_tile::index_t stride_k, - ck_tile::index_t stride_v, - ck_tile::index_t stride_o, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_o, - ck_tile::index_t window_size_left, - ck_tile::index_t window_size_right, - ck_tile::index_t mask_type, - ck_tile::index_t min_seqlen_q) - { - Kargs kargs{{q_ptr, - k_ptr, - v_ptr, - lut_ptr, - valid_block_num_ptr, - o_ptr, - -1, // seqlen will be updated by another pointer - -1, // - hdim_q, - hdim_v, - num_head_q, - nhead_ratio_qk, -#if CK_TILE_FMHA_FWD_FAST_EXP2 - static_cast(scale_s * ck_tile::log2e_v<>), -#else - scale_s, -#endif - stride_q, - stride_k, - stride_v, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_o}, // args for common karg - {}, // placeholder for mask - {}, // placeholder for empty kargs - {}, // placeholder for min_seqlen_q - reinterpret_cast(seqstart_q_ptr), - reinterpret_cast(seqstart_k_ptr), - reinterpret_cast(seqlen_k_ptr)}; - - if constexpr(kHasMask) - { - kargs.window_size_left = window_size_left; - kargs.window_size_right = window_size_right; - kargs.mask_type = static_cast(mask_type); - } - if constexpr(kSkipMinSeqlenQ) - { - kargs.min_seqlen_q = min_seqlen_q; - } - - return kargs; - } - - // std::variant<> can't take in a list initializer, overload for backward compatibility - template - CK_TILE_HOST static constexpr std::enable_if_t - MakeKargs(const void* q_ptr, - const void* k_ptr, - const void* v_ptr, - const void* lut_ptr, - const void* valid_block_num_ptr, - void* o_ptr, - const void* seqstart_q_ptr, - const void* seqstart_k_ptr, - const void* seqlen_k_ptr, - ck_tile::index_t hdim_q, - ck_tile::index_t hdim_v, - ck_tile::index_t num_head_q, - ck_tile::index_t nhead_ratio_qk, - float scale_s, - ck_tile::index_t stride_q, - ck_tile::index_t stride_k, - ck_tile::index_t stride_v, - ck_tile::index_t stride_o, - ck_tile::index_t nhead_stride_q, - ck_tile::index_t nhead_stride_k, - ck_tile::index_t nhead_stride_v, - ck_tile::index_t nhead_stride_o, - ck_tile::index_t window_size_left, - ck_tile::index_t window_size_right, - ck_tile::index_t mask_type) - { - return MakeKargsImpl(q_ptr, - k_ptr, - v_ptr, - lut_ptr, - valid_block_num_ptr, - o_ptr, - seqstart_q_ptr, - seqstart_k_ptr, - seqlen_k_ptr, - hdim_q, - hdim_v, - num_head_q, - nhead_ratio_qk, - scale_s, - stride_q, - stride_k, - stride_v, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - nhead_stride_o, - window_size_left, - window_size_right, - mask_type); - } - CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_, @@ -492,12 +233,8 @@ struct FmhaFwdVSAKernel { bool has_padded_seqlen_k = false; - if constexpr(kIsGroupMode) - has_padded_seqlen_k = (kargs.seqlen_k_ptr != nullptr); - if(has_padded_seqlen_k) { - // const index_t num_tile_m0 = seqlen_q / kM0; const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1); @@ -515,7 +252,6 @@ struct FmhaFwdVSAKernel if constexpr(kHasMask) { - // assume that num_tile_n1 is always 1 return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); } else @@ -525,7 +261,6 @@ struct FmhaFwdVSAKernel } else { - // const index_t num_tile_m0 = seqlen_q / kM0; const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1); @@ -543,7 +278,6 @@ struct FmhaFwdVSAKernel if constexpr(kHasMask) { - // assume that num_tile_n1 is always 1 return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); } else @@ -577,60 +311,10 @@ struct FmhaFwdVSAKernel long_index_t batch_offset_v = 0; long_index_t batch_offset_o = 0; - if constexpr(kIsGroupMode) - { - // get starting offset for each batch - const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; - const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; - - batch_offset_q = query_start * kargs.stride_q; - batch_offset_k = key_start * kargs.stride_k; - if constexpr(std::is_same_v) - { - batch_offset_v = key_start * kargs.stride_v; - } - else - { - batch_offset_v = key_start; - } - batch_offset_o = query_start * kargs.stride_o; - - // get real # queries & # keys under group mode - const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; - kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; - - if constexpr(kSkipMinSeqlenQ) - { - if(kargs.seqlen_q <= kargs.min_seqlen_q) - { - return; - } - } - - // # of required blocks is different in each groups, terminate unnecessary blocks - // earlier - if(kargs.seqlen_q <= i_m0) - { - return; - } - - if(kargs.seqlen_k_ptr != nullptr) - { - kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; - } - else - { - const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; - kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; - } - } - else - { - batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; - batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; - batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; - batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; - } + batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; + batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; + batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; + batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; // for simplicity, batch stride we just modify the pointer const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + @@ -723,21 +407,6 @@ struct FmhaFwdVSAKernel make_tuple(number{}, number{}), sequence{}); } - else - { - const auto v_dram_naive = make_naive_tensor_view( - v_ptr, - make_tuple(kargs.hdim_v, kargs.seqlen_k), - make_tuple(kargs.stride_v, 1), - number{}, - number<1>{}); - - constexpr bool kPadHeadDimV_ = kUseAsyncCopy ? kPadHeadDimV : false; - return pad_tensor_view( - v_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - } }(); auto q_dram_window = make_tile_window( diff --git a/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp index b198b904cde..7ab4cc710c7 100644 --- a/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp +++ b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp @@ -62,10 +62,11 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kHasDropout = Problem::kHasDropout; - static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && - (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || - !kHasLogitsSoftCap)) || - (!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap)); + static_assert(BiasEnum == BlockAttentionBiasEnum::NO_BIAS, + "Jenga sparse attention does not support bias."); + static_assert(!kHasDropout, "Jenga sparse attention does not support dropout."); + static_assert(!kStoreLSE, "Jenga sparse attention does not support LSE output."); + static_assert(!kHasLogitsSoftCap, "Jenga sparse attention does not support logits soft-cap."); // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this @@ -78,9 +79,6 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga return kPadSeqLenK ? 1 : Policy::template GetAlignmentV(); }(); static constexpr index_t kAlignmentO = Policy::template GetAlignmentO(); - static constexpr index_t kAlignmentBias = - kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); - #if CK_TILE_FMHA_FWD_FAST_EXP2 static constexpr auto R_LOG2E = 1.0 / log2e_v; #endif @@ -91,36 +89,30 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga else { // minimize occupancy - if constexpr(BiasEnum != BlockAttentionBiasEnum::NO_BIAS && kHasDropout) - { - return 1; - } - if constexpr(kQKHeaddim <= 32) { - if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS && - FmhaMask::IsMasking) + if constexpr(kPadSeqLenK && FmhaMask::IsMasking) return 1; else return 2; } else if constexpr(kQKHeaddim <= 64) { - if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + if constexpr(kPadSeqLenK) return 2; else return 3; } else if constexpr(kQKHeaddim <= 128) { - if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + if constexpr(kPadSeqLenK) return 1; else return 2; } else if constexpr(kQKHeaddim <= 192) { - if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + if constexpr(kPadSeqLenK) return 1; else return 2; @@ -285,15 +277,6 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga { if(num_total_loop <= 0) { - if constexpr(kStoreLSE) - { - auto lse = - make_static_distributed_tensor(m.get_tile_distribution()); - - set_tile(lse, -numeric::infinity()); - - store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); - } buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0) // otherwise will have compute error(maybe compiler bug?) @@ -326,24 +309,15 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga // load k_dram_window.init_raw(); constexpr auto k_oob_ck = bool_constant{}; - constexpr auto k_pre_np = [&]() { - if constexpr(kPadSeqLenK && - (BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - (BiasEnum != BlockAttentionBiasEnum::NO_BIAS && kHasDropout))) - return bool_constant{}; - else - return bool_constant{}; - }(); + constexpr auto k_pre_np = bool_constant{}; - const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); - auto bias_dram_window = - make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), - bias_dram_block_window_tmp.get_window_lengths(), - {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N - Policy::template MakeBiasDramTileDistribution()); - - auto randval_dram_window = dropout.template MakeRandvalDramWindow( - randval_dram_block_window_tmp, seqlen_k_start); + (void)bias_dram_block_window_tmp; + (void)bias_element_func; + (void)randval_dram_block_window_tmp; + (void)lse_dram_window_tmp; + (void)lse_element_func; + (void)position_encoding; + (void)dropout; auto v_dram_window = make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), @@ -438,8 +412,7 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga async_load_fence(); __builtin_amdgcn_s_barrier(); - const auto bias_tile = load_tile(bias_dram_window); // load bias tile - auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); + auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); __builtin_amdgcn_sched_barrier(0); { // tail gemm_0( @@ -452,75 +425,11 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga } __builtin_amdgcn_sched_barrier(1); - // STAGE 2, scale_s, add bias, mask, softmax - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); - tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); - tile_elementwise_inout( - [&](auto& x, const auto& y) { + // STAGE 2, scale_s, mask, softmax (no bias/soft-cap) + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); #if !CK_TILE_FMHA_FWD_FAST_EXP2 - x += type_convert(bias_element_func(y)); -#else - x += log2e_v * - type_convert(bias_element_func(y)); + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); #endif - }, - s_acc, - bias_tile); - } - else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - const auto k_origin = k_dram_block_window.get_window_origin(); - constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); - sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { - const auto tile_idx = get_x_indices_from_distributed_indices( - s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); - - const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - constexpr auto i_j_idx = make_tuple(idx0, idx1); - - s_acc(i_j_idx) *= scale_s; - position_encoding.update(s_acc(i_j_idx), row, col); - }); - }); - } - else - { - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); - if constexpr(kHasLogitsSoftCap) - { - auto apply_logits_transform = - [&variant, &variant_params, &block_indices](auto& x) { - x = variant.LogitsTransform(variant_params, - variant.QueryTransform(variant_params, x), - block_indices.batch_idx, - block_indices.qo_head_idx, - block_indices.kv_head_idx); - }; -#if !CK_TILE_FMHA_FWD_FAST_EXP2 - for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i) - { - apply_logits_transform(s_acc.thread_buf_[i]); - } -#else - for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i) - { - apply_logits_transform(s_acc.thread_buf_[i]); - } -#endif - } - else - { -#if !CK_TILE_FMHA_FWD_FAST_EXP2 - tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); -#endif - } - } - move_tile_window(bias_dram_window, {0, kN0}); if constexpr(kPadSeqLenK || FmhaMask::IsMasking) { const auto k_origin = k_dram_block_window.get_window_origin(); @@ -562,30 +471,17 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga __builtin_amdgcn_sched_barrier(0x7F); // store & prefetch next v, after the max reduction - if constexpr(std::is_same_v) - { - auto v_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledVRegBlockDescriptor()); - shuffle_tile(v_shuffle_tmp, v_buf); - - auto v_lds_window_tmp = - get_slice_tile(v_lds_window, - sequence<(LdsSeq.at(number{})) * kN1, 0>{}, - sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); - - store_tile( - v_lds_window_tmp, - tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch - } - else - { - auto v_lds_window_tmp = - get_slice_tile(v_lds_window, - sequence<(LdsSeq.at(number{})) * kN1, 0>{}, - sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); - store_tile(v_lds_window_tmp, - tile_elementwise_in(v_element_func, v_buf)); // store the prefetch - } + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_buf); + + auto v_lds_window_tmp = + get_slice_tile(v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + + store_tile(v_lds_window_tmp, + tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch if constexpr(k1_loops > 1) { @@ -598,10 +494,7 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga __builtin_amdgcn_sched_barrier(0); static const auto get_validated_m = [](SMPLComputeDataType raw_m) { - /// NOTICE: bias might be materialized mask including -inf values, need - /// consideration. alibi does not have this problem - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - FmhaMask::IsMasking) + if constexpr(FmhaMask::IsMasking) { return raw_m == -numeric::infinity() ? type_convert(0.f) @@ -622,22 +515,7 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); #if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); - } - else - { - if constexpr(kHasLogitsSoftCap) - { - p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); - } - else - { - p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); - } - } + p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); #else p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); #endif @@ -654,23 +532,8 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga constexpr auto i_idx = make_tuple(idx0); #if CK_TILE_FMHA_FWD_FAST_EXP2 const auto tmp = [&]() { - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); - } - else - { - if constexpr(kHasLogitsSoftCap) - { - return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); - } - else - { - auto row_max = scale_s * get_validated_m(m[i_idx]); - return exp2(scale_s * m_old[i_idx] - row_max); - } - } + auto row_max = scale_s * get_validated_m(m[i_idx]); + return exp2(scale_s * m_old[i_idx] - row_max); }(); #else const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx])); @@ -685,17 +548,6 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga }); }); - if constexpr(kHasDropout) - { - auto randval_ptr = - reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); - dropout.template Run( - randval_ptr, - seqlen_k_start + i_total_loops * kN0, - p_compute, - randval_dram_window); - } - const auto p = cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); @@ -717,28 +569,16 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga sequence<(LdsSeq.at(number{})) * kN1, 0>{}, sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{})); - if constexpr(std::is_same_v) - { - auto v_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledVRegBlockDescriptor()); - shuffle_tile(v_shuffle_tmp, v_buf); - auto v_lds_window_tmp = get_slice_tile( - v_lds_window, - sequence<(LdsSeq.at(number{})) * kN1, 0>{}, - sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); - store_tile(v_lds_window_tmp, - tile_elementwise_in(v_element_func, - v_shuffle_tmp)); // store the prefetch - } - else - { - auto v_lds_window_tmp = get_slice_tile( - v_lds_window, - sequence<(LdsSeq.at(number{})) * kN1, 0>{}, - sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); - store_tile(v_lds_window_tmp, - tile_elementwise_in(v_element_func, v_buf)); // store next v_buf - } + auto v_shuffle_tmp_next = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp_next, v_buf); + auto v_lds_window_tmp_next = get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + store_tile(v_lds_window_tmp_next, + tile_elementwise_in(v_element_func, + v_shuffle_tmp_next)); // store the prefetch if constexpr(i_k1 < k1_loops - 1) move_tile_window(v_dram_window, {0, kK1}); }); @@ -773,39 +613,6 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga } } while(i_total_loops < num_total_loop); - // store lse - if constexpr(kStoreLSE) - { - auto lse = make_static_distributed_tensor(m.get_tile_distribution()); - - constexpr auto lse_spans = decltype(lse)::get_distributed_spans(); - sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); -#if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - lse(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]); - } - else - { - if constexpr(kHasLogitsSoftCap) - { - lse(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]); - } - else - { - lse(i_idx) = m_[i_idx] * scale_s * R_LOG2E + log(l_[i_idx]); - } - } -#else - lse(i_idx) = m_[i_idx] + log(l_[i_idx]); -#endif - }); - - store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); - } - // finally, O constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); diff --git a/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp index bf4f89e3c4f..05aff237bee 100644 --- a/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp +++ b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp @@ -62,10 +62,11 @@ struct BlockFmhaPipelineQRKSVSAsyncVSA static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kHasDropout = Problem::kHasDropout; - static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && - (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || - !kHasLogitsSoftCap)) || - (!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap)); + static_assert(BiasEnum == BlockAttentionBiasEnum::NO_BIAS, + "VSA sparse attention does not support bias."); + static_assert(!kHasDropout, "VSA sparse attention does not support dropout."); + static_assert(!kStoreLSE, "VSA sparse attention does not support LSE output."); + static_assert(!kHasLogitsSoftCap, "VSA sparse attention does not support logits soft-cap."); // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this @@ -78,8 +79,6 @@ struct BlockFmhaPipelineQRKSVSAsyncVSA return kPadSeqLenK ? 1 : Policy::template GetAlignmentV(); }(); static constexpr index_t kAlignmentO = Policy::template GetAlignmentO(); - static constexpr index_t kAlignmentBias = - kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); #if CK_TILE_FMHA_FWD_FAST_EXP2 static constexpr auto R_LOG2E = 1.0 / log2e_v; @@ -91,36 +90,30 @@ struct BlockFmhaPipelineQRKSVSAsyncVSA else { // minimize occupancy - if constexpr(BiasEnum != BlockAttentionBiasEnum::NO_BIAS && kHasDropout) - { - return 1; - } - if constexpr(kQKHeaddim <= 32) { - if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS && - FmhaMask::IsMasking) + if constexpr(kPadSeqLenK && FmhaMask::IsMasking) return 1; else return 2; } else if constexpr(kQKHeaddim <= 64) { - if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + if constexpr(kPadSeqLenK) return 2; else return 3; } else if constexpr(kQKHeaddim <= 128) { - if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + if constexpr(kPadSeqLenK) return 1; else return 2; } else if constexpr(kQKHeaddim <= 192) { - if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + if constexpr(kPadSeqLenK) return 1; else return 2; @@ -289,15 +282,6 @@ struct BlockFmhaPipelineQRKSVSAsyncVSA { if(num_total_loop <= 0) { - if constexpr(kStoreLSE) - { - auto lse = - make_static_distributed_tensor(m.get_tile_distribution()); - - set_tile(lse, -numeric::infinity()); - - store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); - } buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0) // otherwise will have compute error(maybe compiler bug?) @@ -320,24 +304,14 @@ struct BlockFmhaPipelineQRKSVSAsyncVSA // load k_dram_window.init_raw(); constexpr auto k_oob_ck = bool_constant{}; - constexpr auto k_pre_np = [&]() { - if constexpr(kPadSeqLenK && - (BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - (BiasEnum != BlockAttentionBiasEnum::NO_BIAS && kHasDropout))) - return bool_constant{}; - else - return bool_constant{}; - }(); - - const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); - auto bias_dram_window = - make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), - bias_dram_block_window_tmp.get_window_lengths(), - {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N - Policy::template MakeBiasDramTileDistribution()); - - auto randval_dram_window = dropout.template MakeRandvalDramWindow( - randval_dram_block_window_tmp, seqlen_k_start); + constexpr auto k_pre_np = bool_constant{}; + (void)bias_dram_block_window_tmp; + (void)bias_element_func; + (void)randval_dram_block_window_tmp; + (void)lse_dram_window_tmp; + (void)lse_element_func; + (void)position_encoding; + (void)dropout; auto v_dram_window = make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), @@ -405,8 +379,7 @@ struct BlockFmhaPipelineQRKSVSAsyncVSA int block_idx = kv_block_idx_ptr[i_total_loops + 1]; // if (threadIdx.x==0 && blockIdx.x==0 && blockIdx.z == 101) printf("%d %d %d\n", // i_total_loops, num_total_loop, block_idx); - const auto bias_tile = load_tile(bias_dram_window); // load bias tile - auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); + auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); __builtin_amdgcn_sched_barrier(0); { // tail gemm_0( @@ -419,75 +392,11 @@ struct BlockFmhaPipelineQRKSVSAsyncVSA } __builtin_amdgcn_sched_barrier(1); - // STAGE 2, scale_s, add bias, mask, softmax - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); - tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); - tile_elementwise_inout( - [&](auto& x, const auto& y) { + // STAGE 2, scale_s, mask, softmax (no bias/soft-cap) + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); #if !CK_TILE_FMHA_FWD_FAST_EXP2 - x += type_convert(bias_element_func(y)); -#else - x += log2e_v * - type_convert(bias_element_func(y)); + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); #endif - }, - s_acc, - bias_tile); - } - else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - const auto k_origin = k_dram_block_window.get_window_origin(); - constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); - sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { - const auto tile_idx = get_x_indices_from_distributed_indices( - s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); - - const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - constexpr auto i_j_idx = make_tuple(idx0, idx1); - - s_acc(i_j_idx) *= scale_s; - position_encoding.update(s_acc(i_j_idx), row, col); - }); - }); - } - else - { - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); - if constexpr(kHasLogitsSoftCap) - { - auto apply_logits_transform = - [&variant, &variant_params, &block_indices](auto& x) { - x = variant.LogitsTransform(variant_params, - variant.QueryTransform(variant_params, x), - block_indices.batch_idx, - block_indices.qo_head_idx, - block_indices.kv_head_idx); - }; -#if !CK_TILE_FMHA_FWD_FAST_EXP2 - for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i) - { - apply_logits_transform(s_acc.thread_buf_[i]); - } -#else - for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i) - { - apply_logits_transform(s_acc.thread_buf_[i]); - } -#endif - } - else - { -#if !CK_TILE_FMHA_FWD_FAST_EXP2 - tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); -#endif - } - } - move_tile_window(bias_dram_window, {0, kN0}); if constexpr(kPadSeqLenK || FmhaMask::IsMasking) { const auto k_origin = k_dram_block_window.get_window_origin(); @@ -565,10 +474,7 @@ struct BlockFmhaPipelineQRKSVSAsyncVSA __builtin_amdgcn_sched_barrier(0); static const auto get_validated_m = [](SMPLComputeDataType raw_m) { - /// NOTICE: bias might be materialized mask including -inf values, need - /// consideration. alibi does not have this problem - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - FmhaMask::IsMasking) + if constexpr(FmhaMask::IsMasking) { return raw_m == -numeric::infinity() ? type_convert(0.f) @@ -589,22 +495,7 @@ struct BlockFmhaPipelineQRKSVSAsyncVSA sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); #if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); - } - else - { - if constexpr(kHasLogitsSoftCap) - { - p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); - } - else - { - p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); - } - } + p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); #else p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); #endif @@ -621,23 +512,8 @@ struct BlockFmhaPipelineQRKSVSAsyncVSA constexpr auto i_idx = make_tuple(idx0); #if CK_TILE_FMHA_FWD_FAST_EXP2 const auto tmp = [&]() { - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); - } - else - { - if constexpr(kHasLogitsSoftCap) - { - return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); - } - else - { - auto row_max = scale_s * get_validated_m(m[i_idx]); - return exp2(scale_s * m_old[i_idx] - row_max); - } - } + auto row_max = scale_s * get_validated_m(m[i_idx]); + return exp2(scale_s * m_old[i_idx] - row_max); }(); #else const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx])); @@ -652,17 +528,6 @@ struct BlockFmhaPipelineQRKSVSAsyncVSA }); }); - if constexpr(kHasDropout) - { - auto randval_ptr = - reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); - dropout.template Run( - randval_ptr, - seqlen_k_start + i_total_loops * kN0, - p_compute, - randval_dram_window); - } - const auto p = [&]() { if constexpr(std::is_same_v) return impl::cast_tile_pkrtz_fp16_fp32( @@ -757,39 +622,6 @@ struct BlockFmhaPipelineQRKSVSAsyncVSA } } while(i_total_loops < num_total_loop); - // store lse - if constexpr(kStoreLSE) - { - auto lse = make_static_distributed_tensor(m.get_tile_distribution()); - - constexpr auto lse_spans = decltype(lse)::get_distributed_spans(); - sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); -#if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - lse(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]); - } - else - { - if constexpr(kHasLogitsSoftCap) - { - lse(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]); - } - else - { - lse(i_idx) = m_[i_idx] * scale_s * R_LOG2E + log(l_[i_idx]); - } - } -#else - lse(i_idx) = m_[i_idx] + log(l_[i_idx]); -#endif - }); - - store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); - } - // finally, O constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); From 4693a5cfe9623dfd35b25b79d956340b4004b78d Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Thu, 29 Jan 2026 03:03:07 -0600 Subject: [PATCH 19/22] Clean up code --- .../codegen/ops/fmha_fwd_vsa.py | 6 +- .../ck_tile/50_sparse_attn/fmha_fwd_trek.hpp | 80 ++++++++++++++----- .../50_sparse_attn/vsa_sparse_attention.cpp | 2 +- .../kernel/fmha_fwd_jenga_kernel.hpp | 9 +-- .../kernel/fmha_fwd_vsa_kernel.hpp | 10 +-- ...block_fmha_pipeline_qr_ks_vs_async_vsa.hpp | 25 +----- 6 files changed, 75 insertions(+), 57 deletions(-) diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py index 731e4be9b2c..c303a5ab8d5 100644 --- a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py +++ b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py @@ -114,7 +114,7 @@ def update_file(file_path, content): using fmha_kernel_{F_idx} = ck_tile::FmhaFwdVSAKernel; -using trait_{F_idx} = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, +using trait_{F_idx} = fmha_vsa_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; #include @@ -166,7 +166,7 @@ def update_file(file_path, content): }} }} // namespace -float fmha_vsa_fwd(fmha_jenga_fwd_traits t, fmha_vsa_fwd_args a, const ck_tile::stream_config& s){{ +float fmha_vsa_fwd(fmha_vsa_fwd_traits t, fmha_vsa_fwd_args a, const ck_tile::stream_config& s){{ float r = -1; [[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate @@ -203,7 +203,7 @@ def update_file(file_path, content): FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ - using trait_ = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; + using trait_ = fmha_vsa_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; return fmha_vsa_fwd_(s, a); }} """ diff --git a/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp b/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp index e90431b22fc..e6d4881b820 100644 --- a/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp +++ b/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp @@ -32,33 +32,37 @@ struct FmhaSparseFwdTypeConfig; template <> struct FmhaSparseFwdTypeConfig { - using QDataType = ck_tile::half_t; - using KDataType = ck_tile::half_t; - using VDataType = ck_tile::half_t; + using QDataType = ck_tile::half_t; + using KDataType = ck_tile::half_t; + using VDataType = ck_tile::half_t; + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::half_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::half_t; + // Note: The following types are required by BlockFmhaPipelineProblem but not used + // by sparse attention (bias, dropout, LSE are not supported). using BiasDataType = ck_tile::half_t; using RandValOutputDataType = uint8_t; - using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) - using SaccDataType = float; // data type for first gemm accumulation - using SMPLComputeDataType = float; // data type for reduction, softmax - using PDataType = ck_tile::half_t; // data type for A matrix of second gemm - using OaccDataType = float; // data type for second gemm accumulation - using ODataType = ck_tile::half_t; + using LSEDataType = float; }; template <> struct FmhaSparseFwdTypeConfig { - using QDataType = ck_tile::bf16_t; - using KDataType = ck_tile::bf16_t; - using VDataType = ck_tile::bf16_t; + using QDataType = ck_tile::bf16_t; + using KDataType = ck_tile::bf16_t; + using VDataType = ck_tile::bf16_t; + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::bf16_t; + // Note: The following types are required by BlockFmhaPipelineProblem but not used + // by sparse attention (bias, dropout, LSE are not supported). using BiasDataType = ck_tile::bf16_t; using RandValOutputDataType = uint8_t; - using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) - using SaccDataType = float; // data type for first gemm accumulation - using SMPLComputeDataType = float; // data type for reduction, softmax - using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm - using OaccDataType = float; // data type for second gemm accumulation - using ODataType = ck_tile::bf16_t; + using LSEDataType = float; }; struct FmhaMasks @@ -278,7 +282,45 @@ float fmha_jenga_fwd_(const ck_tile::stream_config&, fmha_jenga_fwd_args); float fmha_jenga_fwd(fmha_jenga_fwd_args, const ck_tile::stream_config&); -float fmha_vsa_fwd(fmha_jenga_fwd_traits, fmha_vsa_fwd_args, const ck_tile::stream_config&); +// VSA uses the same traits structure as Jenga; aliases for clarity +template +using fmha_vsa_fwd_traits_ = fmha_jenga_fwd_traits_; + +using fmha_vsa_fwd_traits = fmha_jenga_fwd_traits; + +float fmha_vsa_fwd(fmha_vsa_fwd_traits, fmha_vsa_fwd_args, const ck_tile::stream_config&); template float fmha_vsa_fwd_(const ck_tile::stream_config&, fmha_vsa_fwd_args); diff --git a/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cpp b/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cpp index 2dbd265ef22..88c28f19a71 100644 --- a/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cpp +++ b/example/ck_tile/50_sparse_attn/vsa_sparse_attention.cpp @@ -149,7 +149,7 @@ vsa_sparse_attention(const ck_tile::HostTensor& TQ, traits.mask_type = mask.type; }; - fmha_jenga_fwd_traits fmha_traits; + fmha_vsa_fwd_traits fmha_traits; init_traits(fmha_traits); fmha_vsa_fwd_args args; diff --git a/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_jenga_kernel.hpp b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_jenga_kernel.hpp index d0424898f09..e7531a1ce62 100644 --- a/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_jenga_kernel.hpp +++ b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_jenga_kernel.hpp @@ -110,15 +110,13 @@ struct FmhaFwdJengaKernel struct FmhaFwdMaskKargs { - // ck_tile::index_t window_size_left, window_size_right; ck_tile::index_t window_size_left, window_size_right; ck_tile::GenericAttentionMaskEnum mask_type; }; struct FmhaFwdBatchModeKargs : FmhaFwdCommonKargs, - std::conditional_t>, - FmhaFwdEmptyKargs<2> + std::conditional_t> { ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; @@ -186,9 +184,8 @@ struct FmhaFwdJengaKernel nhead_stride_q, nhead_stride_k, nhead_stride_v, - nhead_stride_o}, // args for common karg - {}, // placeholder for mask - {}, // placeholder for empty kargs + nhead_stride_o}, // FmhaFwdCommonKargs + {}, // FmhaFwdMaskKargs or FmhaFwdEmptyKargs<1> batch_stride_q, batch_stride_k, batch_stride_v, diff --git a/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp index 686576bb9be..bca8d691a71 100644 --- a/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp +++ b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp @@ -111,12 +111,13 @@ struct FmhaFwdVSAKernel struct FmhaFwdMaskKargs { - // ck_tile::index_t window_size_left, window_size_right; ck_tile::index_t window_size_left, window_size_right; ck_tile::GenericAttentionMaskEnum mask_type; }; - struct FmhaFwdBatchModeKargs : FmhaFwdCommonKargs, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<2> + struct FmhaFwdBatchModeKargs + : FmhaFwdCommonKargs, + std::conditional_t> { ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; @@ -187,9 +188,8 @@ struct FmhaFwdVSAKernel nhead_stride_q, nhead_stride_k, nhead_stride_v, - nhead_stride_o}, // args for common karg - {}, // placeholder for mask - {}, // placeholder for empty kargs + nhead_stride_o}, // FmhaFwdCommonKargs + {}, // FmhaFwdMaskKargs or FmhaFwdEmptyKargs<1> batch_stride_q, batch_stride_k, batch_stride_v, diff --git a/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp index 05aff237bee..06432f2b196 100644 --- a/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp +++ b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp @@ -270,11 +270,7 @@ struct BlockFmhaPipelineQRKSVSAsyncVSA clear_tile(l); __builtin_amdgcn_sched_barrier(0); - const auto q_origin = q_dram_window.get_window_origin(); - // const auto [seqlen_k_start, seqlen_k_end] = - // mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); - // const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); - + const auto q_origin = q_dram_window.get_window_origin(); const auto num_total_loop = kv_blocks; // check early exit if no work to do @@ -363,10 +359,6 @@ struct BlockFmhaPipelineQRKSVSAsyncVSA sequence<(LdsSeq.at(number{}) + 1) * kN0, kK0>{})); }); } - //__shared__ int printed_flag; - // if (blockIdx.x == 0 && threadIdx.x == 0 && i_total_loops==1000) { - // printed_flag = 100; - //} // TODO: this to fix a bug when loop smaller than 2, // the following fence/barrier will be scheduled inside 1st loop @@ -377,9 +369,7 @@ struct BlockFmhaPipelineQRKSVSAsyncVSA __builtin_amdgcn_s_barrier(); int block_idx = kv_block_idx_ptr[i_total_loops + 1]; - // if (threadIdx.x==0 && blockIdx.x==0 && blockIdx.z == 101) printf("%d %d %d\n", - // i_total_loops, num_total_loop, block_idx); - auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); + auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); __builtin_amdgcn_sched_barrier(0); { // tail gemm_0( @@ -585,19 +575,8 @@ struct BlockFmhaPipelineQRKSVSAsyncVSA if(i_total_loops < num_total_loop) { move_tile_window(v_dram_window, {0, kN0 * (block_idx - 1)}); - // v_dram_window = - // make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), - // v_dram_block_window_tmp.get_window_lengths(), - // {0, kv_block_idx[i_total_loops]}, - // Policy::template MakeVDramTileDistribution()); - // move K tile windows move_tile_window(k_dram_block_window, {kN0 * block_idx, 0}); k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); - // k_dram_block_window = - // make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), - // k_dram_block_window_tmp.get_window_lengths(), - // {kv_block_idx[i_total_loops], 0}); - // k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); if constexpr(k1_loops >= 2 && LdsSeq.at(number<0>{}) == LdsSeq.at(number{})) From 685842c2ea4d6ff9e18e27d712ca04171cb536ee Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Thu, 29 Jan 2026 03:43:17 -0600 Subject: [PATCH 20/22] Remove more unused code --- .../kernel/fmha_fwd_jenga_kernel.hpp | 49 ++----- .../kernel/fmha_fwd_vsa_kernel.hpp | 130 ++++-------------- ...ock_fmha_pipeline_qr_ks_vs_async_jenga.hpp | 107 +------------- ...block_fmha_pipeline_qr_ks_vs_async_vsa.hpp | 119 ++-------------- 4 files changed, 54 insertions(+), 351 deletions(-) diff --git a/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_jenga_kernel.hpp b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_jenga_kernel.hpp index e7531a1ce62..cd3513530d4 100644 --- a/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_jenga_kernel.hpp +++ b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_jenga_kernel.hpp @@ -203,11 +203,8 @@ struct FmhaFwdJengaKernel CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_, - ck_tile::index_t hdim_v_, - bool has_padded_seqlen_k = false) + ck_tile::index_t hdim_v_) { - (void)has_padded_seqlen_k; - // TODO: this may need tuning return dim3(nhead_, batch_size_, ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * @@ -394,19 +391,6 @@ struct FmhaFwdJengaKernel make_tuple(number{}, number{}), {i_n1, 0}); - constexpr auto bias_dram_window_lengths = - make_tuple(number{}, number{}); - const auto bias_dram_window = make_null_tile_window(bias_dram_window_lengths); - - constexpr auto lse_dram_window_lengths = make_tuple(number{}); - auto lse_dram_window = make_null_tile_window(lse_dram_window_lengths); - - auto dropout = NullBlockDropout{}; - - constexpr auto randval_dram_window_lengths = - make_tuple(number{}, number{}); - auto randval_dram_window = make_null_tile_window(randval_dram_window_lengths); - FmhaMask mask = [&]() { if constexpr(kHasMask) return ck_tile::make_generic_attention_mask_from_lr_window( @@ -419,32 +403,21 @@ struct FmhaFwdJengaKernel return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; }(); - // WA i_batch capture structure binding before c++20 - auto position_encoding = EmptyPositionEncoding{}; - AttentionVariant variant; const auto variant_params = ck_tile::StandardAttentionParams{mask, kargs.scale_s}; BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; - auto o_acc_tile = [&]() { - // TODO: constexpr(kDoFp8StaticQuant) - return FmhaPipeline{}(q_dram_window, - k_dram_window, - v_dram_window, - block_relation_onehot_ptr, - bias_dram_window, - randval_dram_window, - lse_dram_window, - mask, - position_encoding, - kargs.scale_s, - variant, - variant_params, - block_indices, - smem_ptr, - dropout); - }(); + auto o_acc_tile = FmhaPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + block_relation_onehot_ptr, + mask, + kargs.scale_s, + variant, + variant_params, + block_indices, + smem_ptr); // O DRAM and O DRAM window auto o_dram = [&]() { diff --git a/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp index bca8d691a71..5caf27756ff 100644 --- a/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp +++ b/include/ck_tile/ops/sparse_attn/kernel/fmha_fwd_vsa_kernel.hpp @@ -207,83 +207,37 @@ struct FmhaFwdVSAKernel CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_, - ck_tile::index_t hdim_v_, - bool has_padded_seqlen_k = false) + ck_tile::index_t hdim_v_) { - // has_padded_seqlen_k is determined by checking (seqlen_k_ptr != nullptr) - if(has_padded_seqlen_k) - { - // TODO: this may need tuning - return dim3(nhead_, - batch_size_, - ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * - ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1)); - } - else - { - // TODO: this may need tuning - return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * - ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1), - nhead_, - batch_size_); - } + return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1), + nhead_, + batch_size_); } CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs) { - bool has_padded_seqlen_k = false; - - if(has_padded_seqlen_k) - { - const index_t num_tile_n1 = - ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1); + const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1); - const index_t i_block = blockIdx.z; - const index_t i_nhead = blockIdx.x; - const index_t i_batch = blockIdx.y; + const index_t i_block = blockIdx.x; + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; - const auto f = [](index_t dividend, index_t divisor) { - index_t quotient = dividend / divisor; - index_t modulus = dividend - quotient * divisor; - return ck_tile::make_tuple(quotient, modulus); - }; + const auto f = [](index_t dividend, index_t divisor) { + index_t quotient = dividend / divisor; + index_t modulus = dividend - quotient * divisor; + return ck_tile::make_tuple(quotient, modulus); + }; - const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); + const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); - if constexpr(kHasMask) - { - return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); - } - else - { - return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); - } + if constexpr(kHasMask) + { + return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); } else { - const index_t num_tile_n1 = - ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1); - - const index_t i_block = blockIdx.x; - const index_t i_nhead = blockIdx.y; - const index_t i_batch = blockIdx.z; - - const auto f = [](index_t dividend, index_t divisor) { - index_t quotient = dividend / divisor; - index_t modulus = dividend - quotient * divisor; - return ck_tile::make_tuple(quotient, modulus); - }; - - const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); - - if constexpr(kHasMask) - { - return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); - } - else - { - return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); - } + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); } } @@ -428,19 +382,6 @@ struct FmhaFwdVSAKernel make_tuple(number{}, number{}), {i_n1, 0}); - constexpr auto bias_dram_window_lengths = - make_tuple(number{}, number{}); - const auto bias_dram_window = make_null_tile_window(bias_dram_window_lengths); - - constexpr auto lse_dram_window_lengths = make_tuple(number{}); - auto lse_dram_window = make_null_tile_window(lse_dram_window_lengths); - - auto dropout = NullBlockDropout{}; - - constexpr auto randval_dram_window_lengths = - make_tuple(number{}, number{}); - auto randval_dram_window = make_null_tile_window(randval_dram_window_lengths); - FmhaMask mask = [&]() { if constexpr(kHasMask) return ck_tile::make_generic_attention_mask_from_lr_window( @@ -453,33 +394,22 @@ struct FmhaFwdVSAKernel return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; }(); - // WA i_batch capture structure binding before c++20 - auto position_encoding = EmptyPositionEncoding{}; - AttentionVariant variant; const auto variant_params = ck_tile::StandardAttentionParams{mask, kargs.scale_s}; BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; - auto o_acc_tile = [&]() { - // TODO: constexpr(kDoFp8StaticQuant) - return FmhaPipeline{}(q_dram_window, - k_dram_window, - v_dram_window, - lut_ptr, - valid_block_num_value, - bias_dram_window, - randval_dram_window, - lse_dram_window, - mask, - position_encoding, - kargs.scale_s, - variant, - variant_params, - block_indices, - smem_ptr, - dropout); - }(); + auto o_acc_tile = FmhaPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + lut_ptr, + valid_block_num_value, + mask, + kargs.scale_s, + variant, + variant_params, + block_indices, + smem_ptr); // O DRAM and O DRAM window auto o_dram = [&]() { diff --git a/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp index 7ab4cc710c7..67936c4353f 100644 --- a/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp +++ b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_jenga.hpp @@ -130,8 +130,6 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga static constexpr const char* name = "qr_async"; - using DropoutType = std::conditional_t; - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return Policy::template GetSmemSize(); @@ -140,44 +138,19 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga template CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const QElementFunction& q_element_func, const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - const KElementFunction& /*k_element_func*/, const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - const VElementFunction& v_element_func, const bool* block_relation_onehot_ptr, - const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile - const BiasElementFunction& bias_element_func, - RandValDramBlockWindowTmp& randval_dram_block_window_tmp, - LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile - const LSEElementFunction& lse_element_func, - const SAccElementFunction& s_acc_element_func, - const PComputeElementFunction& p_compute_element_func, - const OAccElementFunction& o_acc_element_func, FmhaMask mask, - PositionEncoding position_encoding, float scale_s, const AttentionVariant& variant, const AttentionVariantParams& variant_params, const BlockIndices& block_indices, - void* smem_ptr, - DropoutType& dropout) const + void* smem_ptr) const { static_assert( std::is_same_v> && @@ -189,9 +162,7 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && - kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); constexpr auto LdsSeq = Policy::template GetLdsBufferSequence(); @@ -311,14 +282,6 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga constexpr auto k_oob_ck = bool_constant{}; constexpr auto k_pre_np = bool_constant{}; - (void)bias_dram_block_window_tmp; - (void)bias_element_func; - (void)randval_dram_block_window_tmp; - (void)lse_dram_window_tmp; - (void)lse_element_func; - (void)position_encoding; - (void)dropout; - auto v_dram_window = make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), v_dram_block_window_tmp.get_window_lengths(), @@ -343,8 +306,6 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga // buffer_load_fence(k_dram_window.get_num_of_access(), q.get_thread_buffer()); buffer_load_fence(k_dram_window.get_num_of_access()); - (void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32 - // auto q_tile = q; // tile_elementwise_in(q_element_func, q); index_t i_total_loops = 0; constexpr index_t k0_loops = kQKHeaddim / kK0; @@ -426,7 +387,6 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga __builtin_amdgcn_sched_barrier(1); // STAGE 2, scale_s, mask, softmax (no bias/soft-cap) - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); #if !CK_TILE_FMHA_FWD_FAST_EXP2 tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); #endif @@ -480,8 +440,7 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga sequence<(LdsSeq.at(number{})) * kN1, 0>{}, sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); - store_tile(v_lds_window_tmp, - tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch + store_tile(v_lds_window_tmp, v_shuffle_tmp); if constexpr(k1_loops > 1) { @@ -548,8 +507,7 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga }); }); - const auto p = - cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); + const auto p = cast_tile(p_compute); // STAGE 3, KV gemm if constexpr(k1_loops > 1) @@ -576,9 +534,7 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga v_lds_window, sequence<(LdsSeq.at(number{})) * kN1, 0>{}, sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); - store_tile(v_lds_window_tmp_next, - tile_elementwise_in(v_element_func, - v_shuffle_tmp_next)); // store the prefetch + store_tile(v_lds_window_tmp_next, v_shuffle_tmp_next); if constexpr(i_k1 < k1_loops - 1) move_tile_window(v_dram_window, {0, kK1}); }); @@ -632,61 +588,8 @@ struct BlockFmhaPipelineQRKSVSAsyncJenga }); }); - o_acc = tile_elementwise_in(o_acc_element_func, o_acc); - return o_acc; } - - template - CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - const bool* block_relation_onehot_ptr, - const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile - RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile - LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile - FmhaMask mask, - PositionEncoding position_encoding, - float scale_s, - const AttentionVariant& variant, - const AttentionVariantParams& variant_params, - const BlockIndices& block_indices, - void* smem_ptr, - DropoutType& dropout) const - { - return operator()(q_dram_block_window_tmp, - identity{}, - k_dram_block_window_tmp, - identity{}, - v_dram_block_window_tmp, - identity{}, - block_relation_onehot_ptr, - bias_dram_block_window_tmp, - identity{}, - randval_dram_block_window_tmp, - lse_dram_block_window_tmp, - identity{}, - identity{}, - identity{}, - identity{}, - mask, - position_encoding, - scale_s, - variant, - variant_params, - block_indices, - smem_ptr, - dropout); - } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp index 06432f2b196..2b097ae5827 100644 --- a/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp +++ b/include/ck_tile/ops/sparse_attn/pipeline/block_fmha_pipeline_qr_ks_vs_async_vsa.hpp @@ -131,8 +131,6 @@ struct BlockFmhaPipelineQRKSVSAsyncVSA static constexpr const char* name = "qr_async"; - using DropoutType = std::conditional_t; - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return Policy::template GetSmemSize(); @@ -141,45 +139,20 @@ struct BlockFmhaPipelineQRKSVSAsyncVSA template CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const QElementFunction& q_element_func, const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - const KElementFunction& /*k_element_func*/, const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - const VElementFunction& v_element_func, const int* kv_block_idx_ptr, int kv_blocks, - const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile - const BiasElementFunction& bias_element_func, - RandValDramBlockWindowTmp& randval_dram_block_window_tmp, - LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile - const LSEElementFunction& lse_element_func, - const SAccElementFunction& s_acc_element_func, - const PComputeElementFunction& p_compute_element_func, - const OAccElementFunction& o_acc_element_func, FmhaMask mask, - PositionEncoding position_encoding, float scale_s, const AttentionVariant& variant, const AttentionVariantParams& variant_params, const BlockIndices& block_indices, - void* smem_ptr, - DropoutType& dropout) const + void* smem_ptr) const { static_assert( std::is_same_v> && @@ -191,9 +164,7 @@ struct BlockFmhaPipelineQRKSVSAsyncVSA kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && - kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); constexpr auto LdsSeq = Policy::template GetLdsBufferSequence(); @@ -301,14 +272,6 @@ struct BlockFmhaPipelineQRKSVSAsyncVSA k_dram_window.init_raw(); constexpr auto k_oob_ck = bool_constant{}; constexpr auto k_pre_np = bool_constant{}; - (void)bias_dram_block_window_tmp; - (void)bias_element_func; - (void)randval_dram_block_window_tmp; - (void)lse_dram_window_tmp; - (void)lse_element_func; - (void)position_encoding; - (void)dropout; - auto v_dram_window = make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), v_dram_block_window_tmp.get_window_lengths(), @@ -323,8 +286,6 @@ struct BlockFmhaPipelineQRKSVSAsyncVSA // buffer_load_fence(k_dram_window.get_num_of_access(), q.get_thread_buffer()); buffer_load_fence(k_dram_window.get_num_of_access()); - (void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32 - // auto q_tile = q; // tile_elementwise_in(q_element_func, q); index_t i_total_loops = 0; constexpr index_t k0_loops = kQKHeaddim / kK0; @@ -383,7 +344,6 @@ struct BlockFmhaPipelineQRKSVSAsyncVSA __builtin_amdgcn_sched_barrier(1); // STAGE 2, scale_s, mask, softmax (no bias/soft-cap) - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); #if !CK_TILE_FMHA_FWD_FAST_EXP2 tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); #endif @@ -439,9 +399,7 @@ struct BlockFmhaPipelineQRKSVSAsyncVSA sequence<(LdsSeq.at(number{})) * kN1, 0>{}, sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); - store_tile( - v_lds_window_tmp, - tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch + store_tile(v_lds_window_tmp, v_shuffle_tmp); } else { @@ -449,8 +407,7 @@ struct BlockFmhaPipelineQRKSVSAsyncVSA get_slice_tile(v_lds_window, sequence<(LdsSeq.at(number{})) * kN1, 0>{}, sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); - store_tile(v_lds_window_tmp, - tile_elementwise_in(v_element_func, v_buf)); // store the prefetch + store_tile(v_lds_window_tmp, v_buf); } if constexpr(k1_loops > 1) @@ -520,11 +477,9 @@ struct BlockFmhaPipelineQRKSVSAsyncVSA const auto p = [&]() { if constexpr(std::is_same_v) - return impl::cast_tile_pkrtz_fp16_fp32( - tile_elementwise_in(p_compute_element_func, p_compute)); + return impl::cast_tile_pkrtz_fp16_fp32(p_compute); else - return cast_tile( - tile_elementwise_in(p_compute_element_func, p_compute)); + return cast_tile(p_compute); }(); // STAGE 3, KV gemm @@ -554,9 +509,7 @@ struct BlockFmhaPipelineQRKSVSAsyncVSA v_lds_window, sequence<(LdsSeq.at(number{})) * kN1, 0>{}, sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); - store_tile(v_lds_window_tmp, - tile_elementwise_in(v_element_func, - v_shuffle_tmp)); // store the prefetch + store_tile(v_lds_window_tmp, v_shuffle_tmp); } else { @@ -564,8 +517,7 @@ struct BlockFmhaPipelineQRKSVSAsyncVSA v_lds_window, sequence<(LdsSeq.at(number{})) * kN1, 0>{}, sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); - store_tile(v_lds_window_tmp, - tile_elementwise_in(v_element_func, v_buf)); // store next v_buf + store_tile(v_lds_window_tmp, v_buf); } if constexpr(i_k1 < k1_loops - 1) move_tile_window(v_dram_window, {0, kK1}); @@ -620,63 +572,8 @@ struct BlockFmhaPipelineQRKSVSAsyncVSA }); }); - o_acc = tile_elementwise_in(o_acc_element_func, o_acc); - return o_acc; } - - template - CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - const int* kv_block_idx_ptr, - int kv_blocks, - const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile - RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile - LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile - FmhaMask mask, - PositionEncoding position_encoding, - float scale_s, - const AttentionVariant& variant, - const AttentionVariantParams& variant_params, - const BlockIndices& block_indices, - void* smem_ptr, - DropoutType& dropout) const - { - return operator()(q_dram_block_window_tmp, - identity{}, - k_dram_block_window_tmp, - identity{}, - v_dram_block_window_tmp, - identity{}, - kv_block_idx_ptr, - kv_blocks, - bias_dram_block_window_tmp, - identity{}, - randval_dram_block_window_tmp, - lse_dram_block_window_tmp, - identity{}, - identity{}, - identity{}, - identity{}, - mask, - position_encoding, - scale_s, - variant, - variant_params, - block_indices, - smem_ptr, - dropout); - } }; } // namespace ck_tile From 368d050a29c2d838de9d778084a2522ddf2181c8 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Thu, 29 Jan 2026 03:49:33 -0600 Subject: [PATCH 21/22] Re-format .hpp --- example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp b/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp index e6d4881b820..7349c3576e8 100644 --- a/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp +++ b/example/ck_tile/50_sparse_attn/fmha_fwd_trek.hpp @@ -35,8 +35,8 @@ struct FmhaSparseFwdTypeConfig using QDataType = ck_tile::half_t; using KDataType = ck_tile::half_t; using VDataType = ck_tile::half_t; - using SaccDataType = float; // data type for first gemm accumulation - using SMPLComputeDataType = float; // data type for reduction, softmax + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax using PDataType = ck_tile::half_t; // data type for A matrix of second gemm using OaccDataType = float; // data type for second gemm accumulation using ODataType = ck_tile::half_t; @@ -53,8 +53,8 @@ struct FmhaSparseFwdTypeConfig using QDataType = ck_tile::bf16_t; using KDataType = ck_tile::bf16_t; using VDataType = ck_tile::bf16_t; - using SaccDataType = float; // data type for first gemm accumulation - using SMPLComputeDataType = float; // data type for reduction, softmax + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm using OaccDataType = float; // data type for second gemm accumulation using ODataType = ck_tile::bf16_t; From 565ab74629fdab19b1a60853e995444cbd58413c Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Thu, 29 Jan 2026 03:54:35 -0600 Subject: [PATCH 22/22] Refactor codegen scripts --- .../codegen/ops/fmha_fwd_jenga.py | 45 ++++++++++--------- .../codegen/ops/fmha_fwd_vsa.py | 45 ++++++++++--------- 2 files changed, 50 insertions(+), 40 deletions(-) diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py index a8566d9e53e..7cf64849afe 100644 --- a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py +++ b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_jenga.py @@ -55,6 +55,15 @@ def update_file(file_path, content): """ +# NOTE: Jenga sparse attention kernel has the following restrictions enforced by static_assert: +# - Group mode: NOT supported (batch mode only) +# - Bias: NOT supported (NO_BIAS only) +# - LSE output: NOT supported (false only) +# - Dropout: NOT supported (false only) +# - Logits soft-cap: NOT supported (false only) +# - FP8 static quantization: NOT supported (NO_SCALE only) +# The template below hardcodes these unsupported features accordingly. + FMHA_FWD_KERNEL_BODY = """ using fmha_dtype_{F_idx} = {F_dtype}; @@ -67,20 +76,22 @@ def update_file(file_path, content): ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, {F_vlayout}>; +// TileFmhaTraits: spad, skpad, dpad, dvpad, has_logits_soft_cap, bias_enum, +// store_lse, has_dropout, has_randval, quant_scale_enum, occupancy, is_v_rowmajor_skip using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, - {F_logits}, - ck_tile::BlockAttentionBiasEnum::NO_BIAS, - false, - false, - false, - ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE, + false, // has_logits_soft_cap - NOT supported + ck_tile::BlockAttentionBiasEnum::NO_BIAS, // bias - NOT supported + false, // store_lse - NOT supported + false, // has_dropout - NOT supported + false, // has_randval - NOT supported + ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE, // FP8 quant - NOT supported {F_occupancy}, false>; -using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; +using fmha_variant_{F_idx} = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>; // logits_soft_cap=0 (NOT supported) using fmha_mask_{F_idx} = {F_mask}; @@ -115,7 +126,7 @@ def update_file(file_path, content): ck_tile::FmhaFwdJengaKernel; using trait_{F_idx} = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; + {F_pipeline_enum}, false/*logits*/, fmha_mask_{F_idx}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; #include @@ -203,7 +214,7 @@ def update_file(file_path, content): FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ - using trait_ = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; + using trait_ = fmha_jenga_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; return fmha_jenga_fwd_(s, a); }} """ @@ -395,7 +406,7 @@ def api(self) -> str: F_if=if_k, F_vlayout=LAYOUT_MAP[trait.vlayout], F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], - F_logits=BOOL_MAP[trait.logits], + # F_logits removed - hardcoded to false (NOT supported) F_mask=get_mask_map(self.mask_impl)[trait.mask], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_trload=BOOL_MAP[trait.tr_load], @@ -510,7 +521,7 @@ def template(self) -> str: F_skpad=BOOL_MAP[self.F_pipeline.F_skpad], F_dpad=BOOL_MAP[self.F_pipeline.F_dpad], F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad], - F_logits=BOOL_MAP[self.F_pipeline.F_logits], + # F_logits removed - hardcoded to false in template (NOT supported) F_occupancy=self.F_tile.F_occupancy, F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], @@ -670,10 +681,11 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]: # this function will populate a list possible pipelines # TODO: the order of List matters! the later in this list will be also be checked later + # NOTE: logits soft-cap is NOT supported by Jenga sparse attention (enforced by static_assert) pipelines = [] if dtype in ["fp16", "bf16"]: for logits, mask in itertools.product( - ["f"], + ["f"], # logits soft-cap NOT supported, always false get_mask_map(mask_impl).keys(), ): if hdim == 256 and hdim_v == 256: @@ -759,24 +771,17 @@ def get_fwd_blobs( ) # Only generate fp16/bf16 kernels for now. + # NOTE: Jenga sparse attention only supports batch mode (group mode NOT supported, enforced by static_assert) for dtype in ["fp16", "bf16"]: d = factory.get_hdim_tile_size_dict(dtype) if d is None: continue - # for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), ["batch"]): for tile, pipeline in itertools.product( tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) ): if tile.F_bm0 != 128 or tile.F_bn0 != 128: continue - if mode == "group": - if pipeline.F_spad != "t" or pipeline.F_skpad != "t": - # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not - continue - # logits soft-cap is not generated for sparse attention - if not (pipeline.F_logits == "f"): - continue if pipeline.tag != "qr_async": continue k = FmhaFwdKernel( diff --git a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py index c303a5ab8d5..11b3fa743c5 100644 --- a/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py +++ b/example/ck_tile/50_sparse_attn/codegen/ops/fmha_fwd_vsa.py @@ -55,6 +55,15 @@ def update_file(file_path, content): """ +# NOTE: VSA sparse attention kernel has the following restrictions enforced by static_assert: +# - Group mode: NOT supported (batch mode only) +# - Bias: NOT supported (NO_BIAS only) +# - LSE output: NOT supported (false only) +# - Dropout: NOT supported (false only) +# - Logits soft-cap: NOT supported (false only) +# - FP8 static quantization: NOT supported (NO_SCALE only) +# The template below hardcodes these unsupported features accordingly. + FMHA_FWD_KERNEL_BODY = """ using fmha_dtype_{F_idx} = {F_dtype}; @@ -67,20 +76,22 @@ def update_file(file_path, content): ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, {F_vlayout}>; +// TileFmhaTraits: spad, skpad, dpad, dvpad, has_logits_soft_cap, bias_enum, +// store_lse, has_dropout, has_randval, quant_scale_enum, occupancy, is_v_rowmajor_skip using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, - {F_logits}, - ck_tile::BlockAttentionBiasEnum::NO_BIAS, - false, - false, - false, - ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE, + false, // has_logits_soft_cap - NOT supported + ck_tile::BlockAttentionBiasEnum::NO_BIAS, // bias - NOT supported + false, // store_lse - NOT supported + false, // has_dropout - NOT supported + false, // has_randval - NOT supported + ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE, // FP8 quant - NOT supported {F_occupancy}, false>; -using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; +using fmha_variant_{F_idx} = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>; // logits_soft_cap=0 (NOT supported) using fmha_mask_{F_idx} = {F_mask}; @@ -115,7 +126,7 @@ def update_file(file_path, content): ck_tile::FmhaFwdVSAKernel; using trait_{F_idx} = fmha_vsa_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; + {F_pipeline_enum}, false/*logits*/, fmha_mask_{F_idx}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; #include @@ -203,7 +214,7 @@ def update_file(file_path, content): FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ - using trait_ = fmha_vsa_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; + using trait_ = fmha_vsa_fwd_traits_<{F_hdim}, {F_dtype}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, false/*logits*/, {F_mask}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}>; return fmha_vsa_fwd_(s, a); }} """ @@ -395,7 +406,7 @@ def api(self) -> str: F_if=if_k, F_vlayout=LAYOUT_MAP[trait.vlayout], F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], - F_logits=BOOL_MAP[trait.logits], + # F_logits removed - hardcoded to false (NOT supported) F_mask=get_mask_map(self.mask_impl)[trait.mask], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_trload=BOOL_MAP[trait.tr_load], @@ -510,7 +521,7 @@ def template(self) -> str: F_skpad=BOOL_MAP[self.F_pipeline.F_skpad], F_dpad=BOOL_MAP[self.F_pipeline.F_dpad], F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad], - F_logits=BOOL_MAP[self.F_pipeline.F_logits], + # F_logits removed - hardcoded to false in template (NOT supported) F_occupancy=self.F_tile.F_occupancy, F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], @@ -670,10 +681,11 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]: def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]: # this function will populate a list possible pipelines # TODO: the order of List matters! the later in this list will be also be checked later + # NOTE: logits soft-cap is NOT supported by VSA sparse attention (enforced by static_assert) pipelines = [] if dtype in ["fp16", "bf16"]: for logits, mask in itertools.product( - ["f"], + ["f"], # logits soft-cap NOT supported, always false get_mask_map(mask_impl).keys(), ): if hdim == 256 and hdim_v == 256: @@ -759,24 +771,17 @@ def get_fwd_blobs( ) # Only generate fp16/bf16 kernels for now. + # NOTE: VSA sparse attention only supports batch mode (group mode NOT supported, enforced by static_assert) for dtype in ["fp16", "bf16"]: d = factory.get_hdim_tile_size_dict(dtype) if d is None: continue - # for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), ["batch"]): for tile, pipeline in itertools.product( tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) ): if tile.F_bm0 != 128 or tile.F_bn0 != 128: continue - if mode == "group": - if pipeline.F_spad != "t" or pipeline.F_skpad != "t": - # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not - continue - # logits soft-cap is not generated for sparse attention - if not (pipeline.F_logits == "f"): - continue if pipeline.tag != "qr_async_vsa": continue k = FmhaFwdKernel(