[contrib] Add DiffLlama-0.3B-handcut NeuronX port#75
Open
Conversation
NeuronX port of DiffLlama-0.3B-handcut with differential attention mechanism. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Use consistent CE/TG column table format across all contrib models. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
NeuronX Distributed Inference port of kajuma/DiffLlama-0.3B-handcut, a 0.3B-parameter Differential Transformer. DiffLlama implements a novel attention mechanism where V is transformed before the attention matmul (heads chunked into 2 halves, concatenated along head_dim), and after attention the output is split and subtracted with learned lambda parameters. This requires overriding the full attention forward since NXDI's built-in attention kernels cannot handle the modified V shape.
Model Information
Model Name: DiffLlama-0.3B-handcut
Model Architecture: Decoder-only transformer with differential attention, GQA (32 Q heads / 8 KV heads), Llama3-style RoPE scaling, SwiGLU MLP, RMSNorm, tied embeddings
Purpose: Text generation with differential attention mechanism
Checklist
Required Components
test/integration/test_model.py)src/)Optional Components
Folder Structure
Testing
Model was compiled and tested with TP=1, batch_size=1, seq_len=128, bfloat16 on trn1.32xlarge.
Test Results:
Compatibility
Tested with:
Additional Information
standard_causal_attention_forwardis fully overridden with manual matmul + softmax + matmul.torch.triuto avoid XLA shape broadcasting issues with framework-provided masks.Llama3RotaryEmbeddingwith frequency-dependent scaling (high-frequency unchanged, low-frequency scaled by factor, mid-frequency interpolated).Related Issues
N/A
vLLM Integration
By submitting this PR, I confirm that: