Skip to content

[MLIR] Optimized linear layer in PyTorch frontend#596

Merged
Moehre2 merged 8 commits intomainfrom
mlir-opt-linear
Mar 20, 2026
Merged

[MLIR] Optimized linear layer in PyTorch frontend#596
Moehre2 merged 8 commits intomainfrom
mlir-opt-linear

Conversation

@Moehre2
Copy link
Copy Markdown
Contributor

@Moehre2 Moehre2 commented Mar 17, 2026

  • Optimized linear layer in PyTorch frontend by enabling the MatmulNode to expand to GEMM node with transpose
  • Restructured daisy workflows for PyTorch frontend
  • Now, build time of models in PyTorch daisy workflows is not included in the runtime

@Moehre2 Moehre2 self-assigned this Mar 17, 2026
@daisytuner
Copy link
Copy Markdown

daisytuner Bot commented Mar 17, 2026

Daisytuner Report - python_npbench (zinnia)

@@                                   Benchmarks                                   @@
=====================================================================================
  Benchmark              Time        ΔTime       Thr         Energy      ΔEnergy     
=====================================================================================
# adi_numpy              1.32 s      +0.02%      N/A         131.28 J    +0.12%      
# adi_omp                15.95 s     -0.10%      N/A         1505.45 J   -0.11%      
# adi_cuda               4.80 s      +1.09%      N/A         465.19 J    +1.08%      
# adi_seq_tuning         15.85 s     -0.42%      N/A         1495.47 J   -0.62%      
# atax_numpy             2.15 s      -0.50%      N/A         222.81 J    -0.43%      
# atax_omp               2.47 s      +0.06%      N/A         259.24 J    +0.09%      
# atax_cuda              4.11 s      +0.17%      N/A         422.88 J    +0.15%      
# atax_seq_tuning        3.71 s      +0.10%      N/A         375.30 J    +0.07%      
# gemm_numpy             1.23 s      -1.27%      N/A         198.41 J    -1.10%      
# gemm_omp               1.11 s      +0.28%      N/A         162.01 J    +0.28%      
# gemm_cuda              10.63 s     +0.46%      N/A         1010.63 J   +0.51%      
# gemm_seq_tuning        1.11 s      -0.01%      N/A         161.73 J    -0.11%      
# gesummv_numpy          1.75 s      -0.92%      N/A         249.97 J    -1.06%      
# gesummv_omp            5.33 s      +0.35%      N/A         692.01 J    +0.25%      
# gesummv_cuda           8.40 s      +0.82%      N/A         1004.61 J   +0.59%      
# gesummv_seq_tuning     6.55 s      -0.41%      N/A         805.20 J    -0.13%      
# gemver_numpy           1.08 s      -0.38%      N/A         166.97 J    -0.53%      
# gemver_omp             714.68 ms   +0.56%      N/A         81.59 J     +0.97%      
# gemver_cuda            3.86 s      +0.63%      N/A         386.49 J    +0.18%      
# gemver_seq_tuning      4.48 s      +0.84%      N/A         433.74 J    +0.91%      
# k2mm_numpy             1.19 s      -0.78%      N/A         197.27 J    -0.54%      
# k2mm_omp               3.62 s      +0.40%      N/A         469.79 J    +0.47%      
# k2mm_cuda              13.60 s     +0.16%      N/A         1287.13 J   +0.12%      
# k2mm_seq_tuning        3.62 s      -0.11%      N/A         468.00 J    +0.11%      
# k3mm_numpy             1.03 s      -0.24%      N/A         184.19 J    -0.14%      
# k3mm_omp               5.72 s      -0.37%      N/A         796.84 J    -0.13%      
# k3mm_cuda              19.85 s     +0.25%      N/A         1868.92 J   +0.19%      
# k3mm_seq_tuning        5.72 s      -0.16%      N/A         792.39 J    -0.15%      
# mvt_numpy              2.42 s      -0.56%      N/A         247.59 J    -0.71%      
# mvt_omp                2.74 s      -0.09%      N/A         284.61 J    -0.15%      
# mvt_cuda               3.36 s      +0.54%      N/A         342.67 J    +0.35%      
# mvt_seq_tuning         2.75 s      +0.14%      N/A         285.10 J    +0.18%      
# symm_numpy             781.78 ms   -0.59%      N/A         80.62 J     -0.19%      
# symm_omp               8.42 s      +0.41%      N/A         802.01 J    +0.31%      
# symm_seq_tuning        8.36 s      -0.12%      N/A         796.53 J    -0.09%      
# syr2k_numpy            882.10 ms   +0.23%      N/A         89.64 J     -0.23%      
# syr2k_omp              9.81 s      -0.03%      N/A         931.91 J    -0.19%      
# syr2k_cuda             1.64 s      +0.08%      N/A         170.18 J    +0.23%      
# syr2k_seq_tuning       9.80 s      +0.24%      N/A         931.30 J    +0.28%      
# syrk_numpy             775.56 ms   +0.24%      N/A         80.09 J     +0.48%      
# syrk_omp               5.95 s      +0.93%      N/A         572.76 J    +1.03%      
# syrk_cuda              1.52 s      +0.53%      N/A         159.28 J    +0.86%      
# syrk_seq_tuning        5.90 s      +0.33%      N/A         567.60 J    +0.42%      
# trmm_numpy             879.33 ms   +0.22%      N/A         89.52 J     +0.03%      
# trmm_omp               3.10 s      +0.05%      N/A         306.07 J    +0.19%      
# trmm_seq_tuning        3.36 s      +1.10%      N/A         321.51 J    +1.10%      

@daisytuner
Copy link
Copy Markdown

daisytuner Bot commented Mar 19, 2026

Daisytuner Report - mlir_torch_models (chamomile)

@@                                   Benchmarks                                   @@
=====================================================================================
  Benchmark              Time        ΔTime       Thr         Energy      ΔEnergy     
=====================================================================================
# bn_conv_bn_relu_maxpool_torch18.60 s     N/A         N/A         3616.28 J   N/A         
# bn_conv_bn_relu_maxpool_run_none4.55 s      N/A         N/A         828.56 J    N/A         
# bn_conv_bn_relu_maxpool_run_sequential4.61 s      N/A         N/A         876.99 J    N/A         
# bn_conv_bn_relu_maxpool_run_openmp4.55 s      N/A         N/A         811.26 J    N/A         
# bn_conv_bn_relu_maxpool_run_cuda3.64 s      N/A         N/A         718.39 J    N/A         

@daisytuner
Copy link
Copy Markdown

daisytuner Bot commented Mar 19, 2026

Daisytuner Report - mlir_torch_layers (chamomile)

@@                                   Benchmarks                                   @@
=====================================================================================
  Benchmark              Time        ΔTime       Thr         Energy      ΔEnergy     
=====================================================================================
# batchnorm_torch        19.27 s     N/A         N/A         3614.31 J   N/A         
# batchnorm_run_none     7.67 s      N/A         N/A         1418.62 J   N/A         
# batchnorm_run_sequential8.48 s      N/A         N/A         1565.67 J   N/A         
# batchnorm_run_openmp   8.51 s      N/A         N/A         1571.35 J   N/A         
# batchnorm_run_cuda     12.71 s     N/A         N/A         2348.82 J   N/A         
# linear_torch           6.13 s      N/A         N/A         1433.52 J   N/A         
# linear_run_none        9.99 s      N/A         N/A         2651.34 J   N/A         
# linear_run_sequential  8.51 s      N/A         N/A         2378.13 J   N/A         
# linear_run_openmp      8.53 s      N/A         N/A         2381.48 J   N/A         
# linear_run_cuda        7.38 s      N/A         N/A         1371.67 J   N/A         
# matmul_torch           6.13 s      N/A         N/A         1434.18 J   N/A         
# matmul_run_none        10.04 s     N/A         N/A         2669.27 J   N/A         
# matmul_run_sequential  8.44 s      N/A         N/A         2374.23 J   N/A         
# matmul_run_openmp      8.48 s      N/A         N/A         2378.82 J   N/A         
# matmul_run_cuda        7.28 s      N/A         N/A         1360.09 J   N/A         
# pooling_torch          25.50 s     N/A         N/A         4811.07 J   N/A         
# pooling_run_none       17.26 s     N/A         N/A         3158.54 J   N/A         
# pooling_run_sequential 17.34 s     N/A         N/A         3174.27 J   N/A         
# pooling_run_openmp     17.30 s     N/A         N/A         3167.63 J   N/A         
# pooling_run_cuda       24.05 s     N/A         N/A         4430.54 J   N/A         
# relu_torch             18.98 s     N/A         N/A         3557.53 J   N/A         
# relu_run_none          4.40 s      N/A         N/A         822.70 J    N/A         
# relu_run_sequential    4.41 s      N/A         N/A         825.30 J    N/A         
# relu_run_openmp        4.41 s      N/A         N/A         825.36 J    N/A         
# relu_run_cuda          6.09 s      N/A         N/A         1134.64 J   N/A         

@Moehre2 Moehre2 requested review from Atrisan and ramonwirsch March 20, 2026 08:32
symbolic::Expression offset_a_; ///< Offset into tensor A (in elements)
symbolic::Expression offset_b_; ///< Offset into tensor B (in elements)

static bool has_basic_strides(symbolic::MultiExpression shape, symbolic::MultiExpression strides);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"basic" is not the most expressive name. Those checks should have thorough documentation so that there are no misunderstandings for what guarantees they provide

symbolic::Expression offset_a_; ///< Offset into tensor A (in elements)
symbolic::Expression offset_b_; ///< Offset into tensor B (in elements)

static bool has_basic_strides(symbolic::MultiExpression shape, symbolic::MultiExpression strides);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As static methods with no dependency on MatMulNode, would be more intuitive to find on Tensor sth.
Its applicable to any tensor no?

@Moehre2 Moehre2 merged commit 3f0642c into main Mar 20, 2026
21 checks passed
@Moehre2 Moehre2 deleted the mlir-opt-linear branch March 20, 2026 10:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants