diff --git a/include/gauxc/runtime_environment/decl.hpp b/include/gauxc/runtime_environment/decl.hpp index 424f9d981..5ae8bc0d2 100644 --- a/include/gauxc/runtime_environment/decl.hpp +++ b/include/gauxc/runtime_environment/decl.hpp @@ -71,6 +71,7 @@ class DeviceRuntimeEnvironment : public RuntimeEnvironment { DeviceRuntimeEnvironment(GAUXC_MPI_CODE(MPI_Comm comm,) void* mem, size_t mem_sz); DeviceRuntimeEnvironment(GAUXC_MPI_CODE(MPI_Comm,) double fill_fraction); + DeviceRuntimeEnvironment(GAUXC_MPI_CODE(MPI_Comm,) size_t nbytes); ~DeviceRuntimeEnvironment() noexcept; DeviceRuntimeEnvironment( const DeviceRuntimeEnvironment& ); diff --git a/include/gauxc/xc_integrator.hpp b/include/gauxc/xc_integrator.hpp index 03feaf934..a8d091309 100644 --- a/include/gauxc/xc_integrator.hpp +++ b/include/gauxc/xc_integrator.hpp @@ -40,6 +40,7 @@ class XCIntegrator { using exc_vxc_type_gks = std::tuple< value_type, matrix_type, matrix_type, matrix_type, matrix_type >; using exc_grad_type = std::vector< value_type >; using exx_type = matrix_type; + using exx_grad_type = std::vector< value_type >; using fxc_contraction_type_rks = matrix_type; using fxc_contraction_type_uks = std::tuple< matrix_type, matrix_type >; using dd_psi_type = std::vector< value_type >; @@ -80,6 +81,8 @@ class XCIntegrator { exx_type eval_exx ( const MatrixType&, const IntegratorSettingsEXX& = IntegratorSettingsEXX{} ); + exx_grad_type eval_exx_grad( const MatrixType&, + const IntegratorSettingsEXX& = IntegratorSettingsEXX{} ); fxc_contraction_type_rks eval_fxc_contraction ( const MatrixType&, const MatrixType&, const IntegratorSettingsXC& = IntegratorSettingsXC{} ); fxc_contraction_type_uks eval_fxc_contraction ( const MatrixType&, const MatrixType&, const MatrixType&, const MatrixType&, diff --git a/include/gauxc/xc_integrator/impl.hpp b/include/gauxc/xc_integrator/impl.hpp index 400afb7c7..3349fea61 100644 --- a/include/gauxc/xc_integrator/impl.hpp +++ b/include/gauxc/xc_integrator/impl.hpp @@ -100,6 +100,14 @@ typename XCIntegrator::exx_type return pimpl_->eval_exx(P,settings); }; +template +typename XCIntegrator::exx_grad_type + XCIntegrator::eval_exx_grad( const MatrixType& P, + const IntegratorSettingsEXX& settings ) { + if( not pimpl_ ) GAUXC_PIMPL_NOT_INITIALIZED(); + return pimpl_->eval_exx_grad(P,settings); +}; + template typename XCIntegrator::fxc_contraction_type_rks XCIntegrator::eval_fxc_contraction( const MatrixType& P, const MatrixType& tP, diff --git a/include/gauxc/xc_integrator/replicated/impl.hpp b/include/gauxc/xc_integrator/replicated/impl.hpp index bfc95fc88..0ec1bd6af 100644 --- a/include/gauxc/xc_integrator/replicated/impl.hpp +++ b/include/gauxc/xc_integrator/replicated/impl.hpp @@ -265,5 +265,20 @@ typename ReplicatedXCIntegrator::dd_psi_potential_type } +template +typename ReplicatedXCIntegrator::exx_grad_type + ReplicatedXCIntegrator::eval_exx_grad_( const MatrixType& P, const IntegratorSettingsEXX& settings ) { + + if( not pimpl_ ) GAUXC_PIMPL_NOT_INITIALIZED(); + + std::vector EXX_GRAD( 3*pimpl_->load_balancer().molecule().natoms() ); + pimpl_->eval_exx_grad( P.rows(), P.cols(), P.data(), P.rows(), + EXX_GRAD.data(), + settings ); + + return EXX_GRAD; + +} + } } diff --git a/include/gauxc/xc_integrator/replicated/replicated_xc_integrator_impl.hpp b/include/gauxc/xc_integrator/replicated/replicated_xc_integrator_impl.hpp index 457315122..fcec6e84e 100644 --- a/include/gauxc/xc_integrator/replicated/replicated_xc_integrator_impl.hpp +++ b/include/gauxc/xc_integrator/replicated/replicated_xc_integrator_impl.hpp @@ -103,6 +103,9 @@ class ReplicatedXCIntegratorImpl { virtual void eval_dd_psi_potential_( int64_t m, int64_t n, const value_type* X, unsigned max_Ylm, value_type* Vddx) = 0; + virtual void eval_exx_grad_( int64_t m, int64_t n, const value_type* P, + int64_t ldp, value_type* EXX_GRAD, + const IntegratorSettingsEXX& settings ) = 0; public: ReplicatedXCIntegratorImpl( std::shared_ptr< functional_type > func, @@ -162,6 +165,9 @@ class ReplicatedXCIntegratorImpl { int64_t ldp, value_type* K, int64_t ldk, const IntegratorSettingsEXX& settings ); + void eval_exx_grad( int64_t m, int64_t n, const value_type* P, + int64_t ldp, value_type* EXX_GRAD, + const IntegratorSettingsEXX& settings ); void eval_fxc_contraction( int64_t m, int64_t n, const value_type* P, int64_t ldp, const value_type* tP, int64_t ldtp, diff --git a/include/gauxc/xc_integrator/replicated_xc_integrator.hpp b/include/gauxc/xc_integrator/replicated_xc_integrator.hpp index 1ca53f917..1b6a725c3 100644 --- a/include/gauxc/xc_integrator/replicated_xc_integrator.hpp +++ b/include/gauxc/xc_integrator/replicated_xc_integrator.hpp @@ -37,6 +37,7 @@ class ReplicatedXCIntegrator : public XCIntegratorImpl { using exc_vxc_type_gks = typename XCIntegratorImpl::exc_vxc_type_gks; using exc_grad_type = typename XCIntegratorImpl::exc_grad_type; using exx_type = typename XCIntegratorImpl::exx_type; + using exx_grad_type = typename XCIntegratorImpl::exx_grad_type; using fxc_contraction_type_rks = typename XCIntegratorImpl::fxc_contraction_type_rks; using fxc_contraction_type_uks = typename XCIntegratorImpl::fxc_contraction_type_uks; using dd_psi_type = typename XCIntegratorImpl::dd_psi_type; @@ -57,6 +58,7 @@ class ReplicatedXCIntegrator : public XCIntegratorImpl { exc_grad_type eval_exc_grad_( const MatrixType&, const IntegratorSettingsXC& ) override; exc_grad_type eval_exc_grad_( const MatrixType&, const MatrixType&, const IntegratorSettingsXC& ) override; exx_type eval_exx_ ( const MatrixType&, const IntegratorSettingsEXX& ) override; + exx_grad_type eval_exx_grad_( const MatrixType&, const IntegratorSettingsEXX& ) override; fxc_contraction_type_rks eval_fxc_contraction_ ( const MatrixType&, const MatrixType&, const IntegratorSettingsXC& ) override; fxc_contraction_type_uks eval_fxc_contraction_ ( const MatrixType&, const MatrixType&, const MatrixType&, const MatrixType&, const IntegratorSettingsXC&) override; dd_psi_type eval_dd_psi_( const MatrixType& , unsigned ) override; diff --git a/include/gauxc/xc_integrator/xc_integrator_impl.hpp b/include/gauxc/xc_integrator/xc_integrator_impl.hpp index ba7bebebb..13b2cc50e 100644 --- a/include/gauxc/xc_integrator/xc_integrator_impl.hpp +++ b/include/gauxc/xc_integrator/xc_integrator_impl.hpp @@ -29,6 +29,7 @@ class XCIntegratorImpl { using exc_vxc_type_gks = typename XCIntegrator::exc_vxc_type_gks; using exc_grad_type = typename XCIntegrator::exc_grad_type; using exx_type = typename XCIntegrator::exx_type; + using exx_grad_type = typename XCIntegrator::exx_grad_type; using fxc_contraction_type_rks = typename XCIntegrator::fxc_contraction_type_rks; using fxc_contraction_type_uks = typename XCIntegrator::fxc_contraction_type_uks; using dd_psi_type = typename XCIntegrator::dd_psi_type; @@ -50,6 +51,8 @@ class XCIntegratorImpl { virtual exc_grad_type eval_exc_grad_( const MatrixType& Ps, const MatrixType& Pz, const IntegratorSettingsXC& ks_settings ) = 0; virtual exx_type eval_exx_ ( const MatrixType& P, const IntegratorSettingsEXX& settings ) = 0; + virtual exx_grad_type eval_exx_grad_ ( const MatrixType& P, + const IntegratorSettingsEXX& settings ) = 0; virtual fxc_contraction_type_rks eval_fxc_contraction_ ( const MatrixType& P, const MatrixType& tP, const IntegratorSettingsXC& ks_settings ) = 0; virtual fxc_contraction_type_uks eval_fxc_contraction_ ( const MatrixType& Ps, const MatrixType& Pz, @@ -151,6 +154,15 @@ class XCIntegratorImpl { return eval_exx_(P,settings); } + /** Integrate Exact Exchange nuclear + * derivatives for RHF + * + * @param[in] P The alpha density matrix + * @returns Excact Exchange Matrix + */ + exx_grad_type eval_exx_grad( const MatrixType& P, const IntegratorSettingsEXX& settings ) { + return eval_exx_grad_(P,settings); + } /** Integrate FXC contraction for RKS * diff --git a/src/runtime_environment/device/device_runtime_environment.cxx b/src/runtime_environment/device/device_runtime_environment.cxx index 88998bd9b..1634d1893 100644 --- a/src/runtime_environment/device/device_runtime_environment.cxx +++ b/src/runtime_environment/device/device_runtime_environment.cxx @@ -51,6 +51,10 @@ DeviceRuntimeEnvironment::DeviceRuntimeEnvironment( GAUXC_MPI_CODE(MPI_Comm c,) double ff) : RuntimeEnvironment(detail::make_device_runtime(GAUXC_MPI_CODE(c,)ff)) {} +DeviceRuntimeEnvironment::DeviceRuntimeEnvironment( + GAUXC_MPI_CODE(MPI_Comm c,) size_t nbytes) : + RuntimeEnvironment(detail::make_device_runtime(GAUXC_MPI_CODE(c,)nbytes)) {} + DeviceRuntimeEnvironment::~DeviceRuntimeEnvironment() noexcept = default; DeviceRuntimeEnvironment::DeviceRuntimeEnvironment( diff --git a/src/runtime_environment/device/device_runtime_environment_impl.hpp b/src/runtime_environment/device/device_runtime_environment_impl.hpp index 9831c5c21..7e49a33e9 100644 --- a/src/runtime_environment/device/device_runtime_environment_impl.hpp +++ b/src/runtime_environment/device/device_runtime_environment_impl.hpp @@ -53,6 +53,17 @@ class DeviceRuntimeEnvironmentImpl : public RuntimeEnvironmentImpl { } + explicit DeviceRuntimeEnvironmentImpl(GAUXC_MPI_CODE(MPI_Comm c,) + size_t nbytes) : + DeviceRuntimeEnvironmentImpl(GAUXC_MPI_CODE(c,) nullptr, 0) { + + std::tie( device_memory_, device_memory_size_ ) = + device_backend_->allocate_device_buffer(nbytes); + + i_own_this_memory_ = true; + + } + ~DeviceRuntimeEnvironmentImpl() noexcept { if(i_own_this_memory_ and device_memory_ and device_memory_size_) { device_backend_->free_device_buffer(device_memory_); diff --git a/src/xc_integrator/local_work_driver/device/common/device_blas.hpp b/src/xc_integrator/local_work_driver/device/common/device_blas.hpp index dc1f0d8f8..ad507dc35 100644 --- a/src/xc_integrator/local_work_driver/device/common/device_blas.hpp +++ b/src/xc_integrator/local_work_driver/device/common/device_blas.hpp @@ -47,6 +47,22 @@ void gdot( device_blas_handle handle, T* SCR, T* RES ); +template +void matrix_reduce_rows( device_blas_handle handle, + int M, + int N, + const T* A, + int LDA, + T* X ); + +template +void matrix_reduce_cols( device_blas_handle handle, + int M, + int N, + const T* A, + int LDA, + T* X ); + template void hadamard_product( device_blas_handle handle, diff --git a/src/xc_integrator/local_work_driver/device/cuda/kernels/cublas_extensions.cu b/src/xc_integrator/local_work_driver/device/cuda/kernels/cublas_extensions.cu index 947d7b184..b6c875e12 100644 --- a/src/xc_integrator/local_work_driver/device/cuda/kernels/cublas_extensions.cu +++ b/src/xc_integrator/local_work_driver/device/cuda/kernels/cublas_extensions.cu @@ -137,6 +137,92 @@ void __global__ hadamard_product_kernel( int M, } +template +void __global__ matrix_reduce_rows_kernel( int M, + int N, + const T* A, + int LDA, + T* X ) { + + auto j = blockIdx.x * blockDim.x + threadIdx.x; + if( j < M ) { + for (size_t i = 0; i < N; i++) { + X[j] += A[ j*LDA + i ]; + } + } + +} + + +template +void __global__ matrix_reduce_cols_kernel( int M, + int N, + const T* A, + int LDA, + T* X ) { + + auto j = blockIdx.x * blockDim.x + threadIdx.x; + + if( j < N ) { + for (size_t i = 0; i < M; i++) { + X[j] += A[ j + LDA * i ]; + } + } + +} + +template +void matrix_reduce_rows( device_blas_handle generic_handle, + int M, + int N, + const T* A, + int LDA, + T* X ) { + + + cublasHandle_t handle = generic_handle.blas_handle_as(); + auto stream = util::get_stream(handle); + dim3 threads(cuda::warp_size, 1, 1); + dim3 blocks( util::div_ceil( M, cuda::warp_size ), 1, 1); + + matrix_reduce_rows_kernel<<< blocks, threads, 0, stream >>>( M, N, A, LDA, X ); + +} + +template +void matrix_reduce_cols( device_blas_handle generic_handle, + int M, + int N, + const T* A, + int LDA, + T* X ) { + + + cublasHandle_t handle = generic_handle.blas_handle_as(); + auto stream = util::get_stream(handle); + dim3 threads(cuda::warp_size, 1, 1); + dim3 blocks( util::div_ceil( N, cuda::warp_size ), 1, 1); + + matrix_reduce_cols_kernel<<< blocks, threads, 0, stream >>>( M, N, A, LDA, X ); + +} + +template +void matrix_reduce_rows( device_blas_handle generic_handle, + int M, + int N, + const double* A, + int LDA, + double* X ); + +template +void matrix_reduce_cols( device_blas_handle generic_handle, + int M, + int N, + const double* A, + int LDA, + double* X ); + template diff --git a/src/xc_integrator/local_work_driver/device/local_device_work_driver.cxx b/src/xc_integrator/local_work_driver/device/local_device_work_driver.cxx index 89626b466..94dde2922 100644 --- a/src/xc_integrator/local_work_driver/device/local_device_work_driver.cxx +++ b/src/xc_integrator/local_work_driver/device/local_device_work_driver.cxx @@ -128,6 +128,8 @@ FWD_TO_PIMPL_DEN_ID_BOOL(inc_vxc) // Increment VXC_I by Z FWD_TO_PIMPL_DEN_ID_BOOL(inc_fxc) // Increment FXC_I by Z FWD_TO_PIMPL(inc_exx_k) +FWD_TO_PIMPL(eval_exx_kgrad) +FWD_TO_PIMPL(inc_exx_kgrad) FWD_TO_PIMPL_KS_SCHEME_BOOL(inc_exc_grad_lda) FWD_TO_PIMPL_KS_SCHEME_BOOL(inc_exc_grad_gga) FWD_TO_PIMPL_KS_SCHEME_BOOL_BOOL(inc_exc_grad_mgga) diff --git a/src/xc_integrator/local_work_driver/device/local_device_work_driver.hpp b/src/xc_integrator/local_work_driver/device/local_device_work_driver.hpp index 8c65c075e..ca9c5bd07 100644 --- a/src/xc_integrator/local_work_driver/device/local_device_work_driver.hpp +++ b/src/xc_integrator/local_work_driver/device/local_device_work_driver.hpp @@ -115,6 +115,8 @@ class LocalDeviceWorkDriver : public LocalWorkDriver { void inc_exc_grad_gga( XCDeviceData*, integrator_ks_scheme, bool ); void inc_exc_grad_mgga( XCDeviceData*, integrator_ks_scheme , bool, bool ); void inc_exx_k( XCDeviceData* ); + void eval_exx_kgrad( XCDeviceData* ); + void inc_exx_kgrad( XCDeviceData* ); void eval_exx_ek_screening_bfn_stats( XCDeviceData* ); void exx_ek_shellpair_collision( double eps_E, double eps_K, XCDeviceData*, diff --git a/src/xc_integrator/local_work_driver/device/local_device_work_driver_pimpl.hpp b/src/xc_integrator/local_work_driver/device/local_device_work_driver_pimpl.hpp index f7178a8f3..7fcd68253 100644 --- a/src/xc_integrator/local_work_driver/device/local_device_work_driver_pimpl.hpp +++ b/src/xc_integrator/local_work_driver/device/local_device_work_driver_pimpl.hpp @@ -73,6 +73,8 @@ struct LocalDeviceWorkDriverPIMPL { virtual void inc_exc_grad_gga( XCDeviceData*, integrator_ks_scheme, bool ) = 0; virtual void inc_exc_grad_mgga( XCDeviceData*, integrator_ks_scheme , bool, bool ) = 0; virtual void inc_exx_k( XCDeviceData* ) = 0; + virtual void eval_exx_kgrad( XCDeviceData* ) = 0; + virtual void inc_exx_kgrad( XCDeviceData* ) = 0; virtual void symmetrize_vxc( XCDeviceData*, density_id ) = 0; virtual void symmetrize_fxc( XCDeviceData*, density_id ) = 0; virtual void symmetrize_exx_k( XCDeviceData* ) = 0; diff --git a/src/xc_integrator/local_work_driver/device/scheme1_base.cxx b/src/xc_integrator/local_work_driver/device/scheme1_base.cxx index d28013070..e24df5387 100644 --- a/src/xc_integrator/local_work_driver/device/scheme1_base.cxx +++ b/src/xc_integrator/local_work_driver/device/scheme1_base.cxx @@ -2354,6 +2354,138 @@ void AoSScheme1Base::inc_exx_k( XCDeviceData* _data ) { #endif } +void AoSScheme1Base::eval_exx_kgrad( XCDeviceData* _data ) { +#ifndef GAUXC_ENABLE_EXX + GAUXC_GENERIC_EXCEPTION("EXX + non-CUDA NYI"); +#else + auto* data = dynamic_cast(_data); + if( !data ) GAUXC_BAD_LWD_DATA_CAST(); + + if( not data->device_backend_ ) GAUXC_UNINITIALIZED_DEVICE_BACKEND(); + + auto& tasks = data->host_device_tasks; + const auto ntasks = tasks.size(); + + // KX + { + // Sync blas streams with master stream + data->device_backend_->sync_blas_pool_with_master(); + + // Launch GEMM in round-robin + const auto n_blas_streams = data->device_backend_->blas_pool_size(); + for( size_t iT = 0; iT < ntasks; ++iT ) { + auto& task = tasks[iT]; + auto handle = data->device_backend_->blas_pool_handle( iT % n_blas_streams ); + auto npts = task.npts; + auto nbe_bfn = task.bfn_screening.nbe; + auto nbe_cou = task.cou_screening.nbe; + gemm( handle, DeviceBlasOp::Trans, DeviceBlasOp::NoTrans, + nbe_bfn, nbe_cou, npts, 1., task.dbfx, npts, task.gmat, npts, 0., + task.nbe_scr, nbe_bfn ); + } + + // Record completion of BLAS ops on master stream + data->device_backend_->sync_master_with_blas_pool(); + + // Increment EXX_K + const auto nbf = data->global_dims.nbf; + const auto submat_block_size = data->get_submat_chunk_size( nbf, 0 ); + auto static_stack = data->static_stack; + auto aos_stack = data->aos_stack; + asym_task_inc_potential( ntasks, aos_stack.device_tasks, + static_stack.exx_kx_device, nbf, submat_block_size, + data->device_backend_->queue() ); + } + + // KY + { + // Sync blas streams with master stream + data->device_backend_->sync_blas_pool_with_master(); + + // Launch GEMM in round-robin + const auto n_blas_streams = data->device_backend_->blas_pool_size(); + for( size_t iT = 0; iT < ntasks; ++iT ) { + auto& task = tasks[iT]; + auto handle = data->device_backend_->blas_pool_handle( iT % n_blas_streams ); + auto npts = task.npts; + auto nbe_bfn = task.bfn_screening.nbe; + auto nbe_cou = task.cou_screening.nbe; + gemm( handle, DeviceBlasOp::Trans, DeviceBlasOp::NoTrans, + nbe_bfn, nbe_cou, npts, 1., task.dbfy, npts, task.gmat, npts, 0., + task.nbe_scr, nbe_bfn ); + } + + // Record completion of BLAS ops on master stream + data->device_backend_->sync_master_with_blas_pool(); + + // Increment EXX_K + const auto nbf = data->global_dims.nbf; + const auto submat_block_size = data->get_submat_chunk_size( nbf, 0 ); + auto static_stack = data->static_stack; + auto aos_stack = data->aos_stack; + asym_task_inc_potential( ntasks, aos_stack.device_tasks, + static_stack.exx_ky_device, nbf, submat_block_size, + data->device_backend_->queue() ); + } + + // KZ + { + // Sync blas streams with master stream + data->device_backend_->sync_blas_pool_with_master(); + + // Launch GEMM in round-robin + const auto n_blas_streams = data->device_backend_->blas_pool_size(); + for( size_t iT = 0; iT < ntasks; ++iT ) { + auto& task = tasks[iT]; + auto handle = data->device_backend_->blas_pool_handle( iT % n_blas_streams ); + auto npts = task.npts; + auto nbe_bfn = task.bfn_screening.nbe; + auto nbe_cou = task.cou_screening.nbe; + gemm( handle, DeviceBlasOp::Trans, DeviceBlasOp::NoTrans, + nbe_bfn, nbe_cou, npts, 1., task.dbfz, npts, task.gmat, npts, 0., + task.nbe_scr, nbe_bfn ); + } + + // Record completion of BLAS ops on master stream + data->device_backend_->sync_master_with_blas_pool(); + + // Increment EXX_K + const auto nbf = data->global_dims.nbf; + const auto submat_block_size = data->get_submat_chunk_size( nbf, 0 ); + auto static_stack = data->static_stack; + auto aos_stack = data->aos_stack; + asym_task_inc_potential( ntasks, aos_stack.device_tasks, + static_stack.exx_kz_device, nbf, submat_block_size, + data->device_backend_->queue() ); + } +#endif +} + +void AoSScheme1Base::inc_exx_kgrad( XCDeviceData* _data ) { +#ifndef GAUXC_ENABLE_EXX + GAUXC_GENERIC_EXCEPTION("EXX + non-CUDA NYI"); +#else + auto* data = dynamic_cast(_data); + if( !data ) GAUXC_BAD_LWD_DATA_CAST(); + + if( not data->device_backend_ ) GAUXC_UNINITIALIZED_DEVICE_BACKEND(); + + const auto nbf = data->global_dims.nbf; + auto static_stack = data->static_stack; + + hadamard_product( data->device_backend_->master_blas_handle(), nbf, nbf, static_stack.dmat_s_device, nbf, static_stack.exx_kx_device, nbf ); + hadamard_product( data->device_backend_->master_blas_handle(), nbf, nbf, static_stack.dmat_s_device, nbf, static_stack.exx_ky_device, nbf ); + hadamard_product( data->device_backend_->master_blas_handle(), nbf, nbf, static_stack.dmat_s_device, nbf, static_stack.exx_kz_device, nbf ); + + matrix_reduce_cols( data->device_backend_->master_blas_handle(), nbf, nbf, static_stack.exx_kx_device, nbf, static_stack.exx_bfgrad_device ); + matrix_reduce_cols( data->device_backend_->master_blas_handle(), nbf, nbf, static_stack.exx_ky_device, nbf, static_stack.exx_bfgrad_device+nbf ); + matrix_reduce_cols( data->device_backend_->master_blas_handle(), nbf, nbf, static_stack.exx_kz_device, nbf, static_stack.exx_bfgrad_device+(2*nbf) ); + + data->device_backend_->master_queue_synchronize(); + +#endif +} + void AoSScheme1Base::symmetrize_exx_k( XCDeviceData* _data ) { #ifndef GAUXC_ENABLE_EXX GAUXC_GENERIC_EXCEPTION("EXX + non-CUDA NYI"); diff --git a/src/xc_integrator/local_work_driver/device/scheme1_base.hpp b/src/xc_integrator/local_work_driver/device/scheme1_base.hpp index 6a04d4369..6ff1b7633 100644 --- a/src/xc_integrator/local_work_driver/device/scheme1_base.hpp +++ b/src/xc_integrator/local_work_driver/device/scheme1_base.hpp @@ -85,6 +85,8 @@ struct AoSScheme1Base : public detail::LocalDeviceWorkDriverPIMPL { virtual void inc_vxc( XCDeviceData*, density_id, bool ) override; virtual void inc_fxc( XCDeviceData*, density_id, bool ) override; virtual void inc_exx_k( XCDeviceData* ) override; + virtual void eval_exx_kgrad( XCDeviceData* ) override; + virtual void inc_exx_kgrad( XCDeviceData* ) override; using Data = Scheme1DataBase; diff --git a/src/xc_integrator/local_work_driver/device/scheme1_magma_base.cxx b/src/xc_integrator/local_work_driver/device/scheme1_magma_base.cxx index 095564f4a..f2f9ca453 100644 --- a/src/xc_integrator/local_work_driver/device/scheme1_magma_base.cxx +++ b/src/xc_integrator/local_work_driver/device/scheme1_magma_base.cxx @@ -197,4 +197,12 @@ void AoSScheme1MAGMABase::inc_exx_k( XCDeviceData* _data){ #endif } +void AoSScheme1MAGMABase::eval_exx_kgrad( XCDeviceData* _data){ + GAUXC_GENERIC_EXCEPTION("EXX grad + Magma NYI"); +} + +void AoSScheme1MAGMABase::inc_exx_kgrad( XCDeviceData* _data){ + GAUXC_GENERIC_EXCEPTION("EXX grad + Magma NYI"); +} + } diff --git a/src/xc_integrator/local_work_driver/device/scheme1_magma_base.hpp b/src/xc_integrator/local_work_driver/device/scheme1_magma_base.hpp index 21242a406..d61bf5108 100644 --- a/src/xc_integrator/local_work_driver/device/scheme1_magma_base.hpp +++ b/src/xc_integrator/local_work_driver/device/scheme1_magma_base.hpp @@ -20,6 +20,8 @@ struct AoSScheme1MAGMABase : public AoSScheme1Base { void eval_exx_fmat( XCDeviceData* ) override final; void inc_vxc( XCDeviceData*, density_id den, bool ) override final; void inc_exx_k( XCDeviceData* ) override final; + void eval_exx_kgrad( XCDeviceData* ) override final; + void inc_exx_kgrad( XCDeviceData* ) override final; struct Data; diff --git a/src/xc_integrator/replicated/device/incore_replicated_xc_device_integrator.cxx b/src/xc_integrator/replicated/device/incore_replicated_xc_device_integrator.cxx index ff64d58a6..921dc23c1 100644 --- a/src/xc_integrator/replicated/device/incore_replicated_xc_device_integrator.cxx +++ b/src/xc_integrator/replicated/device/incore_replicated_xc_device_integrator.cxx @@ -14,6 +14,7 @@ #include "incore_replicated_xc_device_integrator_exc_vxc.hpp" #include "incore_replicated_xc_device_integrator_exc_grad.hpp" #include "incore_replicated_xc_device_integrator_exx.hpp" +#include "incore_replicated_xc_device_integrator_exx_grad.hpp" #include "incore_replicated_xc_device_integrator_fxc_contraction.hpp" #include "incore_replicated_xc_device_integrator_dd.hpp" diff --git a/src/xc_integrator/replicated/device/incore_replicated_xc_device_integrator.hpp b/src/xc_integrator/replicated/device/incore_replicated_xc_device_integrator.hpp index 30ff47ce1..ce4d5f788 100644 --- a/src/xc_integrator/replicated/device/incore_replicated_xc_device_integrator.hpp +++ b/src/xc_integrator/replicated/device/incore_replicated_xc_device_integrator.hpp @@ -83,6 +83,9 @@ class IncoreReplicatedXCDeviceIntegrator : int64_t ldp, value_type* K, int64_t ldk, const IntegratorSettingsEXX& settings ) override; + void eval_exx_grad_( int64_t m, int64_t n, const value_type* P, + int64_t ldp, value_type* EXX_GRAD, + const IntegratorSettingsEXX& settings ) override; void eval_fxc_contraction_( int64_t m, int64_t n, const value_type* P, int64_t ldp, const value_type* tP, int64_t ldtp, @@ -169,6 +172,17 @@ class IncoreReplicatedXCDeviceIntegrator : XCDeviceData& device_data, const IntegratorSettingsEXX& settings); + void exx_grad_local_work_( const basis_type& basis, const value_type* P, int64_t ldp, + host_task_iterator task_begin, host_task_iterator task_end, + XCDeviceData& device_data, + const IntegratorSettingsEXX& settings); + + void exx_grad_local_work_( const basis_type& basis, const value_type* P, int64_t ldp, + value_type* EXX_GRAD, int64_t nbf, + host_task_iterator task_begin, host_task_iterator task_end, + XCDeviceData& device_data, + const IntegratorSettingsEXX& settings); + void exx_ek_screening_local_work_( const basis_type& basis, const value_type* P, int64_t ldp, XCDeviceData& device_data, diff --git a/src/xc_integrator/replicated/device/incore_replicated_xc_device_integrator_exx_grad.hpp b/src/xc_integrator/replicated/device/incore_replicated_xc_device_integrator_exx_grad.hpp new file mode 100644 index 000000000..2d65f9fae --- /dev/null +++ b/src/xc_integrator/replicated/device/incore_replicated_xc_device_integrator_exx_grad.hpp @@ -0,0 +1,253 @@ +/** + * GauXC Copyright (c) 2020-2024, The Regents of the University of California, + * through Lawrence Berkeley National Laboratory (subject to receipt of + * any required approvals from the U.S. Dept. of Energy). All rights reserved. + * + * See LICENSE.txt for details + */ + +#include "incore_replicated_xc_device_integrator.hpp" +#include "device/local_device_work_driver.hpp" +#include "host/reference_local_host_work_driver.hpp" +#include +#include "device/xc_device_aos_data.hpp" +#include +#include + +#include "integrator_util/exx_screening.hpp" +#include "integrator_util/integral_bounds.hpp" + +namespace GauXC { +namespace detail { + +template +void IncoreReplicatedXCDeviceIntegrator:: + eval_exx_grad_( int64_t m, int64_t n, const value_type* P, + int64_t ldp, value_type* EXX_GRAD, + const IntegratorSettingsEXX& settings ) { + + + const auto& basis = this->load_balancer_->basis(); + + // Check that P / K are sane + const int64_t nbf = basis.nbf(); + if( m != n ) + GAUXC_GENERIC_EXCEPTION("P/K Must Be Square"); + if( m != nbf ) + GAUXC_GENERIC_EXCEPTION("P/K Must Have Same Dimension as Basis"); + if( ldp < nbf ) + GAUXC_GENERIC_EXCEPTION("Invalid LDP"); + + // Allocate Device memory + auto* lwd = dynamic_cast(this->local_work_driver_.get() ); + auto rt = detail::as_device_runtime(this->load_balancer_->runtime()); + auto device_data_ptr = lwd->create_device_data(rt); + + std::vector exx_bfgrad(3*nbf, 0); + + GAUXC_MPI_CODE(MPI_Barrier(rt.comm());) + + this->timer_.time_op("XCIntegrator.EXX_GRAD_Screening", [&]() { + exx_ek_screening_local_work_( basis, P, ldp, *device_data_ptr, settings); + }); + + + // Get Tasks + auto& tasks = this->load_balancer_->get_tasks(); + if( this->reduction_driver_->takes_device_memory() ) { + + // Compute local contributions to K and keep on device + this->timer_.time_op("XCIntegrator.LocalWork_EXX_GRAD", [&](){ + exx_grad_local_work_( basis, P, ldp, + tasks.begin(), tasks.end(), *device_data_ptr, settings); + rt.device_backend()->master_queue_synchronize(); + }); + + GAUXC_MPI_CODE( + this->timer_.time_op("XCIntegrator.ImbalanceWait_EXX_GRAD",[&](){ + MPI_Barrier(rt.comm()); + }); + ) + + // Reduce results in device memory + this->timer_.time_op("XCIntegrator.Allreduce_EXX_GRAD", [&](){ + this->reduction_driver_->allreduce_inplace( + device_data_ptr->exx_grad_device_data(), nbf, ReductionOp::Sum, + device_data_ptr->queue()); + }); + + // Receive K from host + this->timer_.time_op("XCIntegrator.DeviceToHostCopy_EXX_GRAD",[&](){ + device_data_ptr->retrieve_exx_grad(exx_bfgrad.data()); + }); + + } else { + + // Compute local contributions to K and retrieve + // data from device + this->timer_.time_op("XCIntegrator.LocalWork_EXX_GRAD", [&](){ + exx_grad_local_work_( basis, P, ldp, + tasks.begin(), tasks.end(), *device_data_ptr, settings); + }); + + GAUXC_MPI_CODE( + this->timer_.time_op("XCIntegrator.ImbalanceWait_EXX_GRAD",[&](){ + MPI_Barrier(rt.comm()); + }); + ) + + this->timer_.time_op("XCIntegrator.DeviceToHostCopy_EXX_GRAD",[&](){ + device_data_ptr->retrieve_exx_grad(exx_bfgrad.data()); + }); + + // Reduce Results in host mem + this->timer_.time_op("XCIntegrator.Allreduce_EXX_GRAD", [&](){ + this->reduction_driver_->allreduce_inplace(exx_bfgrad.data(), 3*nbf, ReductionOp::Sum ); + }); + + } + + // Sum gradient contribution of basis functions + // to nuclei + rt.device_backend()->master_queue_synchronize(); + auto& basis_map = this->load_balancer_->basis_map(); +for (size_t i = 0; i< basis.nshells(); i++) { + const auto [b0, b1] = basis_map.shell_to_ao_range()[i]; + const auto iCenter = basis_map.shell_to_center()[i]; + for (size_t bf = b0; bf < b1; bf++) { + // factor 2 is for bra-ket integral symmetry + EXX_GRAD[3*iCenter ] += 2*exx_bfgrad[bf ]; + EXX_GRAD[3*iCenter+1] += 2*exx_bfgrad[bf+1*nbf]; + EXX_GRAD[3*iCenter+2] += 2*exx_bfgrad[bf+2*nbf]; + } + +} +} + +template +void IncoreReplicatedXCDeviceIntegrator:: + exx_grad_local_work_( const basis_type& basis, const value_type* P, int64_t ldp, + value_type* EXX_GRAD, int64_t nbf, + host_task_iterator task_begin, host_task_iterator task_end, + XCDeviceData& device_data, + const IntegratorSettingsEXX& settings ) { + + + exx_local_work_(basis, P, ldp, task_begin, task_end, device_data, settings); + auto rt = detail::as_device_runtime(this->load_balancer_->runtime()); + rt.device_backend()->master_queue_synchronize(); + + // Receive K from host + this->timer_.time_op("XCIntegrator.DeviceToHostCopy_EXX_GRAD",[&](){ + device_data.retrieve_exx_integrands( EXX_GRAD, nbf ); + }); + +} + +template +void IncoreReplicatedXCDeviceIntegrator:: + exx_grad_local_work_( const basis_type& basis, const value_type* P, int64_t ldp, + host_task_iterator task_begin, host_task_iterator task_end, + XCDeviceData& device_data, + const IntegratorSettingsEXX& settings ) { + + auto* lwd = dynamic_cast(this->local_work_driver_.get() ); + IntegratorSettingsSNLinK sn_link_settings; + if( auto* tmp = dynamic_cast(&settings) ) { + sn_link_settings = *tmp; + } + + // Setup Aliases + const auto nbf = basis.nbf(); + const auto nshells = basis.nshells(); + + + // Get basis map and shell pairs + auto& basis_map = this->load_balancer_->basis_map(); + auto& shell_pairs = this->load_balancer_->shell_pairs(); + + + + // Sort tasks + auto task_comparator = []( const XCTask& a, const XCTask& b ) { + return (a.points.size() * a.bfn_screening.nbe) > (b.points.size() * b.bfn_screening.nbe); + }; + std::sort( task_begin, task_end, task_comparator ); + + + + // Check that Partition Weights have been calculated + auto& lb_state = this->load_balancer_->state(); + if( not lb_state.modified_weights_are_stored ) { + GAUXC_GENERIC_EXCEPTION("Weights Have Not Been Modified"); + } + + task_end = std::stable_partition( task_begin, task_end, + []( const auto& t ) { return t.cou_screening.shell_list.size() > 0; } ); + + std::sort(task_begin,task_end, + [](auto& a, auto& b){ return a.cou_screening.shell_pair_list.size() > + b.cou_screening.shell_pair_list.size(); }); + + // Populate submat maps + device_data.populate_submat_maps( basis.nbf(), task_begin, task_end, basis_map ); + + + + // Do EXX integration in task batches + device_data.reset_allocations(); + device_data.allocate_static_data_exx_grad( nbf, nshells, shell_pairs.npairs(), shell_pairs.nprim_pair_total(), basis_map.max_l() ); + device_data.send_static_data_density_basis( P, ldp, nullptr, 0, nullptr, 0, nullptr, 0, basis ); + device_data.send_static_data_shell_pairs( basis, shell_pairs ); + + // Zero integrands + device_data.zero_exx_integrands(); + device_data.zero_exx_grad_integrands(); + + // Processes batches in groups that saturadate available device memory + integrator_term_tracker enabled_terms; + enabled_terms.exx = true; + enabled_terms.exx_grad = true; + + auto task_it = task_begin; + while( task_it != task_end ) { + + // Determine next task batch, send relevant data to device (EXX only) + task_it = + device_data.generate_buffers( enabled_terms, basis_map, task_it, task_end ); + +#if 1 + /*** Process the batches ***/ + + // Evaluate collocation gradient + lwd->eval_collocation_gradient( &device_data ); + + // Evaluate F(mu,i) = P(mu,nu) * B(nu,i) + // mu runs over significant ek shells + // nu runs over the bfn shell list + // i runs over all points + lwd->eval_exx_fmat( &device_data ); + + // Compute G(mu,i) = w(i) * A(mu,nu,i) * F(nu,i) + // mu/nu run over significant ek shells + // i runs over all points + lwd->eval_exx_gmat( &device_data, basis_map ); + + // Increment dK(mu,nu)/dx += dB/dx(mu,i) * G(nu,i) + // mu runs over bfn shell list + // nu runs over ek shells + // i runs over all points + lwd->eval_exx_kgrad( &device_data ); +#endif + + } // Loop over batches of batches + + // Contract derivative K matrices with density + // to produce gradient: + // bfgrad(mu)_x = sum_nu dK/dx(mu,nu) * DM(mu,nu) + lwd->inc_exx_kgrad( &device_data ); + +} + +} +} diff --git a/src/xc_integrator/replicated/device/shell_batched_replicated_xc_device_integrator.cxx b/src/xc_integrator/replicated/device/shell_batched_replicated_xc_device_integrator.cxx index febcd7aa4..e577c6a4d 100644 --- a/src/xc_integrator/replicated/device/shell_batched_replicated_xc_device_integrator.cxx +++ b/src/xc_integrator/replicated/device/shell_batched_replicated_xc_device_integrator.cxx @@ -15,6 +15,7 @@ #include "shell_batched_replicated_xc_integrator_exc_vxc.hpp" #include "shell_batched_replicated_xc_integrator_exc_grad.hpp" #include "shell_batched_replicated_xc_integrator_exx.hpp" +#include "shell_batched_replicated_xc_integrator_exx_grad.hpp" #include "shell_batched_replicated_xc_integrator_fxc_contraction.hpp" #include "shell_batched_replicated_xc_integrator_dd_psi.hpp" #include "shell_batched_replicated_xc_integrator_dd_psi_potential.hpp" diff --git a/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator.cxx b/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator.cxx index 6695d9121..c86d3af46 100644 --- a/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator.cxx +++ b/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator.cxx @@ -14,6 +14,7 @@ #include "reference_replicated_xc_host_integrator_exc_vxc.hpp" #include "reference_replicated_xc_host_integrator_exc_grad.hpp" #include "reference_replicated_xc_host_integrator_exx.hpp" +#include "reference_replicated_xc_host_integrator_exx_grad.hpp" #include "reference_replicated_xc_host_integrator_fxc_contraction.hpp" #include "reference_replicated_xc_host_integrator_dd_psi.hpp" #include "reference_replicated_xc_host_integrator_dd_psi_potential.hpp" diff --git a/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator.hpp b/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator.hpp index a32748eb5..693c34179 100644 --- a/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator.hpp +++ b/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator.hpp @@ -86,6 +86,9 @@ class ReferenceReplicatedXCHostIntegrator : void eval_exx_( int64_t m, int64_t n, const value_type* P, int64_t ldp, value_type* K, int64_t ldk, const IntegratorSettingsEXX& settings ) override; + void eval_exx_grad_( int64_t m, int64_t n, const value_type* P, + int64_t ldp, value_type* EXX_GRAD, + const IntegratorSettingsEXX& settings ) override; /// RKS FXC contraction void eval_fxc_contraction_( int64_t m, int64_t n, diff --git a/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator_exx_grad.hpp b/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator_exx_grad.hpp new file mode 100644 index 000000000..9a0d19325 --- /dev/null +++ b/src/xc_integrator/replicated/host/reference_replicated_xc_host_integrator_exx_grad.hpp @@ -0,0 +1,27 @@ +/** + * GauXC Copyright (c) 2020-2024, The Regents of the University of California, + * through Lawrence Berkeley National Laboratory (subject to receipt of + * any required approvals from the U.S. Dept. of Energy). All rights reserved. + * + * See LICENSE.txt for details + */ +#pragma once +#include "reference_replicated_xc_host_integrator.hpp" +#include +#include + +namespace GauXC { +namespace detail { + +template +void ReferenceReplicatedXCHostIntegrator:: + eval_exx_grad_( int64_t m, int64_t n, const value_type* P, + int64_t ldp, value_type* EXC_GRAD, + const IntegratorSettingsEXX& settings ) { + + GAUXC_GENERIC_EXCEPTION("HostReplicated exc_grad NYI" ); + util::unused(m,n,P,ldp,EXC_GRAD); +} + +} +} diff --git a/src/xc_integrator/replicated/host/shell_batched_replicated_xc_host_integrator.cxx b/src/xc_integrator/replicated/host/shell_batched_replicated_xc_host_integrator.cxx index c972d30a9..19df6349b 100644 --- a/src/xc_integrator/replicated/host/shell_batched_replicated_xc_host_integrator.cxx +++ b/src/xc_integrator/replicated/host/shell_batched_replicated_xc_host_integrator.cxx @@ -15,6 +15,7 @@ #include "shell_batched_replicated_xc_integrator_exc_vxc.hpp" #include "shell_batched_replicated_xc_integrator_exc_grad.hpp" #include "shell_batched_replicated_xc_integrator_exx.hpp" +#include "shell_batched_replicated_xc_integrator_exx_grad.hpp" #include "shell_batched_replicated_xc_integrator_fxc_contraction.hpp" #include "shell_batched_replicated_xc_integrator_dd_psi.hpp" #include "shell_batched_replicated_xc_integrator_dd_psi_potential.hpp" diff --git a/src/xc_integrator/replicated/replicated_xc_integrator_impl.cxx b/src/xc_integrator/replicated/replicated_xc_integrator_impl.cxx index 071afe312..3900e03de 100644 --- a/src/xc_integrator/replicated/replicated_xc_integrator_impl.cxx +++ b/src/xc_integrator/replicated/replicated_xc_integrator_impl.cxx @@ -150,6 +150,16 @@ void ReplicatedXCIntegratorImpl:: } +template +void ReplicatedXCIntegratorImpl:: + eval_exx_grad( int64_t m, int64_t n, const value_type* P, + int64_t ldp, value_type* EXX_GRAD, + const IntegratorSettingsEXX& settings ) { + + eval_exx_grad_(m,n,P,ldp,EXX_GRAD,settings); + +} + template void ReplicatedXCIntegratorImpl:: eval_fxc_contraction( int64_t m, int64_t n, const value_type* P, diff --git a/src/xc_integrator/shell_batched/shell_batched_replicated_xc_integrator.hpp b/src/xc_integrator/shell_batched/shell_batched_replicated_xc_integrator.hpp index 5c1d4a949..d56e14ed5 100644 --- a/src/xc_integrator/shell_batched/shell_batched_replicated_xc_integrator.hpp +++ b/src/xc_integrator/shell_batched/shell_batched_replicated_xc_integrator.hpp @@ -15,6 +15,7 @@ #ifdef GAUXC_HAS_DEVICE #include "device/xc_device_data.hpp" #endif +#include "gauxc/xc_integrator_settings.hpp" namespace GauXC { namespace detail { @@ -98,6 +99,9 @@ class ShellBatchedReplicatedXCIntegrator : int64_t ldp, value_type* K, int64_t ldk, const IntegratorSettingsEXX& settings ) override; + void eval_exx_grad_( int64_t m, int64_t n, const value_type* P, + int64_t ldp, value_type* EXX_GRAD, + const IntegratorSettingsEXX& settings ) override; // RKS FXC contraction void eval_fxc_contraction_( int64_t m, int64_t n, const value_type* P, int64_t ldp, diff --git a/src/xc_integrator/shell_batched/shell_batched_replicated_xc_integrator_exx_grad.hpp b/src/xc_integrator/shell_batched/shell_batched_replicated_xc_integrator_exx_grad.hpp new file mode 100644 index 000000000..c1440fcc5 --- /dev/null +++ b/src/xc_integrator/shell_batched/shell_batched_replicated_xc_integrator_exx_grad.hpp @@ -0,0 +1,26 @@ +/** + * GauXC Copyright (c) 2020-2024, The Regents of the University of California, + * through Lawrence Berkeley National Laboratory (subject to receipt of + * any required approvals from the U.S. Dept. of Energy). All rights reserved. + * + * See LICENSE.txt for details + */ +#pragma once +#include "shell_batched_replicated_xc_integrator.hpp" +#include +#include + +namespace GauXC { +namespace detail { + +template +void ShellBatchedReplicatedXCIntegrator:: + eval_exx_grad_( int64_t m, int64_t n, const value_type* P, + int64_t ldp, value_type* EXX_GRAD, + const IntegratorSettingsEXX& settings ) { + GAUXC_GENERIC_EXCEPTION("ShellBatched EXX GRADIENT NYI"); + util::unused(m,n,P,ldp,EXX_GRAD,settings); +} + +} +} diff --git a/src/xc_integrator/xc_data/device/xc_device_data.hpp b/src/xc_integrator/xc_data/device/xc_device_data.hpp index 781e23729..504dca23f 100644 --- a/src/xc_integrator/xc_data/device/xc_device_data.hpp +++ b/src/xc_integrator/xc_data/device/xc_device_data.hpp @@ -52,6 +52,7 @@ struct integrator_term_tracker { bool exc_vxc = false; bool exc_grad = false; bool exx = false; + bool exx_grad = false; bool exx_ek_screening = false; bool fxc_contraction = false; integrator_xc_approx xc_approx = _UNDEF_APPROX; @@ -630,6 +631,10 @@ struct required_term_storage { task_to_shell_pair_cou = true; } + if(tracker.exx_grad) { + task_bfn_grad = true; + } + if(tracker.exx_ek_screening) { task_bfn = true; task_indirection = true; @@ -657,6 +662,7 @@ std::ostream& operator<<( std::ostream& out, const integrator_term_tracker& t ) out << " FXC_CONTRACTION " << t.fxc_contraction << std::endl; out << " EXC_GRAD " << t.exc_grad << std::endl; out << " EXX " << t.exx << std::endl; + out << " EXX_GRAD " << t.exx_grad << std::endl; return out; } @@ -680,6 +686,7 @@ struct XCDeviceData { virtual void allocate_static_data_den( int32_t nbf, int32_t nshells ) = 0; virtual void allocate_static_data_exc_grad( int32_t nbf, int32_t nshells, int32_t natoms, integrator_term_tracker enabled_terms ) = 0; virtual void allocate_static_data_exx( int32_t nbf, int32_t nshells, size_t nshell_pairs, size_t nprim_pair_total, int32_t max_l ) = 0; + virtual void allocate_static_data_exx_grad( int32_t nbf, int32_t nshells, size_t nshell_pairs, size_t nprim_pair_total, int32_t max_l ) = 0; virtual void allocate_static_data_exx_ek_screening( size_t ntasks, int32_t nbf, int32_t nshells, int nshell_pairs, int32_t max_l ) = 0; virtual void allocate_static_data_fxc_contraction( int32_t nbf, int32_t nshells, integrator_term_tracker enabled_terms) = 0; @@ -706,6 +713,9 @@ struct XCDeviceData { /// Zero out the EXX integrands in device memory virtual void zero_exx_integrands() = 0; + /// Zero out the EXX Gradient integrands in device memory + virtual void zero_exx_grad_integrands() = 0; + /// Zero out intermediates for EXX EK screening virtual void zero_exx_ek_screening_intermediates() = 0; @@ -762,6 +772,8 @@ struct XCDeviceData { virtual void retrieve_exx_integrands( double* K, int32_t ldk ) = 0; + virtual void retrieve_exx_grad( double* grad ) = 0; + virtual void retrieve_exx_ek_max_bfn_sum( double* MBS, int32_t nt) = 0; @@ -775,6 +787,7 @@ struct XCDeviceData { virtual double* exc_device_data() = 0; virtual double* nel_device_data() = 0; virtual double* exx_k_device_data() = 0; + virtual double* exx_grad_device_data() = 0; virtual double* fxc_z_device_data() = 0; virtual double* fxc_s_device_data() = 0; virtual double* fxc_y_device_data() = 0; diff --git a/src/xc_integrator/xc_data/device/xc_device_stack_data.cxx b/src/xc_integrator/xc_data/device/xc_device_stack_data.cxx index 96ffb888b..07c5a9802 100644 --- a/src/xc_integrator/xc_data/device/xc_device_stack_data.cxx +++ b/src/xc_integrator/xc_data/device/xc_device_stack_data.cxx @@ -45,6 +45,7 @@ double* XCDeviceStackData::vxc_x_device_data() { return static_stack.vxc_x_devic double* XCDeviceStackData::exc_device_data() { return static_stack.exc_device; } double* XCDeviceStackData::nel_device_data() { return static_stack.nel_device; } double* XCDeviceStackData::exx_k_device_data() { return static_stack.exx_k_device; } +double* XCDeviceStackData::exx_grad_device_data() { return static_stack.exx_bfgrad_device; } double* XCDeviceStackData::fxc_s_device_data() { return static_stack.fxc_s_device; } double* XCDeviceStackData::fxc_z_device_data() { return static_stack.fxc_z_device; } double* XCDeviceStackData::fxc_y_device_data() { return static_stack.fxc_y_device; } @@ -273,6 +274,27 @@ void XCDeviceStackData::allocate_static_data_exx( int32_t nbf, int32_t nshells, allocated_terms.exx = true; } +void XCDeviceStackData::allocate_static_data_exx_grad( int32_t nbf, int32_t nshells, size_t nshell_pairs, size_t nprim_pair_total, int32_t max_l ) { + + allocate_static_data_exx(nbf, nshells, nshell_pairs, nprim_pair_total, max_l); + + if( allocated_terms.exx_grad ) + GAUXC_GENERIC_EXCEPTION("Attempting to reallocate Stack EXC GRAD"); + + // Allocate static memory with proper alignment + buffer_adaptor mem( dynmem_ptr, dynmem_sz ); + static_stack.exx_kx_device = mem.aligned_alloc( nbf * nbf , csl); + static_stack.exx_ky_device = mem.aligned_alloc( nbf * nbf , csl); + static_stack.exx_kz_device = mem.aligned_alloc( nbf * nbf , csl); + static_stack.exx_bfgrad_device = mem.aligned_alloc( 3*nbf , csl); + + // Get current stack location + dynmem_ptr = mem.stack(); + dynmem_sz = mem.nleft(); + + allocated_terms.exx_grad = true; +} + void XCDeviceStackData::allocate_static_data_exx_ek_screening( size_t ntasks, int32_t nbf, int32_t nshells, int nshell_pairs, int32_t max_l ) { if( allocated_terms.exx_ek_screening ) @@ -603,6 +625,18 @@ void XCDeviceStackData::zero_exx_integrands() { } +void XCDeviceStackData::zero_exx_grad_integrands() { + + if( not device_backend_ ) GAUXC_GENERIC_EXCEPTION("Invalid Device Backend"); + + const auto nbf = global_dims.nbf; + device_backend_->set_zero( 3*nbf, static_stack.exx_bfgrad_device, "EXX Gradient Zero" ); + device_backend_->set_zero( nbf*nbf, static_stack.exx_kx_device, "Kx Zero" ); + device_backend_->set_zero( nbf*nbf, static_stack.exx_ky_device, "Ky Zero" ); + device_backend_->set_zero( nbf*nbf, static_stack.exx_kz_device, "Kz Zero" ); + +} + void XCDeviceStackData::zero_exx_ek_screening_intermediates() { if( not device_backend_ ) GAUXC_GENERIC_EXCEPTION("Invalid Device Backend"); @@ -694,6 +728,15 @@ void XCDeviceStackData::retrieve_exx_integrands( double* K, int32_t ldk ) { } +void XCDeviceStackData::retrieve_exx_grad( double* grad ) { + + const auto nbf = global_dims.nbf; + if( not device_backend_ ) GAUXC_GENERIC_EXCEPTION("Invalid Device Backend"); + + device_backend_->copy_async( 3*nbf, static_stack.exx_bfgrad_device, grad, "K grad" ); + +} + void XCDeviceStackData::retrieve_exx_ek_max_bfn_sum( double* MBS, int32_t nt ) { const auto ntask_ek = global_dims.ntask_ek; diff --git a/src/xc_integrator/xc_data/device/xc_device_stack_data.hpp b/src/xc_integrator/xc_data/device/xc_device_stack_data.hpp index cf5399a8b..790e315d6 100644 --- a/src/xc_integrator/xc_data/device/xc_device_stack_data.hpp +++ b/src/xc_integrator/xc_data/device/xc_device_stack_data.hpp @@ -57,6 +57,10 @@ struct XCDeviceStackData : public XCDeviceData { double* exc_device = nullptr; ///< EXC storage (1) double* nel_device = nullptr; ///< N_EL storage (1) double* exx_k_device = nullptr; ///< EXX K storage (nbf,nbf) + double* exx_kx_device = nullptr; ///< EXX dK/dx intermediates storage (nbf,nbf) + double* exx_ky_device = nullptr; ///< EXX dK/dy intermediates storage (nbf,nbf) + double* exx_kz_device = nullptr; ///< EXX dK/dz intermediates storage (nbf,nbf) + double* exx_bfgrad_device = nullptr; ///< EXX gradient storage (nbf) double* acc_scr_device = nullptr; ///< Accumulaion scratch (1) double* exc_grad_device = nullptr; ///< EXC Gradient storage (3*natoms) double* fxc_device = nullptr; ///< FXC contraction storage (nbf,nbf) @@ -328,6 +332,7 @@ struct XCDeviceStackData : public XCDeviceData { void allocate_static_data_den( int32_t nbf, int32_t nshells ) override final; void allocate_static_data_exc_grad( int32_t nbf, int32_t nshells, int32_t natoms, integrator_term_tracker enabled_terms ) override final; void allocate_static_data_exx( int32_t nbf, int32_t nshells, size_t nshell_pairs, size_t nprim_pair_total, int32_t max_l ) override final; + void allocate_static_data_exx_grad( int32_t nbf, int32_t nshells, size_t nshell_pairs, size_t nprim_pair_total, int32_t max_l ) override final; void allocate_static_data_exx_ek_screening( size_t ntasks, int32_t nbf, int32_t nshells, int nshell_pairs, int32_t max_l ) override final; void send_static_data_weights( const Molecule& mol, const MolMeta& meta ) override final; void send_static_data_density_basis( const double* Ps, int32_t ldps, const double* Pz, int32_t ldpz, @@ -344,6 +349,7 @@ struct XCDeviceStackData : public XCDeviceData { void zero_fxc_contraction_integrands() override final; void zero_exc_grad_integrands() override final; void zero_exx_integrands() override final; + void zero_exx_grad_integrands() override final; void zero_exx_ek_screening_intermediates() override final; void retrieve_exc_vxc_integrands( double* EXC, double* N_EL, double* VXCscalar, int32_t ldvxcscalar, double* VXCz, int32_t ldvxcz, @@ -354,6 +360,7 @@ struct XCDeviceStackData : public XCDeviceData { void retrieve_exc_grad_integrands( double* EXC_GRAD, double* N_EL ) override final; void retrieve_den_integrands( double* N_EL ) override final; void retrieve_exx_integrands( double* K, int32_t ldk ) override final; + void retrieve_exx_grad( double* grad ) override final; void retrieve_exx_ek_max_bfn_sum( double* MBS, int32_t nt) override final; void copy_weights_to_tasks( host_task_iterator task_begin, host_task_iterator task_end ) override final; @@ -364,6 +371,7 @@ struct XCDeviceStackData : public XCDeviceData { double* exc_device_data() override; double* nel_device_data() override; double* exx_k_device_data() override; + double* exx_grad_device_data() override; double* fxc_s_device_data() override; double* fxc_z_device_data() override; double* fxc_y_device_data() override; diff --git a/tests/standalone_driver.cxx b/tests/standalone_driver.cxx index 68a9c13aa..7daf41ae3 100644 --- a/tests/standalone_driver.cxx +++ b/tests/standalone_driver.cxx @@ -231,6 +231,7 @@ int main(int argc, char** argv) { matrix_type FXC_ref, FXCz_ref; double EXC_ref; std::vector EXC_GRAD_ref(3*mol.size()); + std::vector EXX_GRAD_ref(3*mol.size()); bool rks = true, uks = false, gks = false; size_t N_EL_ref = MolMeta(mol).sum_atomic_charges(); { @@ -346,6 +347,19 @@ int main(int argc, char** argv) { } std::fill( EXC_GRAD_ref.begin(), EXC_GRAD_ref.end(), 0. ); } + try { + dset = file.getDataSet("EXX_GRAD"); + auto xc_grad_dims = dset.getDimensions(); + if( xc_grad_dims[0] != mol.size() or xc_grad_dims[1] != 3 ) + GAUXC_GENERIC_EXCEPTION("Incorrect dims for EXX_GRAD"); + dset.read( EXX_GRAD_ref.data() ); + } catch(...) { + if(world_rank == 0) { + std::cout << "** Warning: Could Not Find Reference EXX_GRAD" + << std::endl; + } + std::fill( EXX_GRAD_ref.begin(), EXX_GRAD_ref.end(), 0. ); + } } if( integrate_exx ) { @@ -492,6 +506,7 @@ int main(int argc, char** argv) { } std::vector EXC_GRAD; + std::vector EXX_GRAD; if( integrate_exc_grad ) { if( rks ) { EXC_GRAD = integrator.eval_exc_grad( P ); @@ -590,6 +605,28 @@ int main(int argc, char** argv) { K = integrator.eval_exx(P, sn_link_settings); //matrix_type K_tmp = 0.5 * (K + K.transpose()); //K = -K_tmp; + if( integrate_exc_grad ) { + if( rks ) { + EXX_GRAD = integrator.eval_exx_grad( P, sn_link_settings ); + } + else if( uks ) { + std::cout << "Warning: eval_exx_grad + UKS NYI!" << std::endl; + } + else if( gks ) { + std::cout << "Warning: eval_exx_grad + GKS NYI!" << std::endl; + } + if(!world_rank) { + std::cout << "EXX Gradient:" << std::endl; + std::cout << std::scientific << std::setprecision(6); + for( auto iAt = 0; iAt < mol.size(); ++iAt ) { + std::cout << " " + << std::setw(16) << EXX_GRAD[3*iAt + 0] + << std::setw(16) << EXX_GRAD[3*iAt + 1] + << std::setw(16) << EXX_GRAD[3*iAt + 2] + << std::endl; + } + } + } } else { K = K_ref; } @@ -740,6 +777,24 @@ int main(int argc, char** argv) { std::cout << "| K (calc) |_F = " << K.norm() << std::endl; std::cout << "RMS K Diff = " << (K_ref - K).norm() / basis.nbf() << std::endl; + if(integrate_exc_grad) { + double exx_grad_ref_nrm(0.), exx_grad_calc_nrm(0.), exx_grad_diff_nrm(0.); + for( auto i = 0; i < 3*mol.size(); ++i ) { + const auto ref_val = EXX_GRAD_ref[i]; + const auto clc_val = EXX_GRAD[i]; + const auto dif_val = std::abs(ref_val - clc_val); + exx_grad_ref_nrm += ref_val*ref_val; + exx_grad_calc_nrm += clc_val*clc_val; + exx_grad_diff_nrm += dif_val*dif_val; + } + + exx_grad_ref_nrm = std::sqrt(exx_grad_ref_nrm); + exx_grad_calc_nrm = std::sqrt(exx_grad_calc_nrm); + exx_grad_diff_nrm = std::sqrt(exx_grad_diff_nrm); + std::cout << "| EXX_GRAD (ref) | = " << exx_grad_ref_nrm << std::endl; + std::cout << "| EXX_GRAD (calc) | = " << exx_grad_calc_nrm << std::endl; + std::cout << "| EXX_GRAD (diff) | = " << exx_grad_diff_nrm << std::endl; + } } if (integrate_dd_psi) { std::cout << "| DD_PSI (ref) |_F = " << ddPsi_ref.norm() << std::endl; @@ -819,6 +874,11 @@ int main(int argc, char** argv) { if( integrate_exx ) { dset = file.createDataSet( "/K", mat_space ); dset.write_raw( K.data() ); + if( integrate_exc_grad ) { + HighFive::DataSpace grad_space( mol.size(), 3 ); + dset = file.createDataSet( "/EXX_GRAD", grad_space ); + dset.write_raw( EXX_GRAD.data() ); + } } if( integrate_exc_grad ) {