Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 108 additions & 12 deletions tzrec/acc/aot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import json
import os
from typing import Any, Dict, Optional, Set, Union
from typing import Any, Dict, List, Optional, Set, Union

import torch
from torch import nn
Expand All @@ -38,6 +38,84 @@
logger.debug("cutlass_hstu_attention not available; skipping op registration")


def _build_aoti_output_field_names(
exported_pg: "torch.export.ExportedProgram",
eager_output_keys: List[str],
) -> List[str]:
"""Align user-facing output names with the AOTI output-handle layout.

``torch._inductor.aoti_compile_and_package`` compiles the graph inside
``exported_pg``. The resulting AOTI wrapper emits one ``output_handles[i]``
per leaf in ``exported_pg.graph_signature.output_specs``. That list can be
**longer** than the eager module's return dict because
``torch.export.export`` also surfaces buffer-mutation results, token
outputs, gradient-to-parameter signals, etc. as extra outputs.

The runtime (TorchRecProcessor) names the emitted tensors positionally
using ``output_field_names.json``, so the JSON **must** have one entry per
AOTI output handle. We therefore walk ``output_specs`` in order, assign
the eager dict keys only to ``USER_OUTPUT`` slots, and fill every other
slot with an ignorable ``_unused_*`` placeholder.

Args:
exported_pg: Result of ``torch.export.export(...)``.
eager_output_keys: Keys of the eager forward's return dict, in
dict-iteration order (which is the order ``torch.export`` flattens
them in).

Returns:
List of names, one per AOTI output handle, with USER_OUTPUT slots
carrying the eager keys in order and all other slots filled with
``_unused_<idx>_<kind>`` placeholders.

Raises:
RuntimeError: If the number of USER_OUTPUT slots does not match
``len(eager_output_keys)``; this indicates the eager dict and
the exported program disagree on the user-visible outputs and the
exported model cannot be reliably consumed downstream.
"""
output_specs = exported_pg.graph_signature.output_specs

# Identify USER_OUTPUT slots robustly across torch versions.
def _is_user_output(spec: Any) -> bool:
kind = getattr(spec, "kind", None)
if kind is None:
return False
# Enum path (current torch): kind.name == "USER_OUTPUT".
name = getattr(kind, "name", None)
if isinstance(name, str):
return name == "USER_OUTPUT"
# String fallback.
return str(kind).endswith("USER_OUTPUT")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The string fallback str(kind).endswith("USER_OUTPUT") is looser than it needs to be. None of today's OutputKind values (LOSS_OUTPUT, BUFFER_MUTATION, GRADIENT_TO_PARAMETER, GRADIENT_TO_USER_INPUT, USER_INPUT_MUTATION, TOKEN) trigger a false positive, but a future enum value that happens to end in USER_OUTPUT (or any subclass/wrapper whose __str__ concatenates extra text) would silently mis-classify. Prefer an anchored check, e.g. split on . and compare the final segment to "USER_OUTPUT", or match against {"USER_OUTPUT", "OutputKind.USER_OUTPUT"} exactly. The test_string_kind_fallback test only exercises well-formed strings and wouldn't catch this.


user_output_positions = [
i for i, spec in enumerate(output_specs) if _is_user_output(spec)
]

if len(user_output_positions) != len(eager_output_keys):
raise RuntimeError(
"AOTI output-name alignment failed: exported program has "
f"{len(user_output_positions)} USER_OUTPUT slots "
f"(out of {len(output_specs)} total) but eager forward returned "
f"{len(eager_output_keys)} keys ({eager_output_keys}). "
"The eager dict and the torch.export graph disagree on "
"user-visible outputs; refusing to emit a mislabeled "
"output_field_names.json."
)

