Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/implementation-status.md
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ selected constant values from `ISO_FORTRAN_ENV` of the hosting compiler.
| `prif_co_min_character` | **YES** | |
| `prif_co_sum` | **YES** | |
| `prif_co_reduce` | **YES** | |
| `prif_co_reduce_cptr` | **YES** | expected in PRIF 0.8 |

---

Expand Down
31 changes: 19 additions & 12 deletions src/caffeine/caffeine.c
Original file line number Diff line number Diff line change
Expand Up @@ -427,12 +427,12 @@ void caf_atomic_logical(int opcode, int image, void* addr, int64_t *result, int6
caf_atomic_int(opcode, image, addr, result, op1, op2);
}

//-------------------------------------------------------------------
// Type-erased collective subroutines
//-------------------------------------------------------------------

void caf_co_reduce(
CFI_cdesc_t* a_desc, int result_image, size_t num_elements, gex_Coll_ReduceFn_t user_op, void* client_data, gex_TM_t team
) {
assert(a_desc);
void caf_co_reduce_cptr( void *a_ptr, int result_image, size_t num_elements, size_t element_size,
gex_Coll_ReduceFn_t user_op, void* client_data, gex_TM_t team) {
assert(result_image >= 0);
assert(num_elements > 0);
assert(user_op);
Expand All @@ -442,22 +442,29 @@ void caf_co_reduce(
// Here we undo that incorrect extra level of indirection
user_op = *(gex_Coll_ReduceFn_t *)user_op;
#endif
char* a_address = (char*) a_desc->base_addr;
size_t c_sizeof_a = a_desc->elem_len;
gex_Event_t ev;

if (result_image) {
ev = gex_Coll_ReduceToOneNB(
team, result_image-1, a_address, a_address, GEX_DT_USER, c_sizeof_a, num_elements, GEX_OP_USER, user_op, client_data, 0
);
ev = gex_Coll_ReduceToOneNB( team, result_image-1, a_ptr, a_ptr,
GEX_DT_USER, element_size, num_elements,
GEX_OP_USER, user_op, client_data, 0 );
} else {
ev = gex_Coll_ReduceToAllNB(
team, a_address, a_address, GEX_DT_USER, c_sizeof_a, num_elements, GEX_OP_USER, user_op, client_data, 0
);
ev = gex_Coll_ReduceToAllNB( team, a_ptr, a_ptr,
GEX_DT_USER, element_size, num_elements,
GEX_OP_USER, user_op, client_data, 0);
}
gex_Event_Wait(ev);
}

void caf_co_reduce( CFI_cdesc_t* a_desc, int result_image, size_t num_elements,
gex_Coll_ReduceFn_t user_op, void* client_data, gex_TM_t team) {
assert(a_desc);
char* a_ptr = (char*) a_desc->base_addr;
size_t element_size = a_desc->elem_len;
caf_co_reduce_cptr(a_ptr, result_image, num_elements, element_size,
user_op, client_data, team);
}

void caf_co_broadcast(CFI_cdesc_t * a_desc, int source_image, int num_elements, gex_TM_t team)
{
char* c_loc_a = (char*) a_desc->base_addr;
Expand Down
28 changes: 28 additions & 0 deletions src/caffeine/co_reduce_s.F90
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,32 @@ subroutine contiguous_co_reduce(a, operation_wrapper, cdata, result_image, stat,
current_team%info%gex_team)
end subroutine

module subroutine prif_co_reduce_cptr(a_ptr, element_size, element_count, operation_wrapper, cdata, result_image, stat, errmsg, errmsg_alloc)
type(c_ptr), intent(in) :: a_ptr
integer(c_size_t), intent(in) :: element_size
integer(c_size_t), intent(in) :: element_count
procedure(prif_operation_wrapper_interface), pointer, intent(in) :: operation_wrapper
type(c_ptr), intent(in), value :: cdata
integer(c_int), intent(in), optional :: result_image
integer(c_int), intent(out), optional :: stat
character(len=*), intent(inout), optional :: errmsg
character(len=:), intent(inout), allocatable, optional :: errmsg_alloc
type(c_funptr) :: funptr

if (present(stat)) stat=0

funptr = c_funloc(operation_wrapper)
call_assert(c_associated(funptr))

call caf_co_reduce_cptr( &
a_ptr, &
optional_value(result_image), &
element_count, element_size, &
funptr, &
cdata, &
current_team%info%gex_team)
end subroutine



end submodule co_reduce_s
17 changes: 15 additions & 2 deletions src/caffeine/prif_private_s.F90
Original file line number Diff line number Diff line change
Expand Up @@ -288,14 +288,27 @@ subroutine caf_co_broadcast(a, source_image, Nelem, team) bind(C)
end subroutine

subroutine caf_co_reduce(a, result_image, num_elements, Coll_ReduceSub, client_data, team) bind(C)
!! void caf_co_reduce(CFI_cdesc_t* a_desc, int result_image, int num_elements, gex_Coll_ReduceFn_t* user_op, void* client_data)
!! void caf_co_reduce(CFI_cdesc_t* a_desc, int result_image, size_t num_elements, gex_Coll_ReduceFn_t user_op, void* client_data, gex_TM_t team)
import c_int, c_ptr, c_size_t, c_funptr
implicit none
type(*) a(..)
integer(c_int), value :: result_image
type(c_ptr), value :: client_data
integer(c_size_t), value :: num_elements
type(c_funptr), value :: Coll_ReduceSub
type(c_ptr), value :: client_data
type(c_ptr), value :: team
end subroutine

subroutine caf_co_reduce_cptr(a_ptr, result_image, num_elements, element_size, Coll_ReduceSub, client_data, team) bind(C)
!! void caf_co_reduce_cptr(void *a_ptr, int result_image, size_t num_elements, size_t element_size, gex_Coll_ReduceFn_t user_op, void* client_data, gex_TM_t team)
import c_int, c_ptr, c_size_t, c_funptr
implicit none
type(c_ptr), value :: a_ptr
integer(c_int), value :: result_image
integer(c_size_t), value :: num_elements
integer(c_size_t), value :: element_size
type(c_funptr), value :: Coll_ReduceSub
type(c_ptr), value :: client_data
type(c_ptr), value :: team
end subroutine

Expand Down
15 changes: 14 additions & 1 deletion src/prif.F90
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ module prif
public :: prif_num_images, prif_num_images_with_team, prif_num_images_with_team_number
public :: prif_failed_images, prif_stopped_images, prif_image_status
public :: prif_local_data_pointer, prif_set_context_data, prif_get_context_data, prif_size_bytes
public :: prif_co_sum, prif_co_max, prif_co_min, prif_co_reduce, prif_co_broadcast
public :: prif_co_sum, prif_co_max, prif_co_min, prif_co_reduce, prif_co_reduce_cptr, prif_co_broadcast
public :: prif_co_min_character, prif_co_max_character
public :: prif_operation_wrapper_interface
public :: prif_form_team, prif_change_team, prif_end_team, prif_get_team, prif_team_number
Expand Down Expand Up @@ -741,6 +741,19 @@ module subroutine prif_co_reduce(a, operation_wrapper, cdata, result_image, stat
character(len=:), intent(inout), allocatable, optional :: errmsg_alloc
end subroutine

module subroutine prif_co_reduce_cptr(a_ptr, element_size, element_count, operation_wrapper, cdata, result_image, stat, errmsg, errmsg_alloc)
implicit none
type(c_ptr), intent(in) :: a_ptr
integer(c_size_t), intent(in) :: element_size
integer(c_size_t), intent(in) :: element_count
procedure(prif_operation_wrapper_interface), pointer, intent(in) :: operation_wrapper
type(c_ptr), intent(in), value :: cdata
integer(c_int), intent(in), optional :: result_image
integer(c_int), intent(out), optional :: stat
character(len=*), intent(inout), optional :: errmsg
character(len=:), intent(inout), allocatable, optional :: errmsg_alloc
end subroutine

module subroutine prif_co_broadcast(a, source_image, stat, errmsg, errmsg_alloc)
implicit none
type(*), intent(inout), target :: a(..)
Expand Down
61 changes: 51 additions & 10 deletions test/prif_co_reduce_test.F90
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
#include "julienne-assert-macros.h"

module prif_co_reduce_test_m
use iso_c_binding, only: c_ptr, c_funptr, c_size_t, c_f_pointer, c_f_procpointer, c_funloc, c_loc, c_null_ptr, c_associated
use prif, only : prif_co_reduce, prif_num_images, prif_this_image_no_coarray, prif_operation_wrapper_interface
use iso_c_binding, only: c_ptr, c_funptr, c_size_t, c_f_pointer, c_f_procpointer, c_funloc, c_loc, c_null_ptr, c_associated, c_int8_t
use prif, only : prif_co_reduce, prif_co_reduce_cptr, prif_num_images, prif_this_image_no_coarray, prif_operation_wrapper_interface
use julienne_m, only : &
call_julienne_assert_ &
,operator(.all.) &
Expand Down Expand Up @@ -122,6 +122,8 @@ function check_derived_type_reduction() result(diag)
procedure(prif_operation_wrapper_interface), pointer :: op
real, parameter :: tolerance = 0D0

diag = .true.

op => pair_adder
call prif_this_image_no_coarray(this_image=me)
call prif_num_images(ni)
Expand All @@ -145,8 +147,23 @@ function check_derived_type_reduction() result(diag)
#else
expected = reduce(tmp, add_pair, dim=2)
#endif
diag = .all. (my_val%fst .equalsExpected. expected%fst) &
.also. (.all. ( my_val%snd .approximates. expected%snd .within. tolerance))
ALSO(.all. (my_val%fst .equalsExpected. expected%fst))
ALSO(.all. (my_val%snd .approximates. expected%snd .within. tolerance))

! now repeat the same test using the prif_co_reduce_cptr variant:
my_val = values(:, mod(me-1, size(values,2))+1)
block
integer(c_size_t) :: element_size, element_count
integer(c_int8_t), allocatable, target :: bytes(:)
element_size = storage_size(my_val(1))/8
element_count = size(my_val)
bytes = transfer(my_val, bytes)
call prif_co_reduce_cptr(c_loc(bytes), element_size, element_count, op, c_loc(dummy))
my_val = transfer(bytes, my_val, element_count)
end block
ALSO(.all. (my_val%fst .equalsExpected. expected%fst))
ALSO(.all. (my_val%snd .approximates. expected%snd .within. tolerance))

end function

pure function add_pair(lhs, rhs) result(total)
Expand Down Expand Up @@ -175,11 +192,10 @@ subroutine pair_adder(arg1, arg2_and_out, count, cdata) bind(C)
end subroutine

#if HAVE_PARAM_DERIVED
! As of LLVM20, flang does not implement the types used by this test:
! As of LLVM21, flang does not implement the types used by this test:
! flang/lib/Lower/ConvertType.cpp:482: not yet implemented: parameterized derived types
! error: Actual argument associated with TYPE(*) dummy argument 'a=' may not have a parameterized derived type

! Gfortran 14.2 also lacks the type support for this test:
! Gfortran 14.2..15.2 also lack the type support for this test:
! Error: Derived type 'pdtarray' at (1) is being used before it is defined

function check_type_parameter_reduction() result(diag)
Expand All @@ -196,17 +212,42 @@ function check_type_parameter_reduction() result(diag)
procedure(prif_operation_wrapper_interface), pointer :: op
type(reduction_context_data), target :: context

diag = .true.

op => array_wrapper
context%user_op = c_funloc(add_array)
context%length = values%length
context%length = values(1,1)%length
call prif_this_image_no_coarray(this_image=me)
call prif_num_images(ni)

my_val = values(:, mod(me-1, size(values,2))+1)
call prif_co_reduce(my_val, op, c_loc(context))

# if ALLOW_ASSUMED_TYPE_PDT
! Ideally here we'd directly pass the user data `my_val` to prif_co_reduce as follows:
call prif_co_reduce(my_val, op, c_loc(context))
! Unfortunately the code above is not strictly standards-conformant, because Fortran forbids
! passing an actual argument of derived type with type parameters to a procedure where the
! corresponding dummy argument has assumed type (the first argument to `prif_co_reduce`).
! Example errors from gfortran and flang:
! error: Actual argument associated with TYPE(*) dummy argument 'a=' may not have a parameterized derived type
! Error: Actual argument at (1) to assumed-type dummy has type parameters or is of derived type with type-bound or FINAL procedures
# else
! So instead, we stage the data through an type-erased buffer and call the _cptr variant
block
integer(c_size_t) :: element_size, element_count
integer(c_int8_t), allocatable, target :: bytes(:)
element_size = storage_size(my_val(1))/8
element_count = size(my_val)
bytes = transfer(my_val, bytes)
call prif_co_reduce_cptr(c_loc(bytes), element_size, element_count, op, c_loc(context))
my_val = transfer(bytes, my_val, element_count)
end block
# endif

expected = reduce(reshape([(values(:, mod(i-1,size(values,2))+1), i = 1, ni)], [size(values,1),ni]), add_array, dim=2)
diag = .all. (my_val%elements .equalsExpected. expected%elements)
do i = 1, size(my_val)
ALSO(.all. (my_val(i)%elements .equalsExpected. expected(i)%elements))
end do
end function

pure function add_array(lhs, rhs) result(total)
Expand Down