Skip to content

Commit 3984a99

Browse files
committed
rope,embedding验证:
1 parent c3f8fa9 commit 3984a99

5 files changed

Lines changed: 62 additions & 92 deletions

File tree

.github/ISSUE_TEMPLATE/operator.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
---
2+
name: 算子新增
3+
about: 用于提交新的算子实现请求
4+
title: '[算子] '
5+
labels: enhancement, operator
6+
assignees: ''
7+
---
8+
9+
## 算子新增
10+
该算子数学表达为
11+
12+
## 影响组件
13+
14+
### front
15+
1.
16+
2.
17+
18+
### 引擎
19+
1.
20+
2.
21+
22+
## 其他叙述
23+
24+
<!-- 请在此处添加其他相关信息,如:
25+
- 参考实现(如PyTorch中的实现)
26+
- 性能要求
27+
- 测试用例
28+
- 其他注意事项
29+
-->

front/py/deepx/tensor/tensor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,8 @@ def __radd__(self, other:Union[Number,'Tensor']):
119119
def __sub__(self, other:Union[Number,'Tensor']):
120120
return self.sub(other)
121121
def __rsub__(self, other:Union[Number,'Tensor']):
122-
return self.sub(other)
122+
x=self.mul(-1)
123+
return x.add(other)
123124
def __mul__(self, other:Union[Number,'Tensor']):
124125
return self.mul(other)
125126
def __rmul__(self, other:Union[Number,'Tensor']):
@@ -156,7 +157,7 @@ def __matmul__(self, other:'Tensor'):
156157
return self.matmul(other)
157158
def __rmatmul__(self, other:'Tensor'):
158159
return other.matmul(self)
159-
#gather
160+
160161
def __getitem__(self, index:'Tensor'):
161162
return self.indexselect(index)
162163

