diff --git a/src/xc_integrator/integrator_util/exx_screening.cxx b/src/xc_integrator/integrator_util/exx_screening.cxx index 5c7efcd13..6bed9ec84 100644 --- a/src/xc_integrator/integrator_util/exx_screening.cxx +++ b/src/xc_integrator/integrator_util/exx_screening.cxx @@ -253,7 +253,6 @@ void exx_ek_screening( const auto nshells = basis.nshells(); const size_t ntasks = std::distance(task_begin, task_end); - const size_t task_batch_size = 10000; // Setup EXX EK Screening memory on the device device_data.reset_allocations(); @@ -268,39 +267,25 @@ void exx_ek_screening( - auto task_batch_begin = task_begin; - while(task_batch_begin != task_end) { - - size_t nleft = std::distance(task_batch_begin, task_end); - exx_detail::host_task_iterator task_batch_end; - if(nleft > task_batch_size) - task_batch_end = task_batch_begin + task_batch_size; - else - task_batch_end = task_end; - - device_data.zero_exx_ek_screening_intermediates(); - - // Loop over tasks and form basis-related buffers - auto task_it = task_batch_begin; - while( task_it != task_batch_end ) { - - // Determine next task patch, send relevant data (EXX_EK only) - task_it = device_data.generate_buffers( enabled_terms, basis_map, task_it, - task_batch_end ); + auto task_it = task_begin; + while (task_it != task_end) { - // Evaluate collocation - lwd->eval_collocation( &device_data ); + device_data.zero_exx_ek_screening_intermediates(); + auto task_batch_begin = task_it; - // Evaluate EXX EK Screening Basis Statistics - lwd->eval_exx_ek_screening_bfn_stats( &device_data ); + // Determine next task patch, send relevant data (EXX_EK only) + task_it = device_data.generate_buffers(enabled_terms, basis_map, task_it, + task_end); - } + // Evaluate collocation + lwd->eval_collocation(&device_data); + // Evaluate EXX EK Screening Basis Statistics + lwd->eval_exx_ek_screening_bfn_stats(&device_data); - lwd->exx_ek_shellpair_collision( eps_E, eps_K, &device_data, task_batch_begin, - task_batch_end, shpairs ); - task_batch_begin = task_batch_end; + lwd->exx_ek_shellpair_collision(eps_E, eps_K, &device_data, task_batch_begin, + task_it, shpairs); } //GAUXC_CUDA_ERROR("End Sync", cudaDeviceSynchronize()); diff --git a/src/xc_integrator/local_work_driver/device/cuda/kernels/exx_ek_screening_bfn_stats.cu b/src/xc_integrator/local_work_driver/device/cuda/kernels/exx_ek_screening_bfn_stats.cu index 86799ad25..81cda6952 100644 --- a/src/xc_integrator/local_work_driver/device/cuda/kernels/exx_ek_screening_bfn_stats.cu +++ b/src/xc_integrator/local_work_driver/device/cuda/kernels/exx_ek_screening_bfn_stats.cu @@ -199,8 +199,8 @@ __global__ void exx_ek_shellpair_collision_shared_kernel( int LD_coll, uint32_t* rc_collisions, int LD_rc, - uint32_t* counts, - uint32_t* rc_counts + uint64_t* counts, + uint64_t* rc_counts ) { extern __shared__ uint32_t s_rc_collisions[]; @@ -253,13 +253,13 @@ __global__ void exx_ek_shellpair_collision_shared_kernel( // TODO use thread block level reduction before writing to global memory - uint32_t count = 0; + unsigned long long count = 0; for(int ij = threadIdx.x; ij < LD_coll; ij+=blockDim.x) count += __popc(collisions[i_task * LD_coll + ij]); - atomicAdd(&(counts[i_task]), count); + atomicAdd((unsigned long long *)&(counts[i_task]), count); count = 0; for(int ij = threadIdx.x; ij < LD_rc; ij+=blockDim.x) count += __popc(rc_collisions[i_task * LD_rc + ij]); - atomicAdd(&(rc_counts[i_task]), count); + atomicAdd((unsigned long long *)&(rc_counts[i_task]), count); __syncthreads(); } @@ -289,7 +289,7 @@ __global__ void print_coll(size_t ntasks, size_t nshells, uint32_t* collisions, } } -__global__ void print_counts(size_t ntasks, uint32_t* counts) { +__global__ void print_counts(size_t ntasks, uint64_t* counts) { for(auto i_task = 0 ; i_task < ntasks; ++i_task) { @@ -308,8 +308,8 @@ __global__ void bitvector_to_position_list_shellpair( size_t nsp, size_t LD_bit, const uint32_t* collisions, - const uint32_t* counts, - uint32_t* position_list + const uint64_t* counts, + uint64_t* position_list ) { constexpr auto warp_size = cuda::warp_size; @@ -370,9 +370,9 @@ __global__ void bitvector_to_position_list_shells( size_t nshells, size_t LD_bit, const uint32_t* collisions, - const uint32_t* counts, + const uint64_t* counts, const int32_t* shell_size, - uint32_t* position_list, + uint64_t* position_list, size_t* nbe_list ) { constexpr auto warp_size = cuda::warp_size; @@ -500,8 +500,8 @@ void exx_ek_shellpair_collision( using dur_t = std::chrono::duration; cudaStream_t stream = queue.queue_as(); - std::vector counts_host (ntasks); - std::vector rc_counts_host (ntasks); + std::vector counts_host (ntasks); + std::vector rc_counts_host (ntasks); const size_t nshell_pairs = shpairs.npairs(); const size_t LD_coll = util::div_ceil(nshell_pairs, 32); @@ -533,9 +533,9 @@ void exx_ek_shellpair_collision( buffer_adaptor full_stack(dyn_stack, dyn_size); auto collisions = full_stack.aligned_alloc(ntasks * LD_coll); - auto counts = full_stack.aligned_alloc(ntasks); + auto counts = full_stack.aligned_alloc(ntasks); auto rc_collisions = full_stack.aligned_alloc(ntasks * LD_rc); - auto rc_counts = full_stack.aligned_alloc(ntasks); + auto rc_counts = full_stack.aligned_alloc(ntasks); auto sp_check_st = hrt_t::now(); util::cuda_set_zero_async( ntasks * LD_coll,collisions.ptr, stream, "Zero Coll"); @@ -641,8 +641,8 @@ void exx_ek_shellpair_collision( auto scan_en = hrt_t::now(); dur_t scan_dur = scan_en - scan_st; - uint32_t total_sp_count = counts_host[ntasks-1]; - uint32_t total_s_count = rc_counts_host[ntasks-1]; + uint64_t total_sp_count = counts_host[ntasks-1]; + uint64_t total_s_count = rc_counts_host[ntasks-1]; //size_t global_sp_count = total_sp_count; //MPI_Allreduce(MPI_IN_PLACE, &global_sp_count, 1, MPI_UINT64_T, MPI_SUM, @@ -653,8 +653,8 @@ void exx_ek_shellpair_collision( auto bv_st = hrt_t::now(); - auto position_sp_list_device = full_stack.aligned_alloc(total_sp_count); - auto position_s_list_device = full_stack.aligned_alloc(total_s_count); + auto position_sp_list_device = full_stack.aligned_alloc(total_sp_count); + auto position_s_list_device = full_stack.aligned_alloc(total_s_count); auto nbe_list = full_stack.aligned_alloc(ntasks); { dim3 threads(32,32); @@ -668,7 +668,7 @@ void exx_ek_shellpair_collision( ); } - std::vector position_sp_list(total_sp_count); + std::vector position_sp_list(total_sp_count); util::cuda_copy(total_sp_count, position_sp_list.data(), position_sp_list_device.ptr, "Position List ShellPair"); auto bv_en = hrt_t::now(); @@ -676,7 +676,7 @@ void exx_ek_shellpair_collision( auto d2h_st = hrt_t::now(); - std::vector position_s_list(total_s_count); + std::vector position_s_list(total_s_count); std::vector nbe_list_host(ntasks); util::cuda_copy(total_s_count, position_s_list.data(), position_s_list_device.ptr, "Position List Shell"); util::cuda_copy(ntasks, nbe_list_host.data(), nbe_list.ptr, "NBE List"); diff --git a/src/xc_integrator/local_work_driver/device/scheme1_data_base.cxx b/src/xc_integrator/local_work_driver/device/scheme1_data_base.cxx index 7818a5a83..3d0c6656a 100644 --- a/src/xc_integrator/local_work_driver/device/scheme1_data_base.cxx +++ b/src/xc_integrator/local_work_driver/device/scheme1_data_base.cxx @@ -58,7 +58,7 @@ size_t Scheme1DataBase::get_static_mem_requirement() { nsp * sizeof(int32_t) + // nprim_pairs nsp * sizeof(shell_pair*) + // shell_pair pointer nsp * 3 * sizeof(double) + // X_AB, Y_AB, Z_AB - 1024 * 1024; // additional memory for alignment padding + 4 * 1024 * 1024; // additional memory for alignment padding return size; } diff --git a/src/xc_integrator/xc_data/device/xc_device_aos_data.cxx b/src/xc_integrator/xc_data/device/xc_device_aos_data.cxx index 2e043842f..c7c76d245 100644 --- a/src/xc_integrator/xc_data/device/xc_device_aos_data.cxx +++ b/src/xc_integrator/xc_data/device/xc_device_aos_data.cxx @@ -52,6 +52,7 @@ size_t XCDeviceAoSData::get_mem_req( integrator_term_tracker terms, const size_t nbe_cou = task.cou_screening.nbe; const size_t ncut_cou = submat_cut_cou.size(); const size_t nblock_cou = submat_block_cou.size(); + const size_t nshells = global_dims.nshells; return base_size + // Collocation + Derivatives @@ -88,6 +89,9 @@ size_t XCDeviceAoSData::get_mem_req( integrator_term_tracker terms, // Map from packed to unpacked indices reqt.task_bfn_shell_indirection_size( nbe_bfn ) * sizeof(int32_t) + + // Scratch memory to store shell pairs + reqt.task_exx_collision_size( nshells ) * sizeof(int64_t) + + // Memory associated with task indirection: valid for both AoS and SoA reqt.task_indirection_size() * sizeof(XCDeviceTask); } diff --git a/src/xc_integrator/xc_data/device/xc_device_data.hpp b/src/xc_integrator/xc_data/device/xc_device_data.hpp index 781e23729..a2579bf44 100644 --- a/src/xc_integrator/xc_data/device/xc_device_data.hpp +++ b/src/xc_integrator/xc_data/device/xc_device_data.hpp @@ -376,6 +376,7 @@ struct required_term_storage { bool task_gmat = false; bool task_nbe_scr = false; bool task_bfn_shell_indirection = false; + bool task_exx_collision = false; inline size_t task_bfn_size(size_t nbe, size_t npts) { @@ -506,6 +507,12 @@ struct required_term_storage { const size_t num_subtasks = util::div_ceil(npts, subtask_size); return PRDVL(task_to_shell_pair_cou, num_subtasks); } + inline size_t task_exx_collision_size(size_t nshells) { + const size_t nslt = (nshells * (nshells+1)) / 2 + + nshells + ; + return PRDVL(task_exx_collision, nslt); + } @@ -638,6 +645,7 @@ struct required_term_storage { task_shell_offs_bfn = true; task_bfn_shell_indirection = true; shell_to_task_bfn = true; + task_exx_collision = true; } }