From 1d15bada23798ae1a1535ee26cbe11177bafa828 Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Fri, 15 May 2026 16:20:59 +0000 Subject: [PATCH 1/7] TBE inference INT-weight nobag: shifted-index store fast path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When D is an exact multiple of (kWarpSize * kOutputsPerThread), skip the half2 scale/bias header at row[0] via a +1 read offset (in scalar_t units) and drop the D_padding shift on output_d. This eliminates the mostly-empty tail iteration that the original loop runs, which on AMD wave-64 for D=256 wastes 63/64 lanes. The branch is hoisted out of the per-iter loop nest — the compiler hoists the predicate but not the branch itself, which would otherwise add ~3-5% per-iter overhead on non-triggering D values. --- ...rd_quantized_split_nbit_kernel_template.cu | 168 +++++++++++++----- 1 file changed, 119 insertions(+), 49 deletions(-) diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu index 7172236b19..125de12308 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu @@ -281,6 +281,16 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no } } + {#- Hoist the writeopt branch outside the per-input_row / per-i loop nest + so the runtime check `D % (kWarpSize*kOutputsPerThread) == 0` is evaluated + once per L-tile instead of per inner iteration. The compiler doesn't + auto-hoist this; observed per-iter `s_cbranch` adds ~3-5% overhead on + non-triggering D values. Both inner-loop bodies are identical except + for the store/accumulate at the bottom — we duplicate the loop nest to + get the branch out, but the binary already contained both paths, so this + adds source duplication without adding binary size. -#} + {% if not nobag %} + {#- Bagged path is unchanged: writeopt is not (yet) applied here. -#} for (uint32_t input_row_idx = 0; input_row_idx < input_rows_in_flight; ++input_row_idx) { #pragma unroll OutputRowsPerThread for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { @@ -289,22 +299,13 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no continue; } const uint32_t* row = reinterpret_cast(&buffers[warp_idx][i][input_row_idx][0]); - // scale and bias are at the beginning of each row. - // rationale: have scale/shift at start since these get loaded first - // and then broadcasted around so it might speed up the first cache miss. {% if emb_weight_type.primitive_type == "INT" %} - // In PackedMode, row pointer may contain several rows from different bags, so each thread/lane should - // read the certain shift_scale related to the row in the packed_bag. half2 shift_scale = reinterpret_cast(row)[PackedMode ? packed_bag_acc_idx * uints_per_row : 0]; {% endif %} - {% if weighted %} float row_weight = buffers_indice_weights[warp_idx][i][input_row_idx][PackedMode ? packed_bag_acc_idx : 0]; {% endif %} - using scalar_t = {{ emb_weight_type.cpp_type_name }}; - - {% if not nobag %} #pragma unroll AccumulateStoreRequests for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; @@ -314,57 +315,126 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no accumulators[i][j].add(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); {% endif %} } - {% else %} - const int32_t output_j = indices_starts[i] + L_start + input_row_idx; - if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { - #pragma unroll AccumulateStoreRequests - for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { - // Read the uint8/4/2 values: note that first 4 Bytes will be ditched later: - // We shift back by 4/8/16 elements to remove the first 4 Bytes (which is garbage due to - // the scale/shift handling). - // Reason: to avoid divergence the first thread in the warp computes garbage. - const auto output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; - scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; - if (output_d < D) { - const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); - VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); - acc.store(&output[output_j][output_d], num_valid_outputs); - } + } + } + {% else %} + {#- Nobag path: hoist the writeopt branch outside the input_row_idx / i loops. -#} + using scalar_t = {{ emb_weight_type.cpp_type_name }}; + {% if emb_weight_type.primitive_type == "INT" %} + // Hoisted runtime check — evaluated once per L-tile (vs once per inner iter + // before hoisting). When true, every (input_row_idx, i) inside takes the + // optimized store path; otherwise every (input_row_idx, i) takes the + // original path. Eliminates per-iter branch overhead on the original path. + if (D % (kWarpSize * kOutputsPerThread) == 0 && D_padding > 0) { + // ============ Optimized branch (writeopt active) ============ + for (uint32_t input_row_idx = 0; input_row_idx < input_rows_in_flight; ++input_row_idx) { + #pragma unroll OutputRowsPerThread + for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { + bool valid = L_start + input_row_idx < Ls[i]; + if (!valid) { + continue; } - } else if constexpr (std::is_same_v) { - // INT8: - // apply per feature row-wise int8 - auto thread_local_min = std::numeric_limits::max(); - auto thread_local_max = std::numeric_limits::lowest(); - float2 qparams; - #pragma unroll AccumulateStoreRequests - for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { - auto output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; - scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; - VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); - if (output_d < D) { + const uint32_t* row = reinterpret_cast(&buffers[warp_idx][i][input_row_idx][0]); + half2 shift_scale = reinterpret_cast(row)[PackedMode ? packed_bag_acc_idx * uints_per_row : 0]; + // Number of scalar_t-sized entries occupied by the per-row scale/bias header. + // Derived from `shift_scale`'s type so this offset auto-updates if the header type + // ever changes. Folds to 1 for all INT weight types. + constexpr uint32_t kHeaderScalarOffset = sizeof(decltype(shift_scale)) / sizeof(scalar_t); + const uint32_t opt_iters = D / (kWarpSize * kOutputsPerThread); + const int32_t output_j = indices_starts[i] + L_start + input_row_idx; + if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { + for (uint32_t j = 0; j < opt_iters; ++j) { + const auto output_d = j * kWarpSize * kOutputsPerThread + threadIdx.x * kOutputsPerThread; + // +kHeaderScalarOffset skips the half2 scale/bias header at row[0] + scalar_t v = reinterpret_cast(row)[j * kWarpSize + threadIdx.x + kHeaderScalarOffset]; + VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::INT> acc(v, shift_scale); + acc.store(&output[output_j][output_d], {{ (32 // emb_weight_type.bit_width) }}); + } + } else if constexpr (std::is_same_v) { + auto thread_local_min = std::numeric_limits::max(); + auto thread_local_max = std::numeric_limits::lowest(); + float2 qparams; + // Pass 1: min/max scan + for (uint32_t j = 0; j < opt_iters; ++j) { + scalar_t v = reinterpret_cast(row)[j * kWarpSize + threadIdx.x + kHeaderScalarOffset]; + VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::INT> acc(v, shift_scale); thread_local_max = max(thread_local_max, float{{ (32 // emb_weight_type.bit_width) }}_max(acc.acc)); thread_local_min = min(thread_local_min, float{{ (32 // emb_weight_type.bit_width) }}_min(acc.acc)); } + qparams = warp_find_qparams(thread_local_min, thread_local_max); + // Pass 2: requantize and store + for (uint32_t j = 0; j < opt_iters; ++j) { + const auto output_d = j * kWarpSize * kOutputsPerThread + threadIdx.x * kOutputsPerThread; + scalar_t v = reinterpret_cast(row)[j * kWarpSize + threadIdx.x + kHeaderScalarOffset]; + VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::INT> acc(v, shift_scale); + acc.store(&output[output_j][output_d], qparams, {{ (32 // emb_weight_type.bit_width) }}); + } + if (threadIdx.x == 0) { + store_qparams_to_row(&output[output_j][D], qparams); + } } - qparams = warp_find_qparams(thread_local_min, thread_local_max); - #pragma unroll AccumulateStoreRequests - for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { - const auto output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; - scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; - if (output_d < D) { - const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); + } + } + } else + {% endif %} + { + // ============ Original branch (no writeopt) ============ + for (uint32_t input_row_idx = 0; input_row_idx < input_rows_in_flight; ++input_row_idx) { + #pragma unroll OutputRowsPerThread + for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { + bool valid = L_start + input_row_idx < Ls[i]; + if (!valid) { + continue; + } + const uint32_t* row = reinterpret_cast(&buffers[warp_idx][i][input_row_idx][0]); + {% if emb_weight_type.primitive_type == "INT" %} + half2 shift_scale = reinterpret_cast(row)[PackedMode ? packed_bag_acc_idx * uints_per_row : 0]; + {% endif %} + const int32_t output_j = indices_starts[i] + L_start + input_row_idx; + if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + const auto output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; + scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; + if (output_d < D) { + const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); + VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); + acc.store(&output[output_j][output_d], num_valid_outputs); + } + } + } else if constexpr (std::is_same_v) { + auto thread_local_min = std::numeric_limits::max(); + auto thread_local_max = std::numeric_limits::lowest(); + float2 qparams; + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + auto output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; + scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); - acc.store(&output[output_j][output_d], qparams, num_valid_outputs); + if (output_d < D) { + thread_local_max = max(thread_local_max, float{{ (32 // emb_weight_type.bit_width) }}_max(acc.acc)); + thread_local_min = min(thread_local_min, float{{ (32 // emb_weight_type.bit_width) }}_min(acc.acc)); + } + } + qparams = warp_find_qparams(thread_local_min, thread_local_max); + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + const auto output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; + scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; + if (output_d < D) { + const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); + VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); + acc.store(&output[output_j][output_d], qparams, num_valid_outputs); + } + } + if (threadIdx.x == 0) { + store_qparams_to_row(&output[output_j][D], qparams); } - } - if (threadIdx.x == 0) { - store_qparams_to_row(&output[output_j][D], qparams); } } - {% endif %} } } + {% endif %} } {% if not nobag %} From db2c70e4553b26ea8e8fb2ae9dbe544c6c59af6d Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Fri, 15 May 2026 18:11:19 +0000 Subject: [PATCH 2/7] Restore original comments and weighted block in non-writeopt branch --- ...rward_quantized_split_nbit_kernel_template.cu | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu index 125de12308..d507bdcabd 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu @@ -387,13 +387,27 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no continue; } const uint32_t* row = reinterpret_cast(&buffers[warp_idx][i][input_row_idx][0]); + // scale and bias are at the beginning of each row. + // rationale: have scale/shift at start since these get loaded first + // and then broadcasted around so it might speed up the first cache miss. {% if emb_weight_type.primitive_type == "INT" %} + // In PackedMode, row pointer may contain several rows from different bags, so each thread/lane should + // read the certain shift_scale related to the row in the packed_bag. half2 shift_scale = reinterpret_cast(row)[PackedMode ? packed_bag_acc_idx * uints_per_row : 0]; {% endif %} + + {% if weighted %} + float row_weight = buffers_indice_weights[warp_idx][i][input_row_idx][PackedMode ? packed_bag_acc_idx : 0]; + {% endif %} + const int32_t output_j = indices_starts[i] + L_start + input_row_idx; if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { #pragma unroll AccumulateStoreRequests for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + // Read the uint8/4/2 values: note that first 4 Bytes will be ditched later: + // We shift back by 4/8/16 elements to remove the first 4 Bytes (which is garbage due to + // the scale/shift handling). + // Reason: to avoid divergence the first thread in the warp computes garbage. const auto output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; if (output_d < D) { @@ -403,6 +417,8 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no } } } else if constexpr (std::is_same_v) { + // INT8: + // apply per feature row-wise int8 auto thread_local_min = std::numeric_limits::max(); auto thread_local_max = std::numeric_limits::lowest(); float2 qparams; From 46acccb3e5123f014ca00e403cf59718aa41ffa2 Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Fri, 15 May 2026 18:29:19 +0000 Subject: [PATCH 3/7] Remove process-history comments from writeopt section --- ...rward_quantized_split_nbit_kernel_template.cu | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu index d507bdcabd..06622d9354 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu @@ -281,16 +281,7 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no } } - {#- Hoist the writeopt branch outside the per-input_row / per-i loop nest - so the runtime check `D % (kWarpSize*kOutputsPerThread) == 0` is evaluated - once per L-tile instead of per inner iteration. The compiler doesn't - auto-hoist this; observed per-iter `s_cbranch` adds ~3-5% overhead on - non-triggering D values. Both inner-loop bodies are identical except - for the store/accumulate at the bottom — we duplicate the loop nest to - get the branch out, but the binary already contained both paths, so this - adds source duplication without adding binary size. -#} {% if not nobag %} - {#- Bagged path is unchanged: writeopt is not (yet) applied here. -#} for (uint32_t input_row_idx = 0; input_row_idx < input_rows_in_flight; ++input_row_idx) { #pragma unroll OutputRowsPerThread for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { @@ -318,15 +309,9 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no } } {% else %} - {#- Nobag path: hoist the writeopt branch outside the input_row_idx / i loops. -#} using scalar_t = {{ emb_weight_type.cpp_type_name }}; {% if emb_weight_type.primitive_type == "INT" %} - // Hoisted runtime check — evaluated once per L-tile (vs once per inner iter - // before hoisting). When true, every (input_row_idx, i) inside takes the - // optimized store path; otherwise every (input_row_idx, i) takes the - // original path. Eliminates per-iter branch overhead on the original path. if (D % (kWarpSize * kOutputsPerThread) == 0 && D_padding > 0) { - // ============ Optimized branch (writeopt active) ============ for (uint32_t input_row_idx = 0; input_row_idx < input_rows_in_flight; ++input_row_idx) { #pragma unroll OutputRowsPerThread for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { @@ -378,7 +363,6 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no } else {% endif %} { - // ============ Original branch (no writeopt) ============ for (uint32_t input_row_idx = 0; input_row_idx < input_rows_in_flight; ++input_row_idx) { #pragma unroll OutputRowsPerThread for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { From 01c7d8ab0804530008a8915d56c4b3c390c593e7 Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Fri, 15 May 2026 18:29:49 +0000 Subject: [PATCH 4/7] Restore comments and whitespace in bagged path --- ...edding_forward_quantized_split_nbit_kernel_template.cu | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu index 06622d9354..7b554d7079 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu @@ -290,13 +290,21 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no continue; } const uint32_t* row = reinterpret_cast(&buffers[warp_idx][i][input_row_idx][0]); + // scale and bias are at the beginning of each row. + // rationale: have scale/shift at start since these get loaded first + // and then broadcasted around so it might speed up the first cache miss. {% if emb_weight_type.primitive_type == "INT" %} + // In PackedMode, row pointer may contain several rows from different bags, so each thread/lane should + // read the certain shift_scale related to the row in the packed_bag. half2 shift_scale = reinterpret_cast(row)[PackedMode ? packed_bag_acc_idx * uints_per_row : 0]; {% endif %} + {% if weighted %} float row_weight = buffers_indice_weights[warp_idx][i][input_row_idx][PackedMode ? packed_bag_acc_idx : 0]; {% endif %} + using scalar_t = {{ emb_weight_type.cpp_type_name }}; + #pragma unroll AccumulateStoreRequests for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; From 15d6906d359dd9830fdd887b3507fe6b055fa4c2 Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Fri, 15 May 2026 18:30:26 +0000 Subject: [PATCH 5/7] Tighten kHeaderScalarOffset comment; document bare-scope construct --- ...mbedding_forward_quantized_split_nbit_kernel_template.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu index 7b554d7079..f853f28335 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu @@ -329,9 +329,7 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no } const uint32_t* row = reinterpret_cast(&buffers[warp_idx][i][input_row_idx][0]); half2 shift_scale = reinterpret_cast(row)[PackedMode ? packed_bag_acc_idx * uints_per_row : 0]; - // Number of scalar_t-sized entries occupied by the per-row scale/bias header. - // Derived from `shift_scale`'s type so this offset auto-updates if the header type - // ever changes. Folds to 1 for all INT weight types. + // Size of the per-row scale/bias header in scalar_t units; derived from shift_scale's type. constexpr uint32_t kHeaderScalarOffset = sizeof(decltype(shift_scale)) / sizeof(scalar_t); const uint32_t opt_iters = D / (kWarpSize * kOutputsPerThread); const int32_t output_j = indices_starts[i] + L_start + input_row_idx; @@ -370,6 +368,8 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no } } else {% endif %} + {#- For non-INT weight types the {%- if %} above renders nothing; this block then + becomes a bare scope holding the unmodified loop. -#} { for (uint32_t input_row_idx = 0; input_row_idx < input_rows_in_flight; ++input_row_idx) { #pragma unroll OutputRowsPerThread From b56d302dd29455998624a252868655f1175b05ff Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Fri, 15 May 2026 18:34:04 +0000 Subject: [PATCH 6/7] Gate writeopt fast path to ROCm builds --- ...mbedding_forward_quantized_split_nbit_kernel_template.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu index f853f28335..50f19682ca 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu @@ -318,7 +318,7 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no } {% else %} using scalar_t = {{ emb_weight_type.cpp_type_name }}; - {% if emb_weight_type.primitive_type == "INT" %} + {% if emb_weight_type.primitive_type == "INT" and is_rocm %} if (D % (kWarpSize * kOutputsPerThread) == 0 && D_padding > 0) { for (uint32_t input_row_idx = 0; input_row_idx < input_rows_in_flight; ++input_row_idx) { #pragma unroll OutputRowsPerThread @@ -368,8 +368,8 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no } } else {% endif %} - {#- For non-INT weight types the {%- if %} above renders nothing; this block then - becomes a bare scope holding the unmodified loop. -#} + {#- For non-INT weight types or non-ROCm builds the {%- if %} above renders + nothing; this block then becomes a bare scope holding the unmodified loop. -#} { for (uint32_t input_row_idx = 0; input_row_idx < input_rows_in_flight; ++input_row_idx) { #pragma unroll OutputRowsPerThread From 89d30b8064cc88d1c3783eb32e6c5dfd094f3109 Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Thu, 21 May 2026 17:46:42 -0500 Subject: [PATCH 7/7] embedding_forward_quantized_split_nbit_host_template.cu: disables all packedMode effects in nobag kernel --- ...mbedding_forward_quantized_split_nbit_host_template.cu | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu index 0edd97a0c6..eb7e59f819 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu @@ -225,8 +225,12 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ constexpr int32_t kWarpsPerBlock = 4; const auto device_only = lxu_cache_weights.numel() == 0 && uvm_weights.numel() == 0; // PackedMode is only available for ROCm devices + {%- if not nobag %} constexpr bool kIsRocm = {{ "true" if is_rocm else "false" }}; const static bool use_rocm_packed_bag_mode = kIsRocm && fbgemm_gpu::config::is_feature_enabled(fbgemm_gpu::config::FeatureGateName::TBE_ROCM_INFERENCE_PACKED_BAGS); + {%- else %} + constexpr bool use_rocm_packed_bag_mode = false; + {%- endif %} /* * Helper macro for run-time packed mode dispatch. Computes maximum number of bags * (num_packed_bags) that fits into NumUint4LoadsPerRow given embeddings' type and @@ -237,9 +241,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ #define PACKED_MODE_SWITCH(dev_only, OutputRowsPerThread, InputRowsInFlight, MinNum128BRows, MaxNum128BRows) \ int32_t num_packed_bags = 1; \ {%-if is_rocm and not nobag %} - const static bool use_packed_bag_mode = fbgemm_gpu::config::is_feature_enabled( \ - fbgemm_gpu::config::FeatureGateName::TBE_ROCM_INFERENCE_PACKED_BAGS); \ - if(use_packed_bag_mode) { \ + if(use_rocm_packed_bag_mode) { \ /* The actual maximum number of uint4 reads per row w.r.t. row size, type and alignment */ \ const int32_t num_uint4_loads_per_row = nbit::div_round_up(nbit::padded_row_size_in_bytes(max_D, sparse_type, row_alignment), sizeof(uint4)); \ constexpr int32_t NumUint4LoadsPerRow = MaxNum128BRows * 128 / sizeof(uint4); \