Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions scripts/prepare_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def parse_args():
"magicoder-evol-instruct",
"sciq",
"camel",
"nebius-llama31-8b-infinity-instruct",
],
help="The demo dataset to quickly run the training for speculative decoding",
)
Expand Down Expand Up @@ -234,6 +235,19 @@ def process_sharegpt4v_row(row, dataset_name: str = None) -> Dict:
return row, skipped_count


def process_nebius_infinity_instruct(
row: Dict, dataset_name: str = None
) -> Tuple[Dict, int]:
conversation = row["conversation"][0]
generated_message = row["generated_message"]
formatted_conversations = [
{"role": "user", "content": conversation["content"]},
{"role": "assistant", "content": generated_message["content"]},
]
row = {"id": str(row["id"]), "conversations": formatted_conversations}
return row, 0


def load_dataset_from_path(data_path: Path):
suffix = data_path.suffix.split(".")[1]
ds = load_dataset(suffix, data_files=str(data_path), split="train")
Expand Down Expand Up @@ -580,6 +594,12 @@ def main():
raise Exception("Not supported sharegpt4v now")
download_vlm_dataset(args.dataset)
proc_fn = process_sharegpt4v_row
elif args.dataset == "nebius-llama31-8b-infinity-instruct":
ds = load_dataset(
"nebius/Llama-3.1-8B-Instruct-Infinity-Instruct-0625", split="train"
)
ds = ds.map(add_index, with_indices=True)
proc_fn = process_nebius_infinity_instruct
elif args.dataset == "allava4v":
ds = load_dataset("FreedomIntelligence/ALLaVA-4V", name="allava_laion")[
"instruct"
Expand Down
72 changes: 66 additions & 6 deletions scripts/train_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,28 @@ def parse_args() -> Tuple[ArgumentParser, Namespace]:
training_group.add_argument("--seed", type=int, default=0)
training_group.add_argument("--draft-accumulation-steps", type=int, default=1)

# LK / acceptance-rate loss arguments
lk_group = parser.add_argument_group("lk loss")
lk_group.add_argument(
"--lk-loss-type",
type=str,
default=None,
choices=["lambda", "alpha"],
help="Enable LK loss objective. Choices: lambda (hybrid KL+LK), alpha (pure acceptance-rate likelihood).",
)
lk_group.add_argument(
"--kl-scale",
type=float,
default=1.0,
help="Scale for adaptive KL weight: kl_weight = kl_scale * exp(-kl_decay * acc). Used when --lk-loss-type=lambda.",
)
lk_group.add_argument(
"--kl-decay",
type=float,
default=3.0,
help="Decay for adaptive KL weight. Used when --lk-loss-type=lambda.",
)

# data processing type
optimization_group = parser.add_argument_group("optimization")
optimization_group.add_argument(
Expand Down Expand Up @@ -339,6 +361,10 @@ def sanity_check(args: Namespace) -> None:
"""
args.dp_size = dist.get_world_size() // args.tp_size
args.target_batch_size = args.tp_size * args.batch_size
if args.kl_scale < 0:
raise ValueError(f"--kl-scale must be non-negative, got {args.kl_scale}")
if args.kl_decay < 0:
raise ValueError(f"--kl-decay must be non-negative, got {args.kl_decay}")
if args.attention_backend == "usp":
sp_sanity_check(args)

Expand Down Expand Up @@ -597,9 +623,9 @@ def run_forward(
data: dict,
target_model: Optional[Eagle3TargetModel] = None,
is_online: bool = True,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
if args.is_vlm and args.target_model_backend == "custom":
plosses, _, acces = eagle3_model(
plosses, acceptance_rates, acces = eagle3_model(
input_ids=data["input_ids"].cuda(),
attention_mask=data["attention_mask"].cuda(),
loss_mask=data["loss_mask"].cuda(),
Expand Down Expand Up @@ -651,7 +677,7 @@ def run_forward(
target.cuda()
) # The `data['target']` value occupies a large amount of GPU memory, with a shape of [seqlen, vocab_size]. It needs to be processed before being loaded into the GPU.
loss_mask = loss_mask.cuda()
plosses, _, acces = eagle3_model(
plosses, acceptance_rates, acces = eagle3_model(
input_ids=input_ids,
attention_mask=attention_mask,
loss_mask=loss_mask,
Expand All @@ -663,7 +689,7 @@ def run_forward(
image_grid_thw=image_grid_thw,
is_vlm=args.is_vlm,
)
return plosses, acces
return plosses, acces, acceptance_rates


def run_backward_and_update(
Expand All @@ -683,6 +709,7 @@ def run_backward_and_update(
def record_metrcs(
args: Namespace,
accuracies: List[torch.Tensor],
acceptance_rates: List[torch.Tensor],
plosses: List[torch.Tensor],
global_step: int,
tracker: Tracker,
Expand All @@ -706,6 +733,16 @@ def record_metrcs(
f"Eval - Step {global_step} [{global_step + 1}/{args.num_epochs}], position {i}, Acc: {accuracies[i]:.2f}"
)

acceptance_rates = torch.stack(acceptance_rates)
assert acceptance_rates.shape[0] == args.ttt_length
dist.all_reduce(acceptance_rates, op=dist.ReduceOp.AVG)
acceptance_rates = acceptance_rates.cpu().tolist()
for i in range(len(acceptance_rates)):
logdict[f"{mode}/acceptance_rate_{i}"] = acceptance_rates[i]
print_on_rank0(
f"Eval - Step {global_step} [{global_step + 1}/{args.num_epochs}], position {i}, Acceptance Rate: {acceptance_rates[i]:.4f}"
)

dist.all_reduce(plosses, op=dist.ReduceOp.AVG)
plosses = plosses.cpu().tolist()
for i in range(len(plosses)):
Expand Down Expand Up @@ -789,6 +826,9 @@ def main():
processor=processor,
length=args.ttt_length,
attention_backend=args.attention_backend,
lk_loss_type=args.lk_loss_type,
kl_scale=args.kl_scale,
kl_decay=args.kl_decay,
)
else:
if is_online:
Expand All @@ -797,13 +837,19 @@ def main():
draft_model=draft_model,
length=args.ttt_length,
attention_backend=args.attention_backend,
lk_loss_type=args.lk_loss_type,
kl_scale=args.kl_scale,
kl_decay=args.kl_decay,
)
else:
# offline: the target_model is TargetHead not a model
eagle3_model = OnlineEagle3Model(
draft_model=draft_model,
length=args.ttt_length,
attention_backend=args.attention_backend,
lk_loss_type=args.lk_loss_type,
kl_scale=args.kl_scale,
kl_decay=args.kl_decay,
)
eagle3_model = FSDP(
eagle3_model,
Expand Down Expand Up @@ -909,7 +955,7 @@ def main():
# ================================================
# 7.1 Training Step
# ================================================
plosses, acces = run_forward(
plosses, acces, acceptance_rates = run_forward(
args,
eagle3_model,
data,
Expand All @@ -923,6 +969,7 @@ def main():
record_metrcs(
args,
acces,
acceptance_rates,
plosses,
global_step // args.draft_accumulation_steps,
tracker,
Expand All @@ -935,10 +982,14 @@ def main():
last_time = time.time()
avg_loss = sum(pl for pl in plosses) / len(plosses)
avg_acc = sum(acces) / len(acces)
avg_acceptance_rate = sum(ar for ar in acceptance_rates) / len(
acceptance_rates
)
progress_bar.set_postfix(
{
"loss": f"{avg_loss:.2f}",
"acc": f"{avg_acc:.2f}",
"acceptance_rate": f"{avg_acceptance_rate:.2f}",
"time": f"{time_per_step:.2f}s",
}
)
Expand All @@ -958,27 +1009,36 @@ def main():
# Run evaluation
draft_model.eval()
eval_acces = [[] for _ in range(eagle3_model.length)]
eval_acceptance_rates = [[] for _ in range(eagle3_model.length)]
eval_plosses = [[] for _ in range(eagle3_model.length)]

for data in tqdm(eval_dataloader, desc=f"Evaluating Epoch {epoch}"):
with torch.no_grad():
plosses, acces = run_forward(
plosses, acces, acceptance_rates = run_forward(
args, eagle3_model, data, target_model, is_online
)
eval_acces = [
eval_acces[i] + [acces[i]] for i in range(len(acces))
]
eval_acceptance_rates = [
eval_acceptance_rates[i] + [acceptance_rates[i]]
for i in range(len(acceptance_rates))
]
eval_plosses = [
eval_plosses[i] + [plosses[i]] for i in range(len(plosses))
]

# compute average over all minibatches
eval_acces = [torch.stack(acc).mean() for acc in eval_acces]
eval_acceptance_rates = [
torch.stack(ar).mean() for ar in eval_acceptance_rates
]
eval_plosses = [torch.stack(pl).mean() for pl in eval_plosses]

record_metrcs(
args,
eval_acces,
eval_acceptance_rates,
eval_plosses,
global_step // args.draft_accumulation_steps,
tracker,
Expand Down
Loading
Loading