Skip to content

Exception when trying to train on qwen3 8b model: AttributeError: 'ControlModule' object has no attribute 'attention_type' #64

@itszn

Description

@itszn
# https://huggingface.co/Qwen/Qwen3-8B/tree/main
model_name = "/Volumes/models/Qwen3-8B/"

hf_token = "..."
tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
tokenizer.pad_token_id = 0

# device_map="auto" will distribute the model over multiple GPUs
# this notebook was run on a runpod 3xA100—the cuda:0 device will need to have enough spare memory
# to do inference on for this notebook to work
model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype=torch.float16, device_map="mps:0", token=hf_token
)

wrapped_model = model
model = ControlModel(wrapped_model, list(range(1, 35)))

def chat_template_unparse(messages: list[tuple[str, str]]) -> str:
    template = []
    for role, content in messages:
        template.append(
            f"<|im_start|>{role}\n\n{content}<|im_end|>"
        )
    if messages[-1][0] != "assistant":
        # prefill assistant prefix
        template.append("<|im_start|>assistant\n\n")
    return "".join(template)

with open("./data/all_truncated_outputs.json") as f:
    output_suffixes = json.load(f)
truncated_output_suffixes = [
    tokenizer.convert_tokens_to_string(tokens[:i])
    for tokens in (tokenizer.tokenize(s) for s in output_suffixes)
    for i in range(1, len(tokens))
]
truncated_output_suffixes_512 = [
    tokenizer.convert_tokens_to_string(tokens[:i])
    for tokens in (tokenizer.tokenize(s) for s in output_suffixes[:512])
    for i in range(1, len(tokens))
]

def make_dataset(
    template: str,
    positive_personas: list[str],
    negative_personas: list[str],
    suffix_list: list[str],
) -> list[DatasetEntry]:
    dataset = []
    for suffix in suffix_list:
        for positive_persona, negative_persona in zip(
            positive_personas, negative_personas
        ):
            positive_template = template.format(persona=positive_persona)
            negative_template = template.format(persona=negative_persona)
            dataset.append(
                DatasetEntry(
                    positive=f"{positive_template}{suffix}",
                    negative=f"{negative_template}{suffix}",
                )
            )
    return dataset

def make_vector(template, pos, neg):
    ds = make_dataset(
        template,
        pos, neg,
        truncated_output_suffixes,
    )
    model.reset()
    vec = ControlVector.train(
        model, tokenizer, ds, batch_size=32, method="pca_center"
    )
    return vec

vec = make_vector(
    chat_template_unparse([("user", "{persona}")]),
    ["Pretend you're a happy person making statements about the world"],
    ["Pretend you're a depressed person making statements about the world"]
)
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[12], line 1
----> 1 vec = make_vector(
      2     chat_template_unparse([("user", "{persona}")]),
      3     ["Pretend you're a happy person making statements about the world"],
      4     ["Pretend you're a depressed person making statements about the world"]
      5 )

Cell In[10], line 8, in make_vector(template, pos, neg)
      2 ds = make_dataset(
      3     template,
      4     pos, neg,
      5     truncated_output_suffixes,
      6 )
      7 model.reset()
----> 8 vec = ControlVector.train(
      9     model, tokenizer, ds, batch_size=32, method="pca_center"
     10 )
     11 return vec

File [~/.cache/uv/archive-v0/V2To1YSJ6Tv55EOIzYrge/lib/python3.12/site-packages/repeng/extract.py:53](http://metta:8888/lab/tree/~/.cache/uv/archive-v0/V2To1YSJ6Tv55EOIzYrge/lib/python3.12/site-packages/repeng/extract.py#line=52), in ControlVector.train(cls, model, tokenizer, dataset, **kwargs)
     36 """
     37 Train a ControlVector for a given model and tokenizer using the provided dataset.
     38 
   (...)     50     ControlVector: The trained vector.
     51 """
     52 with torch.inference_mode():
---> 53     dirs = read_representations(
     54         model,
     55         tokenizer,
     56         dataset,
     57         **kwargs,
     58     )
     59 return cls(model_type=model.config.model_type, directions=dirs)

File [~/.cache/uv/archive-v0/V2To1YSJ6Tv55EOIzYrge/lib/python3.12/site-packages/repeng/extract.py:267](http://metta:8888/lab/tree/~/.cache/uv/archive-v0/V2To1YSJ6Tv55EOIzYrge/lib/python3.12/site-packages/repeng/extract.py#line=266), in read_representations(model, tokenizer, inputs, hidden_layers, batch_size, method, transform_hiddens)
    264 # the order is [positive, negative, positive, negative, ...]
    265 train_strs = [s for ex in inputs for s in (ex.positive, ex.negative)]
--> 267 layer_hiddens = batched_get_hiddens(
    268     model, tokenizer, train_strs, hidden_layers, batch_size
    269 )
    271 if transform_hiddens is not None:
    272     layer_hiddens = transform_hiddens(layer_hiddens)

File [~/.cache/uv/archive-v0/V2To1YSJ6Tv55EOIzYrge/lib/python3.12/site-packages/repeng/extract.py:350](http://metta:8888/lab/tree/~/.cache/uv/archive-v0/V2To1YSJ6Tv55EOIzYrge/lib/python3.12/site-packages/repeng/extract.py#line=349), in batched_get_hiddens(model, tokenizer, inputs, hidden_layers, batch_size)
    348 encoded_batch = tokenizer(batch, padding=True, return_tensors="pt")
    349 encoded_batch = encoded_batch.to(model.device)
--> 350 out = model(**encoded_batch, output_hidden_states=True)
    351 attention_mask = encoded_batch["attention_mask"]
    352 for i in range(len(batch)):

File [~/.cache/uv/archive-v0/V2To1YSJ6Tv55EOIzYrge/lib/python3.12/site-packages/repeng/control.py:121](http://metta:8888/lab/tree/~/.cache/uv/archive-v0/V2To1YSJ6Tv55EOIzYrge/lib/python3.12/site-packages/repeng/control.py#line=120), in ControlModel.__call__(self, *args, **kwargs)
    120 def __call__(self, *args, **kwargs):
--> 121     return self.model(*args, **kwargs)

File [~/.cache/uv/archive-v0/V2To1YSJ6Tv55EOIzYrge/lib/python3.12/site-packages/torch/nn/modules/module.py:1751](http://metta:8888/lab/tree/~/.cache/uv/archive-v0/V2To1YSJ6Tv55EOIzYrge/lib/python3.12/site-packages/torch/nn/modules/module.py#line=1750), in Module._wrapped_call_impl(self, *args, **kwargs)
   1749     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1750 else:
-> 1751     return self._call_impl(*args, **kwargs)

File [~/.cache/uv/archive-v0/V2To1YSJ6Tv55EOIzYrge/lib/python3.12/site-packages/torch/nn/modules/module.py:1762](http://metta:8888/lab/tree/~/.cache/uv/archive-v0/V2To1YSJ6Tv55EOIzYrge/lib/python3.12/site-packages/torch/nn/modules/module.py#line=1761), in Module._call_impl(self, *args, **kwargs)
   1757 # If we don't have any hooks, we want to skip the rest of the logic in
   1758 # this function, and just call forward.
   1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1760         or _global_backward_pre_hooks or _global_backward_hooks
   1761         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762     return forward_call(*args, **kwargs)
   1764 result = None
   1765 called_always_called_hooks = set()

File [~/.cache/uv/archive-v0/V2To1YSJ6Tv55EOIzYrge/lib/python3.12/site-packages/transformers/utils/generic.py:943](http://metta:8888/lab/tree/~/.cache/uv/archive-v0/V2To1YSJ6Tv55EOIzYrge/lib/python3.12/site-packages/transformers/utils/generic.py#line=942), in can_return_tuple.<locals>.wrapper(self, *args, **kwargs)
    940     set_attribute_for_modules(self, "_is_top_level_module", False)
    942 try:
--> 943     output = func(self, *args, **kwargs)
    944     if is_requested_to_return_tuple or (is_configured_to_return_tuple and is_top_level_module):
    945         output = output.to_tuple()

File [~/.cache/uv/archive-v0/V2To1YSJ6Tv55EOIzYrge/lib/python3.12/site-packages/transformers/models/qwen3/modeling_qwen3.py:570](http://metta:8888/lab/tree/~/.cache/uv/archive-v0/V2To1YSJ6Tv55EOIzYrge/lib/python3.12/site-packages/transformers/models/qwen3/modeling_qwen3.py#line=569), in Qwen3ForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, cache_position, logits_to_keep, **kwargs)
    565 output_hidden_states = (
    566     output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
    567 )
    569 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
--> 570 outputs: BaseModelOutputWithPast = self.model(
    571     input_ids=input_ids,
    572     attention_mask=attention_mask,
    573     position_ids=position_ids,
    574     past_key_values=past_key_values,
    575     inputs_embeds=inputs_embeds,
    576     use_cache=use_cache,
    577     output_attentions=output_attentions,
    578     output_hidden_states=output_hidden_states,
    579     cache_position=cache_position,
    580     **kwargs,
    581 )
    583 hidden_states = outputs.last_hidden_state
    584 # Only compute necessary logits, and do not upcast them to float if we are not computing the loss

File [~/.cache/uv/archive-v0/V2To1YSJ6Tv55EOIzYrge/lib/python3.12/site-packages/torch/nn/modules/module.py:1751](http://metta:8888/lab/tree/~/.cache/uv/archive-v0/V2To1YSJ6Tv55EOIzYrge/lib/python3.12/site-packages/torch/nn/modules/module.py#line=1750), in Module._wrapped_call_impl(self, *args, **kwargs)
   1749     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1750 else:
-> 1751     return self._call_impl(*args, **kwargs)

File [~/.cache/uv/archive-v0/V2To1YSJ6Tv55EOIzYrge/lib/python3.12/site-packages/torch/nn/modules/module.py:1762](http://metta:8888/lab/tree/~/.cache/uv/archive-v0/V2To1YSJ6Tv55EOIzYrge/lib/python3.12/site-packages/torch/nn/modules/module.py#line=1761), in Module._call_impl(self, *args, **kwargs)
   1757 # If we don't have any hooks, we want to skip the rest of the logic in
   1758 # this function, and just call forward.
   1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1760         or _global_backward_pre_hooks or _global_backward_hooks
   1761         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762     return forward_call(*args, **kwargs)
   1764 result = None
   1765 called_always_called_hooks = set()

File [~/.cache/uv/archive-v0/V2To1YSJ6Tv55EOIzYrge/lib/python3.12/site-packages/transformers/utils/generic.py:943](http://metta:8888/lab/tree/~/.cache/uv/archive-v0/V2To1YSJ6Tv55EOIzYrge/lib/python3.12/site-packages/transformers/utils/generic.py#line=942), in can_return_tuple.<locals>.wrapper(self, *args, **kwargs)
    940     set_attribute_for_modules(self, "_is_top_level_module", False)
    942 try:
--> 943     output = func(self, *args, **kwargs)
    944     if is_requested_to_return_tuple or (is_configured_to_return_tuple and is_top_level_module):
    945         output = output.to_tuple()

File [~/.cache/uv/archive-v0/V2To1YSJ6Tv55EOIzYrge/lib/python3.12/site-packages/transformers/models/qwen3/modeling_qwen3.py:460](http://metta:8888/lab/tree/~/.cache/uv/archive-v0/V2To1YSJ6Tv55EOIzYrge/lib/python3.12/site-packages/transformers/models/qwen3/modeling_qwen3.py#line=459), in Qwen3Model.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, cache_position, **flash_attn_kwargs)
    455 if output_hidden_states:
    456     all_hidden_states += (hidden_states,)
    458 layer_outputs = decoder_layer(
    459     hidden_states,
--> 460     attention_mask=causal_mask_mapping[decoder_layer.attention_type],
    461     position_ids=position_ids,
    462     past_key_value=past_key_values,
    463     output_attentions=output_attentions,
    464     use_cache=use_cache,
    465     cache_position=cache_position,
    466     position_embeddings=position_embeddings,
    467     **flash_attn_kwargs,
    468 )
    470 hidden_states = layer_outputs[0]
    472 if output_attentions:

File [~/.cache/uv/archive-v0/V2To1YSJ6Tv55EOIzYrge/lib/python3.12/site-packages/torch/nn/modules/module.py:1940](http://metta:8888/lab/tree/~/.cache/uv/archive-v0/V2To1YSJ6Tv55EOIzYrge/lib/python3.12/site-packages/torch/nn/modules/module.py#line=1939), in Module.__getattr__(self, name)
   1938     if name in modules:
   1939         return modules[name]
-> 1940 raise AttributeError(
   1941     f"'{type(self).__name__}' object has no attribute '{name}'"
   1942 )

AttributeError: 'ControlModule' object has no attribute 'attention_type'

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions