Skip to content

[contrib] Add Gemma 2 9B NeuronX port#97

Open
dhwanw wants to merge 3 commits intomainfrom
contrib/gemma-2-9b
Open

[contrib] Add Gemma 2 9B NeuronX port#97
dhwanw wants to merge 3 commits intomainfrom
contrib/gemma-2-9b

Conversation

@dhwanw
Copy link

@dhwanw dhwanw commented Mar 20, 2026

Description

Adds NeuronX port of Gemma 2 9B (model_type=gemma2) to the contrib models collection.

Model Information

Field Value
Model google/gemma-2-9b
Architecture Gemma2ForCausalLM (decoder-only, GQA, 4 RMSNorm layers per block)
Parameters 9B
TP Degree 8
Precision BF16

Checklist

  • Model compiles successfully on Neuron
  • Token matching validated (84.69% greedy, 99.38% teacher-forced)
  • Performance profiled (47.1 tok/s)
  • README with architecture details, usage, validation results
  • Integration tests included

Folder Structure

contrib/models/gemma-2-9b/
├── README.md
├── src/
│   ├── __init__.py
│   └── modeling_gemma2.py
└── test/
    └── integration/
        └── test_model.py

Testing

  • Token Match (greedy): 84.69% (10 prompts, 32 tokens each)
  • Token Match (teacher-forced): 99.38%
  • Throughput: 47.1 tok/s (TP=8, BS=1, seq_len=128)
  • 8/10 prompts at 100% greedy match

Note: Uses custom Gemma2NeuronConfig with attn_cls=NeuronGemma2Attention. Attention logit softcapping disabled (NKI kernel limitation). Sliding window disabled (head_dim=256 exceeds NKI limit of 128). Final logit softcapping (30.0) omitted from forward pass since tanh is monotonic and doesn't affect greedy argmax.

Compatibility

  • Neuron SDK: 2.22+
  • Instance: trn1.32xlarge

🤖 Generated with Claude Code

dhwanw and others added 3 commits March 20, 2026 02:12
- Google Gemma 2 9B with GQA (16 heads, 8 KV heads), head_dim=256
- 42 decoder layers, RMSNorm, GELU tanh, tied embeddings
- Final logit softcapping (30.0)
- Custom Gemma2NeuronConfig for model-specific settings
- Validated: 99.38% teacher-forced / 84.69% greedy (TP=8, seq_len=128, bf16)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
47.1 tok/s, CE MBU=1.3%/MFU=0.6%, TG MBU=2.3%/MFU=0.0%
Profiled on trn1.32xlarge, TP=8, seq_len=128, BF16.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Remove redundant config overrides (sliding_window, attn_logit_softcapping
already force-set in add_derived_config), simplify overly defensive
exception handling in test helper, extract duplicate compilation logic
into compile_if_needed(), and remove unused GenerationConfig import.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@dhwanw dhwanw marked this pull request as ready for review March 21, 2026 05:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant