Skip to content

[CPU/CUDA EP] Add DeformConv op support#27393

Open
ShirasawaSama wants to merge 58 commits intomicrosoft:mainfrom
ShirasawaSama:feature/add-deform-conv-2d-support
Open

[CPU/CUDA EP] Add DeformConv op support#27393
ShirasawaSama wants to merge 58 commits intomicrosoft:mainfrom
ShirasawaSama:feature/add-deform-conv-2d-support

Conversation

@ShirasawaSama
Copy link
Contributor

@ShirasawaSama ShirasawaSama commented Feb 19, 2026

Description

This change adds support for the Deformable Convolution 2D operator (DeformConv2D) to ONNX Runtime. The branch implements the operator schema and registration, provides kernel implementations (CPU and GPU/CUDA where available), implements shape inference, and adds unit and integration tests to validate correctness and numerical parity with reference implementations. The changes include performance-oriented optimizations and necessary changes to build/test scripts.

Motivation and Context

Deformable convolutions are widely used in vision models that require spatial sampling flexibility (e.g., Deformable ConvNets, some detection/segmentation models). Native support in ONNX Runtime enables these models to run efficiently without custom operators or external runtimes, broadening the set of compatible models and improving performance and portability.

See also

@ShirasawaSama
Copy link
Contributor Author

@microsoft-github-policy-service agree

@ShirasawaSama ShirasawaSama changed the title Feature/add deform conv 2d support Add deform conv 2d support Feb 19, 2026
@ShirasawaSama
Copy link
Contributor Author

@fs-eire @tianleiwu Hello, could you please help me trigger a GitHub Copilot code review and unit test/build test pipeline?

Thank you very much.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request adds comprehensive support for the DeformConv2D (Deformable Convolution 2D) operator to ONNX Runtime for opsets 19-22. The implementation includes CPU and CUDA kernel implementations with extensive test coverage.

Changes:

  • Implements DeformConv operator for CPU (float) and CUDA (float, double, MLFloat16, BFloat16)
  • Adds deformable im2col kernels with bilinear interpolation for spatially-adaptive sampling
  • Includes comprehensive unit tests covering edge cases, data types, and parameter combinations

Reviewed changes

Copilot reviewed 10 out of 10 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc Comprehensive test suite with 24 test cases covering various configurations, edge cases, and data types
onnxruntime/test/providers/cpu/nn/deform_conv_expected_gen.py Python script to generate expected outputs using PyTorch's torchvision.ops.deform_conv2d as reference
onnxruntime/core/providers/cuda/nn/deform_conv_impl.h CUDA kernel interface declarations for im2col, bias addition, and GEMM output copy
onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu CUDA kernel implementations with optimizations for FP16/BF16 and memory access patterns
onnxruntime/core/providers/cuda/nn/deform_conv.h CUDA operator class declaration and attributes structure
onnxruntime/core/providers/cuda/nn/deform_conv.cc CUDA operator compute logic with batched processing and memory management
onnxruntime/core/providers/cuda/cuda_execution_provider.cc CUDA kernel registrations for opsets 19-22 across all supported types
onnxruntime/core/providers/cpu/nn/deform_conv.h CPU operator class declaration
onnxruntime/core/providers/cpu/nn/deform_conv.cc CPU operator implementation with float support only
onnxruntime/core/providers/cpu/cpu_execution_provider.cc CPU kernel registrations for opsets 19-22

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@ShirasawaSama ShirasawaSama changed the title Add deform conv 2d support [CPU/CUDA EP] Add DeformConv op support Feb 23, 2026
@ShirasawaSama
Copy link
Contributor Author

Hello @hariharans29,

Sorry to ping here — I just wanted to kindly ask whether the current PR structure looks reasonable.

I realize this PR is relatively large (~2000 lines including tests) since it includes both CPU and CUDA implementations of DeformConv.

Would it make the review process easier if I split it into two separate PRs (e.g., CPU first, then CUDA)? I’m happy to restructure it if that aligns better with ORT’s contribution workflow.

I’m also planning to add a DirectML implementation in a follow-up PR. Please let me know if there are any specific design considerations I should keep in mind to ensure consistency across execution providers.

