Skip to content

Optimizing writes in nobag inference kernel#157

Open
aryaman-gupta wants to merge 7 commits into
aryaman/upstreamfrom
aryaman/writeopt-int-nobag-store
Open

Optimizing writes in nobag inference kernel#157
aryaman-gupta wants to merge 7 commits into
aryaman/upstreamfrom
aryaman/writeopt-int-nobag-store

Conversation

@aryaman-gupta
Copy link
Copy Markdown

@aryaman-gupta aryaman-gupta commented May 15, 2026

Summary

Adds a shifted-index store fast path for the INT-weight nobag inference forward kernel ({INT8,INT4,INT2}_split_embedding_nobag_codegen_forward_unweighted_kernel_small_L) on ROCm.

When D % (kWarpSize * kOutputsPerThread) == 0 and D_padding > 0, the original store loop wastes one full iteration: the last j iteration is fully masked out by the output_d < D guard because the per-row header offset (-D_padding) shifts every lane below D. The fast path iterates D / (kWarpSize * kOutputsPerThread) times instead of (MaxNum128BRows + 1) / 2, drops the bounds check, and reads weights at [j * kWarpSize + threadIdx.x + kHeaderScalarOffset] to absorb the header offset on the load side.

The runtime check is hoisted outside the per-input_row_idx / per-i loops to avoid ~3-5% per-iter branch overhead on non-triggering D values. Bagged path is unchanged.

Scope

  • ROCm-only (is_rocm Jinja gate); CUDA path is byte-identical to upstream.
  • INT weight types only; FP weight types unchanged.
  • Nobag path only; bagged path unchanged.

Validation

  • Bitwise correctness: 16/16 configs PASS (INT8/INT4/INT2 weights × fp16/INT8 outputs × multiple D values).
  • Bagged forward unit tests pass.
  • Perf on MI350 (INT8, fp16 out, 2 tables, L=900, 100 iters): triggering D=256 sees the win; non-triggering D ∈ {128, 384} unchanged within noise.

When D is an exact multiple of (kWarpSize * kOutputsPerThread), skip
the half2 scale/bias header at row[0] via a +1 read offset (in
scalar_t units) and drop the D_padding shift on output_d. This
eliminates the mostly-empty tail iteration that the original loop
runs, which on AMD wave-64 for D=256 wastes 63/64 lanes.

The branch is hoisted out of the per-iter loop nest — the compiler
hoists the predicate but not the branch itself, which would otherwise
add ~3-5% per-iter overhead on non-triggering D values.
@aryaman-gupta aryaman-gupta marked this pull request as ready for review May 19, 2026 11:49
Copy link
Copy Markdown

@avbokovoy avbokovoy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM with potential for further improvements

{% else %}
using scalar_t = {{ emb_weight_type.cpp_type_name }};
{% if emb_weight_type.primitive_type == "INT" and is_rocm %}
if (D % (kWarpSize * kOutputsPerThread) == 0 && D_padding > 0) {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ain't always D_padding > 0 in case of quantization to INT?

auto thread_local_max = std::numeric_limits<float>::lowest();
float2 qparams;
// Pass 1: min/max scan
for (uint32_t j = 0; j < opt_iters; ++j) {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Loops with opt_iters worth be investigated as a candidate for manual loop unrolling

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did investigate that. I created a dynamic if-else ladder for different embedding dimensions (256, 512, 768 etc.) and manually unrolling the loops within each case. This negatively impacted the performance, though

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants