From 43034dd0aacebf70086897b649f5ae98ecc45e34 Mon Sep 17 00:00:00 2001 From: Prasanna Date: Thu, 7 May 2026 21:08:27 +0530 Subject: [PATCH 1/2] fix(detector): detect sentinel replay and bmm self output --- kernelguard.py | 230 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 230 insertions(+) diff --git a/kernelguard.py b/kernelguard.py index f086847..6d828e5 100644 --- a/kernelguard.py +++ b/kernelguard.py @@ -1163,6 +1163,228 @@ def detect_trivial_probe(code: str | SubmissionFacts) -> list[dict]: return matches +def _expr_is_none(expr: ast.AST | None) -> bool: + return isinstance(expr, ast.Constant) and expr.value is None + + +def _expr_is_class_attr(expr: ast.AST | None, class_name: str, attr_name: str) -> bool: + return ( + isinstance(expr, ast.Attribute) + and expr.attr == attr_name + and isinstance(expr.value, ast.Name) + and expr.value.id == class_name + ) + + +def _expr_calls_input(expr: ast.AST | None, params: set[str]) -> bool: + if expr is None: + return False + return any( + isinstance(node, ast.Call) + and isinstance(node.func, ast.Attribute) + and isinstance(node.func.value, ast.Name) + and node.func.value.id in params + for node in ast.walk(expr) + ) + + +def detect_sentinel_output_replay(code: str | SubmissionFacts) -> list[dict]: + """Pattern: None sentinel stores an input-derived result and replays it.""" + facts = ensure_submission_facts(code) + tree = facts.ast_tree + if tree is None: + return [] + entrypoint_name = entrypoint_label(facts.entrypoint_name) + class_none_attrs: dict[str, set[str]] = {} + + for stmt in tree.body: + if not isinstance(stmt, ast.ClassDef): + continue + inherited_attrs = set() + for base in stmt.bases: + if isinstance(base, ast.Name): + inherited_attrs.update(class_none_attrs.get(base.id, set())) + if inherited_attrs: + class_none_attrs.setdefault(stmt.name, set()).update(inherited_attrs) + for child in stmt.body: + if not isinstance(child, ast.Assign) or not _expr_is_none(child.value): + continue + for target in child.targets: + if isinstance(target, ast.Name): + class_none_attrs.setdefault(stmt.name, set()).add(target.id) + + for node in ast.walk(tree): + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + if not is_entrypoint_name(node.name): + continue + params = {arg.arg for arg in node.args.args} + if not params: + continue + + for sentinel in facts._none_inited: + stores_input = False + guarded = False + returns_sentinel = False + for child in ast.walk(node): + if isinstance(child, ast.Assign) and _expr_calls_input(child.value, params): + stores_input = stores_input or any( + isinstance(target, ast.Name) and target.id == sentinel + for target in child.targets + ) + if stores_input: + continue + if isinstance(child, ast.If): + guarded = guarded or any( + isinstance(compare, ast.Compare) + and isinstance(compare.left, ast.Name) + and compare.left.id == sentinel + and len(compare.comparators) == 1 + and _expr_is_none(compare.comparators[0]) + for compare in ast.walk(child.test) + ) + if guarded: + continue + if isinstance(child, ast.Return): + returns_sentinel = returns_sentinel or ( + isinstance(child.value, ast.Name) + and child.value.id == sentinel + ) + if stores_input and guarded and returns_sentinel: + return [{ + "pattern": "LAST_CALL_REPLAY", + "severity": "critical", + "evidence": f"{entrypoint_name} replays input-derived output from a global sentinel", + }] + + for class_name, attrs in class_none_attrs.items(): + for attr_name in attrs: + stores_input = False + guarded = False + returns_attr = False + for child in ast.walk(node): + if isinstance(child, ast.Assign) and _expr_calls_input(child.value, params): + stores_input = stores_input or any( + _expr_is_class_attr(target, class_name, attr_name) + for target in child.targets + ) + if stores_input: + continue + if isinstance(child, ast.If): + guarded = guarded or any( + isinstance(compare, ast.Compare) + and _expr_is_class_attr(compare.left, class_name, attr_name) + and len(compare.comparators) == 1 + and _expr_is_none(compare.comparators[0]) + for compare in ast.walk(child.test) + ) + if guarded: + continue + if isinstance(child, ast.Return): + returns_attr = returns_attr or _expr_is_class_attr(child.value, class_name, attr_name) + if stores_input and guarded and returns_attr: + return [{ + "pattern": "LAST_CALL_REPLAY", + "severity": "critical", + "evidence": f"{entrypoint_name} replays input-derived output from a class sentinel", + }] + + return [] + + +def _expr_is_torch_attr(expr: ast.AST | None, torch_aliases: set[str], attr: str) -> bool: + return ( + isinstance(expr, ast.Attribute) + and expr.attr == attr + and isinstance(expr.value, ast.Name) + and expr.value.id in torch_aliases + ) + + +def detect_bmm_self_matmul_output(code: str | SubmissionFacts) -> list[dict]: + """Pattern: aliased torch.bmm computes an input against its transpose.""" + facts = ensure_submission_facts(code) + tree = facts.ast_tree + if tree is None: + return [] + entrypoint_name = entrypoint_label(facts.entrypoint_name) + torch_aliases = {"torch"} + bmm_aliases: set[str] = set() + + for node in facts._imports: + for alias in node.names: + if alias.name == "torch": + torch_aliases.add(alias.asname or alias.name) + for node in facts._import_froms: + if node.module != "torch": + continue + for alias in node.names: + if alias.name == "bmm": + bmm_aliases.add(alias.asname or alias.name) + + for stmt in tree.body: + if not isinstance(stmt, ast.Assign): + continue + if not _expr_is_torch_attr(stmt.value, torch_aliases, "bmm"): + continue + for target in stmt.targets: + if isinstance(target, ast.Name): + bmm_aliases.add(target.id) + + for node in ast.walk(tree): + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + if not is_entrypoint_name(node.name): + continue + params = {arg.arg for arg in node.args.args} + expanded_inputs: set[str] = set() + for child in ast.walk(node): + if not isinstance(child, ast.Assign): + continue + if not ( + isinstance(child.value, ast.Call) + and isinstance(child.value.func, ast.Attribute) + and child.value.func.attr == "unsqueeze" + and isinstance(child.value.func.value, ast.Name) + and child.value.func.value.id in params + ): + continue + for target in child.targets: + if isinstance(target, ast.Name): + expanded_inputs.add(target.id) + + for child in ast.walk(node): + if not isinstance(child, ast.Return): + continue + for call in ast.walk(child.value): + if not isinstance(call, ast.Call) or len(call.args) < 2: + continue + calls_bmm = ( + isinstance(call.func, ast.Name) + and call.func.id in bmm_aliases + ) or _expr_is_torch_attr(call.func, torch_aliases, "bmm") + if not calls_bmm: + continue + first, second = call.args[:2] + if not (isinstance(first, ast.Name) and first.id in expanded_inputs): + continue + if not ( + isinstance(second, ast.Call) + and isinstance(second.func, ast.Attribute) + and second.func.attr in {"transpose", "transpose_"} + and isinstance(second.func.value, ast.Name) + and second.func.value.id == first.id + ): + continue + return [{ + "pattern": "SELF_MATMUL_OUTPUT", + "severity": "critical", + "evidence": f"{entrypoint_name} returns bmm(input, input.transpose(...))", + }] + + return [] + + def detect_torch_compile_cache(code: str | SubmissionFacts) -> list[dict]: """Pattern 11: torch.compile for pipeline graph caching.""" facts = ensure_submission_facts(code) @@ -3483,6 +3705,10 @@ class RulePolicy: "SHAPE_OUTPUT_REPLAY", "result_reuse", "hard", AUTO_FILTER, (), AMD_PACKAGE_SHAPE_REPLAY_FIXTURES, "keep", ), + "SELF_MATMUL_OUTPUT": RulePolicy( + "SELF_MATMUL_OUTPUT", "fake_output", "hard", AUTO_FILTER, (), + (), "keep", + ), "TIMED_INPUT_REPLAY": RulePolicy( "TIMED_INPUT_REPLAY", "result_reuse", "hard", AUTO_FILTER, (), AMD_PACKAGE_TIMED_INPUT_REPLAY_FIXTURES, "keep", @@ -3763,6 +3989,8 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool: detect_result_caching, detect_last_call_replay, detect_shape_output_replay, + detect_sentinel_output_replay, + detect_bmm_self_matmul_output, detect_timed_input_replay, detect_cuda_graph_replay, detect_silent_fallback, @@ -3801,6 +4029,8 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool: ("result_caching", detect_result_caching), ("last_call_replay", detect_last_call_replay), ("shape_output_replay", detect_shape_output_replay), + ("sentinel_output_replay", detect_sentinel_output_replay), + ("bmm_self_matmul_output", detect_bmm_self_matmul_output), ("timed_input_replay", detect_timed_input_replay), ("cuda_graph_replay", detect_cuda_graph_replay), ("silent_fallback", detect_silent_fallback), From 6b89247a408eed36e2949942aab1e4e469a0603b Mon Sep 17 00:00:00 2001 From: Prasanna Date: Thu, 7 May 2026 21:25:30 +0530 Subject: [PATCH 2/2] chore(eval): fresh target red 795