-
Notifications
You must be signed in to change notification settings - Fork 258
Add explicit CUDA graph construction API #1729
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
478101c
ee55795
3b3a715
7559d2f
ae5d706
ba7d2f9
77a5dac
1a790ce
3e25c64
228de38
506850b
4ee0ed7
82f0ec7
5cf85c4
d993f9c
da73536
d228830
b463ccb
f1cceb2
133719b
e7ebe53
f0bbf66
b830e6e
8830b97
296ef6e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -56,6 +56,9 @@ decltype(&cuLibraryLoadData) p_cuLibraryLoadData = nullptr; | |
| decltype(&cuLibraryUnload) p_cuLibraryUnload = nullptr; | ||
| decltype(&cuLibraryGetKernel) p_cuLibraryGetKernel = nullptr; | ||
|
|
||
| // Graph | ||
| decltype(&cuGraphDestroy) p_cuGraphDestroy = nullptr; | ||
|
|
||
| // Linker | ||
| decltype(&cuLinkDestroy) p_cuLinkDestroy = nullptr; | ||
|
|
||
|
|
@@ -160,6 +163,49 @@ class GILAcquireGuard { | |
|
|
||
| } // namespace | ||
|
|
||
| // ============================================================================ | ||
| // Handle reverse-lookup registry | ||
| // | ||
| // Maps raw CUDA handles (CUevent, CUkernel, etc.) back to their owning | ||
| // shared_ptr so that _ref constructors can recover full metadata. | ||
| // Uses weak_ptr to avoid preventing destruction. | ||
| // ============================================================================ | ||
|
|
||
| template<typename Key, typename Handle, typename Hash = std::hash<Key>> | ||
| class HandleRegistry { | ||
| public: | ||
| void register_handle(const Key& key, const Handle& h) { | ||
| std::lock_guard<std::mutex> lock(mutex_); | ||
| map_[key] = h; | ||
| } | ||
|
|
||
| void unregister_handle(const Key& key) noexcept { | ||
| try { | ||
| std::lock_guard<std::mutex> lock(mutex_); | ||
| auto it = map_.find(key); | ||
| if (it != map_.end() && it->second.expired()) { | ||
| map_.erase(it); | ||
| } | ||
| } catch (...) {} | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should probably provide some feedback to the user that unregister_handle operation failed.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. This would be an ideal place to use a logger. Do we have plans to add one to cuda.core? I'm being cautious here because this is typically called from a destructor. |
||
| } | ||
|
|
||
| Handle lookup(const Key& key) { | ||
| std::lock_guard<std::mutex> lock(mutex_); | ||
| auto it = map_.find(key); | ||
| if (it != map_.end()) { | ||
| if (auto h = it->second.lock()) { | ||
| return h; | ||
| } | ||
| map_.erase(it); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think you want to erase the found element in your
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The found element is an expired |
||
| } | ||
| return {}; | ||
| } | ||
|
|
||
| private: | ||
| std::mutex mutex_; | ||
| std::unordered_map<Key, std::weak_ptr<typename Handle::element_type>, Hash> map_; | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider giving this |
||
| }; | ||
|
|
||
| // ============================================================================ | ||
| // Thread-local error handling | ||
| // ============================================================================ | ||
|
|
@@ -318,47 +364,98 @@ StreamHandle get_per_thread_stream() { | |
| namespace { | ||
| struct EventBox { | ||
| CUevent resource; | ||
| bool timing_disabled; | ||
| bool busy_waited; | ||
| bool ipc_enabled; | ||
| int device_id; | ||
| ContextHandle h_context; | ||
|
Comment on lines
365
to
+371
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These properties are set at event creation time and cannot be queried through the driver API. Moreover, graph-attached events are returned from the driver as plain The solution is to move the property metadata into C++ and set up a reverse look-up so that the driver-returned Graph-attached kernels are handled similarly.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider memory layout density -- order members primarily by size and alignment. |
||
| }; | ||
| } // namespace | ||
|
|
||
| EventHandle create_event_handle(const ContextHandle& h_ctx, unsigned int flags) { | ||
| static const EventBox* get_box(const EventHandle& h) { | ||
| const CUevent* p = h.get(); | ||
| return reinterpret_cast<const EventBox*>( | ||
| reinterpret_cast<const char*>(p) - offsetof(EventBox, resource) | ||
| ); | ||
| } | ||
|
|
||
| bool get_event_timing_disabled(const EventHandle& h) noexcept { | ||
| return h ? get_box(h)->timing_disabled : true; | ||
| } | ||
|
|
||
| bool get_event_busy_waited(const EventHandle& h) noexcept { | ||
| return h ? get_box(h)->busy_waited : false; | ||
| } | ||
|
|
||
| bool get_event_ipc_enabled(const EventHandle& h) noexcept { | ||
| return h ? get_box(h)->ipc_enabled : false; | ||
| } | ||
|
|
||
| int get_event_device_id(const EventHandle& h) noexcept { | ||
| return h ? get_box(h)->device_id : -1; | ||
| } | ||
|
|
||
| ContextHandle get_event_context(const EventHandle& h) noexcept { | ||
| return h ? get_box(h)->h_context : ContextHandle{}; | ||
| } | ||
|
|
||
| static HandleRegistry<CUevent, EventHandle> event_registry; | ||
|
|
||
| EventHandle create_event_handle(const ContextHandle& h_ctx, unsigned int flags, | ||
| bool timing_disabled, bool busy_waited, | ||
| bool ipc_enabled, int device_id) { | ||
| GILReleaseGuard gil; | ||
| CUevent event; | ||
| if (CUDA_SUCCESS != (err = p_cuEventCreate(&event, flags))) { | ||
| return {}; | ||
| } | ||
|
|
||
| auto box = std::shared_ptr<const EventBox>( | ||
| new EventBox{event}, | ||
| new EventBox{event, timing_disabled, busy_waited, ipc_enabled, device_id, h_ctx}, | ||
| [h_ctx](const EventBox* b) { | ||
| event_registry.unregister_handle(b->resource); | ||
| GILReleaseGuard gil; | ||
| p_cuEventDestroy(b->resource); | ||
| delete b; | ||
| } | ||
| ); | ||
| return EventHandle(box, &box->resource); | ||
| EventHandle h(box, &box->resource); | ||
| event_registry.register_handle(event, h); | ||
| return h; | ||
| } | ||
|
|
||
| EventHandle create_event_handle_noctx(unsigned int flags) { | ||
| return create_event_handle(ContextHandle{}, flags); | ||
| return create_event_handle(ContextHandle{}, flags, true, false, false, -1); | ||
| } | ||
|
|
||
| EventHandle create_event_handle_ref(CUevent event) { | ||
| if (auto h = event_registry.lookup(event)) { | ||
| return h; | ||
| } | ||
| auto box = std::make_shared<const EventBox>(EventBox{event, true, false, false, -1, {}}); | ||
| return EventHandle(box, &box->resource); | ||
| } | ||
|
|
||
| EventHandle create_event_handle_ipc(const CUipcEventHandle& ipc_handle) { | ||
| EventHandle create_event_handle_ipc(const CUipcEventHandle& ipc_handle, | ||
| bool busy_waited) { | ||
| GILReleaseGuard gil; | ||
| CUevent event; | ||
| if (CUDA_SUCCESS != (err = p_cuIpcOpenEventHandle(&event, ipc_handle))) { | ||
| return {}; | ||
| } | ||
|
|
||
| auto box = std::shared_ptr<const EventBox>( | ||
| new EventBox{event}, | ||
| new EventBox{event, true, busy_waited, true, -1, {}}, | ||
| [](const EventBox* b) { | ||
| event_registry.unregister_handle(b->resource); | ||
| GILReleaseGuard gil; | ||
| p_cuEventDestroy(b->resource); | ||
| delete b; | ||
| } | ||
| ); | ||
| return EventHandle(box, &box->resource); | ||
| EventHandle h(box, &box->resource); | ||
| event_registry.register_handle(event, h); | ||
| return h; | ||
| } | ||
|
|
||
| // ============================================================================ | ||
|
|
@@ -665,61 +762,43 @@ struct ExportDataKeyHash { | |
|
|
||
| } | ||
|
|
||
| static std::mutex ipc_ptr_cache_mutex; | ||
| static std::unordered_map<ExportDataKey, std::weak_ptr<DevicePtrBox>, ExportDataKeyHash> ipc_ptr_cache; | ||
| static HandleRegistry<ExportDataKey, DevicePtrHandle, ExportDataKeyHash> ipc_ptr_cache; | ||
| static std::mutex ipc_import_mutex; | ||
|
|
||
| DevicePtrHandle deviceptr_import_ipc(const MemoryPoolHandle& h_pool, const void* export_data, const StreamHandle& h_stream) { | ||
| auto data = const_cast<CUmemPoolPtrExportData*>( | ||
| reinterpret_cast<const CUmemPoolPtrExportData*>(export_data)); | ||
|
|
||
| if (use_ipc_ptr_cache()) { | ||
| // Check cache before calling cuMemPoolImportPointer | ||
| ExportDataKey key; | ||
| std::memcpy(&key.data, data, sizeof(key.data)); | ||
|
|
||
| std::lock_guard<std::mutex> lock(ipc_ptr_cache_mutex); | ||
| std::lock_guard<std::mutex> lock(ipc_import_mutex); | ||
|
|
||
| auto it = ipc_ptr_cache.find(key); | ||
| if (it != ipc_ptr_cache.end()) { | ||
| if (auto box = it->second.lock()) { | ||
| // Cache hit - return existing handle | ||
| return DevicePtrHandle(box, &box->resource); | ||
| } | ||
| ipc_ptr_cache.erase(it); // Expired entry | ||
| if (auto h = ipc_ptr_cache.lookup(key)) { | ||
| return h; | ||
| } | ||
|
|
||
| // Cache miss - import the pointer | ||
| GILReleaseGuard gil; | ||
| CUdeviceptr ptr; | ||
| if (CUDA_SUCCESS != (err = p_cuMemPoolImportPointer(&ptr, *h_pool, data))) { | ||
| return {}; | ||
| } | ||
|
|
||
| // Create new handle with cache-clearing deleter | ||
| auto box = std::shared_ptr<DevicePtrBox>( | ||
| new DevicePtrBox{ptr, h_stream}, | ||
| [h_pool, key](DevicePtrBox* b) { | ||
| ipc_ptr_cache.unregister_handle(key); | ||
| GILReleaseGuard gil; | ||
| try { | ||
| std::lock_guard<std::mutex> lock(ipc_ptr_cache_mutex); | ||
| // Only erase if expired - avoids race where another thread | ||
| // replaced the entry with a new import before we acquired the lock. | ||
| auto it = ipc_ptr_cache.find(key); | ||
| if (it != ipc_ptr_cache.end() && it->second.expired()) { | ||
| ipc_ptr_cache.erase(it); | ||
| } | ||
| } catch (...) { | ||
| // Cache cleanup is best-effort - swallow exceptions in destructor context | ||
| } | ||
| p_cuMemFreeAsync(b->resource, as_cu(b->h_stream)); | ||
| delete b; | ||
| } | ||
| ); | ||
| ipc_ptr_cache[key] = box; | ||
| return DevicePtrHandle(box, &box->resource); | ||
| DevicePtrHandle h(box, &box->resource); | ||
| ipc_ptr_cache.register_handle(key, h); | ||
| return h; | ||
|
|
||
| } else { | ||
| // No caching - simple handle creation | ||
| GILReleaseGuard gil; | ||
| CUdeviceptr ptr; | ||
| if (CUDA_SUCCESS != (err = p_cuMemPoolImportPointer(&ptr, *h_pool, data))) { | ||
|
|
@@ -798,25 +877,96 @@ LibraryHandle create_library_handle_ref(CUlibrary library) { | |
| namespace { | ||
| struct KernelBox { | ||
| CUkernel resource; | ||
| LibraryHandle h_library; // Keeps library alive | ||
| LibraryHandle h_library; | ||
| }; | ||
| } // namespace | ||
|
|
||
| static const KernelBox* get_box(const KernelHandle& h) { | ||
| const CUkernel* p = h.get(); | ||
| return reinterpret_cast<const KernelBox*>( | ||
| reinterpret_cast<const char*>(p) - offsetof(KernelBox, resource) | ||
| ); | ||
| } | ||
|
|
||
| static HandleRegistry<CUkernel, KernelHandle> kernel_registry; | ||
|
|
||
| KernelHandle create_kernel_handle(const LibraryHandle& h_library, const char* name) { | ||
| GILReleaseGuard gil; | ||
| CUkernel kernel; | ||
| if (CUDA_SUCCESS != (err = p_cuLibraryGetKernel(&kernel, *h_library, name))) { | ||
| return {}; | ||
| } | ||
|
|
||
| return create_kernel_handle_ref(kernel, h_library); | ||
| auto box = std::make_shared<const KernelBox>(KernelBox{kernel, h_library}); | ||
| KernelHandle h(box, &box->resource); | ||
| kernel_registry.register_handle(kernel, h); | ||
| return h; | ||
| } | ||
|
|
||
| KernelHandle create_kernel_handle_ref(CUkernel kernel, const LibraryHandle& h_library) { | ||
| auto box = std::make_shared<const KernelBox>(KernelBox{kernel, h_library}); | ||
| KernelHandle create_kernel_handle_ref(CUkernel kernel) { | ||
| if (auto h = kernel_registry.lookup(kernel)) { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Consider using trailing new if-statement:
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think |
||
| return h; | ||
| } | ||
| auto box = std::make_shared<const KernelBox>(KernelBox{kernel, {}}); | ||
| return KernelHandle(box, &box->resource); | ||
| } | ||
|
|
||
| LibraryHandle get_kernel_library(const KernelHandle& h) noexcept { | ||
| if (!h) return {}; | ||
| return get_box(h)->h_library; | ||
| } | ||
|
|
||
| // ============================================================================ | ||
| // Graph Handles | ||
| // ============================================================================ | ||
|
|
||
| namespace { | ||
| struct GraphBox { | ||
| CUgraph resource; | ||
| GraphHandle h_parent; // Keeps parent alive for child/branch graphs | ||
| }; | ||
| } // namespace | ||
|
|
||
| GraphHandle create_graph_handle(CUgraph graph) { | ||
| auto box = std::shared_ptr<const GraphBox>( | ||
| new GraphBox{graph, {}}, | ||
| [](const GraphBox* b) { | ||
| GILReleaseGuard gil; | ||
| p_cuGraphDestroy(b->resource); | ||
| delete b; | ||
| } | ||
| ); | ||
| return GraphHandle(box, &box->resource); | ||
| } | ||
|
|
||
| GraphHandle create_graph_handle_ref(CUgraph graph, const GraphHandle& h_parent) { | ||
| auto box = std::make_shared<const GraphBox>(GraphBox{graph, h_parent}); | ||
| return GraphHandle(box, &box->resource); | ||
| } | ||
|
|
||
| namespace { | ||
| struct GraphNodeBox { | ||
| CUgraphNode resource; | ||
| GraphHandle h_graph; | ||
| }; | ||
| } // namespace | ||
|
|
||
| static const GraphNodeBox* get_box(const GraphNodeHandle& h) { | ||
| const CUgraphNode* p = h.get(); | ||
| return reinterpret_cast<const GraphNodeBox*>( | ||
| reinterpret_cast<const char*>(p) - offsetof(GraphNodeBox, resource) | ||
| ); | ||
| } | ||
|
|
||
| GraphNodeHandle create_graph_node_handle(CUgraphNode node, const GraphHandle& h_graph) { | ||
| auto box = std::make_shared<const GraphNodeBox>(GraphNodeBox{node, h_graph}); | ||
| return GraphNodeHandle(box, &box->resource); | ||
| } | ||
|
|
||
| GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept { | ||
| return h ? get_box(h)->h_graph : GraphHandle{}; | ||
| } | ||
|
|
||
| // ============================================================================ | ||
| // Graphics Resource Handles | ||
| // ============================================================================ | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can write
std::lock_guard lock(mutex_);CTAD was included in C++17. Removing a code ripple if you want to change the type of the mutex later.