Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 38 additions & 3 deletions src/runtime_environment/device/cuda/cuda_backend.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,33 @@ CUDABackend::CUDABackend() {

}

#ifdef GAUXC_HAS_MPI
CUDABackend::CUDABackend(MPI_Comm c)
{
comm = c;
MPI_Comm_split_type(comm, MPI_COMM_TYPE_SHARED, 0,
MPI_INFO_NULL, &local_comm);
MPI_Comm_size(local_comm, &local_size);
MPI_Comm_rank(local_comm, &local_rank);
int ndev;
auto stat = cudaGetDeviceCount(&ndev);
GAUXC_CUDA_ERROR("CUDA backend init failed", stat);
gpuid = local_rank % ndev;
cudaSetDevice(gpuid);

// Create CUDA Stream and CUBLAS Handles and make them talk to eachother
master_stream = std::make_shared< util::cuda_stream >();
master_handle = std::make_shared< util::cublas_handle >();

cublasSetStream( *master_handle, *master_stream );

#ifdef GAUXC_HAS_MAGMA
// Setup MAGMA queue with CUDA stream / cuBLAS handle
master_magma_queue_ = std::make_shared< util::magma_queue >(0, *master_stream, *master_handle);
#endif
}
#endif

CUDABackend::~CUDABackend() noexcept = default;

CUDABackend::device_buffer_t CUDABackend::allocate_device_buffer(int64_t sz) {
Expand All @@ -41,6 +68,15 @@ size_t CUDABackend::get_available_mem() {
size_t cuda_avail, cuda_total;
auto stat = cudaMemGetInfo( &cuda_avail, &cuda_total );
GAUXC_CUDA_ERROR( "MemInfo Failed", stat );
#ifdef GAUXC_HAS_MPI
int ndev;
stat = cudaGetDeviceCount(&ndev);
GAUXC_CUDA_ERROR("MemInfo Failed while getting number of devices", stat);
double factor = 1.0 / ((local_size - 1) / ndev + 1);
factor = (factor > 1.0 ? 1.0 : factor);
cuda_avail = size_t(cuda_avail * factor);
MPI_Barrier(local_comm);
#endif
return cuda_avail;
}

Expand Down Expand Up @@ -137,8 +173,7 @@ void CUDABackend::check_error_(std::string msg) {
GAUXC_CUDA_ERROR("CUDA Failed ["+msg+"]", stat );
}


std::unique_ptr<DeviceBackend> make_device_backend() {
return std::make_unique<CUDABackend>();
std::unique_ptr<DeviceBackend> make_device_backend(GAUXC_MPI_CODE(MPI_Comm c)) {
return std::make_unique<CUDABackend>(GAUXC_MPI_CODE(c));
}
}
9 changes: 9 additions & 0 deletions src/runtime_environment/device/cuda/cuda_backend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,15 @@ struct CUDABackend : public DeviceBackend {

std::vector<std::shared_ptr<util::cuda_stream>> blas_streams;
std::vector<std::shared_ptr<util::cublas_handle>> blas_handles;

#ifdef GAUXC_HAS_MPI
MPI_Comm comm = MPI_COMM_NULL;
MPI_Comm local_comm = MPI_COMM_NULL;
int gpuid = 0;
int local_rank = 0;
int local_size = 1;
CUDABackend(MPI_Comm comm);
#endif
};

}
4 changes: 2 additions & 2 deletions src/runtime_environment/device/device_backend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "device_queue.hpp"
#include "device_blas_handle.hpp"
#include <gauxc/gauxc_config.hpp>
#include "gauxc/runtime_environment.hpp"

#ifdef GAUXC_HAS_MAGMA
#include "device_specific/magma_util.hpp"
Expand Down Expand Up @@ -99,6 +100,5 @@ class DeviceBackend {


/// Generate the default device backend for this platform
std::unique_ptr<DeviceBackend> make_device_backend();

std::unique_ptr<DeviceBackend> make_device_backend(GAUXC_MPI_CODE(MPI_Comm c));
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class DeviceRuntimeEnvironmentImpl : public RuntimeEnvironmentImpl {
size_t sz) : parent_type(GAUXC_MPI_CODE(c)),
i_own_this_memory_(false), device_memory_(p),
device_memory_size_(sz),
device_backend_{make_device_backend()} {}
device_backend_{make_device_backend(GAUXC_MPI_CODE(c))} {}


explicit DeviceRuntimeEnvironmentImpl(GAUXC_MPI_CODE(MPI_Comm c,)
Expand Down
3 changes: 2 additions & 1 deletion src/runtime_environment/device/hip/hip_backend.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ void HIPBackend::check_error_(std::string msg) {
GAUXC_HIP_ERROR("HIP Failed ["+msg+"]", stat );
}

std::unique_ptr<DeviceBackend> make_device_backend() {
std::unique_ptr<DeviceBackend> make_device_backend(GAUXC_MPI_CODE(MPI_Comm c))
{
return std::make_unique<HIPBackend>();
}
}