Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions nemo_rl/utils/flops_formulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class FLOPSConfig:
mamba_head_dim: Optional[int] = None
mamba_num_groups: Optional[int] = None
mamba_num_heads: Optional[int] = None
gated_linear_unit: Optional[int] = None


def gpt3(config: FLOPSConfig):
Expand Down
19 changes: 18 additions & 1 deletion nemo_rl/utils/flops_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
from transformers.models.qwen3_moe.configuration_qwen3_moe import Qwen3MoeConfig

from nemo_rl.utils.flops_formulas import FLOPSConfig, deepseekv3, llama, qwen2, qwen3
from nemo_rl.utils.flops_formulas import FLOPSConfig, deepseekv3, llama, qwen2, qwen3, nemotronh


def get_default_hf_config(model_name: str) -> PretrainedConfig:
Expand Down Expand Up @@ -96,6 +96,23 @@ def convert_config_to_flops_config(
mtp_num_layers=0,
causal_self_attn=True,
), deepseekv3
elif config.__class__.model_type == "nemotron_h":
return FLOPSConfig(
gbs=1,
enc_seq_len=config.max_position_embeddings if hasattr(config, "max_position_embeddings") else 2048,
hs=config.hidden_size,
ffn_hs=config.intermediate_size,
gated_linear_unit=None,
attention_heads=config.num_attention_heads,
query_groups=config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads,
mamba_state_dim=config.ssm_state_size,
mamba_head_dim=config.mamba_head_dim,
mamba_num_heads=config.mamba_num_heads,
mamba_num_groups=config.n_groups,
is_hybrid_model=True,
hybrid_override_pattern=config.hybrid_override_pattern,
vocab_size=config.vocab_size
), nemotronh
else:
raise ValueError(f"Unsupported config type: {type(config)}")

Expand Down
Loading