diff --git a/auto_round/experimental/utils.py b/auto_round/experimental/utils.py index e90f9c0d5..b38f33d16 100644 --- a/auto_round/experimental/utils.py +++ b/auto_round/experimental/utils.py @@ -35,7 +35,12 @@ def update_parameter_data(module: torch.nn.Module, new_val: torch.Tensor, name: if hasattr(module, name): param = getattr(module, name) if isinstance(param, torch.nn.Parameter): - param.data.copy_(new_val) + if param.shape == new_val.shape: + param.data.copy_(new_val) + else: + # Re-create the parameter when shapes differ (e.g. after offload + # cleared it to an empty tensor). + module.register_parameter(name, torch.nn.Parameter(new_val.clone(), requires_grad=param.requires_grad)) else: module.register_parameter(name, torch.nn.Parameter(new_val)) else: diff --git a/test/test_cpu/schemes/test_auto_scheme.py b/test/test_cpu/schemes/test_auto_scheme.py index 9bd362bf3..fe7b18511 100644 --- a/test/test_cpu/schemes/test_auto_scheme.py +++ b/test/test_cpu/schemes/test_auto_scheme.py @@ -54,3 +54,41 @@ def test_layer_config(self, tiny_opt_model_path): avg_bits, _ = compute_avg_bits_for_model(model) print(avg_bits) assert target_bits - 0.1 < avg_bits <= target_bits + 1e-3 + + def test_autoscheme_mxfp_with_static_kv(self, tiny_opt_model_path): + """MXFP4+MXFP8 AutoScheme with static_kv_dtype='fp8' should yield + non-zero k_scale and v_scale on the first attention layer.""" + scheme = AutoScheme( + avg_bits=5.0, + options=("MXFP4", "MXFP8"), + nsamples=2, + seqlen=8, + ignore_scale_zp_bits=True, + ) + ar = AutoRound( + tiny_opt_model_path, + scheme=scheme, + static_kv_dtype="fp8", + iters=0, + nsamples=2, + seqlen=8, + disable_opt_rtn=True, + ) + quantized_model, _ = ar.quantize_and_save( + format="fake", + output_dir=self.save_dir, + ) + + # After quantize_and_save, the model's attention modules should have + # k_scale and v_scale registered as parameters with non-zero values. + attn = quantized_model.model.decoder.layers[0].self_attn + assert hasattr(attn, "k_scale"), "missing k_scale after quantization" + assert hasattr(attn, "v_scale"), "missing v_scale after quantization" + k_val = attn.k_scale.item() + v_val = attn.v_scale.item() + assert k_val != 0.0, ( + "k_scale is 0.0 — scale was not collected during " "calibration with AutoScheme + static_kv_dtype" + ) + assert v_val != 0.0, ( + "v_scale is 0.0 — scale was not collected during " "calibration with AutoScheme + static_kv_dtype" + )