From dbd8e237888369b20666f1e0ad7e599695554e00 Mon Sep 17 00:00:00 2001 From: Dongmin Ra Date: Thu, 23 Oct 2025 07:08:25 +0000 Subject: [PATCH] Split 'shmemOutTokMemObj' and 'shmemOutWeightsMemObj' into seperate objects for dispatch and combine --- .../dispatch_combine/test_dispatch_combine.cpp | 10 ++++++---- .../ops/dispatch_combine/dispatch_combine.hpp | 18 ++++++++++++------ src/ops/dispatch_combine/dispatch_combine.cpp | 14 ++++++++++---- src/ops/dispatch_combine/internode.hpp | 18 ++++++++++-------- src/ops/dispatch_combine/intranode.hpp | 17 +++++++++-------- src/pybind/mori.cpp | 8 ++++---- 6 files changed, 51 insertions(+), 34 deletions(-) diff --git a/examples/ops/dispatch_combine/test_dispatch_combine.cpp b/examples/ops/dispatch_combine/test_dispatch_combine.cpp index 39e97dfb..325a7bd5 100644 --- a/examples/ops/dispatch_combine/test_dispatch_combine.cpp +++ b/examples/ops/dispatch_combine/test_dispatch_combine.cpp @@ -249,7 +249,8 @@ class EpDispatchCombineTestCase { index_t srcPe = srcTokId / config.maxNumInpTokenPerRank; index_t localSrcTokId = srcTokId - srcPe * config.maxNumInpTokenPerRank; - T* localTokBuf = handle.shmemOutTokMemObj->template GetAs() + i * config.hiddenDim; + T* localTokBuf = + handle.shmemDispatchOutTokMemObj->template GetAs() + i * config.hiddenDim; T* srcTokBuf = reinterpret_cast(globalInpTokBufCpu) + srcPe * inpTokEleNum + localSrcTokId * config.hiddenDim; @@ -291,7 +292,8 @@ class EpDispatchCombineTestCase { index_t srcTokDispatchId = peSortToTokenIdxMapsVec[srcPe][peSortedId]; index_t srcTokId = srcTokDispatchId / config.numExpertPerToken; - T* localTokBuf = handle.shmemOutTokMemObj->template GetAs() + i * config.hiddenDim; + T* localTokBuf = + handle.shmemDispatchOutTokMemObj->template GetAs() + i * config.hiddenDim; T* srcTokBuf = reinterpret_cast(globalInpTokBufCpu) + srcPe * inpTokEleNum + srcTokId * config.hiddenDim; @@ -348,7 +350,7 @@ class EpDispatchCombineTestCase { float expected = float(T(float(reinterpret_cast(inpTokBufCpu)[tokenOffset + j]) * weightSum)); // float got = float(handle.outTokenBuf[tokenOffset + j]); - float got = float(handle.shmemOutTokMemObj->template GetAs()[tokenOffset + j]); + float got = float(handle.shmemCombineOutTokMemObj->template GetAs()[tokenOffset + j]); assert(weightSum != 0); if (abs(got - expected) > runConfig.atol) { std::cout << "Wrong result at pos " << j << ": mype " << config.rank << " tokenId " << i @@ -468,7 +470,7 @@ class EpDispatchCombineTestCase { // HIP_RUNTIME_CHECK(hipMemcpy(inpTokBuf, outTokBuf, // config.MaxNumTokensToRecvPerRank() * config.hiddenDim * // sizeof(T), hipMemcpyDeviceToDevice)); - HIP_RUNTIME_CHECK(hipMemcpy(inpTokBuf, handle.shmemOutTokMemObj->template GetAs(), + HIP_RUNTIME_CHECK(hipMemcpy(inpTokBuf, handle.shmemDispatchOutTokMemObj->template GetAs(), config.MaxNumTokensToRecvPerRank() * config.hiddenDim * sizeof(T), hipMemcpyDeviceToDevice)); HIP_RUNTIME_CHECK( diff --git a/include/mori/ops/dispatch_combine/dispatch_combine.hpp b/include/mori/ops/dispatch_combine/dispatch_combine.hpp index 0dd55622..789b0613 100644 --- a/include/mori/ops/dispatch_combine/dispatch_combine.hpp +++ b/include/mori/ops/dispatch_combine/dispatch_combine.hpp @@ -171,12 +171,14 @@ class EpDispatchCombineHandle { // Registered buffers for tokens, shmemOutTokMemObj will be returned to user as output mori::application::SymmMemObjPtr shmemDispatchInpTokMemObj; mori::application::SymmMemObjPtr shmemCombineInpTokMemObj; - mori::application::SymmMemObjPtr shmemOutTokMemObj; + mori::application::SymmMemObjPtr shmemDispatchOutTokMemObj; + mori::application::SymmMemObjPtr shmemCombineOutTokMemObj; mori::application::SymmMemObjPtr shmemStagingTokMemObj; // Registered buffer used for weights, indices and scales mori::application::SymmMemObjPtr shmemInpWeightsMemObj; - mori::application::SymmMemObjPtr shmemOutWeightsMemObj; + mori::application::SymmMemObjPtr shmemDispatchOutWeightsMemObj; + mori::application::SymmMemObjPtr shmemCombineOutWeightsMemObj; mori::application::SymmMemObjPtr shmemInpScalesMemObj; mori::application::SymmMemObjPtr shmemOutScalesMemObj; mori::application::SymmMemObjPtr shmemInpIndicesMemObj; @@ -231,10 +233,12 @@ struct EpDispatchCombineArgs { uint8_t* scalesBuf{nullptr}; mori::application::SymmMemObjPtr shmemDispatchInpTokMemObj; mori::application::SymmMemObjPtr shmemCombineInpTokMemObj; - mori::application::SymmMemObjPtr shmemOutTokMemObj; + mori::application::SymmMemObjPtr shmemDispatchOutTokMemObj; + mori::application::SymmMemObjPtr shmemCombineOutTokMemObj; mori::application::SymmMemObjPtr shmemStagingTokMemObj; mori::application::SymmMemObjPtr shmemInpWeightsMemObj; - mori::application::SymmMemObjPtr shmemOutWeightsMemObj; + mori::application::SymmMemObjPtr shmemDispatchOutWeightsMemObj; + mori::application::SymmMemObjPtr shmemCombineOutWeightsMemObj; mori::application::SymmMemObjPtr shmemInpScalesMemObj; mori::application::SymmMemObjPtr shmemOutScalesMemObj; mori::application::SymmMemObjPtr shmemInpIndicesMemObj; @@ -276,10 +280,12 @@ EpDispatchCombineArgs GetEpDispatchCombineArgs(const EpDispatchCombineHandle& args.localPeTokenCounter = handle.localPeTokenCounter; args.shmemDispatchInpTokMemObj = handle.shmemDispatchInpTokMemObj; args.shmemCombineInpTokMemObj = handle.shmemCombineInpTokMemObj; - args.shmemOutTokMemObj = handle.shmemOutTokMemObj; + args.shmemDispatchOutTokMemObj = handle.shmemDispatchOutTokMemObj; + args.shmemCombineOutTokMemObj = handle.shmemCombineOutTokMemObj; args.shmemStagingTokMemObj = handle.shmemStagingTokMemObj; args.shmemInpWeightsMemObj = handle.shmemInpWeightsMemObj; - args.shmemOutWeightsMemObj = handle.shmemOutWeightsMemObj; + args.shmemDispatchOutWeightsMemObj = handle.shmemDispatchOutWeightsMemObj; + args.shmemCombineOutWeightsMemObj = handle.shmemCombineOutWeightsMemObj; args.shmemInpScalesMemObj = handle.shmemInpScalesMemObj; args.shmemOutScalesMemObj = handle.shmemOutScalesMemObj; args.shmemInpIndicesMemObj = handle.shmemInpIndicesMemObj; diff --git a/src/ops/dispatch_combine/dispatch_combine.cpp b/src/ops/dispatch_combine/dispatch_combine.cpp index f7b2572b..dc7d8ecc 100644 --- a/src/ops/dispatch_combine/dispatch_combine.cpp +++ b/src/ops/dispatch_combine/dispatch_combine.cpp @@ -74,12 +74,16 @@ void EpDispatchCombineHandle::InitializeShmemBuf() { ShmemMallocAndReturnMemObjPtr(maxStagingTokSize, hipDeviceMallocUncached); shmemCombineInpTokMemObj = ShmemMallocAndReturnMemObjPtr(maxStagingTokSize, hipDeviceMallocUncached); - shmemOutTokMemObj = ShmemMallocAndReturnMemObjPtr(maxTokenSize, hipDeviceMallocUncached); + shmemDispatchOutTokMemObj = ShmemMallocAndReturnMemObjPtr(maxTokenSize, hipDeviceMallocUncached); + shmemCombineOutTokMemObj = ShmemMallocAndReturnMemObjPtr(maxTokenSize, hipDeviceMallocUncached); shmemStagingTokMemObj = ShmemMallocAndReturnMemObjPtr(maxStagingTokSize, hipDeviceMallocUncached); size_t maxWeightSize = config.MaxNumTokensToRecv() * config.numExpertPerToken * sizeof(float); shmemInpWeightsMemObj = ShmemMallocAndReturnMemObjPtr(maxWeightSize, hipDeviceMallocUncached); - shmemOutWeightsMemObj = ShmemMallocAndReturnMemObjPtr(maxWeightSize, hipDeviceMallocUncached); + shmemDispatchOutWeightsMemObj = + ShmemMallocAndReturnMemObjPtr(maxWeightSize, hipDeviceMallocUncached); + shmemCombineOutWeightsMemObj = + ShmemMallocAndReturnMemObjPtr(maxWeightSize, hipDeviceMallocUncached); if (config.scaleDim > 0 && config.scaleTypeSize > 0) { size_t maxScaleSize = config.MaxNumTokensToRecv() * config.scaleDim * config.scaleTypeSize; @@ -95,10 +99,12 @@ void EpDispatchCombineHandle::InitializeShmemBuf() { void EpDispatchCombineHandle::FinalizeShmemBuf() { ShmemFree(shmemDispatchInpTokMemObj->localPtr); ShmemFree(shmemCombineInpTokMemObj->localPtr); - ShmemFree(shmemOutTokMemObj->localPtr); + ShmemFree(shmemDispatchOutTokMemObj->localPtr); + ShmemFree(shmemCombineOutTokMemObj->localPtr); ShmemFree(shmemStagingTokMemObj->localPtr); ShmemFree(shmemInpWeightsMemObj->localPtr); - ShmemFree(shmemOutWeightsMemObj->localPtr); + ShmemFree(shmemDispatchOutWeightsMemObj->localPtr); + ShmemFree(shmemCombineOutWeightsMemObj->localPtr); if (shmemInpScalesMemObj.IsValid()) ShmemFree(shmemInpScalesMemObj->localPtr); if (shmemOutScalesMemObj.IsValid()) ShmemFree(shmemOutScalesMemObj->localPtr); ShmemFree(shmemInpIndicesMemObj->localPtr); diff --git a/src/ops/dispatch_combine/internode.hpp b/src/ops/dispatch_combine/internode.hpp index 7dc26495..b569d8cb 100644 --- a/src/ops/dispatch_combine/internode.hpp +++ b/src/ops/dispatch_combine/internode.hpp @@ -297,10 +297,10 @@ __global__ void EpDispatchInterNodeKernel(EpDispatchCombineArgs args) { size_t localTokenOffset = size_t(localTokenIdx) * size_t(config.hiddenDim) * sizeof(T); size_t peSortedTokenOffset = size_t(peSortedId) * stagingOffset; - core::WarpCopy(args.shmemOutTokMemObj->template GetAs() + localTokenOffset, + core::WarpCopy(args.shmemDispatchOutTokMemObj->template GetAs() + localTokenOffset, args.shmemDispatchInpTokMemObj->template GetAs() + peSortedTokenOffset, config.hiddenDim * sizeof(T)); - core::WarpCopy(args.shmemOutWeightsMemObj->template GetAs() + + core::WarpCopy(args.shmemDispatchOutWeightsMemObj->template GetAs() + localTokenIdx * config.numExpertPerToken * sizeof(float), args.shmemDispatchInpTokMemObj->template GetAs() + peSortedTokenOffset + weightOffset, @@ -348,7 +348,8 @@ inline __device__ void CrossDeviceBarrierInterNodeKernel(EpDispatchCombineArgstemplate GetAs(); + volatile uint64_t* localBarrierPtr = + args.crossDeviceBarrierMemObj->template GetAs(); if (thdId < args.config.worldSize) { uint64_t currentVal = core::AtomicLoadRelaxedSystem(localBarrierPtr + thdId); #if DEBUG == 1 @@ -551,13 +552,14 @@ __global__ void EpCombineInterNodeKernel(EpDispatchCombineArgs args) { } size_t offset = size_t(tokenId) * size_t(config.hiddenDim) + hiddenDimOffset; - core::WarpAccum(args.shmemOutTokMemObj->template GetAs() + offset, srcPtrs, nullptr, - config.numExpertPerToken, hiddenDimSize); + core::WarpAccum(args.shmemCombineOutTokMemObj->template GetAs() + offset, srcPtrs, + nullptr, config.numExpertPerToken, hiddenDimSize); if (args.weightsBuf && inTokenPartId == warpsPerToken - 1) { - core::WarpAccum( - args.shmemOutWeightsMemObj->template GetAs() + tokenId * config.numExpertPerToken, - srcWeightsPtr, nullptr, config.numExpertPerToken, config.numExpertPerToken); + core::WarpAccum(args.shmemCombineOutWeightsMemObj->template GetAs() + + tokenId * config.numExpertPerToken, + srcWeightsPtr, nullptr, config.numExpertPerToken, + config.numExpertPerToken); } } if (globalThdId == 0) { diff --git a/src/ops/dispatch_combine/intranode.hpp b/src/ops/dispatch_combine/intranode.hpp index 924caac8..03ef6619 100644 --- a/src/ops/dispatch_combine/intranode.hpp +++ b/src/ops/dispatch_combine/intranode.hpp @@ -122,7 +122,7 @@ __global__ void EpDispatchIntraNodeKernel(EpDispatchCombineArgs args) { // Write weights and indices if (laneId < config.numExpertPerToken) { if (args.weightsBuf) { - args.shmemOutWeightsMemObj->template GetAs( + args.shmemDispatchOutWeightsMemObj->template GetAs( destPe)[destTokId * config.numExpertPerToken + laneId] = args.weightsBuf[srcTokId * config.numExpertPerToken + laneId]; } @@ -142,7 +142,7 @@ __global__ void EpDispatchIntraNodeKernel(EpDispatchCombineArgs args) { index_t srcTokOffset = srcTokId * config.hiddenDim; index_t destTokOffset = destTokId * config.hiddenDim; - core::WarpCopy(args.shmemOutTokMemObj->template GetAs(destPe) + destTokOffset, + core::WarpCopy(args.shmemDispatchOutTokMemObj->template GetAs(destPe) + destTokOffset, args.inpTokenBuf + srcTokOffset, config.hiddenDim); } } @@ -261,14 +261,15 @@ __global__ void EpCombineIntraNodeKernel(EpDispatchCombineArgs args) { srcWeightsPtr[j] = nullptr; } } - core::WarpAccum( - args.shmemOutTokMemObj->template GetAs() + tokenId * config.hiddenDim + hiddenDimOffset, - srcPtrs, nullptr, config.numExpertPerToken, hiddenDimSize); + core::WarpAccum(args.shmemCombineOutTokMemObj->template GetAs() + + tokenId * config.hiddenDim + hiddenDimOffset, + srcPtrs, nullptr, config.numExpertPerToken, hiddenDimSize); if (args.weightsBuf && inTokenPartId == warpsPerToken - 1) { - core::WarpAccum( - args.shmemOutWeightsMemObj->template GetAs() + tokenId * config.numExpertPerToken, - srcWeightsPtr, nullptr, config.numExpertPerToken, config.numExpertPerToken); + core::WarpAccum(args.shmemCombineOutWeightsMemObj->template GetAs() + + tokenId * config.numExpertPerToken, + srcWeightsPtr, nullptr, config.numExpertPerToken, + config.numExpertPerToken); } } } diff --git a/src/pybind/mori.cpp b/src/pybind/mori.cpp index eaadfe9c..f1355a46 100644 --- a/src/pybind/mori.cpp +++ b/src/pybind/mori.cpp @@ -70,14 +70,14 @@ LaunchDispatch(mori::moe::EpDispatchCombineHandle& handle, int kernelType, at::cuda::getCurrentHIPStream()); torch::Tensor out = - torch::from_blob(handle.shmemOutTokMemObj->Get(), + torch::from_blob(handle.shmemDispatchOutTokMemObj->Get(), {handle.config.MaxNumTokensToRecv(), handle.config.hiddenDim}, torch::TensorOptions().dtype(input.scalar_type()).device(torch::kCUDA)); std::optional outWeights{std::nullopt}; if (weightPtr) { outWeights = torch::from_blob( - handle.shmemOutWeightsMemObj->Get(), + handle.shmemDispatchOutWeightsMemObj->Get(), {handle.config.MaxNumTokensToRecv(), handle.config.numExpertPerToken}, torch::TensorOptions().dtype(mori::GetTorchDataType()).device(torch::kCUDA)); } @@ -127,13 +127,13 @@ std::tuple> LaunchCombine( auto options = torch::TensorOptions().dtype(input.scalar_type()).device(torch::kCUDA); torch::Tensor out = - torch::from_blob(handle.shmemOutTokMemObj->Get(), + torch::from_blob(handle.shmemCombineOutTokMemObj->Get(), {handle.config.maxNumInpTokenPerRank, handle.config.hiddenDim}, options); std::optional outWeights{std::nullopt}; if (weightsPtr) { outWeights = - torch::from_blob(handle.shmemOutWeightsMemObj->Get(), + torch::from_blob(handle.shmemCombineOutWeightsMemObj->Get(), {handle.config.maxNumInpTokenPerRank, handle.config.numExpertPerToken}, torch::TensorOptions().dtype(weights->scalar_type()).device(torch::kCUDA)); }