[bugfix]the output keys of aoti model are inconsistent between eager output and exported program#478
Conversation
| if isinstance(name, str): | ||
| return name == "USER_OUTPUT" | ||
| # String fallback. | ||
| return str(kind).endswith("USER_OUTPUT") |
There was a problem hiding this comment.
The string fallback str(kind).endswith("USER_OUTPUT") is looser than it needs to be. None of today's OutputKind values (LOSS_OUTPUT, BUFFER_MUTATION, GRADIENT_TO_PARAMETER, GRADIENT_TO_USER_INPUT, USER_INPUT_MUTATION, TOKEN) trigger a false positive, but a future enum value that happens to end in USER_OUTPUT (or any subclass/wrapper whose __str__ concatenates extra text) would silently mis-classify. Prefer an anchored check, e.g. split on . and compare the final segment to "USER_OUTPUT", or match against {"USER_OUTPUT", "OutputKind.USER_OUTPUT"} exactly. The test_string_kind_fallback test only exercises well-formed strings and wouldn't catch this.
| kinds = ( | ||
| ["BUFFER_MUTATION"] * 3 | ||
| + ["TOKEN"] | ||
| + ["USER_INPUT_MUTATION"] * 2 | ||
| + ["USER_OUTPUT"] * 3 | ||
| ) |
There was a problem hiding this comment.
This only covers the "USER_OUTPUT slots at the tail" layout. The zip-by-iterator logic would also pass if the code mistakenly bucketed all non-user outputs to the start. Consider adding interleaved ([USER_OUTPUT, BUFFER_MUTATION, USER_OUTPUT, TOKEN, USER_OUTPUT]) and prefix ([USER_OUTPUT]*3 + [BUFFER_MUTATION]*3) patterns — those are the layouts most likely to regress an ordering bug.
Review summaryThe bugfix is well-motivated and the helper's contract is clearly documented. Main feedback in inline comments:
No security concerns. No performance concerns (export-time, <100 specs). Docstrings are internally consistent with the implementation (minor: the "extra output kinds" list in the docstring isn't exhaustive vs. the real |
| if aoti_output_keys: | ||
| # Save output field names to aoti directory (one per AOTI output | ||
| # handle; non-USER_OUTPUT slots are filled with _unused_* placeholders). | ||
| if aoti_output_field_names: |
There was a problem hiding this comment.
The surrounding compile/save block (lines ~247–263: makedirs(aoti_dir) → optional JSON dump → aoti_compile_and_package, inside _inductor.config.patch) is copy-pasted to export_unified_model_aot at lines ~541–557. This PR is itself a fix that had to land in both places because they drifted, and the prior aoti_output_keys handling also lived in both. Consider extracting a small helper like _compile_and_save_aoti(exported_pg, aoti_output_field_names, save_dir) so the next fix only lands once.
No description provided.