front/py/deepx/transformer/modeling_rope_utils.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,24 +27,27 @@ def _compute_llama3_parameters(config:dict={
2727
# Gets the default RoPE parameters
2828
inv_freq, attention_factor = _compute_default_rope_parameters(config)
2929

30+
factor = config["rope_scaling"]["factor"] # `8` in the original implementation
3031
low_freq_factor = config["rope_scaling"]["low_freq_factor"] # `1` in the original implementation
3132
high_freq_factor = config["rope_scaling"]["high_freq_factor"] # `4` in the original implementation
3233
old_context_len = config["rope_scaling"]["original_max_position_embeddings"] # `8192` in the original implementation
33-
low_freq_wavelen = old_context_len /low_freq_factor
34-
high_freq_wavelen = old_context_len/ high_freq_factor
34+
35+
low_freq_wavelen = old_context_len / low_freq_factor
36+
high_freq_wavelen = old_context_len / high_freq_factor
3537

3638
wavelen = 2 * math.pi / inv_freq
37-
factor=config["rope_scaling"]["factor"]
38-
cases=wavelen > low_freq_wavelen
39-
inv_freq_llama = where(cases, inv_freq /factor, inv_freq)
39+
40+
# wavelen < high_freq_wavelen: do nothing
41+
# wavelen > low_freq_wavelen: divide by factor
42+
inv_freq_llama = where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
4043
# otherwise: interpolate between the two, using a smooth factor
41-
smooth_factor = (old_context_len / wavelen -low_freq_factor) / ( high_freq_factor - low_freq_factor)
42-
smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
44+
smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
45+
smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
4346
is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
4447
inv_freq_llama = where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
4548

4649
return inv_freq_llama, attention_factor
47-
50+
4851
ROPE_INIT_FUNCTIONS = {
4952
"default": _compute_default_rope_parameters,
5053
# "linear": _compute_linear_scaling_rope_parameters,
Lines changed: 6 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,69 +1,4 @@
1-
hidden_size = 8
2-
eps = 1e-6
3-
dir='/home/lipeng/model/deepxmodel/llama/'
4-
model_path="/home/lipeng/model/deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
5-
print()
6-
7-
from transformers import AutoTokenizer,AutoConfig
8-
def init_tokenizer(model_path):
9-
tokenizer = AutoTokenizer.from_pretrained(model_path)
10-
tokenizer.pad_token = tokenizer.eos_token
11-
return tokenizer
12-
13-
tokenizer = init_tokenizer(model_path)
14-
config=AutoConfig.from_pretrained(model_path)
15-
def tokenize_text(text, tokenizer):
16-
tokens = tokenizer(text, return_tensors="pt").input_ids
17-
import torch
18-
# 处理超出词汇表范围的token
19-
if torch.any(tokens >= tokenizer.vocab_size):
20-
# 获取UNK token ID,如果没有则使用0
21-
unk_token_id = tokenizer.unk_token_id if hasattr(tokenizer, 'unk_token_id') and tokenizer.unk_token_id is not None else 0
22-
# 替换所有超出范围的token为UNK
23-
tokens = torch.where(tokens < tokenizer.vocab_size, tokens, torch.tensor(unk_token_id, device=tokens.device))
24-
return tokens
25-
26-
############-------PyTorch-------################
27-
import torch
28-
29-
# 创建输入
30-
text = "这是一个测试文本,用于演示嵌入层的使用。"
31-
torch_input = tokenize_text(text, tokenizer)
32-
from deepxutil.torch import save_torch
33-
save_torch(torch_input,dir+'input')
34-
35-
# 创建网络
36-
37-
class NetTorch(torch.nn.Module):
38-
from transformers.models.llama.modeling_llama import LlamaConfig
39-
def __init__(self,config:LlamaConfig):
40-
super().__init__()
41-
self.padding_idx = config.pad_token_id
42-
self.config = config
43-
self.embed_tokens = torch.nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
44-
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
45-
self.rotary_emb = LlamaRotaryEmbedding(config=config)
46-
47-
def forward(self,x):
48-
inputs_embeds = self.embed_tokens(x)
49-
hidden_states = inputs_embeds
50-
# create position embeddings to be shared across the decoder layers
51-
position_ids = torch.arange(hidden_states.shape[1], device=hidden_states.device).unsqueeze(0)
52-
return self.rotary_emb(hidden_states, position_ids)
53-
54-
55-
torch_net = NetTorch(config)
56-
save_torch(torch_net.embed_tokens.weight,dir+'weight')
57-
# 前向传播
58-
torch_output = torch_net(torch_input)
59-
torch_sin, torch_cos = torch_output
60-
61-
print("sin shape:",torch_sin.shape)
62-
print("sin:", torch_sin)
63-
64-
print("cos shape:", torch_cos.shape)
65-
print("cos:", torch_cos)
66-
1+
from .llama_rope_torch import dir,config
672

683
############-------DEEPX-------################
694
from deepx.nn.modules import Embedding,Module
@@ -86,10 +21,10 @@ def forward(self,x):
8621
position_ids = arange(start=0,end=hidden_states.shape[1]).unsqueeze(0)
8722
return self.rotary_emb(hidden_states, position_ids)
8823

89-
net = NetDeepx(configdict=config.to_dict())
90-
out=net.forward(input)
91-
out[0].print()
92-
out[1].print()
93-
24+
if __name__ == "__main__":
25+
net = NetDeepx(configdict=config.to_dict())
26+
out=net.forward(input)
27+
out[0].print()
28+
out[1].print()
9429

9530

front/py/examples/4_transformer/llama/llama_rope_torch.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,23 +52,25 @@ def __init__(self, config: LlamaConfig):
5252
self.embed_tokens = torch.nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
5353
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
5454
self.rotary_emb = LlamaRotaryEmbedding(config=config)
55-
55+
print("rotary_emb.inv_freq")
56+
print(self.rotary_emb.inv_freq)
5657
def forward(self, x):
5758
inputs_embeds = self.embed_tokens(x)
59+
print(inputs_embeds)
5860
hidden_states = inputs_embeds
5961
# create position embeddings to be shared across the decoder layers
6062
position_ids = torch.arange(hidden_states.shape[1], device=hidden_states.device).unsqueeze(0)
6163
return self.rotary_emb(hidden_states, position_ids)
6264

65+
if __name__ == "__main__":
66+
torch_net = NetTorch(config)
67+
save_torch(torch_net.embed_tokens.weight, dir + 'weight')
68+
# 前向传播
69+
torch_output = torch_net(torch_input)
70+
torch_sin, torch_cos = torch_output
6371

64-
torch_net = NetTorch(config)
65-
save_torch(torch_net.embed_tokens.weight, dir + 'weight')
66-
# 前向传播
67-
torch_output = torch_net(torch_input)
68-
torch_sin, torch_cos = torch_output
69-
70-
print("sin shape:", torch_sin.shape)
71-
print("sin:", torch_sin)
72+
print("sin shape:", torch_sin.shape)
73+
print("sin:", torch_sin)
7274

73-
print("cos shape:", torch_cos.shape)
74-
print("cos:", torch_cos)
75+
print("cos shape:", torch_cos.shape)
76+
print("cos:", torch_cos)

0 commit comments

Comments
 (0)