Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions examples/megatron/configs/MI300X/mamba_370M-pretrain.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ modules:
# Mamba-specific: must provide spec
spec: ['megatron.core.models.mamba.mamba_layer_specs', 'mamba_stack_spec']

# Tokenizer
tokenizer_type: HuggingFaceTokenizer
tokenizer_model: EleutherAI/gpt-neox-20b

# parallel
tensor_model_parallel_size: 1
pipeline_model_parallel_size: 1
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
work_group: ${PRIMUS_TEAM:amd}
user_name: ${PRIMUS_USER:root}
exp_name: ${PRIMUS_EXP_NAME:mamba_370M_sft_posttrain}
workspace: ${PRIMUS_WORKSPACE:./output}

modules:
post_trainer:
framework: megatron_bridge
config: sft_trainer.yaml

# Model to run
model: mamba_370M.yaml

overrides:
stderr_sink_level: DEBUG

# Parallelism configuration
tensor_model_parallel_size: 1
pipeline_model_parallel_size: 1
pipeline_dtype: null
virtual_pipeline_model_parallel_size: null
context_parallel_size: 1
sequence_parallel: false
use_megatron_fsdp: false

# Finetuning-specific params
#pretrained_checkpoint: null
peft: "none"
packed_sequence: false

# Training configuration
train_iters: 200
global_batch_size: 128
micro_batch_size: 4
seq_length: 2048
eval_interval: 30
save_interval: 50

# Optimizer configuration
finetune_lr: 5.0e-6
min_lr: 0.0
lr_warmup_iters: 50
lr_decay_iters: null

# W&B logging
wandb_project: null
wandb_entity: null
wandb_exp_name: null

# Precision
precision_config: bf16_mixed
comm_overlap_config: null

# Turbo - disabled for Mamba (not supported)
enable_primus_turbo: false
use_turbo_attention: false
use_turbo_grouped_mlp: false
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
work_group: ${PRIMUS_TEAM:amd}
user_name: ${PRIMUS_USER:root}
exp_name: ${PRIMUS_EXP_NAME:zebra_llama_1B_sft_posttrain}
workspace: ${PRIMUS_WORKSPACE:./output}

modules:
post_trainer:
framework: megatron_bridge
config: sft_trainer.yaml

# Model to run
model: zebra_llama_1B.yaml

overrides:
stderr_sink_level: DEBUG

# Parallelism configuration
tensor_model_parallel_size: 1
pipeline_model_parallel_size: 1
pipeline_dtype: null
virtual_pipeline_model_parallel_size: null
context_parallel_size: 1
sequence_parallel: false
use_megatron_fsdp: false

# Finetuning-specific params
#pretrained_checkpoint: null
peft: "none"
packed_sequence: false

# Training configuration
train_iters: 200
global_batch_size: 64
micro_batch_size: 8
seq_length: 8192
eval_interval: 30
save_interval: 50

# Optimizer configuration
finetune_lr: 5.0e-6
min_lr: 0.0
lr_warmup_iters: 50
lr_decay_iters: null

# W&B logging
wandb_project: null
wandb_entity: null
wandb_exp_name: null

# Precision
precision_config: bf16_mixed
comm_overlap_config: null

# Turbo - disabled for hybrid Mamba+MLA (not supported)
enable_primus_turbo: false
use_turbo_attention: false
use_turbo_grouped_mlp: false
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
work_group: ${PRIMUS_TEAM:amd}
user_name: ${PRIMUS_USER:root}
exp_name: ${PRIMUS_EXP_NAME:zebra_llama_3B_sft_posttrain}
workspace: ${PRIMUS_WORKSPACE:./output}

modules:
post_trainer:
framework: megatron_bridge
config: sft_trainer.yaml

# Model to run
model: zebra_llama_3B.yaml

overrides:
stderr_sink_level: DEBUG

# Parallelism configuration
tensor_model_parallel_size: 1
pipeline_model_parallel_size: 1
pipeline_dtype: null
virtual_pipeline_model_parallel_size: null
context_parallel_size: 1
sequence_parallel: false
use_megatron_fsdp: false

# Finetuning-specific params
#pretrained_checkpoint: null
peft: "none"
packed_sequence: false

# Training configuration
train_iters: 200
global_batch_size: 32
micro_batch_size: 4
seq_length: 8192
eval_interval: 30
save_interval: 50

# Optimizer configuration
finetune_lr: 5.0e-6
min_lr: 0.0
lr_warmup_iters: 50
lr_decay_iters: null

# W&B logging
wandb_project: null
wandb_entity: null
wandb_exp_name: null

# Precision
precision_config: bf16_mixed
comm_overlap_config: null

# Turbo - disabled for hybrid Mamba+MLA (not supported)
enable_primus_turbo: false
use_turbo_attention: false
use_turbo_grouped_mlp: false
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
work_group: ${PRIMUS_TEAM:amd}
user_name: ${PRIMUS_USER:root}
exp_name: ${PRIMUS_EXP_NAME:zebra_llama_8B_sft_posttrain}
workspace: ${PRIMUS_WORKSPACE:./output}

modules:
post_trainer:
framework: megatron_bridge
config: sft_trainer.yaml

# Model to run
model: zebra_llama_8B.yaml

overrides:
stderr_sink_level: DEBUG

# Parallelism configuration
tensor_model_parallel_size: 1
pipeline_model_parallel_size: 1
pipeline_dtype: null
virtual_pipeline_model_parallel_size: null
context_parallel_size: 1
sequence_parallel: false
use_megatron_fsdp: false

# Finetuning-specific params
#pretrained_checkpoint: null
peft: "none"
packed_sequence: false

# Training configuration
train_iters: 200
global_batch_size: 16
micro_batch_size: 2
seq_length: 8192
eval_interval: 30
save_interval: 50

# Optimizer configuration
finetune_lr: 5.0e-6
min_lr: 0.0
lr_warmup_iters: 50
lr_decay_iters: null

# W&B logging
wandb_project: null
wandb_entity: null
wandb_exp_name: null

# Precision
precision_config: bf16_mixed
comm_overlap_config: null

# Turbo - disabled for hybrid Mamba+MLA (not supported)
enable_primus_turbo: false
use_turbo_attention: false
use_turbo_grouped_mlp: false
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@
)
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec

# Import MambaStack from relative path
from megatron.core.ssm.mamba_block import MambaStack, MambaStackSubmodules
from megatron.core.ssm.mamba_layer import MambaLayer, MambaLayerSubmodules
from megatron.core.ssm.mamba_mixer import MambaMixer, MambaMixerSubmodules
from megatron.core.ssm.mlp_layer import MLPLayer
Expand All @@ -22,6 +19,12 @@
MLASelfAttentionSubmodules,
)

# Import HybridStack from relative path
from primus.backends.megatron.core.models.hybrid.hybrid_block import (
HybridStack,
HybridStackSubmodules,
)

# Inference layers may not be available in older Megatron versions
# They're only used in hybrid_inference_stack_spec, not the training spec
try:
Expand Down Expand Up @@ -49,11 +52,12 @@
use_te=True,
num_experts=8, # Can be any positive integer (must not be None).
moe_grouped_gemm=True,
moe_use_legacy_grouped_gemm=False,
)

