Skip to content

Potential Shared Memory Misalignment Issues in GEMM Kernels #42

@trobr

Description

@trobr

Hi ,

While reproducing your GEMM code on SM120 (Blackwell), I encountered an issue in the shared memory section where shm_data failed to align to 128 bytes. I wanted to ask if your implementation might also be susceptible to this risk.

In the code for group_gemm/kernels.cuh:170-173:

__shared__ uint64_t writable[kStage];
__shared__ uint64_t readable[kStage];

extern __shared__ uint8_t shm_data[] alignas(128);
auto *shm_a = reinterpret_cast<Tin *>(shm_data);

This implementation mixes static shared memory (writable/readable) and dynamic shared memory (shm_data). I haven't found explicit NVIDIA documentation explaining the alignment behavior when both types are used together. In my tests on SM120, I tried the following:

__shared__ __align__(16) uint64_t tma_load_mbar[Stage];
extern __shared__ __align__(128) char smem[];
T* a_ptr = reinterpret_cast<T*>(smem);

In this case, smem was only aligned to 32 bytes(Stage=3) instead of 128, which caused the TMA load to fail (Misaligned). nvcc version: Cuda compilation tools, release 12.8, V12.8.61

I noticed that CUTLASS typically uses a SharedStorage structure to manage everything (both barriers and tensor data) within dynamic shared memory, like this:

template <class TypeA, class TypeB, class ASmemLayout, class BSmemLayout>
struct SharedStorage {
  alignas(128) cute::ArrayEngine<TypeA, cute::cosize_v<ASmemLayout>> A;
  alignas(128) cute::ArrayEngine<TypeB, cute::cosize_v<BSmemLayout>> B;
  alignas(16) cute::uint64_t mma_barrier;
  alignas(16) cute::uint64_t tma_barrier;
  // ...
};

I tested a similar "all-in-dynamic-struct" approach on SM120:

template <class T, int STAGE, class ASmemLayout, class BSmemLayout>
struct SharedStorageT {
    __align__(128) char A[cosize(ASmemLayout{})];
    __align__(128) char B[cosize(BSmemLayout{})];
    __align__(16) uint64_t tma_load_mbar[STAGE];
};

__global__ void kernel() {
    // ...

    extern __shared__ __align__(128) char smem[];
    using SharedStorage = SharedStorageT<T, Stage, SmemLayoutA, SmemLayoutB>;
    SharedStorage& storage = *reinterpret_cast<SharedStorage*>(smem);

    // ...
}

This approach resolved the misalignment issues. Since I haven't been able to find an official NVIDIA explanation regarding the alignment guarantees when mixing static and dynamic shared memory, I wanted to reach out and ask for your insights on this matter.

Thank you for your time and help!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions