1- import math
21from typing import Optional
32import torch
43from torch import nn
5- from torch .nn import functional as F
6- import numpy as np
4+ from .mha_kernels import MHAKernel
75
86
9- def set_random_seed (
10- seed : int , rank : int = 0 , force_deterministic : bool = False
11- ) -> None :
12- """
13- Set the random seed for numpy and torch.
14- """
15- np .random .seed (seed + rank )
16- torch .manual_seed (seed + rank )
17- if force_deterministic :
18- torch .backends .cudnn .deterministic = True
19- torch .backends .cudnn .benchmark = False
20-
21-
22- class MultiHeadSelfAttentionKernel (nn .Module ):
23- def __init__ (self , hidden_dim : int , num_heads : int ):
24- super ().__init__ ()
25-
26- self .hidden_dim : int = hidden_dim
27- self .num_heads : int = num_heads
28- self .head_size : int = hidden_dim // num_heads
29-
30- def forward (
31- self ,
32- q : torch .Tensor ,
33- k : torch .Tensor ,
34- v : torch .Tensor ,
35- mask : Optional [torch .Tensor ] = None ,
36- ):
37- """
38- Calculates softmax(Q @ KT / sqrt(dk)) @ V .
39-
40- Parameters
41- ----------
42- q : torch.Tensor; Shape: (q_len, hidden_dim)
43-
44- k : torch.Tensor; Shape: (kv_len, hidden_dim)
45-
46- v : torch.Tensor; Shape: (kv_len, hidden_dim)
47-
48- mask: torch.Tensor; Shape: (q_len, kv_len), optional
49-
50- Note
51- ----
52- When prefilling, q_len equals to seq_len (number of tokens in the input
53- seq);
54- When decoding, q_len equals to 1, refering to the newly generated
55- token. (Based on different sampling strategies, q_len could be larger
56- than 1.)
57- """
58-
59- q_len , kv_len = q .size (0 ), k .size (0 )
60- # q -> (num_heads, q_len, head_size)
61- q = q .view (q_len , self .num_heads , self .head_size ).transpose (0 , 1 )
62- # k -> (num_heads, kv_len, head_size)
63- k = k .view (kv_len , self .num_heads , self .head_size ).transpose (0 , 1 )
64- # v -> (num_heads, kv_len, head_size)
65- v = v .view (kv_len , self .num_heads , self .head_size ).transpose (0 , 1 )
66- # scores -> (num_heads, q_len, kv_len)
67- scores = torch .matmul (q , k .transpose (- 1 , - 2 )) / (self .head_size ** 0.5 )
68- scores = (
69- scores .masked_fill (mask == 0 , float ("-inf" ))
70- if mask is not None
71- else scores
72- )
73- # scores -> (num_heads, q_len, kv_len)
74- attn_probs = F .softmax (scores , dim = - 1 )
75- # out -> (num_heads, q_len, head_size)
76- out = torch .matmul (attn_probs , v )
77- # out -> (q_len, num_heads, head_size)
78- out = out .transpose (0 , 1 ).reshape (q_len , self .hidden_dim )
79-
80- return out
81-
82-
83- class MultiHeadSelfAttention (nn .Module ):
7+ class MHA (nn .Module ):
848 def __init__ (
859 self ,
8610 embed_dim : int ,
@@ -97,7 +21,7 @@ def __init__(
9721 self .Wv = nn .Linear (embed_dim , hidden_dim )
9822 self .Wo = nn .Linear (hidden_dim , embed_dim )
9923
100- self .attn_kernel = MultiHeadSelfAttentionKernel (hidden_dim , num_heads )
24+ self .attn_kernel = MHAKernel (hidden_dim , num_heads )
10125
10226 def forward (
10327 self ,
@@ -135,7 +59,7 @@ def forward(
13559 v = self .Wv (seq )
13660
13761 # k_cache -> (kv_len + seq_len, hidden_dim)
138- k = k if k_cache is None else torch .cat ([k_cache , k .detach ()], dim = 0 )
62+ k = k if k_cache is None else torch .cat ([k_cache , k .detach ()], dim = 0 )
13963 # v_cache -> (kv_len + seq_len, hidden_dim)
14064 v = v if v_cache is None else torch .cat ([v_cache , v .detach ()], dim = 0 )
14165
@@ -148,9 +72,7 @@ def forward(
14872class TransformerBlock (nn .Module ):
14973 def __init__ (self , embed_dim , num_heads , hidden_dim , mlp_dim , dropout = 0.1 ):
15074 super ().__init__ ()
151- self .attention = MultiHeadSelfAttention (
152- embed_dim , num_heads , hidden_dim
153- )
75+ self .attention = MHA (embed_dim , num_heads , hidden_dim )
15476 self .norm1 = nn .RMSNorm (embed_dim )
15577 self .norm2 = nn .RMSNorm (embed_dim )
15678 self .mlp = nn .Sequential (
@@ -293,8 +215,6 @@ def forward(
293215
294216
295217if __name__ == "__main__" :
296- set_random_seed (114514 )
297-
298218 seq_len = 4
299219 vocab_size = 1024
300220 embed_dim = 128
@@ -347,5 +267,5 @@ def forward(
347267 for i in range (1 , n_generate ):
348268 probs = lm (token , is_prefilling = False )
349269 token = torch .argmax (probs [- 1 , :], dim = - 1 , keepdim = True )
350- print (f"The { i + 1 } th predicted token: { token } " )
270+ print (f"The { i + 1 } th predicted token: { token } " )
351271 print (f"|- Token Shape: { token .shape } " )
0 commit comments