Additionally, would you recommend including performance benchmarks or accuracy comparisons against PyTorch’s native implementation? I can provide benchmark results and numerical validation reports if that would be helpful for review.

As I’m still relatively new to contributing here, I may not be reaching out to the most appropriate person. If this isn’t within your area or availability, I completely understand — I would really appreciate any suggestions on the appropriate next steps.

I really appreciate your time and guidance. Thank you!

@ShirasawaSama ShirasawaSama force-pushed the feature/add-deform-conv-2d-support branch from 2d85c09 to bb17da5 Compare February 26, 2026 16:03
Copy link
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1. CPU kernel registration uses InputMemoryType(OrtMemTypeCPUInput, ...) incorrectly

File: onnxruntime/core/providers/cpu/nn/deform_conv.cc, lines ~370–384

.InputMemoryType(OrtMemTypeCPUInput, 2)   /* offset */
.InputMemoryType(OrtMemTypeCPUInput, 4),  /* optional mask */

InputMemoryType annotations are meaningful for non-CPU execution providers (e.g., CUDA EP) to specify that certain inputs should stay on CPU. For a CPU kernel, all inputs are already on CPU — these annotations are unnecessary and potentially confusing. They should be removed from the CPU registration macros.

4. GemmEx<double> addition in math_cpu.cc may affect other operators

File: onnxruntime/core/util/math_cpu.cc, +65 lines

A new GemmEx<double, ThreadPool> specialization is added with Eigen fallback when MLAS_SUPPORTS_GEMM_DOUBLE is not defined. This is a shared utility — if this template was previously uninstantiated/missing, it could break or affect other operators that might attempt to use it. This change should be:

  • Called out explicitly in the PR description
  • Tested independently or confirmed that no existing code path is affected
  • Considered for a separate PR since it's infrastructure-level

5. CUDA: No cudaGetLastError() / error checking after kernel launches

File: onnxruntime/core/providers/cuda/nn/deform_conv_impl.cu

After DeformableIm2ColKernel, DeformConvAddBiasKernel, and CopyGemmOutputRowMajorToNCHWKernel launches, there is no error checking (e.g., CUDA_RETURN_IF_ERROR(cudaGetLastError())). ORT's convention for CUDA kernels typically includes post-launch error checking to catch launch configuration errors and async kernel failures. Add error checking after each kernel launch.


Moderate Issues

6. CUDA: Potential integer overflow in DeformableIm2ColKernel address computation

Inside the kernel, address calculations like:

out_b * (channels * height * width) + in_c * (height * width)

use int64_t arithmetic, but out_b is derived from IndexT (which may be int32_t). When IndexT = int32_t, the intermediate products could overflow if channels * height * width is large. The address calculations should be explicitly cast to int64_t.

7. CPU: DeformableIm2col iterates over output spatially then channels — poor cache locality

File: onnxruntime/core/providers/cpu/nn/deform_conv.cc

The loop order is:

for c_col in [0, out_h*out_w):    // spatial
  for c_im in [0, C):             // channel
    for i, j in kernel:           // kernel

This means for each spatial position, all channels are processed. A better loop order for CPU cache efficiency would be to iterate channels in the outer loop (stride-1 access in data_im). This would significantly improve performance especially for large input tensors. Consider:

for c_im:
  for c_col:
    for i, j:

8. CPU: No multi-threading in DeformableIm2col

The CPU im2col implementation is entirely single-threaded. The thread_pool obtained from context is only passed to GemmEx. For large inputs, the im2col phase could be a bottleneck. Consider parallelizing the outer spatial or channel loop using concurrency::ThreadPool::TryParallelFor.

9. CUDA: CopyGemmOutputRowMajorToNCHW — unnecessary extra copy kernel

The CopyGemmOutputRowMajorToNCHW kernel copies GEMM results to NCHW layout after each group's GEMM. This is an extra memory copy pass. Could the GEMM output be written directly to the correct NCHW offsets by adjusting the cuBLAS call's output pointer and leading dimension? This would eliminate one kernel launch per group per batch block.

Alternatively, if the current layout mismatch is unavoidable due to the batched im2col approach, this should be documented clearly.

