-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathkernelguard.py
More file actions
5832 lines (5124 loc) · 234 KB
/
kernelguard.py
File metadata and controls
5832 lines (5124 loc) · 234 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/bin/env python3
"""
KernelGuard — Rule-based GPU kernel hack detector.
Detects cheating/hacking patterns in GPU kernel competition submissions.
Usage:
kernelguard --jsonl /path/to/pairs.jsonl
kernelguard --parquet /path/to/submissions.parquet
kernelguard --audit-rules
"""
import argparse
import ast
import copy
from dataclasses import dataclass, field
import glob
import hashlib
import io
import json
import multiprocessing as mp
import os
import re
import sys
import tomllib
import tokenize
from collections import Counter, defaultdict
from datetime import datetime, timezone
from typing import Any, Optional
CONFIG_VERSION = 1
DEFAULT_PROFILE_NAME = "default"
# ---------------------------------------------------------------------------
# Compiled regex patterns (module-level for performance)
# ---------------------------------------------------------------------------
# Pattern 1: Timer monkey-patching
RE_CUDA_EVENT_ASSIGN = re.compile(r'torch\.cuda\.Event\s*=\s*')
RE_CUDA_EVENT_SETATTR = re.compile(r'setattr\s*\(\s*torch\.cuda\s*,\s*[\'"]Event[\'"]\s*,')
RE_ELAPSED_TIME_METHOD_PATCH = re.compile(r'torch\.cuda\.Event\.elapsed_time\s*=')
RE_CUDA_SYNCHRONIZE_PATCH = re.compile(r'torch\.cuda\.synchronize\s*=')
RE_DUMMY_FAKE_EVENT = re.compile(r'class\s+_?(?:Dummy|Fake|Mock|Stub|Noop)\w*Event', re.IGNORECASE)
RE_ELAPSED_TIME_FAKE = re.compile(r'def\s+elapsed_time\s*\([^)]*\)\s*:')
# Pattern 2: Print/stdout injection
RE_EMIT_FAKE_BENCH = re.compile(r'def\s+_?emit_fake_benchmark')
RE_BUILTINS_PRINT_ASSIGN = re.compile(r'builtins\.print\s*=')
RE_PATCHED_PRINT = re.compile(r'_(?:patched|real|orig|original)_print')
RE_SYS_STDOUT_REPLACE = re.compile(r'sys\.stdout\s*=')
RE_SYS_STDERR_REPLACE = re.compile(r'sys\.stderr\s*=')
RE_RANKED_BENCHMARK_PRINT = re.compile(
r'(?:print|sys\.stdout\.write)\s*\([^)]*(?:Ranked\s+Benchmark)',
re.DOTALL,
)
RE_RANKED_BENCHMARK_STR = re.compile(r'["\'].*?Ranked\s+Benchmark.*?["\']', re.DOTALL)
# Pattern 3: Evaluator exploitation
RE_MAIN_TIME_KERNEL = re.compile(r'main\.time_kernel\s*=')
RE_MAIN_CLONE_DATA = re.compile(r'main\._clone_data\s*=')
RE_MAIN_CLEAR_L2 = re.compile(r'main\.clear_l2_cache\s*=')
RE_SYS_MODULES_MAIN = re.compile(
r'sys\.modules(?:\s*\[\s*[\'"]__main__[\'"]\s*\]|\.get\s*\(\s*[\'"]__main__[\'"]\s*\))',
)
RE_MAIN_ALIAS_FROM_GET = re.compile(
r'(\w+)\s*=\s*sys\.modules\.get\s*\(\s*[\'"]__main__[\'"]\s*\)',
)
RE_MAIN_ALIAS_FROM_INDEX = re.compile(
r'(\w+)\s*=\s*sys\.modules\s*\[\s*[\'"]__main__[\'"]\s*\]',
)
RE_MAIN_ATTR_PATCH = re.compile(
r'(\w+)\.(time_kernel|_clone_data|clear_l2_cache)\s*=',
)
RE_MAIN_SETATTR_PATCH = re.compile(
r'setattr\s*\(\s*(\w+)\s*,\s*[\'"](time_kernel|_clone_data|clear_l2_cache)[\'"]\s*,',
)
# Pattern 4: torch._scaled_mm reference
RE_SCALED_MM_PYTHON = re.compile(r'torch\._scaled_mm\s*\(')
RE_SCALED_MM_ALIAS = re.compile(r'=\s*torch\._scaled_mm\s*(?!\()$', re.MULTILINE) # _mm = torch._scaled_mm (not a call)
RE_SCALED_MM_ALIAS_ASSIGN = re.compile(r'^\s*(\w+)\s*=\s*torch\._scaled_mm\s*$', re.MULTILINE)
RE_SCALED_MM_CPP = re.compile(r'aten::_scaled_mm')
RE_SCALED_MM_SCHEMA = re.compile(r'findSchemaOrThrow\s*\(\s*["\']aten::_scaled_mm')
RE_SCALED_MM_ANY = re.compile(r'_scaled_mm')
# Pattern 5: torch.mm/bmm as main compute (Python-only check)
RE_TORCH_MM = re.compile(r'torch\.mm\s*\(')
RE_TORCH_BMM = re.compile(r'torch\.bmm\s*\(')
RE_TORCH_MATMUL = re.compile(r'torch\.matmul\s*\(')
# Pattern 6: Result caching by pointer
RE_WEAK_VALUE_DICT = re.compile(r'WeakValueDictionary')
RE_DECODED_CACHE = re.compile(r'_decoded_cache')
RE_PREPROCESS_CACHE = re.compile(r'_PREPROCESS_CACHE')
RE_RESULT_REUSE = re.compile(r'_result_reuse')
RE_SCALE_CACHE = re.compile(r'_scale_cache')
RE_OUTPUT_CACHE = re.compile(r'_OUTPUT_CACHE')
RE_RESULT_CACHE_GENERAL = re.compile(r'_(?:RESULT|GROUPED_RESULT|COMPUTE|GEMM)_CACHE')
RE_ID_DATA_CACHE = re.compile(r'id\s*\(\s*data\s*\)')
RE_DATA_PTR_CACHE_KEY = re.compile(
r'(?:\w+\s*\[\s*int\s*\(\s*\w+\.data_ptr\s*\(\)\s*\)\s*\])|' # cache[int(A.data_ptr())]
r'(?:\.data_ptr\s*\(\)\s*[,\)].*?(?:cache|key|dict|hash))',
re.IGNORECASE,
)
RE_VERSION_CACHE = re.compile(r'\._version\b')
RE_CACHE_GET_ASSIGN = re.compile(
r'(\w+)\s*=\s*(\w*(?:cache|reuse)\w*)\.get\s*\(',
re.IGNORECASE,
)
# "return cache[...]" but exclude compiled-kernel / module caches (legitimate)
RE_RETURN_CACHE_INDEX = re.compile(
r'return\s+(?!_?(?:compiled|kernel|module|func|op)_?\w*cache)'
r'\w*(?:cache|reuse)\w*\s*\[',
re.IGNORECASE,
)
RE_CACHE_STORE_OUTPUT = re.compile(
r'\w*(?:cache|reuse)\w*\s*\[[^\]]+\]\s*=\s*(?:result|output|out|c_ref)\b',
re.IGNORECASE,
)
RE_CPP_CACHE_FAST_RETURN = re.compile(
r'if\s*\([^)]*(?:cache\.last|cache\.prev|lastA|lastB|prevA|prevB)[^)]*\)\s*(?:\{[^{}]{0,200}?\breturn\b|return\b)',
re.DOTALL,
)
# Pattern 7: CUDA graph replay
RE_CTYPES_LIBCUDART = re.compile(r'ctypes\.CDLL\s*\(\s*["\']libcudart')
RE_CUDA_GRAPH_CREATE = re.compile(r'cudaGraphCreate')
RE_CUDA_GRAPH_LAUNCH = re.compile(r'cudaGraphLaunch')
RE_CUDA_GRAPH_INSTANTIATE = re.compile(r'cudaGraphInstantiate')
RE_CUDA_STREAM_CAPTURE = re.compile(r'cudaStreamBeginCapture')
# Pattern 8: Silent fallback — limit scope to ~500 chars after except (roughly 10-15 lines)
RE_TORCH_MATMUL_FALLBACK = re.compile(
r'except\s+(?:Exception|RuntimeError|BaseException|TypeError)[^:]*:'
r'[^\n]{0,500}?'
r'(?:torch\.matmul|torch\.mm\s*\(|torch\.bmm\s*\()',
re.DOTALL,
)
# Pattern 9: Trivial/probe
RE_RETURN_EMPTY = re.compile(r'return\s+\[\s*\]')
RE_RETURN_NONE = re.compile(r'return\s+None\b')
# Pattern 11: torch.compile
RE_TORCH_COMPILE = re.compile(r'torch\.compile\s*\(')
# Pattern 12: Python CUDAGraph caching (not ctypes — uses torch.cuda.CUDAGraph API)
RE_CUDA_GRAPH_PYTHON = re.compile(r'torch\.cuda\.CUDAGraph\s*\(')
RE_CUDA_GRAPH_REPLAY = re.compile(r'\.replay\s*\(')
RE_CUDA_GRAPH_CACHE_KEY = re.compile(
r'data_ptr\s*\(\).*?(?:graph|cache|dict)|'
r'(?:graph|cache|dict).*?data_ptr\s*\(\)',
re.IGNORECASE | re.DOTALL,
)
RE_GRAPH_CACHE_NAME = re.compile(r'\b(?:_graph_cache|graph_cache)\b', re.IGNORECASE)
# Pattern 13: Hardcoded benchmark shapes (KERNEL_CONFIGS gated by exact dimensions)
RE_KERNEL_CONFIGS_DICT = re.compile(r'KERNEL_CONFIGS\s*=\s*\{')
RE_SHAPE_TUPLE_KEY = re.compile(r'\(\s*\d{2,5}\s*,\s*\d{2,5}\s*(?:,\s*\d{1,5}\s*)?\)\s*:')
RE_SHAPE_IF_GATE = re.compile(
r'if\s+.*?(?:==|in)\s*[\[(]?\s*\(?\s*\d{3,5}\s*,\s*\d{3,5}',
)
# Pattern 14: Unsynchronized multi-stream dispatch
RE_GET_STREAM_FROM_POOL = re.compile(r'getStreamFromPool|get_stream_from_pool|torch\.cuda\.Stream\s*\(')
RE_NO_SYNC_STREAM = re.compile(r'(?:stream|s)\d*\.synchronize\s*\(\)')
RE_STREAM_WAIT_EVENT = re.compile(r'\.wait_event\s*\(')
RE_STREAM_WAIT_STREAM = re.compile(r'\.wait_stream\s*\(')
RE_TORCH_CUDA_SYNCHRONIZE = re.compile(r'torch\.cuda\.synchronize\s*\(')
RE_CPP_STREAM_SYNC = re.compile(
r'(?:cudaStreamSynchronize|cudaDeviceSynchronize|cudaEventSynchronize|cudaStreamWaitEvent)\s*\(',
)
RE_CPP_METHOD_SYNC = re.compile(r'\.(?:synchronize|wait_event|wait_stream)\s*\(')
# Pattern 15: cudaEventDisableTiming
RE_CUDA_EVENT_DISABLE_TIMING = re.compile(r'cudaEventDisableTiming|disable_timing\s*=\s*True')
# C++/CUDA block markers for stripping
RE_CPP_MARKERS = re.compile(r'#include|__global__|__device__|__host__|extern\s+"C"|asm\s+volatile')
RE_TRIPLE_QUOTED = re.compile(r'(?:r)?(?:"""|\'\'\')(.*?)(?:"""|\'\'\')', re.DOTALL)
# ---------------------------------------------------------------------------
# Utility: strip C++/CUDA inline source from Python code
# ---------------------------------------------------------------------------
_LARGE_LITERAL_THRESHOLD = 2000 # chars; blobs larger than this are stripped
def strip_cpp_cuda_blocks(code: str) -> str:
"""Remove C/CUDA source literals and large data blobs from Python code.
Strips triple-quoted strings that either:
- Contain C++/CUDA source markers (__global__, #include, etc.), OR
- Are larger than _LARGE_LITERAL_THRESHOLD characters (e.g. base64 blobs,
embedded source code, lookup tables). These blobs are never hacking
indicators and their long all-word-char content causes catastrophic
backtracking in patterns like (\\w+)\\s*=\\s*literal.
"""
def replacer(match):
content = match.group(1)
if len(content) > _LARGE_LITERAL_THRESHOLD or RE_CPP_MARKERS.search(content):
return '""""""' # empty triple-quoted string
return match.group(0)
return RE_TRIPLE_QUOTED.sub(replacer, code)
def strip_python_comments(code: str) -> str:
"""Remove Python comments without touching string literals."""
try:
tokens = [
tok
for tok in tokenize.generate_tokens(io.StringIO(code).readline)
if tok.type != tokenize.COMMENT
]
except (tokenize.TokenError, IndentationError, SyntaxError):
return code
return tokenize.untokenize(tokens)
def extract_function_block(code: str, func_name: str) -> str:
"""Best-effort extraction of a Python function block from source text."""
lines = code.splitlines()
func_re = re.compile(rf'^\s*def\s+{re.escape(func_name)}\s*\(')
for i, line in enumerate(lines):
if not func_re.match(line):
continue
base_indent = len(line) - len(line.lstrip())
block = [line]
for nxt in lines[i + 1:]:
stripped = nxt.strip()
cur_indent = len(nxt) - len(nxt.lstrip())
if stripped and cur_indent <= base_indent and re.match(r'^\s*(def|class)\s+\w+', nxt):
break
block.append(nxt)
return "\n".join(block)
return ""
RE_IMPORT_MAIN_AS = re.compile(r'import\s+__main__\s+as\s+(\w+)')
RE_FROM_IMPORT_MAIN = re.compile(r'from\s+__main__\s+import\s+')
def find_main_aliases(code: str) -> set[str]:
"""Collect local variable aliases bound to sys.modules['__main__']."""
aliases = {"main"}
for m in RE_MAIN_ALIAS_FROM_GET.finditer(code):
aliases.add(m.group(1))
for m in RE_MAIN_ALIAS_FROM_INDEX.finditer(code):
aliases.add(m.group(1))
for m in RE_IMPORT_MAIN_AS.finditer(code):
aliases.add(m.group(1))
return aliases
def find_scaled_mm_aliases(code: str) -> set[str]:
"""Collect local aliases bound directly to torch._scaled_mm."""
aliases = set()
for m in re.finditer(r'^\s*(\w+)\s*=\s*torch\._scaled_mm\s*$', code, re.MULTILINE):
aliases.add(m.group(1))
return aliases
def function_uses_scaled_mm(func_body: str, alias_names: set[str]) -> bool:
"""Return True when a function body calls torch._scaled_mm or one of its aliases."""
if RE_SCALED_MM_PYTHON.search(func_body):
return True
for alias in alias_names:
if re.search(rf'\b{re.escape(alias)}\s*\(', func_body):
return True
return False
BASE_ENTRYPOINT_NAMES = ("custom_kernel",)
ENTRYPOINT_NAMES = BASE_ENTRYPOINT_NAMES
def _entrypoint_candidates() -> tuple[str, ...]:
return ENTRYPOINT_NAMES
def is_entrypoint_name(name: str) -> bool:
return name in ENTRYPOINT_NAMES
def entrypoint_label(name: Optional[str] = None) -> str:
return name or (ENTRYPOINT_NAMES[0] if ENTRYPOINT_NAMES else "entrypoint")
@dataclass
class SubmissionFacts:
"""Shared normalized views and AST summaries for one submission."""
raw_code: str
python_only: str
python_active: str
ast_tree: Optional[ast.AST]
main_aliases: set[str]
scaled_mm_aliases: set[str]
trusted_aliases: dict[str, str]
entrypoint_name: Optional[str]
custom_kernel_pos: Optional[int]
code_before_custom_kernel: str
code_from_custom_kernel: str
custom_kernel_code: str
custom_kernel_active: str
_function_blocks: dict[str, str] = field(default_factory=dict)
_active_function_blocks: dict[str, str] = field(default_factory=dict)
# --- Pre-computed AST indices (populated by _build_ast_index) ---
# Nodes that contain a .data_ptr() call anywhere in their subtree
_nodes_with_data_ptr: set[int] = field(default_factory=set)
# Nodes that contain ._version attribute access
_nodes_with_version: set[int] = field(default_factory=set)
# Function names (non-entrypoint) whose body contains data_ptr / _version
_data_ptr_helpers: set[str] = field(default_factory=set)
_version_helpers: set[str] = field(default_factory=set)
# Module-level vars initialized to None
_none_inited: set[str] = field(default_factory=set)
# All assignments: {target_name: [value_node, ...]}
_assignments_by_target: dict[str, list] = field(default_factory=dict)
# All import statements
_imports: list = field(default_factory=list)
_import_froms: list = field(default_factory=list)
# Class definitions
_class_defs: list = field(default_factory=list)
def get_function_block(self, func_name: str) -> str:
block = self._function_blocks.get(func_name)
if block is None:
block = extract_function_block(self.raw_code, func_name)
self._function_blocks[func_name] = block
return block
def get_active_function_block(self, func_name: str) -> str:
block = self._active_function_blocks.get(func_name)
if block is None:
block = strip_python_comments(self.get_function_block(func_name))
self._active_function_blocks[func_name] = block
return block
def build_submission_facts(code: str) -> SubmissionFacts:
"""Parse and normalize a submission once for reuse across all detectors."""
python_only = strip_cpp_cuda_blocks(code)
python_active = strip_python_comments(python_only)
tree = _safe_ast_parse(code)
entrypoint_name = None
custom_kernel_match = None
for candidate_name in _entrypoint_candidates():
match = re.search(rf'^\s*def\s+{re.escape(candidate_name)}\s*\(', code, re.MULTILINE)
if match is not None:
entrypoint_name = candidate_name
custom_kernel_match = match
break
custom_kernel_pos = custom_kernel_match.start() if custom_kernel_match else None
code_before_custom_kernel = code[:custom_kernel_pos] if custom_kernel_pos is not None else code
code_from_custom_kernel = code[custom_kernel_pos:] if custom_kernel_pos is not None else code
custom_kernel_code = extract_function_block(code, entrypoint_name or entrypoint_label())
custom_kernel_active = strip_python_comments(custom_kernel_code)
trusted_aliases = _collect_trusted_aliases(tree) if tree is not None else {}
facts = SubmissionFacts(
raw_code=code,
python_only=python_only,
python_active=python_active,
ast_tree=tree,
main_aliases=find_main_aliases(python_only),
scaled_mm_aliases=find_scaled_mm_aliases(code_before_custom_kernel),
trusted_aliases=trusted_aliases,
entrypoint_name=entrypoint_name,
custom_kernel_pos=custom_kernel_pos,
code_before_custom_kernel=code_before_custom_kernel,
code_from_custom_kernel=code_from_custom_kernel,
custom_kernel_code=custom_kernel_code,
custom_kernel_active=custom_kernel_active,
)
if entrypoint_name:
facts._function_blocks[entrypoint_name] = custom_kernel_code
facts._active_function_blocks[entrypoint_name] = custom_kernel_active
facts._function_blocks["custom_kernel"] = custom_kernel_code
facts._active_function_blocks["custom_kernel"] = custom_kernel_active
_build_ast_index(facts)
return facts
def _build_ast_index(facts: SubmissionFacts) -> None:
"""Single-pass AST walk to populate all index fields on facts."""
tree = facts.ast_tree
if tree is None:
return
nodes_with_data_ptr: set[int] = set()
nodes_with_version: set[int] = set()
data_ptr_helpers: set[str] = set()
version_helpers: set[str] = set()
none_inited: set[str] = set()
imports: list = []
import_froms: list = []
class_defs: list = []
# Single walk: tag every node that is a data_ptr call or _version access
for node in ast.walk(tree):
if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute):
if node.func.attr == "data_ptr":
nodes_with_data_ptr.add(id(node))
if isinstance(node, ast.Attribute) and node.attr == "_version":
nodes_with_version.add(id(node))
if isinstance(node, ast.Import):
imports.append(node)
elif isinstance(node, ast.ImportFrom):
import_froms.append(node)
elif isinstance(node, ast.ClassDef):
class_defs.append(node)
# Module-level None-initialized vars
for stmt in tree.body:
if isinstance(stmt, ast.Assign):
if isinstance(stmt.value, ast.Constant) and stmt.value.value is None:
for t in stmt.targets:
n = _ast_root_name(t)
if n:
none_inited.add(n)
# Find helper functions (non-entrypoint) that contain data_ptr / _version
for node in ast.walk(tree):
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
if is_entrypoint_name(node.name):
continue
for child in ast.walk(node):
if id(child) in nodes_with_data_ptr:
data_ptr_helpers.add(node.name)
if id(child) in nodes_with_version:
version_helpers.add(node.name)
if node.name in data_ptr_helpers and node.name in version_helpers:
break
# Propagate: mark ancestor expressions as containing data_ptr / _version
# We need this for _expr_has_data_ptr / _expr_has_tensor_version replacements
# Walk each assignment value and check if any descendant has the tag
# This is still O(n) total since we do one walk and check set membership
facts._nodes_with_data_ptr = nodes_with_data_ptr
facts._nodes_with_version = nodes_with_version
facts._data_ptr_helpers = data_ptr_helpers
facts._version_helpers = version_helpers
facts._none_inited = none_inited
facts._imports = imports
facts._import_froms = import_froms
facts._class_defs = class_defs
def _expr_has_data_ptr_fast(expr: ast.AST | None, index: set[int]) -> bool:
"""O(subtree) check using pre-computed index — avoids full ast.walk per call."""
if expr is None:
return False
for node in ast.walk(expr):
if id(node) in index:
return True
return False
def _expr_has_version_fast(expr: ast.AST | None, index: set[int]) -> bool:
if expr is None:
return False
for node in ast.walk(expr):
if id(node) in index:
return True
return False
def ensure_submission_facts(code_or_facts: str | SubmissionFacts) -> SubmissionFacts:
"""Accept a raw code string or a pre-built SubmissionFacts object."""
if isinstance(code_or_facts, SubmissionFacts):
return code_or_facts
return build_submission_facts(code_or_facts)
def _ast_root_name(expr: ast.AST | None) -> Optional[str]:
"""Return the left-most name that owns an expression, when present."""
cur = expr
while cur is not None:
if isinstance(cur, ast.Name):
return cur.id
if isinstance(cur, ast.Attribute):
cur = cur.value
continue
if isinstance(cur, ast.Subscript):
cur = cur.value
continue
break
return None
def _expr_names(expr: ast.AST | None) -> set[str]:
if expr is None:
return set()
return {
node.id
for node in ast.walk(expr)
if isinstance(node, ast.Name)
}
def _expr_has_data_ptr(expr: ast.AST | None) -> bool:
if expr is None:
return False
return any(
isinstance(node, ast.Call)
and isinstance(node.func, ast.Attribute)
and node.func.attr == "data_ptr"
for node in ast.walk(expr)
)
def _expr_has_tensor_version(expr: ast.AST | None) -> bool:
if expr is None:
return False
return any(
isinstance(node, ast.Attribute) and node.attr == "_version"
for node in ast.walk(expr)
)
_TRIVIAL_GPU_OPS = frozenset({
"fill_", "zero_", "copy_", "fill", "zero", "record",
})
def _body_has_calls(body: list[ast.stmt]) -> bool:
"""Return True if the body contains non-trivial function calls.
Tiny GPU ops like ``_tiny.fill_(0)`` or ``_anchor.copy_(_anchor)`` are
common dummy work used to keep CUDA timers non-zero; they don't count
as real compute and should not prevent replay detection.
"""
for stmt in body:
for nested in ast.walk(stmt):
if not isinstance(nested, ast.Call):
continue
# Allow trivial method calls: obj.fill_(0), obj.copy_(obj), etc.
if (isinstance(nested.func, ast.Attribute)
and nested.func.attr in _TRIVIAL_GPU_OPS):
continue
return True
return False
def _looks_stateful_name(name: str) -> bool:
lowered = name.lower()
return any(token in lowered for token in ("last", "prev", "cache", "saved", "memo"))
# ---------------------------------------------------------------------------
# Detectors
# ---------------------------------------------------------------------------
def detect_timer_monkeypatch(code: str | SubmissionFacts) -> list[dict]:
"""Pattern 1: torch.cuda.Event replaced with fake class."""
facts = ensure_submission_facts(code)
code = facts.raw_code
matches = []
if RE_CUDA_EVENT_ASSIGN.search(code):
matches.append({
"pattern": "TIMER_MONKEYPATCH",
"severity": "critical",
"evidence": "torch.cuda.Event reassigned",
})
if RE_CUDA_EVENT_SETATTR.search(code):
matches.append({
"pattern": "TIMER_MONKEYPATCH",
"severity": "critical",
"evidence": "setattr(torch.cuda, 'Event', ...) patching",
})
if RE_ELAPSED_TIME_METHOD_PATCH.search(code):
matches.append({
"pattern": "TIMER_MONKEYPATCH",
"severity": "critical",
"evidence": "torch.cuda.Event.elapsed_time method patched",
})
if RE_CUDA_SYNCHRONIZE_PATCH.search(code):
matches.append({
"pattern": "TIMER_MONKEYPATCH",
"severity": "critical",
"evidence": "torch.cuda.synchronize patched",
})
if RE_DUMMY_FAKE_EVENT.search(code):
matches.append({
"pattern": "TIMER_MONKEYPATCH",
"severity": "critical",
"evidence": "Dummy/Fake Event class defined",
})
return matches
def detect_print_injection(code: str | SubmissionFacts) -> list[dict]:
"""Pattern 2: separate real benchmark spoofing from benign stdio handling."""
facts = ensure_submission_facts(code)
python_active = facts.python_active
matches = []
seen: set[tuple[str, str]] = set()
def add(pattern: str, severity: str, evidence: str):
key = (pattern, evidence)
if key in seen:
return
seen.add(key)
matches.append({
"pattern": pattern,
"severity": severity,
"evidence": evidence,
})
fake_emit = False
if RE_EMIT_FAKE_BENCH.search(python_active):
fake_emit = True
add("FAKE_BENCHMARK_EMIT", "critical", "_emit_fake_benchmark function defined")
if RE_BUILTINS_PRINT_ASSIGN.search(python_active):
fake_emit = True
add("FAKE_BENCHMARK_EMIT", "critical", "builtins.print reassigned")
if RE_RANKED_BENCHMARK_PRINT.search(python_active):
fake_emit = True
add("FAKE_BENCHMARK_EMIT", "critical", "Ranked Benchmark string printed directly")
elif RE_PATCHED_PRINT.search(python_active) and RE_RANKED_BENCHMARK_STR.search(python_active):
fake_emit = True
add("FAKE_BENCHMARK_EMIT", "critical", "print patching plus Ranked Benchmark payload")
# Stdio redirection is common as a compatibility shim around compilation and
# should not be treated as benchmark spoofing on its own.
if RE_SYS_STDOUT_REPLACE.search(python_active):
if re.search(r'if\s+sys\.stdout\s+is\s+None\s*:', python_active):
add("STDIO_REDIRECT", "low", "guarded sys.stdout compatibility redirect")
else:
add("STDIO_REDIRECT", "medium", "sys.stdout reassigned")
if RE_SYS_STDERR_REPLACE.search(python_active):
if re.search(r'if\s+sys\.stderr\s+is\s+None\s*:', python_active):
add("STDIO_REDIRECT", "low", "guarded sys.stderr compatibility redirect")
else:
add("STDIO_REDIRECT", "medium", "sys.stderr reassigned")
# Do not keep the old broad PRINT_INJECTION marker; the split rules carry
# the action semantics now.
return matches
def detect_evaluator_exploit(code: str | SubmissionFacts) -> list[dict]:
"""Pattern 3: Patching evaluator internals (main.time_kernel, etc.)."""
facts = ensure_submission_facts(code)
matches = []
seen = set()
def add_match(evidence: str):
if evidence in seen:
return
seen.add(evidence)
matches.append({
"pattern": "EVALUATOR_EXPLOIT",
"severity": "critical",
"evidence": evidence,
})
# Strip large data blobs (base64, lookup tables) before pattern matching.
# Blobs cause O(n²) backtracking in patterns like (\w+)\s*=\s*literal when
# they form very long single-line word runs (e.g. 129KB base64 strings).
code = facts.python_only
main_patch_targets = sorted(TRUSTED_FUNCTIONS)
for attr in main_patch_targets:
if re.search(rf'\bmain\.{re.escape(attr)}\s*=', code):
add_match(f"main.{attr} patched")
aliases = facts.main_aliases
# Check for alias-based patching regardless of how __main__ was obtained
for obj in sorted(aliases):
if obj == "main":
continue
for attr in main_patch_targets:
if re.search(rf'\b{re.escape(obj)}\.{re.escape(attr)}\s*=', code):
add_match(f"{obj}.{attr} patched via __main__ alias")
for obj in sorted(aliases):
for attr in main_patch_targets:
if re.search(
rf'setattr\s*\(\s*{re.escape(obj)}\s*,\s*[\'"]{re.escape(attr)}[\'"]\s*,',
code,
):
add_match(f"setattr({obj}, '{attr}', ...) on __main__ alias")
# Also flag if __main__ is imported directly (unusual in a kernel submission)
if RE_FROM_IMPORT_MAIN.search(code):
add_match("from __main__ import ... (direct harness access)")
return matches
def detect_scaled_mm_ref(code: str | SubmissionFacts) -> list[dict]:
"""Pattern 4: Using torch._scaled_mm as primary compute.
Scope-aware: if the file has a configured entrypoint function and
`_scaled_mm` only appears BEFORE that function, it's likely a
reference implementation (not the submission's compute path) and
should not be flagged.
"""
facts = ensure_submission_facts(code)
matches = []
code = facts.raw_code
entrypoint_name = entrypoint_label(facts.entrypoint_name)
custom_kernel_pos = facts.custom_kernel_pos or 0
# For scope-aware check: code from the configured entrypoint onward.
# If no entrypoint is found, check the entire file (conservative).
if facts.custom_kernel_pos is not None:
# Check if _scaled_mm is used at or after the configured entrypoint,
# or if _scaled_mm is aliased to a variable that the entrypoint could
# call indirectly.
code_from_ck = facts.code_from_custom_kernel
code_before_ck = facts.code_before_custom_kernel
has_python_after = bool(RE_SCALED_MM_PYTHON.search(code_from_ck))
has_alias_after = bool(RE_SCALED_MM_ALIAS.search(code_from_ck))
alias_names_before = facts.scaled_mm_aliases
has_alias_before = bool(alias_names_before)
has_cpp = bool(RE_SCALED_MM_CPP.search(code))
has_schema = bool(RE_SCALED_MM_SCHEMA.search(code))
# _scaled_mm only before the configured entrypoint — check if it actually
# reaches that code path through a helper or a direct alias call.
has_python_before = bool(RE_SCALED_MM_PYTHON.search(code_before_ck))
if ((has_python_before or has_alias_before) and not has_python_after and
not has_alias_after and
not has_cpp and not has_schema):
# Find function names defined before the entrypoint that use _scaled_mm
helper_funcs_with_mm = set()
for func_m in re.finditer(r'^def\s+(\w+)\s*\(', code_before_ck, re.MULTILINE):
func_name = func_m.group(1)
func_body = facts.get_function_block(func_name)
if function_uses_scaled_mm(func_body, alias_names_before):
helper_funcs_with_mm.add(func_name)
# Strip Python comments from the entrypoint body before call-site analysis
# so that commented-out calls (e.g. # result = ref_kernel(data)) are ignored.
ck_active = facts.custom_kernel_active
# Check if the entrypoint output comes from a helper that uses _scaled_mm.
# Validation-only calls (result not returned) are NOT the submission's compute.
compute_helpers = set()
for fn in helper_funcs_with_mm:
# Direct return: return helper(...)
if re.search(rf'\breturn\s+{re.escape(fn)}\s*\(', ck_active):
compute_helpers.add(fn)
continue
# Assigned return: var = helper(...) … return var
assign_m = re.search(rf'\b(\w+)\s*=\s*{re.escape(fn)}\s*\(', ck_active)
if assign_m:
var = assign_m.group(1)
after = ck_active[assign_m.end():]
if re.search(rf'\breturn\s+{re.escape(var)}\b', after):
compute_helpers.add(fn)
for alias in alias_names_before:
if re.search(rf'\breturn\s+{re.escape(alias)}\s*\(', ck_active):
compute_helpers.add(alias)
continue
assign_m = re.search(rf'\b(\w+)\s*=\s*{re.escape(alias)}\s*\(', ck_active)
if assign_m:
var = assign_m.group(1)
after = ck_active[assign_m.end():]
if re.search(rf'\breturn\s+{re.escape(var)}\b', after):
compute_helpers.add(alias)
if compute_helpers:
# Distinguish pure delegation from benchmark-shape if/else fallback.
# Two conditions must BOTH hold to downgrade to medium:
# 1. There is a primary non-ref compute branch (more return paths than
# just the ref-helper returns) → ref call is a shape-gated fallback.
# 2. All compute helpers have "ref" in their name — only the evaluator's
# own ref_kernel (or similar) counts as a true correctness fallback.
# A custom-named helper (e.g. cublas_kernel, fast_path) that wraps
# _scaled_mm is still a hack regardless of branching.
total_returns = len(re.findall(r'\breturn\b', ck_active))
# Count return paths that go through a ref helper — both direct
# ("return ref_kernel(data)") and assign-then-return
# ("results = ref_kernel(data); return results").
ref_path_returns = 0
for fn in compute_helpers:
ref_path_returns += len(
re.findall(rf'\breturn\s+{re.escape(fn)}\s*\(', ck_active)
)
for am in re.finditer(rf'\b(\w+)\s*=\s*{re.escape(fn)}\s*\(', ck_active):
var = am.group(1)
if re.search(rf'\breturn\s+{re.escape(var)}\b', ck_active[am.end():]):
ref_path_returns += 1
has_primary_custom_path = total_returns > ref_path_returns
all_helpers_are_ref_named = all(
'ref' in fn.lower() for fn in compute_helpers
)
if has_primary_custom_path and all_helpers_are_ref_named:
matches.append({
"pattern": "SCALED_MM_REF",
"severity": "medium",
"evidence": (
f"_scaled_mm helper(s) used as shape-gated fallback in custom_kernel: "
f"{', '.join(sorted(compute_helpers))} "
f"(primary path is custom, ref only in conditional branch)"
).replace("custom_kernel", entrypoint_name),
})
else:
matches.append({
"pattern": "SCALED_MM_REF",
"severity": "high",
"evidence": (
f"{entrypoint_name} returns output of _scaled_mm helper(s): "
f"{', '.join(sorted(compute_helpers))} "
),
})
return matches
# Helper called for validation only (result not returned).
# If called as a bare statement (no assignment), the helper writes
# into an in-place output argument — treat as low severity.
# Otherwise (assigned but not returned) it is a medium signal.
validation_calls = {fn for fn in helper_funcs_with_mm
if re.search(rf'\b{re.escape(fn)}\s*\(', ck_active)}
validation_calls.update(
alias for alias in alias_names_before
if re.search(rf'\b{re.escape(alias)}\s*\(', ck_active)
)
if validation_calls:
inplace = {fn for fn in validation_calls
if re.search(rf'^\s*{re.escape(fn)}\s*\(', ck_active, re.MULTILINE)}
non_inplace = validation_calls - inplace
if non_inplace:
matches.append({
"pattern": "SCALED_MM_REF",
"severity": "medium",
"evidence": (
f"_scaled_mm helper(s) called in {entrypoint_name} but result not returned: "
f"{', '.join(sorted(non_inplace))}"
),
})
if inplace:
matches.append({
"pattern": "SCALED_MM_REF",
"severity": "low",
"evidence": f"_scaled_mm helper(s) called as in-place mutation (result discarded): {', '.join(sorted(inplace))}",
})
return matches
# Helper defined but not called from the entrypoint (e.g. commented-out) → medium
if helper_funcs_with_mm:
matches.append({
"pattern": "SCALED_MM_REF",
"severity": "medium",
"evidence": (
f"_scaled_mm in helper(s) before {entrypoint_name}: "
f"{', '.join(sorted(helper_funcs_with_mm))} (not called in active code)"
),
})
return matches
# _scaled_mm or an alias only appears before the entrypoint and is never
# reached from it → likely reference/dead code, not the submission path.
return []
has_python = has_python_after
has_alias = has_alias_after
else:
has_python = bool(RE_SCALED_MM_PYTHON.search(code))
has_alias = bool(RE_SCALED_MM_ALIAS.search(code))
has_cpp = bool(RE_SCALED_MM_CPP.search(code))
has_schema = bool(RE_SCALED_MM_SCHEMA.search(code))
if has_python or has_alias or has_cpp or has_schema:
evidence_parts = []
if has_python:
evidence_parts.append("torch._scaled_mm() called in Python")
if has_alias:
evidence_parts.append("torch._scaled_mm aliased to variable")
if has_cpp:
evidence_parts.append("aten::_scaled_mm in C++ code")
if has_schema:
evidence_parts.append("findSchemaOrThrow for _scaled_mm")
matches.append({
"pattern": "SCALED_MM_REF",
"severity": "high",
"evidence": "; ".join(evidence_parts),
})
return matches
def detect_decode_mm_ref(code: str | SubmissionFacts) -> list[dict]:
"""Pattern 5: FP4 decode + torch.mm/bmm as main compute path.
Only flags when mm/bmm/matmul appears to feed the output (near return
or assigned to a result-like variable), not when used for small
pre/post processing steps.
"""
facts = ensure_submission_facts(code)
python_only = facts.python_only
custom_kernel_code = facts.custom_kernel_code
entrypoint_name = entrypoint_label(facts.entrypoint_name)
matches = []
# Tighten scope: only count mm/bmm/matmul used in the submission entrypoint.
if not custom_kernel_code:
return matches
has_mm = bool(RE_TORCH_MM.search(custom_kernel_code))
has_bmm = bool(RE_TORCH_BMM.search(custom_kernel_code))
has_matmul = bool(RE_TORCH_MATMUL.search(custom_kernel_code))
if has_mm or has_bmm or has_matmul:
ops = []
if has_mm:
ops.append("torch.mm")
if has_bmm:
ops.append("torch.bmm")
if has_matmul:
ops.append("torch.matmul")
# Check if mm/bmm feeds the output path (return, result assignment, or
# is the dominant compute — more than one call suggests it IS the kernel)
mm_call_count = (
len(RE_TORCH_MM.findall(custom_kernel_code)) +
len(RE_TORCH_BMM.findall(custom_kernel_code)) +
len(RE_TORCH_MATMUL.findall(custom_kernel_code))
)
feeds_output = bool(re.search(
r'(?:result|output|out|c_ref|c|ret)\s*=\s*torch\.(?:mm|bmm|matmul)\s*\('
r'|return\s+.*?torch\.(?:mm|bmm|matmul)\s*\(',
custom_kernel_code,
))
# If only 1 call and doesn't feed output, likely a small helper step
if mm_call_count == 1 and not feeds_output:
return matches
has_decode_context = bool(
re.search(r'\b(?:fp4|decode|sfa|sfb|scale_a|scale_b|to_blocked)\b', python_only, re.IGNORECASE),
)
severity = "high" if has_decode_context else "medium"
evidence = f"{', '.join(ops)} used inside {entrypoint_name}"
if not has_decode_context:
evidence += " (no explicit FP4 decode context)"
if feeds_output:
evidence += " [feeds output path]"
matches.append({
"pattern": "DECODE_MM_REF",
"severity": severity,
"evidence": evidence,
})
return matches
def detect_result_caching(code: str | SubmissionFacts) -> list[dict]:
"""Pattern 6: split output replay from benign workspace/preprocess caches."""
facts = ensure_submission_facts(code)
code = facts.raw_code
python_only = facts.python_only
python_active = facts.python_active
matches = []
output_replay_signals = []
preprocess_signals = []
workspace_signals = []
runner_plan_signals = []
if RE_WEAK_VALUE_DICT.search(python_only):
workspace_signals.append("WeakValueDictionary")
if RE_DECODED_CACHE.search(python_only):
preprocess_signals.append("_decoded_cache")
if RE_PREPROCESS_CACHE.search(python_only):
preprocess_signals.append("_PREPROCESS_CACHE")
if RE_SCALE_CACHE.search(python_only):
preprocess_signals.append("_scale_cache")
if RE_RESULT_CACHE_GENERAL.search(python_only):
workspace_signals.append("_RESULT/_GROUPED_RESULT/_COMPUTE/_GEMM_CACHE")
if RE_VERSION_CACHE.search(python_only):
workspace_signals.append("tensor._version cache check")
# Scope all strong-signal checks to the configured entrypoint body.
# Helper functions that cache compilation artifacts (TensorMap, compiled kernels,
# plan descriptors, etc.) return those objects from caches, but that is legitimate —
# only a cache inside the entrypoint itself indicates result caching.
cache_scope = facts.custom_kernel_active if facts.custom_kernel_active else python_active
stores_output = bool(RE_CACHE_STORE_OUTPUT.search(cache_scope))
if RE_RESULT_REUSE.search(cache_scope):
output_replay_signals.append("_result_reuse")
if RE_OUTPUT_CACHE.search(cache_scope):
workspace_signals.append("_OUTPUT_CACHE")
if RE_ID_DATA_CACHE.search(cache_scope):
workspace_signals.append("id(data) cache key")
if RE_DATA_PTR_CACHE_KEY.search(cache_scope):
workspace_signals.append("data_ptr() cache key")
if RE_RETURN_CACHE_INDEX.search(cache_scope):
output_replay_signals.append("direct return from cache[...]")
if stores_output and output_replay_signals:
output_replay_signals.append("cache[...] stores output/result tensor")
elif stores_output:
workspace_signals.append("cache[...] stores reusable output/result tensor")
for var, cache_name in RE_CACHE_GET_ASSIGN.findall(cache_scope):
cache_lower = cache_name.lower()
if any(token in cache_lower for token in ("plan", "dispatch", "runner", "config")):
runner_plan_signals.append(f"{cache_name}.get(...) runner/plan cache")
elif any(token in cache_lower for token in ("decoded", "preprocess", "scale", "sort", "view", "shape", "quant", "meta", "pad")):
preprocess_signals.append(f"{cache_name}.get(...) preprocess cache")