Skip to content

[contrib] Add DiffLlama-0.3B-handcut NeuronX port#75

Open
dhwanw wants to merge 3 commits intomainfrom
contrib/DiffLlama-0.3B-handcut
Open

[contrib] Add DiffLlama-0.3B-handcut NeuronX port#75
dhwanw wants to merge 3 commits intomainfrom
contrib/DiffLlama-0.3B-handcut

Conversation

@dhwanw
Copy link

@dhwanw dhwanw commented Mar 17, 2026

Description

NeuronX Distributed Inference port of kajuma/DiffLlama-0.3B-handcut, a 0.3B-parameter Differential Transformer. DiffLlama implements a novel attention mechanism where V is transformed before the attention matmul (heads chunked into 2 halves, concatenated along head_dim), and after attention the output is split and subtracted with learned lambda parameters. This requires overriding the full attention forward since NXDI's built-in attention kernels cannot handle the modified V shape.

Model Information

Model Name: DiffLlama-0.3B-handcut
Model Architecture: Decoder-only transformer with differential attention, GQA (32 Q heads / 8 KV heads), Llama3-style RoPE scaling, SwiGLU MLP, RMSNorm, tied embeddings
Purpose: Text generation with differential attention mechanism

Checklist

Required Components

  • Accuracy Test (test/integration/test_model.py)
    • Validates model generation and coherence
    • Performance benchmarks (TTFT, throughput)
    • Test can compile and run the model on Neuron
  • README.md with the following sections:
    • Usage Example: Clear code example showing how to use the model
    • Compatibility Matrix: Table showing tested Neuron SDK versions and instance types
    • Example Checkpoints: Links to compatible model checkpoints
    • Testing Instructions: Command to run the test suite for the model
  • Source Code (src/)
    • Modeling code following NxD Inference patterns

Optional Components

  • Unit Tests (CPU or Neuron-based)

Folder Structure

/contrib/models/DiffLlama-0.3B-handcut/
  README.md
  /src
    modeling_diffllama.py
  /test
    /integration
      test_model.py

Testing

Model was compiled and tested with TP=1, batch_size=1, seq_len=128, bfloat16 on trn1.32xlarge.

Test Results:

Test Status Result
Smoke Test ✅ PASS Model loads successfully
Greedy Token Matching ✅ PASS 94.69% average (7/10 prompts at 100%)
Teacher-Forced Match ✅ PASS 99.38% average
Throughput ✅ PASS 56.6 tok/s

Compatibility

Tested with:

  • Neuron SDK Version(s): 2.22
  • Instance Type(s): trn1.32xlarge
  • PyTorch Version: 2.9
  • Python Version: 3.10
  • Configuration: TP=1, batch_size=1, seq_len=128, bfloat16

Additional Information

  • Differential attention: V is transformed before the attention matmul -- heads are chunked into 2 halves, concatenated along head_dim (doubling it to 128), then repeated. After attention, the output is split and subtracted with learned lambda parameters, followed by RMSNorm on 2*head_dim features.
  • Custom attention forward: NXDI's built-in attention kernels cannot handle the modified V shape, so standard_causal_attention_forward is fully overridden with manual matmul + softmax + matmul.
  • Causal mask: Generated internally via torch.triu to avoid XLA shape broadcasting issues with framework-provided masks.
  • Llama3 RoPE scaling: Custom Llama3RotaryEmbedding with frequency-dependent scaling (high-frequency unchanged, low-frequency scaled by factor, mid-frequency interpolated).
  • Dev transformers required: DiffLlama is not yet in mainline HuggingFace transformers.

Related Issues

N/A

vLLM Integration

  • This model/feature is intended for use with vLLM
  • Documentation includes vLLM registration instructions

By submitting this PR, I confirm that:

  • I have read and followed the contributing guidelines
  • This is a community contribution and may have limited testing compared to officially-supported models
  • The code follows best practices and is well-documented
  • All required components listed above are included

dhwanw and others added 3 commits March 5, 2026 22:20
NeuronX port of DiffLlama-0.3B-handcut with differential attention mechanism.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Use consistent CE/TG column table format across all contrib models.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@dhwanw dhwanw marked this pull request as ready for review March 19, 2026 19:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant