|
1 | | -from typing import Tuple |
| 1 | +from typing import Tuple,Optional |
2 | 2 | import math |
| 3 | +from deepx.utils import Config |
3 | 4 | from deepx import arange,Tensor,where |
4 | 5 |
|
5 | | -def _compute_default_rope_parameters(config:dict={ |
6 | | - "rope_theta":10000.0, |
7 | | - "head_dim":0, |
8 | | - "partial_rotary_factor":1.0, |
9 | | -}) -> Tuple[Tensor, float]: |
10 | | - partial_rotary_factor = config.get("partial_rotary_factor", 1.0) |
11 | | - dim = config["head_dim"]* partial_rotary_factor |
12 | | - # 计算逆频率 |
13 | | - base=config["rope_theta"] |
14 | | - inv_freq = 1.0 / (base ** (arange(0, dim, 2, dtype='float64')/ dim)) |
15 | | - return inv_freq, 1.0 |
| 6 | +def _compute_default_rope_parameters(config:Config=None,seq_len: Optional[int] = None, **rope_kwargs) -> Tuple[Tensor, float]: |
| 7 | + if len(rope_kwargs) > 0: |
| 8 | + base = rope_kwargs["base"] |
| 9 | + dim = rope_kwargs["dim"] |
| 10 | + elif config is not None: |
| 11 | + base = config.rope_theta |
| 12 | + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 |
| 13 | + head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads |
| 14 | + dim = int(head_dim * partial_rotary_factor) |
| 15 | + |
| 16 | + attention_factor = 1.0 # Unused in this type of RoPE |
| 17 | + |
| 18 | + # Compute the inverse frequencies |
| 19 | + inv_freq = 1.0 / (base ** (arange(0, dim, 2, dtype="int64").float() / dim)) |
| 20 | + return inv_freq, attention_factor |
16 | 21 |
|
17 | | -def _compute_llama3_parameters(config:dict={ |
18 | | - "rope_theta":10000.0, |
19 | | - "head_dim":0, |
20 | | - "partial_rotary_factor":1.0, |
21 | | - "factor":8, |
22 | | - "low_freq_factor":1, |
23 | | - "high_freq_factor":4, |
24 | | - "old_context_len":8192, |
25 | | - "seq_len":None |
26 | | -}) -> Tuple[Tensor, float]: |
| 22 | +def _compute_llama3_parameters(config:Config,seq_len: Optional[int] = None,**rope_kwargs) -> Tuple[Tensor, float]: |
27 | 23 | # Gets the default RoPE parameters |
28 | | - inv_freq, attention_factor = _compute_default_rope_parameters(config) |
| 24 | + inv_freq, attention_factor = _compute_default_rope_parameters(config, seq_len, **rope_kwargs) |
29 | 25 |
|
30 | | - factor = config["rope_scaling"]["factor"] # `8` in the original implementation |
31 | | - low_freq_factor = config["rope_scaling"]["low_freq_factor"] # `1` in the original implementation |
32 | | - high_freq_factor = config["rope_scaling"]["high_freq_factor"] # `4` in the original implementation |
33 | | - old_context_len = config["rope_scaling"]["original_max_position_embeddings"] # `8192` in the original implementation |
| 26 | + factor = config.rope_scaling["factor"] # `8` in the original implementation |
| 27 | + low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation |
| 28 | + high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation |
| 29 | + old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation |
34 | 30 |
|
35 | 31 | low_freq_wavelen = old_context_len / low_freq_factor |
36 | 32 | high_freq_wavelen = old_context_len / high_freq_factor |
37 | 33 |
|
38 | 34 | wavelen = 2 * math.pi / inv_freq |
39 | | - wavelen.print() |
40 | 35 | # wavelen < high_freq_wavelen: do nothing |
41 | 36 | # wavelen > low_freq_wavelen: divide by factor |
42 | 37 | inv_freq_llama = where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq) |
43 | 38 | # otherwise: interpolate between the two, using a smooth factor |
44 | 39 | smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) |
45 | 40 | smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama |
46 | 41 | is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) |
47 | | - is_medium_freq.print() |
48 | | - # TODO 这一步执行后,会导致an illegal memory access was encountered |
49 | 42 | inv_freq_llama = where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) |
50 | | - is_medium_freq.print() |
51 | | - inv_freq_llama.print() |
| 43 | + |
52 | 44 | return inv_freq_llama, attention_factor |
53 | 45 |
|
54 | 46 | ROPE_INIT_FUNCTIONS = { |
|
0 commit comments