Skip to content

Scalar global memref.load/store hardcodes SRD s[8:11], so multi-buffer kernels alias arguments #1117

@powderluv

Description

@powderluv

Summary

waveasm-translate's scalar global memref.load / memref.store lowering appears to hardcode the SRD base register to s[8:11], so kernels with multiple global memref arguments alias those buffers incorrectly.

I hit this while integrating WaveASM as a Baybridge backend on both gfx942 (MI300) and gfx950 (MI355).

Minimal repro

Input MLIR:

module {
  gpu.module @kernels {
    gpu.func @add_kernel(%a: memref<4xf32>, %b: memref<4xf32>, %c: memref<4xf32>) kernel {
      %idx = gpu.thread_id x
      %a_val = memref.load %a[%idx] : memref<4xf32>
      %b_val = memref.load %b[%idx] : memref<4xf32>
      %sum = arith.addf %a_val, %b_val : f32
      memref.store %sum, %c[%idx] : memref<4xf32>
      gpu.return
    }
  }
}

Commands:

waveasm-translate --target=gfx942 add.mlir > add.waveasm.mlir
waveasm-translate --target=gfx942 \
  --waveasm-scoped-cse \
  --waveasm-peephole \
  --waveasm-scale-pack-elimination \
  --loop-invariant-code-motion \
  --waveasm-m0-redundancy-elim \
  --waveasm-buffer-load-strength-reduction \
  --waveasm-memory-offset-opt \
  --canonicalize \
  --waveasm-scoped-cse \
  --waveasm-loop-address-promotion \
  '--waveasm-linear-scan=max-vgprs=512 max-agprs=512' \
  --waveasm-insert-waitcnt=ticketed-waitcnt=true \
  --waveasm-hazard-mitigation=target=gfx942 \
  --emit-assembly \
  add.waveasm.mlir > add.s

Intermediate WaveASM IR looks reasonable and contains 3 SRD setups plus distinct precolored SGPR ranges:

waveasm.program @add_kernel target = <#waveasm.gfx942> abi = <> attributes {num_kernel_args = 3 : i64} {
  ...
  %8 = waveasm.precolored.sreg 8, 4 : !waveasm.sreg<4, 4>
  ...
  %9 = waveasm.precolored.sreg 12, 4 : !waveasm.sreg<4, 4>
  ...
  %10 = waveasm.precolored.sreg 16, 4 : !waveasm.sreg<4, 4>
  ...
}

But the emitted assembly uses s[8:11] for both loads and the store:

buffer_load_dword v0, v1, s[8:11], 0 offen
buffer_load_dword v2, v1, s[8:11], 0 offen
v_add_f32 v3, v0, v2
buffer_store_dword v3, v1, s[8:11], 0 offen

Expected would be distinct SRDs, e.g.:

buffer_load_dword ..., s[8:11], ...   ; %a
buffer_load_dword ..., s[12:15], ...  ; %b
buffer_store_dword ..., s[16:19], ... ; %c

Source pointer

The current scalar handlers appear to create a fresh hardcoded precolored SRD each time:

auto sregType = ctx.createSRegType(4, 4);
auto srd = PrecoloredSRegOp::create(builder, loc, sregType, 8, 4);

This occurs in both handleMemRefLoad and handleMemRefStore in waveasm/lib/Transforms/handlers/MemRefHandlers.cpp.

Observed runtime impact

This is not just cosmetic in the assembly:

  • on MI355 (gfx950), a simple pointwise add kernel launched successfully but produced all zeros
  • on MI300 (gfx942), the same pattern hit HIP status 700 illegal memory access

I also tried routing the same kernel through vector.load / vector.store with vector<1xf32> to avoid the scalar memref path. That produced distinct SRDs in assembly, so this looks specific to the scalar global memref.load/store handlers.

Request

Can you confirm whether this is an expected limitation of the scalar global memref path, or a bug in handleMemRefLoad / handleMemRefStore register/SRD selection?

If it is a bug, I can help test a fix quickly on gfx942 and gfx950.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions