diff --git a/src/runtime_environment/device/cuda/cuda_backend.cxx b/src/runtime_environment/device/cuda/cuda_backend.cxx index 610f33f42..4c196cdfc 100644 --- a/src/runtime_environment/device/cuda/cuda_backend.cxx +++ b/src/runtime_environment/device/cuda/cuda_backend.cxx @@ -28,6 +28,33 @@ CUDABackend::CUDABackend() { } +#ifdef GAUXC_HAS_MPI +CUDABackend::CUDABackend(MPI_Comm c) +{ + comm = c; + MPI_Comm_split_type(comm, MPI_COMM_TYPE_SHARED, 0, + MPI_INFO_NULL, &local_comm); + MPI_Comm_size(local_comm, &local_size); + MPI_Comm_rank(local_comm, &local_rank); + int ndev; + auto stat = cudaGetDeviceCount(&ndev); + GAUXC_CUDA_ERROR("CUDA backend init failed", stat); + gpuid = local_rank % ndev; + cudaSetDevice(gpuid); + + // Create CUDA Stream and CUBLAS Handles and make them talk to eachother + master_stream = std::make_shared< util::cuda_stream >(); + master_handle = std::make_shared< util::cublas_handle >(); + + cublasSetStream( *master_handle, *master_stream ); + +#ifdef GAUXC_HAS_MAGMA + // Setup MAGMA queue with CUDA stream / cuBLAS handle + master_magma_queue_ = std::make_shared< util::magma_queue >(0, *master_stream, *master_handle); +#endif +} +#endif + CUDABackend::~CUDABackend() noexcept = default; CUDABackend::device_buffer_t CUDABackend::allocate_device_buffer(int64_t sz) { @@ -41,6 +68,15 @@ size_t CUDABackend::get_available_mem() { size_t cuda_avail, cuda_total; auto stat = cudaMemGetInfo( &cuda_avail, &cuda_total ); GAUXC_CUDA_ERROR( "MemInfo Failed", stat ); +#ifdef GAUXC_HAS_MPI + int ndev; + stat = cudaGetDeviceCount(&ndev); + GAUXC_CUDA_ERROR("MemInfo Failed while getting number of devices", stat); + double factor = 1.0 / ((local_size - 1) / ndev + 1); + factor = (factor > 1.0 ? 1.0 : factor); + cuda_avail = size_t(cuda_avail * factor); + MPI_Barrier(local_comm); +#endif return cuda_avail; } @@ -137,8 +173,7 @@ void CUDABackend::check_error_(std::string msg) { GAUXC_CUDA_ERROR("CUDA Failed ["+msg+"]", stat ); } - -std::unique_ptr make_device_backend() { - return std::make_unique(); +std::unique_ptr make_device_backend(GAUXC_MPI_CODE(MPI_Comm c)) { + return std::make_unique(GAUXC_MPI_CODE(c)); } } diff --git a/src/runtime_environment/device/cuda/cuda_backend.hpp b/src/runtime_environment/device/cuda/cuda_backend.hpp index 8e47a6a9a..916cad88c 100644 --- a/src/runtime_environment/device/cuda/cuda_backend.hpp +++ b/src/runtime_environment/device/cuda/cuda_backend.hpp @@ -52,6 +52,15 @@ struct CUDABackend : public DeviceBackend { std::vector> blas_streams; std::vector> blas_handles; + +#ifdef GAUXC_HAS_MPI + MPI_Comm comm = MPI_COMM_NULL; + MPI_Comm local_comm = MPI_COMM_NULL; + int gpuid = 0; + int local_rank = 0; + int local_size = 1; + CUDABackend(MPI_Comm comm); +#endif }; } diff --git a/src/runtime_environment/device/device_backend.hpp b/src/runtime_environment/device/device_backend.hpp index 594b79888..f506543f2 100644 --- a/src/runtime_environment/device/device_backend.hpp +++ b/src/runtime_environment/device/device_backend.hpp @@ -17,6 +17,7 @@ #include "device_queue.hpp" #include "device_blas_handle.hpp" #include +#include "gauxc/runtime_environment.hpp" #ifdef GAUXC_HAS_MAGMA #include "device_specific/magma_util.hpp" @@ -99,6 +100,5 @@ class DeviceBackend { /// Generate the default device backend for this platform -std::unique_ptr make_device_backend(); - +std::unique_ptr make_device_backend(GAUXC_MPI_CODE(MPI_Comm c)); } diff --git a/src/runtime_environment/device/device_runtime_environment_impl.hpp b/src/runtime_environment/device/device_runtime_environment_impl.hpp index 9831c5c21..7d5d0ddf5 100644 --- a/src/runtime_environment/device/device_runtime_environment_impl.hpp +++ b/src/runtime_environment/device/device_runtime_environment_impl.hpp @@ -35,7 +35,7 @@ class DeviceRuntimeEnvironmentImpl : public RuntimeEnvironmentImpl { size_t sz) : parent_type(GAUXC_MPI_CODE(c)), i_own_this_memory_(false), device_memory_(p), device_memory_size_(sz), - device_backend_{make_device_backend()} {} + device_backend_{make_device_backend(GAUXC_MPI_CODE(c))} {} explicit DeviceRuntimeEnvironmentImpl(GAUXC_MPI_CODE(MPI_Comm c,) diff --git a/src/runtime_environment/device/hip/hip_backend.cxx b/src/runtime_environment/device/hip/hip_backend.cxx index 69c3fd286..be16c8f8a 100644 --- a/src/runtime_environment/device/hip/hip_backend.cxx +++ b/src/runtime_environment/device/hip/hip_backend.cxx @@ -128,7 +128,8 @@ void HIPBackend::check_error_(std::string msg) { GAUXC_HIP_ERROR("HIP Failed ["+msg+"]", stat ); } -std::unique_ptr make_device_backend() { +std::unique_ptr make_device_backend(GAUXC_MPI_CODE(MPI_Comm c)) +{ return std::make_unique(); } }