-
Notifications
You must be signed in to change notification settings - Fork 70
[bugfix]the output keys of aoti model are inconsistent between eager output and exported program #478
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
[bugfix]the output keys of aoti model are inconsistent between eager output and exported program #478
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,7 +12,7 @@ | |
|
|
||
| import json | ||
| import os | ||
| from typing import Any, Dict, Optional, Set, Union | ||
| from typing import Any, Dict, List, Optional, Set, Union | ||
|
|
||
| import torch | ||
| from torch import nn | ||
|
|
@@ -38,6 +38,84 @@ | |
| logger.debug("cutlass_hstu_attention not available; skipping op registration") | ||
|
|
||
|
|
||
| def _build_aoti_output_field_names( | ||
| exported_pg: "torch.export.ExportedProgram", | ||
| eager_output_keys: List[str], | ||
| ) -> List[str]: | ||
| """Align user-facing output names with the AOTI output-handle layout. | ||
|
|
||
| ``torch._inductor.aoti_compile_and_package`` compiles the graph inside | ||
| ``exported_pg``. The resulting AOTI wrapper emits one ``output_handles[i]`` | ||
| per leaf in ``exported_pg.graph_signature.output_specs``. That list can be | ||
| **longer** than the eager module's return dict because | ||
| ``torch.export.export`` also surfaces buffer-mutation results, token | ||
| outputs, gradient-to-parameter signals, etc. as extra outputs. | ||
|
|
||
| The runtime (TorchRecProcessor) names the emitted tensors positionally | ||
| using ``output_field_names.json``, so the JSON **must** have one entry per | ||
| AOTI output handle. We therefore walk ``output_specs`` in order, assign | ||
| the eager dict keys only to ``USER_OUTPUT`` slots, and fill every other | ||
| slot with an ignorable ``_unused_*`` placeholder. | ||
|
|
||
| Args: | ||
| exported_pg: Result of ``torch.export.export(...)``. | ||
| eager_output_keys: Keys of the eager forward's return dict, in | ||
| dict-iteration order (which is the order ``torch.export`` flattens | ||
| them in). | ||
|
|
||
| Returns: | ||
| List of names, one per AOTI output handle, with USER_OUTPUT slots | ||
| carrying the eager keys in order and all other slots filled with | ||
| ``_unused_<idx>_<kind>`` placeholders. | ||
|
|
||
| Raises: | ||
| RuntimeError: If the number of USER_OUTPUT slots does not match | ||
| ``len(eager_output_keys)``; this indicates the eager dict and | ||
| the exported program disagree on the user-visible outputs and the | ||
| exported model cannot be reliably consumed downstream. | ||
| """ | ||
| output_specs = exported_pg.graph_signature.output_specs | ||
|
|
||
| # Identify USER_OUTPUT slots robustly across torch versions. | ||
| def _is_user_output(spec: Any) -> bool: | ||
| kind = getattr(spec, "kind", None) | ||
| if kind is None: | ||
| return False | ||
| # Enum path (current torch): kind.name == "USER_OUTPUT". | ||
| name = getattr(kind, "name", None) | ||
| if isinstance(name, str): | ||
| return name == "USER_OUTPUT" | ||
| # String fallback. | ||
| return str(kind).endswith("USER_OUTPUT") | ||
|
|
||
| user_output_positions = [ | ||
| i for i, spec in enumerate(output_specs) if _is_user_output(spec) | ||
| ] | ||
|
|
||
| if len(user_output_positions) != len(eager_output_keys): | ||
| raise RuntimeError( | ||
| "AOTI output-name alignment failed: exported program has " | ||
| f"{len(user_output_positions)} USER_OUTPUT slots " | ||
| f"(out of {len(output_specs)} total) but eager forward returned " | ||
| f"{len(eager_output_keys)} keys ({eager_output_keys}). " | ||
| "The eager dict and the torch.export graph disagree on " | ||
| "user-visible outputs; refusing to emit a mislabeled " | ||
| "output_field_names.json." | ||
| ) | ||
|
|
||
| names: List[str] = [] | ||
| user_iter = iter(eager_output_keys) | ||
| for i, spec in enumerate(output_specs): | ||
| if _is_user_output(spec): | ||
| names.append(next(user_iter)) | ||
| else: | ||
| kind = getattr(spec, "kind", None) | ||
| kind_name = getattr(kind, "name", None) or str(kind) or "other" | ||
| names.append(f"_unused_{i}_{kind_name.lower()}") | ||
|
|
||
| return names | ||
|
|
||
|
|
||
| def load_model_aot( | ||
| model_path: str, device: torch.device | ||
| ) -> Union[CombinedModelWrapper, UnifiedAOTIModelWrapper]: | ||
|
|
@@ -137,7 +215,7 @@ def export_model_aot( | |
| # active — kernels like CUTLASS HSTU attention reject fp32 inputs. | ||
| with torch.no_grad(): | ||
| _out = dense_to_export(sparse_output) | ||
| aoti_output_keys = list(_out.keys()) | ||
| eager_output_keys = list(_out.keys()) | ||
| del _out | ||
|
|
||
| # pre_hook requires running arbitrary code at runtime | ||
|
|
@@ -149,6 +227,16 @@ def export_model_aot( | |
| args=(sparse_output,), | ||
| dynamic_shapes=(dynamic_shapes,), | ||
| ) | ||
|
|
||
| # Align names with AOTI output-handle layout: the exported program may | ||
| # emit extra outputs (buffer mutations, tokens, ...) in addition to the | ||
| # eager dict's user-visible outputs. The runtime indexes into | ||
| # output_field_names.json positionally, so it must have exactly one | ||
| # entry per AOTI output handle. | ||
| aoti_output_field_names = _build_aoti_output_field_names( | ||
| exported_pg, eager_output_keys | ||
| ) | ||
|
|
||
| # AsserScalar codegen is not correct. | ||
| with torch._inductor.config.patch( | ||
| { | ||
|
|
@@ -159,13 +247,15 @@ def export_model_aot( | |
| aoti_dir = os.path.join(save_dir, "aoti") | ||
| os.makedirs(aoti_dir, exist_ok=True) | ||
|
|
||
| # Save original model output field names to aoti directory | ||
| if aoti_output_keys: | ||
| # Save output field names to aoti directory (one per AOTI output | ||
| # handle; non-USER_OUTPUT slots are filled with _unused_* placeholders). | ||
| if aoti_output_field_names: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The surrounding compile/save block (lines ~247–263: |
||
| output_names_path = os.path.join(aoti_dir, "output_field_names.json") | ||
| with open(output_names_path, "w") as f: | ||
| json.dump(aoti_output_keys, f, indent=4) | ||
| json.dump(aoti_output_field_names, f, indent=4) | ||
| logger.info( | ||
| f"Saved output field names to {output_names_path}: {aoti_output_keys}" | ||
| f"Saved output field names to {output_names_path}: " | ||
| f"{aoti_output_field_names}" | ||
| ) | ||
|
|
||
| torch._inductor.aoti_compile_and_package( | ||
|
|
@@ -408,7 +498,7 @@ def export_unified_model_aot( | |
| f.write(full_gm.code) | ||
|
|
||
| result = full_gm(data) | ||
| aoti_output_keys = list(result.keys()) | ||
| eager_output_keys = list(result.keys()) | ||
| del result | ||
|
|
||
| # Pad any 0-size non-sequence sparse .values tensors so torch.export | ||
|
|
@@ -435,6 +525,11 @@ def export_unified_model_aot( | |
| dynamic_shapes=(dynamic_shapes,), | ||
| ) | ||
|
|
||
| # Align names with AOTI output-handle layout (see _build_aoti_output_field_names). | ||
| aoti_output_field_names = _build_aoti_output_field_names( | ||
| exported_pg, eager_output_keys | ||
| ) | ||
|
|
||
| # Compile with AOTI | ||
| logger.info("compiling unified model with AOTI...") | ||
| with torch._inductor.config.patch( | ||
|
|
@@ -446,14 +541,15 @@ def export_unified_model_aot( | |
| aoti_dir = os.path.join(save_dir, "aoti") | ||
| os.makedirs(aoti_dir, exist_ok=True) | ||
|
|
||
| # Save original model output field names (matches legacy | ||
| # export_model_aot behavior). | ||
| if aoti_output_keys: | ||
| # Save output field names (one per AOTI output handle; non-USER_OUTPUT | ||
| # slots are filled with _unused_* placeholders). | ||
| if aoti_output_field_names: | ||
| output_names_path = os.path.join(aoti_dir, "output_field_names.json") | ||
| with open(output_names_path, "w") as f: | ||
| json.dump(aoti_output_keys, f, indent=4) | ||
| json.dump(aoti_output_field_names, f, indent=4) | ||
| logger.info( | ||
| f"Saved output field names to {output_names_path}: {aoti_output_keys}" | ||
| f"Saved output field names to {output_names_path}: " | ||
| f"{aoti_output_field_names}" | ||
| ) | ||
|
|
||
| torch._inductor.aoti_compile_and_package( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,172 @@ | ||
| # Copyright (c) 2024, Alibaba Group; | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
|
|
||
| """Unit tests for ``tzrec.acc.aot_utils._build_aoti_output_field_names``. | ||
|
|
||
| These tests verify that the output-name alignment helper correctly maps the | ||
| eager forward's return-dict keys onto the AOTI output-handle layout emitted | ||
| by ``torch.export.export``. The helper is pure Python and only reads | ||
| ``exported_pg.graph_signature.output_specs[i].kind`` / ``.target``, so the | ||
| tests use lightweight duck-typed fakes instead of invoking the real | ||
| ``torch.export`` machinery. That keeps the tests fast, deterministic, and | ||
| independent of the surrounding CUDA / Triton / TorchRec stack. | ||
|
|
||
| Regression coverage: | ||
|
|
||
| * ``test_user_output_only``: the original (pre-bugfix) happy path where | ||
| ``USER_OUTPUT`` is the only output kind. | ||
| * ``test_mixed_output_kinds``: the buggy scenario from the HSTU export that | ||
| motivated this helper. Non-``USER_OUTPUT`` slots must be filled with | ||
| placeholders so the JSON written to disk has one entry per AOTI output | ||
| handle (and downstream tensors are not renamed by position drift). | ||
| * ``test_user_output_count_mismatch_raises``: if the exported program does | ||
| not expose the same number of USER_OUTPUT slots as the eager dict returns, | ||
| the helper must refuse to emit a mislabeled mapping. | ||
| * ``test_string_kind_fallback``: older torch builds surface ``kind`` as a | ||
| bare string rather than an enum; the helper must still identify | ||
| ``USER_OUTPUT`` slots in that case. | ||
| """ | ||
|
|
||
| import unittest | ||
| from typing import Any, List, Optional | ||
|
|
||
| from tzrec.acc.aot_utils import _build_aoti_output_field_names | ||
|
|
||
|
|
||
| class _FakeKind: | ||
| """Duck-typed stand-in for ``torch.export.graph_signature.OutputKind``. | ||
|
|
||
| The real enum exposes ``.name`` (e.g. ``"USER_OUTPUT"``, | ||
| ``"BUFFER_MUTATION"``). ``_build_aoti_output_field_names`` reads exactly | ||
| that attribute plus ``str(kind)``, nothing else. | ||
| """ | ||
|
|
||
| def __init__(self, name: str) -> None: | ||
| self.name = name | ||
|
|
||
| def __str__(self) -> str: # matches enum ``str(OutputKind.X)`` shape. | ||
| return f"OutputKind.{self.name}" | ||
|
|
||
|
|
||
| class _FakeSpec: | ||
| """Stand-in for one entry of ``graph_signature.output_specs``.""" | ||
|
|
||
| def __init__(self, kind: Any, target: Optional[str] = None) -> None: | ||
| self.kind = kind | ||
| self.target = target | ||
|
|
||
|
|
||
| class _FakeSignature: | ||
| def __init__(self, output_specs: List[_FakeSpec]) -> None: | ||
| self.output_specs = output_specs | ||
|
|
||
|
|
||
| class _FakeExportedProgram: | ||
| """Minimal duck-typed ExportedProgram exposing only ``.graph_signature``.""" | ||
|
|
||
| def __init__(self, output_specs: List[_FakeSpec]) -> None: | ||
| self.graph_signature = _FakeSignature(output_specs) | ||
|
|
||
|
|
||
| def _make_program(kinds: List[str]) -> _FakeExportedProgram: | ||
| """Build a fake ExportedProgram with the given sequence of kind names.""" | ||
| return _FakeExportedProgram([_FakeSpec(_FakeKind(name)) for name in kinds]) | ||
|
|
||
|
|
||
| class BuildAotiOutputFieldNamesTest(unittest.TestCase): | ||
| def test_user_output_only(self) -> None: | ||
| """All outputs are USER_OUTPUT: names pass through unchanged.""" | ||
| program = _make_program(["USER_OUTPUT", "USER_OUTPUT", "USER_OUTPUT"]) | ||
| names = _build_aoti_output_field_names(program, ["logits", "probs", "length"]) | ||
| self.assertEqual(names, ["logits", "probs", "length"]) | ||
|
|
||
| def test_mixed_output_kinds(self) -> None: | ||
| """Regression: USER_OUTPUT slots may appear after extra output kinds. | ||
|
|
||
| Mirrors the HSTU export that motivated this helper: 35 Inductor-emitted | ||
| slots (parameter-preprocessing / buffer-mutation / token outputs) | ||
| precede the 5 actual USER_OUTPUT slots. Eager keys must land on the | ||
| USER_OUTPUT positions only, with placeholders elsewhere. | ||
| """ | ||
| kinds = ( | ||
| ["BUFFER_MUTATION"] * 3 | ||
| + ["TOKEN"] | ||
| + ["USER_INPUT_MUTATION"] * 2 | ||
| + ["USER_OUTPUT"] * 3 | ||
| ) | ||
|
Comment on lines
+100
to
+105
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This only covers the "USER_OUTPUT slots at the tail" layout. The zip-by-iterator logic would also pass if the code mistakenly bucketed all non-user outputs to the start. Consider adding interleaved ( |
||
| program = _make_program(kinds) | ||
| eager_keys = ["logits_is_click", "probs_is_click", "length"] | ||
|
|
||
| names = _build_aoti_output_field_names(program, eager_keys) | ||
|
|
||
| # One name per AOTI output handle. | ||
| self.assertEqual(len(names), len(kinds)) | ||
| # USER_OUTPUT slots receive eager keys in order. | ||
| self.assertEqual(names[-3:], eager_keys) | ||
| # Non-USER_OUTPUT slots are all placeholders (never eager keys). | ||
| for i in range(6): | ||
| self.assertTrue( | ||
| names[i].startswith("_unused_"), | ||
| f"expected placeholder at position {i}, got {names[i]!r}", | ||
| ) | ||
| # Placeholder carries its index so collisions are impossible. | ||
| self.assertIn(f"_{i}_", names[i]) | ||
| # And no placeholder ever duplicates a real user-output name. | ||
| for placeholder in names[:6]: | ||
| self.assertNotIn(placeholder, eager_keys) | ||
|
|
||
| def test_placeholders_are_unique(self) -> None: | ||
| """Placeholders embed their slot index so names are globally unique.""" | ||
| kinds = ["BUFFER_MUTATION"] * 4 + ["USER_OUTPUT"] | ||
| program = _make_program(kinds) | ||
| names = _build_aoti_output_field_names(program, ["out"]) | ||
| self.assertEqual(len(set(names)), len(names)) | ||
|
|
||
| def test_user_output_count_mismatch_raises(self) -> None: | ||
| """Fail loudly rather than silently misalign names.""" | ||
| program = _make_program(["USER_OUTPUT", "USER_OUTPUT"]) | ||
| with self.assertRaises(RuntimeError) as cm: | ||
| _build_aoti_output_field_names(program, ["only_one_key"]) | ||
| msg = str(cm.exception) | ||
| self.assertIn("USER_OUTPUT", msg) | ||
| self.assertIn("only_one_key", msg) | ||
|
|
||
| def test_string_kind_fallback(self) -> None: | ||
| """Older torch versions may surface ``kind`` as a plain string.""" | ||
| # Mix of bare strings and enum-like objects to exercise both paths. | ||
| specs = [ | ||
| _FakeSpec("OutputKind.BUFFER_MUTATION"), | ||
| _FakeSpec("OutputKind.USER_OUTPUT"), | ||
| _FakeSpec(_FakeKind("USER_OUTPUT")), | ||
| ] | ||
| program = _FakeExportedProgram(specs) | ||
| names = _build_aoti_output_field_names(program, ["a", "b"]) | ||
| self.assertEqual(len(names), 3) | ||
| self.assertTrue(names[0].startswith("_unused_0_")) | ||
| self.assertEqual(names[1], "a") | ||
| self.assertEqual(names[2], "b") | ||
|
|
||
| def test_empty_outputs(self) -> None: | ||
| """Zero outputs on both sides is a valid (if unusual) configuration.""" | ||
| program = _make_program([]) | ||
| names = _build_aoti_output_field_names(program, []) | ||
| self.assertEqual(names, []) | ||
|
|
||
| def test_all_non_user_output_with_nonempty_eager_raises(self) -> None: | ||
| """If the graph exposes no USER_OUTPUT, any eager key is a mismatch.""" | ||
| program = _make_program(["BUFFER_MUTATION", "TOKEN"]) | ||
| with self.assertRaises(RuntimeError): | ||
| _build_aoti_output_field_names(program, ["x"]) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The string fallback
str(kind).endswith("USER_OUTPUT")is looser than it needs to be. None of today'sOutputKindvalues (LOSS_OUTPUT,BUFFER_MUTATION,GRADIENT_TO_PARAMETER,GRADIENT_TO_USER_INPUT,USER_INPUT_MUTATION,TOKEN) trigger a false positive, but a future enum value that happens to end inUSER_OUTPUT(or any subclass/wrapper whose__str__concatenates extra text) would silently mis-classify. Prefer an anchored check, e.g. split on.and compare the final segment to"USER_OUTPUT", or match against{"USER_OUTPUT", "OutputKind.USER_OUTPUT"}exactly. Thetest_string_kind_fallbacktest only exercises well-formed strings and wouldn't catch this.