10. CUDA: Memory heuristic kMaxTempMemSize = 256MB is arbitrary

File: onnxruntime/core/providers/cuda/nn/deform_conv.cc, line ~210

constexpr size_t kMaxTempMemSize = 256 * 1024 * 1024;

This hardcoded limit doesn't account for actual GPU memory availability. On smaller GPUs (e.g., 4GB), 256MB is a significant fraction. Consider querying available memory or making this configurable, or at least documenting the rationale.

11. Test: Asymmetric padding computation uses pad[1] as left, but ONNX spec says pad[1] is left

File: onnxruntime/test/providers/cpu/nn/deform_conv_op_test.cc

In RunDeformConvTest, the output size is computed as:

const int64_t out_h = (params.in_h + params.pad[0] + params.pad[2] - ...) / params.stride[0] + 1;
const int64_t out_w = (params.in_w + params.pad[1] + params.pad[3] - ...) / params.stride[1] + 1;

The ONNX spec defines pads as [x1_begin, x2_begin, x1_end, x2_end] (for 2D: [pad_top, pad_left, pad_bottom, pad_right]). The computation here is consistent but should be verified against the kernel code which uses pad_h = pads[0], pad_w = pads[1] (begin pads only) and pad_h_end = pads[2], pad_w_end = pads[3].


Minor Issues / Suggestions

12. Duplicated DeformConvAttributes struct between CPU and CUDA

The DeformConvAttributes struct is defined identically in both cpu/nn/deform_conv.h and cuda/nn/deform_conv.h. Consider extracting it to a shared header (e.g., core/providers/common/deform_conv_attributes.h) to avoid duplication and keep them in sync.

13. Duplicated validation logic between CPU and CUDA Compute functions

Both DeformConv<T>::Compute (CPU) and DeformConv<T>::ComputeInternal (CUDA) contain ~30 lines of identical validation code. This could be extracted to a shared validation helper.

14. CUDA kernel uses #pragma unroll on dynamic loop bounds

