Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
b75b621
the initial design for unified hint framework
starrryz Jan 12, 2026
0078272
update the logic of how to call backend method in basehinthandler
starrryz Jan 13, 2026
bfe1966
update hintmanager, wrap additional code into hintmanager, back no-hi…
starrryz Jan 21, 2026
2c64367
remove redundant code
starrryz Jan 26, 2026
a430a9e
fix import and python bugs
starrryz Jan 26, 2026
06a032a
fix import and python bugs_2
starrryz Jan 26, 2026
854b504
apply code-format change
starrryz Jan 26, 2026
9e2ef64
apply code-format change_2
starrryz Jan 26, 2026
51756b2
fix bug : circular import
starrryz Jan 27, 2026
19f80c5
fix bug : hintmanager name into hint_manager
starrryz Jan 27, 2026
45c93b4
fix bug : massive useless print
starrryz Jan 27, 2026
869c357
update spec hint-related codegen && jit
starrryz Mar 10, 2026
cc97432
remove redundant code in python triton src
starrryz Mar 10, 2026
d484e12
update hintmanager, Align with triton_v3.5.x branch.
starrryz Mar 10, 2026
5f7a336
fix hint manager import error
starrryz Mar 11, 2026
018a6e0
Merge branch 'triton_v3.2.x' into triton_v3.2.x_hint_manager
sgjzfzzf Mar 19, 2026
10655ef
Merge branch 'triton_v3.2.x' into triton_v3.2.x_hint_manager
sgjzfzzf Mar 23, 2026
bf15c64
add hint test on IR phase
starrryz Mar 26, 2026
8c11f4b
fix hint test on ascend
starrryz Mar 26, 2026
e665202
fix hint test on ascend 2
starrryz Mar 26, 2026
cc7e247
fix hint test on ascend 3
starrryz Mar 26, 2026
f23042e
fix hint test on ascend 4
starrryz Mar 26, 2026
48700c0
fix hint test on ascend final
starrryz Mar 27, 2026
4924d99
fix hint test on ascend final 2
starrryz Mar 27, 2026
04d87b5
fix hint test on ascend final 3
starrryz Mar 27, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/ascend-build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ jobs:
python3 14-accuracy-comparison.py
#python3 15-embedding_gather_demo.py
popd
# hint tests
pushd third_party/ascend/tutorials/hint
python3 test_comment_hint.py
popd
# pytest_ut
pushd third_party/ascend/unittest/pytest_ut
python3 -m pytest . \
Expand Down
135 changes: 135 additions & 0 deletions python/triton/compiler/hint_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import sys
import importlib


class BaseHintHandler:
# dynamicly find method
def trigger(self, hook_name, *args, **kwargs):
if hasattr(self, hook_name):
method = getattr(self, hook_name)
if callable(method):
try:
return method(*args, **kwargs)

except TypeError as e:
import inspect

try:
sig = inspect.signature(method)
expected = str(sig)
except Exception:
expected = "(unknown)"

actual_args = f"{len(args)} positional"
actual_kwargs = f"keys={list(kwargs.keys())}" if kwargs else "no keywords"

print(f"\n[Hint Trigger Mismatch] {self.__class__.__name__}.{hook_name}")
print(f" > Expect : {expected}")
print(f" > Actual : {actual_args}, {actual_kwargs}")
print(f" > Reason : {e}\n")

raise e
return None


class HintManager:

def __init__(self, backend_name):
self.backend_name = backend_name
# load Handler with backend name
self.handler = self._load_handler(backend_name)

def _load_handler(self, backend):
if backend == 'npu':
try:
module = importlib.import_module("triton.backends.ascend.ascend_hint_handler")
return module.AscendHintHandler()
except ImportError as e:
print(f"[FlagTree] Warning: Failed to load Ascend Hint Handler: {e}", file=sys.stderr)
return BaseHintHandler()
elif backend == 'aipu':
try:
module = importlib.import_module("triton.backends.aipu.aipu_hint_handler")
return module.AipuHintHandler()
except ImportError as e:
print(f"[FlagTree] Warning: Failed to load aipu Hint Handler: {e}", file=sys.stderr)
return BaseHintHandler()
elif backend == 'cuda':
try:
module = importlib.import_module("triton.backends.nvidia.nvidia_hint_handler")
return module.NvidiaHintHandler()
except ImportError as e:
print(f"[FlagTree] Warning: Failed to load Nvidia Hint Handler: {e}", file=sys.stderr)
return BaseHintHandler()
else:
return BaseHintHandler()


# supported backend with matched version
SUPPORTED_BACKENDS = ["aipu", "npu", "cuda"]

# TODO : npu will have conflicts if more backend involved
# mapping name
BACKEND_ALIASES = {
"ascend": "npu",
"huawei": "npu",
"nvidia": "cuda",
}


def normalize_backend_name(name: str) -> str:
if not name:
return ""
name = name.lower()
return BACKEND_ALIASES.get(name, name)


def hint_get_flagtree_backend() -> str:
detected_backend = ""

import torch

# Priority 1: Triton Driver
try:
from triton.runtime import driver
if hasattr(driver, 'active') and hasattr(driver.active, 'get_active_torch_device'):
device = driver.active.get_active_torch_device()
if isinstance(device, torch.device):
detected_backend = device.type
# unimplemented support
elif isinstance(device, str):
detected_backend = device
except ImportError:
pass

# TODO : some backend may not support priority 1, so keep priority 2 is necessary
# Priority 2: Torch Global State
if not detected_backend:
check_priority = ["aipu", "npu", "cuda"]

# 3. parse according to benefit
for candidate in check_priority:
module = getattr(torch, candidate, None)
if module and hasattr(module, "is_available") and module.is_available():
detected_backend = candidate
break

# (Normalization and Validation)
canonical_backend = normalize_backend_name(detected_backend)

if not canonical_backend or canonical_backend not in SUPPORTED_BACKENDS:
return ""

return canonical_backend


# lazy load after first call hint trigger
_global_hint_manager = None


def hint_trigger(hook_name, *args, **kwargs):
global _global_hint_manager

if _global_hint_manager is None:
_global_hint_manager = HintManager(hint_get_flagtree_backend())
return _global_hint_manager.handler.trigger(hook_name, *args, **kwargs)
79 changes: 79 additions & 0 deletions third_party/ascend/backend/ascend_hint_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# should store at thrid_party/???/backend/
from triton.compiler.hint_manager import BaseHintHandler
import triton.language as language
import ast
from triton.compiler.code_generator import _is_triton_value


class AscendHintHandler(BaseHintHandler):

@staticmethod
def ext_CodeGenerator_visit_Assign_hint_anno(code_generator, node, names, values):
import ast
from triton.compiler.code_generator import _is_triton_value
# flagtree: After normal processing, check if we need to add hint annotation
if hasattr(node, 'lineno') and hasattr(code_generator, 'jit_fn'):
line_num = node.lineno
# TODO: reparse needed in case we need to deal with complex cases, will be redesigned later
function_def = code_generator.jit_fn.parse()
line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {})
flagtree_hints = line_flagtree_hints.get(line_num)

# Check if this is a tl.load call with dot_pad_only_k hint
if (flagtree_hints and 'dot_pad_only_k' in flagtree_hints and isinstance(node.value, ast.Call)
and isinstance(node.value.func, ast.Attribute) and isinstance(node.value.func.value, ast.Name)
and node.value.func.value.id == 'tl' and node.value.func.attr == 'load'):

# Add hint annotation to the loaded tensor(s)
for name, value in zip(names, values):
if _is_triton_value(value):
# print(f"[FLAGTREE] Creating hint annotation for tensor: {flagtree_hints}")
Comment thread
starrryz marked this conversation as resolved.
# Create hint annotation
hint_val = code_generator.builder.get_unit_attr()
code_generator.builder.create_annotation(value.handle, 'dot_pad_only_k', hint_val)

@staticmethod
def check_override_bind_sub_block(code_generator, node, bind_sub_block):
# flagtree: After normal processing, check if we need to override bind_sub_block
if hasattr(node, 'lineno') and hasattr(code_generator, 'jit_fn'):
line_num = node.lineno
# TODO: reparse needed in case we need to deal with complex cases, will be redesigned later
function_def = code_generator.jit_fn.parse()
line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {})
flagtree_hints = line_flagtree_hints.get(line_num)

