-
Notifications
You must be signed in to change notification settings - Fork 28
[contrib] Add DiffLlama-0.3B-handcut NeuronX port #75
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
dhwanw
wants to merge
3
commits into
main
Choose a base branch
from
contrib/DiffLlama-0.3B-handcut
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,127 @@ | ||
| # Contrib Model: DiffLlama | ||
|
|
||
| NeuronX Distributed Inference implementation of DiffLlama (Differential Transformer). | ||
|
|
||
| ## Model Information | ||
|
|
||
| - **HuggingFace ID:** `kajuma/DiffLlama-0.3B-handcut` | ||
| - **Model Type:** Decoder-only transformer with differential attention | ||
| - **Parameters:** 0.3B | ||
| - **License:** Apache-2.0 | ||
|
|
||
| ## Architecture Details | ||
|
|
||
| | Property | Value | | ||
| |----------|-------| | ||
| | Hidden Size | 2048 | | ||
| | Num Attention Heads | 32 (GQA: 8 KV heads) | | ||
| | Head Dimension | 64 | | ||
| | Num Hidden Layers | 16 | | ||
| | Vocab Size | 128256 | | ||
| | Max Position Embeddings | 131072 | | ||
| | Intermediate Size | 8192 | | ||
| | Position Encoding | RoPE with llama3 scaling (factor=32, original_max=8192) | | ||
| | Residual Connection | Pre-norm (LN -> Attn -> Add -> LN -> MLP -> Add) | | ||
| | Normalization | RMSNorm (eps=1e-5) | | ||
| | Activation | SiLU (SwiGLU MLP) | | ||
| | LM Head | Tied with embed_tokens | | ||
|
|
||
| ### Key Implementation Notes | ||
|
|
||
| - **Differential attention:** Unlike standard attention, DiffLlama transforms V before the attention matmul — chunk heads into 2 halves, concatenate along head_dim (doubling it to 128), then repeat. After attention, the output is split into 2 head groups and subtracted with learned lambda parameters, followed by RMSNorm on 2*head_dim features. | ||
| - **standard_causal_attention_forward override:** The full attention forward is overridden because NXDI's built-in attention kernels cannot handle the modified V shape. Manual attention (matmul + softmax + matmul) compiles correctly to XLA/HLO. | ||
| - **Causal mask:** Generated internally via `torch.triu` rather than using the framework-provided mask, which avoids XLA shape broadcasting issues. | ||
| - **Llama3 RoPE scaling:** Custom `Llama3RotaryEmbedding` extends NXDI's `RotaryEmbedding` with frequency-dependent scaling (high-frequency components unchanged, low-frequency scaled by factor, mid-frequency interpolated). | ||
| - **KV cache:** Stores original K, V (before V transformation); the transformation is reapplied at each token generation step. | ||
| - **HF transformers:** Requires custom transformers with DiffLlama support (not yet in mainline HuggingFace). Path: `/shared/dhwanw/agent_friday_test/example/transformers/src`. | ||
|
|
||
| ## Validation Results | ||
|
|
||
| **Validated:** 2026-03-05 | ||
| **Configuration:** TP=1, batch_size=1, seq_len=128, bfloat16 | ||
|
|
||
| ### Test Results | ||
|
|
||
| | Test | Status | Result | | ||
| |------|--------|--------| | ||
| | Smoke Test | PASS | Model loads successfully | | ||
| | Greedy Token Matching | PASS | **94.69% average** (7/10 prompts at 100%) | | ||
| | Teacher-Forced Match | PASS | **99.38% average** | | ||
|
|
||
| ### Greedy Match Details | ||
|
|
||
| 7 of 10 prompts achieve 100% greedy match. The 3 prompts with slight divergence (98.4%) occur on high-entropy natural language where BF16 precision causes late-stage cascading differences. One prompt ("Amazon River") diverges early (51.6% greedy) but maintains 98.4% teacher-forced match, indicating correct model behavior with a single early-token divergence that cascades. | ||
|
|
||
| ## Usage | ||
|
|
||
| ```python | ||
| from transformers import AutoTokenizer | ||
| from neuronx_distributed_inference.models.config import NeuronConfig | ||
|
|
||
| from src.modeling_diffllama import NeuronDiffLlamaForCausalLM, DiffLlamaInferenceConfig | ||
|
|
||
| model_path = "/path/to/DiffLlama-0.3B-handcut/" | ||
| compiled_model_path = "/path/to/compiled/" | ||
|
|
||
| # Configure | ||
| neuron_config = NeuronConfig( | ||
| tp_degree=1, | ||
| batch_size=1, | ||
| seq_len=128, | ||
| torch_dtype=torch.bfloat16, | ||
| ) | ||
|
|
||
| config = DiffLlamaInferenceConfig.from_pretrained( | ||
| model_path, | ||
| neuron_config=neuron_config, | ||
| ) | ||
|
|
||
| # Compile and load | ||
| model = NeuronDiffLlamaForCausalLM(model_path, config) | ||
| model.compile(compiled_model_path) | ||
| model.load(compiled_model_path) | ||
|
|
||
| # Generate | ||
| tokenizer = AutoTokenizer.from_pretrained(model_path) | ||
| # ... (see integration test for full example) | ||
| ``` | ||
|
|
||
| ## Performance | ||
|
|
||
| Profiled on trn1.32xlarge (single NeuronCore utilization): | ||
|
|
||
| | Metric | Context Encoding | Token Generation | | ||
| |--------|-----------------|------------------| | ||
| | Throughput | - | 56.6 tok/s | | ||
| | MBU (Memory) | 13.2% | 19.3% | | ||
| | MFU (Compute) | 5.3% | 0.1% | | ||
|
|
||
| *Batch size 1, sequence length 128, BF16 precision, TP=1* | ||
| ## Compatibility Matrix | ||
|
|
||
| | Instance/Version | 2.20+ | 2.19 and earlier | | ||
| |------------------|-------|------------------| | ||
| | Trn1 | Working | Not tested | | ||
| | Inf2 | Not tested | Not tested | | ||
|
|
||
| ## Testing | ||
|
|
||
| Run integration tests: | ||
|
|
||
| ```bash | ||
| pytest contrib/models/DiffLlama-0.3B-handcut/test/integration/test_model.py --capture=tee-sys | ||
| ``` | ||
|
|
||
| ## Example Checkpoints | ||
|
|
||
| * kajuma/DiffLlama-0.3B-handcut | ||
|
|
||
| ## References | ||
|
|
||
| - [Differential Transformer Paper](https://arxiv.org/abs/2410.05258) | ||
|
|
||
| ## Maintainer | ||
|
|
||
| Neuroboros Team - Annapurna Labs | ||
|
|
||
| **Last Updated:** 2026-03-05 | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| from .modeling_diffllama import NeuronDiffLlamaForCausalLM, DiffLlamaInferenceConfig | ||
|
|
||
| __all__ = ["NeuronDiffLlamaForCausalLM", "DiffLlamaInferenceConfig"] |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change to "Annapurna Labs"