diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu index 0edd97a0c6..eb7e59f819 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu @@ -225,8 +225,12 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ constexpr int32_t kWarpsPerBlock = 4; const auto device_only = lxu_cache_weights.numel() == 0 && uvm_weights.numel() == 0; // PackedMode is only available for ROCm devices + {%- if not nobag %} constexpr bool kIsRocm = {{ "true" if is_rocm else "false" }}; const static bool use_rocm_packed_bag_mode = kIsRocm && fbgemm_gpu::config::is_feature_enabled(fbgemm_gpu::config::FeatureGateName::TBE_ROCM_INFERENCE_PACKED_BAGS); + {%- else %} + constexpr bool use_rocm_packed_bag_mode = false; + {%- endif %} /* * Helper macro for run-time packed mode dispatch. Computes maximum number of bags * (num_packed_bags) that fits into NumUint4LoadsPerRow given embeddings' type and @@ -237,9 +241,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ #define PACKED_MODE_SWITCH(dev_only, OutputRowsPerThread, InputRowsInFlight, MinNum128BRows, MaxNum128BRows) \ int32_t num_packed_bags = 1; \ {%-if is_rocm and not nobag %} - const static bool use_packed_bag_mode = fbgemm_gpu::config::is_feature_enabled( \ - fbgemm_gpu::config::FeatureGateName::TBE_ROCM_INFERENCE_PACKED_BAGS); \ - if(use_packed_bag_mode) { \ + if(use_rocm_packed_bag_mode) { \ /* The actual maximum number of uint4 reads per row w.r.t. row size, type and alignment */ \ const int32_t num_uint4_loads_per_row = nbit::div_round_up(nbit::padded_row_size_in_bytes(max_D, sparse_type, row_alignment), sizeof(uint4)); \ constexpr int32_t NumUint4LoadsPerRow = MaxNum128BRows * 128 / sizeof(uint4); \ diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu index 7172236b19..50f19682ca 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu @@ -281,6 +281,7 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no } } + {% if not nobag %} for (uint32_t input_row_idx = 0; input_row_idx < input_rows_in_flight; ++input_row_idx) { #pragma unroll OutputRowsPerThread for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { @@ -304,7 +305,6 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no using scalar_t = {{ emb_weight_type.cpp_type_name }}; - {% if not nobag %} #pragma unroll AccumulateStoreRequests for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; @@ -314,57 +314,135 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no accumulators[i][j].add(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); {% endif %} } - {% else %} - const int32_t output_j = indices_starts[i] + L_start + input_row_idx; - if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { - #pragma unroll AccumulateStoreRequests - for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { - // Read the uint8/4/2 values: note that first 4 Bytes will be ditched later: - // We shift back by 4/8/16 elements to remove the first 4 Bytes (which is garbage due to - // the scale/shift handling). - // Reason: to avoid divergence the first thread in the warp computes garbage. - const auto output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; - scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; - if (output_d < D) { - const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); - VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); - acc.store(&output[output_j][output_d], num_valid_outputs); - } + } + } + {% else %} + using scalar_t = {{ emb_weight_type.cpp_type_name }}; + {% if emb_weight_type.primitive_type == "INT" and is_rocm %} + if (D % (kWarpSize * kOutputsPerThread) == 0 && D_padding > 0) { + for (uint32_t input_row_idx = 0; input_row_idx < input_rows_in_flight; ++input_row_idx) { + #pragma unroll OutputRowsPerThread + for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { + bool valid = L_start + input_row_idx < Ls[i]; + if (!valid) { + continue; } - } else if constexpr (std::is_same_v) { - // INT8: - // apply per feature row-wise int8 - auto thread_local_min = std::numeric_limits::max(); - auto thread_local_max = std::numeric_limits::lowest(); - float2 qparams; - #pragma unroll AccumulateStoreRequests - for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { - auto output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; - scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; - VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); - if (output_d < D) { + const uint32_t* row = reinterpret_cast(&buffers[warp_idx][i][input_row_idx][0]); + half2 shift_scale = reinterpret_cast(row)[PackedMode ? packed_bag_acc_idx * uints_per_row : 0]; + // Size of the per-row scale/bias header in scalar_t units; derived from shift_scale's type. + constexpr uint32_t kHeaderScalarOffset = sizeof(decltype(shift_scale)) / sizeof(scalar_t); + const uint32_t opt_iters = D / (kWarpSize * kOutputsPerThread); + const int32_t output_j = indices_starts[i] + L_start + input_row_idx; + if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { + for (uint32_t j = 0; j < opt_iters; ++j) { + const auto output_d = j * kWarpSize * kOutputsPerThread + threadIdx.x * kOutputsPerThread; + // +kHeaderScalarOffset skips the half2 scale/bias header at row[0] + scalar_t v = reinterpret_cast(row)[j * kWarpSize + threadIdx.x + kHeaderScalarOffset]; + VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::INT> acc(v, shift_scale); + acc.store(&output[output_j][output_d], {{ (32 // emb_weight_type.bit_width) }}); + } + } else if constexpr (std::is_same_v) { + auto thread_local_min = std::numeric_limits::max(); + auto thread_local_max = std::numeric_limits::lowest(); + float2 qparams; + // Pass 1: min/max scan + for (uint32_t j = 0; j < opt_iters; ++j) { + scalar_t v = reinterpret_cast(row)[j * kWarpSize + threadIdx.x + kHeaderScalarOffset]; + VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::INT> acc(v, shift_scale); thread_local_max = max(thread_local_max, float{{ (32 // emb_weight_type.bit_width) }}_max(acc.acc)); thread_local_min = min(thread_local_min, float{{ (32 // emb_weight_type.bit_width) }}_min(acc.acc)); } + qparams = warp_find_qparams(thread_local_min, thread_local_max); + // Pass 2: requantize and store + for (uint32_t j = 0; j < opt_iters; ++j) { + const auto output_d = j * kWarpSize * kOutputsPerThread + threadIdx.x * kOutputsPerThread; + scalar_t v = reinterpret_cast(row)[j * kWarpSize + threadIdx.x + kHeaderScalarOffset]; + VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::INT> acc(v, shift_scale); + acc.store(&output[output_j][output_d], qparams, {{ (32 // emb_weight_type.bit_width) }}); + } + if (threadIdx.x == 0) { + store_qparams_to_row(&output[output_j][D], qparams); + } + } + } + } + } else + {% endif %} + {#- For non-INT weight types or non-ROCm builds the {%- if %} above renders + nothing; this block then becomes a bare scope holding the unmodified loop. -#} + { + for (uint32_t input_row_idx = 0; input_row_idx < input_rows_in_flight; ++input_row_idx) { + #pragma unroll OutputRowsPerThread + for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { + bool valid = L_start + input_row_idx < Ls[i]; + if (!valid) { + continue; } - qparams = warp_find_qparams(thread_local_min, thread_local_max); - #pragma unroll AccumulateStoreRequests - for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { - const auto output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; - scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; - if (output_d < D) { - const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); + const uint32_t* row = reinterpret_cast(&buffers[warp_idx][i][input_row_idx][0]); + // scale and bias are at the beginning of each row. + // rationale: have scale/shift at start since these get loaded first + // and then broadcasted around so it might speed up the first cache miss. + {% if emb_weight_type.primitive_type == "INT" %} + // In PackedMode, row pointer may contain several rows from different bags, so each thread/lane should + // read the certain shift_scale related to the row in the packed_bag. + half2 shift_scale = reinterpret_cast(row)[PackedMode ? packed_bag_acc_idx * uints_per_row : 0]; + {% endif %} + + {% if weighted %} + float row_weight = buffers_indice_weights[warp_idx][i][input_row_idx][PackedMode ? packed_bag_acc_idx : 0]; + {% endif %} + + const int32_t output_j = indices_starts[i] + L_start + input_row_idx; + if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + // Read the uint8/4/2 values: note that first 4 Bytes will be ditched later: + // We shift back by 4/8/16 elements to remove the first 4 Bytes (which is garbage due to + // the scale/shift handling). + // Reason: to avoid divergence the first thread in the warp computes garbage. + const auto output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; + scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; + if (output_d < D) { + const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); + VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); + acc.store(&output[output_j][output_d], num_valid_outputs); + } + } + } else if constexpr (std::is_same_v) { + // INT8: + // apply per feature row-wise int8 + auto thread_local_min = std::numeric_limits::max(); + auto thread_local_max = std::numeric_limits::lowest(); + float2 qparams; + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + auto output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; + scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); - acc.store(&output[output_j][output_d], qparams, num_valid_outputs); + if (output_d < D) { + thread_local_max = max(thread_local_max, float{{ (32 // emb_weight_type.bit_width) }}_max(acc.acc)); + thread_local_min = min(thread_local_min, float{{ (32 // emb_weight_type.bit_width) }}_min(acc.acc)); + } + } + qparams = warp_find_qparams(thread_local_min, thread_local_max); + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + const auto output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; + scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; + if (output_d < D) { + const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); + VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); + acc.store(&output[output_j][output_d], qparams, num_valid_outputs); + } + } + if (threadIdx.x == 0) { + store_qparams_to_row(&output[output_j][D], qparams); } - } - if (threadIdx.x == 0) { - store_qparams_to_row(&output[output_j][D], qparams); } } - {% endif %} } } + {% endif %} } {% if not nobag %}