1- hidden_size = 8
2- eps = 1e-6
3- dir = '/home/lipeng/model/deepxmodel/llama/'
4- model_path = "/home/lipeng/model/deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
5- print ()
6-
7- from transformers import AutoTokenizer ,AutoConfig
8- def init_tokenizer (model_path ):
9- tokenizer = AutoTokenizer .from_pretrained (model_path )
10- tokenizer .pad_token = tokenizer .eos_token
11- return tokenizer
12-
13- tokenizer = init_tokenizer (model_path )
14- config = AutoConfig .from_pretrained (model_path )
15- def tokenize_text (text , tokenizer ):
16- tokens = tokenizer (text , return_tensors = "pt" ).input_ids
17- import torch
18- # 处理超出词汇表范围的token
19- if torch .any (tokens >= tokenizer .vocab_size ):
20- # 获取UNK token ID,如果没有则使用0
21- unk_token_id = tokenizer .unk_token_id if hasattr (tokenizer , 'unk_token_id' ) and tokenizer .unk_token_id is not None else 0
22- # 替换所有超出范围的token为UNK
23- tokens = torch .where (tokens < tokenizer .vocab_size , tokens , torch .tensor (unk_token_id , device = tokens .device ))
24- return tokens
25-
26- ############-------PyTorch-------################
27- import torch
28-
29- # 创建输入
30- text = "这是一个测试文本,用于演示嵌入层的使用。"
31- torch_input = tokenize_text (text , tokenizer )
32- from deepxutil .torch import save_torch
33- save_torch (torch_input ,dir + 'input' )
34-
35- # 创建网络
36-
37- class NetTorch (torch .nn .Module ):
38- from transformers .models .llama .modeling_llama import LlamaConfig
39- def __init__ (self ,config :LlamaConfig ):
40- super ().__init__ ()
41- self .padding_idx = config .pad_token_id
42- self .config = config
43- self .embed_tokens = torch .nn .Embedding (config .vocab_size , config .hidden_size , self .padding_idx )
44- from transformers .models .llama .modeling_llama import LlamaRotaryEmbedding
45- self .rotary_emb = LlamaRotaryEmbedding (config = config )
46-
47- def forward (self ,x ):
48- inputs_embeds = self .embed_tokens (x )
49- hidden_states = inputs_embeds
50- # create position embeddings to be shared across the decoder layers
51- position_ids = torch .arange (hidden_states .shape [1 ], device = hidden_states .device ).unsqueeze (0 )
52- return self .rotary_emb (hidden_states , position_ids )
53-
54-
55- torch_net = NetTorch (config )
56- save_torch (torch_net .embed_tokens .weight ,dir + 'weight' )
57- # 前向传播
58- torch_output = torch_net (torch_input )
59- torch_sin , torch_cos = torch_output
60-
61- print ("sin shape:" ,torch_sin .shape )
62- print ("sin:" , torch_sin )
63-
64- print ("cos shape:" , torch_cos .shape )
65- print ("cos:" , torch_cos )
66-
1+ from .llama_rope_torch import dir ,config
672
683############-------DEEPX-------################
694from deepx .nn .modules import Embedding ,Module
@@ -86,10 +21,10 @@ def forward(self,x):
8621 position_ids = arange (start = 0 ,end = hidden_states .shape [1 ]).unsqueeze (0 )
8722 return self .rotary_emb (hidden_states , position_ids )
8823
89- net = NetDeepx ( configdict = config . to_dict ())
90- out = net . forward ( input )
91- out [ 0 ]. print ( )
92- out [1 ].print ()
93-
24+ if __name__ == "__main__" :
25+ net = NetDeepx ( configdict = config . to_dict () )
26+ out = net . forward ( input )
27+ out [0 ].print ()
28+ out [ 1 ]. print ()
9429
9530
0 commit comments