names: List[str] = []
user_iter = iter(eager_output_keys)
for i, spec in enumerate(output_specs):
if _is_user_output(spec):
names.append(next(user_iter))
else:
kind = getattr(spec, "kind", None)
kind_name = getattr(kind, "name", None) or str(kind) or "other"
names.append(f"_unused_{i}_{kind_name.lower()}")

return names


def load_model_aot(
model_path: str, device: torch.device
) -> Union[CombinedModelWrapper, UnifiedAOTIModelWrapper]:
Expand Down Expand Up @@ -137,7 +215,7 @@ def export_model_aot(
# active — kernels like CUTLASS HSTU attention reject fp32 inputs.
with torch.no_grad():
_out = dense_to_export(sparse_output)
aoti_output_keys = list(_out.keys())
eager_output_keys = list(_out.keys())
del _out

# pre_hook requires running arbitrary code at runtime
Expand All @@ -149,6 +227,16 @@ def export_model_aot(
args=(sparse_output,),
dynamic_shapes=(dynamic_shapes,),
)

# Align names with AOTI output-handle layout: the exported program may
# emit extra outputs (buffer mutations, tokens, ...) in addition to the
# eager dict's user-visible outputs. The runtime indexes into
# output_field_names.json positionally, so it must have exactly one
# entry per AOTI output handle.
aoti_output_field_names = _build_aoti_output_field_names(
exported_pg, eager_output_keys
)

# AsserScalar codegen is not correct.
with torch._inductor.config.patch(
{
Expand All @@ -159,13 +247,15 @@ def export_model_aot(
aoti_dir = os.path.join(save_dir, "aoti")
os.makedirs(aoti_dir, exist_ok=True)

# Save original model output field names to aoti directory
if aoti_output_keys:
# Save output field names to aoti directory (one per AOTI output
# handle; non-USER_OUTPUT slots are filled with _unused_* placeholders).
if aoti_output_field_names:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The surrounding compile/save block (lines ~247–263: makedirs(aoti_dir) → optional JSON dump → aoti_compile_and_package, inside _inductor.config.patch) is copy-pasted to export_unified_model_aot at lines ~541–557. This PR is itself a fix that had to land in both places because they drifted, and the prior aoti_output_keys handling also lived in both. Consider extracting a small helper like _compile_and_save_aoti(exported_pg, aoti_output_field_names, save_dir) so the next fix only lands once.

output_names_path = os.path.join(aoti_dir, "output_field_names.json")
with open(output_names_path, "w") as f:
json.dump(aoti_output_keys, f, indent=4)
json.dump(aoti_output_field_names, f, indent=4)
logger.info(
f"Saved output field names to {output_names_path}: {aoti_output_keys}"
f"Saved output field names to {output_names_path}: "
f"{aoti_output_field_names}"
)

torch._inductor.aoti_compile_and_package(
Expand Down Expand Up @@ -408,7 +498,7 @@ def export_unified_model_aot(
f.write(full_gm.code)

result = full_gm(data)
aoti_output_keys = list(result.keys())
eager_output_keys = list(result.keys())
del result

# Pad any 0-size non-sequence sparse .values tensors so torch.export
Expand All @@ -435,6 +525,11 @@ def export_unified_model_aot(
dynamic_shapes=(dynamic_shapes,),
)

# Align names with AOTI output-handle layout (see _build_aoti_output_field_names).
aoti_output_field_names = _build_aoti_output_field_names(
exported_pg, eager_output_keys
)

# Compile with AOTI
logger.info("compiling unified model with AOTI...")
with torch._inductor.config.patch(
Expand All @@ -446,14 +541,15 @@ def export_unified_model_aot(
aoti_dir = os.path.join(save_dir, "aoti")
os.makedirs(aoti_dir, exist_ok=True)

# Save original model output field names (matches legacy
# export_model_aot behavior).
if aoti_output_keys:
# Save output field names (one per AOTI output handle; non-USER_OUTPUT
# slots are filled with _unused_* placeholders).
if aoti_output_field_names:
output_names_path = os.path.join(aoti_dir, "output_field_names.json")
with open(output_names_path, "w") as f:
json.dump(aoti_output_keys, f, indent=4)
json.dump(aoti_output_field_names, f, indent=4)
logger.info(
f"Saved output field names to {output_names_path}: {aoti_output_keys}"
f"Saved output field names to {output_names_path}: "
f"{aoti_output_field_names}"
)

torch._inductor.aoti_compile_and_package(
Expand Down
172 changes: 172 additions & 0 deletions tzrec/acc/aot_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Unit tests for ``tzrec.acc.aot_utils._build_aoti_output_field_names``.

These tests verify that the output-name alignment helper correctly maps the
eager forward's return-dict keys onto the AOTI output-handle layout emitted
by ``torch.export.export``. The helper is pure Python and only reads
``exported_pg.graph_signature.output_specs[i].kind`` / ``.target``, so the
tests use lightweight duck-typed fakes instead of invoking the real
``torch.export`` machinery. That keeps the tests fast, deterministic, and
independent of the surrounding CUDA / Triton / TorchRec stack.

Regression coverage:

* ``test_user_output_only``: the original (pre-bugfix) happy path where
``USER_OUTPUT`` is the only output kind.
* ``test_mixed_output_kinds``: the buggy scenario from the HSTU export that
motivated this helper. Non-``USER_OUTPUT`` slots must be filled with
placeholders so the JSON written to disk has one entry per AOTI output
handle (and downstream tensors are not renamed by position drift).
* ``test_user_output_count_mismatch_raises``: if the exported program does
not expose the same number of USER_OUTPUT slots as the eager dict returns,
the helper must refuse to emit a mislabeled mapping.
* ``test_string_kind_fallback``: older torch builds surface ``kind`` as a
bare string rather than an enum; the helper must still identify
``USER_OUTPUT`` slots in that case.
"""

import unittest
from typing import Any, List, Optional

from tzrec.acc.aot_utils import _build_aoti_output_field_names


class _FakeKind:
"""Duck-typed stand-in for ``torch.export.graph_signature.OutputKind``.

The real enum exposes ``.name`` (e.g. ``"USER_OUTPUT"``,
``"BUFFER_MUTATION"``). ``_build_aoti_output_field_names`` reads exactly
that attribute plus ``str(kind)``, nothing else.
"""

def __init__(self, name: str) -> None:
self.name = name

def __str__(self) -> str: # matches enum ``str(OutputKind.X)`` shape.
return f"OutputKind.{self.name}"


class _FakeSpec:
"""Stand-in for one entry of ``graph_signature.output_specs``."""

def __init__(self, kind: Any, target: Optional[str] = None) -> None:
self.kind = kind
self.target = target


class _FakeSignature:
def __init__(self, output_specs: List[_FakeSpec]) -> None:
self.output_specs = output_specs


class _FakeExportedProgram:
"""Minimal duck-typed ExportedProgram exposing only ``.graph_signature``."""

def __init__(self, output_specs: List[_FakeSpec]) -> None:
self.graph_signature = _FakeSignature(output_specs)


def _make_program(kinds: List[str]) -> _FakeExportedProgram:
"""Build a fake ExportedProgram with the given sequence of kind names."""
return _FakeExportedProgram([_FakeSpec(_FakeKind(name)) for name in kinds])


class BuildAotiOutputFieldNamesTest(unittest.TestCase):
def test_user_output_only(self) -> None:
"""All outputs are USER_OUTPUT: names pass through unchanged."""
program = _make_program(["USER_OUTPUT", "USER_OUTPUT", "USER_OUTPUT"])
names = _build_aoti_output_field_names(program, ["logits", "probs", "length"])
self.assertEqual(names, ["logits", "probs", "length"])

def test_mixed_output_kinds(self) -> None:
"""Regression: USER_OUTPUT slots may appear after extra output kinds.

Mirrors the HSTU export that motivated this helper: 35 Inductor-emitted
slots (parameter-preprocessing / buffer-mutation / token outputs)
precede the 5 actual USER_OUTPUT slots. Eager keys must land on the
USER_OUTPUT positions only, with placeholders elsewhere.
"""
kinds = (
["BUFFER_MUTATION"] * 3
+ ["TOKEN"]
+ ["USER_INPUT_MUTATION"] * 2
+ ["USER_OUTPUT"] * 3
)
Comment on lines +100 to +105
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only covers the "USER_OUTPUT slots at the tail" layout. The zip-by-iterator logic would also pass if the code mistakenly bucketed all non-user outputs to the start. Consider adding interleaved ([USER_OUTPUT, BUFFER_MUTATION, USER_OUTPUT, TOKEN, USER_OUTPUT]) and prefix ([USER_OUTPUT]*3 + [BUFFER_MUTATION]*3) patterns — those are the layouts most likely to regress an ordering bug.

program = _make_program(kinds)
eager_keys = ["logits_is_click", "probs_is_click", "length"]

names = _build_aoti_output_field_names(program, eager_keys)

# One name per AOTI output handle.
self.assertEqual(len(names), len(kinds))
# USER_OUTPUT slots receive eager keys in order.
self.assertEqual(names[-3:], eager_keys)
# Non-USER_OUTPUT slots are all placeholders (never eager keys).
for i in range(6):
self.assertTrue(
names[i].startswith("_unused_"),
f"expected placeholder at position {i}, got {names[i]!r}",
)
# Placeholder carries its index so collisions are impossible.
self.assertIn(f"_{i}_", names[i])
# And no placeholder ever duplicates a real user-output name.
for placeholder in names[:6]:
self.assertNotIn(placeholder, eager_keys)

def test_placeholders_are_unique(self) -> None:
"""Placeholders embed their slot index so names are globally unique."""
kinds = ["BUFFER_MUTATION"] * 4 + ["USER_OUTPUT"]
program = _make_program(kinds)
names = _build_aoti_output_field_names(program, ["out"])
self.assertEqual(len(set(names)), len(names))

def test_user_output_count_mismatch_raises(self) -> None:
"""Fail loudly rather than silently misalign names."""
program = _make_program(["USER_OUTPUT", "USER_OUTPUT"])
with self.assertRaises(RuntimeError) as cm:
_build_aoti_output_field_names(program, ["only_one_key"])
msg = str(cm.exception)
self.assertIn("USER_OUTPUT", msg)
self.assertIn("only_one_key", msg)

def test_string_kind_fallback(self) -> None:
"""Older torch versions may surface ``kind`` as a plain string."""
# Mix of bare strings and enum-like objects to exercise both paths.
specs = [
_FakeSpec("OutputKind.BUFFER_MUTATION"),
_FakeSpec("OutputKind.USER_OUTPUT"),
_FakeSpec(_FakeKind("USER_OUTPUT")),
]
program = _FakeExportedProgram(specs)
names = _build_aoti_output_field_names(program, ["a", "b"])
self.assertEqual(len(names), 3)
self.assertTrue(names[0].startswith("_unused_0_"))
self.assertEqual(names[1], "a")
self.assertEqual(names[2], "b")

def test_empty_outputs(self) -> None:
"""Zero outputs on both sides is a valid (if unusual) configuration."""
program = _make_program([])
names = _build_aoti_output_field_names(program, [])
self.assertEqual(names, [])

def test_all_non_user_output_with_nonempty_eager_raises(self) -> None:
"""If the graph exposes no USER_OUTPUT, any eager key is a mismatch."""
program = _make_program(["BUFFER_MUTATION", "TOKEN"])
with self.assertRaises(RuntimeError):
_build_aoti_output_field_names(program, ["x"])


if __name__ == "__main__":
unittest.main()
Loading