Skip to content

Add blenderbot-3B encoder-decoder NeuronX port#98

Open
dhwanw wants to merge 3 commits intomainfrom
contrib/blenderbot-3B
Open

Add blenderbot-3B encoder-decoder NeuronX port#98
dhwanw wants to merge 3 commits intomainfrom
contrib/blenderbot-3B

Conversation

@dhwanw
Copy link

@dhwanw dhwanw commented Mar 20, 2026

Summary

  • Adds NeuronX port of facebook/blenderbot-3B encoder-decoder model following the Whisper pattern
  • Separate encoder and decoder NeuronApplicationBase subclasses with prefill/decode dispatch
  • Key fixes: encoder attention mask, cross-attention mask, and cross-attention dead code elimination prevention
  • Validated at 87.9% token match across 5 prompts (20 tokens each)

Test plan

  • Run integration tests: pytest contrib/models/blenderbot-3B/test/integration/test_blenderbot_inference.py
  • Verify smoke test (model loads, correct config dimensions)
  • Verify per-prompt token match >= 70%
  • Verify aggregate accuracy >= 75%

🤖 Generated with Claude Code

dhwanw and others added 3 commits March 20, 2026 02:12
- Google Gemma 2 9B with GQA (16 heads, 8 KV heads), head_dim=256
- 42 decoder layers, RMSNorm, GELU tanh, tied embeddings
- Final logit softcapping (30.0)
- Custom Gemma2NeuronConfig for model-specific settings
- Validated: 99.38% teacher-forced / 84.69% greedy (TP=8, seq_len=128, bf16)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
47.1 tok/s, CE MBU=1.3%/MFU=0.6%, TG MBU=2.3%/MFU=0.0%
Profiled on trn1.32xlarge, TP=8, seq_len=128, BF16.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
First encoder-decoder (seq2seq) model in contrib, following the Whisper
pattern with separate encoder/decoder NeuronApplicationBase subclasses.

Key fixes applied to the original handoff port:
- Added encoder attention mask to prevent attention to padding tokens
- Added cross-attention mask to prevent decoder attending to encoder padding
- Fixed cross-attention to always compute K/V from encoder output (prevents
  Neuron tracer dead code elimination of encoder output during decode)
- Fixed decoder dispatch between prefill and decode model wrappers

Validated: 87.9% token match across 5 prompts (TP=8, float32).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant