Skip to content

support hadamard transform for mxfp4 with rtn or autoround method#1515

Merged
chensuyue merged 43 commits intointel:mainfrom
lkk12014402:support_hadamard_transform
Mar 20, 2026
Merged

support hadamard transform for mxfp4 with rtn or autoround method#1515
chensuyue merged 43 commits intointel:mainfrom
lkk12014402:support_hadamard_transform

Conversation

@lkk12014402
Copy link
Copy Markdown
Contributor

Description

  1. support hadamard transform for mxfp4

original linear:

$$ y = Wx $$

transform matrix $$H$$(Hadamard should $$H^\top H = I$$,and $$H^{-1}=H^\top$$):

$$ y = W x = (W H^\top) (H x) $$

define:

  • $$W' = W H^\top$$rotated weight
  • $$x' = H x$$rotated activation

then:

$$ y = W' x' $$

  1. support do evaluation with huggingface/transformers

with huggingface/transformers

from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "./mxfp4_transformed_model"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto")
model.to("cuda")
print(model)
tokenizer = AutoTokenizer.from_pretrained(model_name)
text = "There is a girl who likes adventure,"
inputs = tokenizer(text, return_tensors="pt").to(model.device)
print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0]))

with vllm

lm_eval --model hf    --model_args pretrained=./mxfp4_transformed_model    --tasks gsm8k     --batch_size 8

Signed-off-by: lkk12014402 <kaokao.lv@intel.com>
@lkk12014402 lkk12014402 requested review from Copilot, wenhuach21 and yiliu30 and removed request for Copilot and yiliu30 March 9, 2026 03:03
@lkk12014402
Copy link
Copy Markdown
Contributor Author

This PR is a refactored version of the original PR: #1349

@wenhuach21 wenhuach21 requested a review from n1ck-guo March 9, 2026 03:09
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds Hadamard-based rotation support for MXFP4/NVFP4 workflows by introducing an experimental transform pipeline (weight rotation + activation-side transform during inference) and a CUDA test exercising quantize→save→HF load→generate.

Changes:

  • Introduce experimental transform modules/config and a Triton MXFP4 Hadamard+QDQ kernel wrapper.
  • Plumb transform_config through quantization scheme/config and apply activation transforms during HF model conversion.
  • Add a CUDA integration test that applies transform + quantizes/saves then runs HF generation.

Reviewed changes

Copilot reviewed 9 out of 10 changed files in this pull request and generated 15 comments.

Show a summary per file
File Description
auto_round/experimental/transform/apply.py Applies transforms to model modules; registers activation pre-hooks or fuses weight transforms.
auto_round/experimental/transform/transforms.py Defines Hadamard/identity transforms and transform factory.
auto_round/experimental/transform/transform_config.py Pydantic config object for transform serialization.
auto_round/experimental/transform/triton/mxfp4.py Triton kernel for Hadamard + FP4 QDQ on activations.
auto_round/inference/convert_model.py Threads transform_config into layer config and registers activation transform during conversion.
auto_round/schemes.py Adds transform_config field to QuantizationScheme.
auto_round/compressors/base.py Adds transform_config to serialization keys and compressor init args.
auto_round/autoround.py Adds transform_config arg pass-through on AutoRound construction.
test/test_cuda/transform/test_mxfp4_transform.py New CUDA test for transform + MXFP4 quantize/save and HF inference.

You can also share your feedback on Copilot code review. Take the survey.

Comment thread auto_round/compressors/base.py Outdated
Comment thread auto_round/experimental/transform/apply.py Outdated
Comment thread auto_round/experimental/transform/triton/mxfp4.py Outdated
Comment thread auto_round/experimental/transform/apply.py Outdated
Comment thread auto_round/experimental/transform/apply.py Outdated
Comment thread auto_round/experimental/transform/apply.py Outdated
Comment thread auto_round/experimental/transform/transforms.py
Comment thread auto_round/autoround.py Outdated
Comment thread auto_round/experimental/transform/apply.py Outdated
Comment thread auto_round/experimental/transform/triton/mxfp4.py Outdated
@wenhuach21
Copy link
Copy Markdown
Contributor

I left several comments in the previous comments and several of them have not been addressed

@wenhuach21 wenhuach21 requested a review from lvliang-intel March 9, 2026 03:27
@wenhuach21
Copy link
Copy Markdown
Contributor

@n1ck-guo please have a careful review, otherwise, it will be your work to refine the API

Comment thread auto_round/experimental/transform/transforms.py
Comment thread auto_round/autoround.py Outdated
Comment thread auto_round/experimental/transform/transform_config.py Outdated
Comment thread auto_round/experimental/transform/transform_config.py
@chensuyue chensuyue added this to the 0.10.3 milestone Mar 9, 2026
Comment thread auto_round/experimental/transform/triton/mxfp4.py
Comment thread auto_round/experimental/transform/transforms.py Outdated
Comment thread auto_round/experimental/transform/transforms.py
lkk12014402 and others added 5 commits March 9, 2026 13:37
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: lkk12014402 <kaokao.lv@intel.com>
@wenhuach21
Copy link
Copy Markdown
Contributor

Additionally, it would be better to run some accuracy tests for Qwen and LLaMA to verify correctness and document the results in the docs folder.

@wenhuach21
Copy link
Copy Markdown
Contributor

Besides, the same Hadamard transformation for fused modules such as QKV and MoE is not handled yet. If it is not supported in this version, we should warn users.

@yiliu30
Copy link
Copy Markdown
Contributor

yiliu30 commented Mar 16, 2026

Code review

No issues found. Checked for bugs and CLAUDE.md compliance.

🤖 Generated with Claude Code

- If this code review was useful, please react with 👍. Otherwise, react with 👎.

@lkk12014402
Copy link
Copy Markdown
Contributor Author

Code review

No issues found. Checked for bugs and CLAUDE.md compliance.

🤖 Generated with Claude Code

  • If this code review was useful, please react with 👍. Otherwise, react with 👎.

Thanks~

@lkk12014402
Copy link
Copy Markdown
Contributor Author

lkk12014402 commented Mar 17, 2026

Besides, the same Hadamard transformation for fused modules such as QKV and MoE is not handled yet. If it is not supported in this version, we should warn users.

I see. As a next step, we have a plan to enable fused modules such as QKV, MLP, and MoE. In the current version, the Hadamard transform for these fused patterns is not handled yet, so we’ll add clear guidance in the documentation and warn users when it’s not supported to avoid confusion.

@lkk12014402
Copy link
Copy Markdown
Contributor Author

@wenhuach21 @yiliu30 please help review

Comment thread auto_round/experimental/transform/triton/mxfp4.py
Comment thread auto_round/experimental/transform/triton/utils.py
Signed-off-by: lkk12014402 <kaokao.lv@intel.com>
@lkk12014402
Copy link
Copy Markdown
Contributor Author

lkk12014402 commented Mar 17, 2026

@wenhuach21 The test accuracy results are here, evaluated with lm_eval using --model hf.

Llama-3.1-8B-Instruct Gsm8k_llama   (strict) Mmlu_llama Piqa Hellaswag Average Ratio
BF16 0.8264 0.6763 0.8014 0.5976 0.7254 /
MXFP4 (RTN) 0.5830 0.5625 0.7606 0.5568 0.615725 84.88%
mxfp4(iters=200) 0.6929 0.6204 0.7797 0.5568 0.6624 91.31%
deterministic_hadamard + mxfp4(iters=200) 0.7491 0.6398 0.7780 0.5685 0.6838 94.26%
random_hadamard (same matrix for each layer) + mxfp4 0.7316 0.6431 0.7742 0.5666 0.67887 93.58%

@wenhuach21
Copy link
Copy Markdown
Contributor

@wenhuach21 The test accuracy results are here, evaluated with lm_eval using --model hf.

Llama-3.1-8B-Instruct Gsm8k_llama   (strict) Mmlu_llama Piqa Hellaswag Average Ratio
BF16 0.8264 0.6763 0.8014 0.5976 0.7254 /
MXFP4 (RTN) 0.5830 0.5625 0.7606 0.5568 0.615725 84.88%
mxfp4(iters=200) 0.6929 0.6204 0.7797 0.5568 0.6624 91.31%
deterministic_hadamard + mxfp4(iters=200) 0.7491 0.6398 0.7780 0.5685 0.6838 94.26%
random_hadamard (same matrix for each layer) + mxfp4 0.7316 0.6431 0.7742 0.5666 0.67887 93.58%

Better document it and add qwen3-8B

@hshen14
Copy link
Copy Markdown
Contributor

hshen14 commented Mar 18, 2026

Besides the llama3.1 8B, can we also post a 70B model as well? Prepare the readme to show the recipe and acc data in a separate PR.

Copy link
Copy Markdown
Contributor

@yiliu30 yiliu30 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Others LGTM

Comment thread auto_round/experimental/transform/triton/utils.py
Comment thread auto_round/experimental/transform/utils/hadamards.safetensors
@lkk12014402
Copy link
Copy Markdown
Contributor Author

o show the recipe and acc data in a

will do

@wenhuach21
Copy link
Copy Markdown
Contributor

wenhuach21 commented Mar 20, 2026

please resolve the API issue and then merge

@chensuyue chensuyue merged commit a49550b into intel:main Mar 20, 2026
24 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants