diff --git a/nemo_rl/utils/flops_formulas.py b/nemo_rl/utils/flops_formulas.py index fe55e3b4cb..6a73e203b3 100644 --- a/nemo_rl/utils/flops_formulas.py +++ b/nemo_rl/utils/flops_formulas.py @@ -57,6 +57,7 @@ class FLOPSConfig: mamba_head_dim: Optional[int] = None mamba_num_groups: Optional[int] = None mamba_num_heads: Optional[int] = None + gated_linear_unit: Optional[int] = None def gpt3(config: FLOPSConfig): diff --git a/nemo_rl/utils/flops_tracker.py b/nemo_rl/utils/flops_tracker.py index bf09a210db..d0940cc5c8 100644 --- a/nemo_rl/utils/flops_tracker.py +++ b/nemo_rl/utils/flops_tracker.py @@ -24,7 +24,7 @@ from transformers.models.qwen3.configuration_qwen3 import Qwen3Config from transformers.models.qwen3_moe.configuration_qwen3_moe import Qwen3MoeConfig -from nemo_rl.utils.flops_formulas import FLOPSConfig, deepseekv3, llama, qwen2, qwen3 +from nemo_rl.utils.flops_formulas import FLOPSConfig, deepseekv3, llama, qwen2, qwen3, nemotronh def get_default_hf_config(model_name: str) -> PretrainedConfig: @@ -96,6 +96,23 @@ def convert_config_to_flops_config( mtp_num_layers=0, causal_self_attn=True, ), deepseekv3 + elif config.__class__.model_type == "nemotron_h": + return FLOPSConfig( + gbs=1, + enc_seq_len=config.max_position_embeddings if hasattr(config, "max_position_embeddings") else 2048, + hs=config.hidden_size, + ffn_hs=config.intermediate_size, + gated_linear_unit=None, + attention_heads=config.num_attention_heads, + query_groups=config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads, + mamba_state_dim=config.ssm_state_size, + mamba_head_dim=config.mamba_head_dim, + mamba_num_heads=config.mamba_num_heads, + mamba_num_groups=config.n_groups, + is_hybrid_model=True, + hybrid_override_pattern=config.hybrid_override_pattern, + vocab_size=config.vocab_size + ), nemotronh else: raise ValueError(f"Unsupported config type: {type(config)}")