Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion auto_round/experimental/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@ def update_parameter_data(module: torch.nn.Module, new_val: torch.Tensor, name:
if hasattr(module, name):
param = getattr(module, name)
if isinstance(param, torch.nn.Parameter):
param.data.copy_(new_val)
if param.shape == new_val.shape:
param.data.copy_(new_val)
else:
# Re-create the parameter when shapes differ (e.g. after offload
# cleared it to an empty tensor).
module.register_parameter(name, torch.nn.Parameter(new_val.clone(), requires_grad=param.requires_grad))
Comment thread
xin3he marked this conversation as resolved.
else:
module.register_parameter(name, torch.nn.Parameter(new_val))
else:
Expand Down
38 changes: 38 additions & 0 deletions test/test_cpu/schemes/test_auto_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,41 @@ def test_layer_config(self, tiny_opt_model_path):
avg_bits, _ = compute_avg_bits_for_model(model)
print(avg_bits)
assert target_bits - 0.1 < avg_bits <= target_bits + 1e-3

def test_autoscheme_mxfp_with_static_kv(self, tiny_opt_model_path):
"""MXFP4+MXFP8 AutoScheme with static_kv_dtype='fp8' should yield
non-zero k_scale and v_scale on the first attention layer."""
scheme = AutoScheme(
avg_bits=5.0,
options=("MXFP4", "MXFP8"),
nsamples=2,
seqlen=8,
ignore_scale_zp_bits=True,
)
ar = AutoRound(
tiny_opt_model_path,
scheme=scheme,
static_kv_dtype="fp8",
iters=0,
nsamples=2,
seqlen=8,
disable_opt_rtn=True,
)
quantized_model, _ = ar.quantize_and_save(
format="fake",
output_dir=self.save_dir,
)

# After quantize_and_save, the model's attention modules should have
# k_scale and v_scale registered as parameters with non-zero values.
attn = quantized_model.model.decoder.layers[0].self_attn
assert hasattr(attn, "k_scale"), "missing k_scale after quantization"
assert hasattr(attn, "v_scale"), "missing v_scale after quantization"
k_val = attn.k_scale.item()
v_val = attn.v_scale.item()
Comment thread
xin3he marked this conversation as resolved.
assert k_val != 0.0, (
"k_scale is 0.0 — scale was not collected during " "calibration with AutoScheme + static_kv_dtype"
)
assert v_val != 0.0, (
"v_scale is 0.0 — scale was not collected during " "calibration with AutoScheme + static_kv_dtype"
)
Loading