Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
1916626
Use sorted batch pointers for Triton masks
denix56 Jan 24, 2026
fe21563
Merge pull request #1 from denix56/codex/add-triton-module-with-tests
denix56 Jan 24, 2026
03a3d7a
Fix triton pairwise grid sizing
denix56 Jan 24, 2026
f709778
Merge pull request #2 from denix56/codex/fix-assertion-error-in-knn-test
denix56 Jan 24, 2026
7b1f2db
Add Triton support
denix56 Jan 26, 2026
02cc6ae
Replace jit.script with custom ops registration using TORCH_LIBRARY
denix56 Jan 28, 2026
3b9b3d8
Merge pull request #3 from denix56/triton
denix56 Jan 28, 2026
4bdd9b0
Compute split-N tile counts via heuristics
denix56 Jan 29, 2026
cd01935
Merge pull request #4 from denix56/codex/implement-fused-k-nn-in-triton
denix56 Jan 29, 2026
9ca282b
Add Triton support
denix56 Jan 29, 2026
c2927bd
Add Triton support
denix56 Jan 30, 2026
ee711ea
Add Triton support
denix56 Jan 30, 2026
b4a5698
Add Triton support
denix56 Jan 30, 2026
cd17eef
Add Triton support (#6)
denix56 Jan 30, 2026
74a2a58
Fix flake8 errors
denix56 Jan 30, 2026
fcd64f2
Add Triton support (#7)
denix56 Jan 30, 2026
fb40555
Remove meaningless comments
denix56 Jan 30, 2026
971ab9d
Fix triton pairwise grid sizing
denix56 Jan 24, 2026
d59acf7
Add Triton support
denix56 Jan 26, 2026
1c70637
Replace jit.script with custom ops registration using TORCH_LIBRARY
denix56 Jan 28, 2026
53d2e09
Compute split-N tile counts via heuristics
denix56 Jan 29, 2026
e903e66
Add Triton support
denix56 Jan 29, 2026
5af902d
Add Triton support
denix56 Jan 30, 2026
56c55ce
Add Triton support
denix56 Jan 30, 2026
29e0b41
Fix flake8 errors
denix56 Jan 30, 2026
102736e
Remove meaningless comments
denix56 Jan 30, 2026
163fd6a
Add Triton support (#8)
denix56 Jan 30, 2026
a266f8e
Merge pull request #9 from denix56/triton
denix56 Jan 30, 2026
b693b30
Merge branch 'master' into triton
denix56 Jan 30, 2026
56d33dd
Merge pull request #11 from denix56/triton
denix56 Jan 30, 2026
4106591
Fix flake8 errors
denix56 Jan 30, 2026
0ad7981
Merge pull request #12 from denix56/triton
denix56 Jan 30, 2026
1c501d6
Gate torch.compile in tests for windows
denix56 Jan 30, 2026
2be1748
Merge pull request #13 from denix56/triton
denix56 Jan 30, 2026
4b30dc5
Fix mismatches with CUDA, added radius search support
denix56 Jan 30, 2026
a5e7d09
Fix mismatches with CUDA, added radius search support
denix56 Jan 31, 2026
4878474
Improve operations performance
denix56 Jan 31, 2026
2a89407
Improve operations performance
denix56 Jan 31, 2026
31551f1
Merge pull request #15 from denix56/triton
denix56 Jan 31, 2026
7b26a00
Fix flake8 errors
denix56 Jan 31, 2026
f717cb7
Merge pull request #16 from denix56/triton
denix56 Jan 31, 2026
437500e
Make radius func consistent with knn
denix56 Jan 31, 2026
1713540
Fix flake8 errors
denix56 Jan 31, 2026
cec6c99
Fix compilation error
denix56 Jan 31, 2026
8a875e1
Fix numerical errors
denix56 Feb 1, 2026
6aab27f
Fix numerical errors
denix56 Feb 1, 2026
2d5f204
Fix typing
denix56 Feb 1, 2026
38acdf8
Minor fixes
denix56 Feb 2, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
495 changes: 495 additions & 0 deletions benchmarks/test_benchmark_knn.py

Large diffs are not rendered by default.

157 changes: 157 additions & 0 deletions benchmarks/test_benchmark_nearest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import importlib.util
from itertools import product
from typing import Optional

import pytest
import torch
import torch_cluster as tc

nearest = tc.nearest

pytestmark = pytest.mark.skipif(
not (
torch.ops.torch_cluster.cuda_version() != -1
and importlib.util.find_spec('triton') is not None
),
reason='CUDA and Triton are required for Triton benchmark tests.',
)

NEAREST_SIZES = [
(256, 128),
(1024, 512),
(4096, 2048),
(8192, 4096),
(8192, 8192),
(8201, 4103),
(32000, 32000),
(255, 127),
(256, 5),
(1024, 5),
(4096, 5),
(255, 5),
]
NEAREST_GROUPS = [1, 2, 4, 8, 16, 32]
FEATURES = [8, 64, 200]


def _assert_nearest_within_cuda(
out_cuda: torch.Tensor,
out_triton: torch.Tensor,
x: torch.Tensor,
y: torch.Tensor,
tol: Optional[float] = None,
) -> None:
if tol is None:
tol = 5 * torch.finfo(x.dtype).eps
x_f = x.float()
y_f = y.float()
cuda_dist = ((x_f - y_f[out_cuda]) ** 2).sum(dim=1)
triton_dist = ((x_f - y_f[out_triton]) ** 2).sum(dim=1)
thresh = cuda_dist + tol
margin = (triton_dist - thresh).max().item()
max_diff = torch.abs(triton_dist - cuda_dist).max().item()
print(f"[nearest][match] max_margin={margin:.6e} tol={tol:.1e}")
print(f"[nearest][match] max_diff={max_diff:.6e} tol={tol:.1e}")
assert (triton_dist <= thresh).all()
assert max_diff <= tol


def _make_batch(
num_nodes: int,
num_groups: int,
device: torch.device,
) -> torch.Tensor:
groups = max(1, min(num_groups, num_nodes))
counts = torch.full(
(groups,),
num_nodes // groups,
device=device,
dtype=torch.long,
)
remainder = num_nodes % groups
if remainder:
counts[:remainder] += 1
return torch.repeat_interleave(
torch.arange(groups, device=device),
counts,
)


def _nearest_param_grid():
return (
(*p[0], p[1], p[2])
for p in product(NEAREST_SIZES, NEAREST_GROUPS, FEATURES)
if p[1] <= min(p[0])
)


@pytest.mark.parametrize(
'num_x,num_y,num_groups,num_features',
_nearest_param_grid(),
)
@pytest.mark.benchmark(group="nearest")
def test_triton_nearest_benchmark_cuda(
benchmark,
num_x,
num_y,
num_groups,
num_features,
):
torch.manual_seed(123)
x = torch.randn(num_x, num_features, device='cuda')
y = torch.randn(num_y, num_features, device='cuda')
groups = min(num_groups, x.size(0), y.size(0))
batch_x = _make_batch(num_x, groups, x.device)
batch_y = _make_batch(num_y, groups, y.device)

def cuda_fn():
nearest(x, y, batch_x, batch_y, use_triton=False)

for _ in range(5):
cuda_fn()
torch.cuda.synchronize()

benchmark(cuda_fn)
print(
f"[nearest][cuda] num_x={num_x} num_y={num_y} groups={groups}"
)


@pytest.mark.parametrize(
'num_x,num_y,num_groups,num_features',
_nearest_param_grid(),
)
@pytest.mark.benchmark(group="nearest")
def test_triton_nearest_benchmark_triton(
benchmark,
num_x,
num_y,
num_groups,
num_features,
):
torch.manual_seed(123)
x = torch.randn(num_x, num_features, device='cuda')
y = torch.randn(num_y, num_features, device='cuda')
groups = min(num_groups, x.size(0), y.size(0))
batch_x = _make_batch(num_x, groups, x.device)
batch_y = _make_batch(num_y, groups, y.device)

def cuda_fn():
return nearest(x, y, batch_x, batch_y, use_triton=False)

def triton_fn():
return nearest(x, y, batch_x, batch_y, use_triton=True)

for i in range(5):
if i == 0:
out_cuda = cuda_fn()
out_triton = triton_fn()
_assert_nearest_within_cuda(out_cuda, out_triton, x, y)
else:
triton_fn()
torch.cuda.synchronize()

benchmark(triton_fn)
print(
f"[nearest][triton] num_x={num_x} num_y={num_y} groups={groups}"
)
Loading
Loading