Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions docs/examples/qwen3-moe-eagle3-offline.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Eagle3 for Llama3 - Offline

## Introduction

This document provides a step-by-step guide on how to train the EAGLE3 model for the Llama3.1-8B-Instruct model in an offline manner. In offline training, we generate the hidden states required by EAGLE3 draft model beforehand and store them to the disk. During training, we load them back to the GPU memory. As offline training requires a lot of disk space, we do not recommend running this on large datasets such as Perfect-Blend.

## Training on ShareGPT dataset

### **Step 1. Prepare ShareGPT dataset**

First of all, we should download the dataset.

```shell
python ./scripts/prepare_data.py --dataset sharegpt
```

### **Step 2. Prepare Hidden States**

We need to prepare the hidden states for the training.

```shell
torchrun --nproc_per_node=8 \
    scripts/prepare_hidden_states.py \
    --target-model-path /home/data/weights/Qwen3-32B \
    --enable-aux-hidden-states \
    --data-path ./cache/dataset/sharegpt_train.jsonl \
    --chat-template qwen \
    --max-length 2048 \
    --tp-size 8 \
    --batch-size 32 \
    --num-samples 20 \
    --output-path ./cache/hidden_states
```

The hidden states will be saved to the disk in the `output-path` directory.

### **Step 3. Start Training**

```shell
torchrun \
--standalone \
--nproc_per_node $NUM_GPUS \
$ROOT_DIR/scripts/train_eagle3.py \
--target-model-path /home/data/weights/Qwen3-30B-A3B/ \
--draft-model-config $ROOT_DIR/configs/qwen3-30B-A3B-eagle3_moe.json \
--train-data-path ./cache/dataset/sharegpt_train.jsonl \
--train-hidden-states-path ./cache/hidden_states \
--output-dir ./outputs/qwen3-moe-8b-eagle3-sharegpt-offline \
--num-epochs 4 \
--batch-size 1 \
--learning-rate 1e-4 \
--max-length 2048 \
--save-interval 1984 \
--chat-template qwen \
--cache-dir $ROOT_DIR/cache \
--embedding-key model.embed_tokens.weight \
--sp-ulysses-size 4 \
--attention-backend "usp" \
--target-model-backend sglang

```
4 changes: 2 additions & 2 deletions specforge/core/eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ def _compute_target_p_padded(target, t2d, loss_mask, length):
return target_p_padded, position_mask


@torch.compile(dynamic=None)
# @torch.compile(dynamic=None)
def _compute_target_p(target, t2d, loss_mask):
target_head = target
target_max_token = target_head.argmax(-1)
Expand All @@ -559,7 +559,7 @@ def _compute_target_p(target, t2d, loss_mask):
return target_p, position_mask


@torch.compile(dynamic=None)
# @torch.compile(dynamic=None)
def _compute_metric_acc(logits, target_p, position_mask, loss_mask):
return (
(logits.argmax(-1) == target_p.argmax(-1)) * position_mask.squeeze(-1)
Expand Down
4 changes: 2 additions & 2 deletions specforge/core/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


# Reference implementation
@torch.compile(dynamic=None)
# @torch.compile(dynamic=None)
def _compute_loss(logits, target_p, position_mask):
logits = logits.float()
out_logp = nn.LogSoftmax(dim=2)(logits)
Expand All @@ -30,7 +30,7 @@ def _calculate_settings(n):
raise RuntimeError(
f"Cannot launch Triton kernel since n = {n} exceeds the recommended Triton blocksize = {MAX_FUSED_SIZE}."
)

BLOCK_SIZE = 2048
num_warps = 4
if BLOCK_SIZE >= 32768:
num_warps = 32
Expand Down
3 changes: 3 additions & 0 deletions specforge/modeling/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# from .auto import AutoDistributedTargetModel, AutoDraftModelConfig, AutoEagle3DraftModel
from .auto import AutoDraftModelConfig, AutoEagle3DraftModel
from .draft.llama3_eagle import LlamaForCausalLMEagle3
from .draft.qwen3_moe_eagle import Qwen3MoEForCausalLMEagle3
from .draft.router_moe import Qwen3MoERouterForCausalLMEagle3
from .target.eagle3_target_model import (
CustomEagle3TargetModel,
HFEagle3TargetModel,
Expand All @@ -10,6 +12,7 @@

__all__ = [
"LlamaForCausalLMEagle3",
"Qwen3MoERouterForCausalLMEagle3",
"SGLangEagle3TargetModel",
"HFEagle3TargetModel",
"CustomEagle3TargetModel",
Expand Down
4 changes: 4 additions & 0 deletions specforge/modeling/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
)

from .draft.llama3_eagle import LlamaForCausalLMEagle3
from .draft.qwen3_moe_eagle import Qwen3MoEForCausalLMEagle3
from .draft.router_moe import Qwen3MoERouterForCausalLMEagle3
from .target.custom_backend import (
GptOssForCausalLM,
Llama4ForCausalLM,
Expand All @@ -34,6 +36,7 @@ class AutoEagle3DraftModel(AutoModelForCausalLMBase):
# the model mapping is currently hardcoded, we should support lazy model mapping via registry
_model_mapping = {
LlamaConfig: LlamaForCausalLMEagle3,
Qwen3MoeConfig: Qwen3MoERouterForCausalLMEagle3,
}

@classmethod
Expand Down Expand Up @@ -133,6 +136,7 @@ class AutoDraftModelConfig:

_config_mapping = {
"LlamaForCausalLMEagle3": LlamaConfig,
"Qwen3MoERouterForCausalLMEagle3": Qwen3MoeConfig,
}

@classmethod
Expand Down
15 changes: 12 additions & 3 deletions specforge/modeling/draft/llama3_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


@torch.compile(dynamic=True)
# @torch.compile(dynamic=True)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
Expand Down Expand Up @@ -272,7 +272,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
"sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
)

@torch.compile(dynamic=True)
# @torch.compile(dynamic=True)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len and seq_len > self.max_seq_len_cached:
Expand Down Expand Up @@ -1347,7 +1347,7 @@ def __init__(self, hidden_size, eps=1e-6):
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

@torch.compile(dynamic=True)
# @torch.compile(dynamic=True)
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
Expand Down Expand Up @@ -1474,6 +1474,7 @@ def __init__(self, config, quant_config=None, attention_backend="sdpa") -> None:
d2t = torch.zeros(self.draft_vocab_size, dtype=torch.int64)
self.register_buffer("t2d", t2d)
self.register_buffer("d2t", d2t)
self.save_idx = 1

def forward(
self,
Expand Down Expand Up @@ -1528,6 +1529,12 @@ def forward(

# norm
hidden_states = self.norm(hidden_states)
self.save_idx += 1
save_path = f"./router/{self.save_idx}.pt"
# 2. 保存张量(转CPU避免GPU张量依赖,detach解耦计算图)
torch.save(hidden_states.detach().cpu(), save_path)
if self.save_idx >= 16:
exit(0)

return hidden_states

Expand Down Expand Up @@ -1563,3 +1570,5 @@ def backbone(
output_attentions=False,
use_cache=False,
)

#
Loading