@@ -43,52 +43,85 @@ __all__ += [
4343]
4444
4545
46- ctypedef c_dpctl.DPCTLSyclEventRef(* fptr_dpnp_partition_t)(c_dpctl.DPCTLSyclQueueRef,
47- void * ,
48- void * ,
49- void * ,
50- const size_t,
51- const shape_elem_type * ,
52- const size_t,
53- const c_dpctl.DPCTLEventVectorRef)
54-
55-
56- cpdef utils.dpnp_descriptor dpnp_partition(utils.dpnp_descriptor arr, int kth, axis = - 1 , kind = ' introselect' , order = None ):
46+ ctypedef c_dpctl.DPCTLSyclEventRef(
47+ * fptr_dpnp_partition_t)(
48+ c_dpctl.DPCTLSyclQueueRef,
49+ void * ,
50+ void * ,
51+ void * ,
52+ const size_t,
53+ const shape_elem_type * ,
54+ const size_t,
55+ const c_dpctl.DPCTLEventVectorRef,
56+ )
57+
58+
59+ cpdef utils.dpnp_descriptor dpnp_partition(
60+ utils.dpnp_descriptor arr, int kth,
61+ axis = - 1 , kind = " introselect" , order = None ,
62+ ):
5763 cdef shape_type_c shape1 = arr.shape
5864
59- cdef size_t kth_ = kth if kth >= 0 else (arr.ndim + kth)
60- cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(arr.dtype)
61-
62- cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_PARTITION_EXT, param1_type, param1_type)
63-
64- cdef utils.dpnp_descriptor arr2 = dpnp.get_dpnp_descriptor(arr.get_pyobj().copy(), copy_when_nondefault_queue = False )
65+ cdef size_t kth_ = (
66+ kth if kth >= 0 else (arr.ndim + kth)
67+ )
68+ cdef DPNPFuncType param1_type = (
69+ dpnp_dtype_to_DPNPFuncType(arr.dtype)
70+ )
71+
72+ cdef DPNPFuncData kernel_data = (
73+ get_dpnp_function_ptr(
74+ DPNP_FN_PARTITION_EXT,
75+ param1_type, param1_type,
76+ )
77+ )
78+
79+ cdef utils.dpnp_descriptor arr2 = (
80+ dpnp.get_dpnp_descriptor(
81+ arr.get_pyobj().copy(),
82+ copy_when_nondefault_queue = False ,
83+ )
84+ )
6585
6686 arr_obj = arr.get_array()
6787
68- cdef utils.dpnp_descriptor result = utils.create_output_descriptor(arr.shape,
69- kernel_data.return_type,
70- None ,
71- device = arr_obj.sycl_device,
72- usm_type = arr_obj.usm_type,
73- sycl_queue = arr_obj.sycl_queue)
88+ cdef utils.dpnp_descriptor result = (
89+ utils.create_output_descriptor(
90+ arr.shape,
91+ kernel_data.return_type,
92+ None ,
93+ device = arr_obj.sycl_device,
94+ usm_type = arr_obj.usm_type,
95+ sycl_queue = arr_obj.sycl_queue,
96+ )
97+ )
7498
7599 result_sycl_queue = result.get_array().sycl_queue
76100
77- cdef c_dpctl.SyclQueue q = < c_dpctl.SyclQueue> result_sycl_queue
78- cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref()
79-
80- cdef fptr_dpnp_partition_t func = < fptr_dpnp_partition_t > kernel_data.ptr
81-
82- cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref,
83- arr.get_data(),
84- arr2.get_data(),
85- result.get_data(),
86- kth_,
87- shape1.data(),
88- arr.ndim,
89- NULL ) # dep_events_ref
90-
91- with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref)
101+ cdef c_dpctl.SyclQueue q = (
102+ < c_dpctl.SyclQueue> result_sycl_queue
103+ )
104+ cdef c_dpctl.DPCTLSyclQueueRef q_ref = (
105+ q.get_queue_ref()
106+ )
107+
108+ cdef fptr_dpnp_partition_t func = (
109+ < fptr_dpnp_partition_t> kernel_data.ptr
110+ )
111+
112+ cdef c_dpctl.DPCTLSyclEventRef event_ref = func(
113+ q_ref,
114+ arr.get_data(),
115+ arr2.get_data(),
116+ result.get_data(),
117+ kth_,
118+ shape1.data(),
119+ arr.ndim,
120+ NULL ,
121+ )
122+
123+ with nogil:
124+ c_dpctl.DPCTLEvent_WaitAndThrow(event_ref)
92125 c_dpctl.DPCTLEvent_Delete(event_ref)
93126
94127 return result
0 commit comments