Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions examples/ops/dispatch_combine/test_dispatch_combine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,8 @@ class EpDispatchCombineTestCase {
index_t srcPe = srcTokId / config.maxNumInpTokenPerRank;
index_t localSrcTokId = srcTokId - srcPe * config.maxNumInpTokenPerRank;

T* localTokBuf = handle.shmemOutTokMemObj->template GetAs<T*>() + i * config.hiddenDim;
T* localTokBuf =
handle.shmemDispatchOutTokMemObj->template GetAs<T*>() + i * config.hiddenDim;
T* srcTokBuf = reinterpret_cast<T*>(globalInpTokBufCpu) + srcPe * inpTokEleNum +
localSrcTokId * config.hiddenDim;

Expand Down Expand Up @@ -291,7 +292,8 @@ class EpDispatchCombineTestCase {
index_t srcTokDispatchId = peSortToTokenIdxMapsVec[srcPe][peSortedId];
index_t srcTokId = srcTokDispatchId / config.numExpertPerToken;

T* localTokBuf = handle.shmemOutTokMemObj->template GetAs<T*>() + i * config.hiddenDim;
T* localTokBuf =
handle.shmemDispatchOutTokMemObj->template GetAs<T*>() + i * config.hiddenDim;

T* srcTokBuf = reinterpret_cast<T*>(globalInpTokBufCpu) + srcPe * inpTokEleNum +
srcTokId * config.hiddenDim;
Expand Down Expand Up @@ -348,7 +350,7 @@ class EpDispatchCombineTestCase {
float expected =
float(T(float(reinterpret_cast<T*>(inpTokBufCpu)[tokenOffset + j]) * weightSum));
// float got = float(handle.outTokenBuf[tokenOffset + j]);
float got = float(handle.shmemOutTokMemObj->template GetAs<T*>()[tokenOffset + j]);
float got = float(handle.shmemCombineOutTokMemObj->template GetAs<T*>()[tokenOffset + j]);
assert(weightSum != 0);
if (abs(got - expected) > runConfig.atol) {
std::cout << "Wrong result at pos " << j << ": mype " << config.rank << " tokenId " << i
Expand Down Expand Up @@ -468,7 +470,7 @@ class EpDispatchCombineTestCase {
// HIP_RUNTIME_CHECK(hipMemcpy(inpTokBuf, outTokBuf,
// config.MaxNumTokensToRecvPerRank() * config.hiddenDim *
// sizeof(T), hipMemcpyDeviceToDevice));
HIP_RUNTIME_CHECK(hipMemcpy(inpTokBuf, handle.shmemOutTokMemObj->template GetAs<T*>(),
HIP_RUNTIME_CHECK(hipMemcpy(inpTokBuf, handle.shmemDispatchOutTokMemObj->template GetAs<T*>(),
config.MaxNumTokensToRecvPerRank() * config.hiddenDim * sizeof(T),
hipMemcpyDeviceToDevice));
HIP_RUNTIME_CHECK(
Expand Down
18 changes: 12 additions & 6 deletions include/mori/ops/dispatch_combine/dispatch_combine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,14 @@ class EpDispatchCombineHandle {
// Registered buffers for tokens, shmemOutTokMemObj will be returned to user as output
mori::application::SymmMemObjPtr shmemDispatchInpTokMemObj;
mori::application::SymmMemObjPtr shmemCombineInpTokMemObj;
mori::application::SymmMemObjPtr shmemOutTokMemObj;
mori::application::SymmMemObjPtr shmemDispatchOutTokMemObj;
mori::application::SymmMemObjPtr shmemCombineOutTokMemObj;
mori::application::SymmMemObjPtr shmemStagingTokMemObj;

// Registered buffer used for weights, indices and scales
mori::application::SymmMemObjPtr shmemInpWeightsMemObj;
mori::application::SymmMemObjPtr shmemOutWeightsMemObj;
mori::application::SymmMemObjPtr shmemDispatchOutWeightsMemObj;
mori::application::SymmMemObjPtr shmemCombineOutWeightsMemObj;
mori::application::SymmMemObjPtr shmemInpScalesMemObj;
mori::application::SymmMemObjPtr shmemOutScalesMemObj;
mori::application::SymmMemObjPtr shmemInpIndicesMemObj;
Expand Down Expand Up @@ -231,10 +233,12 @@ struct EpDispatchCombineArgs {
uint8_t* scalesBuf{nullptr};
mori::application::SymmMemObjPtr shmemDispatchInpTokMemObj;
mori::application::SymmMemObjPtr shmemCombineInpTokMemObj;
mori::application::SymmMemObjPtr shmemOutTokMemObj;
mori::application::SymmMemObjPtr shmemDispatchOutTokMemObj;
mori::application::SymmMemObjPtr shmemCombineOutTokMemObj;
mori::application::SymmMemObjPtr shmemStagingTokMemObj;
mori::application::SymmMemObjPtr shmemInpWeightsMemObj;
mori::application::SymmMemObjPtr shmemOutWeightsMemObj;
mori::application::SymmMemObjPtr shmemDispatchOutWeightsMemObj;
mori::application::SymmMemObjPtr shmemCombineOutWeightsMemObj;
mori::application::SymmMemObjPtr shmemInpScalesMemObj;
mori::application::SymmMemObjPtr shmemOutScalesMemObj;
mori::application::SymmMemObjPtr shmemInpIndicesMemObj;
Expand Down Expand Up @@ -276,10 +280,12 @@ EpDispatchCombineArgs<T> GetEpDispatchCombineArgs(const EpDispatchCombineHandle&
args.localPeTokenCounter = handle.localPeTokenCounter;
args.shmemDispatchInpTokMemObj = handle.shmemDispatchInpTokMemObj;
args.shmemCombineInpTokMemObj = handle.shmemCombineInpTokMemObj;
args.shmemOutTokMemObj = handle.shmemOutTokMemObj;
args.shmemDispatchOutTokMemObj = handle.shmemDispatchOutTokMemObj;
args.shmemCombineOutTokMemObj = handle.shmemCombineOutTokMemObj;
args.shmemStagingTokMemObj = handle.shmemStagingTokMemObj;
args.shmemInpWeightsMemObj = handle.shmemInpWeightsMemObj;
args.shmemOutWeightsMemObj = handle.shmemOutWeightsMemObj;
args.shmemDispatchOutWeightsMemObj = handle.shmemDispatchOutWeightsMemObj;
args.shmemCombineOutWeightsMemObj = handle.shmemCombineOutWeightsMemObj;
args.shmemInpScalesMemObj = handle.shmemInpScalesMemObj;
args.shmemOutScalesMemObj = handle.shmemOutScalesMemObj;
args.shmemInpIndicesMemObj = handle.shmemInpIndicesMemObj;
Expand Down
14 changes: 10 additions & 4 deletions src/ops/dispatch_combine/dispatch_combine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,16 @@ void EpDispatchCombineHandle::InitializeShmemBuf() {
ShmemMallocAndReturnMemObjPtr(maxStagingTokSize, hipDeviceMallocUncached);
shmemCombineInpTokMemObj =
ShmemMallocAndReturnMemObjPtr(maxStagingTokSize, hipDeviceMallocUncached);
shmemOutTokMemObj = ShmemMallocAndReturnMemObjPtr(maxTokenSize, hipDeviceMallocUncached);
shmemDispatchOutTokMemObj = ShmemMallocAndReturnMemObjPtr(maxTokenSize, hipDeviceMallocUncached);
shmemCombineOutTokMemObj = ShmemMallocAndReturnMemObjPtr(maxTokenSize, hipDeviceMallocUncached);
shmemStagingTokMemObj = ShmemMallocAndReturnMemObjPtr(maxStagingTokSize, hipDeviceMallocUncached);

size_t maxWeightSize = config.MaxNumTokensToRecv() * config.numExpertPerToken * sizeof(float);
shmemInpWeightsMemObj = ShmemMallocAndReturnMemObjPtr(maxWeightSize, hipDeviceMallocUncached);
shmemOutWeightsMemObj = ShmemMallocAndReturnMemObjPtr(maxWeightSize, hipDeviceMallocUncached);
shmemDispatchOutWeightsMemObj =
ShmemMallocAndReturnMemObjPtr(maxWeightSize, hipDeviceMallocUncached);
shmemCombineOutWeightsMemObj =
ShmemMallocAndReturnMemObjPtr(maxWeightSize, hipDeviceMallocUncached);

if (config.scaleDim > 0 && config.scaleTypeSize > 0) {
size_t maxScaleSize = config.MaxNumTokensToRecv() * config.scaleDim * config.scaleTypeSize;
Expand All @@ -95,10 +99,12 @@ void EpDispatchCombineHandle::InitializeShmemBuf() {
void EpDispatchCombineHandle::FinalizeShmemBuf() {
ShmemFree(shmemDispatchInpTokMemObj->localPtr);
ShmemFree(shmemCombineInpTokMemObj->localPtr);
ShmemFree(shmemOutTokMemObj->localPtr);
ShmemFree(shmemDispatchOutTokMemObj->localPtr);
ShmemFree(shmemCombineOutTokMemObj->localPtr);
ShmemFree(shmemStagingTokMemObj->localPtr);
ShmemFree(shmemInpWeightsMemObj->localPtr);
ShmemFree(shmemOutWeightsMemObj->localPtr);
ShmemFree(shmemDispatchOutWeightsMemObj->localPtr);
ShmemFree(shmemCombineOutWeightsMemObj->localPtr);
if (shmemInpScalesMemObj.IsValid()) ShmemFree(shmemInpScalesMemObj->localPtr);
if (shmemOutScalesMemObj.IsValid()) ShmemFree(shmemOutScalesMemObj->localPtr);
ShmemFree(shmemInpIndicesMemObj->localPtr);
Expand Down
18 changes: 10 additions & 8 deletions src/ops/dispatch_combine/internode.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,10 +297,10 @@ __global__ void EpDispatchInterNodeKernel(EpDispatchCombineArgs<T> args) {
size_t localTokenOffset = size_t(localTokenIdx) * size_t(config.hiddenDim) * sizeof(T);
size_t peSortedTokenOffset = size_t(peSortedId) * stagingOffset;

core::WarpCopy(args.shmemOutTokMemObj->template GetAs<char*>() + localTokenOffset,
core::WarpCopy(args.shmemDispatchOutTokMemObj->template GetAs<char*>() + localTokenOffset,
args.shmemDispatchInpTokMemObj->template GetAs<char*>() + peSortedTokenOffset,
config.hiddenDim * sizeof(T));
core::WarpCopy(args.shmemOutWeightsMemObj->template GetAs<char*>() +
core::WarpCopy(args.shmemDispatchOutWeightsMemObj->template GetAs<char*>() +
localTokenIdx * config.numExpertPerToken * sizeof(float),
args.shmemDispatchInpTokMemObj->template GetAs<char*>() + peSortedTokenOffset +
weightOffset,
Expand Down Expand Up @@ -348,7 +348,8 @@ inline __device__ void CrossDeviceBarrierInterNodeKernel(EpDispatchCombineArgs<T
shmem::ShmemUint32WaitUntilEquals(args.combineGridBarrier, globalWarpNum);
}

volatile uint64_t* localBarrierPtr = args.crossDeviceBarrierMemObj->template GetAs<volatile uint64_t*>();
volatile uint64_t* localBarrierPtr =
args.crossDeviceBarrierMemObj->template GetAs<volatile uint64_t*>();
if (thdId < args.config.worldSize) {
uint64_t currentVal = core::AtomicLoadRelaxedSystem(localBarrierPtr + thdId);
#if DEBUG == 1
Expand Down Expand Up @@ -551,13 +552,14 @@ __global__ void EpCombineInterNodeKernel(EpDispatchCombineArgs<T> args) {
}

size_t offset = size_t(tokenId) * size_t(config.hiddenDim) + hiddenDimOffset;
core::WarpAccum<T, 8>(args.shmemOutTokMemObj->template GetAs<T*>() + offset, srcPtrs, nullptr,
config.numExpertPerToken, hiddenDimSize);
core::WarpAccum<T, 8>(args.shmemCombineOutTokMemObj->template GetAs<T*>() + offset, srcPtrs,
nullptr, config.numExpertPerToken, hiddenDimSize);

if (args.weightsBuf && inTokenPartId == warpsPerToken - 1) {
core::WarpAccum<float, 4>(
args.shmemOutWeightsMemObj->template GetAs<float*>() + tokenId * config.numExpertPerToken,
srcWeightsPtr, nullptr, config.numExpertPerToken, config.numExpertPerToken);
core::WarpAccum<float, 4>(args.shmemCombineOutWeightsMemObj->template GetAs<float*>() +
tokenId * config.numExpertPerToken,
srcWeightsPtr, nullptr, config.numExpertPerToken,
config.numExpertPerToken);
}
}
if (globalThdId == 0) {
Expand Down
17 changes: 9 additions & 8 deletions src/ops/dispatch_combine/intranode.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ __global__ void EpDispatchIntraNodeKernel(EpDispatchCombineArgs<T> args) {
// Write weights and indices
if (laneId < config.numExpertPerToken) {
if (args.weightsBuf) {
args.shmemOutWeightsMemObj->template GetAs<float*>(
args.shmemDispatchOutWeightsMemObj->template GetAs<float*>(
destPe)[destTokId * config.numExpertPerToken + laneId] =
args.weightsBuf[srcTokId * config.numExpertPerToken + laneId];
}
Expand All @@ -142,7 +142,7 @@ __global__ void EpDispatchIntraNodeKernel(EpDispatchCombineArgs<T> args) {

index_t srcTokOffset = srcTokId * config.hiddenDim;
index_t destTokOffset = destTokId * config.hiddenDim;
core::WarpCopy(args.shmemOutTokMemObj->template GetAs<T*>(destPe) + destTokOffset,
core::WarpCopy(args.shmemDispatchOutTokMemObj->template GetAs<T*>(destPe) + destTokOffset,
args.inpTokenBuf + srcTokOffset, config.hiddenDim);
}
}
Expand Down Expand Up @@ -261,14 +261,15 @@ __global__ void EpCombineIntraNodeKernel(EpDispatchCombineArgs<T> args) {
srcWeightsPtr[j] = nullptr;
}
}
core::WarpAccum<T, 4>(
args.shmemOutTokMemObj->template GetAs<T*>() + tokenId * config.hiddenDim + hiddenDimOffset,
srcPtrs, nullptr, config.numExpertPerToken, hiddenDimSize);
core::WarpAccum<T, 4>(args.shmemCombineOutTokMemObj->template GetAs<T*>() +
tokenId * config.hiddenDim + hiddenDimOffset,
srcPtrs, nullptr, config.numExpertPerToken, hiddenDimSize);

if (args.weightsBuf && inTokenPartId == warpsPerToken - 1) {
core::WarpAccum<float, 4>(
args.shmemOutWeightsMemObj->template GetAs<float*>() + tokenId * config.numExpertPerToken,
srcWeightsPtr, nullptr, config.numExpertPerToken, config.numExpertPerToken);
core::WarpAccum<float, 4>(args.shmemCombineOutWeightsMemObj->template GetAs<float*>() +
tokenId * config.numExpertPerToken,
srcWeightsPtr, nullptr, config.numExpertPerToken,
config.numExpertPerToken);
}
}
}
Expand Down
8 changes: 4 additions & 4 deletions src/pybind/mori.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,14 @@ LaunchDispatch(mori::moe::EpDispatchCombineHandle& handle, int kernelType,
at::cuda::getCurrentHIPStream());

torch::Tensor out =
torch::from_blob(handle.shmemOutTokMemObj->Get(),
torch::from_blob(handle.shmemDispatchOutTokMemObj->Get(),
{handle.config.MaxNumTokensToRecv(), handle.config.hiddenDim},
torch::TensorOptions().dtype(input.scalar_type()).device(torch::kCUDA));

std::optional<torch::Tensor> outWeights{std::nullopt};
if (weightPtr) {
outWeights = torch::from_blob(
handle.shmemOutWeightsMemObj->Get(),
handle.shmemDispatchOutWeightsMemObj->Get(),
{handle.config.MaxNumTokensToRecv(), handle.config.numExpertPerToken},
torch::TensorOptions().dtype(mori::GetTorchDataType<float>()).device(torch::kCUDA));
}
Expand Down Expand Up @@ -127,13 +127,13 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>> LaunchCombine(

auto options = torch::TensorOptions().dtype(input.scalar_type()).device(torch::kCUDA);
torch::Tensor out =
torch::from_blob(handle.shmemOutTokMemObj->Get(),
torch::from_blob(handle.shmemCombineOutTokMemObj->Get(),
{handle.config.maxNumInpTokenPerRank, handle.config.hiddenDim}, options);

std::optional<torch::Tensor> outWeights{std::nullopt};
if (weightsPtr) {
outWeights =
torch::from_blob(handle.shmemOutWeightsMemObj->Get(),
torch::from_blob(handle.shmemCombineOutWeightsMemObj->Get(),
{handle.config.maxNumInpTokenPerRank, handle.config.numExpertPerToken},
torch::TensorOptions().dtype(weights->scalar_type()).device(torch::kCUDA));
}
Expand Down