# Check if this is a range/for loop with bind_sub_block hint
if flagtree_hints and 'bind_sub_block' in flagtree_hints:
return True
# print(f"[FLAGTREE] Found bind_sub_block hint at line {line_num}")
Comment thread
starrryz marked this conversation as resolved.
return bind_sub_block

@staticmethod
def forop_setattr_for_bind_sub_block(code_generator, for_op, bind_sub_block):
for_op.set_attr("bind_sub_block", code_generator.builder.get_bool_attr(bind_sub_block))

@staticmethod
def maps_line_numbers_to_comment_hints(jit_fn):
import tokenize
from io import StringIO
# Maps line numbers to comment hints
line_flagtree_hints = {}
code_str = jit_fn.src
g = tokenize.generate_tokens(StringIO(code_str).readline)
for tok_type, tok_text, start, end, _ in g:
if tok_type == tokenize.COMMENT:
comment = tok_text.replace(" ", "").strip()
if comment.startswith('#@hint:'):
flagtree_hints = comment[len('#@hint:'):].strip()
# Record the line number of the comment
line_num = start[0]
line_flagtree_hints[line_num] = flagtree_hints

# print(f"[FLAGTREE] Parsed hint at line {line_num}: {flagtree_hints}")

return line_flagtree_hints

@staticmethod
def attach_line_number_to_comment_mapping(tree, line_flagtree_hints):
# Attach the line number to comment mapping to the function definition node
tree.body[0].line_flagtree_hints = line_flagtree_hints
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ..runtime import JITFunction
from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct)
from types import ModuleType
from .hint_manager import hint_trigger
# Central registry for all 'with' statement handlers
WITH_DISPATCH = {}

Expand Down Expand Up @@ -548,6 +549,9 @@ def visit_Assign(self, node):
value = language.semantic.to_tensor(value, self.builder)
self.set_value(name, value)

# switch into hintmanager
hint_trigger("ext_CodeGenerator_visit_Assign_hint_anno", self, node, names, values)

def visit_AugAssign(self, node):
name = node.target.id
lhs = ast.Name(id=name, ctx=ast.Load())
Expand Down Expand Up @@ -997,6 +1001,11 @@ def visit_For(self, node):
step = iter_args[2] if len(iter_args) > 2 else self.visit(ast.Num(1))
else:
raise RuntimeError('Only `range` and `static_range` iterators are currently supported')
# hint manager
new_bind_sub_block = hint_trigger("check_override_bind_sub_block", self, node, bind_sub_block)
if new_bind_sub_block is not None:
bind_sub_block = new_bind_sub_block

# handle negative constant step (not supported by scf.for in MLIR)
negative_step = False
if _is_constexpr(step) and step.value < 0:
Expand Down Expand Up @@ -1072,6 +1081,9 @@ def visit_For(self, node):
tle = importlib.import_module("triton.experimental.tle", package=__package__)
if (IteratorClass is extension.parallel or IteratorClass is tle.dsa.parallel):
for_op.set_attr("hivm.parallel_loop", self.builder.get_unit_attr())
# hint manager
if bind_sub_block:
hint_trigger("forop_setattr_for_bind_sub_block", self, for_op, bind_sub_block)

self.scf_stack.append(node)
self.builder.set_insertion_point_to_start(for_op.get_body(0))
Expand Down
10 changes: 10 additions & 0 deletions third_party/ascend/backend/spec/triton/runtime/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,10 +756,20 @@ def preload(self, specialization_data):
# the user might want to monkey-patch self.src dynamically.
# Our unit tests do this, for example.
def parse(self):
# hint manager
# after removing flagtree backend specialization, hiding the implementation into hintmanager
from ..compiler.hint_manager import hint_trigger
line_flagtree_hints = hint_trigger("maps_line_numbers_to_comment_hints", self)

tree = ast.parse(self.src)
assert isinstance(tree, ast.Module)
assert len(tree.body) == 1
assert isinstance(tree.body[0], ast.FunctionDef)

# hint manager
# Attach the line number to comment mapping to the function definition node
hint_trigger('attach_line_number_to_comment_mapping', tree, line_flagtree_hints)

return tree

def __call__(self, *args, **kwargs):
Expand Down
Loading
Loading