From 2d224f5288c48c88cbc849b0b094d22ea9fc8871 Mon Sep 17 00:00:00 2001 From: Sai Hemanth Bheemreddy Date: Tue, 28 Apr 2026 23:02:02 +0000 Subject: [PATCH 1/2] Updated `replace_wmma_intrinsics` to support sm_90 --- compilation/KernelTranslation.cpp | 1 - compilation/KernelTranslation/src/tool.cpp | 1895 +++++++++++++---- ...-nvptx64-nvidia-cuda-sm_90-4b676c53.bc.tmp | 0 runtime/src/vortex/kernel/cudaKernelImpl.cpp | 144 -- 4 files changed, 1539 insertions(+), 501 deletions(-) create mode 100644 examples/sgemm_tcu/sgemm_tcu-cuda-nvptx64-nvidia-cuda-sm_90-4b676c53.bc.tmp diff --git a/compilation/KernelTranslation.cpp b/compilation/KernelTranslation.cpp index e141d4c..34a587a 100644 --- a/compilation/KernelTranslation.cpp +++ b/compilation/KernelTranslation.cpp @@ -123,7 +123,6 @@ int main(int argc, char **argv) { DBG_LOG("replace_wmma_intrinsics\n"); replace_wmma_intrinsics(program); - replace_nvvm_wmma_intrinsics(program); dumpFile(program, "5_after_wmma.ll"); VerifyModule(program); diff --git a/compilation/KernelTranslation/src/tool.cpp b/compilation/KernelTranslation/src/tool.cpp index 89e5489..84cc0c1 100644 --- a/compilation/KernelTranslation/src/tool.cpp +++ b/compilation/KernelTranslation/src/tool.cpp @@ -2643,8 +2643,7 @@ void replace_cp_async_with_dxa(llvm::Module *M) { } } -// WMMA Section - +// WMMA descriptor types enum class WmmaOpKind { Unknown, Load, Store, Mma }; enum class WmmaFragKind { Unknown, MatrixA, MatrixB, Accumulator }; enum class WmmaElemKind { Unknown, F16, F32, U8, S8, S32 }; @@ -2659,12 +2658,14 @@ struct FragmentSig { int K = 0; bool isValid() const { - return Kind != WmmaFragKind::Unknown && M > 0 && N > 0 && K > 0; + return Kind != WmmaFragKind::Unknown && Elem != WmmaElemKind::Unknown && + M > 0 && N > 0 && K > 0; } }; struct WmmaCallDesc { WmmaOpKind Op = WmmaOpKind::Unknown; + SmallVector OrderedFrags; FragmentSig D; @@ -2674,272 +2675,633 @@ struct WmmaCallDesc { WmmaLayout StoreMemLayout = WmmaLayout::None; std::string ReplacementName; + + bool FromNVVMIntrinsic = false; + + unsigned ARegCount = 0; + unsigned BRegCount = 0; + unsigned CRegCount = 0; + unsigned DRegCount = 0; }; +// Key: old NVVM producer call. +// Value: pointer to fragment storage to pass to runtime helpers. +using FragStorageMap = DenseMap; + +// WMMA Utility Functions std::string maybeDemangle(StringRef Name) { - if (Name.starts_with("_Z") || Name.starts_with("?")) { + if (Name.starts_with("_Z") || Name.starts_with("?")) return demangle(Name.str()); - } + return Name.str(); } -DenseMap collectWMMADesc(llvm::Module *M) { - DenseMap CallToDesc; +static const char *toLoadFragSuffix(WmmaFragKind K) { + switch (K) { + case WmmaFragKind::MatrixA: + return "a"; + case WmmaFragKind::MatrixB: + return "b"; + case WmmaFragKind::Accumulator: + return "c"; + default: + return "unknown"; + } +} + +static const char *toElemSuffix(WmmaElemKind K) { + switch (K) { + case WmmaElemKind::F16: + return "f16"; + case WmmaElemKind::F32: + return "f32"; + case WmmaElemKind::U8: + return "u8"; + case WmmaElemKind::S8: + return "s8"; + case WmmaElemKind::S32: + return "s32"; + default: + return "unknown"; + } +} + +static const char *toLayoutSuffix(WmmaLayout L) { + switch (L) { + case WmmaLayout::RowMajor: + return "row"; + case WmmaLayout::ColMajor: + return "col"; + case WmmaLayout::Dynamic: + return "dyn"; + case WmmaLayout::None: + return "none"; + default: + return "unknown"; + } +} + +static const FragmentSig *getLoadFragment(const WmmaCallDesc &Desc) { + if (Desc.A.isValid()) + return &Desc.A; + if (Desc.B.isValid()) + return &Desc.B; + if (Desc.C.isValid()) + return &Desc.C; + return nullptr; +} + +static unsigned getLoadRegCount(const WmmaCallDesc &Desc) { + if (Desc.A.isValid()) + return Desc.ARegCount; + if (Desc.B.isValid()) + return Desc.BRegCount; + if (Desc.C.isValid()) + return Desc.CRegCount; + return 0; +} + +static void buildReplacementName(WmmaCallDesc &Desc) { + Desc.ReplacementName.clear(); + raw_string_ostream OS(Desc.ReplacementName); + + if (Desc.Op == WmmaOpKind::Load) { + const FragmentSig *Frag = getLoadFragment(Desc); + if (!Frag) + return; + + OS << "__vx_wmma_load_" << toLoadFragSuffix(Frag->Kind) << "_m" << Frag->M + << "n" << Frag->N << "k" << Frag->K << "_" + << toLayoutSuffix(Frag->Layout) << "_" << toElemSuffix(Frag->Elem); + + } else if (Desc.Op == WmmaOpKind::Store) { + if (!Desc.D.isValid()) + return; + + OS << "__vx_wmma_store_d" + << "_m" << Desc.D.M << "n" << Desc.D.N << "k" << Desc.D.K << "_" + << toElemSuffix(Desc.D.Elem); + + } else if (Desc.Op == WmmaOpKind::Mma) { + if (!Desc.D.isValid() || !Desc.A.isValid() || !Desc.B.isValid() || + !Desc.C.isValid()) + return; + + OS << "__vx_wmma_mma" + << "_m" << Desc.D.M << "n" << Desc.D.N << "k" << Desc.D.K << "_" + << toLayoutSuffix(Desc.A.Layout) << "_" << toLayoutSuffix(Desc.B.Layout) + << "_" << toElemSuffix(Desc.A.Elem) << "_" << toElemSuffix(Desc.B.Elem) + << "_" << toElemSuffix(Desc.D.Elem); + } + + OS.flush(); +} + +static bool parseShape(StringRef Tok, int &M, int &N, int &K) { + if (!Tok.starts_with("m")) + return false; + + size_t NPos = Tok.find('n'); + size_t KPos = Tok.find('k'); + + if (NPos == StringRef::npos || KPos == StringRef::npos || NPos >= KPos) + return false; + + StringRef MStr = Tok.slice(1, NPos); + StringRef NStr = Tok.slice(NPos + 1, KPos); + StringRef KStr = Tok.drop_front(KPos + 1); + + return !MStr.getAsInteger(10, M) && !NStr.getAsInteger(10, N) && + !KStr.getAsInteger(10, K); +} + +static WmmaLayout parseLayoutToken(StringRef Tok) { + if (Tok == "row") + return WmmaLayout::RowMajor; + if (Tok == "col") + return WmmaLayout::ColMajor; + return WmmaLayout::None; +} + +static WmmaElemKind parseElemToken(StringRef Tok) { + if (Tok == "f16") + return WmmaElemKind::F16; + if (Tok == "f32") + return WmmaElemKind::F32; + if (Tok == "u8") + return WmmaElemKind::U8; + if (Tok == "s8") + return WmmaElemKind::S8; + if (Tok == "s32") + return WmmaElemKind::S32; + return WmmaElemKind::Unknown; +} + +static WmmaElemKind elemFromLLVMType(Type *Ty) { + if (auto *VT = dyn_cast(Ty)) + Ty = VT->getElementType(); + + if (Ty->isHalfTy()) + return WmmaElemKind::F16; + + if (Ty->isFloatTy()) + return WmmaElemKind::F32; + + if (Ty->isIntegerTy(32)) + return WmmaElemKind::S32; + + // LLVM integer types do not preserve signedness. + if (Ty->isIntegerTy(8)) + return WmmaElemKind::Unknown; + + return WmmaElemKind::Unknown; +} + +static unsigned getStructNumElements(Type *Ty) { + auto *ST = dyn_cast(Ty); + if (!ST) + return 0; + return ST->getNumElements(); +} + +static Type *getStructElementType(Type *Ty, unsigned Index) { + auto *ST = dyn_cast(Ty); + if (!ST || Index >= ST->getNumElements()) + return nullptr; + return ST->getElementType(Index); +} + +// Runtime fragment register storage. +// A/B f16 fragments are packed into i32 registers. +// C/D f32 fragments are float registers. +static Type *getRuntimeFragmentRegTy(LLVMContext &Ctx, const FragmentSig &Sig) { + switch (Sig.Elem) { + case WmmaElemKind::F32: + return Type::getFloatTy(Ctx); + + case WmmaElemKind::F16: + case WmmaElemKind::U8: + case WmmaElemKind::S8: + case WmmaElemKind::S32: + return Type::getInt32Ty(Ctx); + + default: + return nullptr; + } +} + +static AllocaInst *createEntryAlloca(Function *F, Type *Ty, StringRef Name) { + IRBuilder<> B(&*F->getEntryBlock().getFirstInsertionPt()); + return B.CreateAlloca(Ty, nullptr, Name); +} + +static AllocaInst *createFragmentAlloca(Function *F, const FragmentSig &Sig, + unsigned RegCount, StringRef Name) { + LLVMContext &Ctx = F->getContext(); + + Type *RegTy = getRuntimeFragmentRegTy(Ctx, Sig); + if (!RegTy) + return nullptr; - auto toFragSuffix = [](WmmaFragKind K) -> const char * { - switch (K) { - case WmmaFragKind::MatrixA: - return "a"; - case WmmaFragKind::MatrixB: - return "b"; - case WmmaFragKind::Accumulator: - return "acc"; - default: - return "unknown"; + Type *ArrayTy = ArrayType::get(RegTy, RegCount); + return createEntryAlloca(F, ArrayTy, Name); +} + +static bool inferABRegCountsForMma(const FragmentSig &A, const FragmentSig &B, + unsigned ABTotal, unsigned &ARegs, + unsigned &BRegs) { + if (A.Elem == WmmaElemKind::F16 && B.Elem == WmmaElemKind::F16) { + if (A.M == 16 && A.N == 16 && A.K == 16 && ABTotal == 16) { + ARegs = 8; + BRegs = 8; + return true; } - }; - auto toElemSuffix = [](WmmaElemKind K) -> const char * { - switch (K) { - case WmmaElemKind::F16: - return "f16"; - case WmmaElemKind::F32: - return "f32"; - case WmmaElemKind::U8: - return "u8"; - case WmmaElemKind::S8: - return "s8"; - case WmmaElemKind::S32: - return "s32"; - default: - return "unknown"; + if (A.M == 32 && A.N == 8 && A.K == 16 && ABTotal == 12) { + ARegs = 8; + BRegs = 4; + return true; } - }; - auto toLayoutSuffix = [](WmmaLayout L) -> const char * { - switch (L) { - case WmmaLayout::RowMajor: - return "row"; - case WmmaLayout::ColMajor: - return "col"; - case WmmaLayout::Dynamic: - return "dyn"; - case WmmaLayout::None: - return "none"; - default: - return "unknown"; + if (A.M == 8 && A.N == 32 && A.K == 16 && ABTotal == 12) { + ARegs = 4; + BRegs = 8; + return true; } - }; + } - for (Function &F : *M) { - for (BasicBlock &BB : F) { - for (Instruction &I : BB) { - auto *CI = dyn_cast(&I); - if (!CI) - continue; + if (ABTotal % 2 == 0) { + ARegs = ABTotal / 2; + BRegs = ABTotal / 2; + return true; + } - Function *Callee = CI->getCalledFunction(); - if (!Callee) - continue; + return false; +} - std::string Name = maybeDemangle(Callee->getName()); - StringRef Demangled(Name); +// NVVM descriptor parser +static bool fillDescFromNVVMIntrinsic(CallInst *CI, WmmaCallDesc &Desc) { + Function *Callee = CI->getCalledFunction(); + if (!Callee) + return false; - if (!Demangled.contains("nvcuda::wmma::")) - continue; + StringRef Name = Callee->getName(); - WmmaCallDesc Desc; + if (!Name.starts_with("llvm.nvvm.wmma.")) + return false; - if (Demangled.contains("load_matrix_sync(")) - Desc.Op = WmmaOpKind::Load; - else if (Demangled.contains("store_matrix_sync(")) - Desc.Op = WmmaOpKind::Store; - else if (Demangled.contains("mma_sync(")) - Desc.Op = WmmaOpKind::Mma; - else - continue; + SmallVector Parts; + Name.split(Parts, '.', -1, false); - size_t SearchFrom = 0; - while (true) { - size_t Pos = Demangled.find("fragment<", SearchFrom); - if (Pos == StringRef::npos) - break; + if (Parts.size() < 6) + return false; - size_t Start = Pos + strlen("fragment<"); - int Depth = 1; - size_t End = Start; - for (; End < Demangled.size(); ++End) { - if (Demangled[End] == '<') - ++Depth; - else if (Demangled[End] == '>') { - --Depth; - if (Depth == 0) - break; - } - } - if (End >= Demangled.size()) - break; + if (Parts[0] != "llvm" || Parts[1] != "nvvm" || Parts[2] != "wmma") + return false; - StringRef Body = Demangled.slice(Start, End); - - SmallVector Parts; - size_t PartStart = 0; - int AngleDepth = 0; - for (size_t J = 0; J < Body.size(); ++J) { - if (Body[J] == '<') - ++AngleDepth; - else if (Body[J] == '>') - --AngleDepth; - else if (Body[J] == ',' && AngleDepth == 0) { - Parts.push_back(Body.slice(PartStart, J).trim()); - PartStart = J + 1; - } - } - Parts.push_back(Body.drop_front(PartStart).trim()); - - if (Parts.size() == 6) { - FragmentSig Sig; - - if (Parts[0].contains("matrix_a")) - Sig.Kind = WmmaFragKind::MatrixA; - else if (Parts[0].contains("matrix_b")) - Sig.Kind = WmmaFragKind::MatrixB; - else if (Parts[0].contains("accumulator")) - Sig.Kind = WmmaFragKind::Accumulator; - - if (Parts[1].getAsInteger(10, Sig.M) || - Parts[2].getAsInteger(10, Sig.N) || - Parts[3].getAsInteger(10, Sig.K)) { - SearchFrom = End + 1; - continue; - } + int M = 0, N = 0, K = 0; + if (!parseShape(Parts[3], M, N, K)) + return false; - if (Parts[4] == "__half" || Parts[4].ends_with("__half")) - Sig.Elem = WmmaElemKind::F16; - else if (Parts[4] == "float") - Sig.Elem = WmmaElemKind::F32; - else if (Parts[4] == "unsigned char") - Sig.Elem = WmmaElemKind::U8; - else if (Parts[4] == "signed char") - Sig.Elem = WmmaElemKind::S8; - else if (Parts[4] == "int") - Sig.Elem = WmmaElemKind::S32; - - if (Parts[5] == "void") - Sig.Layout = WmmaLayout::None; - else if (Parts[5].contains("row_major")) - Sig.Layout = WmmaLayout::RowMajor; - else if (Parts[5].contains("col_major")) - Sig.Layout = WmmaLayout::ColMajor; - - if (Sig.isValid()) - Desc.OrderedFrags.push_back(Sig); - } + Desc = WmmaCallDesc{}; + Desc.FromNVVMIntrinsic = true; - SearchFrom = End + 1; - } + if (Parts[4] == "load") { + // llvm.nvvm.wmma.m32n8k16.load.a.row.stride.f16.p0 + if (Parts.size() < 9) + return false; - if (Desc.Op == WmmaOpKind::Load && Desc.OrderedFrags.size() == 1) { - if (Desc.OrderedFrags[0].Kind == WmmaFragKind::MatrixA) - Desc.A = Desc.OrderedFrags[0]; - else if (Desc.OrderedFrags[0].Kind == WmmaFragKind::MatrixB) - Desc.B = Desc.OrderedFrags[0]; - } else if (Desc.Op == WmmaOpKind::Store && - Desc.OrderedFrags.size() == 1) { - Desc.D = Desc.OrderedFrags[0]; - if (CI->arg_size() >= 4) { - if (auto *C = dyn_cast(CI->getArgOperand(3))) - Desc.StoreMemLayout = - C->isZero() ? WmmaLayout::RowMajor : WmmaLayout::ColMajor; - else - Desc.StoreMemLayout = WmmaLayout::Dynamic; - } - } else if (Desc.Op == WmmaOpKind::Mma && - Desc.OrderedFrags.size() == 4) { - Desc.D = Desc.OrderedFrags[0]; - Desc.A = Desc.OrderedFrags[1]; - Desc.B = Desc.OrderedFrags[2]; - Desc.C = Desc.OrderedFrags[3]; - } + Desc.Op = WmmaOpKind::Load; - raw_string_ostream OS(Desc.ReplacementName); - if (Desc.Op == WmmaOpKind::Load) { - const FragmentSig &Frag = Desc.A.isValid() ? Desc.A : Desc.B; - OS << "__vx_wmma_load_" << toFragSuffix(Frag.Kind) << "_m" << Frag.M - << "n" << Frag.N << "k" << Frag.K << "_" - << toLayoutSuffix(Frag.Layout) << "_" << toElemSuffix(Frag.Elem); - } else if (Desc.Op == WmmaOpKind::Store) { - OS << "__vx_wmma_store_d" - << "_m" << Desc.D.M << "n" << Desc.D.N << "k" << Desc.D.K << "_" - << toElemSuffix(Desc.D.Elem); - } else if (Desc.Op == WmmaOpKind::Mma) { - OS << "__vx_wmma_mma" - << "_m" << Desc.D.M << "n" << Desc.D.N << "k" << Desc.D.K << "_" - << toLayoutSuffix(Desc.A.Layout) << "_" - << toLayoutSuffix(Desc.B.Layout) << "_" - << toElemSuffix(Desc.A.Elem) << "_" << toElemSuffix(Desc.B.Elem) - << "_" << toElemSuffix(Desc.D.Elem); - } - OS.flush(); + FragmentSig Frag; + Frag.M = M; + Frag.N = N; + Frag.K = K; - CallToDesc[CI] = std::move(Desc); - } + if (Parts[5] == "a") + Frag.Kind = WmmaFragKind::MatrixA; + else if (Parts[5] == "b") + Frag.Kind = WmmaFragKind::MatrixB; + else if (Parts[5] == "c" || Parts[5] == "d") + Frag.Kind = WmmaFragKind::Accumulator; + else + return false; + + Frag.Layout = parseLayoutToken(Parts[6]); + Frag.Elem = parseElemToken(Parts[8]); + + if (Frag.Elem == WmmaElemKind::Unknown) { + if (Type *EltTy = getStructElementType(CI->getType(), 0)) + Frag.Elem = elemFromLLVMType(EltTy); + } + + if (!Frag.isValid()) + return false; + + unsigned RegCount = getStructNumElements(CI->getType()); + if (RegCount == 0) + return false; + + if (Frag.Kind == WmmaFragKind::MatrixA) { + Desc.A = Frag; + Desc.ARegCount = RegCount; + } else if (Frag.Kind == WmmaFragKind::MatrixB) { + Desc.B = Frag; + Desc.BRegCount = RegCount; + } else { + Desc.C = Frag; + Desc.CRegCount = RegCount; + } + + Desc.OrderedFrags.push_back(Frag); + + } else if (Parts[4] == "store") { + // llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f32.p0 + if (Parts.size() < 9) + return false; + + Desc.Op = WmmaOpKind::Store; + + FragmentSig D; + D.Kind = WmmaFragKind::Accumulator; + D.M = M; + D.N = N; + D.K = K; + D.Layout = WmmaLayout::None; + D.Elem = parseElemToken(Parts[8]); + + if (D.Elem == WmmaElemKind::Unknown && CI->arg_size() >= 2) + D.Elem = elemFromLLVMType(CI->getArgOperand(1)->getType()); + + if (!D.isValid()) + return false; + + Desc.D = D; + Desc.StoreMemLayout = parseLayoutToken(Parts[6]); + + if (CI->arg_size() < 3) + return false; + + // Args: dst, d0..dN, stride + Desc.DRegCount = CI->arg_size() - 2; + Desc.OrderedFrags.push_back(D); + + } else if (Parts[4] == "mma") { + // llvm.nvvm.wmma.m32n8k16.mma.row.row.f32.f32 + if (Parts.size() < 9) + return false; + + Desc.Op = WmmaOpKind::Mma; + + FragmentSig A; + FragmentSig B; + FragmentSig C; + FragmentSig D; + + A.Kind = WmmaFragKind::MatrixA; + B.Kind = WmmaFragKind::MatrixB; + C.Kind = WmmaFragKind::Accumulator; + D.Kind = WmmaFragKind::Accumulator; + + A.M = B.M = C.M = D.M = M; + A.N = B.N = C.N = D.N = N; + A.K = B.K = C.K = D.K = K; + + A.Layout = parseLayoutToken(Parts[5]); + B.Layout = parseLayoutToken(Parts[6]); + C.Layout = WmmaLayout::None; + D.Layout = WmmaLayout::None; + + Desc.DRegCount = getStructNumElements(CI->getType()); + Desc.CRegCount = Desc.DRegCount; + + if (Desc.DRegCount == 0) + return false; + + unsigned ArgCount = CI->arg_size(); + + if (ArgCount <= Desc.CRegCount) + return false; + + unsigned ABTotal = ArgCount - Desc.CRegCount; + + if (ArgCount > 0) + A.Elem = elemFromLLVMType(CI->getArgOperand(0)->getType()); + + if (A.Elem == WmmaElemKind::Unknown) + A.Elem = WmmaElemKind::F16; + + B.Elem = A.Elem; + + unsigned ARegs = 0; + unsigned BRegs = 0; + + if (!inferABRegCountsForMma(A, B, ABTotal, ARegs, BRegs)) + return false; + + if (ARegs + BRegs + Desc.CRegCount != ArgCount) + return false; + + if (BRegs > 0) + B.Elem = elemFromLLVMType(CI->getArgOperand(ARegs)->getType()); + + if (B.Elem == WmmaElemKind::Unknown) + B.Elem = A.Elem; + + if (Desc.CRegCount > 0) { + unsigned CStart = ARegs + BRegs; + C.Elem = elemFromLLVMType(CI->getArgOperand(CStart)->getType()); } + + if (Type *DEltTy = getStructElementType(CI->getType(), 0)) + D.Elem = elemFromLLVMType(DEltTy); + + if (C.Elem == WmmaElemKind::Unknown) + C.Elem = WmmaElemKind::F32; + + if (D.Elem == WmmaElemKind::Unknown) + D.Elem = C.Elem; + + if (!A.isValid() || !B.isValid() || !C.isValid() || !D.isValid()) + return false; + + Desc.A = A; + Desc.B = B; + Desc.C = C; + Desc.D = D; + + Desc.ARegCount = ARegs; + Desc.BRegCount = BRegs; + + Desc.OrderedFrags.push_back(D); + Desc.OrderedFrags.push_back(A); + Desc.OrderedFrags.push_back(B); + Desc.OrderedFrags.push_back(C); + + } else { + return false; } - return CallToDesc; + buildReplacementName(Desc); + return !Desc.ReplacementName.empty(); } -void replace_wmma_intrinsics(llvm::Module *M) { - DenseMap CallToDesc = collectWMMADesc(M); - SmallPtrSet ToErase; +// Demangled nvcuda::wmma helper descriptor parser +static bool fillDescFromDemangledWmmaHelper(CallInst *CI, WmmaCallDesc &Desc) { + Function *Callee = CI->getCalledFunction(); + if (!Callee) + return false; - DBG(for (auto &It : CallToDesc) { - CallInst *CI = It.first; - WmmaCallDesc &Desc = It.second; - errs() << "[wmma] " << maybeDemangle(CI->getCalledFunction()->getName()) - << " -> " << Desc.ReplacementName << "\n"; - }); + std::string Name = maybeDemangle(Callee->getName()); + StringRef Demangled(Name); - for (auto &It : CallToDesc) { - CallInst *CI = It.first; - WmmaCallDesc &Desc = It.second; + if (!Demangled.contains("nvcuda::wmma::")) + return false; - Function *OldHelper = CI->getCalledFunction(); + Desc = WmmaCallDesc{}; - FunctionType *FTy = OldHelper->getFunctionType(); - FunctionCallee NewHelperCallee = - M->getOrInsertFunction(Desc.ReplacementName, FTy); - Function *NewHelper = dyn_cast(NewHelperCallee.getCallee()); - if (!NewHelper) - continue; - NewHelper->setCallingConv(CallingConv::C); - CI->setCalledFunction(NewHelperCallee); - if (OldHelper->use_empty()) { - ToErase.insert(OldHelper); + if (Demangled.contains("load_matrix_sync(")) + Desc.Op = WmmaOpKind::Load; + else if (Demangled.contains("store_matrix_sync(")) + Desc.Op = WmmaOpKind::Store; + else if (Demangled.contains("mma_sync(")) + Desc.Op = WmmaOpKind::Mma; + else + return false; + + size_t SearchFrom = 0; + + while (true) { + size_t Pos = Demangled.find("fragment<", SearchFrom); + if (Pos == StringRef::npos) + break; + + size_t Start = Pos + strlen("fragment<"); + int Depth = 1; + size_t End = Start; + + for (; End < Demangled.size(); ++End) { + if (Demangled[End] == '<') { + ++Depth; + } else if (Demangled[End] == '>') { + --Depth; + if (Depth == 0) + break; + } + } + + if (End >= Demangled.size()) + break; + + StringRef Body = Demangled.slice(Start, End); + + SmallVector Parts; + size_t PartStart = 0; + int AngleDepth = 0; + + for (size_t J = 0; J < Body.size(); ++J) { + if (Body[J] == '<') { + ++AngleDepth; + } else if (Body[J] == '>') { + --AngleDepth; + } else if (Body[J] == ',' && AngleDepth == 0) { + Parts.push_back(Body.slice(PartStart, J).trim()); + PartStart = J + 1; + } + } + + Parts.push_back(Body.drop_front(PartStart).trim()); + + if (Parts.size() == 6) { + FragmentSig Sig; + + if (Parts[0].contains("matrix_a")) + Sig.Kind = WmmaFragKind::MatrixA; + else if (Parts[0].contains("matrix_b")) + Sig.Kind = WmmaFragKind::MatrixB; + else if (Parts[0].contains("accumulator")) + Sig.Kind = WmmaFragKind::Accumulator; + + if (Parts[1].getAsInteger(10, Sig.M) || + Parts[2].getAsInteger(10, Sig.N) || + Parts[3].getAsInteger(10, Sig.K)) { + SearchFrom = End + 1; + continue; + } + + if (Parts[4] == "__half" || Parts[4].ends_with("__half")) + Sig.Elem = WmmaElemKind::F16; + else if (Parts[4] == "float") + Sig.Elem = WmmaElemKind::F32; + else if (Parts[4] == "unsigned char") + Sig.Elem = WmmaElemKind::U8; + else if (Parts[4] == "signed char") + Sig.Elem = WmmaElemKind::S8; + else if (Parts[4] == "int") + Sig.Elem = WmmaElemKind::S32; + + if (Parts[5] == "void") + Sig.Layout = WmmaLayout::None; + else if (Parts[5].contains("row_major")) + Sig.Layout = WmmaLayout::RowMajor; + else if (Parts[5].contains("col_major")) + Sig.Layout = WmmaLayout::ColMajor; + + if (Sig.isValid()) + Desc.OrderedFrags.push_back(Sig); } + + SearchFrom = End + 1; } - for (Function *F : ToErase) { - F->dropAllReferences(); - F->removeFromParent(); + if (Desc.Op == WmmaOpKind::Load && Desc.OrderedFrags.size() == 1) { + FragmentSig Frag = Desc.OrderedFrags[0]; + + if (Frag.Kind == WmmaFragKind::MatrixA) { + Desc.A = Frag; + Desc.ARegCount = 8; + } else if (Frag.Kind == WmmaFragKind::MatrixB) { + Desc.B = Frag; + Desc.BRegCount = 8; + } else if (Frag.Kind == WmmaFragKind::Accumulator) { + Desc.C = Frag; + Desc.CRegCount = 8; + } + + } else if (Desc.Op == WmmaOpKind::Store && Desc.OrderedFrags.size() == 1) { + Desc.D = Desc.OrderedFrags[0]; + Desc.DRegCount = 8; + + if (CI->arg_size() >= 4) { + if (auto *C = dyn_cast(CI->getArgOperand(3))) { + Desc.StoreMemLayout = + C->isZero() ? WmmaLayout::RowMajor : WmmaLayout::ColMajor; + } else { + Desc.StoreMemLayout = WmmaLayout::Dynamic; + } + } + + } else if (Desc.Op == WmmaOpKind::Mma && Desc.OrderedFrags.size() == 4) { + Desc.D = Desc.OrderedFrags[0]; + Desc.A = Desc.OrderedFrags[1]; + Desc.B = Desc.OrderedFrags[2]; + Desc.C = Desc.OrderedFrags[3]; + + unsigned ABTotal = 16; + inferABRegCountsForMma(Desc.A, Desc.B, ABTotal, Desc.ARegCount, + Desc.BRegCount); + + Desc.CRegCount = 8; + Desc.DRegCount = 8; } -} -// Replace raw NVVM WMMA intrinsics (llvm.nvvm.wmma.*) with Vortex runtime -// calls. These appear when clang inlines the mma.hpp C++ helper functions -// at compile time, leaving only the low-level NVVM intrinsics which -// collectWMMADesc (demangled-name matcher) cannot catch. -void replace_nvvm_wmma_intrinsics(llvm::Module *M) { - LLVMContext &Ctx = M->getContext(); - Type *F32 = Type::getFloatTy(Ctx); - Type *HalfTy = Type::getHalfTy(Ctx); - Type *I32 = Type::getInt32Ty(Ctx); - Type *PtrTy = PointerType::getUnqual(Ctx); - Type *VoidTy = Type::getVoidTy(Ctx); - Type *V2HalfTy = FixedVectorType::get(HalfTy, 2); + buildReplacementName(Desc); + return !Desc.ReplacementName.empty(); +} - SmallVector ToReplace; +// Unified collector +DenseMap collectWMMADesc(llvm::Module *M) { + DenseMap CallToDesc; for (Function &F : *M) { for (BasicBlock &BB : F) { @@ -2947,145 +3309,966 @@ void replace_nvvm_wmma_intrinsics(llvm::Module *M) { auto *CI = dyn_cast(&I); if (!CI) continue; + Function *Callee = CI->getCalledFunction(); if (!Callee) continue; - if (Callee->getName().starts_with("llvm.nvvm.wmma.")) - ToReplace.push_back(CI); + + WmmaCallDesc Desc; + + if (fillDescFromNVVMIntrinsic(CI, Desc)) { + CallToDesc[CI] = std::move(Desc); + continue; + } + + if (fillDescFromDemangledWmmaHelper(CI, Desc)) { + CallToDesc[CI] = std::move(Desc); + continue; + } } } } - if (ToReplace.empty()) - return; + return CallToDesc; +} - DBG_PRINT("[nvvm-wmma] found %zu NVVM WMMA intrinsics to replace\n", - ToReplace.size()); - - for (CallInst *CI : ToReplace) { - Function *Callee = CI->getCalledFunction(); - StringRef Name = Callee->getName(); - IRBuilder<> B(CI); - - if (Name.contains(".load.a.") || Name.contains(".load.b.")) { - // llvm.nvvm.wmma.m16n16k16.load.{a,b}.{row,col}.stride.f16.p0 - // Args: (ptr src, i32 stride) -> {<2 x half> x8} - bool isA = Name.contains(".load.a."); - bool isRow = Name.contains(".row."); - const char *rtName = isA - ? (isRow ? "__vx_wmma_load_a_m16n16k16_row_f16" - : "__vx_wmma_load_a_m16n16k16_col_f16") - : (isRow ? "__vx_wmma_load_b_m16n16k16_row_f16" - : "__vx_wmma_load_b_m16n16k16_col_f16"); - - DBG_PRINT("[nvvm-wmma] %s -> %s\n", Name.str().c_str(), rtName); - - Value *Src = CI->getArgOperand(0); - Value *Stride = CI->getArgOperand(1); - - // Alloca [8 x <2 x half>] at function entry - Type *FragTy = ArrayType::get(V2HalfTy, 8); - IRBuilder<> AB(&CI->getFunction()->getEntryBlock().front()); - Value *Frag = AB.CreateAlloca(FragTy, nullptr, "wmma.frag"); - - // void runtime(ptr frag, ptr src, i32 ldm) - FunctionType *FT = FunctionType::get(VoidTy, {PtrTy, PtrTy, I32}, false); - FunctionCallee FC = M->getOrInsertFunction(rtName, FT); - B.CreateCall(FC, {Frag, Src, Stride}); - - // Load 8 x <2 x half> and build return struct - Value *Result = UndefValue::get(CI->getType()); - for (unsigned i = 0; i < 8; i++) { - Value *GEP = B.CreateConstInBoundsGEP2_32(FragTy, Frag, 0, i); - Value *Val = B.CreateLoad(V2HalfTy, GEP); - Result = B.CreateInsertValue(Result, Val, i); - } - CI->replaceAllUsesWith(Result); - CI->eraseFromParent(); - - } else if (Name.contains(".mma.")) { - // llvm.nvvm.wmma.m16n16k16.mma.{row}.{row}.f32.f32 - // Args: (<2 x half> a0..a7, <2 x half> b0..b7, float c0..c7) -> {float x8} - DBG_PRINT("[nvvm-wmma] %s -> __vx_wmma_mma_m16n16k16_row_row_f16_f16_f32\n", - Name.str().c_str()); - - Type *AFragTy = ArrayType::get(V2HalfTy, 8); - Type *CFragTy = ArrayType::get(F32, 8); - - IRBuilder<> AB(&CI->getFunction()->getEntryBlock().front()); - Value *AAlloca = AB.CreateAlloca(AFragTy, nullptr, "wmma.a"); - Value *BAlloca = AB.CreateAlloca(AFragTy, nullptr, "wmma.b"); - Value *CAlloca = AB.CreateAlloca(CFragTy, nullptr, "wmma.c"); - Value *DAlloca = AB.CreateAlloca(CFragTy, nullptr, "wmma.d"); - - // Store a0..a7 (<2 x half> args 0-7) - for (unsigned i = 0; i < 8; i++) { - Value *GEP = B.CreateConstInBoundsGEP2_32(AFragTy, AAlloca, 0, i); - B.CreateStore(CI->getArgOperand(i), GEP); - } - // Store b0..b7 (<2 x half> args 8-15) - for (unsigned i = 0; i < 8; i++) { - Value *GEP = B.CreateConstInBoundsGEP2_32(AFragTy, BAlloca, 0, i); - B.CreateStore(CI->getArgOperand(8 + i), GEP); - } - // Store c0..c7 (float args 16-23) - for (unsigned i = 0; i < 8; i++) { - Value *GEP = B.CreateConstInBoundsGEP2_32(CFragTy, CAlloca, 0, i); - B.CreateStore(CI->getArgOperand(16 + i), GEP); - } +// Glue-pattern helpers +static bool isNVVMWMMACall(Value *V) { + auto *CI = dyn_cast(V); + if (!CI) + return false; - // void __vx_wmma_mma_...(ptr d, ptr a, ptr b, ptr c) - FunctionType *FT = FunctionType::get(VoidTy, - {PtrTy, PtrTy, PtrTy, PtrTy}, false); - FunctionCallee FC = M->getOrInsertFunction( - "__vx_wmma_mma_m16n16k16_row_row_f16_f16_f32", FT); - B.CreateCall(FC, {DAlloca, AAlloca, BAlloca, CAlloca}); - - // Load d0..d7 and build return struct - Value *Result = UndefValue::get(CI->getType()); - for (unsigned i = 0; i < 8; i++) { - Value *GEP = B.CreateConstInBoundsGEP2_32(CFragTy, DAlloca, 0, i); - Value *Val = B.CreateLoad(F32, GEP); - Result = B.CreateInsertValue(Result, Val, i); - } - CI->replaceAllUsesWith(Result); - CI->eraseFromParent(); - - } else if (Name.contains(".store.d.")) { - // llvm.nvvm.wmma.m16n16k16.store.d.{row,col}.stride.f32.p0 - // Args: (ptr dst, float f0..f7, i32 stride) -> void - int layout = Name.contains(".store.d.row.") ? 0 : 1; - DBG_PRINT("[nvvm-wmma] %s -> __vx_wmma_store_d_m16n16k16_f32 (layout=%d)\n", - Name.str().c_str(), layout); - - Type *CFragTy = ArrayType::get(F32, 8); - IRBuilder<> AB(&CI->getFunction()->getEntryBlock().front()); - Value *FragAlloca = AB.CreateAlloca(CFragTy, nullptr, "wmma.store"); - - Value *Dst = CI->getArgOperand(0); - for (unsigned i = 0; i < 8; i++) { - Value *GEP = B.CreateConstInBoundsGEP2_32(CFragTy, FragAlloca, 0, i); - B.CreateStore(CI->getArgOperand(1 + i), GEP); - } - Value *Stride = CI->getArgOperand(9); + Function *Callee = CI->getCalledFunction(); + if (!Callee) + return false; - // void __vx_wmma_store_d_m16n16k16_f32(ptr p, ptr frag, i32 ldm, i32 layout) - FunctionType *FT = FunctionType::get(VoidTy, - {PtrTy, PtrTy, I32, I32}, false); - FunctionCallee FC = M->getOrInsertFunction( - "__vx_wmma_store_d_m16n16k16_f32", FT); - B.CreateCall(FC, {Dst, FragAlloca, Stride, - ConstantInt::get(I32, layout)}); - CI->eraseFromParent(); - } - } + return Callee->getName().starts_with("llvm.nvvm.wmma."); +} - // Remove unused NVVM intrinsic declarations - SmallVector DeadDecls; - for (Function &F : *M) { - if (F.getName().starts_with("llvm.nvvm.wmma.") && F.use_empty()) - DeadDecls.push_back(&F); +static bool getConstantGEPIndex(Value *Ptr, Value *&Base, unsigned &Index, + Instruction **GEPInstOut = nullptr) { + if (GEPInstOut) + *GEPInstOut = nullptr; + + Ptr = Ptr->stripPointerCasts(); + + auto *GEP = dyn_cast(Ptr); + + if (!GEP) { + Base = Ptr; + Index = 0; + return true; } - for (Function *F : DeadDecls) - F->removeFromParent(); + + if (GEP->getNumIndices() != 1) + return false; + + auto IdxIt = GEP->idx_begin(); + auto *CI = dyn_cast(IdxIt->get()); + + if (!CI) + return false; + + Base = GEP->getPointerOperand()->stripPointerCasts(); + Index = CI->getZExtValue(); + + if (GEPInstOut) + *GEPInstOut = GEP; + + return true; +} + +static StoreInst * +findSingleStoreUserThroughCasts(Value *V, + SmallVectorImpl &Glue) { + if (auto *I = dyn_cast(V)) + Glue.push_back(I); + + Value *Cur = V; + + while (true) { + if (!Cur->hasOneUse()) + return nullptr; + + User *U = *Cur->user_begin(); + + if (auto *BC = dyn_cast(U)) { + Glue.push_back(BC); + Cur = BC; + continue; + } + + if (auto *ASC = dyn_cast(U)) { + Glue.push_back(ASC); + Cur = ASC; + continue; + } + + auto *SI = dyn_cast(U); + if (!SI) + return nullptr; + + if (SI->getValueOperand() != Cur) + return nullptr; + + Glue.push_back(SI); + return SI; + } +} + +// Handles this pattern: +// +// %frag = call {reg x N} @llvm.nvvm.wmma.* +// %x0 = extractvalue %frag, 0 +// %b0 = bitcast %x0 to i32 +// store i32 %b0, ptr %base +// %p1 = getelementptr i32, ptr %base, i32 1 +// %x1 = extractvalue %frag, 1 +// ... +// +// It returns %base as the existing fragment pointer. +static bool +collectCallResultStoreGlue(CallInst *CI, unsigned RegCount, Value *&FragmentPtr, + SmallVectorImpl &DeadGlue) { + SmallVector Stores; + Stores.resize(RegCount, nullptr); + + SmallVector, 16> PerRegGlue; + PerRegGlue.resize(RegCount); + + SmallVector GEPGlue; + + Value *CommonBase = nullptr; + unsigned Seen = 0; + + for (User *U : CI->users()) { + auto *EVI = dyn_cast(U); + if (!EVI) + return false; + + ArrayRef Indices = EVI->getIndices(); + + if (Indices.size() != 1) + return false; + + unsigned RegIdx = Indices[0]; + + if (RegIdx >= RegCount) + return false; + + SmallVector LocalGlue; + StoreInst *SI = findSingleStoreUserThroughCasts(EVI, LocalGlue); + + if (!SI) + return false; + + Value *Base = nullptr; + unsigned StoreIdx = 0; + Instruction *GEPInst = nullptr; + + if (!getConstantGEPIndex(SI->getPointerOperand(), Base, StoreIdx, &GEPInst)) + return false; + + if (StoreIdx != RegIdx) + return false; + + if (!CommonBase) + CommonBase = Base; + else if (CommonBase != Base) + return false; + + if (Stores[RegIdx]) + return false; + + Stores[RegIdx] = SI; + PerRegGlue[RegIdx] = std::move(LocalGlue); + + if (GEPInst) + GEPGlue.push_back(GEPInst); + + ++Seen; + } + + if (Seen != RegCount) + return false; + + for (unsigned I = 0; I < RegCount; ++I) { + if (!Stores[I]) + return false; + } + + for (Instruction *I : GEPGlue) + DeadGlue.push_back(I); + + for (unsigned I = 0; I < RegCount; ++I) { + for (Instruction *GI : PerRegGlue[I]) + DeadGlue.push_back(GI); + } + + FragmentPtr = CommonBase; + return FragmentPtr != nullptr; +} + +// Handles MMA/store input patterns: +// +// %p0 = getelementptr i32, ptr %frag, i32 0 +// %l0 = load i32, ptr %p0 +// %a0 = bitcast i32 %l0 to <2 x half> +// +// or: +// +// %p0 = getelementptr float, ptr %frag, i32 0 +// %c0 = load float, ptr %p0 +// +// It returns %frag. +static bool getLoadBaseIndexFromArg(Value *Arg, Value *&Base, unsigned &Index, + SmallVectorImpl &Glue) { + Value *Cur = Arg; + + while (true) { + if (auto *BC = dyn_cast(Cur)) { + Glue.push_back(BC); + Cur = BC->getOperand(0); + continue; + } + + if (auto *ASC = dyn_cast(Cur)) { + Glue.push_back(ASC); + Cur = ASC->getOperand(0); + continue; + } + + break; + } + + auto *LI = dyn_cast(Cur); + if (!LI) + return false; + + Glue.push_back(LI); + + Instruction *GEPInst = nullptr; + if (!getConstantGEPIndex(LI->getPointerOperand(), Base, Index, &GEPInst)) + return false; + + if (GEPInst) + Glue.push_back(GEPInst); + + return true; +} + +static Value * +findFragmentStorageFromMemoryArgs(CallInst *CI, unsigned ArgStart, + unsigned Count, + SmallVectorImpl &DeadGlue) { + Value *CommonBase = nullptr; + SmallVector LocalGlue; + + for (unsigned I = 0; I < Count; ++I) { + if (ArgStart + I >= CI->arg_size()) + return nullptr; + + SmallVector PerArgGlue; + + Value *Base = nullptr; + unsigned Index = 0; + + if (!getLoadBaseIndexFromArg(CI->getArgOperand(ArgStart + I), Base, Index, + PerArgGlue)) + return nullptr; + + if (Index != I) + return nullptr; + + if (!CommonBase) + CommonBase = Base; + else if (CommonBase != Base) + return nullptr; + + for (Instruction *GI : PerArgGlue) + LocalGlue.push_back(GI); + } + + if (!CommonBase) + return nullptr; + + for (Instruction *GI : LocalGlue) + DeadGlue.push_back(GI); + + return CommonBase; +} + +static Value * +findFragmentStorageFromExtractArgs(CallInst *CI, unsigned ArgStart, + unsigned Count, + FragStorageMap &FragmentStorage) { + Value *Found = nullptr; + + for (unsigned I = 0; I < Count; ++I) { + if (ArgStart + I >= CI->arg_size()) + return nullptr; + + Value *Arg = CI->getArgOperand(ArgStart + I); + + auto *EVI = dyn_cast(Arg); + if (!EVI) + return nullptr; + + ArrayRef Indices = EVI->getIndices(); + + if (Indices.size() != 1) + return nullptr; + + if (Indices[0] != I) + return nullptr; + + Value *Producer = EVI->getAggregateOperand(); + + auto It = FragmentStorage.find(Producer); + if (It == FragmentStorage.end()) + return nullptr; + + if (!Found) + Found = It->second; + else if (Found != It->second) + return nullptr; + } + + return Found; +} + +static bool canPackArgsWithoutNVVMDependency(CallInst *CI, unsigned ArgStart, + unsigned Count, + FragStorageMap &FragmentStorage) { + for (unsigned I = 0; I < Count; ++I) { + if (ArgStart + I >= CI->arg_size()) + return false; + + Value *Arg = CI->getArgOperand(ArgStart + I); + + auto *EVI = dyn_cast(Arg); + if (!EVI) + continue; + + Value *Producer = EVI->getAggregateOperand(); + + if (!isNVVMWMMACall(Producer)) + continue; + + // If this is an extract from old NVVM WMMA, we must have recovered the + // producer fragment pointer through FragmentStorage. Otherwise, packing + // would keep the old NVVM producer alive. + if (!FragmentStorage.contains(Producer)) + return false; + + return false; + } + + return true; +} + +static Value *bitcastToRuntimeReg(IRBuilder<> &B, Value *V, Type *DstTy) { + Type *SrcTy = V->getType(); + + if (SrcTy == DstTy) + return V; + + if (SrcTy->isPointerTy() || DstTy->isPointerTy()) + return B.CreateBitCast(V, DstTy); + + if (SrcTy->getPrimitiveSizeInBits() == DstTy->getPrimitiveSizeInBits()) + return B.CreateBitCast(V, DstTy); + + if (SrcTy->isIntegerTy() && DstTy->isIntegerTy()) + return B.CreateIntCast(V, DstTy, false); + + if (SrcTy->isFloatingPointTy() && DstTy->isIntegerTy()) + return B.CreateBitCast(V, DstTy); + + if (SrcTy->isIntegerTy() && DstTy->isFloatingPointTy()) + return B.CreateBitCast(V, DstTy); + + return V; +} + +static Value *packArgsToFragmentStorage(Module *M, CallInst *CI, + unsigned ArgStart, unsigned Count, + const FragmentSig &Sig, StringRef Name, + FragStorageMap &FragmentStorage) { + if (!canPackArgsWithoutNVVMDependency(CI, ArgStart, Count, FragmentStorage)) + return nullptr; + + LLVMContext &Ctx = M->getContext(); + IRBuilder<> B(CI); + + Type *RegTy = getRuntimeFragmentRegTy(Ctx, Sig); + if (!RegTy) + return nullptr; + + Type *ArrayTy = ArrayType::get(RegTy, Count); + + AllocaInst *Storage = createEntryAlloca(CI->getFunction(), ArrayTy, Name); + + for (unsigned I = 0; I < Count; ++I) { + Value *GEP = B.CreateConstInBoundsGEP2_32(ArrayTy, Storage, 0, I); + + Value *Arg = CI->getArgOperand(ArgStart + I); + Value *CastArg = bitcastToRuntimeReg(B, Arg, RegTy); + + B.CreateStore(CastArg, GEP); + } + + return Storage; +} + +// NVVM no-return lowering +static bool lowerNVVMLoadNoReturn(Module *M, CallInst *CI, + const WmmaCallDesc &Desc, + FragStorageMap &FragmentStorage, + SmallVectorImpl &DeadGlue, + SmallVectorImpl &OldCalls) { + LLVMContext &Ctx = M->getContext(); + + Type *VoidTy = Type::getVoidTy(Ctx); + Type *I32 = Type::getInt32Ty(Ctx); + Type *PtrTy = PointerType::getUnqual(Ctx); + + const FragmentSig *Frag = getLoadFragment(Desc); + if (!Frag) + return false; + + unsigned RegCount = getLoadRegCount(Desc); + if (RegCount == 0) + return false; + + IRBuilder<> B(CI); + + Value *FragmentPtr = nullptr; + SmallVector LocalDeadGlue; + + bool HasExistingFragmentStore = + collectCallResultStoreGlue(CI, RegCount, FragmentPtr, LocalDeadGlue); + + if (!HasExistingFragmentStore) { + AllocaInst *TmpStorage = createFragmentAlloca(CI->getFunction(), *Frag, + RegCount, "wmma.load.frag"); + + if (!TmpStorage) + return false; + + FragmentPtr = TmpStorage; + } + + Value *Src = CI->getArgOperand(0); + Value *Stride = CI->getArgOperand(1); + + if (Stride->getType() != I32) + Stride = B.CreateIntCast(Stride, I32, true); + + FunctionType *FT = FunctionType::get(VoidTy, {PtrTy, PtrTy, I32}, false); + + FunctionCallee FC = M->getOrInsertFunction(Desc.ReplacementName, FT); + + B.CreateCall(FC, {FragmentPtr, Src, Stride}); + + FragmentStorage[CI] = FragmentPtr; + OldCalls.push_back(CI); + + for (Instruction *I : LocalDeadGlue) + DeadGlue.push_back(I); + + return true; +} + +static bool lowerNVVMMmaNoReturn(Module *M, CallInst *CI, + const WmmaCallDesc &Desc, + FragStorageMap &FragmentStorage, + SmallVectorImpl &DeadGlue, + SmallVectorImpl &OldCalls) { + LLVMContext &Ctx = M->getContext(); + + Type *VoidTy = Type::getVoidTy(Ctx); + Type *PtrTy = PointerType::getUnqual(Ctx); + + unsigned ARegs = Desc.ARegCount; + unsigned BRegs = Desc.BRegCount; + unsigned CRegs = Desc.CRegCount; + unsigned DRegs = Desc.DRegCount; + + if (ARegs == 0 || BRegs == 0 || CRegs == 0 || DRegs == 0) + return false; + + IRBuilder<> B(CI); + + Value *AStorage = findFragmentStorageFromMemoryArgs(CI, 0, ARegs, DeadGlue); + + if (!AStorage) { + AStorage = + findFragmentStorageFromExtractArgs(CI, 0, ARegs, FragmentStorage); + } + + if (!AStorage) { + AStorage = packArgsToFragmentStorage(M, CI, 0, ARegs, Desc.A, "wmma.a.pack", + FragmentStorage); + } + + Value *BStorage = + findFragmentStorageFromMemoryArgs(CI, ARegs, BRegs, DeadGlue); + + if (!BStorage) { + BStorage = + findFragmentStorageFromExtractArgs(CI, ARegs, BRegs, FragmentStorage); + } + + if (!BStorage) { + BStorage = packArgsToFragmentStorage(M, CI, ARegs, BRegs, Desc.B, + "wmma.b.pack", FragmentStorage); + } + + Value *CStorage = + findFragmentStorageFromMemoryArgs(CI, ARegs + BRegs, CRegs, DeadGlue); + + if (!CStorage) { + CStorage = findFragmentStorageFromExtractArgs(CI, ARegs + BRegs, CRegs, + FragmentStorage); + } + + if (!CStorage) { + CStorage = packArgsToFragmentStorage(M, CI, ARegs + BRegs, CRegs, Desc.C, + "wmma.c.pack", FragmentStorage); + } + + if (!AStorage || !BStorage || !CStorage) + return false; + + Value *DStorage = nullptr; + SmallVector LocalResultGlue; + + bool HasExistingDStore = + collectCallResultStoreGlue(CI, DRegs, DStorage, LocalResultGlue); + + if (!HasExistingDStore) { + DStorage = createFragmentAlloca(CI->getFunction(), Desc.D, DRegs, "wmma.d"); + } + + if (!DStorage) + return false; + + FunctionType *FT = + FunctionType::get(VoidTy, {PtrTy, PtrTy, PtrTy, PtrTy}, false); + + FunctionCallee FC = M->getOrInsertFunction(Desc.ReplacementName, FT); + + B.CreateCall(FC, {DStorage, AStorage, BStorage, CStorage}); + + FragmentStorage[CI] = DStorage; + OldCalls.push_back(CI); + + for (Instruction *I : LocalResultGlue) + DeadGlue.push_back(I); + + return true; +} + +static bool lowerNVVMStoreNoReturn(Module *M, CallInst *CI, + const WmmaCallDesc &Desc, + FragStorageMap &FragmentStorage, + SmallVectorImpl &DeadGlue, + SmallVectorImpl &OldCalls) { + LLVMContext &Ctx = M->getContext(); + + Type *VoidTy = Type::getVoidTy(Ctx); + Type *I32 = Type::getInt32Ty(Ctx); + Type *PtrTy = PointerType::getUnqual(Ctx); + + unsigned DRegs = Desc.DRegCount; + + if (DRegs == 0) { + if (CI->arg_size() < 3) + return false; + + DRegs = CI->arg_size() - 2; + } + + IRBuilder<> B(CI); + + Value *Dst = CI->getArgOperand(0); + + Value *DStorage = findFragmentStorageFromMemoryArgs(CI, 1, DRegs, DeadGlue); + + if (!DStorage) { + DStorage = + findFragmentStorageFromExtractArgs(CI, 1, DRegs, FragmentStorage); + } + + if (!DStorage) { + DStorage = packArgsToFragmentStorage(M, CI, 1, DRegs, Desc.D, + "wmma.store.pack", FragmentStorage); + } + + if (!DStorage) + return false; + + Value *Stride = CI->getArgOperand(1 + DRegs); + + if (Stride->getType() != I32) + Stride = B.CreateIntCast(Stride, I32, true); + + int Layout = 0; + + if (Desc.StoreMemLayout == WmmaLayout::ColMajor) + Layout = 1; + + FunctionType *FT = FunctionType::get(VoidTy, {PtrTy, PtrTy, I32, I32}, false); + + FunctionCallee FC = M->getOrInsertFunction(Desc.ReplacementName, FT); + + B.CreateCall(FC, {Dst, DStorage, Stride, ConstantInt::get(I32, Layout)}); + + OldCalls.push_back(CI); + + return true; +} + +static bool lowerNVVMWMMANoReturn(Module *M, CallInst *CI, + const WmmaCallDesc &Desc, + FragStorageMap &FragmentStorage, + SmallVectorImpl &DeadGlue, + SmallVectorImpl &OldCalls) { + switch (Desc.Op) { + case WmmaOpKind::Load: + return lowerNVVMLoadNoReturn(M, CI, Desc, FragmentStorage, DeadGlue, + OldCalls); + + case WmmaOpKind::Mma: + return lowerNVVMMmaNoReturn(M, CI, Desc, FragmentStorage, DeadGlue, + OldCalls); + + case WmmaOpKind::Store: + return lowerNVVMStoreNoReturn(M, CI, Desc, FragmentStorage, DeadGlue, + OldCalls); + + default: + return false; + } +} + +// Cleanup +static void +eraseDeadGlueInstructions(SmallVectorImpl &DeadGlue) { + bool Progress = true; + + while (Progress) { + Progress = false; + + SmallVector ToEraseIdx; + SmallPtrSet SeenThisRound; + + for (unsigned Idx = 0, E = DeadGlue.size(); Idx != E; ++Idx) { + Instruction *I = DeadGlue[Idx]; + + if (!I) + continue; + + // This is only safe because we null entries immediately after erasing. + if (!I->getParent()) { + DeadGlue[Idx] = nullptr; + continue; + } + + if (!I->use_empty()) + continue; + + if (SeenThisRound.insert(I).second) + ToEraseIdx.push_back(Idx); + } + + for (unsigned Idx : ToEraseIdx) { + Instruction *I = DeadGlue[Idx]; + + if (!I) + continue; + + if (!I->getParent()) { + DeadGlue[Idx] = nullptr; + continue; + } + + if (!I->use_empty()) + continue; + + DeadGlue[Idx] = nullptr; + I->eraseFromParent(); + Progress = true; + } + } +} + +static bool +isSafeWMMADeadGlue(Instruction *I, + const SmallPtrSetImpl &OldCallSet) { + if (OldCallSet.contains(I)) + return true; + + if (isa(I)) + return true; + + if (isa(I)) + return true; + + if (isa(I)) + return true; + + if (isa(I) || isa(I) || isa(I)) + return true; + + if (isa(I) || isa(I)) + return true; + + if (isa(I)) + return true; + + return false; +} + +static bool +collectDeadWMMAUsers(Instruction *I, + const SmallPtrSetImpl &OldCallSet, + SmallPtrSetImpl &DeadSet, + SmallVectorImpl &BadUsers) { + bool OK = true; + + for (User *U : I->users()) { + auto *UserI = dyn_cast(U); + + if (!UserI) { + OK = false; + continue; + } + + if (!isSafeWMMADeadGlue(UserI, OldCallSet)) { + BadUsers.push_back(UserI); + OK = false; + continue; + } + + if (DeadSet.insert(UserI).second) { + if (!collectDeadWMMAUsers(UserI, OldCallSet, DeadSet, BadUsers)) + OK = false; + } + } + + return OK; +} + +static void eraseDeadWMMASet(SmallPtrSetImpl &DeadSet) { + bool Progress = true; + + while (Progress) { + Progress = false; + + SmallVector ToEraseNow; + + for (Instruction *I : DeadSet) { + if (!I || !I->getParent()) + continue; + + if (I->use_empty()) + ToEraseNow.push_back(I); + } + + for (Instruction *I : ToEraseNow) { + if (!I || !I->getParent()) + continue; + + DeadSet.erase(I); + I->eraseFromParent(); + Progress = true; + } + } +} + +static void cleanupOldNVVMCalls(ArrayRef OldCalls) { + SmallPtrSet OldCallSet; + + for (CallInst *CI : OldCalls) { + if (CI && CI->getParent()) + OldCallSet.insert(CI); + } + + SmallPtrSet DeadSet; + + for (CallInst *CI : OldCalls) { + if (CI && CI->getParent()) + DeadSet.insert(CI); + } + + SmallVector BadUsers; + + for (CallInst *CI : OldCalls) { + if (!CI || !CI->getParent()) + continue; + + collectDeadWMMAUsers(CI, OldCallSet, DeadSet, BadUsers); + } + + if (!BadUsers.empty()) { + errs() << "[wmma] cannot delete old NVVM WMMA graph because some users " + "are not recognized as dead WMMA glue:\n"; + + for (Instruction *I : BadUsers) { + errs() << " bad user: "; + I->print(errs()); + errs() << "\n"; + } + + report_fatal_error("old NVVM WMMA graph has unsupported remaining users"); + } + + eraseDeadWMMASet(DeadSet); + + for (CallInst *CI : OldCalls) { + if (!CI || !CI->getParent()) + continue; + + errs() << "[wmma] failed to erase old NVVM intrinsic:\n"; + CI->print(errs()); + errs() << "\n"; + + for (User *U : CI->users()) { + errs() << " remaining user: "; + U->print(errs()); + errs() << "\n"; + + if (auto *UserI = dyn_cast(U)) { + for (User *UU : UserI->users()) { + errs() << " user's user: "; + UU->print(errs()); + errs() << "\n"; + } + } + } + + report_fatal_error( + "replace_wmma_intrinsics left an NVVM WMMA intrinsic in the IR"); + } +} + +// Main entry point for replace_wmma_intrinsics +void replace_wmma_intrinsics(llvm::Module *M) { + DenseMap CallToDesc = collectWMMADesc(M); + + SmallVector OrderedCalls; + + for (Function &F : *M) { + for (BasicBlock &BB : F) { + for (Instruction &I : BB) { + auto *CI = dyn_cast(&I); + if (!CI) + continue; + + if (CallToDesc.contains(CI)) + OrderedCalls.push_back(CI); + } + } + } + + FragStorageMap FragmentStorage; + SmallVector OldNVVMCalls; + SmallVector DeadGlue; + SmallPtrSet ToErase; + + DBG(for (CallInst *CI : OrderedCalls) { + WmmaCallDesc &Desc = CallToDesc[CI]; + + errs() << "[wmma] "; + + if (Desc.FromNVVMIntrinsic) + errs() << CI->getCalledFunction()->getName(); + else + errs() << maybeDemangle(CI->getCalledFunction()->getName()); + + errs() << " -> " << Desc.ReplacementName << "\n"; + }); + + for (CallInst *CI : OrderedCalls) { + if (!CI || !CI->getParent()) + continue; + + auto It = CallToDesc.find(CI); + if (It == CallToDesc.end()) + continue; + + WmmaCallDesc &Desc = It->second; + + Function *OldHelper = CI->getCalledFunction(); + + if (Desc.FromNVVMIntrinsic) { + bool OK = lowerNVVMWMMANoReturn(M, CI, Desc, FragmentStorage, DeadGlue, + OldNVVMCalls); + + if (!OK) { + errs() << "[wmma] failed to lower NVVM WMMA intrinsic:\n"; + CI->print(errs()); + errs() << "\n"; + report_fatal_error("failed to lower NVVM WMMA intrinsic"); + } + + continue; + } + + // Non-inlined nvcuda::wmma helper calls are already pointer-style. + // For those, just replace the callee. + if (!OldHelper) + continue; + + FunctionType *FTy = OldHelper->getFunctionType(); + + FunctionCallee NewHelperCallee = + M->getOrInsertFunction(Desc.ReplacementName, FTy); + + Function *NewHelper = + dyn_cast(NewHelperCallee.getCallee()->stripPointerCasts()); + + if (!NewHelper) + continue; + + NewHelper->setCallingConv(CallingConv::C); + + CI->setCalledFunction(NewHelperCallee); + CI->setCallingConv(CallingConv::C); + + if (OldHelper->use_empty()) + ToErase.insert(OldHelper); + } + + // First erase recognized store/load glue such as: + // extractvalue -> bitcast -> store + // gep -> load -> bitcast -> old mma + eraseDeadGlueInstructions(DeadGlue); + + // Then erase old NVVM calls and their remaining extractvalue chains. + cleanupOldNVVMCalls(OldNVVMCalls); + + // After old calls are gone, some input load/bitcast glue may become dead. + eraseDeadGlueInstructions(DeadGlue); + + for (Function *F : ToErase) { + if (!F || !F->use_empty()) + continue; + + F->dropAllReferences(); + F->eraseFromParent(); + } + + SmallVector DeadNVVMDecls; + + for (Function &F : *M) { + if (F.getName().starts_with("llvm.nvvm.wmma.") && F.use_empty()) + DeadNVVMDecls.push_back(&F); + } + + for (Function *F : DeadNVVMDecls) + F->eraseFromParent(); + errs() << "End\n"; } diff --git a/examples/sgemm_tcu/sgemm_tcu-cuda-nvptx64-nvidia-cuda-sm_90-4b676c53.bc.tmp b/examples/sgemm_tcu/sgemm_tcu-cuda-nvptx64-nvidia-cuda-sm_90-4b676c53.bc.tmp new file mode 100644 index 0000000..e69de29 diff --git a/runtime/src/vortex/kernel/cudaKernelImpl.cpp b/runtime/src/vortex/kernel/cudaKernelImpl.cpp index 79a9679..eaa22fe 100644 --- a/runtime/src/vortex/kernel/cudaKernelImpl.cpp +++ b/runtime/src/vortex/kernel/cudaKernelImpl.cpp @@ -413,38 +413,6 @@ extern "C" void __vx_wmma_load_a_m16n16k16_row_f16( } } -// TODO: Fix BUG. Compute is not resulting in correct result -extern "C" void __vx_wmma_load_a_m32n8k16_row_f16( - void *frag, // fragment storage, 8 x i32 - const void *p, // base pointer to matrix A (fp16 elements) - int32_t ldm) { // stride in fp16 elements - - auto *dst = reinterpret_cast(frag); - auto *src = reinterpret_cast(p); - - uint32_t lane = __vx_get_lane_id(); - uint32_t lane16 = lane & 0xFu; // 16-lane native subgroup - uint32_t row_base = - (lane >> 4u) * 16u; // lanes 0..15 => rows 0..15, 16..31 => rows 16..31 - - // Native 16x8x16 A mapping: - // block_row = lane16 / 4, block_col = (lane16 % 4) * 2 - // row stride across regs = 4, col stride across regs = 8 - uint32_t block_row = row_base + (lane16 / 4u); - uint32_t block_col = (lane16 % 4u) * 2u; - - for (uint32_t r = 0; r < 8u; ++r) { - uint32_t elem_row = block_row + ((r / 2u) * 4u); - uint32_t elem_col = block_col + ((r % 2u) * 8u); - - uint32_t off = elem_row * static_cast(ldm) + elem_col; - - uint32_t packed; - __builtin_memcpy(&packed, src + static_cast(off) * 2u, 4); - dst[r] = packed; - } -} - extern "C" void __vx_wmma_load_b_m16n16k16_row_f16(void *frag, // 8 x i32 fragment storage const void *p, // base ptr to matrix B @@ -475,46 +443,6 @@ __vx_wmma_load_b_m16n16k16_row_f16(void *frag, // 8 x i32 fragment storage } } -// TODO: Fix BUG. Compute is not resulting in correct result -extern "C" void __vx_wmma_load_b_m32n8k16_row_f16( - void *frag, // fragment storage, keep 8 x i32 to match wrapper - const void *p, // base ptr to matrix B - int32_t ldm) { // stride in half elements - - auto *dst = reinterpret_cast(frag); - auto *src = reinterpret_cast(p); - - uint32_t lane = __vx_get_lane_id(); - uint32_t lane16 = lane & 0xFu; // both 16-lane subgroups use the same B tile - - // Native 16x8x16 B mapping: - // block_col = lane16 / 4, block_row = (lane16 % 4) * 2 - // NRB = 4 for this native shape, so only dst[0..3] are meaningful. - uint32_t block_col = lane16 / 4u; - uint32_t block_row = (lane16 % 4u) * 2u; - - for (uint32_t r = 0; r < 4u; ++r) { - uint32_t row = block_row + ((r / 2u) * 8u); - uint32_t col = block_col + ((r % 2u) * 4u); - - uint32_t off0 = row * static_cast(ldm) + col; - uint32_t off1 = off0 + static_cast(ldm); // next row - - uint16_t h0; - uint16_t h1; - __builtin_memcpy(&h0, src + static_cast(off0) * 2u, 2); - __builtin_memcpy(&h1, src + static_cast(off1) * 2u, 2); - - dst[r] = static_cast(h0) | (static_cast(h1) << 16); - } - - // Unused by the 4-reg B mma path, but keep fragment storage deterministic. - dst[4] = 0; - dst[5] = 0; - dst[6] = 0; - dst[7] = 0; -} - extern "C" void __vx_wmma_load_c_m16n16k16_row_f32(void *frag, // fragment storage: 8 x float const void *p, // base pointer to matrix C @@ -556,31 +484,6 @@ extern "C" void __vx_wmma_store_d_m16n16k16_f32(float *p, const float *frag, } } -// TODO: Fix BUG. Compute is not resulting in correct result -extern "C" void __vx_wmma_store_d_m32n8k16_f32(float *p, const float *frag, - int32_t ldm, WmmaLayout layout) { - uint32_t lane = __vx_get_lane_id(); - - uint32_t block_row = lane / 4u; - uint32_t block_col = lane % 4u; - - if (layout == WMMA_LAYOUT_ROW) { - for (uint32_t r = 0; r < 8u; ++r) { - uint32_t row = block_row + ((r / 2u) * 8u); - uint32_t col = block_col + ((r % 2u) * 4u); - uint32_t off = row * static_cast(ldm) + col; - p[off] = frag[r]; - } - } else if (layout == WMMA_LAYOUT_COL) { - for (uint32_t r = 0; r < 8u; ++r) { - uint32_t row = block_row + ((r / 2u) * 8u); - uint32_t col = block_col + ((r % 2u) * 4u); - uint32_t off = col * static_cast(ldm) + row; - p[off] = frag[r]; - } - } -} - extern "C" void __vx_wmma_mma_m16n16k16_row_row_f16_f16_f32(float *d, const uint32_t *a, const uint32_t *b, @@ -634,53 +537,6 @@ extern "C" void __vx_wmma_mma_m16n16k16_row_row_f16_f16_f32(float *d, d[7] = d7; } -// TODO: Fix BUG. Compute is not resulting in correct result -extern "C" void __vx_wmma_mma_m32n8k16_row_row_f16_f16_f32(float *d, - const uint32_t *a, - const uint32_t *b, - const float *c) { - register float d0 asm("f0") = c[0]; - register float d1 asm("f1") = c[1]; - register float d2 asm("f2") = c[2]; - register float d3 asm("f3") = c[3]; - register float d4 asm("f4") = c[4]; - register float d5 asm("f5") = c[5]; - register float d6 asm("f6") = c[6]; - register float d7 asm("f7") = c[7]; - - register float a0 asm("f10") = __vx_bitcast_u32_to_f32(a[0]); - register float a1 asm("f11") = __vx_bitcast_u32_to_f32(a[1]); - register float a2 asm("f12") = __vx_bitcast_u32_to_f32(a[2]); - register float a3 asm("f13") = __vx_bitcast_u32_to_f32(a[3]); - register float a4 asm("f14") = __vx_bitcast_u32_to_f32(a[4]); - register float a5 asm("f15") = __vx_bitcast_u32_to_f32(a[5]); - register float a6 asm("f16") = __vx_bitcast_u32_to_f32(a[6]); - register float a7 asm("f17") = __vx_bitcast_u32_to_f32(a[7]); - - // Native 16x8x16 B fragment uses 4 regs, matching the supported Vortex path. - register float b0 asm("f28") = __vx_bitcast_u32_to_f32(b[0]); - register float b1 asm("f29") = __vx_bitcast_u32_to_f32(b[1]); - register float b2 asm("f30") = __vx_bitcast_u32_to_f32(b[2]); - register float b3 asm("f31") = __vx_bitcast_u32_to_f32(b[3]); - - // x0 = fp32::id, x1 = fp16::id, x0 flags = dense - asm volatile(".insn r 0x0b, 0, 2, x0, x1, x0" - : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), "+f"(d4), "+f"(d5), - "+f"(d6), "+f"(d7) - : "f"(a0), "f"(a1), "f"(a2), "f"(a3), "f"(a4), "f"(a5), "f"(a6), - "f"(a7), "f"(b0), "f"(b1), "f"(b2), "f"(b3) - : "memory"); - - d[0] = d0; - d[1] = d1; - d[2] = d2; - d[3] = d3; - d[4] = d4; - d[5] = d5; - d[6] = d6; - d[7] = d7; -} - // Float atomic add helper. // RISC-V "A" extension has no native float atomic add. The default LLVM // lowering of `atomicrmw fadd` produces a cmpxchg (LR/SC) loop whose From 76be1f463570381ca8345972f31bfb8973f74b74 Mon Sep 17 00:00:00 2001 From: Sai Hemanth Bheemreddy Date: Sun, 3 May 2026 22:37:11 -0400 Subject: [PATCH 2/2] partial --- compilation/KernelTranslation/src/tool.cpp | 6 +++--- examples/sgemm_tcu/sgemm_tcu.cu | 11 +++-------- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/compilation/KernelTranslation/src/tool.cpp b/compilation/KernelTranslation/src/tool.cpp index 84cc0c1..3716a70 100644 --- a/compilation/KernelTranslation/src/tool.cpp +++ b/compilation/KernelTranslation/src/tool.cpp @@ -4245,13 +4245,13 @@ void replace_wmma_intrinsics(llvm::Module *M) { // First erase recognized store/load glue such as: // extractvalue -> bitcast -> store // gep -> load -> bitcast -> old mma - eraseDeadGlueInstructions(DeadGlue); + // eraseDeadGlueInstructions(DeadGlue); // Then erase old NVVM calls and their remaining extractvalue chains. - cleanupOldNVVMCalls(OldNVVMCalls); + // cleanupOldNVVMCalls(OldNVVMCalls); // After old calls are gone, some input load/bitcast glue may become dead. - eraseDeadGlueInstructions(DeadGlue); + // eraseDeadGlueInstructions(DeadGlue); for (Function *F : ToErase) { if (!F || !F->use_empty()) diff --git a/examples/sgemm_tcu/sgemm_tcu.cu b/examples/sgemm_tcu/sgemm_tcu.cu index 712c6db..1beabc4 100644 --- a/examples/sgemm_tcu/sgemm_tcu.cu +++ b/examples/sgemm_tcu/sgemm_tcu.cu @@ -87,14 +87,9 @@ int main(int argc, char **argv) { srand(42); for (int i = 0; i < M * K; i++) - // h_A[i] = __float2half((float)(rand() % 5) / 5.0f); - h_A[i] = __float2half(1); - // for (int i = 0; i < K * N; i++) - // h_B[i] = __float2half((float)(rand() % 5) / 5.0f); - for (int i = 0; i < K * N; i++) { - int k = i / N; // row index in B since B is K x N row-major - h_B[i] = __float2half(k < 16 ? 1.0f : 0.0f); - } + h_A[i] = __float2half((float)(rand() % 5) / 5.0f); + for (int i = 0; i < K * N; i++) + h_B[i] = __float2half((float)(rand() % 5) / 5.0f); memset(h_C, 0, sizeC); memset(h_C_ref, 0, sizeC); gemm_cpu(h_A, h_B, h_C_ref, M, N, K);