-
Notifications
You must be signed in to change notification settings - Fork 258
Add TMA TensorMapDescriptor support #1687
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
rparolin
wants to merge
26
commits into
NVIDIA:main
Choose a base branch
from
rparolin:rparolin/tma_feature
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+2,222
−29
Open
Changes from all commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
e3e1899
initial commit
rparolin 77a3c8e
tma wide
rparolin 19c4a0f
clean up
rparolin 35a04b9
Add comments to prepare_tensor_map_arg explaining allocation and life…
rparolin bb19e4f
Address Copilot review feedback
rparolin 23a8900
Split TMA example into two focused files
rparolin 0a1b720
pre-commit
rparolin 44fbdcf
adding stride meta data to gpu allocated memory
rparolin bdf39a2
im2col fixes
rparolin 96a3e84
Reuse CCCL TMA descriptor construction for tiled TensorMap and keep v…
cpcloud 1a6b416
Skip im2col-wide TensorMap tests when runtime support is unavailable.
cpcloud 892ee60
Align TensorMap API surface with review feedback and enforce context …
cpcloud 5a0e141
Restore cu12 feature definitions in cuda_core pixi manifest.
cpcloud eef1c7a
Handle TensorMap device validation by DLPack type
rparolin 99ff204
Merge branch 'main' into rparolin/tma_feature
rparolin d6c311a
formatting change
rparolin 9673bcf
Update cuda_core/cuda/core/_cpp/tensor_map_cccl.h
rparolin ae86192
Update cuda_core/examples/tma_replace_address.py
rparolin 232b621
Update cuda_core/cuda/core/__init__.py
rparolin 358d975
Align TensorMap creation and launch behavior with the latest review g…
cpcloud e67e9d3
Consolidate the TMA examples around the libcudacxx wrappers.
cpcloud 9ff8d0f
Teach the TMA example where to find libcudacxx headers.
cpcloud 719f0f3
Bundle tiled TensorMap options and type retained views.
cpcloud ad1c800
Keep the rebased TensorMap validation helper consistent.
cpcloud a1203ac
Apply the pre-commit fixes for the rebased TensorMap branch.
cpcloud 3c9e32d
Keep the TensorMap multi-GPU tests on the view-based API.
cpcloud File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -68,3 +68,4 @@ | |
| Stream, | ||
| StreamOptions, | ||
| ) | ||
| from cuda.core._tensor_map import TensorMapDescriptor, TensorMapDescriptorOptions | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,154 @@ | ||
| // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| // | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| #include "tensor_map_cccl.h" | ||
|
|
||
| #include <string.h> | ||
|
|
||
| #include <algorithm> | ||
| #include <exception> | ||
|
|
||
| #if defined(__has_include) | ||
| // Older CTK releases do not ship <cuda/tma>. When it is unavailable we keep | ||
| // the CCCL helper compiled out and fall back to the direct driver path. | ||
| # if __has_include(<cuda/tma>) | ||
| # include <cuda/tma> | ||
| # define CUDA_CORE_HAS_CUDA_TMA 1 | ||
| # else | ||
| # define CUDA_CORE_HAS_CUDA_TMA 0 | ||
| # endif | ||
| # if __has_include("dlpack.h") | ||
| # include "dlpack.h" | ||
| # define CUDA_CORE_HAS_DLPACK_H 1 | ||
| # elif __has_include(<dlpack/dlpack.h>) | ||
| # include <dlpack/dlpack.h> | ||
| # define CUDA_CORE_HAS_DLPACK_H 1 | ||
| # else | ||
| # define CUDA_CORE_HAS_DLPACK_H 0 | ||
| # endif | ||
| #else | ||
| # define CUDA_CORE_HAS_CUDA_TMA 0 | ||
| # define CUDA_CORE_HAS_DLPACK_H 0 | ||
| #endif | ||
|
|
||
| static inline void cuda_core_write_err(char* err, size_t cap, const char* msg) noexcept | ||
| { | ||
| if (!err || cap == 0) | ||
| return; | ||
| if (!msg) | ||
| { | ||
| err[0] = '\0'; | ||
| return; | ||
| } | ||
| size_t n = ::strlen(msg); | ||
| if (n >= cap) | ||
| n = cap - 1; | ||
| ::memcpy(err, msg, n); | ||
| err[n] = '\0'; | ||
| } | ||
|
|
||
| int cuda_core_cccl_make_tma_descriptor_tiled( | ||
| void* out_tensor_map, | ||
| void* data, | ||
| int device_type, | ||
| int device_id, | ||
| int ndim, | ||
| const int64_t* shape, | ||
| const int64_t* strides, | ||
| uint8_t dtype_code, | ||
| uint8_t dtype_bits, | ||
| uint16_t dtype_lanes, | ||
| const int* box_sizes, | ||
| const int* elem_strides, | ||
| int interleave_layout, | ||
| int swizzle, | ||
| int l2_fetch_size, | ||
| int oob_fill, | ||
| char* err, | ||
| size_t err_cap) noexcept | ||
| { | ||
| #if !(CUDA_CORE_HAS_CUDA_TMA && CUDA_CORE_HAS_DLPACK_H) | ||
| (void)out_tensor_map; | ||
| (void)data; | ||
| (void)device_type; | ||
| (void)device_id; | ||
| (void)ndim; | ||
| (void)shape; | ||
| (void)strides; | ||
| (void)dtype_code; | ||
| (void)dtype_bits; | ||
| (void)dtype_lanes; | ||
| (void)box_sizes; | ||
| (void)elem_strides; | ||
| (void)interleave_layout; | ||
| (void)swizzle; | ||
| (void)l2_fetch_size; | ||
| (void)oob_fill; | ||
| cuda_core_write_err(err, err_cap, "CCCL <cuda/tma> and/or <dlpack/dlpack.h> not available at build time"); | ||
| return 1; | ||
| #else | ||
| try | ||
| { | ||
| if (!out_tensor_map) | ||
| { | ||
| cuda_core_write_err(err, err_cap, "out_tensor_map is NULL"); | ||
| return 1; | ||
| } | ||
| if (!data) | ||
| { | ||
| cuda_core_write_err(err, err_cap, "tensor data pointer is NULL"); | ||
| return 1; | ||
| } | ||
| if (!shape || !box_sizes || ndim <= 0) | ||
| { | ||
| cuda_core_write_err(err, err_cap, "invalid rank/shape/box_sizes"); | ||
| return 1; | ||
| } | ||
|
|
||
| DLTensor t{}; | ||
| t.data = data; | ||
| t.device = {static_cast<DLDeviceType>(device_type), device_id}; | ||
| t.ndim = ndim; | ||
| t.dtype.code = dtype_code; | ||
| t.dtype.bits = dtype_bits; | ||
| t.dtype.lanes = dtype_lanes; | ||
| // CCCL promises not to mutate the arrays, but DLPack uses non-const pointers. | ||
| t.shape = const_cast<int64_t*>(shape); | ||
| t.strides = const_cast<int64_t*>(strides); | ||
| t.byte_offset = 0; | ||
|
|
||
| const auto layout = static_cast<cuda::tma_interleave_layout>(interleave_layout); | ||
| const auto swz = static_cast<cuda::tma_swizzle>(swizzle); | ||
| const auto l2 = static_cast<cuda::tma_l2_fetch_size>(l2_fetch_size); | ||
| const auto oob = static_cast<cuda::tma_oob_fill>(oob_fill); | ||
|
|
||
| auto box = cuda::std::span<const int>(box_sizes, static_cast<size_t>(ndim)); | ||
|
|
||
| CUtensorMap desc{}; | ||
| if (elem_strides) | ||
| { | ||
| auto es = cuda::std::span<const int>(elem_strides, static_cast<size_t>(ndim)); | ||
| desc = cuda::make_tma_descriptor(t, box, es, layout, swz, l2, oob); | ||
| } | ||
| else | ||
| { | ||
| desc = cuda::make_tma_descriptor(t, box, layout, swz, l2, oob); | ||
| } | ||
|
|
||
| ::memcpy(out_tensor_map, &desc, sizeof(CUtensorMap)); | ||
| cuda_core_write_err(err, err_cap, nullptr); | ||
| return 0; | ||
| } | ||
| catch (const std::exception& e) | ||
| { | ||
| cuda_core_write_err(err, err_cap, e.what()); | ||
| return 1; | ||
| } | ||
| catch (...) | ||
| { | ||
| cuda_core_write_err(err, err_cap, "unknown error while building TMA descriptor"); | ||
| return 1; | ||
| } | ||
| #endif | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,45 @@ | ||
| // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| // | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| #ifndef CUDA_CORE_TENSOR_MAP_CCCL_H_ | ||
| #define CUDA_CORE_TENSOR_MAP_CCCL_H_ | ||
|
|
||
| #ifdef __cplusplus | ||
| #include <cstddef> | ||
| #include <cstdint> | ||
| extern "C" { | ||
| #else | ||
| #include <stddef.h> | ||
| #include <stdint.h> | ||
| #endif | ||
|
|
||
| // Build a tiled CUtensorMap using CCCL's cuda::make_tma_descriptor (from <cuda/tma>). | ||
| // | ||
| // Returns 0 on success; on failure returns non-zero and writes a best-effort | ||
| // human-readable message into (err, err_cap) if provided. | ||
| int cuda_core_cccl_make_tma_descriptor_tiled( | ||
| void* out_tensor_map, | ||
| void* data, | ||
| int device_type, | ||
| int device_id, | ||
| int ndim, | ||
| const int64_t* shape, // length ndim | ||
| const int64_t* strides, // length ndim, or NULL for contiguous | ||
| uint8_t dtype_code, | ||
| uint8_t dtype_bits, | ||
| uint16_t dtype_lanes, | ||
| const int* box_sizes, // length ndim | ||
| const int* elem_strides, // length ndim, or NULL for all-ones overload | ||
| int interleave_layout, | ||
| int swizzle, | ||
| int l2_fetch_size, | ||
| int oob_fill, | ||
| char* err, | ||
| size_t err_cap) noexcept; | ||
|
|
||
| #ifdef __cplusplus | ||
| } // extern "C" | ||
| #endif | ||
|
|
||
| #endif // CUDA_CORE_TENSOR_MAP_CCCL_H_ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from libc.stdint cimport intptr_t | ||
|
|
||
| from cuda.core._dlpack cimport DLTensor | ||
| from cuda.core._layout cimport _StridedLayout | ||
|
|
||
|
|
||
| cdef class StridedMemoryView: | ||
| cdef readonly: | ||
| intptr_t ptr | ||
| int device_id | ||
| bint is_device_accessible | ||
| bint readonly | ||
| object exporting_obj | ||
|
|
||
| cdef: | ||
| object metadata | ||
| DLTensor* dl_tensor | ||
| _StridedLayout _layout | ||
| object _buffer | ||
| object _dtype | ||
|
|
||
| cdef inline _StridedLayout get_layout(self) | ||
| cdef inline object get_buffer(self) | ||
| cdef inline object get_dtype(self) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from cuda.bindings cimport cydriver | ||
| from libc.stdint cimport intptr_t | ||
| from cuda.core._memoryview cimport StridedMemoryView | ||
|
|
||
|
|
||
| cdef class TensorMapDescriptor: | ||
| cdef cydriver.CUtensorMap _tensor_map | ||
| cdef int _device_id | ||
| cdef intptr_t _context | ||
| cdef object _source_ref | ||
| cdef StridedMemoryView _view_ref | ||
| cdef object _repr_info | ||
|
|
||
| cdef int _check_context_compat(self) except -1 | ||
| cdef void* _get_data_ptr(self) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.