From 962be875b69ebf7858d8830641763cc1d9521558 Mon Sep 17 00:00:00 2001 From: gecheng Date: Sun, 19 Apr 2026 19:49:14 +0800 Subject: [PATCH] bugfix, the output keys are inconsistent for eager output and exported program --- tzrec/acc/aot_utils.py | 120 ++++++++++++++++++++++--- tzrec/acc/aot_utils_test.py | 172 ++++++++++++++++++++++++++++++++++++ 2 files changed, 280 insertions(+), 12 deletions(-) create mode 100644 tzrec/acc/aot_utils_test.py diff --git a/tzrec/acc/aot_utils.py b/tzrec/acc/aot_utils.py index 861ea5ba..d569ba20 100644 --- a/tzrec/acc/aot_utils.py +++ b/tzrec/acc/aot_utils.py @@ -12,7 +12,7 @@ import json import os -from typing import Any, Dict, Optional, Set, Union +from typing import Any, Dict, List, Optional, Set, Union import torch from torch import nn @@ -38,6 +38,84 @@ logger.debug("cutlass_hstu_attention not available; skipping op registration") +def _build_aoti_output_field_names( + exported_pg: "torch.export.ExportedProgram", + eager_output_keys: List[str], +) -> List[str]: + """Align user-facing output names with the AOTI output-handle layout. + + ``torch._inductor.aoti_compile_and_package`` compiles the graph inside + ``exported_pg``. The resulting AOTI wrapper emits one ``output_handles[i]`` + per leaf in ``exported_pg.graph_signature.output_specs``. That list can be + **longer** than the eager module's return dict because + ``torch.export.export`` also surfaces buffer-mutation results, token + outputs, gradient-to-parameter signals, etc. as extra outputs. + + The runtime (TorchRecProcessor) names the emitted tensors positionally + using ``output_field_names.json``, so the JSON **must** have one entry per + AOTI output handle. We therefore walk ``output_specs`` in order, assign + the eager dict keys only to ``USER_OUTPUT`` slots, and fill every other + slot with an ignorable ``_unused_*`` placeholder. + + Args: + exported_pg: Result of ``torch.export.export(...)``. + eager_output_keys: Keys of the eager forward's return dict, in + dict-iteration order (which is the order ``torch.export`` flattens + them in). + + Returns: + List of names, one per AOTI output handle, with USER_OUTPUT slots + carrying the eager keys in order and all other slots filled with + ``_unused__`` placeholders. + + Raises: + RuntimeError: If the number of USER_OUTPUT slots does not match + ``len(eager_output_keys)``; this indicates the eager dict and + the exported program disagree on the user-visible outputs and the + exported model cannot be reliably consumed downstream. + """ + output_specs = exported_pg.graph_signature.output_specs + + # Identify USER_OUTPUT slots robustly across torch versions. + def _is_user_output(spec: Any) -> bool: + kind = getattr(spec, "kind", None) + if kind is None: + return False + # Enum path (current torch): kind.name == "USER_OUTPUT". + name = getattr(kind, "name", None) + if isinstance(name, str): + return name == "USER_OUTPUT" + # String fallback. + return str(kind).endswith("USER_OUTPUT") + + user_output_positions = [ + i for i, spec in enumerate(output_specs) if _is_user_output(spec) + ] + + if len(user_output_positions) != len(eager_output_keys): + raise RuntimeError( + "AOTI output-name alignment failed: exported program has " + f"{len(user_output_positions)} USER_OUTPUT slots " + f"(out of {len(output_specs)} total) but eager forward returned " + f"{len(eager_output_keys)} keys ({eager_output_keys}). " + "The eager dict and the torch.export graph disagree on " + "user-visible outputs; refusing to emit a mislabeled " + "output_field_names.json." + ) + + names: List[str] = [] + user_iter = iter(eager_output_keys) + for i, spec in enumerate(output_specs): + if _is_user_output(spec): + names.append(next(user_iter)) + else: + kind = getattr(spec, "kind", None) + kind_name = getattr(kind, "name", None) or str(kind) or "other" + names.append(f"_unused_{i}_{kind_name.lower()}") + + return names + + def load_model_aot( model_path: str, device: torch.device ) -> Union[CombinedModelWrapper, UnifiedAOTIModelWrapper]: @@ -137,7 +215,7 @@ def export_model_aot( # active — kernels like CUTLASS HSTU attention reject fp32 inputs. with torch.no_grad(): _out = dense_to_export(sparse_output) - aoti_output_keys = list(_out.keys()) + eager_output_keys = list(_out.keys()) del _out # pre_hook requires running arbitrary code at runtime @@ -149,6 +227,16 @@ def export_model_aot( args=(sparse_output,), dynamic_shapes=(dynamic_shapes,), ) + + # Align names with AOTI output-handle layout: the exported program may + # emit extra outputs (buffer mutations, tokens, ...) in addition to the + # eager dict's user-visible outputs. The runtime indexes into + # output_field_names.json positionally, so it must have exactly one + # entry per AOTI output handle. + aoti_output_field_names = _build_aoti_output_field_names( + exported_pg, eager_output_keys + ) + # AsserScalar codegen is not correct. with torch._inductor.config.patch( { @@ -159,13 +247,15 @@ def export_model_aot( aoti_dir = os.path.join(save_dir, "aoti") os.makedirs(aoti_dir, exist_ok=True) - # Save original model output field names to aoti directory - if aoti_output_keys: + # Save output field names to aoti directory (one per AOTI output + # handle; non-USER_OUTPUT slots are filled with _unused_* placeholders). + if aoti_output_field_names: output_names_path = os.path.join(aoti_dir, "output_field_names.json") with open(output_names_path, "w") as f: - json.dump(aoti_output_keys, f, indent=4) + json.dump(aoti_output_field_names, f, indent=4) logger.info( - f"Saved output field names to {output_names_path}: {aoti_output_keys}" + f"Saved output field names to {output_names_path}: " + f"{aoti_output_field_names}" ) torch._inductor.aoti_compile_and_package( @@ -408,7 +498,7 @@ def export_unified_model_aot( f.write(full_gm.code) result = full_gm(data) - aoti_output_keys = list(result.keys()) + eager_output_keys = list(result.keys()) del result # Pad any 0-size non-sequence sparse .values tensors so torch.export @@ -435,6 +525,11 @@ def export_unified_model_aot( dynamic_shapes=(dynamic_shapes,), ) + # Align names with AOTI output-handle layout (see _build_aoti_output_field_names). + aoti_output_field_names = _build_aoti_output_field_names( + exported_pg, eager_output_keys + ) + # Compile with AOTI logger.info("compiling unified model with AOTI...") with torch._inductor.config.patch( @@ -446,14 +541,15 @@ def export_unified_model_aot( aoti_dir = os.path.join(save_dir, "aoti") os.makedirs(aoti_dir, exist_ok=True) - # Save original model output field names (matches legacy - # export_model_aot behavior). - if aoti_output_keys: + # Save output field names (one per AOTI output handle; non-USER_OUTPUT + # slots are filled with _unused_* placeholders). + if aoti_output_field_names: output_names_path = os.path.join(aoti_dir, "output_field_names.json") with open(output_names_path, "w") as f: - json.dump(aoti_output_keys, f, indent=4) + json.dump(aoti_output_field_names, f, indent=4) logger.info( - f"Saved output field names to {output_names_path}: {aoti_output_keys}" + f"Saved output field names to {output_names_path}: " + f"{aoti_output_field_names}" ) torch._inductor.aoti_compile_and_package( diff --git a/tzrec/acc/aot_utils_test.py b/tzrec/acc/aot_utils_test.py new file mode 100644 index 00000000..351d839c --- /dev/null +++ b/tzrec/acc/aot_utils_test.py @@ -0,0 +1,172 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Unit tests for ``tzrec.acc.aot_utils._build_aoti_output_field_names``. + +These tests verify that the output-name alignment helper correctly maps the +eager forward's return-dict keys onto the AOTI output-handle layout emitted +by ``torch.export.export``. The helper is pure Python and only reads +``exported_pg.graph_signature.output_specs[i].kind`` / ``.target``, so the +tests use lightweight duck-typed fakes instead of invoking the real +``torch.export`` machinery. That keeps the tests fast, deterministic, and +independent of the surrounding CUDA / Triton / TorchRec stack. + +Regression coverage: + +* ``test_user_output_only``: the original (pre-bugfix) happy path where + ``USER_OUTPUT`` is the only output kind. +* ``test_mixed_output_kinds``: the buggy scenario from the HSTU export that + motivated this helper. Non-``USER_OUTPUT`` slots must be filled with + placeholders so the JSON written to disk has one entry per AOTI output + handle (and downstream tensors are not renamed by position drift). +* ``test_user_output_count_mismatch_raises``: if the exported program does + not expose the same number of USER_OUTPUT slots as the eager dict returns, + the helper must refuse to emit a mislabeled mapping. +* ``test_string_kind_fallback``: older torch builds surface ``kind`` as a + bare string rather than an enum; the helper must still identify + ``USER_OUTPUT`` slots in that case. +""" + +import unittest +from typing import Any, List, Optional + +from tzrec.acc.aot_utils import _build_aoti_output_field_names + + +class _FakeKind: + """Duck-typed stand-in for ``torch.export.graph_signature.OutputKind``. + + The real enum exposes ``.name`` (e.g. ``"USER_OUTPUT"``, + ``"BUFFER_MUTATION"``). ``_build_aoti_output_field_names`` reads exactly + that attribute plus ``str(kind)``, nothing else. + """ + + def __init__(self, name: str) -> None: + self.name = name + + def __str__(self) -> str: # matches enum ``str(OutputKind.X)`` shape. + return f"OutputKind.{self.name}" + + +class _FakeSpec: + """Stand-in for one entry of ``graph_signature.output_specs``.""" + + def __init__(self, kind: Any, target: Optional[str] = None) -> None: + self.kind = kind + self.target = target + + +class _FakeSignature: + def __init__(self, output_specs: List[_FakeSpec]) -> None: + self.output_specs = output_specs + + +class _FakeExportedProgram: + """Minimal duck-typed ExportedProgram exposing only ``.graph_signature``.""" + + def __init__(self, output_specs: List[_FakeSpec]) -> None: + self.graph_signature = _FakeSignature(output_specs) + + +def _make_program(kinds: List[str]) -> _FakeExportedProgram: + """Build a fake ExportedProgram with the given sequence of kind names.""" + return _FakeExportedProgram([_FakeSpec(_FakeKind(name)) for name in kinds]) + + +class BuildAotiOutputFieldNamesTest(unittest.TestCase): + def test_user_output_only(self) -> None: + """All outputs are USER_OUTPUT: names pass through unchanged.""" + program = _make_program(["USER_OUTPUT", "USER_OUTPUT", "USER_OUTPUT"]) + names = _build_aoti_output_field_names(program, ["logits", "probs", "length"]) + self.assertEqual(names, ["logits", "probs", "length"]) + + def test_mixed_output_kinds(self) -> None: + """Regression: USER_OUTPUT slots may appear after extra output kinds. + + Mirrors the HSTU export that motivated this helper: 35 Inductor-emitted + slots (parameter-preprocessing / buffer-mutation / token outputs) + precede the 5 actual USER_OUTPUT slots. Eager keys must land on the + USER_OUTPUT positions only, with placeholders elsewhere. + """ + kinds = ( + ["BUFFER_MUTATION"] * 3 + + ["TOKEN"] + + ["USER_INPUT_MUTATION"] * 2 + + ["USER_OUTPUT"] * 3 + ) + program = _make_program(kinds) + eager_keys = ["logits_is_click", "probs_is_click", "length"] + + names = _build_aoti_output_field_names(program, eager_keys) + + # One name per AOTI output handle. + self.assertEqual(len(names), len(kinds)) + # USER_OUTPUT slots receive eager keys in order. + self.assertEqual(names[-3:], eager_keys) + # Non-USER_OUTPUT slots are all placeholders (never eager keys). + for i in range(6): + self.assertTrue( + names[i].startswith("_unused_"), + f"expected placeholder at position {i}, got {names[i]!r}", + ) + # Placeholder carries its index so collisions are impossible. + self.assertIn(f"_{i}_", names[i]) + # And no placeholder ever duplicates a real user-output name. + for placeholder in names[:6]: + self.assertNotIn(placeholder, eager_keys) + + def test_placeholders_are_unique(self) -> None: + """Placeholders embed their slot index so names are globally unique.""" + kinds = ["BUFFER_MUTATION"] * 4 + ["USER_OUTPUT"] + program = _make_program(kinds) + names = _build_aoti_output_field_names(program, ["out"]) + self.assertEqual(len(set(names)), len(names)) + + def test_user_output_count_mismatch_raises(self) -> None: + """Fail loudly rather than silently misalign names.""" + program = _make_program(["USER_OUTPUT", "USER_OUTPUT"]) + with self.assertRaises(RuntimeError) as cm: + _build_aoti_output_field_names(program, ["only_one_key"]) + msg = str(cm.exception) + self.assertIn("USER_OUTPUT", msg) + self.assertIn("only_one_key", msg) + + def test_string_kind_fallback(self) -> None: + """Older torch versions may surface ``kind`` as a plain string.""" + # Mix of bare strings and enum-like objects to exercise both paths. + specs = [ + _FakeSpec("OutputKind.BUFFER_MUTATION"), + _FakeSpec("OutputKind.USER_OUTPUT"), + _FakeSpec(_FakeKind("USER_OUTPUT")), + ] + program = _FakeExportedProgram(specs) + names = _build_aoti_output_field_names(program, ["a", "b"]) + self.assertEqual(len(names), 3) + self.assertTrue(names[0].startswith("_unused_0_")) + self.assertEqual(names[1], "a") + self.assertEqual(names[2], "b") + + def test_empty_outputs(self) -> None: + """Zero outputs on both sides is a valid (if unusual) configuration.""" + program = _make_program([]) + names = _build_aoti_output_field_names(program, []) + self.assertEqual(names, []) + + def test_all_non_user_output_with_nonempty_eager_raises(self) -> None: + """If the graph exposes no USER_OUTPUT, any eager key is a mismatch.""" + program = _make_program(["BUFFER_MUTATION", "TOKEN"]) + with self.assertRaises(RuntimeError): + _build_aoti_output_field_names(program, ["x"]) + + +if __name__ == "__main__": + unittest.main()