-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathspeculative_decode.py
More file actions
106 lines (85 loc) · 4.22 KB
/
speculative_decode.py
File metadata and controls
106 lines (85 loc) · 4.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import time
from typing import Optional
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from greedy_decode import get_model, get_prompt
def get_speculative_candidate(input_ids: torch.Tensor, ngram_len: int = 2, draft_len: int = 3) -> Optional[torch.Tensor]:
"""
Find candidates for speculative decoding from the prompt.
This function identifies the last N (`ngram_len`) tokens and
scans the preceding prompt to find where this context has appeared before.
Intuition is that if the context has appeared before,
the model is likely to generate the same tokens again.
So we copy the subsequent K (`draft_len`) tokens as speculative candidates.
"""
input_list = input_ids[0].tolist()
if len(input_list) < ngram_len:
return None
# The pattern to match is the last N tokens
pattern = input_list[-ngram_len:]
history = input_list[:-ngram_len]
# Iterate backwards through the history to find most recent occurrence
for i in range(len(history)-ngram_len, -1, -1):
if history[i:i+ngram_len] == pattern:
# Match found in prompt
# Grab the subsequent K tokens as speculative candidate
start_idx = i + ngram_len
end_idx = min(start_idx + draft_len, len(history))
candidate = history[start_idx:end_idx]
if len(candidate) == 0:
return None
return torch.tensor([candidate]).to(input_ids.device)
return None
def speculative_decode(model, input_ids, max_new_tokens=25):
input_ids = input_ids.to(model.device)
tokens_generated = 0
with torch.no_grad():
while tokens_generated < max_new_tokens:
candidate = get_speculative_candidate(input_ids)
if candidate is not None:
draft_len = candidate.shape[1]
model_input = torch.cat([input_ids, candidate], dim=-1)
else:
draft_len = 0
model_input = input_ids
# Forward pass over prompt + optional[speculation]
logits = model(model_input).logits
# We need to verify tokens starting from the current end of input_ids.
# The logit that predicts input_ids[N] is located at logits[N-1].
current_seq_len = input_ids.shape[1]
all_drafts_accepted = True
for i in range(draft_len):
logit_index = current_seq_len + i - 1
pred_logits = logits[:, logit_index, :]
pred_token = torch.argmax(pred_logits, dim=-1).unsqueeze(0)
candidate_token = candidate[:, i].unsqueeze(0)
if pred_token == candidate_token:
# Accept the draft token and continue verifying
input_ids = torch.cat([input_ids, candidate_token], dim=-1)
tokens_generated += 1
else:
# Mismatch so reject draft and accept model's prediction
# Stop verifying, the rest of the draft is invalid
input_ids = torch.cat([input_ids, pred_token], dim=-1)
tokens_generated += 1
all_drafts_accepted = False
break
# If all drafts were accepted (or if there was no draft),
# we have computed the logit for the next token
if all_drafts_accepted and tokens_generated < max_new_tokens:
last_logits = logits[:, -1, :]
next_token = torch.argmax(last_logits, dim=-1).unsqueeze(0)
input_ids = torch.cat([input_ids, next_token], dim=-1)
tokens_generated += 1
return input_ids
if __name__ == "__main__":
model, tokenizer = get_model()
prompt = get_prompt(tokenizer)
max_new_tokens = 25
start = time.perf_counter()
generation = speculative_decode(model, prompt, max_new_tokens=max_new_tokens)
elapsed = time.perf_counter() - start
print(tokenizer.decode(generation[0], skip_special_tokens=True))
print(f"Speculative decode time: {elapsed:.2f}s")
throughput = max_new_tokens / elapsed
print(f"Output throughput: {throughput:.2f} tokens/s")