Skip to content

Commit 4bbb37b

Browse files
author
Gourav Kumar
committed
Profile sparse MLP and improve kernel visibility
1 parent fccf888 commit 4bbb37b

17 files changed

Lines changed: 1025 additions & 99 deletions

experiments/old_structure/python/sparseflow/nn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def from_dense(
5757
diff_report: Optional accuracy report (if return_diff=True)
5858
"""
5959
# Get weight
60-
weight = dense_linear.weight.data
60+
weight = dense_linear.weight.detach()
6161

6262
# Prune to 2:4 pattern
6363
weight_sparse = sf.prune_2_4(weight, method=method)
@@ -74,7 +74,7 @@ def from_dense(
7474
sparse_linear = SparseLinear(
7575
weight_compressed,
7676
metadata=None,
77-
bias=dense_linear.bias.data if dense_linear.bias is not None else None
77+
bias=dense_linear.bias.detach() if dense_linear.bias is not None else None
7878
)
7979

8080
# Measure accuracy impact if requested

sparseflow/__init__.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,30 @@
1-
"""SparseFlow: Hardware-aware sparse inference for A100"""
1+
"""
2+
SparseFlow package.
23
3-
from sparseflow.nn.sparseflow_linear import SparseFlowLinear, make_sparseflow_linear, prune_24_dense_weight
4-
from sparseflow.nn.policy import SparseFlowPolicy
5-
from sparseflow.compiled_model import compile_sparseflow_model, CompiledSparseFlowModel
4+
IMPORTANT:
5+
- Keep this __init__ lightweight.
6+
- Do NOT hard-import submodules that may depend on CUDA builds / optional ops.
7+
- Re-export symbols *best-effort* so `import sparseflow` doesn't crash.
8+
"""
69

7-
__version__ = "2.2.0.post1"
10+
from importlib import import_module
811

9-
__all__ = [
10-
'SparseFlowLinear',
11-
'make_sparseflow_linear',
12-
'SparseFlowPolicy',
13-
'prune_24_dense_weight',
14-
'compile_sparseflow_model',
15-
'CompiledSparseFlowModel',
16-
]
12+
__all__ = []
1713

18-
# SparseFlow MLP module
19-
from sparseflow.nn.sparseflow_mlp import SparseFlowMLP, make_sparseflow_mlp
20-
__all__.extend(['SparseFlowMLP', 'make_sparseflow_mlp'])
14+
def _safe_export(mod: str, names: list[str]) -> None:
15+
try:
16+
m = import_module(mod)
17+
g = globals()
18+
for n in names:
19+
if hasattr(m, n):
20+
g[n] = getattr(m, n)
21+
__all__.append(n)
22+
except Exception:
23+
# swallow import errors so package import stays healthy
24+
pass
25+
26+
# Best-effort re-exports
27+
_safe_export("sparseflow.nn.policy", ["SparseFlowPolicy"])
28+
_safe_export("sparseflow.nn.sparseflow_linear", ["SparseFlowLinear", "make_sparseflow_linear", "prune_24_dense_weight"])
29+
_safe_export("sparseflow.nn.sparseflow_mlp", ["SparseFlowMLP", "make_sparseflow_mlp"])
30+
_safe_export("sparseflow.nn.surgery", ["replace_llama_mlp_module"])

sparseflow/kernels/__init__.py

Whitespace-only changes.
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
5+
@triton.jit
6+
def _fused_silu_mul_2d_strided(
7+
g_ptr, u_ptr, o_ptr,
8+
n_rows: tl.constexpr, n_cols: tl.constexpr,
9+
g_s0: tl.constexpr, g_s1: tl.constexpr,
10+
u_s0: tl.constexpr, u_s1: tl.constexpr,
11+
o_s0: tl.constexpr, o_s1: tl.constexpr,
12+
BLOCK: tl.constexpr
13+
):
14+
pid0 = tl.program_id(0) # row
15+
pid1 = tl.program_id(1) # col-block
16+
row = pid0
17+
col0 = pid1 * BLOCK
18+
cols = col0 + tl.arange(0, BLOCK)
19+
mask = cols < n_cols
20+
21+
g_offs = row * g_s0 + cols * g_s1
22+
u_offs = row * u_s0 + cols * u_s1
23+
o_offs = row * o_s0 + cols * o_s1
24+
25+
g = tl.load(g_ptr + g_offs, mask=mask, other=0.0).to(tl.float32)
26+
u = tl.load(u_ptr + u_offs, mask=mask, other=0.0).to(tl.float32)
27+
28+
s = 1.0 / (1.0 + tl.exp(-g)) # sigmoid(g)
29+
y = (g * s) * u # silu(g) * u
30+
31+
tl.store(o_ptr + o_offs, y.to(tl.float16), mask=mask)
32+
33+
def fused_silu_mul(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
34+
"""
35+
Fast path: expects 2D tensors (e.g., [I, BT]) but DOES NOT require contiguity.
36+
Uses explicit strides so we avoid .contiguous() and the clone/copy_ tax.
37+
"""
38+
assert gate.is_cuda and up.is_cuda
39+
assert gate.dtype == up.dtype
40+
assert gate.ndim == 2 and up.ndim == 2, "fused_silu_mul expects 2D tensors"
41+
42+
I, BT = gate.shape
43+
assert up.shape == (I, BT)
44+
45+
out = torch.empty((I, BT), device=gate.device, dtype=gate.dtype)
46+
47+
grid = (I, triton.cdiv(BT, 1024))
48+
_fused_silu_mul_2d_strided[grid](
49+
gate, up, out,
50+
n_rows=I, n_cols=BT,
51+
g_s0=gate.stride(0), g_s1=gate.stride(1),
52+
u_s0=up.stride(0), u_s1=up.stride(1),
53+
o_s0=out.stride(0), o_s1=out.stride(1),
54+
BLOCK=1024,
55+
num_warps=4
56+
)
57+
return out

sparseflow/nn/__init__.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,27 @@
1-
from .policy import SparseFlowPolicy
2-
from .sparseflow_linear import SparseFlowLinear, make_sparseflow_linear, prune_24_dense_weight
1+
"""
2+
sparseflow.nn
33
4-
__all__ = ["SparseFlowPolicy", "SparseFlowLinear", "make_sparseflow_linear", "prune_24_dense_weight"]
4+
Keep this module lightweight. Do NOT hard-import optional CUDA-dependent modules
5+
or large submodules on import, otherwise `from sparseflow.nn.policy import ...`
6+
will fail if any other file has issues.
7+
"""
8+
9+
from importlib import import_module
10+
11+
__all__ = []
12+
13+
def _safe_export(mod: str, names: list[str]) -> None:
14+
try:
15+
m = import_module(mod)
16+
g = globals()
17+
for n in names:
18+
if hasattr(m, n):
19+
g[n] = getattr(m, n)
20+
__all__.append(n)
21+
except Exception:
22+
pass
23+
24+
_safe_export("sparseflow.nn.policy", ["SparseFlowPolicy"])
25+
_safe_export("sparseflow.nn.surgery", ["replace_llama_mlp_module"])
26+
_safe_export("sparseflow.nn.sparseflow_linear", ["SparseFlowLinear", "make_sparseflow_linear", "prune_24_dense_weight"])
27+
_safe_export("sparseflow.nn.sparseflow_mlp", ["SparseFlowMLP", "make_sparseflow_mlp"])

sparseflow/nn/llama_surgery_mlp.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
"""
2+
Shim module so tooling can import:
3+
from sparseflow.nn.llama_surgery_mlp import replace_llama_mlp_module
4+
5+
Actual implementation lives in tools/llama_surgery_mlp.py
6+
"""
7+
import os, sys
8+
9+
# Ensure repo root is on path so `tools.*` is importable when sparseflow is imported as a package.
10+
_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
11+
if _REPO_ROOT not in sys.path:
12+
sys.path.insert(0, _REPO_ROOT)
13+
14+
from tools.llama_surgery_mlp import replace_llama_mlp_module # re-export
15+
16+
__all__ = ["replace_llama_mlp_module"]

sparseflow/nn/sparseflow_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
121121
Ws = self.W_sparse
122122
x2d_T = x2d.transpose(0, 1) # NO contiguous
123123
y2d_T = torch.ops.aten._sparse_semi_structured_mm(Ws.packed, Ws.meta, x2d_T)
124-
y2d = y2d_T.transpose(0, 1) # NO contiguous
124+
y2d = y2d_T.transpose(0, 1).contiguous() # Force contiguous!
125125
if self.bias is not None:
126126
y2d = y2d + self.bias
127127
else:

sparseflow/nn/sparseflow_mlp.py

Lines changed: 93 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
11
"""SparseFlowMLP: Optimized MLP replacement"""
2+
3+
from __future__ import annotations
4+
25
import torch
36
import torch.nn as nn
7+
import torch.nn.functional as F
48
from typing import Optional
9+
510
from sparseflow.nn.policy import SparseFlowPolicy
611
from sparseflow.nn.sparseflow_linear import prune_24_dense_weight
12+
from sparseflow.kernels.fused_silu_mul import fused_silu_mul
13+
714

815
class SparseFlowMLP(nn.Module):
916
"""Drop-in replacement for LlamaMLP using sparse tensor cores."""
10-
17+
1118
def __init__(
1219
self,
1320
hidden_size: int,
@@ -23,97 +30,106 @@ def __init__(
2330
dtype: torch.dtype = torch.float16,
2431
):
2532
super().__init__()
26-
self.hidden_size = hidden_size
27-
self.intermediate_size = intermediate_size
33+
34+
# Some HF/profiling codepaths run under torch.inference_mode().
35+
# Registering buffers from inference tensors can later trip:
36+
# "Cannot set version_counter for inference tensor".
37+
# Force normal tensors here.
38+
with torch.inference_mode(False):
39+
gate_weight = gate_weight.detach().clone()
40+
up_weight = up_weight.detach().clone()
41+
down_weight = down_weight.detach().clone()
42+
43+
if gate_bias is not None:
44+
gate_bias = gate_bias.detach().clone()
45+
if up_bias is not None:
46+
up_bias = up_bias.detach().clone()
47+
if down_bias is not None:
48+
down_bias = down_bias.detach().clone()
49+
50+
# Keep original dense down_weight for the final matmul (for now)
51+
self.register_buffer("down_weight", down_weight, persistent=False)
52+
53+
self.hidden_size = int(hidden_size)
54+
self.intermediate_size = int(intermediate_size)
2855
self.policy = policy
29-
30-
# Convert to sparse - store as SparseSemiStructuredTensor
56+
57+
# Convert to sparse semi-structured tensors (2:4)
3158
gate_pruned = prune_24_dense_weight(gate_weight.contiguous())
32-
self.register_buffer("gate_sparse",
33-
torch.sparse.to_sparse_semi_structured(gate_pruned),
34-
persistent=False)
35-
self.register_buffer("gate_bias", gate_bias.contiguous() if gate_bias is not None else None)
36-
37-
up_pruned = prune_24_dense_weight(up_weight.contiguous())
38-
self.register_buffer("up_sparse",
39-
torch.sparse.to_sparse_semi_structured(up_pruned),
40-
persistent=False)
41-
self.register_buffer("up_bias", up_bias.contiguous() if up_bias is not None else None)
42-
59+
up_pruned = prune_24_dense_weight(up_weight.contiguous())
4360
down_pruned = prune_24_dense_weight(down_weight.contiguous())
44-
self.register_buffer("down_sparse",
45-
torch.sparse.to_sparse_semi_structured(down_pruned),
46-
persistent=False)
47-
self.register_buffer("down_bias", down_bias.contiguous() if down_bias is not None else None)
61+
62+
self.register_buffer(
63+
"gate_sparse",
64+
torch.sparse.to_sparse_semi_structured(gate_pruned),
65+
persistent=False,
66+
)
67+
self.register_buffer(
68+
"up_sparse",
69+
torch.sparse.to_sparse_semi_structured(up_pruned),
70+
persistent=False,
71+
)
72+
self.register_buffer(
73+
"down_sparse",
74+
torch.sparse.to_sparse_semi_structured(down_pruned),
75+
persistent=False,
76+
)
77+
78+
# Biases (may be None)
79+
self.register_buffer("gate_bias", gate_bias.contiguous() if gate_bias is not None else None, persistent=False)
80+
self.register_buffer("up_bias", up_bias.contiguous() if up_bias is not None else None, persistent=False)
81+
self.register_buffer("down_bias", down_bias.contiguous() if down_bias is not None else None, persistent=False)
4882

4983
def _spmm_pick(self, Ws, xT):
50-
"""Auto-pick packed/meta orientation + handle alignment (K must be multiple of 16)"""
51-
# Pad xT's last dimension (N/tokens) to multiple of 16 for CUTLASS alignment
84+
"""
85+
Auto-pick packed/meta orientation + handle alignment.
86+
xT: [K, Ntokens] (transposed view)
87+
"""
5288
orig_n = xT.shape[1]
5389
pad_val = (16 - (orig_n % 16)) % 16
54-
55-
if pad_val > 0:
56-
xT = torch.nn.functional.pad(xT, (0, pad_val))
57-
58-
# Pick orientation based on shape
90+
if pad_val:
91+
xT = F.pad(xT, (0, pad_val))
92+
93+
# Ws is SparseSemiStructuredTensor
5994
if Ws.data.shape[1] == xT.shape[0]:
6095
res = torch.ops.aten._sparse_semi_structured_mm(Ws.packed, Ws.meta, xT)
6196
elif Ws.data.shape[0] == xT.shape[0]:
6297
res = torch.ops.aten._sparse_semi_structured_mm(Ws.packed_t, Ws.meta_t, xT)
6398
else:
64-
raise RuntimeError(
65-
f"Shape mismatch: Ws.data={tuple(Ws.data.shape)} xT={tuple(xT.shape)}"
66-
)
67-
68-
# Slice back to original size
69-
if pad_val > 0:
99+
raise RuntimeError(f"Shape mismatch: Ws.data={tuple(Ws.data.shape)} xT={tuple(xT.shape)}")
100+
101+
if pad_val:
70102
res = res[:, :orig_n]
71-
72103
return res
73104

74-
def forward(self, x: torch.Tensor) -> torch.Tensor:
75-
"""Forward using low-level sparse ops"""
76-
# Reshape to 2D
77-
leading = x.shape[:-1]
78-
T = 1
79-
for d in leading:
80-
T *= int(d)
81-
82-
x_2d = x.reshape(T, self.hidden_size).contiguous() # [T, H]
83-
84-
# Gate projection - use transpose approach that we know works
85-
xT = x_2d.transpose(0, 1) # [H, T].contiguous()
86-
gateT = self._spmm_pick(self.gate_sparse, xT) # [I, T]
87-
gate = gateT.transpose(0, 1) # [T, I]
105+
def forward(self, x):
106+
# x: [B, T, H]
107+
B, T, H = x.shape
108+
BT = B * T
109+
110+
x2d = x.view(BT, H)
111+
xT = x2d.transpose(0, 1) # [H, BT]
112+
113+
gateT = self._spmm_pick(self.gate_sparse, xT) # [I, BT]
114+
upT = self._spmm_pick(self.up_sparse, xT) # [I, BT]
115+
88116
if self.gate_bias is not None:
89-
gate = gate + self.gate_bias
90-
91-
# Up projection
92-
upT = self._spmm_pick(self.up_sparse, xT) # [I, T]
93-
up = upT.transpose(0, 1) # [T, I]
117+
gateT = gateT + self.gate_bias.view(-1, 1)
94118
if self.up_bias is not None:
95-
up = up + self.up_bias
96-
97-
# Activation
98-
hidden = torch.nn.functional.silu(gate) * up # [T, I]
99-
100-
# Down projection
101-
hT = hidden.transpose(0, 1) # [I, T]
102-
outT = self._spmm_pick(self.down_sparse, hT) # [H, T]
103-
out = outT.transpose(0, 1) # [T, H]
119+
upT = upT + self.up_bias.view(-1, 1)
120+
121+
hiddenT = fused_silu_mul(gateT, upT) # [I, BT]
122+
123+
# Down projection (dense for now): [H, I] @ [I, BT] -> [H, BT]
124+
outT = torch.matmul(self.down_weight, hiddenT)
125+
104126
if self.down_bias is not None:
105-
out = out + self.down_bias
106-
107-
return out.reshape(*leading, self.hidden_size)
108-
109-
110-
def make_sparseflow_mlp(mlp_module, policy: SparseFlowPolicy = SparseFlowPolicy()):
111-
"""Convert LlamaMLP to SparseFlowMLP"""
112-
return SparseFlowMLP(
113-
hidden_size=mlp_module.gate_proj.in_features,
114-
intermediate_size=mlp_module.gate_proj.out_features,
115-
gate_weight=mlp_module.gate_proj.weight.data,
116-
up_weight=mlp_module.up_proj.weight.data,
117-
down_weight=mlp_module.down_proj.weight.data,
118-
policy=policy,
119-
)
127+
outT = outT + self.down_bias.view(-1, 1)
128+
129+
out2d = outT.transpose(0, 1) # [BT, H]
130+
return out2d.view(B, T, H)
131+
132+
133+
def make_sparseflow_mlp(*args, **kwargs):
134+
"""Backward-compat factory for older tooling."""
135+
return SparseFlowMLP(*args, **kwargs)

0 commit comments

Comments
 (0)