Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 230 additions & 0 deletions kernelguard.py
Original file line number Diff line number Diff line change
Expand Up @@ -1163,6 +1163,228 @@ def detect_trivial_probe(code: str | SubmissionFacts) -> list[dict]:
return matches


def _expr_is_none(expr: ast.AST | None) -> bool:
return isinstance(expr, ast.Constant) and expr.value is None


def _expr_is_class_attr(expr: ast.AST | None, class_name: str, attr_name: str) -> bool:
return (
isinstance(expr, ast.Attribute)
and expr.attr == attr_name
and isinstance(expr.value, ast.Name)
and expr.value.id == class_name
)


def _expr_calls_input(expr: ast.AST | None, params: set[str]) -> bool:
if expr is None:
return False
return any(
isinstance(node, ast.Call)
and isinstance(node.func, ast.Attribute)
and isinstance(node.func.value, ast.Name)
and node.func.value.id in params
for node in ast.walk(expr)
)


def detect_sentinel_output_replay(code: str | SubmissionFacts) -> list[dict]:
"""Pattern: None sentinel stores an input-derived result and replays it."""
facts = ensure_submission_facts(code)
tree = facts.ast_tree
if tree is None:
return []
entrypoint_name = entrypoint_label(facts.entrypoint_name)
class_none_attrs: dict[str, set[str]] = {}

for stmt in tree.body:
if not isinstance(stmt, ast.ClassDef):
continue
inherited_attrs = set()
for base in stmt.bases:
if isinstance(base, ast.Name):
inherited_attrs.update(class_none_attrs.get(base.id, set()))
if inherited_attrs:
class_none_attrs.setdefault(stmt.name, set()).update(inherited_attrs)
for child in stmt.body:
if not isinstance(child, ast.Assign) or not _expr_is_none(child.value):
continue
for target in child.targets:
if isinstance(target, ast.Name):
class_none_attrs.setdefault(stmt.name, set()).add(target.id)

for node in ast.walk(tree):
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
if not is_entrypoint_name(node.name):
continue
params = {arg.arg for arg in node.args.args}
if not params:
continue

for sentinel in facts._none_inited:
stores_input = False
guarded = False
returns_sentinel = False
for child in ast.walk(node):
if isinstance(child, ast.Assign) and _expr_calls_input(child.value, params):
stores_input = stores_input or any(
isinstance(target, ast.Name) and target.id == sentinel
for target in child.targets
)
if stores_input:
continue
if isinstance(child, ast.If):
guarded = guarded or any(
isinstance(compare, ast.Compare)
and isinstance(compare.left, ast.Name)
and compare.left.id == sentinel
and len(compare.comparators) == 1
and _expr_is_none(compare.comparators[0])
for compare in ast.walk(child.test)
)
if guarded:
continue
if isinstance(child, ast.Return):
returns_sentinel = returns_sentinel or (
isinstance(child.value, ast.Name)
and child.value.id == sentinel
)
if stores_input and guarded and returns_sentinel:
return [{
"pattern": "LAST_CALL_REPLAY",
"severity": "critical",
"evidence": f"{entrypoint_name} replays input-derived output from a global sentinel",
}]

for class_name, attrs in class_none_attrs.items():
for attr_name in attrs:
stores_input = False
guarded = False
returns_attr = False
for child in ast.walk(node):
if isinstance(child, ast.Assign) and _expr_calls_input(child.value, params):
stores_input = stores_input or any(
_expr_is_class_attr(target, class_name, attr_name)
for target in child.targets
)
if stores_input:
continue
if isinstance(child, ast.If):
guarded = guarded or any(
isinstance(compare, ast.Compare)
and _expr_is_class_attr(compare.left, class_name, attr_name)
and len(compare.comparators) == 1
and _expr_is_none(compare.comparators[0])
for compare in ast.walk(child.test)
)
if guarded:
continue
if isinstance(child, ast.Return):
returns_attr = returns_attr or _expr_is_class_attr(child.value, class_name, attr_name)
if stores_input and guarded and returns_attr:
return [{
"pattern": "LAST_CALL_REPLAY",
"severity": "critical",
"evidence": f"{entrypoint_name} replays input-derived output from a class sentinel",
}]

return []


def _expr_is_torch_attr(expr: ast.AST | None, torch_aliases: set[str], attr: str) -> bool:
return (
isinstance(expr, ast.Attribute)
and expr.attr == attr
and isinstance(expr.value, ast.Name)
and expr.value.id in torch_aliases
)


def detect_bmm_self_matmul_output(code: str | SubmissionFacts) -> list[dict]:
"""Pattern: aliased torch.bmm computes an input against its transpose."""
facts = ensure_submission_facts(code)
tree = facts.ast_tree
if tree is None:
return []
entrypoint_name = entrypoint_label(facts.entrypoint_name)
torch_aliases = {"torch"}
bmm_aliases: set[str] = set()

for node in facts._imports:
for alias in node.names:
if alias.name == "torch":
torch_aliases.add(alias.asname or alias.name)
for node in facts._import_froms:
if node.module != "torch":
continue
for alias in node.names:
if alias.name == "bmm":
bmm_aliases.add(alias.asname or alias.name)

for stmt in tree.body:
if not isinstance(stmt, ast.Assign):
continue
if not _expr_is_torch_attr(stmt.value, torch_aliases, "bmm"):
continue
for target in stmt.targets:
if isinstance(target, ast.Name):
bmm_aliases.add(target.id)

for node in ast.walk(tree):
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
if not is_entrypoint_name(node.name):
continue
params = {arg.arg for arg in node.args.args}
expanded_inputs: set[str] = set()
for child in ast.walk(node):
if not isinstance(child, ast.Assign):
continue
if not (
isinstance(child.value, ast.Call)
and isinstance(child.value.func, ast.Attribute)
and child.value.func.attr == "unsqueeze"
and isinstance(child.value.func.value, ast.Name)
and child.value.func.value.id in params
):
continue
for target in child.targets:
if isinstance(target, ast.Name):
expanded_inputs.add(target.id)

for child in ast.walk(node):
if not isinstance(child, ast.Return):
continue
for call in ast.walk(child.value):
if not isinstance(call, ast.Call) or len(call.args) < 2:
continue
calls_bmm = (
isinstance(call.func, ast.Name)
and call.func.id in bmm_aliases
) or _expr_is_torch_attr(call.func, torch_aliases, "bmm")
if not calls_bmm:
continue
first, second = call.args[:2]
if not (isinstance(first, ast.Name) and first.id in expanded_inputs):
continue
if not (
isinstance(second, ast.Call)
and isinstance(second.func, ast.Attribute)
and second.func.attr in {"transpose", "transpose_"}
and isinstance(second.func.value, ast.Name)
and second.func.value.id == first.id
):
continue
return [{
"pattern": "SELF_MATMUL_OUTPUT",
"severity": "critical",
"evidence": f"{entrypoint_name} returns bmm(input, input.transpose(...))",
}]

return []


def detect_torch_compile_cache(code: str | SubmissionFacts) -> list[dict]:
"""Pattern 11: torch.compile for pipeline graph caching."""
facts = ensure_submission_facts(code)
Expand Down Expand Up @@ -3483,6 +3705,10 @@ class RulePolicy:
"SHAPE_OUTPUT_REPLAY", "result_reuse", "hard", AUTO_FILTER, (),
AMD_PACKAGE_SHAPE_REPLAY_FIXTURES, "keep",
),
"SELF_MATMUL_OUTPUT": RulePolicy(
"SELF_MATMUL_OUTPUT", "fake_output", "hard", AUTO_FILTER, (),
(), "keep",
),
"TIMED_INPUT_REPLAY": RulePolicy(
"TIMED_INPUT_REPLAY", "result_reuse", "hard", AUTO_FILTER, (),
AMD_PACKAGE_TIMED_INPUT_REPLAY_FIXTURES, "keep",
Expand Down Expand Up @@ -3763,6 +3989,8 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool:
detect_result_caching,
detect_last_call_replay,
detect_shape_output_replay,
detect_sentinel_output_replay,
detect_bmm_self_matmul_output,
detect_timed_input_replay,
detect_cuda_graph_replay,
detect_silent_fallback,
Expand Down Expand Up @@ -3801,6 +4029,8 @@ def support_only_patterns(matched_patterns: list[dict]) -> bool:
("result_caching", detect_result_caching),
("last_call_replay", detect_last_call_replay),
("shape_output_replay", detect_shape_output_replay),
("sentinel_output_replay", detect_sentinel_output_replay),
("bmm_self_matmul_output", detect_bmm_self_matmul_output),
("timed_input_replay", detect_timed_input_replay),
("cuda_graph_replay", detect_cuda_graph_replay),
("silent_fallback", detect_silent_fallback),
Expand Down