hybrid_stack_spec = ModuleSpec(
module=MambaStack,
submodules=MambaStackSubmodules(
module=HybridStack,
submodules=HybridStackSubmodules(
mamba_layer=ModuleSpec(
module=MambaLayer,
submodules=MambaLayerSubmodules(
Expand Down
8 changes: 3 additions & 5 deletions primus/backends/megatron/megatron_pretrain_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,9 @@ def train(self):

log_rank_0("Using GPT model provider and training components")

# Upstream pretrain entrypoints set this in their __main__ blocks, but Primus imports the
# provider directly and calls pretrain() programmatically. Without restoring this flag,
# only TP rank 0 enters dataset construction while the core dataset builder still issues
# distributed barriers, which deadlocks for TP>1.
train_valid_test_datasets_provider.is_distributed = True
# Configure training components
if hasattr(train_valid_test_datasets_provider, "is_distributed"):
train_valid_test_datasets_provider.is_distributed = True

# Handle Megatron version differences (v0.12.0 vs newer with inprocess_restart)
wrapped_pretrain = pretrain
Expand Down
51 changes: 40 additions & 11 deletions primus/backends/megatron_bridge/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,43 @@ def _merge_dict_to_dataclass(target: Any, source_dict: dict, path: str = "") ->
)


def _resolve_recipe(recipe: str, flavor: str):
"""
Resolve a recipe module and function by searching multiple namespaces.

Search order:
1. primus.backends.megatron_bridge.recipes.{recipe} (Primus-side extensions)
2. megatron.bridge.recipes.{recipe} (upstream Megatron-Bridge)

Returns:
Tuple of (module, full_module_path) for the first namespace that
contains the requested *flavor* function.

Raises:
AssertionError if the recipe cannot be found in any namespace.
"""
search_prefixes = [
"primus.backends.megatron_bridge.recipes",
"megatron.bridge.recipes",
]

for prefix in search_prefixes:
full_module_path = f"{prefix}.{recipe}"
try:
module = importlib.import_module(full_module_path)
except ImportError:
continue
if hasattr(module, flavor):
return module, full_module_path

# Build a helpful error message listing all paths that were tried.
tried = [f"{p}.{recipe}" for p in search_prefixes]
assert False, (
f"Recipe loading failed: Function '{flavor}' not found. "
f"Searched modules: {tried}"
)


def load_recipe_config(backend_args: SimpleNamespace) -> Any:
recipe = backend_args.recipe
flavor = backend_args.flavor
Expand All @@ -193,21 +230,13 @@ def load_recipe_config(backend_args: SimpleNamespace) -> Any:
assert recipe, "Recipe must be specified for Megatron-Bridge backend"
assert flavor, "Flavor must be specified for Megatron-Bridge backend"

# Construct full module path and function name
full_module_path = f"megatron.bridge.recipes.{recipe}"
function_name = flavor

log_rank_0(f"Loading recipe: {full_module_path}.{function_name}()")
# Resolve recipe module from Primus-side extensions or upstream Megatron-Bridge
module, full_module_path = _resolve_recipe(recipe, function_name)

# Import module and get function
try:
module = importlib.import_module(full_module_path)
except ImportError as e:
assert False, f"Recipe loading failed: Cannot import '{full_module_path}': {e}"
log_rank_0(f"Loading recipe: {full_module_path}.{function_name}()")

assert hasattr(
module, function_name
), f"Recipe loading failed: Function '{function_name}' not found in '{full_module_path}'"
recipe_func = getattr(module, function_name)

# Convert backend_args to dict once (used for both recipe call and config override)
Expand Down
7 changes: 4 additions & 3 deletions primus/backends/megatron_bridge/megatron_bridge_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,21 @@ def __init__(self, framework: str = "megatron_bridge"):
super().__init__(framework)
self.third_party_dir_name = "Megatron-Bridge"

def load_trainer_class(self, stage: str = "pretrain"):
def load_trainer_class(self, stage: str = "sft"):
"""
Return the Megatron-Bridge Trainer class for the specified training stage.

Args:
stage: Training stage ("sft" for supervised fine-tuning)
stage: Training stage ("sft" for supervised fine-tuning,
"pretrain" also routes to SFT trainer)

Returns:
Trainer class for the specified stage

Raises:
ValueError: If stage is not supported
"""
if stage == "sft":
if stage in ("pretrain", "sft"):
from primus.backends.megatron_bridge.megatron_bridge_posttrain_trainer import (
MegatronBridgePosttrainTrainer,
)
Expand Down
Loading
Loading