Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 20 additions & 19 deletions mllm-kernel/cmake/CPM.cmake
Original file line number Diff line number Diff line change
@@ -1,29 +1,30 @@
# SPDX-License-Identifier: MIT
# Download CPM.cmake on-the-fly
# This is a lightweight bootstrap that downloads the actual CPM.cmake
# Prefer the vendored CPM.cmake from the parent mllm repo. This avoids relying
# on network access for editable builds while keeping standalone fallback logic.

set(CPM_VERSION 0.42.0)
set(CPM_DOWNLOAD_LOCATION "${CMAKE_BINARY_DIR}/cmake/CPM_${CPM_VERSION}.cmake")
set(PARENT_CPM "${CMAKE_CURRENT_LIST_DIR}/../../cmake/CPM.cmake")

if(NOT EXISTS ${CPM_DOWNLOAD_LOCATION})
message(STATUS "Downloading CPM.cmake v${CPM_VERSION}...")
file(DOWNLOAD
https://github.com/cpm-cmake/CPM.cmake/releases/download/v${CPM_VERSION}/CPM.cmake
${CPM_DOWNLOAD_LOCATION}
STATUS download_status
)
list(GET download_status 0 download_status_code)
if(NOT download_status_code EQUAL 0)
# Fallback: copy from parent mllm project if available
set(PARENT_CPM "${CMAKE_CURRENT_SOURCE_DIR}/../cmake/CPM.cmake")
if(EXISTS ${PARENT_CPM})
message(STATUS "Using CPM.cmake from parent project")
file(COPY ${PARENT_CPM} DESTINATION "${CMAKE_BINARY_DIR}/cmake/")
file(RENAME "${CMAKE_BINARY_DIR}/cmake/CPM.cmake" ${CPM_DOWNLOAD_LOCATION})
else()
if(EXISTS "${PARENT_CPM}")
include("${PARENT_CPM}")
else()
if(NOT EXISTS "${CPM_DOWNLOAD_LOCATION}")
message(STATUS "Downloading CPM.cmake v${CPM_VERSION}...")
file(DOWNLOAD
https://github.com/cpm-cmake/CPM.cmake/releases/download/v${CPM_VERSION}/CPM.cmake
"${CPM_DOWNLOAD_LOCATION}"
STATUS download_status
)
list(GET download_status 0 download_status_code)
if(NOT download_status_code EQUAL 0)
message(FATAL_ERROR "Failed to download CPM.cmake")
endif()
endif()

include("${CPM_DOWNLOAD_LOCATION}")
endif()

include(${CPM_DOWNLOAD_LOCATION})
if(NOT COMMAND CPMAddPackage)
message(FATAL_ERROR "CPM.cmake loaded, but CPMAddPackage is not available")
endif()
6 changes: 4 additions & 2 deletions mllm-kernel/include/mllm_kernel/scalar_type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include <variant>
#endif

namespace host {
namespace mllm_kernel::host {

//
// ScalarType can represent a wide range of floating point and integer types,
Expand Down Expand Up @@ -257,4 +257,6 @@ static inline constexpr auto kFloat16 = kHalf;
static inline constexpr auto kBFloat16 = kFE8M7;

static inline constexpr auto kFloat16Id = kFloat16.id();
} // namespace host
} // namespace mllm_kernel::host

namespace host = ::mllm_kernel::host;
5 changes: 1 addition & 4 deletions mllm-kernel/mllm_kernel/cuda/csrc/gemm/marlin/marlin.cuh
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
#pragma once

#include <mllm_kernel/utils.cuh>
#include <mllm_kernel/scalar_type.hpp>

#include <iostream>

// Bridge the mllm_kernel::host namespace to the `host` namespace expected by
// Marlin code (originally from sglang).
namespace host = ::mllm_kernel::host;

namespace device::marlin {
// Marlin params

Expand Down
2 changes: 2 additions & 0 deletions mllm-kernel/mllm_kernel/cuda/jit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
from .awq_marlin_repack import awq_marlin_repack
from .gdn_decode import gdn_decode
from .gptq_marlin import gptq_marlin_gemm
from .gptq_marlin_repack import gptq_marlin_repack
from .store_cache import can_use_store_cache, store_cache

__all__ = [
"add_constant",
"awq_marlin_repack",
"gptq_marlin_repack",
"can_use_store_cache",
"gdn_decode",
"gptq_marlin_gemm",
Expand Down
7 changes: 4 additions & 3 deletions mllm-kernel/mllm_kernel/cuda/jit/gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,14 @@
@cache_once
def _make_gptq_marlin_gemm_kernel(dtype: torch.dtype):
"""JIT-compile the GPTQ Marlin GEMM kernel for a specific dtype."""
args = make_cpp_args(dtype)
cpp_args = make_cpp_args(dtype)

@jit(
args=args,
args=[dtype],
device="cuda",
cuda_files=["gemm/marlin/gptq_marlin.cuh"],
cuda_wrappers=[("gptq_marlin_gemm", f"gptq_marlin_gemm<{args}>")],
cpp_wrappers=[],
cuda_wrappers=[("gptq_marlin_gemm", f"gptq_marlin_gemm<{cpp_args}>")],
func_name="gptq_marlin_gemm",
)
def _kernel(
Expand Down
75 changes: 75 additions & 0 deletions mllm-kernel/mllm_kernel/cuda/jit/gptq_marlin_repack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""GPTQ/Compressed-Tensors Marlin repack CUDA JIT kernel."""

from __future__ import annotations

from typing import Optional

import torch

from mllm_kernel.jit_utils import cache_once, jit


def _normalize_perm(
perm: Optional[torch.Tensor], size_k: int, device: torch.device
) -> torch.Tensor:
if perm is None or perm.numel() == 0:
return torch.empty(0, dtype=torch.int32, device=device)
if perm.device != device:
raise ValueError("perm must live on the same device as b_q_weight")
if perm.dtype != torch.int32:
raise ValueError("perm must be int32")
if perm.numel() != size_k:
raise ValueError("perm length must equal size_k")
if torch.any(perm < 0) or torch.any(perm >= size_k):
raise ValueError("perm values must be in [0, size_k)")
return perm.contiguous()


@cache_once
def _make_gptq_marlin_repack_kernel():
"""JIT-compile the GPTQ repack kernel."""

@jit(
args=[],
device="cuda",
cuda_files=["gemm/marlin/gptq_marlin_repack.cuh"],
cpp_wrappers=[],
cuda_wrappers=[("gptq_marlin_repack", "gptq_marlin_repack")],
func_name="gptq_marlin_repack",
)
def _kernel(
compiled_module,
b_q_weight: torch.Tensor,
perm: torch.Tensor,
out: torch.Tensor,
size_k: int,
size_n: int,
num_bits: int,
) -> None:
compiled_module.gptq_marlin_repack(
b_q_weight, perm, out, size_k, size_n, num_bits
)

return _kernel


def gptq_marlin_repack(
b_q_weight: torch.Tensor,
perm: Optional[torch.Tensor],
size_k: int,
size_n: int,
num_bits: int,
) -> torch.Tensor:
"""Repack GPTQ/Compressed-Tensors weights into Marlin layout."""

pack_factor = 32 // num_bits
tile_size = 16
out = torch.empty(
(size_k // tile_size, size_n * tile_size // pack_factor),
dtype=b_q_weight.dtype,
device=b_q_weight.device,
)
kernel = _make_gptq_marlin_repack_kernel()
perm_t = _normalize_perm(perm, size_k, b_q_weight.device)
kernel(b_q_weight, perm_t, out, size_k, size_n, num_bits)
Comment on lines +56 to +74
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Reject unsupported size_k/size_n/num_bits inputs before allocating out.

Line 68 derives the output shape with floor division, so unsupported inputs are silently truncated instead of failing fast: a size_k tail is dropped, a non-64-multiple size_n can leave part of out uninitialized, and a non-divisor num_bits produces the wrong packing factor. Since this wrapper is now public, please validate num_bits, alignment, and b_q_weight.shape/dtype/device before the allocation and kernel launch.

Suggested guardrail
 def gptq_marlin_repack(
     b_q_weight: torch.Tensor,
     perm: Optional[torch.Tensor],
     size_k: int,
     size_n: int,
     num_bits: int,
 ) -> torch.Tensor:
     """Repack GPTQ/Compressed-Tensors weights into Marlin layout."""
 
+    if b_q_weight.dtype != torch.int32:
+        raise ValueError("b_q_weight must be int32")
+    if b_q_weight.device.type != "cuda":
+        raise ValueError("b_q_weight must live on CUDA")
+    if num_bits <= 0 or 32 % num_bits != 0:
+        raise ValueError("num_bits must be a positive divisor of 32")
+    if size_k % 16 != 0:
+        raise ValueError("size_k must be divisible by 16")
+    if size_n % 64 != 0:
+        raise ValueError("size_n must be divisible by 64")
+
     pack_factor = 32 // num_bits
+    expected_shape = (size_k // pack_factor, size_n)
+    if tuple(b_q_weight.shape) != expected_shape:
+        raise ValueError(f"b_q_weight must have shape {expected_shape}")
+
     tile_size = 16
     out = torch.empty(
         (size_k // tile_size, size_n * tile_size // pack_factor),
         dtype=b_q_weight.dtype,
         device=b_q_weight.device,

As per coding guidelines, "Validate inputs for public APIs and critical internal functions."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@mllm-kernel/mllm_kernel/cuda/jit/gptq_marlin_repack.py` around lines 56 - 74,
The function gptq_marlin_repack currently computes output shape using floor
division and can silently truncate/leave uninitialized data; before allocating
out or calling kernel, validate inputs: ensure num_bits divides 32 evenly (so
pack_factor = 32 // num_bits is integer and >0), ensure size_k is a multiple of
tile_size (16), ensure (size_n * tile_size) is divisible by pack_factor (so out
shape is exact), ensure b_q_weight.shape matches (size_k, size_n) and its
dtype/device are used for out, and ensure perm (if provided) has length size_k;
raise a clear ValueError on any violation and only then call _normalize_perm,
_make_gptq_marlin_repack_kernel(), allocate out and invoke kernel.

return out
151 changes: 151 additions & 0 deletions mllm-kernel/tests/test_gptq_marlin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import pytest
import torch
import torch.nn.functional as F

from mllm_kernel.cuda.jit import gptq_marlin_gemm, gptq_marlin_repack


CUDA_ONLY = pytest.mark.skipif(
not torch.cuda.is_available(), reason="requires CUDA"
)


def _compute_scalar_type_id(
exponent: int,
mantissa: int,
signed: bool,
bias: int,
finite_values_only: bool = False,
nan_repr: int = 1,
) -> int:
bit_offset = 0
result = 0
for value, width in [
(exponent, 8),
(mantissa, 8),
(signed, 1),
(bias, 32),
(finite_values_only, 1),
(nan_repr, 8),
]:
result |= (int(value) & ((1 << width) - 1)) << bit_offset
bit_offset += width
return result


SCALAR_TYPE_UINT4B8_ID = _compute_scalar_type_id(0, 4, False, 8)


def _pack_checkpoint_weight(q_weight: torch.Tensor, num_bits: int) -> torch.Tensor:
pack_factor = 32 // num_bits
size_n, size_k = q_weight.shape
packed = torch.zeros(
(size_n, size_k // pack_factor),
dtype=torch.int32,
device=q_weight.device,
)
for i in range(pack_factor):
packed.bitwise_or_(q_weight[:, i::pack_factor].int() << (num_bits * i))
return packed


def _get_scale_perms() -> tuple[list[int], list[int]]:
scale_perm: list[int] = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single: list[int] = []
for i in range(4):
scale_perm_single.extend(
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]
)
return scale_perm, scale_perm_single


def _marlin_permute_scales(
s: torch.Tensor, size_k: int, size_n: int, group_size: int
) -> torch.Tensor:
scale_perm, scale_perm_single = _get_scale_perms()
if group_size < size_k and group_size != -1:
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
else:
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
return s.reshape((-1, size_n)).contiguous()


def _marlin_make_workspace(device: torch.device) -> torch.Tensor:
sms = torch.cuda.get_device_properties(device).multi_processor_count
return torch.zeros(sms, dtype=torch.int, device=device, requires_grad=False)


@CUDA_ONLY
def test_gptq_marlin_gemm_matches_reference_for_uint4b8() -> None:
torch.manual_seed(2026)
device = torch.device("cuda")
size_m = 13
size_n = 64
size_k = 128
group_size = 32
num_bits = 4

q_weight = torch.randint(
0,
1 << num_bits,
(size_n, size_k),
dtype=torch.int32,
device=device,
)
scales = (
torch.rand(
(size_n, size_k // group_size),
dtype=torch.float16,
device=device,
)
+ 0.5
)
packed = _pack_checkpoint_weight(q_weight, num_bits=num_bits)
empty = torch.empty(0, dtype=torch.int32, device=device)
marlin_q = gptq_marlin_repack(
packed.t().contiguous(),
perm=empty,
size_k=size_k,
size_n=size_n,
num_bits=num_bits,
)
marlin_s = _marlin_permute_scales(
scales.t().contiguous(),
size_k=size_k,
size_n=size_n,
group_size=group_size,
)
x = torch.randn((size_m, size_k), dtype=torch.float16, device=device)
workspace = _marlin_make_workspace(device)

out = gptq_marlin_gemm(
a=x,
c=None,
b_q_weight=marlin_q,
b_scales=marlin_s,
global_scale=None,
b_zeros=empty,
g_idx=empty,
perm=empty,
workspace=workspace,
b_q_type_id=SCALAR_TYPE_UINT4B8_ID,
size_m=size_m,
size_n=size_n,
size_k=size_k,
is_k_full=True,
use_atomic_add=False,
use_fp32_reduce=False,
is_zp_float=False,
)

ref_weight = (q_weight.to(torch.float16) - 8) * scales.repeat_interleave(
group_size, dim=1
)
ref_out = F.linear(x, ref_weight)
rel_mean_err = torch.mean(torch.abs(out - ref_out)) / torch.mean(
torch.abs(ref_out)
)

assert rel_mean_err < 0.04
Loading