Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,12 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
constexpr int32_t kWarpsPerBlock = 4;
const auto device_only = lxu_cache_weights.numel() == 0 && uvm_weights.numel() == 0;
// PackedMode is only available for ROCm devices
{%- if not nobag %}
constexpr bool kIsRocm = {{ "true" if is_rocm else "false" }};
const static bool use_rocm_packed_bag_mode = kIsRocm && fbgemm_gpu::config::is_feature_enabled(fbgemm_gpu::config::FeatureGateName::TBE_ROCM_INFERENCE_PACKED_BAGS);
{%- else %}
constexpr bool use_rocm_packed_bag_mode = false;
{%- endif %}
/*
* Helper macro for run-time packed mode dispatch. Computes maximum number of bags
* (num_packed_bags) that fits into NumUint4LoadsPerRow given embeddings' type and
Expand All @@ -237,9 +241,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
#define PACKED_MODE_SWITCH(dev_only, OutputRowsPerThread, InputRowsInFlight, MinNum128BRows, MaxNum128BRows) \
int32_t num_packed_bags = 1; \
{%-if is_rocm and not nobag %}
const static bool use_packed_bag_mode = fbgemm_gpu::config::is_feature_enabled( \
fbgemm_gpu::config::FeatureGateName::TBE_ROCM_INFERENCE_PACKED_BAGS); \
if(use_packed_bag_mode) { \
if(use_rocm_packed_bag_mode) { \
/* The actual maximum number of uint4 reads per row w.r.t. row size, type and alignment */ \
const int32_t num_uint4_loads_per_row = nbit::div_round_up(nbit::padded_row_size_in_bytes(max_D, sparse_type, row_alignment), sizeof(uint4)); \
constexpr int32_t NumUint4LoadsPerRow = MaxNum128BRows * 128 / sizeof(uint4); \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no
}
}

{% if not nobag %}
for (uint32_t input_row_idx = 0; input_row_idx < input_rows_in_flight; ++input_row_idx) {
#pragma unroll OutputRowsPerThread
for (uint32_t i = 0; i < OutputRowsPerThread; ++i) {
Expand All @@ -304,7 +305,6 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no

using scalar_t = {{ emb_weight_type.cpp_type_name }};

{% if not nobag %}
#pragma unroll AccumulateStoreRequests
for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) {
scalar_t v = reinterpret_cast<const scalar_t*>(row)[kWarpSize * j + threadIdx.x];
Expand All @@ -314,57 +314,135 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no
accumulators[i][j].add(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %});
{% endif %}
}
{% else %}
const int32_t output_j = indices_starts[i] + L_start + input_row_idx;
if constexpr (std::is_same_v<output_t, float> || std::is_same_v<output_t, at::Half> || std::is_same_v<output_t, at::BFloat16>) {
#pragma unroll AccumulateStoreRequests
for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) {
// Read the uint8/4/2 values: note that first 4 Bytes will be ditched later:
// We shift back by 4/8/16 elements to remove the first 4 Bytes (which is garbage due to
// the scale/shift handling).
// Reason: to avoid divergence the first thread in the warp computes garbage.
const auto output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding;
scalar_t v = reinterpret_cast<const scalar_t*>(row)[kWarpSize * j + threadIdx.x];
if (output_d < D) {
const int num_valid_outputs = min(static_cast<int>(D - output_d), static_cast<int>({{ (32 // emb_weight_type.bit_width) }}));
VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %});
acc.store(&output[output_j][output_d], num_valid_outputs);
}
}
}
{% else %}
using scalar_t = {{ emb_weight_type.cpp_type_name }};
{% if emb_weight_type.primitive_type == "INT" and is_rocm %}
if (D % (kWarpSize * kOutputsPerThread) == 0 && D_padding > 0) {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ain't always D_padding > 0 in case of quantization to INT?

for (uint32_t input_row_idx = 0; input_row_idx < input_rows_in_flight; ++input_row_idx) {
#pragma unroll OutputRowsPerThread
for (uint32_t i = 0; i < OutputRowsPerThread; ++i) {
bool valid = L_start + input_row_idx < Ls[i];
if (!valid) {
continue;
}
} else if constexpr (std::is_same_v<output_t, uint8_t>) {
// INT8:
// apply per feature row-wise int8
auto thread_local_min = std::numeric_limits<float>::max();
auto thread_local_max = std::numeric_limits<float>::lowest();
float2 qparams;
#pragma unroll AccumulateStoreRequests
for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) {
auto output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding;
scalar_t v = reinterpret_cast<const scalar_t*>(row)[kWarpSize * j + threadIdx.x];
VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %});
if (output_d < D) {
const uint32_t* row = reinterpret_cast<const uint32_t*>(&buffers[warp_idx][i][input_row_idx][0]);
half2 shift_scale = reinterpret_cast<const half2*>(row)[PackedMode ? packed_bag_acc_idx * uints_per_row : 0];
// Size of the per-row scale/bias header in scalar_t units; derived from shift_scale's type.
constexpr uint32_t kHeaderScalarOffset = sizeof(decltype(shift_scale)) / sizeof(scalar_t);
const uint32_t opt_iters = D / (kWarpSize * kOutputsPerThread);
const int32_t output_j = indices_starts[i] + L_start + input_row_idx;
if constexpr (std::is_same_v<output_t, float> || std::is_same_v<output_t, at::Half> || std::is_same_v<output_t, at::BFloat16>) {
for (uint32_t j = 0; j < opt_iters; ++j) {
const auto output_d = j * kWarpSize * kOutputsPerThread + threadIdx.x * kOutputsPerThread;
// +kHeaderScalarOffset skips the half2 scale/bias header at row[0]
scalar_t v = reinterpret_cast<const scalar_t*>(row)[j * kWarpSize + threadIdx.x + kHeaderScalarOffset];
VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::INT> acc(v, shift_scale);
acc.store(&output[output_j][output_d], {{ (32 // emb_weight_type.bit_width) }});
}
} else if constexpr (std::is_same_v<output_t, uint8_t>) {
auto thread_local_min = std::numeric_limits<float>::max();
auto thread_local_max = std::numeric_limits<float>::lowest();
float2 qparams;
// Pass 1: min/max scan
for (uint32_t j = 0; j < opt_iters; ++j) {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Loops with opt_iters worth be investigated as a candidate for manual loop unrolling

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did investigate that. I created a dynamic if-else ladder for different embedding dimensions (256, 512, 768 etc.) and manually unrolling the loops within each case. This negatively impacted the performance, though

scalar_t v = reinterpret_cast<const scalar_t*>(row)[j * kWarpSize + threadIdx.x + kHeaderScalarOffset];
VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::INT> acc(v, shift_scale);
thread_local_max = max(thread_local_max, float{{ (32 // emb_weight_type.bit_width) }}_max(acc.acc));
thread_local_min = min(thread_local_min, float{{ (32 // emb_weight_type.bit_width) }}_min(acc.acc));
}
qparams = warp_find_qparams(thread_local_min, thread_local_max);
// Pass 2: requantize and store
for (uint32_t j = 0; j < opt_iters; ++j) {
const auto output_d = j * kWarpSize * kOutputsPerThread + threadIdx.x * kOutputsPerThread;
scalar_t v = reinterpret_cast<const scalar_t*>(row)[j * kWarpSize + threadIdx.x + kHeaderScalarOffset];
VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::INT> acc(v, shift_scale);
acc.store(&output[output_j][output_d], qparams, {{ (32 // emb_weight_type.bit_width) }});
}
if (threadIdx.x == 0) {
store_qparams_to_row(&output[output_j][D], qparams);
}
}
}
}
} else
{% endif %}
{#- For non-INT weight types or non-ROCm builds the {%- if %} above renders
nothing; this block then becomes a bare scope holding the unmodified loop. -#}
{
for (uint32_t input_row_idx = 0; input_row_idx < input_rows_in_flight; ++input_row_idx) {
#pragma unroll OutputRowsPerThread
for (uint32_t i = 0; i < OutputRowsPerThread; ++i) {
bool valid = L_start + input_row_idx < Ls[i];
if (!valid) {
continue;
}
qparams = warp_find_qparams(thread_local_min, thread_local_max);
#pragma unroll AccumulateStoreRequests
for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) {
const auto output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding;
scalar_t v = reinterpret_cast<const scalar_t*>(row)[kWarpSize * j + threadIdx.x];
if (output_d < D) {
const int num_valid_outputs = min(static_cast<int>(D - output_d), static_cast<int>({{ (32 // emb_weight_type.bit_width) }}));
const uint32_t* row = reinterpret_cast<const uint32_t*>(&buffers[warp_idx][i][input_row_idx][0]);
// scale and bias are at the beginning of each row.
// rationale: have scale/shift at start since these get loaded first
// and then broadcasted around so it might speed up the first cache miss.
{% if emb_weight_type.primitive_type == "INT" %}
// In PackedMode, row pointer may contain several rows from different bags, so each thread/lane should
// read the certain shift_scale related to the row in the packed_bag.
half2 shift_scale = reinterpret_cast<const half2*>(row)[PackedMode ? packed_bag_acc_idx * uints_per_row : 0];
{% endif %}

{% if weighted %}
float row_weight = buffers_indice_weights[warp_idx][i][input_row_idx][PackedMode ? packed_bag_acc_idx : 0];
{% endif %}

const int32_t output_j = indices_starts[i] + L_start + input_row_idx;
if constexpr (std::is_same_v<output_t, float> || std::is_same_v<output_t, at::Half> || std::is_same_v<output_t, at::BFloat16>) {
#pragma unroll AccumulateStoreRequests
for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) {
// Read the uint8/4/2 values: note that first 4 Bytes will be ditched later:
// We shift back by 4/8/16 elements to remove the first 4 Bytes (which is garbage due to
// the scale/shift handling).
// Reason: to avoid divergence the first thread in the warp computes garbage.
const auto output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding;
scalar_t v = reinterpret_cast<const scalar_t*>(row)[kWarpSize * j + threadIdx.x];
if (output_d < D) {
const int num_valid_outputs = min(static_cast<int>(D - output_d), static_cast<int>({{ (32 // emb_weight_type.bit_width) }}));
VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %});
acc.store(&output[output_j][output_d], num_valid_outputs);
}
}
} else if constexpr (std::is_same_v<output_t, uint8_t>) {
// INT8:
// apply per feature row-wise int8
auto thread_local_min = std::numeric_limits<float>::max();
auto thread_local_max = std::numeric_limits<float>::lowest();
float2 qparams;
#pragma unroll AccumulateStoreRequests
for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) {
auto output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding;
scalar_t v = reinterpret_cast<const scalar_t*>(row)[kWarpSize * j + threadIdx.x];
VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %});
acc.store(&output[output_j][output_d], qparams, num_valid_outputs);
if (output_d < D) {
thread_local_max = max(thread_local_max, float{{ (32 // emb_weight_type.bit_width) }}_max(acc.acc));
thread_local_min = min(thread_local_min, float{{ (32 // emb_weight_type.bit_width) }}_min(acc.acc));
}
}
qparams = warp_find_qparams(thread_local_min, thread_local_max);
#pragma unroll AccumulateStoreRequests
for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) {
const auto output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding;
scalar_t v = reinterpret_cast<const scalar_t*>(row)[kWarpSize * j + threadIdx.x];
if (output_d < D) {
const int num_valid_outputs = min(static_cast<int>(D - output_d), static_cast<int>({{ (32 // emb_weight_type.bit_width) }}));
VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %});
acc.store(&output[output_j][output_d], qparams, num_valid_outputs);
}
}
if (threadIdx.x == 0) {
store_qparams_to_row(&output[output_j][D], qparams);
}
}
if (threadIdx.x == 0) {
store_qparams_to_row(&output[output_j][D], qparams);
}
}
{% endif %}
}
}
{% endif %}
}

{% if not nobag %}
Expand Down