#pragma unroll
for (int64_t i = 0; i < weight_h; ++i) {

weight_h and weight_w are runtime values (kernel parameters), so #pragma unroll has no effect here with most compilers. It's harmless but misleading — either remove it or add a comment noting it's a hint for common small kernel sizes.

15. Test MinimalBilinear: offset semantics need clearer documentation

The test comment says:

// offset (1, 2, 2, 2): ch0=offset_h, ch1=offset_w per output position. (0,0):(0.5,0)->2.5, (0,1):(0.5,-1)->1

But the offset tensor has 8 values with shape [1, 2, 2, 2]. The mapping from offset values to output positions is not immediately clear from the comment. Consider adding a more explicit breakdown.

16. No ONNX model test nodes (.onnx test data)

The test suite only uses OpTester C++ tests. Consider also adding standard ONNX model test data files that can be used in model zoo validation or cross-platform testing.

18. GetGreatestDivisorBelowBound could be slow for large N with small bound

For N values that are prime (or have no small factors), this linearly scans from bound to 1. For bound = 32 this is negligible, but worth noting.

19. Test Python script (deform_conv_expected_gen.py) hard-codes PyTorch padding convention

The script uses padding=(pad_h, pad_w) (symmetric) for PyTorch but generates ONNX pads = [pad_h, pad_w, pad_h, pad_w]. This works for symmetric padding but wouldn't generate correct asymmetric test cases. This is fine for the current tests but limits the script's utility.

20. Consider adding a DeformConv shape inference function

The PR adds compute kernels but doesn't seem to add shape inference support. If ONNX Runtime's graph optimizer needs shape inference for DeformConv nodes (e.g., for memory planning), this may need to be added.

@ShirasawaSama ShirasawaSama force-pushed the feature/add-deform-conv-2d-support branch from bb17da5 to 7d2b779 Compare March 1, 2026 12:27
@ShirasawaSama ShirasawaSama marked this pull request as draft March 1, 2026 14:57
@ShirasawaSama ShirasawaSama force-pushed the feature/add-deform-conv-2d-support branch from 1e5baba to 1222ad4 Compare March 1, 2026 18:21
@ShirasawaSama
Copy link
Contributor Author

@tianleiwu Thanks for the code review. Here is the updated CR response:


CR Response

1. CPU InputMemoryType for offset/mask
Fixed. Removed InputMemoryType(OrtMemTypeCPUInput, ...) from CPU kernel registration; CPU kernels no longer use these annotations.

4. GemmEx<double> in math_cpu.cc
Fixed. DeformConv not uses math::Gemm<T>, which for double uses the existing Gemm<double, ThreadPool> (Eigen-based) in math_cpu.cc. No new GemmEx<double> was added, and no existing path was changed by this PR.

5. CUDA: No cudaGetLastError() after kernel launches
Fixed. Added CUDA_CALL(cudaGetLastError()) after DeformableIm2ColKernel, DeformConvAddBiasKernel, and CopyGemmOutputRowMajorToNCHWKernel launches.

6. CUDA: Potential integer overflow with IndexT
Fixed. Address computations use explicit static_cast<int64_t> for out_b, in_c, and related IndexT values before multiplication. Large inputs use the use_64bit path with int64_t IndexT.

7. CPU: DeformableIm2col loop order for cache locality
Fixed. Loop order is "channels outer, spatial inner" for better cache locality and sequential data_col access.

8. CPU: No multi-threading in DeformableIm2col
Fixed. Im2col is parallelized over channels via TryParallelFor; bias add is parallelized over batch × output channels (N * M).

9. CUDA: CopyGemmOutputRowMajorToNCHW extra copy
Fixed. When cur_parallel == 1, GEMM output layout (pos, channel) matches NCHW Y[0, channel, pos], so results are written directly into the output with ldc = output_image_size and the copy kernel is skipped. The copy kernel runs only when cur_parallel > 1.

10. CUDA: kMaxTempMemSize = 256MB arbitrary
Fixed. Uses cudaMemGetInfo to get free memory, keeps 90% as a fragmentation buffer, and applies tiered caps (e.g., 16GB+ → 2GB, 8–16GB → 1GB, 4–8GB → 512MB, 2–4GB → 256MB, <2GB → 128MB). Result is max(32MB, min(tier_cap, free_mem)). 256MB is used only when the query fails.

11. Test: Asymmetric padding pad[1] semantics
Verified. pad[0..3] is used as [top, left, bottom, right]; out_h uses pad[0]/pad[2], out_w uses pad[1]/pad[3], matching the ONNX spec. Added a comment in RunDeformConvTest.

12. Duplicated DeformConvAttributes
Fixed. DeformConvAttributes and DeformConvParams live in shared core/providers/cpu/nn/deform_conv_attributes.h; both CPU and CUDA include it.

13. Duplicated validation logic
Fixed. Both CPU and CUDA call shared DeformConvValidateAndParse in deform_conv_attributes.h.

14. #pragma unroll on dynamic loop bounds
Fixed. Split into two paths: if constexpr (is_fixed) for compile-time fixed kernel sizes (3×3, 5×5) with #pragma unroll, and an else branch for dynamic kernel sizes without pragma.

15. MinimalBilinear test offset documentation
Fixed. Added comments describing the offset layout and the mapping from offset values to output positions.

16. No ONNX model test nodes
Fixed. Added deform_conv_test.onnx via deform_conv_test_gen.py; OnnxModelTest uses it with AddReferenceOutputs.

18. GetGreatestDivisorBelowBound performance
Fixed. Optimized with a fast path for multiples of the bound, and a branch that chooses between a linear scan from the bound down (when n >= bound²) and divisor enumeration from 1 to √n (when n < bound²) based on integer comparisons (no sqrt).

19. deform_conv_expected_gen.py padding convention
Addressed. Added a docstring noting that the script uses symmetric padding only; asymmetric pads would require PyTorch API support and are not generated.

20. DeformConv shape inference
Not addressed. ONNX already provides shape inference for DeformConv (convPoolShapeInference), and ORT uses it for graph resolve and memory planning. No ORT-specific shape inference was added in this PR.

@ShirasawaSama ShirasawaSama marked this pull request as ready for review March 3, 2026 06:47
@ShirasawaSama ShirasawaSama requested a review from tianleiwu March 3, 2026 06:47
@ShirasawaSama

This comment was marked as off-topic.

@tianleiwu
Copy link
Contributor

@ShirasawaSama, Here are some issues found by AI, please take a look. They might not be correct.

1. Unused batch_idx in DeformConvAddBiasKernel — compiler warning

In deform_conv_impl.cu:301, batch_idx is computed via channel_div.divmod() but never referenced. This will produce compiler warnings on MSVC (C4189) and GCC/Clang (-Wunused-variable), which may fail ORT's -Werror CI builds.

// deform_conv_impl.cu:296-301
int64_t batch_idx, channel_idx;
channel_div.divmod(batch_channel_idx, batch_idx, channel_idx);
//                                    ^^^^^^^^^  never used

Fix: Replace with an unnamed variable or (void)batch_idx;:

int64_t batch_idx, channel_idx;
channel_div.divmod(batch_channel_idx, batch_idx, channel_idx);
(void)batch_idx;  // Only channel_idx is needed

2. CPU DeformableIm2col only passes begin-padding to im2col

In deform_conv.cc:214, only pad_h and pad_w (the begin-side pads) are passed to DeformableIm2col, while pad_h_end/pad_w_end are parsed in DeformConvParams but never forwarded.

This is technically correct for im2col — the end pads only affect output dimensions (which are correctly computed using all 4 pads). However, the fact that pad_h_end/pad_w_end are parsed and stored but never consumed is misleading. Consider:

  • Adding a comment in DeformConv::Compute() explaining why only the begin pads are needed
  • Or removing the unused pad_h_end/pad_w_end from the local variable destructuring

3. [CORRECTNESS] Missing validation for offset batch dimension

In deform_conv_attributes.h, the validation checks offset_shape[1], [2], [3] but does not verify that offset_shape[0] == N. Same issue for mask_shape[0]. A mismatch would cause silent out-of-bounds reads.

// Should add:
ORT_RETURN_IF_NOT(offset_shape[0] == params.N, "Offset batch size must match input batch size.");
if (params.use_mask) {
  ORT_RETURN_IF_NOT((*mask_shape)[0] == params.N, "Mask batch size must match input batch size.");
}

4. [PERF] CPU bias add could use SIMD/vectorized fill

The CPU bias addition (lines 265–280) uses a scalar loop Y_ptr[i] += bias_val. For large spatial sizes, this could benefit from vectorized operations. Not blocking, but worth noting for future optimization.

5. [STYLE] Template instantiation formatting in .cu file

Lines 440–443 have inconsistent indentation for template instantiations:

INST_DeformConvIm2ColImpl(float)
    INST_DeformConvIm2ColImpl(double)
        INST_DeformConvIm2ColImpl(half)
            INST_DeformConvIm2ColImpl(BFloat16)

These should be left-aligned. This is a minor style issue.

6. [ROBUSTNESS] cudaMemGetInfo called on every inference

GetDeformConvEffectiveMaxTempBytes() calls cudaMemGetInfo() on every ComputeInternal invocation. This is a synchronization point that can stall the GPU pipeline. Consider caching the result, or computing it once during kernel construction.

7. [DOCUMENTATION] CUDA GEMM layout trick deserves a clearer top-level comment

The cuBLAS transpose trick in deform_conv.cc:186–208 is well commented inline, but a one-line summary before the block (e.g., "// Compute Y = W * Col without transpose by exploiting cuBLAS column-major/row-major equivalence.") would help future maintainers skim the code.

@ShirasawaSama
Copy link
Contributor Author

@tianleiwu Thank you for your reply! I will carefully review the code again.

Additionally, would you recommend adding benchmarks to this PR or to subsequent ones?

Personally, I'd prefer to include them in the next PR. In my local tests comparing different convolution kernels, I observed performance gains ranging from 50% to 140% compared to torchvision's CUDA implementation (depending on kernel size and input dimensions). I can share my benchmark Jupyter notebook in my personal repository if needed.

@ShirasawaSama ShirasawaSama marked this pull request as draft March 5, 2026 20:22
@ShirasawaSama
Copy link
Contributor Author

ShirasawaSama commented Mar 6, 2026

All code review suggestions mentioned above have been addressed.

perf(DeformConv CPU): improve im2col, bilinear interpolation, and bias handling

  • DeformableIm2col: UseMask as template parameter; parallelize over channels*kernel_size
  • BilinearInterpolate: early OOB check, fast path when all 4 neighbors are in bounds
  • Reduce redundant work in hot loop: ptr_offset_h/w, ptr_mask, base_h/base_w
  • Add detailed comments on design choices and tensor layouts

@ShirasawaSama
Copy link
Contributor Author

For cuda:

  • Refactored CUDA DeformConv execution flow to reduce lock contention: the mutex now only guards UpdateState, while im2col/GEMM/bias run outside the lock.
  • Added a direct-GEMM fast path using cublasGemmStridedBatchedHelper when gemm_writes_directly (cur_parallel == 1), replacing per-group GEMM launches in that case.
  • Fixed a variable shadowing issue in the batched GEMM path by renaming stride_w to stride_weight.
  • Restored and improved comments around the GEMM layout trick for better readability/maintainability.
  • Updated CUDA im2col mask handling to match CPU behavior: removed the mask == 0 early-exit branch and always compute interpolation followed by mask multiplication.
  • Added a 1x1 kernel-size specialization in launch dispatch (kH == 1 && kW == 1) alongside existing 3x3/5x5 specializations.

@ShirasawaSama ShirasawaSama marked this pull request as ready for review March 6, 2026 23:19
@ShirasawaSama ShirasawaSama force-pushed the feature/add-deform-conv-2d-support branch from 9d7d29d to cbe1eca Compare March 7, 2026 07:41
@ShirasawaSama
Copy link
Contributor Author

ShirasawaSama commented Mar 7, 2026

image

benchmark_deform_conv.py

1. Absolute time (ms)

Config GPU TV (ms) GPU ORT (ms) CPU TV (ms) CPU ORT (ms)
B1 3x3 64x64 0.129 0.073 25.8 2.8
B1 5x5 64x64 0.311 0.138 62.8 7.0
B1 7x7 128x128 1.332 0.915 484.3 56.0
B2 3x3 64x64 0.200 0.155 50.2 5.3
B4 3x3 64x64 0.334 0.251 106.4 10.3
B2 5x5 128x128 1.535 1.016 509.0 57.7
B1 3x3 64x48 0.119 0.075 18.4 2.1
B1 5x5 100x80 0.375 0.250 95.8 14.5
B1 3x3 128x96 0.262 0.157 78.3 8.2
B1 3x3 63x63 0.130 0.077 17.8 2.6
B1 5x5 65x65 0.322 0.144 47.1 7.6
B1 7x7 127x127 1.350 0.952 335.1 54.0

2. Relative time (ORT / TV)

Ratio < 1 means ORT is faster.

Config GPU (ORT/TV) CPU (ORT/TV)
B1 3x3 64x64 0.57 0.11
B1 5x5 64x64 0.44 0.11
B1 7x7 128x128 0.69 0.12
B2 3x3 64x64 0.77 0.11
B4 3x3 64x64 0.75 0.10
B2 5x5 128x128 0.66 0.11
B1 3x3 64x48 0.63 0.12
B1 5x5 100x80 0.67 0.15
B1 3x3 128x96 0.60 0.10
B1 3x3 63x63 0.59 0.15
B1 5x5 65x65 0.45 0.16
B1 7x7 127x127 0.71 0.16

3. Numerical accuracy (TV vs ORT)

We report max absolute difference between TV and ORT outputs over all elements:

max_abs_diff = max |output_TV − output_ORT|

Config max_abs_diff
B1 3x3 64x64 0.000391
B1 5x5 64x64 0.000570
B1 7x7 128x128 0.000846
B2 3x3 64x64 0.000371
B4 3x3 64x64 0.000402
B2 5x5 128x128 0.000630
B1 3x3 64x48 0.000352
B1 5x5 100x80 0.000609
B1 3x3 128x96 0.000369
B1 3x3 63x63 0.000341
B1 5x5 65x65 0.000574
B1 7x7 127x127 0.000848

Relative error (e.g., |diff| / |TV|) is not used as the primary metric because the DeformConv output tensor contains many values close to zero. Dividing by very small reference values can artificially inflate the relative error without indicating a meaningful numerical discrepancy.

Instead, we report the maximum absolute difference. For float32 pipelines involving GEMM-style accumulation and bilinear interpolation, differences on the order of 1e-4–1e-3 are typically expected due to factors such as floating-point reduction order, fused multiply-add behavior, and potential TF32 usage on Tensor Core hardware.

All observed max_abs_diff values are below 1e-2, which is within a reasonable range for independent implementations of the same operator.

@ShirasawaSama
Copy link
Contributor Author

ShirasawaSama commented Mar 7, 2026

Random offset and masks:

image

1. Absolute time (ms)

Config GPU TV GPU ORT CPU TV CPU ORT
B1 3x3 64x64 0.127 0.088 24.9 2.6
B1 5x5 64x64 0.374 0.177 64.2 7.6
B1 7x7 128x128 1.333 0.926 483.5 57.9
B2 3x3 64x64 0.186 0.145 52.8 5.8
B4 3x3 64x64 0.351 0.259 109.5 12.0
B2 5x5 128x128 1.591 1.047 564.1 77.4
B1 3x3 64x48 0.121 0.073 21.5 2.6
B1 5x5 100x80 0.376 0.258 119.1 18.8
B1 3x3 128x96 0.255 0.156 86.6 10.6
B1 3x3 63x63 0.158 0.086 22.7 3.4
B1 5x5 65x65 0.502 0.151 62.7 10.3
B1 7x7 127x127 1.340 0.966 393.4 73.7

2. Relative time (ORT / TV), ratio < 1 ⇒ ORT faster

Config GPU CPU
B1 3x3 64x64 0.69 0.10
B1 5x5 64x64 0.47 0.12
B1 7x7 128x128 0.69 0.12
B2 3x3 64x64 0.78 0.11
B4 3x3 64x64 0.74 0.11
B2 5x5 128x128 0.66 0.14
B1 3x3 64x48 0.61 0.12
B1 5x5 100x80 0.69 0.16
B1 3x3 128x96 0.61 0.12
B1 3x3 63x63 0.55 0.15
B1 5x5 65x65 0.30 0.16
B1 7x7 127x127 0.72 0.19

3. Numerical accuracy (same inputs)

max_abs_diff = max |output_TV − output_ORT|. Relative error is not used as primary metric because outputs contain many near-zero values.

Config max_abs_diff
B1 3x3 64x64 0.000175
B1 5x5 64x64 0.000282
B1 7x7 128x128 0.000432
B2 3x3 64x64 0.000206
B4 3x3 64x64 0.000201
B2 5x5 128x128 0.000311
B1 3x3 64x48 0.000184
B1 5x5 100x80 0.000302
B1 3x3 128x96 0.000202
B1 3x3 63x63 0.000192
B1 5x5 65x65 0.000284
B1 7x7 127x127 0.000405

@ShirasawaSama ShirasawaSama force-pushed the feature/add-deform-conv-2d-support branch from 7adfb05 to c5d86e7 Compare March 14, 2026 11:48
@ShirasawaSama
Copy link
Contributor Author

Rebased to latest main branch

@ShirasawaSama
Copy link
Contributor Author

ShirasawaSama commented Mar 15, 2026

I’ve reviewed my code again and feel that the following two areas could be optimized, though they don’t affect the performance comparison report above:

  1. The logic for determining the maximum number of parallel images in CUDA currently uses the GCD, but it seems a greedy approach would be more reasonable. Common batch sizes are multiples of two, such as 4, 8, 16, 32, and 64, so this is generally not a problem; this is also noted in the comments.
  2. The value for kMaxTempMemSize, which determines the maximum amount of memory that can be allocated via CUDA, is a bit too high. I’m currently using GPU maximum VRAM * 0.1 to calculate it (the maximum limit is still 2 GB). I think adjusting it to 0.4 might be a more reasonable ratio, but I’m not entirely sure. Of course, I think offering different ratios for small and large video memory is also a good solution.

If you have any suggestions, I’m happy to make changes at any time.

Of course, I could also submit a pull request to optimize these two details, given that running the pipeline takes too long.

@tianleiwu tianleiwu requested a review from Copilot March 16, 2026 17:07
@tianleiwu
Copy link
Contributor

/azp run Linux QNN CI Pipeline, Win_TRT_Minimal_CUDA_Test_CI, Windows ARM64 QNN CI Pipeline, Windows GPU Doc Gen CI Pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 4 pipeline(s).

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 13 out of 15 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

@ShirasawaSama
Copy link
Contributor Author

Change list:

  • kernel_shape: If present, require exactly 2 elements and that they match W_shape[2] and W_shape[3]. Reject mismatches to avoid wrong GEMM K and out-of-bounds reads from the weight buffer.
  • pads: If present, require exactly 4 elements [pad_h_begin, pad_w_begin, pad_h_end, pad_w_end] and that all values are non-negative (per ONNX spec).
  • strides / dilations: If present, require exactly 2 elements each; otherwise reject to avoid silently using defaults for malformed models.
  • Scalar / shape checks: Require N > 0, C > 0, M > 0, and W_shape[1] > 0.
  • Offset / mask channel checks: Validate offset and mask channel counts using division (offset_shape[1] / (2*kH*kW) == offset_group and mask_shape[1] / (kH*kW) == offset_group) instead of multiplying offset_group * 2 * kH * kW (and similarly for mask) to avoid int64 overflow on large kernels.
  • Scope: Validation remains 2D-only (4D tensors, 2-D attribute lengths); no 3D support added.

tianleiwu
tianleiwu previously approved these changes Mar 16, 2026
Copy link
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ShirasawaSama
Copy link
Contributor Author

ShirasawaSama commented Mar 16, 2026

image
--use_cuda --use_tensorrt

@ShirasawaSama
Copy link
Contributor Author

ShirasawaSama commented Mar 16, 2026

After careful consideration and reviewing the implementation of conv op, I believe that allowing batch == 0 would be a more reasonable approach.

The last commit was made because the most recent pipeline failed; the new assert statement added last time (b96de81) caused an error when batch==0.

I apologize for running the pipeline again, but this should be the last time.

@ShirasawaSama
Copy link
Contributor Author

ShirasawaSama commented Mar 17, 2026

I think the code in the current commit is good enough.

However, there is still room for improvement. In the next PR, I will remove DeformConvCopyGemmOutputRowMajorToNCHW to reduce the memory copy time required for multi-batch image processing in CUDA.

Of course, I don’t expect the improvement to be significant.

The current implementation already outperforms TorchVision and MMCV (I’ll add the performance comparison results for MMCV when I have time).

@tianleiwu
Copy link
Contributor

/azp run Linux QNN CI Pipeline, Win_TRT_Minimal_CUDA_Test_CI, Windows ARM64 QNN CI Pipeline, Windows GPU Doc Gen CI Pipeline

tianleiwu
tianleiwu previously approved these changes Mar 17, 2026
@azure-pipelines
Copy link

Azure Pipelines successfully started running 4 pipeline(s).

@tianleiwu tianleiwu enabled auto-merge (squash) March 17, 2026 23:18
@tianleiwu
Copy link
Contributor

tianleiwu commented Mar 18, 2026

@ShirasawaSama, please update two files under docs directory with artifacts here: https://aiinfra.visualstudio.com/PublicPackages/_build/results?buildId=1126907&view=artifacts&pathAsName=false&type=publishedArtifacts

It is required to pass "Windows GPU Doc Gen CI Pipeline"

@ShirasawaSama
Copy link
Contributor Author

ok

auto-merge was automatically disabled March 18, 2026 04:35

Head branch was pushed to by a user without write access

@tianleiwu
Copy link
Contributor

/azp run Linux QNN CI Pipeline, Win_TRT_Minimal_CUDA_Test_CI, Windows ARM64 QNN CI Pipeline, Windows GPU Doc Gen CI Pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 4 pipeline(s).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants