diff --git a/src/lmflow/pipeline/utils/lisa_trainer.py b/src/lmflow/pipeline/utils/lisa_trainer.py index a1e441b1d..fb2b851e5 100644 --- a/src/lmflow/pipeline/utils/lisa_trainer.py +++ b/src/lmflow/pipeline/utils/lisa_trainer.py @@ -1,9 +1,111 @@ +import logging from typing import Optional import numpy as np +import torch.nn as nn from transformers import PreTrainedModel from transformers.trainer_callback import TrainerCallback +logger = logging.getLogger(__name__) + +# Mapping from model class name to the dot-separated path of transformer layers. +# Add new entries here as new model families are released. +CLASS_TO_LAYERS_MAP = { + # LLaMA family (1, 2, 3, 3.1, 3.2, 3.3) + "LlamaForCausalLM": "model.model.layers", + # Qwen family + "Qwen2ForCausalLM": "model.model.layers", + "Qwen2MoeForCausalLM": "model.model.layers", + # Mistral / Mixtral + "MistralForCausalLM": "model.model.layers", + "MixtralForCausalLM": "model.model.layers", + # Gemma family + "GemmaForCausalLM": "model.model.layers", + "Gemma2ForCausalLM": "model.model.layers", + "Gemma3ForCausalLM": "model.model.layers", + # Phi family (Microsoft) + "Phi3ForCausalLM": "model.model.layers", + "PhiForCausalLM": "model.model.layers", + # DeepSeek + "DeepseekV2ForCausalLM": "model.model.layers", + "DeepseekV3ForCausalLM": "model.model.layers", + # Cohere (Command R) + "CohereForCausalLM": "model.model.layers", + # OLMo (Allen AI) + "OlmoForCausalLM": "model.model.layers", + "Olmo2ForCausalLM": "model.model.layers", + # Falcon + "FalconForCausalLM": "model.transformer.h", + # GPT-2 + "GPT2LMHeadModel": "model.transformer.h", + # GPT-NeoX / Pythia + "GPTNeoXForCausalLM": "model.gpt_neox.layers", + # Hymba + "HymbaForCausalLM": "model.model.layers", +} + +# Common layer paths tried in order during dynamic fallback. +_FALLBACK_LAYER_PATHS = [ + "model.model.layers", + "model.transformer.h", + "model.gpt_neox.layers", + "model.layers", +] + + +def _resolve_layers(model: PreTrainedModel, layers_attribute: str): + """Walk the dot-separated layers_attribute path on model and return the layer list.""" + obj = model + for attr in layers_attribute.split(".")[1:]: # skip leading "model" + obj = getattr(obj, attr) + return obj + + +def _get_layers_attribute(model: PreTrainedModel, lisa_layers_attribute: Optional[str] = None) -> str: + """Resolve the dot-separated path to the model's transformer layers. + + Resolution order: + 1. User-supplied lisa_layers_attribute override (highest priority). + 2. CLASS_TO_LAYERS_MAP lookup by model class name. + 3. Dynamic introspection across known common paths. + + Raises ValueError if no path can be found. + """ + unwrapped = model.module if hasattr(model, "module") else model + model_class_name = type(unwrapped).__name__ + + # 1. User override takes highest priority + if lisa_layers_attribute is not None: + return lisa_layers_attribute + + # 2. Known architecture map + if model_class_name in CLASS_TO_LAYERS_MAP: + return CLASS_TO_LAYERS_MAP[model_class_name] + + # 3. Dynamic fallback — inspect the actual model object + for path in _FALLBACK_LAYER_PATHS: + try: + obj = unwrapped + for attr in path.split("."): + obj = getattr(obj, attr) + if isinstance(obj, (list, nn.ModuleList)): + logger.warning( + "Model class '%s' not in CLASS_TO_LAYERS_MAP. " + "Dynamically detected layers at '%s'. " + "Consider adding '%s' to CLASS_TO_LAYERS_MAP in lisa_trainer.py.", + model_class_name, path, model_class_name, + ) + return path + except AttributeError: + continue + + raise ValueError( + f"Cannot locate transformer layers for model class '{model_class_name}'. " + f"Set lisa_layers_attribute in FinetunerArguments to the dot-separated " + f"path (e.g. 'model.model.layers'), or add '{model_class_name}' to " + f"CLASS_TO_LAYERS_MAP in src/lmflow/pipeline/utils/lisa_trainer.py." + ) + class DynamicLayerActivationCallback(TrainerCallback): def __init__( @@ -18,49 +120,28 @@ def __init__( self.interval_steps = interval_steps self.model = model - # Determine the way to access layers based on the model type - class_to_layers_map = { - "LlamaForCausalLM": "model.model.layers", - "Qwen2ForCausalLM": "model.model.layers", - "MistralForCausalLM": "model.model.layers", - "MixtralForCausalLM": "model.model.layers", - "GemmaForCausalLM": "model.model.layers", - "GPT2LMHeadModel": "model.transformer.h", - "HymbaForCausalLM": "model.model.layers", - } - model_class_name = self.model.__class__.__name__ - if model_class_name in class_to_layers_map: - self.layers_attribute = class_to_layers_map[model_class_name] - else: - assert lisa_layers_attribute is not None, "Please provide the attribute to access the layers of the model." - self.layers_attribute = lisa_layers_attribute - self.total_layers = len( - eval("self." + self.layers_attribute) - ) # Dynamically execute to get the number of layers + self.layers_attribute = _get_layers_attribute(model, lisa_layers_attribute) + self.total_layers = len(_resolve_layers(self.model, self.layers_attribute)) self.active_layers_indices = [] def freeze_all_layers(self): - layers = eval("self." + self.layers_attribute) # Dynamically execute to get layers + layers = _resolve_layers(self.model, self.layers_attribute) for layer in layers: for param in layer.parameters(): param.requires_grad = False def on_step_begin(self, args, state, control, **kwargs): - # Check if it's time to switch active layers, including at step 0 if state.global_step % self.interval_steps == 0: self.switch_active_layers() def switch_active_layers(self): - # First, disable gradients for all layers self.freeze_all_layers() - # Randomly select n_layers to activate - layers = eval("self." + self.layers_attribute) # Re-fetch layer references + layers = _resolve_layers(self.model, self.layers_attribute) self.active_layers_indices = np.random.choice(range(self.total_layers), self.n_layers, replace=False) - print(f"Activating layers at indices: {self.active_layers_indices} for the next steps.", flush=True) + logger.info("Activating layers at indices: %s for the next steps.", self.active_layers_indices) - # Enable gradients only for the selected layers for idx in self.active_layers_indices: for param in layers[idx].parameters(): param.requires_grad = True diff --git a/tests/pipeline/test_lisa_trainer.py b/tests/pipeline/test_lisa_trainer.py new file mode 100644 index 000000000..fab5ab128 --- /dev/null +++ b/tests/pipeline/test_lisa_trainer.py @@ -0,0 +1,89 @@ +import types +import pytest +import torch.nn as nn + +from lmflow.pipeline.utils.lisa_trainer import ( + CLASS_TO_LAYERS_MAP, + _get_layers_attribute, +) + + +def make_mock_model(class_name: str, layers_path: str = "model.model.layers", num_layers: int = 4): + """Build a minimal mock model with layers at the given dot-separated path. + + Uses SimpleNamespace for nested attributes so we avoid nn.Module.__init__ + complexity. The top-level object is given the requested class name so that + type(model).__name__ returns it correctly. + + For example, layers_path="model.model.layers" creates: + mock.model.model.layers = ModuleList([...]) + """ + layers = nn.ModuleList([nn.Linear(8, 8) for _ in range(num_layers)]) + + current = layers + for part in reversed(layers_path.split(".")): + parent = types.SimpleNamespace() + setattr(parent, part, current) + current = parent + + MockClass = type(class_name, (object,), {}) + instance = object.__new__(MockClass) + instance.__dict__.update(vars(current)) + return instance + + +class TestGetLayersAttribute: + + def test_known_architecture_uses_map(self): + """LLaMA is in CLASS_TO_LAYERS_MAP — should return the mapped path directly.""" + model = make_mock_model("LlamaForCausalLM", "model.model.layers") + result = _get_layers_attribute(model) + assert result == "model.model.layers" + + def test_newly_added_architecture_gemma2(self): + """Gemma2 was added to the expanded map — should resolve without fallback.""" + model = make_mock_model("Gemma2ForCausalLM", "model.model.layers") + result = _get_layers_attribute(model) + assert result == "model.model.layers" + assert "Gemma2ForCausalLM" in CLASS_TO_LAYERS_MAP + + def test_falcon_maps_to_transformer_h(self): + """FalconForCausalLM maps to model.transformer.h — verifies non-default path entries.""" + model = make_mock_model("FalconForCausalLM", "model.transformer.h") + result = _get_layers_attribute(model) + assert result == "model.transformer.h" + + def test_user_override_takes_precedence_over_map(self): + """User-supplied lisa_layers_attribute must win even for known architectures. + + Uses a custom path that differs from both the map entry and all fallback + paths, so the only way the test passes is if the override is truly used. + """ + model = make_mock_model("LlamaForCausalLM", "model.model.layers") + result = _get_layers_attribute(model, lisa_layers_attribute="model.custom.blocks") + assert result == "model.custom.blocks" + + def test_dynamic_fallback_finds_transformer_h(self): + """Unknown model with layers at model.transformer.h — fallback iterates past first entry.""" + model = make_mock_model("BrandNewGPTModel", "model.transformer.h") + result = _get_layers_attribute(model) + assert result == "model.transformer.h" + + def test_completely_unknown_model_raises_valueerror(self): + """Unknown model with no recognizable layer path should raise a clear ValueError.""" + model = make_mock_model("WeirdModelWithNoLayers", "model.model.layers") + model.__dict__.clear() + with pytest.raises(ValueError, match="Cannot locate transformer layers"): + _get_layers_attribute(model) + + def test_dataparallel_wrapped_model_unwrapped(self): + """Model wrapped in DataParallel (.module) should be unwrapped before class lookup.""" + inner = make_mock_model("LlamaForCausalLM", "model.model.layers") + + # Simulate DataParallel wrapping: outer object has a .module attribute + WrapperClass = type("DataParallel", (object,), {}) + wrapper = object.__new__(WrapperClass) + wrapper.module = inner + + result = _get_layers_attribute(wrapper) + assert result == "model.model.layers"