Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions configs/glm4-flash-eagle3.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
"architectures": [
"LlamaForCausalLMEagle3"
],
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 151329,
"eos_token_id": 151336,
"head_dim": 102,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The head_dim is set to 102. While this is a valid configuration, it's not a multiple of 8. This can lead to suboptimal performance on modern hardware accelerators like GPUs, which often have optimized kernels for dimensions that are multiples of 8 or 64. If this value is not a strict requirement of the model architecture, consider adjusting it to a nearby multiple of 8 (e.g., 104 or 96) to potentially improve training and inference speed.

"hidden_act": "silu",
"hidden_size": 2048,
"initializer_range": 0.02,
"intermediate_size": 8192,
"max_position_embeddings": 4096,
"model_type": "llama",
"num_attention_heads": 20,
"num_hidden_layers": 1,
"num_key_value_heads": 4,
"rms_norm_eps": 1e-05,
"rope_theta": 1000000,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"use_cache": true,
"vocab_size": 154880,
"draft_vocab_size": 32000
}
55 changes: 55 additions & 0 deletions examples/run_glm4_flash_eagle3_debug.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#!/bin/bash
# GLM-4.7-Flash EAGLE3 Debug Training Script
# Quick test run with verbose logging

SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
ROOT_DIR=$(dirname $SCRIPT_DIR)
export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels

# Load wandb API key from persistent storage
if [ -f /gustavo/.wandb_key ]; then
export WANDB_API_KEY=$(cat /gustavo/.wandb_key)
fi
Comment on lines +10 to +12
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

The script hardcodes a path to a secret file /gustavo/.wandb_key, which leaks developer details and makes the script non-portable and potentially insecure. It's recommended to use a more generic path like $HOME/.wandb_key or an environment variable.

Suggested change
if [ -f /gustavo/.wandb_key ]; then
export WANDB_API_KEY=$(cat /gustavo/.wandb_key)
fi
if [ -f "$HOME/.wandb_key" ]; then


NUM_GPUS=1
TP_SIZE=1
BUILD_DATASET_NUM_PROC=64

echo "========================================"
echo "GLM-4.7-Flash EAGLE3 DEBUG Training"
echo "========================================"
echo "NUM_GPUS: $NUM_GPUS"
echo "TP_SIZE: $TP_SIZE"
echo "Testing with 1 epoch, verbose logging"
echo "Using single GPU for debugging"
echo "Loss debugging ENABLED"
echo "========================================"

torchrun \
--standalone \
--nproc_per_node $NUM_GPUS \
$ROOT_DIR/scripts/train_eagle3.py \
--target-model-path zai-org/GLM-4.7-Flash \
--trust-remote-code \
--draft-model-config $ROOT_DIR/configs/glm4-flash-eagle3.json \
--train-data-path $ROOT_DIR/cache/dataset/sharegpt_train.jsonl \
--build-dataset-num-proc $BUILD_DATASET_NUM_PROC \
--output-dir $ROOT_DIR/outputs/glm4-flash-eagle3-debug \
--num-epochs 1 \
--batch-size 1 \
--learning-rate 1e-4 \
--max-length 512 \
--chat-template glm4 \
--cache-dir $ROOT_DIR/cache \
--embedding-key model.embed_tokens.weight \
--tp-size $TP_SIZE \
--target-model-backend sglang \
--sglang-mem-fraction-static 0.75 \
--log-interval 10 \
--save-interval 500 \
--eval-interval 500 \
--verbose \
--debug-loss \
--report-to wandb \
--wandb-project baby-shark-glm-eagle3 \
--wandb-name glm4-flash-eagle3-debug-$(date +%Y%m%d-%H%M%S)
45 changes: 45 additions & 0 deletions examples/run_glm4_flash_eagle3_online.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#!/bin/bash
# GLM-4.7-Flash EAGLE3 Training Script
# Usage: ./examples/run_glm4_flash_eagle3_online.sh [NUM_GPUS] [TP_SIZE]

SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
ROOT_DIR=$(dirname $SCRIPT_DIR)
export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels

# Load wandb API key from persistent storage
if [ -f /gustavo/.wandb_key ]; then
export WANDB_API_KEY=$(cat /gustavo/.wandb_key)
fi
Comment on lines +10 to +12
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

The script hardcodes a path to a secret file /gustavo/.wandb_key, which leaks developer details and makes the script non-portable and potentially insecure. It's recommended to use a more generic path like $HOME/.wandb_key or an environment variable.

Suggested change
if [ -f /gustavo/.wandb_key ]; then
export WANDB_API_KEY=$(cat /gustavo/.wandb_key)
fi
if [ -f "$HOME/.wandb_key" ]; then


NUM_GPUS=${1:-1}
TP_SIZE=${2:-1}
BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64}

echo "========================================"
echo "GLM-4.7-Flash EAGLE3 Training"
echo "========================================"
echo "NUM_GPUS: $NUM_GPUS"
echo "TP_SIZE: $TP_SIZE"
echo "BUILD_DATASET_NUM_PROC: $BUILD_DATASET_NUM_PROC"
echo "========================================"

torchrun \
--standalone \
--nproc_per_node $NUM_GPUS \
$ROOT_DIR/scripts/train_eagle3.py \
--target-model-path zai-org/GLM-4.7-Flash \
Comment on lines +28 to +30
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

The script uses unquoted positional arguments $NUM_GPUS and $TP_SIZE in a command. This allows an attacker to inject arbitrary commands if they can control the arguments passed to the script (e.g., by passing 1; id as an argument).

Recommendation: Quote the variables (e.g., "$NUM_GPUS") and ideally validate that they are integers before use.

--trust-remote-code \
--draft-model-config $ROOT_DIR/configs/glm4-flash-eagle3.json \
--train-data-path $ROOT_DIR/cache/dataset/sharegpt_train.jsonl \
--build-dataset-num-proc $BUILD_DATASET_NUM_PROC \
--output-dir $ROOT_DIR/outputs/glm4-flash-eagle3-sharegpt \
--num-epochs 10 \
--batch-size 1 \
--learning-rate 1e-4 \
--max-length 4096 \
--chat-template glm4 \
--cache-dir $ROOT_DIR/cache \
--embedding-key model.embed_tokens.weight \
--tp-size $TP_SIZE \
--target-model-backend sglang \
--sglang-mem-fraction-static 0.4
77 changes: 77 additions & 0 deletions scripts/mix_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#!/usr/bin/env python3
"""Mix multiple datasets according to specified ratios."""
import json
import random
from pathlib import Path
from typing import List, Tuple

def load_jsonl(path: str) -> List[dict]:
"""Load JSONL file."""
data = []
with open(path, 'r') as f:
for line in f:
data.append(json.loads(line))
return data

def mix_datasets(
datasets: List[Tuple[str, float]], # [(path, ratio)]
output_path: str,
total_samples: int = None,
seed: int = 42
):
"""Mix datasets according to ratios."""
random.seed(seed)

# Load all datasets
all_data = []
for path, ratio in datasets:
data = load_jsonl(path)
print(f"Loaded {len(data)} samples from {path}")

if total_samples:
# Sample according to ratio
n_samples = int(total_samples * ratio)
sampled = random.sample(data, min(n_samples, len(data)))
else:
# Use all data weighted by ratio
sampled = random.sample(data, int(len(data) * ratio / sum(r for _, r in datasets)))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The sum of ratios sum(r for _, r in datasets) is recalculated on every iteration of the loop. This is inefficient, especially if there are many datasets. This value should be calculated once before the loop begins and stored in a variable.


all_data.extend(sampled)
print(f"Added {len(sampled)} samples ({ratio*100:.1f}%)")

# Shuffle combined dataset
random.shuffle(all_data)

# Normalize IDs to strings (fix type mismatch between datasets)
for i, item in enumerate(all_data):
item['id'] = str(item.get('id', i))

# Save mixed dataset
with open(output_path, 'w') as f:
for item in all_data:
f.write(json.dumps(item) + '\n')

print(f"\nSaved {len(all_data)} mixed samples to {output_path}")

if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--output", type=str, required=True)
parser.add_argument("--total-samples", type=int, default=None)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()

# Mix for Experiment J: 45% ShareGPT, 35% UltraChat, 20% PerfectBlend
base_dir = Path(__file__).parent.parent / "cache" / "dataset"
datasets = [
(str(base_dir / "sharegpt_train.jsonl"), 0.45),
(str(base_dir / "ultrachat_train.jsonl"), 0.35),
(str(base_dir / "perfectblend_train.jsonl"), 0.20),
]

mix_datasets(
datasets=datasets,
output_path=args.output,
total_samples=args.total_samples,
seed=args.seed
)
6 changes: 6 additions & 0 deletions specforge/data/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,12 @@ def parse(
break
messages.append(sentence)

# Check if we have any actual content to train on (not just system prompt)
has_assistant = any(m["role"] == "assistant" for m in messages)
if not has_assistant:
# No assistant response to train on - skip this sample
return None

try:
conversation = self.apply_chat_template(messages, **kwargs)
except (ValueError, TypeError):
Expand Down
15 changes: 12 additions & 3 deletions specforge/data/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,14 @@ def _apply_loss_mask_from_chat_template(
)

# Find spans of assistant responses using regex
# Match assistant header with OR without end_of_turn_token prefix
# (first response often lacks the prefix)
assistant_pattern = (
re.escape(assistant_message_separator)
+ r"(.*?)(?="
r"(?:"
+ re.escape(assistant_message_separator)
+ "|"
+ re.escape(chat_template.assistant_header)
+ r")(.*?)(?="
+ re.escape(user_message_separator)
+ "|$)"
)
Expand Down Expand Up @@ -163,13 +168,17 @@ def preprocess_conversations(
if not source:
# if the source is None, skip it
continue
input_ids, loss_mask = parser.parse(
result = parser.parse(
source,
max_length,
preformatted=is_preformatted,
train_only_last_turn=train_only_last_turn,
**kwargs_item,
)
if result is None:
# Skip invalid conversations (e.g., no assistant response)
continue
input_ids, loss_mask = result
results["input_ids"].append(input_ids[None, :])
results["loss_mask"].append(loss_mask[None, :])
results["attention_mask"].append(torch.ones_like(loss_mask)[None, :])
Expand Down
10 changes: 10 additions & 0 deletions specforge/data/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,3 +308,13 @@ def get_all_template_names(self) -> List[str]:
end_of_turn_token="</longcat_s>",
),
)

TEMPLATE_REGISTRY.register(
name="glm4",
template=ChatTemplate(
assistant_header="<|assistant|></think>",
user_header="<|user|>",
system_prompt="", # GLM tokenizer handles system prompts natively - don't duplicate
end_of_turn_token="<|endoftext|>",
),
)
14 changes: 14 additions & 0 deletions specforge/modeling/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,17 @@
modeling_utils,
)

try:
from transformers import Glm4MoeLiteConfig
_GLM4_CONFIG_AVAILABLE = True
except ImportError:
Glm4MoeLiteConfig = None
_GLM4_CONFIG_AVAILABLE = False

from .draft.llama3_eagle import LlamaForCausalLMEagle3
from .target.custom_backend import (
_GLM4_AVAILABLE,
Glm4MoeLiteForCausalLM,
GptOssForCausalLM,
Llama4ForCausalLM,
LlamaForCausalLM,
Expand Down Expand Up @@ -86,6 +95,7 @@ def filtered_warning(msg):
class AutoDistributedTargetModel(AutoModelForCausalLMBase):
# the model mapping is currently hardcoded, we should support lazy model mapping via registry
_model_mapping = {
**({Glm4MoeLiteConfig: [Glm4MoeLiteForCausalLM]} if _GLM4_CONFIG_AVAILABLE and _GLM4_AVAILABLE else {}),
Llama4TextConfig: [Llama4ForCausalLM],
Qwen3MoeConfig: [Qwen3MoeForCausalLM],
Qwen2Config: [Qwen2ForCausalLM],
Expand Down Expand Up @@ -172,4 +182,8 @@ def from_file(cls, config_path: str):
if "draft_vocab_size" not in config or config["draft_vocab_size"] is None:
config["draft_vocab_size"] = config.get("vocab_size", None)

# Ensure rope_scaling is None if not explicitly set, to avoid "default" type errors
if "rope_scaling" not in config:
config["rope_scaling"] = None

return cls._config_mapping[architecture].from_dict(config)
9 changes: 8 additions & 1 deletion specforge/modeling/draft/llama3_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,14 @@ def rope_get(key, default=None):
scaling_type = rope_get("rope_type", rope_get("type"))
scaling_factor = rope_get("factor")

if scaling_type == "linear":
if scaling_type == "default" or scaling_type is None:
# Handle "default" rope_type as standard RoPE (no scaling)
self.rotary_emb = LlamaRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=getattr(self.config, "rope_theta", 10000),
)
elif scaling_type == "linear":
if scaling_factor is None:
raise ValueError(
"Linear RoPE scaling requires 'factor' in rope_scaling config."
Expand Down
8 changes: 8 additions & 0 deletions specforge/modeling/target/custom_backend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
try:
from .glm4_moe_lite import Glm4MoeLiteForCausalLM
_GLM4_AVAILABLE = True
except ImportError:
Glm4MoeLiteForCausalLM = None
_GLM4_AVAILABLE = False

from .gpt_oss import GptOssForCausalLM
from .llama import LlamaForCausalLM
from .llama4 import Llama4ForCausalLM
Expand All @@ -7,6 +14,7 @@
from .qwen3_moe import Qwen3MoeForCausalLM

__all__ = [
"Glm4MoeLiteForCausalLM",
"GptOssForCausalLM",
"LlamaForCausalLM",
"Llama4ForCausalLM",
Expand Down
Loading
Loading