Skip to content

Commit f4875f6

Browse files
committed
removing redundant checks that are already in cython
1 parent 24086c3 commit f4875f6

1 file changed

Lines changed: 142 additions & 144 deletions

File tree

cuda_core/cuda/core/_tensor_map.pyx

Lines changed: 142 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,14 @@ class TensorMapOOBFill(enum.IntEnum):
8080
NAN_REQUEST_ZERO_FMA = cydriver.CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA
8181

8282

83-
IF CUDA_CORE_BUILD_MAJOR >= 13:
84-
class TensorMapIm2ColWideMode(enum.IntEnum):
85-
"""Im2col wide mode for tensor map descriptors.
83+
class TensorMapIm2ColWideMode(enum.IntEnum):
84+
"""Im2col wide mode for tensor map descriptors.
8685
87-
These correspond to the ``CUtensorMapIm2ColWideMode`` driver enum values.
88-
Supported on compute capability 10.0+.
89-
"""
90-
W = cydriver.CU_TENSOR_MAP_IM2COL_WIDE_MODE_W
91-
W128 = cydriver.CU_TENSOR_MAP_IM2COL_WIDE_MODE_W128
86+
These correspond to the ``CUtensorMapIm2ColWideMode`` driver enum values.
87+
Supported on compute capability 10.0+.
88+
"""
89+
W = cydriver.CU_TENSOR_MAP_IM2COL_WIDE_MODE_W
90+
W128 = cydriver.CU_TENSOR_MAP_IM2COL_WIDE_MODE_W128
9291

9392

9493
# Mapping from numpy dtype to TMA data type
@@ -505,142 +504,141 @@ cdef class TensorMapDescriptor:
505504

506505
return desc
507506

508-
IF CUDA_CORE_BUILD_MAJOR >= 13:
509-
@classmethod
510-
def from_im2col_wide(cls, tensor, pixel_box_lower_corner_width, pixel_box_upper_corner_width,
511-
channels_per_pixel, pixels_per_column, *,
512-
element_strides=None,
513-
data_type=None,
514-
interleave=TensorMapInterleave.NONE,
515-
mode=TensorMapIm2ColWideMode.W,
516-
swizzle=TensorMapSwizzle.SWIZZLE_128B,
517-
l2_promotion=TensorMapL2Promotion.NONE,
518-
oob_fill=TensorMapOOBFill.NONE):
519-
"""Create an im2col-wide TMA descriptor from a tensor object.
520-
521-
Im2col-wide layout loads elements exclusively along the W (width)
522-
dimension. This variant is supported on compute capability 10.0+
523-
(Blackwell and later).
524-
525-
Parameters
526-
----------
527-
tensor : object
528-
Any object supporting DLPack or ``__cuda_array_interface__``,
529-
or a :obj:`~cuda.core.StridedMemoryView`. Must refer to
530-
device-accessible memory with a 16-byte-aligned pointer.
531-
pixel_box_lower_corner_width : int
532-
Lower corner of the pixel bounding box along the W dimension.
533-
pixel_box_upper_corner_width : int
534-
Upper corner of the pixel bounding box along the W dimension.
535-
channels_per_pixel : int
536-
Number of channels per pixel.
537-
pixels_per_column : int
538-
Number of pixels per column.
539-
element_strides : tuple of int, optional
540-
Per-dimension element traversal strides. Default is all 1s.
541-
data_type : TensorMapDataType, optional
542-
Explicit data type override. If ``None``, inferred from the
543-
tensor's dtype.
544-
interleave : TensorMapInterleave
545-
Interleave layout. Default ``NONE``.
546-
mode : TensorMapIm2ColWideMode
547-
Im2col wide mode. Default ``W``.
548-
swizzle : TensorMapSwizzle
549-
Swizzle mode. Default ``SWIZZLE_128B``.
550-
l2_promotion : TensorMapL2Promotion
551-
L2 promotion mode. Default ``NONE``.
552-
oob_fill : TensorMapOOBFill
553-
Out-of-bounds fill mode. Default ``NONE``.
554-
555-
Returns
556-
-------
557-
TensorMapDescriptor
558-
559-
Raises
560-
------
561-
ValueError
562-
If the tensor rank is outside [3, 5], the pointer is not
563-
16-byte aligned, or other constraints are violated.
564-
"""
565-
cdef TensorMapDescriptor desc = cls.__new__(cls)
566-
567-
view = _get_validated_view(tensor)
568-
desc._source_ref = (tensor, view)
569-
570-
tma_dt = _resolve_data_type(view, data_type)
571-
cdef int c_data_type_int = int(tma_dt)
572-
cdef cydriver.CUtensorMapDataType c_data_type = <cydriver.CUtensorMapDataType>c_data_type_int
573-
574-
cdef intptr_t global_address = view.ptr
575-
shape = view.shape
576-
577-
cdef int rank = len(shape)
578-
if rank < 3 or rank > 5:
579-
raise ValueError(
580-
f"Im2col-wide tensor rank must be between 3 and 5, got {rank}")
581-
582-
element_strides = _validate_element_strides(element_strides, rank)
583-
584-
cdef int elem_size = _TMA_DATA_TYPE_SIZE[tma_dt]
585-
byte_strides = _compute_byte_strides(shape, view.strides, elem_size)
586-
587-
# Reverse all dimension arrays for column-major convention
588-
cdef uint64_t[5] c_global_dim
589-
cdef uint64_t[4] c_global_strides
590-
cdef uint32_t[5] c_element_strides
591-
cdef int i_c
592-
593-
for i_c in range(rank):
594-
c_global_dim[i_c] = <uint64_t>shape[rank - 1 - i_c]
595-
c_element_strides[i_c] = <uint32_t>element_strides[rank - 1 - i_c]
596-
597-
for i_c in range(rank - 1):
598-
c_global_strides[i_c] = <uint64_t>byte_strides[rank - 2 - i_c]
599-
600-
cdef uint32_t c_rank = <uint32_t>rank
601-
cdef int c_lower_w = <int>pixel_box_lower_corner_width
602-
cdef int c_upper_w = <int>pixel_box_upper_corner_width
603-
cdef uint32_t c_channels = <uint32_t>channels_per_pixel
604-
cdef uint32_t c_pixels = <uint32_t>pixels_per_column
605-
cdef int c_interleave_int = int(interleave)
606-
cdef int c_mode_int = int(mode)
607-
cdef int c_swizzle_int = int(swizzle)
608-
cdef int c_l2_promotion_int = int(l2_promotion)
609-
cdef int c_oob_fill_int = int(oob_fill)
610-
cdef cydriver.CUtensorMapInterleave c_interleave = <cydriver.CUtensorMapInterleave>c_interleave_int
611-
cdef cydriver.CUtensorMapIm2ColWideMode c_mode = <cydriver.CUtensorMapIm2ColWideMode>c_mode_int
612-
cdef cydriver.CUtensorMapSwizzle c_swizzle = <cydriver.CUtensorMapSwizzle>c_swizzle_int
613-
cdef cydriver.CUtensorMapL2promotion c_l2_promotion = <cydriver.CUtensorMapL2promotion>c_l2_promotion_int
614-
cdef cydriver.CUtensorMapFloatOOBfill c_oob_fill = <cydriver.CUtensorMapFloatOOBfill>c_oob_fill_int
615-
616-
with nogil:
617-
HANDLE_RETURN(cydriver.cuTensorMapEncodeIm2colWide(
618-
&desc._tensor_map,
619-
c_data_type,
620-
c_rank,
621-
<void*>global_address,
622-
c_global_dim,
623-
c_global_strides,
624-
c_lower_w,
625-
c_upper_w,
626-
c_channels,
627-
c_pixels,
628-
c_element_strides,
629-
c_interleave,
630-
c_mode,
631-
c_swizzle,
632-
c_l2_promotion,
633-
c_oob_fill,
634-
))
635-
636-
desc._repr_info = {
637-
"method": "im2col_wide",
638-
"rank": rank,
639-
"data_type": tma_dt,
640-
"swizzle": swizzle,
641-
}
642-
643-
return desc
507+
@classmethod
508+
def from_im2col_wide(cls, tensor, pixel_box_lower_corner_width, pixel_box_upper_corner_width,
509+
channels_per_pixel, pixels_per_column, *,
510+
element_strides=None,
511+
data_type=None,
512+
interleave=TensorMapInterleave.NONE,
513+
mode=TensorMapIm2ColWideMode.W,
514+
swizzle=TensorMapSwizzle.SWIZZLE_128B,
515+
l2_promotion=TensorMapL2Promotion.NONE,
516+
oob_fill=TensorMapOOBFill.NONE):
517+
"""Create an im2col-wide TMA descriptor from a tensor object.
518+
519+
Im2col-wide layout loads elements exclusively along the W (width)
520+
dimension. This variant is supported on compute capability 10.0+
521+
(Blackwell and later).
522+
523+
Parameters
524+
----------
525+
tensor : object
526+
Any object supporting DLPack or ``__cuda_array_interface__``,
527+
or a :obj:`~cuda.core.StridedMemoryView`. Must refer to
528+
device-accessible memory with a 16-byte-aligned pointer.
529+
pixel_box_lower_corner_width : int
530+
Lower corner of the pixel bounding box along the W dimension.
531+
pixel_box_upper_corner_width : int
532+
Upper corner of the pixel bounding box along the W dimension.
533+
channels_per_pixel : int
534+
Number of channels per pixel.
535+
pixels_per_column : int
536+
Number of pixels per column.
537+
element_strides : tuple of int, optional
538+
Per-dimension element traversal strides. Default is all 1s.
539+
data_type : TensorMapDataType, optional
540+
Explicit data type override. If ``None``, inferred from the
541+
tensor's dtype.
542+
interleave : TensorMapInterleave
543+
Interleave layout. Default ``NONE``.
544+
mode : TensorMapIm2ColWideMode
545+
Im2col wide mode. Default ``W``.
546+
swizzle : TensorMapSwizzle
547+
Swizzle mode. Default ``SWIZZLE_128B``.
548+
l2_promotion : TensorMapL2Promotion
549+
L2 promotion mode. Default ``NONE``.
550+
oob_fill : TensorMapOOBFill
551+
Out-of-bounds fill mode. Default ``NONE``.
552+
553+
Returns
554+
-------
555+
TensorMapDescriptor
556+
557+
Raises
558+
------
559+
ValueError
560+
If the tensor rank is outside [3, 5], the pointer is not
561+
16-byte aligned, or other constraints are violated.
562+
"""
563+
cdef TensorMapDescriptor desc = cls.__new__(cls)
564+
565+
view = _get_validated_view(tensor)
566+
desc._source_ref = (tensor, view)
567+
568+
tma_dt = _resolve_data_type(view, data_type)
569+
cdef int c_data_type_int = int(tma_dt)
570+
cdef cydriver.CUtensorMapDataType c_data_type = <cydriver.CUtensorMapDataType>c_data_type_int
571+
572+
cdef intptr_t global_address = view.ptr
573+
shape = view.shape
574+
575+
cdef int rank = len(shape)
576+
if rank < 3 or rank > 5:
577+
raise ValueError(
578+
f"Im2col-wide tensor rank must be between 3 and 5, got {rank}")
579+
580+
element_strides = _validate_element_strides(element_strides, rank)
581+
582+
cdef int elem_size = _TMA_DATA_TYPE_SIZE[tma_dt]
583+
byte_strides = _compute_byte_strides(shape, view.strides, elem_size)
584+
585+
# Reverse all dimension arrays for column-major convention
586+
cdef uint64_t[5] c_global_dim
587+
cdef uint64_t[4] c_global_strides
588+
cdef uint32_t[5] c_element_strides
589+
cdef int i_c
590+
591+
for i_c in range(rank):
592+
c_global_dim[i_c] = <uint64_t>shape[rank - 1 - i_c]
593+
c_element_strides[i_c] = <uint32_t>element_strides[rank - 1 - i_c]
594+
595+
for i_c in range(rank - 1):
596+
c_global_strides[i_c] = <uint64_t>byte_strides[rank - 2 - i_c]
597+
598+
cdef uint32_t c_rank = <uint32_t>rank
599+
cdef int c_lower_w = <int>pixel_box_lower_corner_width
600+
cdef int c_upper_w = <int>pixel_box_upper_corner_width
601+
cdef uint32_t c_channels = <uint32_t>channels_per_pixel
602+
cdef uint32_t c_pixels = <uint32_t>pixels_per_column
603+
cdef int c_interleave_int = int(interleave)
604+
cdef int c_mode_int = int(mode)
605+
cdef int c_swizzle_int = int(swizzle)
606+
cdef int c_l2_promotion_int = int(l2_promotion)
607+
cdef int c_oob_fill_int = int(oob_fill)
608+
cdef cydriver.CUtensorMapInterleave c_interleave = <cydriver.CUtensorMapInterleave>c_interleave_int
609+
cdef cydriver.CUtensorMapIm2ColWideMode c_mode = <cydriver.CUtensorMapIm2ColWideMode>c_mode_int
610+
cdef cydriver.CUtensorMapSwizzle c_swizzle = <cydriver.CUtensorMapSwizzle>c_swizzle_int
611+
cdef cydriver.CUtensorMapL2promotion c_l2_promotion = <cydriver.CUtensorMapL2promotion>c_l2_promotion_int
612+
cdef cydriver.CUtensorMapFloatOOBfill c_oob_fill = <cydriver.CUtensorMapFloatOOBfill>c_oob_fill_int
613+
614+
with nogil:
615+
HANDLE_RETURN(cydriver.cuTensorMapEncodeIm2colWide(
616+
&desc._tensor_map,
617+
c_data_type,
618+
c_rank,
619+
<void*>global_address,
620+
c_global_dim,
621+
c_global_strides,
622+
c_lower_w,
623+
c_upper_w,
624+
c_channels,
625+
c_pixels,
626+
c_element_strides,
627+
c_interleave,
628+
c_mode,
629+
c_swizzle,
630+
c_l2_promotion,
631+
c_oob_fill,
632+
))
633+
634+
desc._repr_info = {
635+
"method": "im2col_wide",
636+
"rank": rank,
637+
"data_type": tma_dt,
638+
"swizzle": swizzle,
639+
}
640+
641+
return desc
644642

645643
def replace_address(self, tensor):
646644
"""Replace the global memory address in this tensor map descriptor.

0 commit comments

Comments
 (0)