Skip to content

[contrib] Add Baichuan2-7B-Base NeuronX port#81

Open
dhwanw wants to merge 5 commits intomainfrom
contrib/baichuan2-7b-base
Open

[contrib] Add Baichuan2-7B-Base NeuronX port#81
dhwanw wants to merge 5 commits intomainfrom
contrib/baichuan2-7b-base

Conversation

@dhwanw
Copy link

@dhwanw dhwanw commented Mar 17, 2026

Description

NeuronX Distributed Inference port of baichuan-inc/Baichuan2-7B-Base, a Llama-2 architecture variant with fused W_pack QKV weights and NormHead lm_head (L2-normalized). The port handles direct loading to bypass trust_remote_code, fused QKV decomposition, and pre-normalized lm_head weight conversion.

Model Information

Model Name: Baichuan2-7B-Base
Model Architecture: Decoder-only transformer (Llama-2 variant) -- 32 layers, 32 MHA heads (head_dim=128), fused W_pack QKV, NormHead lm_head with L2 normalization
Purpose: Multilingual text generation (Chinese/English)

Checklist

Required Components

  • Accuracy Test (test/integration/test_model.py)
    • Validates model generation and coherence
    • Performance benchmarks (TTFT, throughput)
    • Test can compile and run the model on Neuron
  • README.md with the following sections:
    • Usage Example: Clear code example showing how to use the model
    • Compatibility Matrix: Table showing tested Neuron SDK versions and instance types
    • Example Checkpoints: Links to compatible model checkpoints
    • Testing Instructions: Command to run the test suite for the model
  • Source Code (src/)
    • Modeling code following NxD Inference patterns

Optional Components

  • Unit Tests (CPU or Neuron-based)

Folder Structure

/contrib/models/Baichuan2-7B-Base/
  README.md
  /src
    modeling_baichuan2.py
  /test
    /integration
      test_model.py

Testing

Model was compiled and tested with TP=2, batch_size=1, seq_len=256, bfloat16 on trn1.32xlarge.

Test Results:

Test Status Result
Smoke Test ✅ PASS Model loads successfully
Greedy Token Matching ✅ PASS 54.84% average (4/10 prompts at 100%)
Teacher-Forced Match ✅ PASS 98.59% average
Throughput ✅ PASS 16.6 tok/s

The high teacher-forced rate confirms the model is functionally correct. Lower greedy match on some prompts is due to BF16 precision causing early divergence that cascades into different generation paths.

Compatibility

Tested with:

  • Neuron SDK Version(s): 2.22
  • Instance Type(s): trn1.32xlarge
  • PyTorch Version: 2.9
  • Python Version: 3.10
  • Configuration: TP=2, batch_size=1, seq_len=256, bfloat16

Additional Information

  • W_pack (fused QKV): Stores Q/K/V as a single fused tensor W_pack.weight [3*H, H], split into separate projections during weight conversion.
  • NormHead lm_head: Applies L2 normalization to lm_head weights; pre-normalized during weight conversion.
  • Direct loading: Bypasses trust_remote_code by loading config.json and safetensors directly, adding missing Llama-required keys.

Related Issues

N/A

vLLM Integration

  • This model/feature is intended for use with vLLM
  • Documentation includes vLLM registration instructions

By submitting this PR, I confirm that:

  • I have read and followed the contributing guidelines
  • This is a community contribution and may have limited testing compared to officially-supported models
  • The code follows best practices and is well-documented
  • All required components listed above are included

dhwanw and others added 5 commits March 5, 2026 23:07
Extends NeuronLlamaForCausalLM with custom weight conversion for Baichuan2's
W_pack fused QKV split and NormHead lm_head normalization. Bypasses
trust_remote_code by loading config.json and safetensors directly.

Validated: 54.84% greedy, 98.59% teacher-forced (TP=2, bs=1, seq=128, bf16).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
SLURM job 7871 confirms contrib model passes token matching:
- Greedy match: 54.84% (351/640 tokens, >= 50% threshold)
- Teacher-forced match: 98.59% (>= 95% threshold)
- 4/10 prompts at 100% greedy match

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Add Apache license header and expanded docstring to modeling file
- Add copyright header to __init__.py
- Add standard helper functions (load_neuron_config_from_compiled,
  create_model_for_inference) to test_model.py
- Add generation_config fixture and performance tests (TTFT, throughput)
- Use /home/ubuntu/ path convention for MODEL_PATH/COMPILED_MODEL_PATH
- Use standard __main__ block with create_model_for_inference
- Simplify README architecture section to match Llama-2 format
- Add manual run instructions to README Testing section
- Remove non-standard files: test_token_match.py, run_validation.sh,
  validation_7871.out

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Use consistent CE/TG column table format across all contrib models.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@dhwanw dhwanw marked this pull request as ready for review March 19, 2026 19:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant