11"""SparseFlowMLP: Optimized MLP replacement"""
2+
3+ from __future__ import annotations
4+
25import torch
36import torch .nn as nn
7+ import torch .nn .functional as F
48from typing import Optional
9+
510from sparseflow .nn .policy import SparseFlowPolicy
611from sparseflow .nn .sparseflow_linear import prune_24_dense_weight
12+ from sparseflow .kernels .fused_silu_mul import fused_silu_mul
13+
714
815class SparseFlowMLP (nn .Module ):
916 """Drop-in replacement for LlamaMLP using sparse tensor cores."""
10-
17+
1118 def __init__ (
1219 self ,
1320 hidden_size : int ,
@@ -23,97 +30,106 @@ def __init__(
2330 dtype : torch .dtype = torch .float16 ,
2431 ):
2532 super ().__init__ ()
26- self .hidden_size = hidden_size
27- self .intermediate_size = intermediate_size
33+
34+ # Some HF/profiling codepaths run under torch.inference_mode().
35+ # Registering buffers from inference tensors can later trip:
36+ # "Cannot set version_counter for inference tensor".
37+ # Force normal tensors here.
38+ with torch .inference_mode (False ):
39+ gate_weight = gate_weight .detach ().clone ()
40+ up_weight = up_weight .detach ().clone ()
41+ down_weight = down_weight .detach ().clone ()
42+
43+ if gate_bias is not None :
44+ gate_bias = gate_bias .detach ().clone ()
45+ if up_bias is not None :
46+ up_bias = up_bias .detach ().clone ()
47+ if down_bias is not None :
48+ down_bias = down_bias .detach ().clone ()
49+
50+ # Keep original dense down_weight for the final matmul (for now)
51+ self .register_buffer ("down_weight" , down_weight , persistent = False )
52+
53+ self .hidden_size = int (hidden_size )
54+ self .intermediate_size = int (intermediate_size )
2855 self .policy = policy
29-
30- # Convert to sparse - store as SparseSemiStructuredTensor
56+
57+ # Convert to sparse semi-structured tensors (2:4)
3158 gate_pruned = prune_24_dense_weight (gate_weight .contiguous ())
32- self .register_buffer ("gate_sparse" ,
33- torch .sparse .to_sparse_semi_structured (gate_pruned ),
34- persistent = False )
35- self .register_buffer ("gate_bias" , gate_bias .contiguous () if gate_bias is not None else None )
36-
37- up_pruned = prune_24_dense_weight (up_weight .contiguous ())
38- self .register_buffer ("up_sparse" ,
39- torch .sparse .to_sparse_semi_structured (up_pruned ),
40- persistent = False )
41- self .register_buffer ("up_bias" , up_bias .contiguous () if up_bias is not None else None )
42-
59+ up_pruned = prune_24_dense_weight (up_weight .contiguous ())
4360 down_pruned = prune_24_dense_weight (down_weight .contiguous ())
44- self .register_buffer ("down_sparse" ,
45- torch .sparse .to_sparse_semi_structured (down_pruned ),
46- persistent = False )
47- self .register_buffer ("down_bias" , down_bias .contiguous () if down_bias is not None else None )
61+
62+ self .register_buffer (
63+ "gate_sparse" ,
64+ torch .sparse .to_sparse_semi_structured (gate_pruned ),
65+ persistent = False ,
66+ )
67+ self .register_buffer (
68+ "up_sparse" ,
69+ torch .sparse .to_sparse_semi_structured (up_pruned ),
70+ persistent = False ,
71+ )
72+ self .register_buffer (
73+ "down_sparse" ,
74+ torch .sparse .to_sparse_semi_structured (down_pruned ),
75+ persistent = False ,
76+ )
77+
78+ # Biases (may be None)
79+ self .register_buffer ("gate_bias" , gate_bias .contiguous () if gate_bias is not None else None , persistent = False )
80+ self .register_buffer ("up_bias" , up_bias .contiguous () if up_bias is not None else None , persistent = False )
81+ self .register_buffer ("down_bias" , down_bias .contiguous () if down_bias is not None else None , persistent = False )
4882
4983 def _spmm_pick (self , Ws , xT ):
50- """Auto-pick packed/meta orientation + handle alignment (K must be multiple of 16)"""
51- # Pad xT's last dimension (N/tokens) to multiple of 16 for CUTLASS alignment
84+ """
85+ Auto-pick packed/meta orientation + handle alignment.
86+ xT: [K, Ntokens] (transposed view)
87+ """
5288 orig_n = xT .shape [1 ]
5389 pad_val = (16 - (orig_n % 16 )) % 16
54-
55- if pad_val > 0 :
56- xT = torch .nn .functional .pad (xT , (0 , pad_val ))
57-
58- # Pick orientation based on shape
90+ if pad_val :
91+ xT = F .pad (xT , (0 , pad_val ))
92+
93+ # Ws is SparseSemiStructuredTensor
5994 if Ws .data .shape [1 ] == xT .shape [0 ]:
6095 res = torch .ops .aten ._sparse_semi_structured_mm (Ws .packed , Ws .meta , xT )
6196 elif Ws .data .shape [0 ] == xT .shape [0 ]:
6297 res = torch .ops .aten ._sparse_semi_structured_mm (Ws .packed_t , Ws .meta_t , xT )
6398 else :
64- raise RuntimeError (
65- f"Shape mismatch: Ws.data={ tuple (Ws .data .shape )} xT={ tuple (xT .shape )} "
66- )
67-
68- # Slice back to original size
69- if pad_val > 0 :
99+ raise RuntimeError (f"Shape mismatch: Ws.data={ tuple (Ws .data .shape )} xT={ tuple (xT .shape )} " )
100+
101+ if pad_val :
70102 res = res [:, :orig_n ]
71-
72103 return res
73104
74- def forward (self , x : torch .Tensor ) -> torch .Tensor :
75- """Forward using low-level sparse ops"""
76- # Reshape to 2D
77- leading = x .shape [:- 1 ]
78- T = 1
79- for d in leading :
80- T *= int (d )
81-
82- x_2d = x .reshape (T , self .hidden_size ).contiguous () # [T, H]
83-
84- # Gate projection - use transpose approach that we know works
85- xT = x_2d .transpose (0 , 1 ) # [H, T].contiguous()
86- gateT = self ._spmm_pick (self .gate_sparse , xT ) # [I, T]
87- gate = gateT .transpose (0 , 1 ) # [T, I]
105+ def forward (self , x ):
106+ # x: [B, T, H]
107+ B , T , H = x .shape
108+ BT = B * T
109+
110+ x2d = x .view (BT , H )
111+ xT = x2d .transpose (0 , 1 ) # [H, BT]
112+
113+ gateT = self ._spmm_pick (self .gate_sparse , xT ) # [I, BT]
114+ upT = self ._spmm_pick (self .up_sparse , xT ) # [I, BT]
115+
88116 if self .gate_bias is not None :
89- gate = gate + self .gate_bias
90-
91- # Up projection
92- upT = self ._spmm_pick (self .up_sparse , xT ) # [I, T]
93- up = upT .transpose (0 , 1 ) # [T, I]
117+ gateT = gateT + self .gate_bias .view (- 1 , 1 )
94118 if self .up_bias is not None :
95- up = up + self .up_bias
96-
97- # Activation
98- hidden = torch .nn .functional .silu (gate ) * up # [T, I]
99-
100- # Down projection
101- hT = hidden .transpose (0 , 1 ) # [I, T]
102- outT = self ._spmm_pick (self .down_sparse , hT ) # [H, T]
103- out = outT .transpose (0 , 1 ) # [T, H]
119+ upT = upT + self .up_bias .view (- 1 , 1 )
120+
121+ hiddenT = fused_silu_mul (gateT , upT ) # [I, BT]
122+
123+ # Down projection (dense for now): [H, I] @ [I, BT] -> [H, BT]
124+ outT = torch .matmul (self .down_weight , hiddenT )
125+
104126 if self .down_bias is not None :
105- out = out + self .down_bias
106-
107- return out .reshape (* leading , self .hidden_size )
108-
109-
110- def make_sparseflow_mlp (mlp_module , policy : SparseFlowPolicy = SparseFlowPolicy ()):
111- """Convert LlamaMLP to SparseFlowMLP"""
112- return SparseFlowMLP (
113- hidden_size = mlp_module .gate_proj .in_features ,
114- intermediate_size = mlp_module .gate_proj .out_features ,
115- gate_weight = mlp_module .gate_proj .weight .data ,
116- up_weight = mlp_module .up_proj .weight .data ,
117- down_weight = mlp_module .down_proj .weight .data ,
118- policy = policy ,
119- )
127+ outT = outT + self .down_bias .view (- 1 , 1 )
128+
129+ out2d = outT .transpose (0 , 1 ) # [BT, H]
130+ return out2d .view (B , T , H )
131+
132+
133+ def make_sparseflow_mlp (* args , ** kwargs ):
134+ """Backward-compat factory for older tooling."""
135+ return SparseFlowMLP (* args , ** kwargs )
0 commit comments