Skip to content

Commit 3f13934

Browse files
authored
py: (#86)
1.rmsnorm
1 parent d0b3047 commit 3f13934

4 files changed

Lines changed: 37 additions & 25 deletions

File tree

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from .t5layernorm import *
2+
3+
__all__ = [
4+
"T5LayerNorm","RMSNorm",
5+
]
6+

front/py/deepx/nn/modules/norm/normalization.py

Lines changed: 0 additions & 22 deletions
This file was deleted.
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
2+
from deepx.nn.modules import Module
3+
from deepx import Tensor,ones,rsqrt
4+
5+
# 论文 https://arxiv.org/abs/1910.07467
6+
# 来自 https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
7+
class T5LayerNorm( Module):
8+
def __init__(self, hidden_size, eps=1e-6):
9+
"""
10+
Construct a layernorm module in the T5 style No bias and no subtraction of mean.
11+
"""
12+
super().__init__()
13+
self.weight = ones((hidden_size,))
14+
self.register_parameter("weight", self.weight)
15+
self.variance_epsilon = eps
16+
17+
def forward(self, x:Tensor):
18+
xtype=x.dtype
19+
# layer norm should always be calculated in float32
20+
variance = x.float().pow(2).mean(-1, keepdim=True)
21+
x = x*rsqrt(variance + self.variance_epsilon)
22+
return (self.weight * x).todtype(xtype)
23+
24+
def extra_repr(self):
25+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
26+
27+
28+
RMSNorm = T5LayerNorm

front/py/deepx/nn/modules/transformer/llama/modeling_llama.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from typing import Optional,Tuple,Union
22
from deepx.nn.modules import Module,ModuleList,Linear,Embedding
33
from deepx import Tensor,cat,arange
4-
from front.py.deepx.nn.modules.transformer.modeling_rope_utils import ROPE_INIT_FUNCTIONS
5-
from deepx.nn.modules.mlp import MLP
4+
from deepx.nn.modules.transformer.modeling_rope_utils import ROPE_INIT_FUNCTIONS
5+
from deepx.nn.modules.mlp import GatedMLP
66
from deepx.nn.modules.norm import RMSNorm
77
from deepx.nn.modules.transformer import LlamaRotaryEmbedding,apply_rotary_pos_emb,grouped_query_attention as GQA
88
from deepx.utils.config import Config
@@ -75,7 +75,7 @@ def __init__(self, config:dict, layer_idx: int):
7575

7676
self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx)
7777

78-
self.mlp = MLP(config)
78+
self.mlp = GatedMLP(config)
7979
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
8080
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
8181

0 commit comments

Comments
 (0)