From af045e78c336cc5a816c591a1fe36dad738355a2 Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Sat, 14 Mar 2026 06:02:59 -0400 Subject: [PATCH 01/28] fix: replace 'TBD merged' invariant description with real text The _try_merge_hypotheses() function in APIContainRelation was leaving text_description as the placeholder "TBD merged" on every hypothesis that went through the merge path. These descriptions surfaced verbatim in invariants.json and the HTML checker report, making merged invariants uninterpretable. Generate the description from the parent API's full name and the generalized child param, consistent with the non-merged path. Also assert the parent param type to satisfy the static checker. Co-Authored-By: Claude Sonnet 4.6 --- traincheck/invariant/contain_relation.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/traincheck/invariant/contain_relation.py b/traincheck/invariant/contain_relation.py index 0c4d0e95..ac42ae65 100644 --- a/traincheck/invariant/contain_relation.py +++ b/traincheck/invariant/contain_relation.py @@ -276,14 +276,16 @@ def _merge_hypotheses(hypotheses: list[Hypothesis]) -> list[Hypothesis]: setattr(merged_child_param, field, generalized_value) # construct the merged hypotheses + parent_param = hypotheses[0].invariant.params[0] + assert isinstance(parent_param, APIParam) merged_hypothesis = Hypothesis( invariant=Invariant( relation=hypotheses[0].invariant.relation, params=[ - hypotheses[0].invariant.params[0], + parent_param, merged_child_param, ], - text_description="TBD merged", + text_description=f"{parent_param.api_full_name} contains {merged_child_param}", num_positive_examples=len(all_positive_examples), num_negative_examples=len(all_positive_examples), precondition=None, # to be inferred later From 148e75f30b2fa0f24a6752c7e2d831cc87da64b3 Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Sat, 14 Mar 2026 06:03:19 -0400 Subject: [PATCH 02/28] docs: fix broken link and remove under-construction banner - usage-guide.md: fix broken link to 5-min-tutorial.md (was pointing to ./docs/5-min.md which does not exist) - technical-doc.md: remove the 'under construction' warning that was the first thing visitors saw on the technical docs page Co-Authored-By: Claude Sonnet 4.6 --- docs/technical-doc.md | 3 --- docs/usage-guide.md | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/docs/technical-doc.md b/docs/technical-doc.md index c4923738..a49d6a4f 100644 --- a/docs/technical-doc.md +++ b/docs/technical-doc.md @@ -1,8 +1,5 @@ # TrainCheck Documentation -🚜 This documentation is under construction. We welcome any feedback or questions through GitHub Issues or [our Discord server](https://discord.gg/DPEd7Xeg). - - TrainCheck is a lightweight, invariant-based instrumentation and analysis tool for identifying silent correctness issues in PyTorch training pipelines. It infers behavioral invariants from correct reference runs (e.g., official examples or clean configurations), then checks other scripts for behavioral violations. TrainCheck is designed to be minimally intrusiveβ€”requiring no code modifications or rewrites of training logic. ## πŸ”§ System Overview diff --git a/docs/usage-guide.md b/docs/usage-guide.md index 06e2b935..39f12cbd 100644 --- a/docs/usage-guide.md +++ b/docs/usage-guide.md @@ -4,7 +4,7 @@ TrainCheck helps detect and diagnose silent errors in deep learning training run ## πŸš€ Quick Start -Check out the [5-minute guide](./docs/5-min.md) for a minimal working example. +Check out the [5-minute guide](5-min-tutorial.md) for a minimal working example. ## βœ… Common Use Cases From 4bbb451513d03634f7199c9014661018ea6b8858 Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Sat, 14 Mar 2026 08:33:33 -0400 Subject: [PATCH 03/28] feat: add to_display_name() for human-readable invariant labels MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Each Relation subclass now implements to_display_name(params) returning a natural-language string like "Optimizer.zero_grad() changes Parameter.grad: non-zero β†’ None" instead of raw class names or opaque text_description fields. - base_cls.py: add _short_api_name() helper and default to_display_name() returning None on the Relation base class - All 8 relation types implemented: APIContainRelation, ConsistencyRelation, FunctionCoverRelation, FunctionLeadRelation, DistinctArgumentRelation, ConsistentOutputRelation, ConsistentInputOutputRelation, ThresholdRelation - checker_report._format_invariant_label() now calls to_display_name() first, falling back to text_description, then raw params Co-Authored-By: Claude Sonnet 4.6 --- .../invariant/DistinctArgumentRelation.py | 8 ++ traincheck/invariant/base_cls.py | 22 +++++ traincheck/invariant/consistency_relation.py | 11 +++ .../invariant/consistency_transient_vars.py | 90 +++++++++++++++++++ traincheck/invariant/contain_relation.py | 29 ++++++ traincheck/invariant/cover_relation.py | 12 +++ traincheck/invariant/lead_relation.py | 12 +++ traincheck/reporting/checker_report.py | 3 + 8 files changed, 187 insertions(+) diff --git a/traincheck/invariant/DistinctArgumentRelation.py b/traincheck/invariant/DistinctArgumentRelation.py index 025f7c69..23a63bc0 100644 --- a/traincheck/invariant/DistinctArgumentRelation.py +++ b/traincheck/invariant/DistinctArgumentRelation.py @@ -15,6 +15,7 @@ OnlineCheckerResult, Param, Relation, + _short_api_name, ) from traincheck.invariant.precondition import find_precondition from traincheck.onlinechecker.utils import Checker_data, set_meta_vars_online @@ -465,6 +466,13 @@ def _get_variables_to_check(inv: Invariant): def _get_apis_to_check(inv: Invariant): return None + @staticmethod + def to_display_name(params: list[Param]) -> str | None: + if not params or not isinstance(params[0], APIParam): + return None + func_short = _short_api_name(params[0].api_full_name) + return f"{func_short}() receives distinct arguments at each step" + @staticmethod def _get_api_args_map_to_check(inv): return [inv.params[0].api_full_name] diff --git a/traincheck/invariant/base_cls.py b/traincheck/invariant/base_cls.py index d8649377..071ecaea 100644 --- a/traincheck/invariant/base_cls.py +++ b/traincheck/invariant/base_cls.py @@ -1921,6 +1921,28 @@ def online_check( """Check the invariant online, i.e. during the trace collection process.""" pass + @staticmethod + def to_display_name(params: list[Param]) -> str | None: + """Return a short, human-readable label for this invariant given its params. + + Returns None to fall back to text_description / raw param rendering. + Subclasses should override this to provide meaningful natural-language names. + """ + return None + + +def _short_api_name(full_name: str) -> str: + """Shorten a fully-qualified API name for display. + + 'torch.optim.optimizer.Optimizer.zero_grad' β†’ 'Optimizer.zero_grad' + 'torch.nn.functional.dropout' β†’ 'functional.dropout' + """ + parts = full_name.split(".") + for i, part in enumerate(parts): + if part and part[0].isupper(): + return ".".join(parts[i:]) + return ".".join(parts[-2:]) if len(parts) >= 2 else full_name + def read_inv_file(file_path: str | list[str]) -> list[Invariant]: if isinstance(file_path, str): diff --git a/traincheck/invariant/consistency_relation.py b/traincheck/invariant/consistency_relation.py index 06b24714..74f35c06 100644 --- a/traincheck/invariant/consistency_relation.py +++ b/traincheck/invariant/consistency_relation.py @@ -111,6 +111,17 @@ class ConsistencyRelation(Relation): where the consistency relationships of the variables across different nodes is crucial to maintain. """ + @staticmethod + def to_display_name(params: list[Param]) -> str | None: + if not params: + return None + p = params[0] + if not isinstance(p, VarTypeParam): + return None + var_short = p.var_type.split(".")[-1] + attr = p.attr_name + return f"{var_short}.{attr} stays consistent across training steps" + @staticmethod def infer(trace: Trace) -> tuple[list[Invariant], list[FailedHypothesis]]: """Infer Invariants for the ConsistencyRelation.""" diff --git a/traincheck/invariant/consistency_transient_vars.py b/traincheck/invariant/consistency_transient_vars.py index 485800e9..98ce0721 100644 --- a/traincheck/invariant/consistency_transient_vars.py +++ b/traincheck/invariant/consistency_transient_vars.py @@ -7,6 +7,7 @@ from traincheck.instrumentor.tracer import TraceLineType from traincheck.invariant.base_cls import ( + _NOT_SET, APIParam, Arguments, CheckerResult, @@ -20,6 +21,7 @@ Param, Relation, VarTypeParam, + _short_api_name, make_hashable, ) from traincheck.invariant.precondition import find_precondition @@ -580,6 +582,20 @@ def static_check_all( ) # raise NotImplementedError + @staticmethod + def to_display_name(params: list[Param]) -> str | None: + if len(params) < 2: + return None + api, vt = params[0], params[1] + if not isinstance(api, APIParam) or not isinstance(vt, VarTypeParam): + return None + func_short = _short_api_name(api.api_full_name) + attr = vt.attr_name + const = vt.const_value + if const is not _NOT_SET: + return f"{func_short}() consistently returns tensors with {attr}={const}" + return f"{func_short}() consistently returns tensors with consistent {attr}" + @staticmethod def _get_identifying_params(inv: Invariant) -> list[Param]: return [inv.params[0]] @@ -965,6 +981,40 @@ def static_check_all(trace, inv, check_relation_first): triggered=triggered, ) + @staticmethod + def to_display_name(params: list[Param]) -> str | None: + if len(params) < 3: + return None + input_param, api_param, output_param = params[0], params[1], params[2] + if not isinstance(api_param, APIParam): + return None + if not isinstance(input_param, InputOutputParam) or not isinstance( + output_param, InputOutputParam + ): + return None + func_short = _short_api_name(api_param.api_full_name) + in_path = ( + ".".join(str(p) for p in input_param.additional_path) + if input_param.additional_path + else (input_param.name or "?") + ) + out_path = ( + ".".join(str(p) for p in output_param.additional_path) + if output_param.additional_path + else (output_param.name or "?") + ) + in_ref = ( + f"input[{input_param.index}].{in_path}" + if input_param.index is not None + else f"input.{in_path}" + ) + out_ref = ( + f"output[{output_param.index}].{out_path}" + if output_param.index is not None + else f"output.{out_path}" + ) + return f"{func_short}(): {in_ref} consistent with {out_ref}" + @staticmethod def _get_identifying_params(inv: Invariant) -> list[Param]: return [inv.params[1]] @@ -1453,6 +1503,46 @@ def _get_apis_to_check(inv: Invariant): assert isinstance(inv.params[1], APIParam) return [inv.params[1].api_full_name] + @staticmethod + def to_display_name(params: list[Param]) -> str | None: + if len(params) < 3: + return None + first, api_param, second = params[0], params[1], params[2] + if not isinstance(api_param, APIParam): + return None + if not isinstance(first, InputOutputParam) or not isinstance( + second, InputOutputParam + ): + return None + func_short = _short_api_name(api_param.api_full_name) + # min case: params=[output, api, input_threshold] β†’ output β‰₯ threshold + if not first.is_input and second.is_input: + out_path = ( + ".".join(str(p) for p in first.additional_path) + if first.additional_path + else "value" + ) + out_ref = ( + f"output[{first.index}].{out_path}" + if first.index is not None + else "output" + ) + return f"{func_short}(): {out_ref} β‰₯ {second.name}" + # max case: params=[input_threshold, api, output] β†’ output ≀ threshold + if first.is_input and not second.is_input: + out_path = ( + ".".join(str(p) for p in second.additional_path) + if second.additional_path + else "value" + ) + out_ref = ( + f"output[{second.index}].{out_path}" + if second.index is not None + else "output" + ) + return f"{func_short}(): {out_ref} ≀ {first.name}" + return None + @staticmethod def _get_api_args_map_to_check(inv): return None diff --git a/traincheck/invariant/contain_relation.py b/traincheck/invariant/contain_relation.py index ac42ae65..7196b73f 100644 --- a/traincheck/invariant/contain_relation.py +++ b/traincheck/invariant/contain_relation.py @@ -10,6 +10,7 @@ from traincheck.config.config import ANALYSIS_SKIP_FUNC_NAMES from traincheck.instrumentor.tracer import TraceLineType from traincheck.invariant.base_cls import ( + _NOT_SET, APIParam, Arguments, CheckerResult, @@ -23,6 +24,7 @@ Relation, VarNameParam, VarTypeParam, + _short_api_name, calc_likelihood, construct_api_param, construct_var_param_from_var_change, @@ -340,6 +342,33 @@ class APIContainRelation(Relation): - [ ] Make the Dynamic Analysis part less ad-hoc as of its current form in the code """ + @staticmethod + def to_display_name(params: list[Param]) -> str | None: + if len(params) < 2: + return None + parent = params[0] + child = params[1] + if not isinstance(parent, APIParam): + return None + parent_short = _short_api_name(parent.api_full_name) + if isinstance(child, APIParam): + child_short = _short_api_name(child.api_full_name) + return f"{parent_short}() always calls {child_short}()" + if isinstance(child, (VarTypeParam, VarNameParam)): + var_short = child.var_type.split(".")[-1] + attr = child.attr_name + pre = child.pre_value + post = child.post_value + const = child.const_value + if pre is not _NOT_SET and post is not _NOT_SET: + pre_str = "non-zero" if pre == "non_zero" else str(pre) + post_str = str(post) + return f"{parent_short}() changes {var_short}.{attr}: {pre_str} β†’ {post_str}" + if const is not _NOT_SET: + return f"{parent_short}() sees {var_short}.{attr} = {const}" + return f"{parent_short}() accesses {var_short}.{attr}" + return None + @staticmethod def generate_hypothesis(trace) -> list[Hypothesis]: # let's play it dumb here, diff --git a/traincheck/invariant/cover_relation.py b/traincheck/invariant/cover_relation.py index d4e7c2d9..5dd1dcb9 100644 --- a/traincheck/invariant/cover_relation.py +++ b/traincheck/invariant/cover_relation.py @@ -17,6 +17,7 @@ OnlineCheckerResult, Param, Relation, + _short_api_name, ) from traincheck.invariant.lead_relation import ( check_same_level, @@ -103,6 +104,17 @@ class FunctionCoverRelation(Relation): every time function B is called, a function A invocation exists before it. """ + @staticmethod + def to_display_name(params: list[Param]) -> str | None: + if len(params) < 2: + return None + a, b = params[0], params[1] + if not isinstance(a, APIParam) or not isinstance(b, APIParam): + return None + a_short = _short_api_name(a.api_full_name) + b_short = _short_api_name(b.api_full_name) + return f"{a_short}() always occurs when {b_short}() is called" + @staticmethod def generate_hypothesis(trace) -> list[Hypothesis]: """Generate hypothesis for the FunctionCoverRelation on trace.""" diff --git a/traincheck/invariant/lead_relation.py b/traincheck/invariant/lead_relation.py index 2db839a9..00c222ec 100644 --- a/traincheck/invariant/lead_relation.py +++ b/traincheck/invariant/lead_relation.py @@ -17,6 +17,7 @@ OnlineCheckerResult, Param, Relation, + _short_api_name, ) from traincheck.invariant.precondition import find_precondition from traincheck.onlinechecker.utils import Checker_data, set_meta_vars_online @@ -1102,6 +1103,17 @@ def _get_apis_to_check(inv: Invariant): api_name_list.append(param.api_full_name) return api_name_list + @staticmethod + def to_display_name(params: list[Param]) -> str | None: + if len(params) < 2: + return None + a, b = params[0], params[-1] + if not isinstance(a, APIParam) or not isinstance(b, APIParam): + return None + a_short = _short_api_name(a.api_full_name) + b_short = _short_api_name(b.api_full_name) + return f"{a_short}() always precedes {b_short}()" + @staticmethod def _get_api_args_map_to_check(inv): return None diff --git a/traincheck/reporting/checker_report.py b/traincheck/reporting/checker_report.py index 92e36cee..10f6dbff 100644 --- a/traincheck/reporting/checker_report.py +++ b/traincheck/reporting/checker_report.py @@ -10,6 +10,9 @@ def _format_invariant_label(invariant: Invariant) -> str: + display = invariant.relation.to_display_name(invariant.params) + if display: + return display if invariant.text_description: return invariant.text_description params = ", ".join(str(param) for param in invariant.params) From 695fba02308eb50a702026fec2fc4f5b34af654f Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Sat, 14 Mar 2026 08:34:55 -0400 Subject: [PATCH 04/28] feat: add violation summary with step numbers and recurrence MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - checker_report: add _extract_violation_steps() and _build_violation_entry() helpers; _count_failed_invariants() now tracks first_step per invariant; HTML report shows "first seen at step N Β· M occurrences" under each item - reporting/__init__.py: export build_violations_summary() - checker.py: write violations_summary.json alongside failed.log for each trace, containing first_violation_step, distinct_invariants_violated, and per-violation entries with display_name, relation_type, first/last step, occurrences Co-Authored-By: Claude Sonnet 4.6 --- traincheck/checker.py | 15 ++++- traincheck/reporting/__init__.py | 1 + traincheck/reporting/checker_report.py | 78 ++++++++++++++++++++++++-- 3 files changed, 89 insertions(+), 5 deletions(-) diff --git a/traincheck/checker.py b/traincheck/checker.py index 45086ea8..a24904f7 100644 --- a/traincheck/checker.py +++ b/traincheck/checker.py @@ -7,7 +7,11 @@ from tqdm import tqdm from traincheck.invariant import CheckerResult, Invariant, read_inv_file -from traincheck.reporting import ReportEmitter, build_offline_report_data +from traincheck.reporting import ( + ReportEmitter, + build_offline_report_data, + build_violations_summary, +) from traincheck.trace import MDNONEJSONEncoder, Trace, select_trace_implementation from traincheck.utils import register_custom_excepthook @@ -292,6 +296,15 @@ def main(): json.dump(res.to_dict(), f, indent=4, cls=MDNONEJSONEncoder) f.write("\n") + violations_summary = build_violations_summary(results_per_trace) + with open( + os.path.join( + args.output_dir, trace_parent_folder, "violations_summary.json" + ), + "w", + ) as f: + json.dump(violations_summary, f, indent=2) + results_by_trace.append((trace_parent_folder, results_per_trace)) report_data = build_offline_report_data( diff --git a/traincheck/reporting/__init__.py b/traincheck/reporting/__init__.py index 75bd0c9e..f3c090ec 100644 --- a/traincheck/reporting/__init__.py +++ b/traincheck/reporting/__init__.py @@ -2,4 +2,5 @@ ReportEmitter, build_offline_report_data, build_online_report_data, + build_violations_summary, ) diff --git a/traincheck/reporting/checker_report.py b/traincheck/reporting/checker_report.py index 10f6dbff..c14be4b1 100644 --- a/traincheck/reporting/checker_report.py +++ b/traincheck/reporting/checker_report.py @@ -19,6 +19,41 @@ def _format_invariant_label(invariant: Invariant) -> str: return f"{invariant.relation.__name__}({params})" +def _extract_violation_steps(trace: list[dict] | None) -> list[int]: + """Extract training step numbers from a violation trace.""" + if not trace: + return [] + return [ + r["meta_vars.step"] + for r in trace + if isinstance(r, dict) and r.get("meta_vars.step") is not None + ] + + +def _build_violation_entry(result: CheckerResult) -> dict: + steps = _extract_violation_steps(result.trace) + return { + "display_name": _format_invariant_label(result.invariant), + "relation_type": result.invariant.relation.__name__, + "first_step": min(steps) if steps else None, + "last_step": max(steps) if steps else None, + "occurrences": len(result.trace) if result.trace else 1, + } + + +def build_violations_summary(results: list[CheckerResult]) -> dict: + """Build a pre-digested summary of all violations for machine and human consumption.""" + failed = [r for r in results if not r.check_passed] + all_steps = [] + for r in failed: + all_steps.extend(_extract_violation_steps(r.trace)) + return { + "first_violation_step": min(all_steps) if all_steps else None, + "distinct_invariants_violated": len(failed), + "violations": [_build_violation_entry(r) for r in failed], + } + + def _summarize_results(results: Iterable[CheckerResult]) -> dict[str, int]: failed = sum(1 for res in results if not res.check_passed) not_triggered = sum(1 for res in results if res.triggered is False) @@ -55,14 +90,29 @@ def _count_failed_invariants( results: Iterable[CheckerResult], ) -> list[dict[str, object]]: counter: Counter[tuple[str, str]] = Counter() + first_steps: dict[tuple[str, str], int | None] = {} for res in results: if not res.check_passed: label = _format_invariant_label(res.invariant) relation = res.invariant.relation.__name__ - counter[(label, relation)] += 1 + key = (label, relation) + counter[key] += 1 + steps = _extract_violation_steps(res.trace) + if steps: + existing = first_steps.get(key) + first_steps[key] = ( + min(steps) if existing is None else min(existing, min(steps)) + ) + elif key not in first_steps: + first_steps[key] = None top_pairs = counter.most_common(10) return [ - {"label": label, "relation": relation, "count": count} + { + "label": label, + "relation": relation, + "count": count, + "first_step": first_steps.get((label, relation)), + } for (label, relation), count in top_pairs ] @@ -206,9 +256,19 @@ def percent(part: int, total: int) -> float: top_items = [] for entry in top_violations: label = esc(str(entry.get("label", ""))) - detail = esc(str(entry.get("relation", ""))) + relation = esc(str(entry.get("relation", ""))) count = entry.get("count") + first_step = entry.get("first_step") count_html = f'{count}' if count else "" + if first_step is not None: + step_note = f"first seen at step {first_step}" + if count and count > 1: + step_note += f" Β· {count} occurrences" + detail = esc(f"{entry.get('relation', '')} β€” {step_note}") + else: + detail = relation + if count and count > 1: + detail = esc(f"{entry.get('relation', '')} β€” {count} occurrences") top_items.append( f'
  • {label}' f'{detail}{count_html}
  • ' @@ -231,9 +291,19 @@ def percent(part: int, total: int) -> float: failed_list_items = [] for failed_item in trace["failed_invariants"][:10]: label = esc(str(failed_item.get("label", ""))) - detail = esc(str(failed_item.get("relation", ""))) + relation = esc(str(failed_item.get("relation", ""))) count = failed_item.get("count") + first_step = failed_item.get("first_step") count_html = f'{count}' if count else "" + if first_step is not None: + step_note = f"first seen at step {first_step}" + if count and count > 1: + step_note += f" Β· {count} occurrences" + detail = esc(f"{relation} β€” {step_note}") + else: + detail = relation + if count and count > 1: + detail = esc(f"{relation} β€” {count} occurrences") failed_list_items.append( f'
  • {label}' f'{detail}{count_html}
  • ' From cfb035e5b9a74af9fe6a2d47d0272a9c20f0b1a0 Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Sat, 14 Mar 2026 08:36:53 -0400 Subject: [PATCH 05/28] test: add semantic unit tests for display names and violation summary MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - test_display_names.py: one test class per relation type (8 total, 24 tests) verifying that key semantic tokens appear in to_display_name() output given known param lists β€” independent of the inference algorithm - test_violation_summary.py: pure function tests for _extract_violation_steps(), _build_violation_entry(), and build_violations_summary() (14 tests) All 38 tests pass. Co-Authored-By: Claude Sonnet 4.6 --- tests/test_display_names.py | 249 ++++++++++++++++++++++++++++++++ tests/test_violation_summary.py | 167 +++++++++++++++++++++ 2 files changed, 416 insertions(+) create mode 100644 tests/test_display_names.py create mode 100644 tests/test_violation_summary.py diff --git a/tests/test_display_names.py b/tests/test_display_names.py new file mode 100644 index 00000000..3590c499 --- /dev/null +++ b/tests/test_display_names.py @@ -0,0 +1,249 @@ +"""Semantic unit tests for Relation.to_display_name(). + +These tests verify that key *meaning* tokens appear in the output for each +relation type given a known params list. They do NOT test inference logic β€” +the params are constructed directly, so the tests remain stable even if the +inference algorithm changes. +""" + +import pytest + +from traincheck.invariant.base_cls import ( + _NOT_SET, + APIParam, + InputOutputParam, + VarTypeParam, +) +from traincheck.invariant.consistency_relation import ConsistencyRelation +from traincheck.invariant.consistency_transient_vars import ( + ConsistentInputOutputRelation, + ConsistentOutputRelation, + ThresholdRelation, +) +from traincheck.invariant.contain_relation import APIContainRelation +from traincheck.invariant.cover_relation import FunctionCoverRelation +from traincheck.invariant.DistinctArgumentRelation import DistinctArgumentRelation +from traincheck.invariant.lead_relation import FunctionLeadRelation + + +class TestAPIContainRelationDisplayName: + def test_state_transition(self): + params = [ + APIParam("torch.optim.optimizer.Optimizer.zero_grad"), + VarTypeParam( + "torch.nn.Parameter", "grad", pre_value="non_zero", post_value=None + ), + ] + name = APIContainRelation.to_display_name(params) + assert name is not None + assert "zero_grad" in name + assert "grad" in name + assert "non" in name.lower() # "non-zero" + + def test_api_calls_api(self): + params = [ + APIParam("torch.optim.optimizer.Optimizer.step"), + APIParam("torch.optim.adadelta.adadelta"), + ] + name = APIContainRelation.to_display_name(params) + assert name is not None + assert "step" in name + assert "adadelta" in name + + def test_const_value(self): + params = [ + APIParam("torch.nn.modules.module.Module.forward"), + VarTypeParam("torch.nn.Parameter", "requires_grad", const_value=True), + ] + name = APIContainRelation.to_display_name(params) + assert name is not None + assert "forward" in name + assert "requires_grad" in name + + def test_returns_none_for_empty_params(self): + assert APIContainRelation.to_display_name([]) is None + + def test_returns_none_for_single_param(self): + assert APIContainRelation.to_display_name([APIParam("torch.foo")]) is None + + +class TestConsistencyRelationDisplayName: + def test_basic(self): + params = [VarTypeParam("torch.nn.Parameter", "grad")] + name = ConsistencyRelation.to_display_name(params) + assert name is not None + assert "Parameter" in name + assert "grad" in name + assert any(w in name.lower() for w in ("consistent", "stay", "step")) + + def test_returns_none_for_empty(self): + assert ConsistencyRelation.to_display_name([]) is None + + def test_returns_none_for_non_vartype(self): + assert ConsistencyRelation.to_display_name([APIParam("torch.foo.bar")]) is None + + +class TestFunctionCoverRelationDisplayName: + def test_cover_direction(self): + params = [ + APIParam("torch.distributed.is_initialized"), + APIParam("torch.nn.modules.module.Module.eval"), + ] + name = FunctionCoverRelation.to_display_name(params) + assert name is not None + assert "is_initialized" in name + assert "eval" in name + assert any(w in name.lower() for w in ("occurs", "cover", "when")) + + def test_returns_none_for_insufficient_params(self): + assert FunctionCoverRelation.to_display_name([APIParam("torch.foo")]) is None + + +class TestFunctionLeadRelationDisplayName: + def test_ordering(self): + params = [ + APIParam("torch.Tensor.backward"), + APIParam("torch.optim.optimizer.Optimizer.step"), + ] + name = FunctionLeadRelation.to_display_name(params) + assert name is not None + assert "backward" in name + assert "step" in name + assert any(w in name.lower() for w in ("precede", "before", "lead")) + + def test_merged_three_params(self): + """Merged lead invariants can have 3 APIParams; display uses first and last.""" + params = [ + APIParam("torch.Tensor.backward"), + APIParam("torch.optim.optimizer.Optimizer.zero_grad"), + APIParam("torch.optim.optimizer.Optimizer.step"), + ] + name = FunctionLeadRelation.to_display_name(params) + assert name is not None + assert "backward" in name + assert "step" in name + + def test_returns_none_for_single_param(self): + assert FunctionLeadRelation.to_display_name([APIParam("torch.foo")]) is None + + +class TestDistinctArgumentRelationDisplayName: + def test_basic(self): + params = [APIParam("torch.nn.init.normal_")] + name = DistinctArgumentRelation.to_display_name(params) + assert name is not None + assert "normal_" in name + assert any(w in name.lower() for w in ("distinct", "different", "argument")) + + def test_returns_none_for_empty(self): + assert DistinctArgumentRelation.to_display_name([]) is None + + def test_returns_none_for_non_api_param(self): + params = [VarTypeParam("torch.nn.Parameter", "grad")] + assert DistinctArgumentRelation.to_display_name(params) is None + + +class TestConsistentOutputRelationDisplayName: + def test_with_const_value(self): + params = [ + APIParam("torch.nn.functional.relu"), + VarTypeParam("torch.Tensor", "dtype", const_value="float32"), + ] + name = ConsistentOutputRelation.to_display_name(params) + assert name is not None + assert "relu" in name + assert "dtype" in name + assert "float32" in name + assert any(w in name.lower() for w in ("consistent", "return")) + + def test_without_const_value(self): + params = [ + APIParam("torch.nn.functional.relu"), + VarTypeParam("torch.Tensor", "ndim"), + ] + name = ConsistentOutputRelation.to_display_name(params) + assert name is not None + assert "relu" in name + assert "ndim" in name + + def test_returns_none_for_insufficient_params(self): + assert ConsistentOutputRelation.to_display_name([APIParam("torch.foo")]) is None + + +class TestConsistentInputOutputRelationDisplayName: + def test_basic(self): + in_p = InputOutputParam( + name="input", + index=0, + type="torch.Tensor", + additional_path=("itemsize",), + api_name="kaiming_uniform_", + is_input=True, + ) + out_p = InputOutputParam( + name="output", + index=0, + type="torch.Tensor", + additional_path=("ndim",), + api_name="kaiming_uniform_", + is_input=False, + ) + api_p = APIParam("torch.nn.init.kaiming_uniform_") + name = ConsistentInputOutputRelation.to_display_name([in_p, api_p, out_p]) + assert name is not None + assert "kaiming_uniform_" in name + assert "itemsize" in name + assert "ndim" in name + assert "input" in name.lower() + assert "output" in name.lower() + + def test_returns_none_for_insufficient_params(self): + api_p = APIParam("torch.foo") + assert ConsistentInputOutputRelation.to_display_name([api_p]) is None + + +class TestThresholdRelationDisplayName: + def _make_output_param(self, api_name: str) -> InputOutputParam: + return InputOutputParam( + name="output_tensors", + index=0, + type="torch.Tensor", + additional_path=("value",), + api_name=api_name, + is_input=False, + ) + + def _make_threshold_param(self, name: str, api_name: str) -> InputOutputParam: + return InputOutputParam( + name=name, + index=None, + type="float", + additional_path=None, + api_name=api_name, + is_input=True, + ) + + def test_min_threshold_gte(self): + """params=[output, api, threshold] β†’ output β‰₯ threshold.""" + api_p = APIParam("torch.optim.optimizer.Optimizer.step") + out_p = self._make_output_param("Optimizer.step") + thresh_p = self._make_threshold_param("lr", "Optimizer.step") + name = ThresholdRelation.to_display_name([out_p, api_p, thresh_p]) + assert name is not None + assert "Optimizer.step" in name + assert "lr" in name + assert "β‰₯" in name + + def test_max_threshold_lte(self): + """params=[threshold, api, output] β†’ output ≀ threshold.""" + api_p = APIParam("torch.optim.optimizer.Optimizer.step") + out_p = self._make_output_param("Optimizer.step") + thresh_p = self._make_threshold_param("lr", "Optimizer.step") + name = ThresholdRelation.to_display_name([thresh_p, api_p, out_p]) + assert name is not None + assert "Optimizer.step" in name + assert "lr" in name + assert "≀" in name + + def test_returns_none_for_insufficient_params(self): + assert ThresholdRelation.to_display_name([APIParam("torch.foo")]) is None diff --git a/tests/test_violation_summary.py b/tests/test_violation_summary.py new file mode 100644 index 00000000..db4962f8 --- /dev/null +++ b/tests/test_violation_summary.py @@ -0,0 +1,167 @@ +"""Semantic unit tests for violation summary helpers. + +Tests verify the extraction and aggregation logic β€” pure function tests that +do not depend on trace loading or the inference algorithm. +""" + +import pytest + +from traincheck.invariant.base_cls import APIParam, CheckerResult, Invariant +from traincheck.invariant.cover_relation import FunctionCoverRelation +from traincheck.reporting.checker_report import ( + _build_violation_entry, + _extract_violation_steps, + build_violations_summary, +) + +# --------------------------------------------------------------------------- +# _extract_violation_steps +# --------------------------------------------------------------------------- + + +def test_extract_steps_basic(): + trace = [{"meta_vars.step": 1}, {"meta_vars.step": 3}] + assert _extract_violation_steps(trace) == [1, 3] + + +def test_extract_steps_missing_key(): + trace = [{"function": "foo"}, {"meta_vars.step": 5}] + assert _extract_violation_steps(trace) == [5] + + +def test_extract_steps_none_trace(): + assert _extract_violation_steps(None) == [] + + +def test_extract_steps_empty_trace(): + assert _extract_violation_steps([]) == [] + + +def test_extract_steps_none_value_skipped(): + trace = [{"meta_vars.step": None}, {"meta_vars.step": 2}] + assert _extract_violation_steps(trace) == [2] + + +# --------------------------------------------------------------------------- +# Helper to build minimal CheckerResult fixtures +# --------------------------------------------------------------------------- + + +def _make_invariant() -> Invariant: + return Invariant( + relation=FunctionCoverRelation, + params=[ + APIParam("torch.distributed.is_initialized"), + APIParam("torch.nn.modules.module.Module.eval"), + ], + precondition=None, + text_description="test invariant", + ) + + +def _make_result(steps: list[int] | None, check_passed: bool = False) -> CheckerResult: + trace = [{"meta_vars.step": s} for s in steps] if steps is not None else None + return CheckerResult( + invariant=_make_invariant(), + trace=trace, + check_passed=check_passed, + triggered=True, + ) + + +# --------------------------------------------------------------------------- +# _build_violation_entry +# --------------------------------------------------------------------------- + + +def test_build_entry_fields_present(): + result = _make_result(steps=[1, 2, 3]) + entry = _build_violation_entry(result) + assert "display_name" in entry + assert "relation_type" in entry + assert "first_step" in entry + assert "last_step" in entry + assert "occurrences" in entry + + +def test_build_entry_step_values(): + result = _make_result(steps=[1, 5, 3]) + entry = _build_violation_entry(result) + assert entry["first_step"] == 1 + assert entry["last_step"] == 5 + assert entry["occurrences"] == 3 + + +def test_build_entry_no_steps(): + # trace records that have no meta_vars.step key + trace = [{"function": "foo"}] + result = CheckerResult( + invariant=_make_invariant(), + trace=trace, + check_passed=False, + triggered=True, + ) + entry = _build_violation_entry(result) + assert entry["first_step"] is None + assert entry["last_step"] is None + + +def test_build_entry_display_name_is_string(): + result = _make_result(steps=[1]) + entry = _build_violation_entry(result) + assert isinstance(entry["display_name"], str) + assert len(entry["display_name"]) > 0 + + +def test_build_entry_relation_type(): + result = _make_result(steps=[1]) + entry = _build_violation_entry(result) + assert entry["relation_type"] == "FunctionCoverRelation" + + +# --------------------------------------------------------------------------- +# build_violations_summary +# --------------------------------------------------------------------------- + + +def test_summary_no_failures(): + results = [_make_result(steps=[1], check_passed=True)] + summary = build_violations_summary(results) + assert summary["distinct_invariants_violated"] == 0 + assert summary["violations"] == [] + assert summary["first_violation_step"] is None + + +def test_summary_with_failures(): + results = [ + _make_result(steps=[2, 4], check_passed=False), + _make_result(steps=[1], check_passed=False), + _make_result(steps=[5], check_passed=True), + ] + summary = build_violations_summary(results) + assert summary["distinct_invariants_violated"] == 2 + assert summary["first_violation_step"] == 1 + assert len(summary["violations"]) == 2 + + +def test_summary_first_step_across_violations(): + results = [ + _make_result(steps=[10, 20], check_passed=False), + _make_result(steps=[3], check_passed=False), + ] + summary = build_violations_summary(results) + assert summary["first_violation_step"] == 3 + + +def test_summary_no_step_data(): + # violation with trace records that carry no step information + result = CheckerResult( + invariant=_make_invariant(), + trace=[{"function": "foo"}], + check_passed=False, + triggered=True, + ) + summary = build_violations_summary([result]) + assert summary["distinct_invariants_violated"] == 1 + assert summary["first_violation_step"] is None + assert summary["violations"][0]["first_step"] is None From 6d1b71c113388dbc7f0f949ae0acb8f6ce56706d Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Wed, 18 Mar 2026 13:57:41 -0400 Subject: [PATCH 06/28] fix: clean up display names for _TRAINCHECK_ attrs and non_zero values MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two label quality issues found during live workload testing: - APIContainRelation.to_display_name: return None for attrs starting with _TRAINCHECK_ (internal proxy bookkeeping IDs that are meaningless to users) - APIContainRelation.to_display_name: normalize non_zero β†’ non-zero in both pre and post values via shared _fmt_val() helper - _format_invariant_label: when to_display_name returns None and params include a _TRAINCHECK_ attr, produce "Func() [internal tracking]" instead of falling back to the raw text_description containing the ugly internal name - tests: add test_post_value_non_zero_normalized and test_traincheck_internal_attr_hidden to test_display_names.py Co-Authored-By: Claude Sonnet 4.6 --- tests/test_display_names.py | 29 ++++++++++++++++++++++++ traincheck/invariant/contain_relation.py | 13 +++++++---- traincheck/reporting/checker_report.py | 16 +++++++++++++ 3 files changed, 54 insertions(+), 4 deletions(-) diff --git a/tests/test_display_names.py b/tests/test_display_names.py index 3590c499..68ffd112 100644 --- a/tests/test_display_names.py +++ b/tests/test_display_names.py @@ -60,6 +60,35 @@ def test_const_value(self): assert "forward" in name assert "requires_grad" in name + def test_post_value_non_zero_normalized(self): + """non_zero post-value should render as 'non-zero', not 'non_zero'.""" + params = [ + APIParam("torch.optim.sgd.SGD.step"), + VarTypeParam( + "torch.nn.Parameter", + "data", + pre_value="non_zero", + post_value="non_zero", + ), + ] + name = APIContainRelation.to_display_name(params) + assert name is not None + assert "non_zero" not in name + assert "non-zero" in name + + def test_traincheck_internal_attr_hidden(self): + """Attributes starting with _TRAINCHECK_ are internal proxy IDs and should be filtered.""" + params = [ + APIParam("torch.optim.sgd.SGD.step"), + VarTypeParam( + "torch.nn.Parameter", + "_TRAINCHECK_grad_ID", + pre_value="above_zero", + post_value="above_zero", + ), + ] + assert APIContainRelation.to_display_name(params) is None + def test_returns_none_for_empty_params(self): assert APIContainRelation.to_display_name([]) is None diff --git a/traincheck/invariant/contain_relation.py b/traincheck/invariant/contain_relation.py index 7196b73f..05af1d4f 100644 --- a/traincheck/invariant/contain_relation.py +++ b/traincheck/invariant/contain_relation.py @@ -355,15 +355,20 @@ def to_display_name(params: list[Param]) -> str | None: child_short = _short_api_name(child.api_full_name) return f"{parent_short}() always calls {child_short}()" if isinstance(child, (VarTypeParam, VarNameParam)): - var_short = child.var_type.split(".")[-1] attr = child.attr_name + # Skip internal TrainCheck proxy bookkeeping attributes + if attr.startswith("_TRAINCHECK_"): + return None + var_short = child.var_type.split(".")[-1] pre = child.pre_value post = child.post_value const = child.const_value + + def _fmt_val(v: object) -> str: + return "non-zero" if v == "non_zero" else str(v) + if pre is not _NOT_SET and post is not _NOT_SET: - pre_str = "non-zero" if pre == "non_zero" else str(pre) - post_str = str(post) - return f"{parent_short}() changes {var_short}.{attr}: {pre_str} β†’ {post_str}" + return f"{parent_short}() changes {var_short}.{attr}: {_fmt_val(pre)} β†’ {_fmt_val(post)}" if const is not _NOT_SET: return f"{parent_short}() sees {var_short}.{attr} = {const}" return f"{parent_short}() accesses {var_short}.{attr}" diff --git a/traincheck/reporting/checker_report.py b/traincheck/reporting/checker_report.py index c14be4b1..c07eebb7 100644 --- a/traincheck/reporting/checker_report.py +++ b/traincheck/reporting/checker_report.py @@ -7,12 +7,28 @@ from typing import Iterable from traincheck.invariant import CheckerResult, Invariant +from traincheck.invariant.base_cls import APIParam, VarNameParam, VarTypeParam def _format_invariant_label(invariant: Invariant) -> str: display = invariant.relation.to_display_name(invariant.params) if display: return display + # When to_display_name returns None, fall back β€” but sanitize params that + # contain internal TrainCheck proxy bookkeeping names (_TRAINCHECK_*) so + # they never surface raw in the UI. + for p in invariant.params: + if isinstance(p, (VarTypeParam, VarNameParam)) and p.attr_name.startswith( + "_TRAINCHECK_" + ): + # Build a minimal label using only the API param, hiding internals + api_parts = [q for q in invariant.params if isinstance(q, APIParam)] + from traincheck.invariant.base_cls import _short_api_name + + if api_parts: + func = _short_api_name(api_parts[0].api_full_name) + return f"{func}() [internal tracking]" + return f"{invariant.relation.__name__} [internal tracking]" if invariant.text_description: return invariant.text_description params = ", ".join(str(param) for param in invariant.params) From 8ebe1a68bdc6a66427c0ab4b81cde1c5c013e3fb Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Wed, 18 Mar 2026 14:15:43 -0400 Subject: [PATCH 07/28] fix: silence noisy stdout/stderr during traincheck-check - checker.py: stream handler now set to WARNING level so per-invariant INFO logs stay in the log file; summary lines printed via print() - cover_relation.py, lead_relation.py: convert remaining print() debug calls to logger.debug(); change all leave=True to leave=False on inner tqdm bars so they don't persist after completing - DistinctArgumentRelation.py, consistency_relation.py, consistency_transient_vars.py: same leave=False cleanup Co-Authored-By: Claude Sonnet 4.6 --- traincheck/checker.py | 16 ++++- .../invariant/DistinctArgumentRelation.py | 27 ++++---- traincheck/invariant/consistency_relation.py | 9 ++- .../invariant/consistency_transient_vars.py | 12 +++- traincheck/invariant/cover_relation.py | 68 ++++++++++--------- traincheck/invariant/lead_relation.py | 64 ++++++++--------- 6 files changed, 114 insertions(+), 82 deletions(-) diff --git a/traincheck/checker.py b/traincheck/checker.py index a24904f7..556d3774 100644 --- a/traincheck/checker.py +++ b/traincheck/checker.py @@ -236,7 +236,12 @@ def main(): logger.info("Reading traces from %s", "\n".join(trace_files)) traces.append(read_trace_file(trace_files)) - logger.addHandler(logging.StreamHandler()) + # Warnings and above go to stderr so the user sees them; detailed per-invariant + # INFO logs stay in the log file only. + stream_handler = logging.StreamHandler() + stream_handler.setLevel(logging.WARNING) + logger.addHandler(stream_handler) + results_by_trace: list[tuple[str, list[CheckerResult]]] = [] for trace, trace_parent_folder in zip(traces, trace_parent_folders): @@ -248,6 +253,15 @@ def main(): res for res in results_per_trace if res.triggered is False ] + n_checked = len(results_per_trace) + n_failed = len(results_per_trace_failed) + n_passed = n_checked - n_failed + n_not_triggered = len(results_per_trace_not_triggered) + print(f"Checking finished. {n_checked} invariants checked") + print(f"Total failed invariants: {n_failed}/{n_checked}") + print(f"Total passed invariants: {n_passed}/{n_checked}") + print(f"Total invariants that are not triggered: {n_not_triggered}/{n_checked}") + logger.info("Checking finished. %d invariants checked", len(results_per_trace)) logger.info( "Total failed invariants: %d/%d", diff --git a/traincheck/invariant/DistinctArgumentRelation.py b/traincheck/invariant/DistinctArgumentRelation.py index 23a63bc0..b74a2564 100644 --- a/traincheck/invariant/DistinctArgumentRelation.py +++ b/traincheck/invariant/DistinctArgumentRelation.py @@ -1,3 +1,4 @@ +import logging from itertools import combinations from typing import Any, Dict, Iterable, List, Set, Tuple @@ -22,6 +23,8 @@ from traincheck.trace.trace import Trace from traincheck.utils import safe_isnan +logger = logging.getLogger(__name__) + EXP_GROUP_NAME = "distinct_arg" MAX_FUNC_NUM_CONSECUTIVE_CALL = 6 IOU_THRESHHOLD = 0.1 # pre-defined threshhold for IOU @@ -199,7 +202,7 @@ class DistinctArgumentRelation(Relation): def generate_hypothesis(trace) -> list[Hypothesis]: """Generate hypothesis for the DistinctArgumentRelation on trace.""" # 1. Pre-process all the events - print("Start preprocessing....") + logger.debug("Start preprocessing....") listed_arguments: Dict[ str, Dict[int, Dict[Tuple[str, str], List[dict[str, Any]]]] ] = {} @@ -210,7 +213,7 @@ def generate_hypothesis(trace) -> list[Hypothesis]: function_pool, listed_arguments = get_event_data_per_function_per_step( trace, function_pool ) - print("End preprocessing") + logger.debug("End preprocessing") # If there is no filtered function, return [], [] if not function_pool: @@ -221,7 +224,7 @@ def generate_hypothesis(trace) -> list[Hypothesis]: # function_pool.add("torch.nn.init.normal_") # 2. Generating hypothesis - print("Start generating hypo...") + logger.debug("Start generating hypo...") hypothesis_with_examples = { func_name: Hypothesis( invariant=Invariant( @@ -235,10 +238,10 @@ def generate_hypothesis(trace) -> list[Hypothesis]: ) for func_name in function_pool } - print("End generating hypo") + logger.debug("End generating hypo") # 3. Add positive and negative examples - print("Start adding examples...") + logger.debug("Start adding examples...") for func_name in tqdm(function_pool): flag = False for step, records in listed_arguments[func_name].items(): @@ -280,7 +283,7 @@ def generate_hypothesis(trace) -> list[Hypothesis]: if not flag: hypothesis_with_examples.pop(func_name) - print("End adding examples") + logger.debug("End adding examples") return list(hypothesis_with_examples.values()) @@ -290,7 +293,7 @@ def collect_examples(trace, hypothesis): inv = hypothesis.invariant # 1. Pre-process all the events - print("Start preprocessing....") + logger.debug("Start preprocessing....") listed_arguments: Dict[ str, Dict[int, Dict[Tuple[str, str], List[dict[str, Any]]]] ] = {} @@ -306,7 +309,7 @@ def collect_examples(trace, hypothesis): trace, function_pool ) - print("End preprocessing") + logger.debug("End preprocessing") if not function_pool: return @@ -344,7 +347,7 @@ def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]: # DistinctArgumentRelation.collect_examples(trace, hypothesis) # 4. Precondition inference - print("Start precondition inference...") + logger.debug("Start precondition inference...") failed_hypothesis = [] for hypothesis in all_hypotheses.copy(): preconditions = find_precondition(hypothesis, [trace]) @@ -355,7 +358,7 @@ def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]: FailedHypothesis(hypothesis, "Precondition not found") ) all_hypotheses.remove(hypothesis) - print("End precondition inference") + logger.debug("End precondition inference") return ( list([hypo.invariant for hypo in all_hypotheses]), @@ -390,7 +393,7 @@ def static_check_all( assert inv.precondition is not None, "Invariant should have a precondition." # 1. Pre-process all the events - print("Start preprocessing....") + logger.debug("Start preprocessing....") listed_arguments: Dict[ str, Dict[int, Dict[Tuple[str, str], List[dict[str, Any]]]] ] = {} @@ -415,7 +418,7 @@ def static_check_all( ) events_list = get_event_list(trace, function_pool) - print("End preprocessing") + logger.debug("End preprocessing") if not inv.precondition.verify(events_list, EXP_GROUP_NAME, trace): return CheckerResult( diff --git a/traincheck/invariant/consistency_relation.py b/traincheck/invariant/consistency_relation.py index 74f35c06..78d7a732 100644 --- a/traincheck/invariant/consistency_relation.py +++ b/traincheck/invariant/consistency_relation.py @@ -191,6 +191,7 @@ def skip_attrs_with_different_dtypes(attr1, attr2): combinations(var_insts, 2), desc="Generating Hypothesis for Consistency Relation", total=len(var_insts) * (len(var_insts) - 1) // 2, + leave=False, ): for attr in var_insts[var_inst]: for other_attr in var_insts[other_var_inst]: @@ -432,7 +433,9 @@ def static_check_all( start_time_collecting_pairs = time.time() num_collected_pairs = 0 value_pairs_to_check: dict[float, list[tuple]] = {} - for i, var1_id in enumerate(tqdm(type1_attr1, desc="Collecting Value Pairs")): + for i, var1_id in enumerate( + tqdm(type1_attr1, desc="Collecting Value Pairs", leave=False) + ): for j, var2_id in enumerate(type2_attr2): if var_type1 == var_type2 and attr1 == attr2 and i >= j: continue @@ -473,7 +476,9 @@ def static_check_all( start_time_checking_pairs = time.time() num_checked_pairs = 0 value_pairs_to_check = dict(sorted(value_pairs_to_check.items())) - for time_pair in tqdm(value_pairs_to_check, desc="Checking Value Pairs"): + for time_pair in tqdm( + value_pairs_to_check, desc="Checking Value Pairs", leave=False + ): for attr1_val, attr2_val in value_pairs_to_check[time_pair]: traces = [attr1_val.traces[-1], attr2_val.traces[-1]] num_checked_pairs += 1 diff --git a/traincheck/invariant/consistency_transient_vars.py b/traincheck/invariant/consistency_transient_vars.py index 98ce0721..55bfea02 100644 --- a/traincheck/invariant/consistency_transient_vars.py +++ b/traincheck/invariant/consistency_transient_vars.py @@ -534,7 +534,9 @@ def static_check_all( triggered = False # for each function call, check if the property holds for func_call_id in tqdm( - func_call_ids, desc=f"Checking invariant {inv.text_description}" + func_call_ids, + desc=f"Checking invariant {inv.text_description}", + leave=False, ): func_call_event = trace.query_func_call_event(func_call_id) if isinstance( @@ -938,7 +940,9 @@ def static_check_all(trace, inv, check_relation_first): triggered = False for func_call_id in tqdm( - func_call_ids, desc=f"Checking invariant {inv.text_description}" + func_call_ids, + desc=f"Checking invariant {inv.text_description}", + leave=False, ): func_call_event = trace.query_func_call_event(func_call_id) if isinstance( @@ -1441,7 +1445,9 @@ def static_check_all(trace, inv, check_relation_first): triggered = False for func_call_id in tqdm( - func_call_ids, desc=f"Checking invariant {inv.text_description}" + func_call_ids, + desc=f"Checking invariant {inv.text_description}", + leave=False, ): func_call_event = trace.query_func_call_event(func_call_id) if isinstance( diff --git a/traincheck/invariant/cover_relation.py b/traincheck/invariant/cover_relation.py index 5dd1dcb9..a8e88723 100644 --- a/traincheck/invariant/cover_relation.py +++ b/traincheck/invariant/cover_relation.py @@ -31,6 +31,8 @@ from traincheck.trace.trace import Trace from traincheck.trace.trace_pandas import TracePandas +logger = logging.getLogger(__name__) + EXP_GROUP_NAME = "func_cover" @@ -122,7 +124,7 @@ def generate_hypothesis(trace) -> list[Hypothesis]: logger = logging.getLogger(__name__) # 1. Pre-process all the events - print("Start preprocessing....") + logger.debug("Start preprocessing....") function_times: Dict[Tuple[str, str], Dict[str, Dict[str, Any]]] = {} function_id_map: Dict[Tuple[str, str], Dict[str, List[str]]] = {} listed_events: Dict[Tuple[str, str], List[dict[str, Any]]] = {} @@ -156,9 +158,9 @@ def generate_hypothesis(trace) -> list[Hypothesis]: trace.function_times = function_times trace.function_id_map = function_id_map trace.listed_events = listed_events - print("End preprocessing") + logger.debug("End preprocessing") - print("Start same level checking...") + logger.debug("Start same level checking...") same_level_func: Dict[Tuple[str, str], Dict[str, Any]] = {} valid_relations: Dict[Tuple[str, str], bool] = {} @@ -170,13 +172,13 @@ def generate_hypothesis(trace) -> list[Hypothesis]: valid_relations = trace.valid_relations_cover else: for (process_id, thread_id), _ in tqdm( - listed_events.items(), ascii=True, leave=True, desc="Groups Processed" + listed_events.items(), ascii=True, leave=False, desc="Groups Processed" ): same_level_func[(process_id, thread_id)] = {} for func_A, func_B in tqdm( permutations(function_pool, 2), ascii=True, - leave=True, + leave=False, desc="Combinations Checked", total=len(function_pool) ** 2, ): @@ -194,10 +196,10 @@ def generate_hypothesis(trace) -> list[Hypothesis]: valid_relations[(func_A, func_B)] = True trace.same_level_func_cover = same_level_func trace.valid_relations_cover = valid_relations - print("End same level checking") + logger.debug("End same level checking") # 3. Generating hypothesis - print("Start generating hypo...") + logger.debug("Start generating hypo...") hypothesis_with_examples = { (func_A, func_B): Hypothesis( invariant=Invariant( @@ -214,12 +216,12 @@ def generate_hypothesis(trace) -> list[Hypothesis]: ) for (func_A, func_B), _ in valid_relations.items() } - print("End generating hypo") + logger.debug("End generating hypo") # 4. Add positive and negative examples - print("Start adding examples...") + logger.debug("Start adding examples...") for (process_id, thread_id), events_list in tqdm( - listed_events.items(), ascii=True, leave=True, desc="Group" + listed_events.items(), ascii=True, leave=False, desc="Group" ): for (func_A, func_B), _ in tqdm( @@ -320,7 +322,7 @@ def generate_hypothesis(trace) -> list[Hypothesis]: (func_A, func_B) ].negative_examples.add_example(example) - print("End adding examples") + logger.debug("End adding examples") return list(hypothesis_with_examples.values()) @@ -331,7 +333,7 @@ def collect_examples(trace, hypothesis): logger = logging.getLogger(__name__) # 1. Pre-process all the events - print("Start preprocessing....") + logger.debug("Start preprocessing....") function_times: Dict[Tuple[str, str], Dict[str, Dict[str, Any]]] = {} function_id_map: Dict[Tuple[str, str], Dict[str, List[str]]] = {} listed_events: Dict[Tuple[str, str], List[dict[str, Any]]] = {} @@ -365,9 +367,9 @@ def collect_examples(trace, hypothesis): trace.function_times = function_times trace.function_id_map = function_id_map trace.listed_events = listed_events - print("End preprocessing") + logger.debug("End preprocessing") - print("Start same level checking...") + logger.debug("Start same level checking...") same_level_func: Dict[Tuple[str, str], Dict[str, Any]] = {} valid_relations: Dict[Tuple[str, str], bool] = {} @@ -379,13 +381,13 @@ def collect_examples(trace, hypothesis): valid_relations = trace.valid_relations_cover else: for (process_id, thread_id), _ in tqdm( - listed_events.items(), ascii=True, leave=True, desc="Groups Processed" + listed_events.items(), ascii=True, leave=False, desc="Groups Processed" ): same_level_func[(process_id, thread_id)] = {} for func_A, func_B in tqdm( permutations(function_pool, 2), ascii=True, - leave=True, + leave=False, desc="Combinations Checked", total=len(function_pool) ** 2, ): @@ -403,7 +405,7 @@ def collect_examples(trace, hypothesis): valid_relations[(func_A, func_B)] = True trace.same_level_func_cover = same_level_func trace.valid_relations_cover = valid_relations - print("End same level checking") + logger.debug("End same level checking") inv = hypothesis.invariant @@ -421,12 +423,12 @@ def collect_examples(trace, hypothesis): function_pool = set(function_pool).intersection(function_pool_temp) if len(function_pool) == 0: - print( + logger.debug( "No relevant function calls found in the trace, skipping the collecting" ) return - print("Starting collecting iteration...") + logger.debug("Starting collecting iteration...") # for i in tqdm(range(invariant_length - 1)): for i in range(invariant_length - 1): param_A = inv.params[i] @@ -540,7 +542,7 @@ def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]: if_merge = True - print("Start precondition inference...") + logger.debug("Start precondition inference...") failed_hypothesis = [] for hypothesis in all_hypotheses.copy(): preconditions = find_precondition(hypothesis, [trace]) @@ -551,17 +553,17 @@ def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]: FailedHypothesis(hypothesis, "Precondition not found") ) all_hypotheses.remove(hypothesis) - print("End precondition inference") + logger.debug("End precondition inference") if not if_merge: return ( list([hypo.invariant for hypo in all_hypotheses]), failed_hypothesis, ) - print("End precondition inference") + logger.debug("End precondition inference") # 6. Merge invariants - print("Start merging invariants...") + logger.debug("Start merging invariants...") relation_pool: Dict[ GroupedPreconditions | None, List[Tuple[APIParam, APIParam]] ] = {} @@ -591,7 +593,7 @@ def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]: text_description="Merged FunctionCoverRelation in Ordered List", ) merged_ininvariants.append(new_invariant) - print("End merging invariants") + logger.debug("End merging invariants") return merged_ininvariants, failed_hypothesis @@ -625,7 +627,7 @@ def static_check_all( logger = logging.getLogger(__name__) # 1. Pre-process all the events - print("Start preprocessing....") + logger.debug("Start preprocessing....") function_times: Dict[Tuple[str, str], Dict[str, Dict[str, Any]]] = {} function_id_map: Dict[Tuple[str, str], Dict[str, List[str]]] = {} listed_events: Dict[Tuple[str, str], List[dict[str, Any]]] = {} @@ -667,9 +669,9 @@ def static_check_all( trace.function_times = function_times trace.function_id_map = function_id_map trace.listed_events = listed_events - print("End preprocessing") + logger.debug("End preprocessing") - print("Start same level checking...") + logger.debug("Start same level checking...") same_level_func: Dict[Tuple[str, str], Dict[str, Any]] = {} valid_relations: Dict[Tuple[str, str], bool] = {} @@ -681,13 +683,13 @@ def static_check_all( valid_relations = trace.valid_relations_cover else: for (process_id, thread_id), _ in tqdm( - listed_events.items(), ascii=True, leave=True, desc="Groups Processed" + listed_events.items(), ascii=True, leave=False, desc="Groups Processed" ): same_level_func[(process_id, thread_id)] = {} for func_A, func_B in tqdm( permutations(function_pool, 2), ascii=True, - leave=True, + leave=False, desc="Combinations Checked", total=len(function_pool) ** 2, ): @@ -705,7 +707,7 @@ def static_check_all( valid_relations[(func_A, func_B)] = True trace.same_level_func_cover = same_level_func trace.valid_relations_cover = valid_relations - print("End same level checking") + logger.debug("End same level checking") inv_triggered = False @@ -722,7 +724,7 @@ def static_check_all( function_pool = set(function_pool).intersection(set(function_pool_temp)) # type: ignore if len(function_pool) == 0: - print( + logger.debug( "No relevant function calls found in the trace, skipping the checking" ) return CheckerResult( @@ -732,8 +734,8 @@ def static_check_all( triggered=False, ) - print("Starting checking iteration...") - for i in tqdm(range(invariant_length - 1)): + logger.debug("Starting checking iteration...") + for i in tqdm(range(invariant_length - 1), leave=False): param_A = inv.params[i] param_B = inv.params[i + 1] diff --git a/traincheck/invariant/lead_relation.py b/traincheck/invariant/lead_relation.py index 00c222ec..54af49ac 100644 --- a/traincheck/invariant/lead_relation.py +++ b/traincheck/invariant/lead_relation.py @@ -24,6 +24,8 @@ from traincheck.trace.trace import Trace from traincheck.trace.trace_pandas import TracePandas +logger = logging.getLogger(__name__) + EXP_GROUP_NAME = "func_lead" MAX_FUNC_NUM_CONSECUTIVE_CALL = 4 # ideally this should be proportional to the number of training and testing iterations in the trace @@ -276,7 +278,7 @@ def generate_hypothesis(trace) -> list[Hypothesis]: logger = logging.getLogger(__name__) # 1. Pre-process all the events - print("Start preprocessing....") + logger.debug("Start preprocessing....") function_times: Dict[Tuple[str, str], Dict[str, Dict[str, Any]]] = {} function_id_map: Dict[Tuple[str, str], Dict[str, List[str]]] = {} listed_events: Dict[Tuple[str, str], List[dict[str, Any]]] = {} @@ -312,9 +314,9 @@ def generate_hypothesis(trace) -> list[Hypothesis]: trace.function_times = function_times trace.function_id_map = function_id_map trace.listed_events = listed_events - print("End preprocessing") + logger.debug("End preprocessing") - print("Start same level checking...") + logger.debug("Start same level checking...") same_level_func: Dict[Tuple[str, str], Dict[str, Any]] = {} valid_relations: Dict[Tuple[str, str], bool] = {} @@ -326,13 +328,13 @@ def generate_hypothesis(trace) -> list[Hypothesis]: valid_relations = trace.valid_relations_lead else: for (process_id, thread_id), _ in tqdm( - listed_events.items(), ascii=True, leave=True, desc="Groups Processed" + listed_events.items(), ascii=True, leave=False, desc="Groups Processed" ): same_level_func[(process_id, thread_id)] = {} for func_A, func_B in tqdm( permutations(function_pool, 2), ascii=True, - leave=True, + leave=False, desc="Combinations Checked", total=len(function_pool) ** 2, ): @@ -350,10 +352,10 @@ def generate_hypothesis(trace) -> list[Hypothesis]: valid_relations[(func_A, func_B)] = True trace.same_level_func_lead = same_level_func trace.valid_relations_lead = valid_relations - print("End same level checking") + logger.debug("End same level checking") # 3. Generating hypothesis - print("Start generating hypo...") + logger.debug("Start generating hypo...") hypothesis_with_examples = { (func_A, func_B): Hypothesis( invariant=Invariant( @@ -370,12 +372,12 @@ def generate_hypothesis(trace) -> list[Hypothesis]: ) for (func_A, func_B), _ in valid_relations.items() } - print("End generating hypo") + logger.debug("End generating hypo") # 4. Add positive and negative examples - print("Start adding examples...") + logger.debug("Start adding examples...") for (process_id, thread_id), events_list in tqdm( - listed_events.items(), ascii=True, leave=True, desc="Group" + listed_events.items(), ascii=True, leave=False, desc="Group" ): for (func_A, func_B), _ in tqdm( @@ -512,7 +514,7 @@ def generate_hypothesis(trace) -> list[Hypothesis]: (func_A, func_B) ].negative_examples.add_example(example) - print("End adding examples") + logger.debug("End adding examples") return list(hypothesis_with_examples.values()) @@ -523,7 +525,7 @@ def collect_examples(trace, hypothesis): logger = logging.getLogger(__name__) # 1. Pre-process all the events - print("Start preprocessing....") + logger.debug("Start preprocessing....") function_times: Dict[Tuple[str, str], Dict[str, Dict[str, Any]]] = {} function_id_map: Dict[Tuple[str, str], Dict[str, List[str]]] = {} listed_events: Dict[Tuple[str, str], List[dict[str, Any]]] = {} @@ -559,9 +561,9 @@ def collect_examples(trace, hypothesis): trace.function_times = function_times trace.function_id_map = function_id_map trace.listed_events = listed_events - print("End preprocessing") + logger.debug("End preprocessing") - print("Start same level checking...") + logger.debug("Start same level checking...") same_level_func: Dict[Tuple[str, str], Dict[str, Any]] = {} valid_relations: Dict[Tuple[str, str], bool] = {} @@ -573,13 +575,13 @@ def collect_examples(trace, hypothesis): valid_relations = trace.valid_relations_lead else: for (process_id, thread_id), _ in tqdm( - listed_events.items(), ascii=True, leave=True, desc="Groups Processed" + listed_events.items(), ascii=True, leave=False, desc="Groups Processed" ): same_level_func[(process_id, thread_id)] = {} for func_A, func_B in tqdm( permutations(function_pool, 2), ascii=True, - leave=True, + leave=False, desc="Combinations Checked", total=len(function_pool) ** 2, ): @@ -597,7 +599,7 @@ def collect_examples(trace, hypothesis): valid_relations[(func_A, func_B)] = True trace.same_level_func_lead = same_level_func trace.valid_relations_lead = valid_relations - print("End same level checking") + logger.debug("End same level checking") inv = hypothesis.invariant @@ -614,12 +616,12 @@ def collect_examples(trace, hypothesis): function_pool = set(function_pool).intersection(function_pool_temp) if len(function_pool) == 0: - print( + logger.debug( "No relevant function calls found in the trace, skipping the collecting" ) return - print("Starting collecting iteration...") + logger.debug("Starting collecting iteration...") for i in range(invariant_length - 1): param_A = inv.params[i] param_B = inv.params[i + 1] @@ -765,7 +767,7 @@ def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]: # for hypothesis in all_hypotheses: # FunctionLeadRelation.collect_examples(trace, hypothesis) - print("Start precondition inference...") + logger.debug("Start precondition inference...") failed_hypothesis = [] for hypothesis in all_hypotheses.copy(): preconditions = find_precondition(hypothesis, [trace]) @@ -776,7 +778,7 @@ def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]: FailedHypothesis(hypothesis, "Precondition not found") ) all_hypotheses.remove(hypothesis) - print("End precondition inference") + logger.debug("End precondition inference") if_merge = True @@ -787,7 +789,7 @@ def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]: ) # 6. Merge invariants - print("Start merging invariants...") + logger.debug("Start merging invariants...") relation_pool: Dict[ GroupedPreconditions | None, List[Tuple[APIParam, APIParam]] ] = {} @@ -817,7 +819,7 @@ def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]: text_description="Merged FunctionLeadRelation in Ordered List", ) merged_ininvariants.append(new_invariant) - print("End merging invariants") + logger.debug("End merging invariants") return merged_ininvariants, failed_hypothesis @@ -852,7 +854,7 @@ def static_check_all( logger = logging.getLogger(__name__) # 1. Pre-process all the events - print("Start preprocessing....") + logger.debug("Start preprocessing....") function_times: Dict[Tuple[str, str], Dict[str, Dict[str, Any]]] = {} function_id_map: Dict[Tuple[str, str], Dict[str, List[str]]] = {} listed_events: Dict[Tuple[str, str], List[dict[str, Any]]] = {} @@ -893,9 +895,9 @@ def static_check_all( trace.function_times = function_times trace.function_id_map = function_id_map trace.listed_events = listed_events - print("End preprocessing") + logger.debug("End preprocessing") - print("Start same level checking...") + logger.debug("Start same level checking...") same_level_func: Dict[Tuple[str, str], Dict[str, Any]] = {} valid_relations: Dict[Tuple[str, str], bool] = {} @@ -907,13 +909,13 @@ def static_check_all( valid_relations = trace.valid_relations_lead else: for (process_id, thread_id), _ in tqdm( - listed_events.items(), ascii=True, leave=True, desc="Groups Processed" + listed_events.items(), ascii=True, leave=False, desc="Groups Processed" ): same_level_func[(process_id, thread_id)] = {} for func_A, func_B in tqdm( permutations(function_pool, 2), ascii=True, - leave=True, + leave=False, desc="Combinations Checked", total=len(function_pool) ** 2, ): @@ -931,7 +933,7 @@ def static_check_all( valid_relations[(func_A, func_B)] = True trace.same_level_func_lead = same_level_func trace.valid_relations_lead = valid_relations - print("End same level checking") + logger.debug("End same level checking") inv_triggered = False @@ -948,7 +950,7 @@ def static_check_all( function_pool = set(function_pool).intersection(set(function_pool_temp)) if len(function_pool) == 0: - print( + logger.debug( "No relevant function calls found in the trace, skipping the checking" ) return CheckerResult( @@ -958,7 +960,7 @@ def static_check_all( triggered=False, ) - print("Starting checking iteration...") + logger.debug("Starting checking iteration...") for i in range(invariant_length - 1): param_A = inv.params[i] param_B = inv.params[i + 1] From 4645aadab7f60a7db1e4404bca29e941c7f4adb7 Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Wed, 18 Mar 2026 14:20:28 -0400 Subject: [PATCH 08/28] feat: replace multi-bar clutter with single live-stats progress bar MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit check_engine() now shows one bar: "{N checked Β· M left Β· X violated} P%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| N/total [elapsed --- traincheck/checker.py | 36 +++++++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/traincheck/checker.py b/traincheck/checker.py index 556d3774..078dc0c9 100644 --- a/traincheck/checker.py +++ b/traincheck/checker.py @@ -40,17 +40,31 @@ def check_engine( ) -> list[CheckerResult]: logger = logging.getLogger(__name__) results = [] - for inv in tqdm( - invariants, desc="Checking invariants", unit="invariant", leave=False - ): - assert ( - inv.precondition is not None - ), "Invariant precondition is None. It should at least be 'Unconditional' or an empty list. Please check the invariant file and the inference process." - logger.info("=====================================") - res = inv.check(trace, check_relation_first) - res.calc_and_set_time_precentage(trace.get_start_time(), trace.get_end_time()) - logger.info("Invariant %s on trace %s: %s", inv, trace, res) - results.append(res) + total = len(invariants) + n_violated = 0 + bar_fmt = ( + "{desc} {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]" + ) + with tqdm(total=total, bar_format=bar_fmt, unit="inv") as pbar: + pbar.set_description(f"0 checked Β· {total} left Β· 0 violated") + for i, inv in enumerate(invariants): + assert ( + inv.precondition is not None + ), "Invariant precondition is None. It should at least be 'Unconditional' or an empty list. Please check the invariant file and the inference process." + logger.info("=====================================") + res = inv.check(trace, check_relation_first) + res.calc_and_set_time_precentage( + trace.get_start_time(), trace.get_end_time() + ) + logger.info("Invariant %s on trace %s: %s", inv, trace, res) + results.append(res) + if not res.check_passed: + n_violated += 1 + done = i + 1 + pbar.set_description( + f"{done} checked Β· {total - done} left Β· {n_violated} violated" + ) + pbar.update(1) return results From 1c3f8adc30338d4281f7ce63cfb6dcf3dcb779b4 Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Wed, 18 Mar 2026 14:29:51 -0400 Subject: [PATCH 09/28] fix: suppress all inner progress bars during checking MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add traincheck/progress.py β€” a thin tqdm wrapper that checks utils._suppress_inner_progress before creating each bar. check_engine() sets the flag after opening the single outer checking bar, so only "N checked Β· M left Β· X violated" is visible during a check run. All relation code and trace-layer code now imports from traincheck.progress instead of tqdm directly (one-line change per file, no logic changes). Co-Authored-By: Claude Sonnet 4.6 --- traincheck/checker.py | 49 +++++++++++-------- .../invariant/DistinctArgumentRelation.py | 3 +- traincheck/invariant/consistency_relation.py | 3 +- .../invariant/consistency_transient_vars.py | 2 +- traincheck/invariant/contain_relation.py | 2 +- traincheck/invariant/cover_relation.py | 3 +- traincheck/invariant/lead_relation.py | 3 +- traincheck/invariant/precondition.py | 3 +- traincheck/progress.py | 23 +++++++++ traincheck/trace/trace_dict.py | 3 +- traincheck/trace/trace_pandas.py | 3 +- traincheck/trace/trace_polars.py | 2 +- traincheck/utils.py | 4 ++ 13 files changed, 67 insertions(+), 36 deletions(-) create mode 100644 traincheck/progress.py diff --git a/traincheck/checker.py b/traincheck/checker.py index 078dc0c9..9542743f 100644 --- a/traincheck/checker.py +++ b/traincheck/checker.py @@ -6,6 +6,7 @@ from tqdm import tqdm +import traincheck.utils as _tc_utils from traincheck.invariant import CheckerResult, Invariant, read_inv_file from traincheck.reporting import ( ReportEmitter, @@ -45,26 +46,34 @@ def check_engine( bar_fmt = ( "{desc} {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]" ) - with tqdm(total=total, bar_format=bar_fmt, unit="inv") as pbar: - pbar.set_description(f"0 checked Β· {total} left Β· 0 violated") - for i, inv in enumerate(invariants): - assert ( - inv.precondition is not None - ), "Invariant precondition is None. It should at least be 'Unconditional' or an empty list. Please check the invariant file and the inference process." - logger.info("=====================================") - res = inv.check(trace, check_relation_first) - res.calc_and_set_time_precentage( - trace.get_start_time(), trace.get_end_time() - ) - logger.info("Invariant %s on trace %s: %s", inv, trace, res) - results.append(res) - if not res.check_passed: - n_violated += 1 - done = i + 1 - pbar.set_description( - f"{done} checked Β· {total - done} left Β· {n_violated} violated" - ) - pbar.update(1) + with tqdm( + total=total, + bar_format=bar_fmt, + unit="inv", + desc=f"0 checked Β· {total} left Β· 0 violated", + ) as pbar: + _tc_utils._suppress_inner_progress = True + try: + for i, inv in enumerate(invariants): + assert ( + inv.precondition is not None + ), "Invariant precondition is None. It should at least be 'Unconditional' or an empty list. Please check the invariant file and the inference process." + logger.info("=====================================") + res = inv.check(trace, check_relation_first) + res.calc_and_set_time_precentage( + trace.get_start_time(), trace.get_end_time() + ) + logger.info("Invariant %s on trace %s: %s", inv, trace, res) + results.append(res) + if not res.check_passed: + n_violated += 1 + done = i + 1 + pbar.set_description( + f"{done} checked Β· {total - done} left Β· {n_violated} violated" + ) + pbar.update(1) + finally: + _tc_utils._suppress_inner_progress = False return results diff --git a/traincheck/invariant/DistinctArgumentRelation.py b/traincheck/invariant/DistinctArgumentRelation.py index b74a2564..5a1d8fd1 100644 --- a/traincheck/invariant/DistinctArgumentRelation.py +++ b/traincheck/invariant/DistinctArgumentRelation.py @@ -2,8 +2,6 @@ from itertools import combinations from typing import Any, Dict, Iterable, List, Set, Tuple -from tqdm import tqdm - from traincheck.instrumentor.tracer import TraceLineType from traincheck.invariant.base_cls import ( # GroupedPreconditions, APIParam, @@ -20,6 +18,7 @@ ) from traincheck.invariant.precondition import find_precondition from traincheck.onlinechecker.utils import Checker_data, set_meta_vars_online +from traincheck.progress import tqdm from traincheck.trace.trace import Trace from traincheck.utils import safe_isnan diff --git a/traincheck/invariant/consistency_relation.py b/traincheck/invariant/consistency_relation.py index 78d7a732..96a215e1 100644 --- a/traincheck/invariant/consistency_relation.py +++ b/traincheck/invariant/consistency_relation.py @@ -2,8 +2,6 @@ import time from itertools import combinations -from tqdm import tqdm - from traincheck.config import config from traincheck.invariant.base_cls import ( CheckerResult, @@ -19,6 +17,7 @@ ) from traincheck.invariant.precondition import find_precondition from traincheck.onlinechecker.utils import Checker_data, set_meta_vars_online +from traincheck.progress import tqdm from traincheck.trace.trace import Trace from traincheck.trace.types import Liveness, VarInstId diff --git a/traincheck/invariant/consistency_transient_vars.py b/traincheck/invariant/consistency_transient_vars.py index 55bfea02..e3e9eec9 100644 --- a/traincheck/invariant/consistency_transient_vars.py +++ b/traincheck/invariant/consistency_transient_vars.py @@ -3,7 +3,6 @@ from typing import Hashable import pandas as pd -from tqdm import tqdm from traincheck.instrumentor.tracer import TraceLineType from traincheck.invariant.base_cls import ( @@ -26,6 +25,7 @@ ) from traincheck.invariant.precondition import find_precondition from traincheck.onlinechecker.utils import Checker_data, set_meta_vars_online +from traincheck.progress import tqdm from traincheck.trace.trace import Trace from traincheck.trace.types import ( FuncCallEvent, diff --git a/traincheck/invariant/contain_relation.py b/traincheck/invariant/contain_relation.py index 05af1d4f..f599f51b 100644 --- a/traincheck/invariant/contain_relation.py +++ b/traincheck/invariant/contain_relation.py @@ -5,7 +5,6 @@ from typing import Type import numpy as np -from tqdm import tqdm from traincheck.config.config import ANALYSIS_SKIP_FUNC_NAMES from traincheck.instrumentor.tracer import TraceLineType @@ -38,6 +37,7 @@ get_var_raw_event_before_time, set_meta_vars_online, ) +from traincheck.progress import tqdm from traincheck.trace.trace import Trace from traincheck.trace.types import ( ALL_EVENT_TYPES, diff --git a/traincheck/invariant/cover_relation.py b/traincheck/invariant/cover_relation.py index a8e88723..704a74be 100644 --- a/traincheck/invariant/cover_relation.py +++ b/traincheck/invariant/cover_relation.py @@ -2,8 +2,6 @@ from itertools import permutations from typing import Any, Dict, List, Set, Tuple -from tqdm import tqdm - from traincheck.instrumentor.tracer import TraceLineType from traincheck.invariant.base_cls import ( APIParam, @@ -28,6 +26,7 @@ ) from traincheck.invariant.precondition import find_precondition from traincheck.onlinechecker.utils import Checker_data, set_meta_vars_online +from traincheck.progress import tqdm from traincheck.trace.trace import Trace from traincheck.trace.trace_pandas import TracePandas diff --git a/traincheck/invariant/lead_relation.py b/traincheck/invariant/lead_relation.py index 54af49ac..27dcc2ff 100644 --- a/traincheck/invariant/lead_relation.py +++ b/traincheck/invariant/lead_relation.py @@ -2,8 +2,6 @@ from itertools import permutations from typing import Any, Dict, Iterable, List, Set, Tuple -from tqdm import tqdm - from traincheck.instrumentor.tracer import TraceLineType from traincheck.invariant.base_cls import ( APIParam, @@ -21,6 +19,7 @@ ) from traincheck.invariant.precondition import find_precondition from traincheck.onlinechecker.utils import Checker_data, set_meta_vars_online +from traincheck.progress import tqdm from traincheck.trace.trace import Trace from traincheck.trace.trace_pandas import TracePandas diff --git a/traincheck/invariant/precondition.py b/traincheck/invariant/precondition.py index b2040d83..7056c01c 100644 --- a/traincheck/invariant/precondition.py +++ b/traincheck/invariant/precondition.py @@ -2,8 +2,6 @@ from itertools import combinations from typing import Hashable -from tqdm import tqdm - import traincheck.config.config as config from traincheck.invariant.base_cls import ( PT, @@ -14,6 +12,7 @@ Preconditions, UnconditionalPrecondition, ) +from traincheck.progress import tqdm from traincheck.trace.trace import Trace from traincheck.trace.types import MD_NONE from traincheck.utils import safe_isnan diff --git a/traincheck/progress.py b/traincheck/progress.py new file mode 100644 index 00000000..7947dcd4 --- /dev/null +++ b/traincheck/progress.py @@ -0,0 +1,23 @@ +"""Thin tqdm wrapper that can be silenced during invariant checking. + +Import this instead of tqdm in relation and trace modules: + + from traincheck.progress import tqdm + +When traincheck.utils._suppress_inner_progress is True (set by +check_engine() while the outer checking bar is active), all bars +created via this wrapper are disabled so only the single top-level +progress bar is visible. +""" + +from tqdm import tqdm as _tqdm_orig + + +def tqdm(iterable=None, *args, **kwargs): # type: ignore[override] + from traincheck import utils as _utils + + if _utils._suppress_inner_progress and "disable" not in kwargs: + kwargs["disable"] = True + if iterable is not None: + return _tqdm_orig(iterable, *args, **kwargs) + return _tqdm_orig(*args, **kwargs) diff --git a/traincheck/trace/trace_dict.py b/traincheck/trace/trace_dict.py index 7ae2339b..a218a647 100644 --- a/traincheck/trace/trace_dict.py +++ b/traincheck/trace/trace_dict.py @@ -3,10 +3,9 @@ import re from collections import defaultdict -from tqdm import tqdm - from traincheck.config import config from traincheck.instrumentor.tracer import TraceLineType +from traincheck.progress import tqdm from traincheck.trace.trace import Trace from traincheck.trace.types import ( AttrState, diff --git a/traincheck/trace/trace_pandas.py b/traincheck/trace/trace_pandas.py index c51d946d..6263bea6 100644 --- a/traincheck/trace/trace_pandas.py +++ b/traincheck/trace/trace_pandas.py @@ -3,11 +3,11 @@ from typing import Any import pandas as pd -from tqdm import tqdm from traincheck.config import config from traincheck.instrumentor.tracer import TraceLineType from traincheck.instrumentor.types import PTID +from traincheck.progress import tqdm from traincheck.trace.trace import Trace from traincheck.trace.types import ( MD_NONE, @@ -201,6 +201,7 @@ def _rm_incomplete_trailing_func_calls(self): for _, row in tqdm( incomplete_func_call_records.iterrows(), desc="Removing Incomplete Function Calls", + leave=False, ): assert ( row["type"] == TraceLineType.FUNC_CALL_PRE diff --git a/traincheck/trace/trace_polars.py b/traincheck/trace/trace_polars.py index 924637de..243e671f 100644 --- a/traincheck/trace/trace_polars.py +++ b/traincheck/trace/trace_polars.py @@ -2,10 +2,10 @@ import re import polars as pl -from tqdm import tqdm from traincheck.config import config from traincheck.instrumentor.tracer import TraceLineType +from traincheck.progress import tqdm from traincheck.trace.trace import Trace from traincheck.trace.types import ( AttrState, diff --git a/traincheck/utils.py b/traincheck/utils.py index 944fd989..332cfe04 100644 --- a/traincheck/utils.py +++ b/traincheck/utils.py @@ -13,6 +13,10 @@ THREAD_LOCAL = threading.local() +# When True, tqdm bars created via traincheck.progress.tqdm are disabled. +# Set by check_engine() so only the single outer checking bar is visible. +_suppress_inner_progress: bool = False + def safe_getattr(obj, attr, default=None): """Safely get the attribute of an object. From 3f90199be12a57589768757ba01da608b960fd6c Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Wed, 18 Mar 2026 16:31:34 -0400 Subject: [PATCH 10/28] refactor: use to_display_name() as the canonical text_description at inference time Previously to_display_name() was only called at HTML render time in checker_report.py, so failed.log, violations_summary.json, and any other text_description consumer still showed raw internal strings like 'FunctionCoverRelation between torch.optim... and torch.optim...'. Now every generate_hypothesis() / infer() site that constructs an Invariant calls to_display_name(params) directly and uses the result as text_description, falling back to the old string only when to_display_name returns None (e.g. unexpected param types). _format_invariant_label() in checker_report.py already falls through to text_description, so the HTML report continues to work unchanged. Co-Authored-By: Claude Sonnet 4.6 --- .../invariant/DistinctArgumentRelation.py | 5 +++- traincheck/invariant/consistency_relation.py | 8 +++++- .../invariant/consistency_transient_vars.py | 27 ++++++++++++++++--- traincheck/invariant/contain_relation.py | 20 +++++++++++--- traincheck/invariant/cover_relation.py | 10 +++++-- traincheck/invariant/lead_relation.py | 10 +++++-- 6 files changed, 66 insertions(+), 14 deletions(-) diff --git a/traincheck/invariant/DistinctArgumentRelation.py b/traincheck/invariant/DistinctArgumentRelation.py index 5a1d8fd1..315966c9 100644 --- a/traincheck/invariant/DistinctArgumentRelation.py +++ b/traincheck/invariant/DistinctArgumentRelation.py @@ -230,7 +230,10 @@ def generate_hypothesis(trace) -> list[Hypothesis]: relation=DistinctArgumentRelation, params=[APIParam(func_name)], precondition=None, - text_description=f"{func_name} has distinct input arguments on difference PT for each step", + text_description=DistinctArgumentRelation.to_display_name( + [APIParam(func_name)] + ) + or f"{func_name} has distinct input arguments on difference PT for each step", ), positive_examples=ExampleList({EXP_GROUP_NAME}), negative_examples=ExampleList({EXP_GROUP_NAME}), diff --git a/traincheck/invariant/consistency_relation.py b/traincheck/invariant/consistency_relation.py index 96a215e1..5e41f2ea 100644 --- a/traincheck/invariant/consistency_relation.py +++ b/traincheck/invariant/consistency_relation.py @@ -266,7 +266,13 @@ def skip_attrs_with_different_dtypes(attr1, attr2): VarTypeParam(var_type=hypo[2], attr_name=hypo[3]), ], precondition=None, - text_description=f"Consistency Relation between {hypo[0]}.{hypo[1]} and {hypo[2]}.{hypo[3]}", + text_description=ConsistencyRelation.to_display_name( + [ + VarTypeParam(var_type=hypo[0], attr_name=hypo[1]), + VarTypeParam(var_type=hypo[2], attr_name=hypo[3]), + ] + ) + or f"Consistency Relation between {hypo[0]}.{hypo[1]} and {hypo[2]}.{hypo[3]}", ), positive_examples=ExampleList({VAR_GROUP_NAME}), negative_examples=ExampleList({VAR_GROUP_NAME}), diff --git a/traincheck/invariant/consistency_transient_vars.py b/traincheck/invariant/consistency_transient_vars.py index e3e9eec9..2dae7a0d 100644 --- a/traincheck/invariant/consistency_transient_vars.py +++ b/traincheck/invariant/consistency_transient_vars.py @@ -391,7 +391,17 @@ def generate_hypothesis(trace) -> list[Hypothesis]: ), ], precondition=None, - text_description=f"{prop} of the tensors returned by the function {func_name} is consistently {prop_val}.", + text_description=ConsistentOutputRelation.to_display_name( + [ + APIParam(api_full_name=func_name), + VarTypeParam( + var_type="torch.Tensor", + attr_name=prop, + const_value=make_hashable(prop_val), + ), + ] + ) + or f"{prop} of the tensors returned by the function {func_name} is consistently {prop_val}.", ), positive_examples=ExampleList({"pre_event"}), negative_examples=ExampleList({"pre_event"}), @@ -777,7 +787,10 @@ def generate_hypothesis(trace: Trace) -> list[Hypothesis]: relation=ConsistentInputOutputRelation, params=[input_param, api_param, output_param], precondition=None, - text_description=f"The value {common_value} is consistent across the input {input_path} and output {output_path} tensors of the function {func_name}.", + text_description=ConsistentInputOutputRelation.to_display_name( + [input_param, api_param, output_param] + ) + or f"The value {common_value} is consistent across the input {input_path} and output {output_path} tensors of the function {func_name}.", ), positive_examples=ExampleList( {"pre_event"} @@ -1209,7 +1222,10 @@ def generate_hypothesis(trace: Trace) -> list[Hypothesis]: input_param, ], # the first param should be larger or equal to the second param precondition=None, - text_description=f"Output tensor's value at {output_param.additional_path} is consistently larger than or equal to the min input threshold {input_param.name} for the function {func_name}.", + text_description=ThresholdRelation.to_display_name( + [output_param, api_param, input_param] + ) + or f"Output tensor's value at {output_param.additional_path} is consistently larger than or equal to the min input threshold {input_param.name} for the function {func_name}.", ), positive_examples=ExampleList( {"pre_event"} @@ -1262,7 +1278,10 @@ def generate_hypothesis(trace: Trace) -> list[Hypothesis]: output_param, ], # the first param should be larger or equal to the second param precondition=None, - text_description=f"Output tensor's value at {output_param.additional_path} is consistently less than or equal to the max input threshold {input_param.name} for the function {func_name}.", + text_description=ThresholdRelation.to_display_name( + [input_param, api_param, output_param] + ) + or f"Output tensor's value at {output_param.additional_path} is consistently less than or equal to the max input threshold {input_param.name} for the function {func_name}.", ), positive_examples=ExampleList( {"pre_event"} diff --git a/traincheck/invariant/contain_relation.py b/traincheck/invariant/contain_relation.py index f599f51b..33ab5905 100644 --- a/traincheck/invariant/contain_relation.py +++ b/traincheck/invariant/contain_relation.py @@ -287,7 +287,10 @@ def _merge_hypotheses(hypotheses: list[Hypothesis]) -> list[Hypothesis]: parent_param, merged_child_param, ], - text_description=f"{parent_param.api_full_name} contains {merged_child_param}", + text_description=APIContainRelation.to_display_name( + [parent_param, merged_child_param] + ) + or f"{parent_param.api_full_name} contains {merged_child_param}", num_positive_examples=len(all_positive_examples), num_negative_examples=len(all_positive_examples), precondition=None, # to be inferred later @@ -735,7 +738,10 @@ def _infer( relation=APIContainRelation, params=[parent_param, child_param], precondition=None, - text_description=f"{parent} contains {child_param} of type {typename(event)}", + text_description=APIContainRelation.to_display_name( + [parent_param, child_param] + ) + or f"{parent} contains {child_param} of type {typename(event)}", ), positive_examples=ExampleList( {PARENT_GROUP_NAME, VAR_GROUP_NAME} @@ -773,7 +779,10 @@ def _infer( relation=APIContainRelation, params=[parent_param, child_param], precondition=None, - text_description=f"{parent} contains {child_param} of type {typename(event)}", + text_description=APIContainRelation.to_display_name( + [parent_param, child_param] + ) + or f"{parent} contains {child_param} of type {typename(event)}", ), positive_examples=ExampleList({PARENT_GROUP_NAME}), negative_examples=ExampleList({PARENT_GROUP_NAME}), @@ -827,7 +836,10 @@ def _merge_child_API_events( relation=APIContainRelation, params=[parent_param, merged_child_param], precondition=None, - text_description=f"{parent} contains {merged_child_param}", + text_description=APIContainRelation.to_display_name( + [parent_param, merged_child_param] + ) + or f"{parent} contains {merged_child_param}", ), positive_examples=ExampleList({PARENT_GROUP_NAME}), negative_examples=ExampleList({PARENT_GROUP_NAME}), diff --git a/traincheck/invariant/cover_relation.py b/traincheck/invariant/cover_relation.py index 704a74be..567e6bec 100644 --- a/traincheck/invariant/cover_relation.py +++ b/traincheck/invariant/cover_relation.py @@ -208,7 +208,10 @@ def generate_hypothesis(trace) -> list[Hypothesis]: APIParam(func_B), ], precondition=None, - text_description=f"FunctionCoverRelation between {func_A} and {func_B}", + text_description=FunctionCoverRelation.to_display_name( + [APIParam(func_A), APIParam(func_B)] + ) + or f"FunctionCoverRelation between {func_A} and {func_B}", ), positive_examples=ExampleList({EXP_GROUP_NAME}), negative_examples=ExampleList({EXP_GROUP_NAME}), @@ -589,7 +592,10 @@ def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]: relation=FunctionCoverRelation, params=[param for param in merged_value], precondition=key, - text_description="Merged FunctionCoverRelation in Ordered List", + text_description=FunctionCoverRelation.to_display_name( + [param for param in merged_value] + ) + or "Merged FunctionCoverRelation in Ordered List", ) merged_ininvariants.append(new_invariant) logger.debug("End merging invariants") diff --git a/traincheck/invariant/lead_relation.py b/traincheck/invariant/lead_relation.py index 27dcc2ff..427b7687 100644 --- a/traincheck/invariant/lead_relation.py +++ b/traincheck/invariant/lead_relation.py @@ -364,7 +364,10 @@ def generate_hypothesis(trace) -> list[Hypothesis]: APIParam(func_B), ], precondition=None, - text_description=f"FunctionLeadRelation between {func_A} and {func_B}", + text_description=FunctionLeadRelation.to_display_name( + [APIParam(func_A), APIParam(func_B)] + ) + or f"FunctionLeadRelation between {func_A} and {func_B}", ), positive_examples=ExampleList({EXP_GROUP_NAME}), negative_examples=ExampleList({EXP_GROUP_NAME}), @@ -815,7 +818,10 @@ def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]: relation=FunctionLeadRelation, params=[param for param in merged_value], precondition=key, - text_description="Merged FunctionLeadRelation in Ordered List", + text_description=FunctionLeadRelation.to_display_name( + [param for param in merged_value] + ) + or "Merged FunctionLeadRelation in Ordered List", ) merged_ininvariants.append(new_invariant) logger.debug("End merging invariants") From 5e781871455dc25cd644a601b07d78fc01003ec5 Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Wed, 18 Mar 2026 17:41:08 -0400 Subject: [PATCH 11/28] fix: wire ANALYSIS_SKIP_FUNC_NAMES through all relation function filters MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously lead_relation, DistinctArgumentRelation, and consistency_transient_vars each had their own ad-hoc filtering logic ("._" substring checks, torch.override HACKs) that didn't respect ANALYSIS_SKIP_FUNC_NAMES β€” so the entry added to config.py had no effect on those relations. Now all four relation types use the same ANALYSIS_SKIP_FUNC_NAMES list as the single source of truth for which function names to skip. --- traincheck/config/config.py | 1 + traincheck/invariant/DistinctArgumentRelation.py | 5 +++-- .../invariant/consistency_transient_vars.py | 15 +++++---------- traincheck/invariant/lead_relation.py | 5 +++-- 4 files changed, 12 insertions(+), 14 deletions(-) diff --git a/traincheck/config/config.py b/traincheck/config/config.py index b7d019b0..ddbc4a2a 100644 --- a/traincheck/config/config.py +++ b/traincheck/config/config.py @@ -86,6 +86,7 @@ "torch.optim.optimizer._get_value", "torch.overrides", "._", # skip all private functions (they can only be the contained, but not containing functions) + "", # skip closures / local functions β€” not meaningful in invariant descriptions ] INSTR_OPTS = None # TODO: set defaults for this variable diff --git a/traincheck/invariant/DistinctArgumentRelation.py b/traincheck/invariant/DistinctArgumentRelation.py index 315966c9..0088715d 100644 --- a/traincheck/invariant/DistinctArgumentRelation.py +++ b/traincheck/invariant/DistinctArgumentRelation.py @@ -2,6 +2,7 @@ from itertools import combinations from typing import Any, Dict, Iterable, List, Set, Tuple +from traincheck.config.config import ANALYSIS_SKIP_FUNC_NAMES from traincheck.instrumentor.tracer import TraceLineType from traincheck.invariant.base_cls import ( # GroupedPreconditions, APIParam, @@ -45,9 +46,9 @@ def get_func_names_to_deal_with(trace: Trace) -> List[str]: # get all functions in the trace all_func_names = trace.get_func_names() - # filtering 1: remove private functions + # filtering 1: skip functions matched by ANALYSIS_SKIP_FUNC_NAMES for func_name in all_func_names: - if "._" in func_name: + if any(skip in func_name for skip in ANALYSIS_SKIP_FUNC_NAMES): continue function_pool.add(func_name) diff --git a/traincheck/invariant/consistency_transient_vars.py b/traincheck/invariant/consistency_transient_vars.py index 2dae7a0d..f02e65bb 100644 --- a/traincheck/invariant/consistency_transient_vars.py +++ b/traincheck/invariant/consistency_transient_vars.py @@ -4,6 +4,7 @@ import pandas as pd +from traincheck.config.config import ANALYSIS_SKIP_FUNC_NAMES from traincheck.instrumentor.tracer import TraceLineType from traincheck.invariant.base_cls import ( _NOT_SET, @@ -245,17 +246,11 @@ def get_input_thresholds( def get_events_of_funcs_with_tensors( all_func_names, trace, output_has_tensors=True, input_has_tensors=True ): - # HACK: remove all torch.overrides + # skip functions matched by ANALYSIS_SKIP_FUNC_NAMES all_func_names = [ - func_name for func_name in all_func_names if "torch.override" not in func_name - ] - # remove all functions with "._" in them - all_func_names = [ - func_name for func_name in all_func_names if "._" not in func_name - ] - # remove all functions with "._is_" in them - all_func_names = [ - func_name for func_name in all_func_names if ".is_" not in func_name + func_name + for func_name in all_func_names + if not any(skip in func_name for skip in ANALYSIS_SKIP_FUNC_NAMES) ] # if os.path.exists(_CACHE_PATH): diff --git a/traincheck/invariant/lead_relation.py b/traincheck/invariant/lead_relation.py index 427b7687..6f433a77 100644 --- a/traincheck/invariant/lead_relation.py +++ b/traincheck/invariant/lead_relation.py @@ -2,6 +2,7 @@ from itertools import permutations from typing import Any, Dict, Iterable, List, Set, Tuple +from traincheck.config.config import ANALYSIS_SKIP_FUNC_NAMES from traincheck.instrumentor.tracer import TraceLineType from traincheck.invariant.base_cls import ( APIParam, @@ -86,9 +87,9 @@ def get_func_names_to_deal_with(trace: Trace) -> List[str]: # get all functions in the trace all_func_names = trace.get_func_names() - # filtering 1: remove private functions + # filtering 1: skip functions matched by ANALYSIS_SKIP_FUNC_NAMES for func_name in all_func_names: - if "._" in func_name: + if any(skip in func_name for skip in ANALYSIS_SKIP_FUNC_NAMES): continue function_pool.add(func_name) From d6e68c93617199042be406c6e81093950eaa440a Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Wed, 18 Mar 2026 18:38:31 -0400 Subject: [PATCH 12/28] feat: clean up inference stdout with structured per-phase output MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit generate_hypothesis(): - Prints "[Trace N/M] Generating hypotheses" header per trace - After each relation completes, prints " RelationName: N hypotheses (Xs)" via tqdm.write() so it stays above any inner progress bars - Removes the "Merging Hypotheses" tqdm (pure bookkeeping, not useful) prune_incorrect_hypos(): prints "N pruned β†’ M remaining" summary line collect_examples(): silent unless cross-trace work is needed infer_precondition(): - Single outer tqdm bar: "N done Β· M failed P%|β–ˆβ–ˆβ–ˆβ–ˆ| N/total [elapsed dict[Hypothesis, list[int]]: Returns: dict[Hypothesis, list[int]]: A dictionary mapping hypotheses to the indices of traces that support them """ - logger.info("============= GENERATING HYPOTHESIS =============") hypotheses_and_trace_idxs: dict[Hypothesis, list[int]] = {} - hypo_lookup = {} # Dictionary for O(1) lookup of hypotheses + hypo_lookup = {} + n_traces = len(self.traces) + active_relations = [ + r for r in relation_pool if r not in self.disabled_relations + ] + for trace_idx, trace in enumerate(self.traces): - logger.info(f"Processing trace {trace_idx + 1}/{len(self.traces)}") - for relation_idx, relation in enumerate(relation_pool): - logger.info( - f"Processing relation {relation_idx + 1}/{len(relation_pool)}: {relation.__name__}" - ) - if self.disabled_relations and relation in self.disabled_relations: - logger.info( - f"Skipping relation {relation.__name__} as it is disabled" - ) - continue + print(f"\n[Trace {trace_idx + 1}/{n_traces}] Generating hypotheses") + for relation in active_relations: logger.info(f"Generating hypotheses for relation: {relation.__name__}") + t0 = time.time() inferred_hypos = relation.generate_hypothesis(trace) - logger.info( - f"Found {len(inferred_hypos)} hypotheses for relation: {relation.__name__} on trace {trace_idx + 1}/{len(self.traces)}" + elapsed = time.time() - t0 + tqdm.write( + f" {relation.__name__}: {len(inferred_hypos)} hypotheses ({elapsed:.1f}s)" ) logger.info( - f"Merging hypotheses with existing ones, number of existing ones: {len(hypotheses_and_trace_idxs)}" + f"Found {len(inferred_hypos)} hypotheses for {relation.__name__} " + f"on trace {trace_idx + 1}/{n_traces}" ) - for hypo in tqdm( - inferred_hypos, desc="Merging Hypotheses with existing ones" - ): + for hypo in inferred_hypos: if hypo not in hypotheses_and_trace_idxs: hypotheses_and_trace_idxs[hypo] = [trace_idx] - hypo_lookup[hypo] = hypo # Add to lookup dictionary + hypo_lookup[hypo] = hypo else: hypotheses_and_trace_idxs[hypo].append(trace_idx) - original_hypo = hypo_lookup[hypo] # O(1) lookup + original_hypo = hypo_lookup[hypo] orig_num_pos_exps = len(original_hypo.positive_examples) orig_num_neg_exps = len(original_hypo.negative_examples) original_hypo.positive_examples.examples.extend( @@ -102,19 +100,24 @@ def generate_hypothesis(self) -> dict[Hypothesis, list[int]]: ) == orig_num_neg_exps + len( hypo.negative_examples ), f"Expected {orig_num_neg_exps} + {len(hypo.negative_examples)} negative examples, got {len(hypo_lookup[hypo].negative_examples)}" - logger.info(f"Finished processing trace {trace_idx + 1}/{len(self.traces)}") - logger.info( - f"Finished generating hypotheses, found {len(hypotheses_and_trace_idxs)} hypotheses" - ) + + total = len(hypotheses_and_trace_idxs) + print(f"\n {total} hypotheses generated across all relations") + logger.info(f"Finished generating hypotheses, found {total} hypotheses") return hypotheses_and_trace_idxs def collect_examples(self, hypotheses: dict[Hypothesis, list[int]]): logger.info("============= COLLECTING EXAMPLES =============") - logger.info(f"Start collecting examples for {len(hypotheses)} hypotheses") - for hypo, trace_idxs in hypotheses.items(): - logger.info( - f"Collecting examples for hypothesis: {hypo.invariant.text_description}" + cross_trace_hypos = [ + (hypo, trace_idxs) + for hypo, trace_idxs in hypotheses.items() + if len(set(range(len(self.traces))) - set(trace_idxs)) > 0 + ] + if cross_trace_hypos: + print( + f"\nCollecting examples for {len(cross_trace_hypos)} cross-trace hypotheses" ) + for hypo, trace_idxs in cross_trace_hypos: for trace_idx, trace in enumerate(self.traces): if trace_idx in trace_idxs: continue @@ -125,7 +128,6 @@ def collect_examples(self, hypotheses: dict[Hypothesis, list[int]]): def prune_incorrect_hypos(self, hypotheses: dict[Hypothesis, list[int]]): """Prune incorrect hypotheses based on the collected examples""" - incorrect_hypos = [] correct_hypos = {} for hypo, trace_idxs in hypotheses.items(): @@ -135,30 +137,46 @@ def prune_incorrect_hypos(self, hypotheses: dict[Hypothesis, list[int]]): incorrect_hypos.append( FailedHypothesis(hypo, "only one positive example") ) + n_pruned = len(incorrect_hypos) + n_kept = len(correct_hypos) + print(f" {n_pruned} pruned (insufficient examples) β†’ {n_kept} remaining") return correct_hypos, incorrect_hypos def infer_precondition(self, hypotheses: dict[Hypothesis, list[int]]): """TODO: move the precondition inference driving code into Hypothesis.get_invariant()""" logger.info("============= INFERING PRECONDITIONS =============") - logger.info(f"Inferring preconditions for {len(hypotheses)} hypotheses") - all_hypotheses: list[Hypothesis] = [] - for hypo in hypotheses: - all_hypotheses.append(hypo) + all_hypotheses = list(hypotheses.keys()) + total = len(all_hypotheses) + print(f"\nInferring preconditions for {total} hypotheses") invariants = [] failed_hypos = [] - for hypo_idx, hypothesis in enumerate(all_hypotheses): - logger.info( - f"Inferring precondition for hypothesis {hypo_idx + 1}/{len(all_hypotheses)}: {hypothesis.invariant.text_description}" - ) - precondition = find_precondition(hypothesis, self.traces) - if precondition is None: - failed_hypos.append( - FailedHypothesis(hypothesis, "Precondition not found") - ) - else: - hypothesis.invariant.precondition = precondition - invariants.append(hypothesis.get_invariant(self.all_stages)) + bar_fmt = "{desc} {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]" + _tc_utils._suppress_inner_progress = True + try: + with tqdm(total=total, bar_format=bar_fmt, unit="hypo") as pbar: + pbar.set_description("0 done Β· 0 failed") + for hypo_idx, hypothesis in enumerate(all_hypotheses): + logger.info( + f"Inferring precondition for hypothesis {hypo_idx + 1}/{total}: " + f"{hypothesis.invariant.text_description}" + ) + precondition = find_precondition(hypothesis, self.traces) + if precondition is None: + failed_hypos.append( + FailedHypothesis(hypothesis, "Precondition not found") + ) + else: + hypothesis.invariant.precondition = precondition + invariants.append(hypothesis.get_invariant(self.all_stages)) + pbar.set_description( + f"{len(invariants)} done Β· {len(failed_hypos)} failed" + ) + pbar.update(1) + finally: + _tc_utils._suppress_inner_progress = False + + print(f" {len(invariants)} invariants Β· {len(failed_hypos)} failed") return invariants, failed_hypos diff --git a/traincheck/invariant/precondition.py b/traincheck/invariant/precondition.py index 7056c01c..f86ffb95 100644 --- a/traincheck/invariant/precondition.py +++ b/traincheck/invariant/precondition.py @@ -159,8 +159,8 @@ def verify_precondition_safety( """ for example in negative_examples: if precondition.verify(example): - print("Precondition is not safe") - print("Example", example) + logger.debug("Precondition is not safe") + logger.debug("Example %s", example) return False return True @@ -562,7 +562,7 @@ def find_precondition_from_single_group( if len(local_clauses) == 0: # NOTE: this would also happen under the unconditional case, but since the unconditional case is handled separately, we should not reach here - print("example: ", example) + logger.debug("example: %s", example) raise ValueError( "No clauses can be found in the example, precondition will be empty." ) From 242959c686a11e55619dafa25513ab40ed431db3 Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Wed, 18 Mar 2026 18:42:24 -0400 Subject: [PATCH 13/28] feat: show per-relation progress bar during hypothesis generation Add outer tqdm bar over active relations in generate_hypothesis() so users see which relation is currently running (not just summary lines after each completes). Each completed relation prints elapsed time and hypothesis count via tqdm.write() so lines stay above the live bar. Also fix indentation bug: for-hypo merge loop was outside the for-relation loop, so only the last relation's hypotheses were merged. Co-Authored-By: Claude Sonnet 4.6 --- traincheck/infer_engine.py | 88 +++++++++++++++++++++----------------- 1 file changed, 49 insertions(+), 39 deletions(-) diff --git a/traincheck/infer_engine.py b/traincheck/infer_engine.py index 0f0deab4..267ef72b 100644 --- a/traincheck/infer_engine.py +++ b/traincheck/infer_engine.py @@ -60,46 +60,56 @@ def generate_hypothesis(self) -> dict[Hypothesis, list[int]]: r for r in relation_pool if r not in self.disabled_relations ] + rel_bar_fmt = "{desc} {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]" for trace_idx, trace in enumerate(self.traces): - print(f"\n[Trace {trace_idx + 1}/{n_traces}] Generating hypotheses") - for relation in active_relations: - logger.info(f"Generating hypotheses for relation: {relation.__name__}") - t0 = time.time() - inferred_hypos = relation.generate_hypothesis(trace) - elapsed = time.time() - t0 - tqdm.write( - f" {relation.__name__}: {len(inferred_hypos)} hypotheses ({elapsed:.1f}s)" - ) - logger.info( - f"Found {len(inferred_hypos)} hypotheses for {relation.__name__} " - f"on trace {trace_idx + 1}/{n_traces}" - ) - for hypo in inferred_hypos: - if hypo not in hypotheses_and_trace_idxs: - hypotheses_and_trace_idxs[hypo] = [trace_idx] - hypo_lookup[hypo] = hypo - else: - hypotheses_and_trace_idxs[hypo].append(trace_idx) - original_hypo = hypo_lookup[hypo] - orig_num_pos_exps = len(original_hypo.positive_examples) - orig_num_neg_exps = len(original_hypo.negative_examples) - original_hypo.positive_examples.examples.extend( - hypo.positive_examples.examples - ) - original_hypo.negative_examples.examples.extend( - hypo.negative_examples.examples - ) - - assert len( - hypo_lookup[hypo].positive_examples - ) == orig_num_pos_exps + len( - hypo.positive_examples - ), f"Expected {orig_num_pos_exps} + {len(hypo.positive_examples)} positive examples, got {len(hypo_lookup[hypo].positive_examples)}" - assert len( - hypo_lookup[hypo].negative_examples - ) == orig_num_neg_exps + len( - hypo.negative_examples - ), f"Expected {orig_num_neg_exps} + {len(hypo.negative_examples)} negative examples, got {len(hypo_lookup[hypo].negative_examples)}" + tqdm.write(f"\n[Trace {trace_idx + 1}/{n_traces}] Generating hypotheses") + with tqdm( + active_relations, + bar_format=rel_bar_fmt, + unit="relation", + leave=True, + ) as rel_bar: + for relation in rel_bar: + rel_bar.set_description(f" {relation.__name__}") + logger.info( + f"Generating hypotheses for relation: {relation.__name__}" + ) + t0 = time.time() + inferred_hypos = relation.generate_hypothesis(trace) + elapsed = time.time() - t0 + tqdm.write( + f" {relation.__name__}: {len(inferred_hypos)} hypotheses ({elapsed:.1f}s)" + ) + logger.info( + f"Found {len(inferred_hypos)} hypotheses for {relation.__name__} " + f"on trace {trace_idx + 1}/{n_traces}" + ) + for hypo in inferred_hypos: + if hypo not in hypotheses_and_trace_idxs: + hypotheses_and_trace_idxs[hypo] = [trace_idx] + hypo_lookup[hypo] = hypo + else: + hypotheses_and_trace_idxs[hypo].append(trace_idx) + original_hypo = hypo_lookup[hypo] + orig_num_pos_exps = len(original_hypo.positive_examples) + orig_num_neg_exps = len(original_hypo.negative_examples) + original_hypo.positive_examples.examples.extend( + hypo.positive_examples.examples + ) + original_hypo.negative_examples.examples.extend( + hypo.negative_examples.examples + ) + + assert len( + hypo_lookup[hypo].positive_examples + ) == orig_num_pos_exps + len( + hypo.positive_examples + ), f"Expected {orig_num_pos_exps} + {len(hypo.positive_examples)} positive examples, got {len(hypo_lookup[hypo].positive_examples)}" + assert len( + hypo_lookup[hypo].negative_examples + ) == orig_num_neg_exps + len( + hypo.negative_examples + ), f"Expected {orig_num_neg_exps} + {len(hypo.negative_examples)} negative examples, got {len(hypo_lookup[hypo].negative_examples)}" total = len(hypotheses_and_trace_idxs) print(f"\n {total} hypotheses generated across all relations") From 3c820a35278875094e24b362fe4f1953f42c22b1 Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Wed, 18 Mar 2026 23:27:33 -0400 Subject: [PATCH 14/28] fix: demote instrumentor internal prints to logger.debug Eliminates per-step stdout noise from control.py (Warmup/Interval/ Skipping step printed every training step), shutdown messages from dumper.py, AST loop/model detection messages from source_file.py, and proxy parameter setup messages from proxy.py. All demoted to logger.debug() so they remain accessible with -d flag but don't clutter normal traincheck-collect output. Co-Authored-By: Claude Sonnet 4.6 --- traincheck/instrumentor/control.py | 14 ++++---- traincheck/instrumentor/dumper.py | 8 ++--- .../instrumentor/proxy_wrapper/proxy.py | 7 ++-- traincheck/instrumentor/source_file.py | 33 ++++++++----------- 4 files changed, 28 insertions(+), 34 deletions(-) diff --git a/traincheck/instrumentor/control.py b/traincheck/instrumentor/control.py index a1f5c87d..84e0ca5f 100644 --- a/traincheck/instrumentor/control.py +++ b/traincheck/instrumentor/control.py @@ -32,17 +32,17 @@ def start_step(): config.DISABLE_WRAPPER = False if current_step < warm_up: - print(f"Warmup step {current_step}") + logger.debug(f"Warmup step {current_step}") config.DISABLE_WRAPPER = False elif (current_step - warm_up) % interval == 0: - print(f"Interval step {current_step}") + logger.debug(f"Interval step {current_step}") config.DISABLE_WRAPPER = False else: - print(f"Skipping step {current_step}") + logger.debug(f"Skipping step {current_step}") config.DISABLE_WRAPPER = True else: # No policy, always enable - print("No policy, always enable") + logger.debug("No policy, always enable") config.DISABLE_WRAPPER = False @@ -65,13 +65,13 @@ def start_eval_step(): config.DISABLE_WRAPPER = False if current_step < warm_up: - print(f"Eval: Warmup step {current_step}") + logger.debug(f"Eval: Warmup step {current_step}") config.DISABLE_WRAPPER = False elif (current_step - warm_up) % interval == 0: - print(f"Eval: Interval step {current_step}") + logger.debug(f"Eval: Interval step {current_step}") config.DISABLE_WRAPPER = False else: - print(f"Eval: Skipping step {current_step}") + logger.debug(f"Eval: Skipping step {current_step}") config.DISABLE_WRAPPER = True else: config.DISABLE_WRAPPER = False diff --git a/traincheck/instrumentor/dumper.py b/traincheck/instrumentor/dumper.py index a7471053..a9dcf5e2 100644 --- a/traincheck/instrumentor/dumper.py +++ b/traincheck/instrumentor/dumper.py @@ -92,8 +92,8 @@ def serialize(obj_dict: dict[str, object | str]) -> str: def monitor_main_thread(main_thread, stop_event): main_thread.join() # Wait for the main thread to finish - print("Main thread has finished or encountered an exception") - print("Flushing all buffers to the trace log file") + logger.debug("Main thread has finished or encountered an exception") + logger.debug("Flushing all buffers to the trace log file") stop_event.set() # Signal the logging threads to stop @@ -106,12 +106,12 @@ def trace_dumper(task_queue: Queue, trace_file_name: str, stop_event: threading. ) # wait for 2x the flush interval, this is an arbitrary number, as long as it is larger than the flush interval, it should be fine. except Empty: if stop_event.is_set(): - print("Trace dumper thread has stopped.") + logger.debug("Trace dumper thread has stopped.") break continue f.write(f"{trace}\n") task_queue.task_done() - print("Trace dumper thread has finished normally...") + logger.debug("Trace dumper thread has finished normally.") def get_trace_API_dumper_queue(): diff --git a/traincheck/instrumentor/proxy_wrapper/proxy.py b/traincheck/instrumentor/proxy_wrapper/proxy.py index d9003f10..4def285b 100644 --- a/traincheck/instrumentor/proxy_wrapper/proxy.py +++ b/traincheck/instrumentor/proxy_wrapper/proxy.py @@ -21,6 +21,8 @@ from .proxy_registry import get_global_registry from .utils import print_debug +logger = logging.getLogger(__name__) + class ProxyObjInfo: def __init__(self, var_name: str, last_update_timestamp: int, version: int | None): @@ -118,9 +120,8 @@ def proxy_parameters(module: torch.nn.Module, parent_name="", from_iter=False): time_end = time.perf_counter() if num_params != 0: - print( - "logger_proxy: " - + f"Proxied {num_params} parameters of '{parent_name + module.__class__.__name__}', duration: {time_end - start_time} seconds" + logger.debug( + f"Proxied {num_params} parameters of '{parent_name + module.__class__.__name__}', duration: {time_end - start_time:.3f}s" ) def update_timestamp(self): diff --git a/traincheck/instrumentor/source_file.py b/traincheck/instrumentor/source_file.py index e056b601..850527fe 100644 --- a/traincheck/instrumentor/source_file.py +++ b/traincheck/instrumentor/source_file.py @@ -125,10 +125,10 @@ def _get_loop_context(self, node): if iter_name: if "train" in iter_name: - print(f"Found training loop based on iterator: {iter_name}") + logger.debug(f"Found training loop based on iterator: {iter_name}") return "training" elif any(x in iter_name for x in ["val", "eval", "test"]): - print(f"Found eval loop based on iterator: {iter_name}") + logger.debug(f"Found eval loop based on iterator: {iter_name}") return "eval" # Heuristic 2: Check for calls to .step() or .backward() or .eval() @@ -187,25 +187,25 @@ def _get_loop_context(self, node): if isinstance(expr.func, ast.Attribute): if expr.func.attr == "no_grad": has_eval_signal = True - print(f"Found no_grad context in loop {node}.") + logger.debug(f"Found no_grad context in loop {node}.") if has_training_signal: - print(f"Found training signal in loop {node}.") + logger.debug(f"Found training signal in loop {node}.") return "training" if has_eval_signal: - print(f"Found eval signal in loop {node}.") + logger.debug(f"Found eval signal in loop {node}.") return "eval" # if the number of lines are too few and the function calls do not involve "eval", "train", we omit the loop context # We use statement_count calculated recursively if statement_count < 3: - print( + logger.debug( f"Skipping loop {node} as it is too short ({statement_count} statements) and does not contain eval/train/step/backward signal." ) return None - print(f"Found eval signal in loop {node} (fallback).") + logger.debug(f"Found eval signal in loop {node} (fallback).") return "eval" def _inject_call(self, node, func_name): @@ -465,7 +465,7 @@ def get_child_parent_map(root) -> dict[ast.AST, ast.AST]: for node in ast.walk(root): for child in ast.iter_child_nodes(node): if child in parent_map and not ast.unparse(child).strip() == "": - print( + logger.debug( f"Node {ast.unparse(child)} already has a parent, {ast.unparse(parent_map[child])}" ) parent_map[child] = node @@ -480,7 +480,7 @@ def instrument_all_model_assignments( Finds all assignment statements to `model` and inserts a Proxy statement or a VarSampler statement after each assignment, depending on the mode. """ - print( + logger.debug( f"Instrumenting model: {model_name}, mode: {mode}, scanning for assignments to {model_name}" ) @@ -529,10 +529,10 @@ def instrument_all_model_assignments( if node in parent_map: parent = parent_map[node] # print(f"Parent node: {ast.unparse(parent)}") - print("\tInstrumenting: ", ast.unparse(node)) + logger.debug("Instrumenting: %s", ast.unparse(node)) if isinstance(parent, ast.For): - print( - "\t\t⬆️ Parent is a for loop, cowardly skipping instrumentation in fear of multiple models with the same 'var_name'" + logger.debug( + "Parent is a for loop, skipping instrumentation to avoid multiple models with the same 'var_name'" ) continue if node in parent.body: # type: ignore @@ -601,25 +601,18 @@ def instrument_model_tracker_proxy( spec = importlib.util.find_spec('traincheck') if spec and spec.origin: traincheck_folder = os.path.dirname(spec.origin) - print("traincheck folder: ", traincheck_folder) else: raise Exception("traincheck is not installed properly") -print("auto observer enabled with observing depth: ", auto_observer_config["enable_auto_observer_depth"]) enable_auto_observer_depth = auto_observer_config["enable_auto_observer_depth"] neglect_hidden_func = auto_observer_config["neglect_hidden_func"] neglect_hidden_module = auto_observer_config["neglect_hidden_module"] observe_then_unproxy = auto_observer_config["observe_then_unproxy"] observe_up_to_depth = auto_observer_config["observe_up_to_depth"] -if observe_up_to_depth: - print("observe up to the depth of the function call") -else: - print("observe only the function call at the depth") from traincheck.static_analyzer.graph_generator.call_graph_parser import add_observer_given_call_graph log_files = glob.glob( os.path.join(traincheck_folder, "static_analyzer", "func_level", "*.log") ) -print("log_files: ", log_files) for log_file in log_files: add_observer_given_call_graph( log_file, @@ -1072,7 +1065,7 @@ def instrument_file( if model_tracker_style == "proxy" or model_tracker_style == "subclass": if model_tracker_style == "subclass": # adjust the proxy config to disable the proxy-specific configs - print( + logger.debug( "Using subclass model tracker, overriding observe_then_unproxy to False" ) adjusted_proxy_config[0]["observe_then_unproxy"] = False From 09c9ea796da6529b61e461272f7f21529a5be883 Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Wed, 18 Mar 2026 23:30:28 -0400 Subject: [PATCH 15/28] fix: prefix all TrainCheck log output with [TrainCheck] Update logging format strings in all three CLI entry points so every log message is visually identifiable as coming from TrainCheck, not the user's training script. Co-Authored-By: Claude Sonnet 4.6 --- traincheck/checker.py | 3 +++ traincheck/collect_trace.py | 5 +++-- traincheck/infer_engine.py | 2 +- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/traincheck/checker.py b/traincheck/checker.py index 9542743f..05dee116 100644 --- a/traincheck/checker.py +++ b/traincheck/checker.py @@ -263,6 +263,9 @@ def main(): # INFO logs stay in the log file only. stream_handler = logging.StreamHandler() stream_handler.setLevel(logging.WARNING) + stream_handler.setFormatter( + logging.Formatter("[TrainCheck] %(levelname)s: %(message)s") + ) logger.addHandler(stream_handler) results_by_trace: list[tuple[str, list[CheckerResult]]] = [] diff --git a/traincheck/collect_trace.py b/traincheck/collect_trace.py index e45e996f..2f57f3f9 100644 --- a/traincheck/collect_trace.py +++ b/traincheck/collect_trace.py @@ -432,11 +432,12 @@ def main(): args.warm_up_steps = config.INSTRUMENTATION_POLICY["warm_up"] # set up logging + _log_fmt = "[TrainCheck] %(levelname)s: %(message)s" if args.debug_mode: - logging.basicConfig(level=logging.DEBUG) + logging.basicConfig(level=logging.DEBUG, format=_log_fmt) os.environ["TRAINCHECK_DEBUG"] = "1" else: - logging.basicConfig(level=logging.INFO) + logging.basicConfig(level=logging.INFO, format=_log_fmt) logger = logging.getLogger(__name__) diff --git a/traincheck/infer_engine.py b/traincheck/infer_engine.py index 267ef72b..917cb530 100644 --- a/traincheck/infer_engine.py +++ b/traincheck/infer_engine.py @@ -285,7 +285,7 @@ def main(): logging.basicConfig( filename=f'traincheck_infer_engine_{datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}_{pid}.log', level=log_level, - format="%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)s - %(funcName)20s()] - %(message)s", + format="%(asctime)s - [TrainCheck] %(levelname)s - [%(filename)s:%(lineno)s - %(funcName)20s()] - %(message)s", ) disabled_relations: list[Relation] = [] From e8ec62e2135582dbb7900e264371c9e60c95bc0c Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Wed, 18 Mar 2026 23:33:36 -0400 Subject: [PATCH 16/28] fix: silence instrumentor noise during traincheck-collect MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - collect_trace: use force=True + WARNING default level so basicConfig format actually applies and INFO chatter is suppressed at runtime - source_file: demote all annotate_stage insertion logger.info β†’ debug - dumper: demote attribute-dump failure logger.warning β†’ debug (torch internals routinely fail, not actionable) - call_graph_parser: convert all print() β†’ logger.debug(); add logger Co-Authored-By: Claude Sonnet 4.6 --- traincheck/collect_trace.py | 6 +++--- traincheck/instrumentor/dumper.py | 2 +- traincheck/instrumentor/source_file.py | 16 ++++++++-------- .../graph_generator/call_graph_parser.py | 19 ++++++++++--------- 4 files changed, 22 insertions(+), 21 deletions(-) diff --git a/traincheck/collect_trace.py b/traincheck/collect_trace.py index 2f57f3f9..35eb9616 100644 --- a/traincheck/collect_trace.py +++ b/traincheck/collect_trace.py @@ -431,13 +431,13 @@ def main(): if args.warm_up_steps is None: args.warm_up_steps = config.INSTRUMENTATION_POLICY["warm_up"] - # set up logging + # set up logging (force=True overrides any handler already added by imported libs) _log_fmt = "[TrainCheck] %(levelname)s: %(message)s" if args.debug_mode: - logging.basicConfig(level=logging.DEBUG, format=_log_fmt) + logging.basicConfig(level=logging.DEBUG, format=_log_fmt, force=True) os.environ["TRAINCHECK_DEBUG"] = "1" else: - logging.basicConfig(level=logging.INFO, format=_log_fmt) + logging.basicConfig(level=logging.WARNING, format=_log_fmt, force=True) logger = logging.getLogger(__name__) diff --git a/traincheck/instrumentor/dumper.py b/traincheck/instrumentor/dumper.py index a9dcf5e2..112337d8 100644 --- a/traincheck/instrumentor/dumper.py +++ b/traincheck/instrumentor/dumper.py @@ -392,7 +392,7 @@ def convert_var_to_dict(var, include_tensor_data=True, dump_config=None) -> dict and isinstance(var, torch.Tensor) and not include_tensor_data ): - logger.warning( + logger.debug( f"Failed to get attribute {attr_name} of object type {type(var)}, skipping it for all following dumps for this attribute." ) if var_type not in skip_attrs_due_to_errs: diff --git a/traincheck/instrumentor/source_file.py b/traincheck/instrumentor/source_file.py index 850527fe..5ddb82f7 100644 --- a/traincheck/instrumentor/source_file.py +++ b/traincheck/instrumentor/source_file.py @@ -736,7 +736,7 @@ def has_stage(src: str, name: str) -> bool: for stage_name, present in orig_has.items(): if present: - logger.info( + logger.debug( _ctx( f"Stage '{stage_name}' already present in source; skip adding this stage." ) @@ -819,13 +819,13 @@ def at_attr(name: str) -> bool: indent = m.group(0) new_lines.append(f'{indent}annotate_stage("{stage}")\n') inserted_count[stage] += 1 - logger.info( + logger.debug( _ctx( f"Inserted stage '{stage}' before line {lineno}: {line.strip()}" ) ) else: - logger.info( + logger.debug( _ctx( f"Skip inserting '{stage}' at line {lineno} (previous non-empty line already has it)." ) @@ -858,7 +858,7 @@ def _find_annotate_import_idx(lines): lines_list.insert(insert_idx, "from traincheck import annotate_stage\n") annot_import_idx = insert_idx inserted_count["import"] += 1 - logger.info( + logger.debug( _ctx( f"Inserted import 'from traincheck import annotate_stage' at line {annot_import_idx + 1}." ) @@ -917,13 +917,13 @@ def is_single_line_triple_quoted_string(line: str, quote: str) -> bool: if not (("annotate_stage" in prev) and ("init" in prev)): nl.insert(insert_at, f'{body_indent}annotate_stage("init")\n') inserted_count["init"] += 1 - logger.info( + logger.debug( _ctx( f"Inserted stage 'init' at start of main() body (line {insert_at + 1})." ) ) else: - logger.info( + logger.debug( _ctx( "Skip inserting 'init' inside main(): previous non-empty line already has it." ) @@ -966,13 +966,13 @@ def is_single_line_triple_quoted_string(line: str, quote: str) -> bool: if not (("annotate_stage" in next_line) and ("init" in next_line)): lines2.insert(insert_at, 'annotate_stage("init")\n') inserted_count["init"] += 1 - logger.info( + logger.debug( _ctx( f"Inserted stage 'init' right after annotate_stage import at line {insert_at + 1}." ) ) else: - logger.info( + logger.debug( _ctx( "Skip inserting 'init': next non-empty line after annotate_stage import is already init." ) diff --git a/traincheck/static_analyzer/graph_generator/call_graph_parser.py b/traincheck/static_analyzer/graph_generator/call_graph_parser.py index 5ce9d258..dd997b26 100644 --- a/traincheck/static_analyzer/graph_generator/call_graph_parser.py +++ b/traincheck/static_analyzer/graph_generator/call_graph_parser.py @@ -1,11 +1,14 @@ ## Step 1: filer out the lines from func_level.log with the format - function_depth, e.g. - 2 import importlib +import logging import os import re from traincheck.instrumentor.proxy_wrapper.proxy_observer import add_observer_to_func +logger = logging.getLogger(__name__) + def unparse_module(module_name, level=0): if level > 3: @@ -19,10 +22,9 @@ def unparse_module(module_name, level=0): last_name = module_name.split(".")[-1] try: func_obj = getattr(module, last_name) - # print(f"object {last_name} found in module {'.'.join(module_name.split('.')[:-1])}") return func_obj except AttributeError: - print( + logger.debug( f"object {last_name} not found in module {'.'.join(module_name.split('.')[:-1])}" ) # Ziming: from out observation this typically just mean a function call contains a local class or function call, so we can just pass @@ -39,7 +41,9 @@ def add_observer(module_name, function_name, observe_then_unproxy=False): # module could be a class here, load the class and get the function module = unparse_module(module_name) if module is None: - print(f"error finding {function_name}: module {module_name} not found") + logger.debug( + f"error finding {function_name}: module {module_name} not found" + ) try: # Retrieve the function or property @@ -48,13 +52,13 @@ def add_observer(module_name, function_name, observe_then_unproxy=False): # Check if it's a property before proceeding if isinstance(function, property): - print( + logger.debug( f"Skipping property function: {function_name} in module: {module_name}" ) return # Apply observer to non-property functions - print(f"Observe function: {function_name} found in module: {module}") + logger.debug(f"Observe function: {function_name} found in module: {module}") setattr( module, function_name, @@ -62,9 +66,8 @@ def add_observer(module_name, function_name, observe_then_unproxy=False): ) except AttributeError: - print(f"function {function_name} not found in module {module_name}") + logger.debug(f"function {function_name} not found in module {module_name}") return - # print(f'function: {function} found in module: {module}') # read the func_level.log file @@ -104,11 +107,9 @@ def add_observer_given_call_graph( # save those with function_depth <= depth if observe_up_to_depth: if int(function_depth) <= depth: - # print(f'module_name: {module_name}, function_name: {function_name}, function_depth: {function_depth}') add_observer(module_name, function_name, observe_then_unproxy) else: if int(function_depth) == depth: - # print(f'module_name: {module_name}, function_name: {function_name}, function_depth: {function_depth}') add_observer(module_name, function_name, observe_then_unproxy) From a2be7466e1617de88db191bc2ea5820a54e6761e Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Wed, 18 Mar 2026 23:40:05 -0400 Subject: [PATCH 17/28] fix: suppress deprecation warnings during attribute probing in dumper safe_getattr() iterates all attributes of instrumented objects; some (e.g. torchvision dataset's .test_data) fire warnings.warn() on access. Wrap getattr in warnings.catch_warnings(simplefilter=ignore) so third-party deprecation warnings don't leak to the user's stderr. Co-Authored-By: Claude Sonnet 4.6 --- traincheck/instrumentor/dumper.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/traincheck/instrumentor/dumper.py b/traincheck/instrumentor/dumper.py index 112337d8..27c458f8 100644 --- a/traincheck/instrumentor/dumper.py +++ b/traincheck/instrumentor/dumper.py @@ -316,11 +316,15 @@ class NOT_FOUND: def safe_getattr(obj, attr_name): + import warnings + try: - attr = getattr(obj, attr_name, NOT_FOUND) - if attr is NOT_FOUND: - if issubclass(type(obj), dict): - attr = dict.get(obj, attr_name, NOT_FOUND) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + attr = getattr(obj, attr_name, NOT_FOUND) + if attr is NOT_FOUND: + if issubclass(type(obj), dict): + attr = dict.get(obj, attr_name, NOT_FOUND) return attr except Exception: return NOT_FOUND From 6ae7c8ea2c8587b03b5bf7e7c5cb1c94eabe818d Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Wed, 18 Mar 2026 23:42:58 -0400 Subject: [PATCH 18/28] fix: suppress warnings during torch module instrumentation Importing torch.distributed and other submodules during the two-pass instrumentation scan fires deprecation UserWarnings (e.g. reduce_op). Wrap both passes in warnings.catch_warnings(simplefilter=ignore) so these don't leak to the user's terminal. Co-Authored-By: Claude Sonnet 4.6 --- traincheck/instrumentor/tracer.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/traincheck/instrumentor/tracer.py b/traincheck/instrumentor/tracer.py index 50e532a5..8a6ac1e1 100644 --- a/traincheck/instrumentor/tracer.py +++ b/traincheck/instrumentor/tracer.py @@ -627,6 +627,8 @@ def __init__( ) def instrument(self) -> int: + import warnings + if not self.instrumenting: return 0 @@ -639,9 +641,11 @@ def instrument(self) -> int: "First pass: Recursive scan of the module" ) assert isinstance(self.target, (types.ModuleType, type)), "Invalid target" - first_pass_instrumented_count += self._instrument_module( - self.target, visited_file_paths, True, 0 - ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + first_pass_instrumented_count += self._instrument_module( + self.target, visited_file_paths, True, 0 + ) get_instrumentation_logger_for_process().info( "Files scanned %s", "\n".join(sorted(visited_file_paths)) ) @@ -665,13 +669,15 @@ def instrument(self) -> int: f"Instrumenting module {module_path}" ) - pymodule = importlib.import_module(module_path) - second_pass_instrumented_count += self._instrument_module( - pymodule, - visited_file_paths, - False, - 0, - ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + pymodule = importlib.import_module(module_path) + second_pass_instrumented_count += self._instrument_module( + pymodule, + visited_file_paths, + False, + 0, + ) get_instrumentation_logger_for_process().info( "Second pass instrumented %d functions", second_pass_instrumented_count ) From db9bf45e5d80633ce67235f3a6a3aa34bc829f3a Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Thu, 19 Mar 2026 19:40:40 -0400 Subject: [PATCH 19/28] fix: resolve online checker crashes and add dynamic trace file detection MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Use .get() for args/kwargs (PRE) and return_values/exception (POST) when populating pt_map, since functions instrumented without argument capture (e.g. Adadelta.step) omit these fields β€” previously caused KeyError that silently dropped the record and prevented APIContainRelation from triggering - Add FolderCreationHandler that watches each trace folder and dynamically attaches a StreamLogHandler for any trace_*/proxy_log.json file created after the checker starts, fixing the checker getting stuck when training had not yet created trace files at startup time - Set float('inf') sentinel in read_time_map after _save_initial_content so files with no live updates don't block min_read_time indefinitely - Rename _get_api_args_map_to_check β†’ all_needed_args_api in Checker_data and sort_inv_file for clarity Co-Authored-By: Claude Sonnet 4.6 --- .../onlinechecker/streamhandler_filesystem.py | 109 +++++++++++++----- traincheck/onlinechecker/utils.py | 4 +- 2 files changed, 79 insertions(+), 34 deletions(-) diff --git a/traincheck/onlinechecker/streamhandler_filesystem.py b/traincheck/onlinechecker/streamhandler_filesystem.py index 2b69c5cc..46439c42 100644 --- a/traincheck/onlinechecker/streamhandler_filesystem.py +++ b/traincheck/onlinechecker/streamhandler_filesystem.py @@ -4,7 +4,7 @@ import re import time -from watchdog.events import FileSystemEventHandler +from watchdog.events import FileCreatedEvent, FileSystemEventHandler from watchdog.observers.polling import PollingObserver from traincheck.config import config @@ -43,7 +43,7 @@ def __init__(self, file_path, checker_data: Checker_data): self.needed_vars = checker_data.needed_vars self.needed_apis = checker_data.needed_apis - self._get_api_args_map_to_check = checker_data._get_api_args_map_to_check + self.all_needed_args_api = checker_data.all_needed_args_api self.min_read_time = checker_data.min_read_time self.lock = checker_data.lock @@ -61,11 +61,22 @@ def _save_initial_content(self): self.logger.info(f"Processing initial content from {self.file_path}") self.fp.seek(0) lines = self.fp.readlines() - if not lines: - return - self._handle_line(lines) - self.logger.info(f"Initial content from {self.file_path} processed.") + if lines: + self._handle_line(lines) + self.logger.info(f"Initial content from {self.file_path} processed.") + + # Mark this file as fully read so it doesn't block other files that + # have records at later timestamps. Live on_modified updates override + # this when new records actually arrive. + with self.cond: + self.checker_data.read_time_map[self.file_path] = float("inf") + pre_min = self.checker_data.min_read_time + self.checker_data.min_read_path, self.checker_data.min_read_time = min( + self.checker_data.read_time_map.items(), default=(None, None) + ) + if pre_min != self.checker_data.min_read_time: + self.checker_data.cond.notify_all() def on_modified(self, event): if os.path.abspath(event.src_path) != os.path.abspath(self.file_path): @@ -181,33 +192,37 @@ def _set_func_map(self, trace_record): ) if trace_type == TraceLineType.FUNC_CALL_PRE: self.pt_map[ptname][func_call_id].pre_record = trace_record - self.pt_map[ptname][func_call_id].args = trace_record["args"] - self.pt_map[ptname][func_call_id].kwargs = trace_record["kwargs"] + self.pt_map[ptname][func_call_id].args = trace_record.get("args") + self.pt_map[ptname][func_call_id].kwargs = trace_record.get( + "kwargs" + ) elif trace_type == TraceLineType.FUNC_CALL_POST: assert self.pt_map[ptname][func_call_id].pre_record is not None self.pt_map[ptname][func_call_id].post_record = trace_record - self.pt_map[ptname][func_call_id].return_values = trace_record[ + self.pt_map[ptname][func_call_id].return_values = trace_record.get( "return_values" - ] + ) elif trace_type == TraceLineType.FUNC_CALL_POST_EXCEPTION: self.pt_map[ptname][func_call_id].post_record = trace_record - self.pt_map[ptname][func_call_id].exception = trace_record[ + self.pt_map[ptname][func_call_id].exception = trace_record.get( "exception" - ] + ) if trace_type == TraceLineType.FUNC_CALL_PRE: - if function_name in self.checker_data._get_api_args_map_to_check: - if "args" in trace_record: - if "meta_vars.step" not in trace_record: - trace_record["meta_vars.step"] = -1 - step = trace_record["meta_vars.step"] - if function_name not in self.args_map: - self.args_map[function_name] = {} - if step not in self.args_map[function_name]: - self.args_map[function_name][step] = {} - if ptid not in self.args_map[function_name][step]: - self.args_map[function_name][step][ptid] = [] - self.args_map[function_name][step][ptid].append(trace_record) + if function_name in self.checker_data.all_needed_args_api: + assert ( + "args" in trace_record and "kwargs" in trace_record + ), f"Trace record for function call {function_name} does not contain args or kwargs: {trace_record}" + if "meta_vars.step" not in trace_record: + trace_record["meta_vars.step"] = -1 + step = trace_record["meta_vars.step"] + if function_name not in self.args_map: + self.args_map[function_name] = {} + if step not in self.args_map[function_name]: + self.args_map[function_name][step] = {} + if ptid not in self.args_map[function_name][step]: + self.args_map[function_name][step][ptid] = [] + self.args_map[function_name][step][ptid].append(trace_record) if ( ".__enter__" in function_name @@ -310,29 +325,59 @@ def _set_read_time(self, trace_record): self.checker_data.cond.notify_all() +class FolderCreationHandler(FileSystemEventHandler): + """Watches a trace folder and dynamically attaches StreamLogHandler for new trace files.""" + + def __init__(self, trace_folder, checker_data: Checker_data, observer): + self.trace_folder = os.path.abspath(trace_folder) + self.checker_data = checker_data + self.observer = observer + self.seen_files: set[str] = set() + self.logger = logging.getLogger(__name__) + + def _is_trace_file(self, filename: str) -> bool: + return filename.startswith("trace_") or filename.endswith("proxy_log.json") + + def attach(self, file_path: str): + """Create and schedule a StreamLogHandler for a newly discovered file.""" + if file_path in self.seen_files: + return + self.seen_files.add(file_path) + self.logger.info(f"New trace file detected, watching: {file_path}") + handler = StreamLogHandler(file_path, self.checker_data) + self.observer.schedule( + handler, path=os.path.dirname(file_path), recursive=False + ) + + def on_created(self, event): + if isinstance(event, FileCreatedEvent): + filename = os.path.basename(event.src_path) + if self._is_trace_file(filename): + self.attach(os.path.abspath(event.src_path)) + + def run_stream_monitor(traces, trace_folders, checker_data: Checker_data): """Run the stream monitor to watch the trace files and folders.""" logger = logging.getLogger(__name__) observer = PollingObserver() - handlers = [] if traces is not None: file_path = os.path.abspath(traces[0]) handler = StreamLogHandler(file_path, checker_data) - handlers.append(handler) watch_dir = os.path.dirname(file_path) observer.schedule(handler, path=watch_dir, recursive=False) logger.info(f"Watching: {file_path}") if trace_folders is not None: for trace_folder in trace_folders: + folder_abs = os.path.abspath(trace_folder) + creation_handler = FolderCreationHandler(folder_abs, checker_data, observer) + observer.schedule(creation_handler, path=folder_abs, recursive=False) + + # Pick up any trace files that already exist in the folder. for file in sorted(os.listdir(trace_folder)): if file.startswith("trace_") or file.endswith("proxy_log.json"): - file_path = os.path.join(trace_folder, file) - handler = StreamLogHandler(file_path, checker_data) - handlers.append(handler) - watch_dir = os.path.dirname(file_path) - observer.schedule(handler, path=watch_dir, recursive=False) - logger.info(f"Watching: {file_path}") + file_path = os.path.join(folder_abs, file) + creation_handler.attach(file_path) observer.start() return observer diff --git a/traincheck/onlinechecker/utils.py b/traincheck/onlinechecker/utils.py index 987f22a6..c174d917 100644 --- a/traincheck/onlinechecker/utils.py +++ b/traincheck/onlinechecker/utils.py @@ -9,10 +9,10 @@ class Checker_data: """Data structure for online checker threads. Holds the needed data and the queue for processing.""" def __init__(self, needed_data): - needed_vars, needed_apis, _get_api_args_map_to_check = needed_data + needed_vars, needed_apis, all_needed_args_api = needed_data self.needed_vars = needed_vars self.needed_apis = needed_apis - self._get_api_args_map_to_check = _get_api_args_map_to_check + self.all_needed_args_api = all_needed_args_api self.check_queue = queue.Queue() self.varid_map = {} From 93a3e9c972c27ee48cebcf0ec60ac74ef65b699c Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Thu, 19 Mar 2026 19:40:49 -0400 Subject: [PATCH 20/28] fix: miscellaneous online checker bug fixes - consistency_relation: use .get() instead of direct key access in online_check to avoid KeyError when a variable lacks a tracked attribute - contain_relation: use ASCII arrow in to_display_name for terminal safety - collect_trace: demote InputOutputParam warning to debug to reduce noise Co-Authored-By: Claude Sonnet 4.6 --- traincheck/collect_trace.py | 2 +- traincheck/invariant/consistency_relation.py | 2 +- traincheck/invariant/contain_relation.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/traincheck/collect_trace.py b/traincheck/collect_trace.py index 35eb9616..2d307e49 100644 --- a/traincheck/collect_trace.py +++ b/traincheck/collect_trace.py @@ -104,7 +104,7 @@ def merge(a: dict, b: dict, path=[]): func_instr_opts[func_name]["dump_args"] = True func_instr_opts[func_name]["dump_ret"] = True # TODO: convert the arguments to instr_opts_dict (currently not possible as the index indicates the index of the argument/ret value among other tensors not all arguments) - logger.warning( + logger.debug( "Currently not supporting fine-grained dumping of arguments and return values for InputOutputParam" ) diff --git a/traincheck/invariant/consistency_relation.py b/traincheck/invariant/consistency_relation.py index 5e41f2ea..11ab8705 100644 --- a/traincheck/invariant/consistency_relation.py +++ b/traincheck/invariant/consistency_relation.py @@ -626,7 +626,7 @@ def get_check_attr(param, varid): for var2 in checker_data.type_map[ref_param.var_type]: if var2 == varid: continue - if checker_data.varid_map[var2][ref_param.attr_name] is None: + if checker_data.varid_map[var2].get(ref_param.attr_name) is None: logger.debug( f"Attribute {ref_param.attr_name} not found in variable {var2}" ) diff --git a/traincheck/invariant/contain_relation.py b/traincheck/invariant/contain_relation.py index 33ab5905..ba35e9b8 100644 --- a/traincheck/invariant/contain_relation.py +++ b/traincheck/invariant/contain_relation.py @@ -371,7 +371,7 @@ def _fmt_val(v: object) -> str: return "non-zero" if v == "non_zero" else str(v) if pre is not _NOT_SET and post is not _NOT_SET: - return f"{parent_short}() changes {var_short}.{attr}: {_fmt_val(pre)} β†’ {_fmt_val(post)}" + return f"{parent_short}() changes {var_short}.{attr}: {_fmt_val(pre)} --> {_fmt_val(post)}" if const is not _NOT_SET: return f"{parent_short}() sees {var_short}.{attr} = {const}" return f"{parent_short}() accesses {var_short}.{attr}" From e8a7520f89686f5f317478a0cf7b4306af543dfd Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Thu, 19 Mar 2026 19:41:10 -0400 Subject: [PATCH 21/28] feat: rich online HTML report with step/stage annotations and checking progress MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit checker_online.py: - Track VIOLATION_DETAILS (step/stage pairs and sample trace per invariant) - Track TRIGGERED_INV (invariants checked at least once), ALL_INVS, CURRENT_STEP and CURRENT_STAGE from each processed trace record - Remove bare 'raise e' from API invariant exception handler so a single bad invariant check no longer crashes the entire checker loop - Pass new tracking state to build_online_report_data on every report emit checker_report.py: - Violations sorted by first violation step (earliest first) instead of count - Per-violation: first/last step with stage badge, full step list grouped by stage (e.g. [train] 1,2,3 Β· [eval] 100,101), expandable sample trace table - Stage badges with distinct colors for train/eval/val/test/inference; unknown stages get a hash-derived color from a fallback palette - New Checking Progress panel: stacked bar (passing/failing/not-triggered), collapsible list of not-yet-triggered invariants, pass rate card, and Current Step card showing latest step with stage badge Co-Authored-By: Claude Sonnet 4.6 --- traincheck/checker_online.py | 61 ++- traincheck/reporting/checker_report.py | 672 +++++++++++++++++++++++-- 2 files changed, 680 insertions(+), 53 deletions(-) diff --git a/traincheck/checker_online.py b/traincheck/checker_online.py index 9279b77c..05cdb70d 100644 --- a/traincheck/checker_online.py +++ b/traincheck/checker_online.py @@ -23,6 +23,11 @@ ) NUM_VIOLATIONS = 0 FAILED_INV: dict[Invariant, int] = {} +VIOLATION_DETAILS: dict[Invariant, dict] = {} +TRIGGERED_INV: set[Invariant] = set() +ALL_INVS: list[Invariant] = [] +CURRENT_STEP: int | None = None +CURRENT_STAGE: str | None = None TOTAL_INVARIANTS = 0 RELATION_TOTALS: dict[str, int] = {} REPORTER: ReportEmitter | None = None @@ -100,7 +105,7 @@ def sort_inv_file(invariants): vartype_to_invs: dict[str, dict[str, list[Invariant]]] = {} needed_vars = set() needed_apis = set() - _get_api_args_map_to_check = set() + all_needed_args_api = set() for inv in invs: assert ( inv.precondition is not None @@ -114,7 +119,7 @@ def sort_inv_file(invariants): if needed_api is not None: needed_apis.update(needed_api) if needed_args_api is not None: - _get_api_args_map_to_check.update(needed_args_api) + all_needed_args_api.update(needed_args_api) for param in params: if isinstance(param, VarTypeParam): if param.var_type not in vartype_to_invs: @@ -127,7 +132,7 @@ def sort_inv_file(invariants): param_to_invs[param] = [] param_to_invs[param].append(inv) logger.info("Sorting done.") - needed_data = (needed_vars, needed_apis, _get_api_args_map_to_check) + needed_data = (needed_vars, needed_apis, all_needed_args_api) return invs, param_to_invs, vartype_to_invs, needed_data @@ -139,6 +144,29 @@ def get_violated_pair_hash(trace_pair): return tuple(sorted((h1, h2), reverse=True)) +_MAX_TRACKED_STEPS = 500 # cap on steps stored per invariant + + +def _record_violation_details( + inv: Invariant, result, violation_details: dict[Invariant, dict] +): + """Update per-invariant (step, stage) list and sample trace for the HTML report.""" + trace = result.trace or [] + step_stages = [ + (r["meta_vars.step"], r.get("meta_vars.stage")) + for r in trace + if isinstance(r, dict) and r.get("meta_vars.step") is not None + ] + if inv not in violation_details: + violation_details[inv] = {"step_stages": [], "sample_trace": None} + detail = violation_details[inv] + remaining = _MAX_TRACKED_STEPS - len(detail["step_stages"]) + if remaining > 0: + detail["step_stages"].extend(step_stages[:remaining]) + if detail["sample_trace"] is None and trace: + detail["sample_trace"] = trace[:8] + + def _emit_report(force: bool = False): if REPORTER is None: return @@ -149,6 +177,11 @@ def _emit_report(force: bool = False): total_violations=NUM_VIOLATIONS, failed_inv=FAILED_INV, relation_totals=RELATION_TOTALS, + violation_details=VIOLATION_DETAILS, + triggered_inv=TRIGGERED_INV, + all_invs=ALL_INVS, + current_step=CURRENT_STEP, + current_stage=CURRENT_STAGE, ) report_state = (NUM_VIOLATIONS, len(FAILED_INV)) REPORTER.emit(report_data, force=force, report_state=report_state) @@ -160,6 +193,11 @@ def check( global OBSERVER global NUM_VIOLATIONS global FAILED_INV + global VIOLATION_DETAILS + global TRIGGERED_INV + global ALL_INVS + global CURRENT_STEP + global CURRENT_STAGE global TOTAL_INVARIANTS global RELATION_TOTALS @@ -171,6 +209,7 @@ def check( invs, param_to_invs, vartype_to_invs, needed_data = sort_inv_file(invariants) TOTAL_INVARIANTS = len(invs) + ALL_INVS = list(invs) RELATION_TOTALS = defaultdict(int) for inv in invs: RELATION_TOTALS[inv.relation.__name__] += 1 @@ -198,6 +237,13 @@ def check( else: break + step = trace_record.get("meta_vars.step") + stage = trace_record.get("meta_vars.stage") + if step is not None: + CURRENT_STEP = step + if stage is not None: + CURRENT_STAGE = stage + if "var_name" in trace_record and trace_record["var_name"] is not None: varid = VarInstId( trace_record["process_id"], @@ -216,6 +262,7 @@ def check( result = inv.online_check( trace_record, checker_data, check_relation_first ) + TRIGGERED_INV.add(inv) if not result.check_passed: violated_pair = get_violated_pair_hash(result.trace) if inv not in violated_pairs: @@ -227,6 +274,9 @@ def check( if inv not in FAILED_INV: FAILED_INV[inv] = 0 FAILED_INV[inv] += 1 + _record_violation_details( + inv, result, VIOLATION_DETAILS + ) NUM_VIOLATIONS += 1 result.set_id_and_detection_time( NUM_VIOLATIONS, time.monotonic_ns() @@ -258,16 +308,18 @@ def check( result = inv.online_check( trace_record, checker_data, check_relation_first ) + TRIGGERED_INV.add(inv) if not result.check_passed: if inv not in FAILED_INV: FAILED_INV[inv] = 0 FAILED_INV[inv] += 1 + _record_violation_details(inv, result, VIOLATION_DETAILS) NUM_VIOLATIONS += 1 result.set_id_and_detection_time( NUM_VIOLATIONS, time.monotonic_ns() ) logger.error( - f"Violated id {NUM_VIOLATIONS}:\nInvariant {inv} violated near time {trace_record['time']}" + f"Violated id {NUM_VIOLATIONS}:\nInvariant {inv.text_description} violated near time {trace_record['time'], trace_record['meta_vars.step']}" ) with open(output_file, "a") as f: json.dump( @@ -279,6 +331,7 @@ def check( logger.error( f"Error when checking invariant {inv.text_description} with trace {trace_record}: {e}" ) + raise e _emit_report() diff --git a/traincheck/reporting/checker_report.py b/traincheck/reporting/checker_report.py index c07eebb7..b94a76fa 100644 --- a/traincheck/reporting/checker_report.py +++ b/traincheck/reporting/checker_report.py @@ -195,6 +195,70 @@ def build_offline_report_data( } +_TRACE_DISPLAY_KEYS = ( + "function", + "meta_vars.step", + "meta_vars.stage", + "type", + "var_name", + "var_type", +) +_TRACE_SKIP_PREFIXES = ("attributes._TRAINCHECK_",) + +# Known stage β†’ badge color (bg, text) +_STAGE_COLORS: dict[str, tuple[str, str]] = { + "train": ("#2f6fed", "#fff"), + "training": ("#2f6fed", "#fff"), + "eval": ("#2fb679", "#fff"), + "evaluation": ("#2fb679", "#fff"), + "validation": ("#2fb679", "#fff"), + "val": ("#2fb679", "#fff"), + "test": ("#f2b233", "#333"), + "inference": ("#9b59b6", "#fff"), + "pretrain": ("#1abc9c", "#fff"), +} +_STAGE_FALLBACK_PALETTE = [ + ("#e24c4b", "#fff"), + ("#e67e22", "#fff"), + ("#e91e63", "#fff"), + ("#00bcd4", "#fff"), + ("#607d8b", "#fff"), +] + + +def _stage_badge_style(stage: str) -> str: + """Return inline CSS background/color for a stage badge.""" + key = stage.lower() + if key in _STAGE_COLORS: + bg, fg = _STAGE_COLORS[key] + else: + bg, fg = _STAGE_FALLBACK_PALETTE[hash(key) % len(_STAGE_FALLBACK_PALETTE)] + return f"background:{bg};color:{fg}" + + +def _summarize_trace_records(trace: list[dict] | None) -> list[dict]: + """Return a compact, HTML-safe subset of trace records for display.""" + if not trace: + return [] + out = [] + for rec in trace: + if not isinstance(rec, dict): + continue + row: dict = {} + for key in _TRACE_DISPLAY_KEYS: + val = rec.get(key) + if val is not None: + row[key] = str(val) + # Add first attribute-style key that is not an internal one + for key, val in rec.items(): + if key.startswith("attributes.") and val is not None: + if not any(key.startswith(p) for p in _TRACE_SKIP_PREFIXES): + row[key] = str(val) + break + out.append(row) + return out + + def build_online_report_data( *, generated_at: str, @@ -203,24 +267,77 @@ def build_online_report_data( total_violations: int, failed_inv: dict[Invariant, int], relation_totals: dict[str, int], + violation_details: dict | None = None, + triggered_inv: set | None = None, + all_invs: list | None = None, + current_step: int | None = None, + current_stage: str | None = None, ) -> dict: relation_violations: dict[str, int] = defaultdict(int) for inv in failed_inv: relation_violations[inv.relation.__name__] += 1 - top_pairs = sorted( - ((count, inv) for inv, count in failed_inv.items()), - key=lambda item: item[0], - reverse=True, - )[:10] - top_violations = [ - { + if violation_details is None: + violation_details = {} + + def _make_entry(inv: Invariant, count: int) -> dict: + detail = violation_details.get(inv, {}) + step_stages: list[tuple] = detail.get("step_stages") or [] + sample_trace: list[dict] | None = detail.get("sample_trace") + steps = [s for s, _ in step_stages] + first_step = min(steps) if steps else None + last_step = max(steps) if steps else None + # stage of the first/last violation event + first_stage = ( + next((st for s, st in step_stages if s == first_step), None) + if first_step is not None + else None + ) + last_stage = ( + next((st for s, st in reversed(step_stages) if s == last_step), None) + if last_step is not None + else None + ) + # deduplicated, sorted (step, stage) pairs for the expanded view + unique_step_stages = sorted(set(step_stages), key=lambda x: x[0]) + return { "label": _format_invariant_label(inv), "relation": inv.relation.__name__, "count": count, + "first_step": first_step, + "first_stage": first_stage, + "last_step": last_step, + "last_stage": last_stage, + "step_stages": unique_step_stages[:100], # cap for HTML size + "sample_trace": _summarize_trace_records(sample_trace), } - for count, inv in top_pairs - ] + + # Sort by first violation step (earliest first), then by count descending. + def _sort_key(item): + inv, count = item + detail = violation_details.get(inv, {}) + step_stages = detail.get("step_stages") or [] + steps = [s for s, _ in step_stages] + first = min(steps) if steps else float("inf") + return (first, -count) + + sorted_pairs = sorted(failed_inv.items(), key=_sort_key)[:20] + top_violations = [_make_entry(inv, count) for inv, count in sorted_pairs] + + # Progress tracking + triggered_count = len(triggered_inv) if triggered_inv is not None else 0 + failing_count = len(failed_inv) + passing_count = triggered_count - failing_count + not_triggered_count = total_invariants - triggered_count + pass_rate = ( + round(passing_count / triggered_count * 100, 1) if triggered_count > 0 else None + ) + + not_triggered_labels: list[str] = [] + if all_invs is not None and triggered_inv is not None: + not_triggered_labels = [ + _format_invariant_label(inv) for inv in all_invs if inv not in triggered_inv + ][:50] relations = {} for relation_name, total in relation_totals.items(): @@ -237,6 +354,13 @@ def build_online_report_data( "not_triggered": None, "triggered": None, "violated_invariants": len(failed_inv), + # progress fields + "triggered_count": triggered_count, + "passing_count": passing_count, + "not_triggered_count": not_triggered_count, + "pass_rate": pass_rate, + "current_step": current_step, + "current_stage": current_stage, } return { @@ -247,9 +371,42 @@ def build_online_report_data( "relations": relations, "traces": [], "top_violations": top_violations, + "not_triggered_labels": not_triggered_labels, } +def _render_stage_badge(stage: str | None, esc_fn) -> str: + if not stage: + return "" + style = _stage_badge_style(stage) + return f'{esc_fn(stage)}' + + +def _render_step_stages_html( + step_stages: list[tuple], esc_fn, max_per_group: int = 15 +) -> str: + """Render a compact stage-grouped step list as HTML.""" + if not step_stages: + return "β€”" + # Group consecutive same-stage runs + groups: list[tuple[str | None, list[int]]] = [] + for step, stage in step_stages: + if groups and groups[-1][0] == stage: + groups[-1][1].append(step) + else: + groups.append((stage, [step])) + parts = [] + for stage, steps in groups: + shown = steps[:max_per_group] + more = len(steps) - len(shown) + steps_str = ", ".join(str(s) for s in shown) + if more > 0: + steps_str += f' +{more} more' + badge = _render_stage_badge(stage, esc_fn) + parts.append(f"{badge}{steps_str}") + return ' Β· '.join(parts) + + def _render_bar_segment(width_pct: float, class_name: str) -> str: width_pct = max(0.0, min(100.0, width_pct)) return f'' @@ -269,27 +426,108 @@ def percent(part: int, total: int) -> float: traces = report_data.get("traces", []) top_violations = report_data.get("top_violations", []) - top_items = [] - for entry in top_violations: - label = esc(str(entry.get("label", ""))) - relation = esc(str(entry.get("relation", ""))) - count = entry.get("count") - first_step = entry.get("first_step") - count_html = f'{count}' if count else "" - if first_step is not None: - step_note = f"first seen at step {first_step}" - if count and count > 1: - step_note += f" Β· {count} occurrences" - detail = esc(f"{entry.get('relation', '')} β€” {step_note}") - else: - detail = relation - if count and count > 1: - detail = esc(f"{entry.get('relation', '')} β€” {count} occurrences") - top_items.append( - f'
  • {label}' - f'{detail}{count_html}
  • ' + # Build top violations HTML differently per mode. + # For online mode: step-sorted table with expandable trace rows. + # For offline mode: simple list (unchanged). + top_list = "" # used only in offline mode below + top_table_html = "" # used only in online mode below + + if mode == "online": + rows = [] + for entry in top_violations: + label = esc(str(entry.get("label", ""))) + relation = esc(str(entry.get("relation", ""))) + count = entry.get("count", "") + first_step = entry.get("first_step") + first_stage = entry.get("first_stage") + last_step = entry.get("last_step") + last_stage = entry.get("last_stage") + step_stages: list = entry.get("step_stages") or [] + sample_trace = entry.get("sample_trace") or [] + + def _step_with_badge(step, stage) -> str: + if step is None: + return "β€”" + badge = _render_stage_badge(stage, esc) + return f"{badge}{step}" + + first_step_html = _step_with_badge(first_step, first_stage) + last_step_html = _step_with_badge(last_step, last_stage) + steps_html = _render_step_stages_html(step_stages, esc) + + # Build sample trace table + if sample_trace: + all_keys: list[str] = [] + for rec in sample_trace: + for k in rec: + if k not in all_keys: + all_keys.append(k) + trace_head = "".join(f"{esc(k)}" for k in all_keys) + trace_rows_html = [] + for rec in sample_trace: + cells = [] + for k in all_keys: + val = rec.get(k, "") + # Style stage cells + if k == "meta_vars.stage" and val: + style = _stage_badge_style(val) + cell = ( + f'' + f"{esc(val)}" + ) + else: + cell = f"{esc(str(val))}" + cells.append(cell) + trace_rows_html.append(f"{''.join(cells)}") + trace_body = "\n".join(trace_rows_html) + expand_content = ( + f'
    Steps: {steps_html}
    ' + f'
    ' + f"{trace_head}" + f"{trace_body}
    " + ) + else: + expand_content = f'
    Steps: {steps_html}
    ' + + rows.append( + f"" + f'
    {label}' + f'
    {expand_content}
    ' + f'{relation}' + f'{first_step_html}' + f'{last_step_html}' + f'{count}' + f"" + ) + top_table_html = ( + '' + "" + f"{''.join(rows)}
    InvariantFirst StepLast StepCount
    " + if rows + else "

    No violations yet.

    " ) - top_list = "".join(top_items) or "
  • None
  • " + else: + top_items = [] + for entry in top_violations: + label = esc(str(entry.get("label", ""))) + relation = esc(str(entry.get("relation", ""))) + count = entry.get("count") + first_step = entry.get("first_step") + count_html = f'{count}' if count else "" + if first_step is not None: + step_note = f"first seen at step {first_step}" + if count and count > 1: + step_note += f" Β· {count} occurrences" + detail = esc(f"{entry.get('relation', '')} β€” {step_note}") + else: + detail = relation + if count and count > 1: + detail = esc(f"{entry.get('relation', '')} β€” {count} occurrences") + top_items.append( + f'
  • {label}' + f'{detail}{count_html}
  • ' + ) + top_list = "".join(top_items) or "
  • None
  • " trace_sections = [] for trace in traces: @@ -401,23 +639,114 @@ def percent(part: int, total: int) -> float: ) if mode == "online": + cur_step = overall.get("current_step") + cur_stage = overall.get("current_stage") + triggered_count = overall.get("triggered_count", 0) or 0 + passing_count = overall.get("passing_count", 0) or 0 + not_triggered_count = overall.get("not_triggered_count", 0) or 0 + pass_rate_val = overall.get("pass_rate") + total_invariants_n = overall["total_invariants"] or 0 + violated = overall.get("violated_invariants", 0) or 0 + + # Current step card β€” show stage badge inline + if cur_step is not None: + stage_badge = _render_stage_badge(cur_stage, esc) if cur_stage else "" + step_value_html = f"{stage_badge}{cur_step}" + step_sub = esc(f"stage: {cur_stage}") if cur_stage else "no stage info" + else: + step_value_html = "β€”" + step_sub = "waiting for first trace record" + + pass_rate_display = f"{pass_rate_val}%" if pass_rate_val is not None else "β€”" + card_html = f"""
    Total Invariants
    -
    {overall['total_invariants']}
    +
    {total_invariants_n}
    +
    +
    +
    Triggered
    +
    {triggered_count}
    +
    of {total_invariants_n} loaded
    +
    +
    +
    Pass Rate
    +
    {pass_rate_display}
    +
    {passing_count} passing Β· {violated} failing
    Violations
    {overall['failed']}
    -
    -
    Violated Invariants
    -
    {overall.get('violated_invariants', 0)}
    +
    +
    Current Step
    +
    {step_value_html}
    +
    {step_sub}
    """ relation_header = "RelationTotalViolated" - mode_note = '
    Online mode (partial coverage).
    ' + mode_note = '
    Online mode β€” checking in progress.
    ' + + # Progress panel (checking coverage bar + not-yet-triggered list) + not_triggered_labels: list[str] = report_data.get("not_triggered_labels", []) + bar_total = total_invariants_n or 1 + passing_pct = percent(passing_count, bar_total) + failing_pct = percent(violated, bar_total) + not_triggered_pct = percent(not_triggered_count, bar_total) + progress_bar = ( + _render_bar_segment(passing_pct, "bar-passed") + + _render_bar_segment(failing_pct, "bar-failed") + + _render_bar_segment(not_triggered_pct, "bar-not-triggered") + ) + + if not_triggered_labels: + nt_items = "".join( + f'
  • {esc(lbl)}
  • ' for lbl in not_triggered_labels + ) + suffix = ( + f" (showing first {len(not_triggered_labels)})" + if not_triggered_count > len(not_triggered_labels) + else "" + ) + nt_section = ( + f'
    ' + f"{not_triggered_count} not yet triggered{esc(suffix)}" + f'
      {nt_items}
    ' + f"
    " + ) + elif not_triggered_count > 0: + nt_section = f'

    {not_triggered_count} invariant(s) not yet triggered.

    ' + else: + nt_section = '

    All invariants have been triggered at least once.

    ' + + progress_panel = f""" +
    +
    +
    +

    Checking Progress

    +
    {triggered_count} of {total_invariants_n} invariants triggered so far
    +
    +
    + Passing{passing_count} + Failing{violated} + Not Triggered{not_triggered_count} +
    +
    +
    {progress_bar}
    +
    + Passing ({passing_count}) + Failing ({violated}) + Not Triggered ({not_triggered_count}) +
    + {nt_section} +
    + """ else: + total_checks = overall["total_checks"] or 0 + passed_checks = overall["passed"] or 0 + pass_rate = ( + round(passed_checks / total_checks * 100, 1) if total_checks else 0.0 + ) card_html = f"""
    Total Invariants
    @@ -431,6 +760,10 @@ def percent(part: int, total: int) -> float:
    Passed Checks
    {overall['passed']}
    +
    +
    Pass Rate
    +
    {pass_rate}%
    +
    Not Triggered
    {overall['not_triggered']}
    @@ -438,16 +771,44 @@ def percent(part: int, total: int) -> float: """ relation_header = "RelationFailedPassedNot Triggered" mode_note = "" + progress_panel = "" + + # surface first violation step in the top violations panel header + all_first_steps = [ + v["first_step"] for v in top_violations if v.get("first_step") is not None + ] + if all_first_steps: + first_step_note = ( + f'
    First violation at step {min(all_first_steps)}' + f" Β· {len(top_violations)} distinct invariant(s) violated
    " + ) + elif top_violations: + first_step_note = ( + f'
    {len(top_violations)}' + " distinct invariant(s) violated
    " + ) + else: + first_step_note = "" + + if mode == "online": + panel_subtitle = ( + "Sorted by first violation step β€” click an invariant to expand trace" + ) + panel_content = top_table_html + else: + panel_subtitle = "Most frequent violations observed" + panel_content = f'
      {top_list}
    ' top_panel = f"""
    -

    Top Violations

    -
    Most frequent violations observed
    +

    Violations

    +
    {panel_subtitle}
    + {first_step_note}
    -
      {top_list}
    + {panel_content}
    """ @@ -652,6 +1013,88 @@ def percent(part: int, total: int) -> float: text-transform: uppercase; letter-spacing: 0.04em; }} + .viol-table td {{ vertical-align: top; }} + .step-cell {{ white-space: nowrap; font-variant-numeric: tabular-nums; font-weight: 600; }} + .count-cell {{ white-space: nowrap; font-variant-numeric: tabular-nums; font-weight: 700; color: var(--failed); }} + .stage-badge {{ + display: inline-block; + font-size: 10px; + font-weight: 700; + padding: 1px 6px; + border-radius: 99px; + margin-right: 4px; + text-transform: uppercase; + letter-spacing: 0.05em; + vertical-align: middle; + }} + .step-sep {{ color: var(--muted); }} + .more-steps {{ color: var(--muted); font-size: 11px; }} + .card-sub {{ font-size: 12px; color: var(--muted); margin-top: 4px; }} + .card-pass .value {{ color: var(--passed); }} + .card-step .step-value {{ font-size: 22px; }} + .c-pass {{ color: var(--passed); }} + .c-fail {{ color: var(--failed); }} + .progress-panel h2 {{ margin: 0 0 4px; font-size: 22px; }} + .nt-details summary {{ + cursor: pointer; + font-size: 13px; + color: var(--muted); + padding: 6px 0; + }} + .nt-details summary::-webkit-details-marker {{ display: none; }} + .nt-details summary::before {{ content: "β–Έ "; color: var(--accent); }} + details[open].nt-details summary::before {{ content: "β–Ύ "; }} + .nt-list {{ + list-style: none; + padding: 0; + margin: 8px 0 0; + display: grid; + grid-template-columns: repeat(auto-fill, minmax(280px, 1fr)); + gap: 6px; + }} + .nt-item {{ + font-size: 12px; + color: var(--muted); + background: #f4f6fb; + border: 1px solid var(--border); + border-radius: 6px; + padding: 5px 10px; + }} + .inv-label-summary {{ + cursor: pointer; + font-weight: 600; + font-size: 14px; + list-style: none; + padding: 2px 0; + }} + .inv-label-summary::-webkit-details-marker {{ display: none; }} + .inv-label-summary::before {{ content: "β–Έ "; color: var(--accent); font-size: 11px; }} + details[open] .inv-label-summary::before {{ content: "β–Ύ "; }} + .inv-rel-tag {{ + display: inline-block; + font-size: 11px; + color: var(--muted); + background: #f0f2f8; + border-radius: 4px; + padding: 1px 6px; + margin-top: 4px; + }} + .expand-body {{ + margin-top: 10px; + padding: 10px; + background: #f8f9fd; + border-radius: 8px; + border: 1px solid var(--border); + }} + .trace-steps {{ + font-size: 12px; + color: var(--muted); + margin-bottom: 8px; + word-break: break-all; + }} + .trace-wrap {{ overflow-x: auto; }} + .trace-table {{ font-size: 11px; min-width: 400px; }} + .trace-table th {{ font-size: 10px; }} footer {{ margin-top: 28px; font-size: 12px; @@ -674,6 +1117,8 @@ def percent(part: int, total: int) -> float: {card_html}
    + {progress_panel} + {top_panel} {relation_table} @@ -780,6 +1225,8 @@ def _log_wandb( report_path: str | None, args: argparse.Namespace, ): + import glob + try: import wandb except ImportError: @@ -797,38 +1244,110 @@ def _log_wandb( tags=args.wandb_tags, job_type="checker", ) + run = self._wandb_run + if run is None: + logging.getLogger(__name__).warning("wandb.init() returned None; skipping.") + return overall = report_data["overall"] mode = report_data.get("mode", "offline") + + # --- run config (searchable/filterable in W&B UI) --- + run.config.update( + { + "traincheck/invariants_total": overall["total_invariants"], + "traincheck/output_dir": report_data.get("output_dir", ""), + "traincheck/mode": mode, + }, + allow_val_change=True, + ) + + # --- scalar metrics --- if mode == "online": + total_invariants = overall["total_invariants"] or 0 + violated = overall.get("violated_invariants", 0) or 0 + violation_rate = ( + round(violated / total_invariants * 100, 1) if total_invariants else 0.0 + ) wandb.log( { - "invariants/total": overall["total_invariants"], - "invariants/violated_unique": overall.get("violated_invariants", 0), + "invariants/total": total_invariants, + "invariants/violated_unique": violated, + "invariants/violation_rate_pct": violation_rate, "violations/total": overall["failed"], } ) else: + total_checks = overall["total_checks"] or 0 + passed = overall["passed"] or 0 + pass_rate = round(passed / total_checks * 100, 1) if total_checks else 0.0 wandb.log( { "invariants/total": overall["total_invariants"], - "checks/total": overall["total_checks"], + "checks/total": total_checks, "checks/failed": overall["failed"], - "checks/passed": overall["passed"], + "checks/passed": passed, "checks/not_triggered": overall["not_triggered"], + "checks/pass_rate_pct": pass_rate, } ) - table = wandb.Table(columns=["relation", "failed", "passed", "not_triggered"]) + # --- relation breakdown table --- + rel_table = wandb.Table( + columns=["relation", "failed", "passed", "not_triggered"] + ) for relation_name, rel_counts in report_data["relations"].items(): - table.add_data( + rel_table.add_data( relation_name, rel_counts.get("failed", 0), rel_counts.get("passed", 0), rel_counts.get("not_triggered", 0), ) - wandb.log({"relation_breakdown": table}) + wandb.log({"relation_breakdown": rel_table}) + # --- violated invariants table --- + top_violations = report_data.get("top_violations", []) + if top_violations: + vtable = wandb.Table( + columns=["invariant", "relation_type", "occurrences", "first_step"] + ) + for v in top_violations: + vtable.add_data( + v.get("label", ""), + v.get("relation", ""), + v.get("count", 0), + v.get("first_step"), + ) + wandb.log({"violations": vtable}) + + # --- summary metrics (shown in run comparison table) --- + first_steps = [ + v["first_step"] for v in top_violations if v.get("first_step") is not None + ] + if first_steps: + run.summary["violations/first_step"] = min(first_steps) + run.summary["violations/distinct_invariants"] = len(top_violations) + + # --- violations_summary.json as versioned artifact --- + summary_files = glob.glob( + os.path.join(self.output_dir, "*", "violations_summary.json") + ) + if summary_files: + try: + artifact = wandb.Artifact( + name="violations_summary", + type="checker_output", + description="Per-trace violation summaries from traincheck-check", + ) + for summary_file in summary_files: + artifact.add_file(summary_file) + run.log_artifact(artifact) + except Exception: + logging.getLogger(__name__).warning( + "Failed to attach violations_summary artifact to wandb run." + ) + + # --- HTML report --- if report_path: try: with open(report_path, "r") as f: @@ -852,6 +1371,8 @@ def _log_mlflow( ) return + import glob + if args.mlflow_experiment: mlflow.set_experiment(args.mlflow_experiment) @@ -861,18 +1382,71 @@ def _log_mlflow( overall = report_data["overall"] mode = report_data.get("mode", "offline") + top_violations = report_data.get("top_violations", []) + + # --- run tags (searchable in MLflow UI) --- + mlflow.set_tags( + { + "traincheck.mode": mode, + "traincheck.invariants_total": str(overall["total_invariants"]), + "traincheck.output_dir": report_data.get("output_dir", ""), + } + ) + + # --- scalar metrics --- if mode == "online": - mlflow.log_metric("invariants_total", overall["total_invariants"]) - mlflow.log_metric( - "invariants_violated_unique", overall.get("violated_invariants", 0) + total_invariants = overall["total_invariants"] or 0 + violated = overall.get("violated_invariants", 0) or 0 + violation_rate = ( + round(violated / total_invariants * 100, 1) if total_invariants else 0.0 ) + mlflow.log_metric("invariants_total", total_invariants) + mlflow.log_metric("invariants_violated_unique", violated) + mlflow.log_metric("invariants_violation_rate_pct", violation_rate) mlflow.log_metric("violations_total", overall["failed"]) else: + total_checks = overall["total_checks"] or 0 + passed = overall["passed"] or 0 + pass_rate = round(passed / total_checks * 100, 1) if total_checks else 0.0 mlflow.log_metric("invariants_total", overall["total_invariants"]) - mlflow.log_metric("checks_total", overall["total_checks"]) + mlflow.log_metric("checks_total", total_checks) mlflow.log_metric("checks_failed", overall["failed"]) - mlflow.log_metric("checks_passed", overall["passed"]) + mlflow.log_metric("checks_passed", passed) mlflow.log_metric("checks_not_triggered", overall["not_triggered"]) + mlflow.log_metric("checks_pass_rate_pct", pass_rate) + + # --- violation summary metrics --- + first_steps = [ + v["first_step"] for v in top_violations if v.get("first_step") is not None + ] + if first_steps: + mlflow.log_metric("violations_first_step", min(first_steps)) + mlflow.log_metric("violations_distinct_invariants", len(top_violations)) + + # --- violations table as JSON artifact --- + if top_violations: + try: + mlflow.log_dict( + {"violations": top_violations}, + artifact_file="violations.json", + ) + except Exception: + logging.getLogger(__name__).warning( + "Failed to log violations table to MLflow." + ) + + # --- per-trace violations_summary.json artifacts --- + summary_files = glob.glob( + os.path.join(self.output_dir, "*", "violations_summary.json") + ) + for summary_file in summary_files: + try: + mlflow.log_artifact(summary_file, artifact_path="violations_summaries") + except Exception: + logging.getLogger(__name__).warning( + "Failed to log %s to MLflow.", summary_file + ) + # --- HTML report --- if report_path: mlflow.log_artifact(report_path) From b4f266d9e252073ada148c979b76db06a4eebbf8 Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Thu, 19 Mar 2026 19:57:39 -0400 Subject: [PATCH 22/28] fix: KeyError on missing varid attribute in APIContainRelation online_check MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When iterating all varids of a given type, not every variable instance has every tracked attribute (e.g. _TRAINCHECK_grad_ID may be absent if grad was never observed). Skip varids that don't have the attribute in varid_map rather than crashing with KeyError. Also remove the remaining bare 'raise e' in the API-based invariant check block β€” the var-based block was fixed earlier but this one was missed, causing the checker to crash and stop on any API invariant exception. Co-Authored-By: Claude Sonnet 4.6 --- traincheck/checker_online.py | 34 +++++++++++- traincheck/invariant/contain_relation.py | 2 + traincheck/reporting/checker_report.py | 68 ++++++++++++++++++++++-- 3 files changed, 98 insertions(+), 6 deletions(-) diff --git a/traincheck/checker_online.py b/traincheck/checker_online.py index 05cdb70d..11b24822 100644 --- a/traincheck/checker_online.py +++ b/traincheck/checker_online.py @@ -28,6 +28,8 @@ ALL_INVS: list[Invariant] = [] CURRENT_STEP: int | None = None CURRENT_STAGE: str | None = None +SAMPLING_INTERVAL: int | None = None +WARM_UP_STEPS: int | None = None TOTAL_INVARIANTS = 0 RELATION_TOTALS: dict[str, int] = {} REPORTER: ReportEmitter | None = None @@ -167,6 +169,31 @@ def _record_violation_details( detail["sample_trace"] = trace[:8] +def _read_sampling_config( + trace_folders: list[str] | None, +) -> tuple[int | None, int | None]: + """Parse sampling_interval and warm_up_steps from env_dump.txt in a trace folder.""" + import re + + for folder in trace_folders or []: + env_dump_path = os.path.join(folder, "env_dump.txt") + if not os.path.exists(env_dump_path): + continue + sampling_interval: int | None = None + warm_up_steps: int | None = None + with open(env_dump_path) as fh: + for line in fh: + m = re.match(r"^sampling_interval:\s*(\d+)", line) + if m: + sampling_interval = int(m.group(1)) + m = re.match(r"^warm_up_steps:\s*(\d+)", line) + if m: + warm_up_steps = int(m.group(1)) + if sampling_interval is not None or warm_up_steps is not None: + return sampling_interval, warm_up_steps + return None, None + + def _emit_report(force: bool = False): if REPORTER is None: return @@ -182,6 +209,8 @@ def _emit_report(force: bool = False): all_invs=ALL_INVS, current_step=CURRENT_STEP, current_stage=CURRENT_STAGE, + sampling_interval=SAMPLING_INTERVAL, + warm_up_steps=WARM_UP_STEPS, ) report_state = (NUM_VIOLATIONS, len(FAILED_INV)) REPORTER.emit(report_data, force=force, report_state=report_state) @@ -198,6 +227,8 @@ def check( global ALL_INVS global CURRENT_STEP global CURRENT_STAGE + global SAMPLING_INTERVAL + global WARM_UP_STEPS global TOTAL_INVARIANTS global RELATION_TOTALS @@ -207,6 +238,8 @@ def check( logger.addHandler(logging.StreamHandler()) logger.info("Starting online checker") + SAMPLING_INTERVAL, WARM_UP_STEPS = _read_sampling_config(trace_folders) + invs, param_to_invs, vartype_to_invs, needed_data = sort_inv_file(invariants) TOTAL_INVARIANTS = len(invs) ALL_INVS = list(invs) @@ -331,7 +364,6 @@ def check( logger.error( f"Error when checking invariant {inv.text_description} with trace {trace_record}: {e}" ) - raise e _emit_report() diff --git a/traincheck/invariant/contain_relation.py b/traincheck/invariant/contain_relation.py index ba35e9b8..90225739 100644 --- a/traincheck/invariant/contain_relation.py +++ b/traincheck/invariant/contain_relation.py @@ -1324,6 +1324,8 @@ def online_check( attr_name = child_param.attr_name elif isinstance(child_param, VarTypeParam): attr_name = child_param.attr_name + if attr_name not in checker_data.varid_map[varid]: + continue for i in reversed( range(1, len(checker_data.varid_map[varid][attr_name])) ): diff --git a/traincheck/reporting/checker_report.py b/traincheck/reporting/checker_report.py index b94a76fa..24bc80d9 100644 --- a/traincheck/reporting/checker_report.py +++ b/traincheck/reporting/checker_report.py @@ -272,6 +272,8 @@ def build_online_report_data( all_invs: list | None = None, current_step: int | None = None, current_stage: str | None = None, + sampling_interval: int | None = None, + warm_up_steps: int | None = None, ) -> dict: relation_violations: dict[str, int] = defaultdict(int) for inv in failed_inv: @@ -280,6 +282,16 @@ def build_online_report_data( if violation_details is None: violation_details = {} + # Estimate how many steps have been checked so far. + checked_steps: int | None = None + if ( + current_step is not None + and sampling_interval is not None + and warm_up_steps is not None + and sampling_interval > 0 + ): + checked_steps = max(0, current_step - warm_up_steps) // sampling_interval + 1 + def _make_entry(inv: Invariant, count: int) -> dict: detail = violation_details.get(inv, {}) step_stages: list[tuple] = detail.get("step_stages") or [] @@ -300,6 +312,10 @@ def _make_entry(inv: Invariant, count: int) -> dict: ) # deduplicated, sorted (step, stage) pairs for the expanded view unique_step_stages = sorted(set(step_stages), key=lambda x: x[0]) + unique_viol_steps = len(set(s for s, _ in step_stages)) + viol_rate: float | None = None + if checked_steps is not None and checked_steps > 0: + viol_rate = round(unique_viol_steps / checked_steps * 100, 1) return { "label": _format_invariant_label(inv), "relation": inv.relation.__name__, @@ -310,6 +326,9 @@ def _make_entry(inv: Invariant, count: int) -> dict: "last_stage": last_stage, "step_stages": unique_step_stages[:100], # cap for HTML size "sample_trace": _summarize_trace_records(sample_trace), + "violation_step_count": unique_viol_steps, + "checked_steps": checked_steps, + "violation_rate": viol_rate, } # Sort by first violation step (earliest first), then by count descending. @@ -372,6 +391,9 @@ def _sort_key(item): "traces": [], "top_violations": top_violations, "not_triggered_labels": not_triggered_labels, + "sampling_interval": sampling_interval, + "warm_up_steps": warm_up_steps, + "checked_steps": checked_steps, } @@ -433,6 +455,11 @@ def percent(part: int, total: int) -> float: top_table_html = "" # used only in online mode below if mode == "online": + sampling_interval = report_data.get("sampling_interval") + warm_up_steps_val = report_data.get("warm_up_steps") + checked_steps_total = report_data.get("checked_steps") + has_sampling = sampling_interval is not None + rows = [] for entry in top_violations: label = esc(str(entry.get("label", ""))) @@ -444,6 +471,9 @@ def percent(part: int, total: int) -> float: last_stage = entry.get("last_stage") step_stages: list = entry.get("step_stages") or [] sample_trace = entry.get("sample_trace") or [] + violation_step_count = entry.get("violation_step_count", 0) + entry_checked = entry.get("checked_steps") + viol_rate = entry.get("violation_rate") def _step_with_badge(step, stage) -> str: if step is None: @@ -489,6 +519,18 @@ def _step_with_badge(step, stage) -> str: else: expand_content = f'
    Steps: {steps_html}
    ' + # Frequency cell: prefer rate when sampling info available + if has_sampling and entry_checked is not None and entry_checked > 0: + rate_str = f"{viol_rate}%" if viol_rate is not None else "?" + freq_cell = ( + f'{rate_str}' + f'' + f"{violation_step_count}/{entry_checked} steps" + f"" + ) + else: + freq_cell = f'{count}' + rows.append( f"" f'
    {label}' @@ -496,12 +538,15 @@ def _step_with_badge(step, stage) -> str: f'{relation}' f'{first_step_html}' f'{last_step_html}' - f'{count}' + f'{freq_cell}' f"" ) + + freq_col_header = "Frequency" if has_sampling else "Count" top_table_html = ( - '' - "" + f'
    InvariantFirst StepLast StepCount
    ' + f"" + f"" f"{''.join(rows)}
    InvariantFirst StepLast Step{freq_col_header}
    " if rows else "

    No violations yet.

    " @@ -791,8 +836,18 @@ def _step_with_badge(step, stage) -> str: first_step_note = "" if mode == "online": - panel_subtitle = ( - "Sorted by first violation step — click an invariant to expand trace" + sampling_interval = report_data.get("sampling_interval") + warm_up_steps_val = report_data.get("warm_up_steps") + checked_steps_total = report_data.get("checked_steps") + sampling_ctx = "" + if sampling_interval is not None: + sampling_ctx = f" · sampled every {sampling_interval} steps" + if warm_up_steps_val is not None: + sampling_ctx += f", warm-up {warm_up_steps_val}" + if checked_steps_total is not None: + sampling_ctx += f" ({checked_steps_total} steps checked)" + panel_subtitle = esc( + f"Sorted by first violation step — click to expand trace{sampling_ctx}" ) panel_content = top_table_html else: @@ -1016,6 +1071,9 @@ def _step_with_badge(step, stage) -> str: .viol-table td {{ vertical-align: top; }} .step-cell {{ white-space: nowrap; font-variant-numeric: tabular-nums; font-weight: 600; }} .count-cell {{ white-space: nowrap; font-variant-numeric: tabular-nums; font-weight: 700; color: var(--failed); }} + .freq-cell {{ white-space: nowrap; }} + .freq-rate {{ display: block; font-variant-numeric: tabular-nums; font-weight: 700; color: var(--failed); }} + .freq-detail {{ display: block; font-size: 11px; color: var(--muted); margin-top: 2px; }} .stage-badge {{ display: inline-block; font-size: 10px; From 0a1d1666a883eed4d45aa25b0da5619d38ee1e0c Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Thu, 19 Mar 2026 20:09:44 -0400 Subject: [PATCH 23/28] fix: use attr_map to scope varid iteration in APIContainRelation online_check MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add Checker_data.attr_map (var_type → attr_name → set[VarInstId]), populated in _set_var_map when an attribute is first observed for a variable. Replace the broad type_map iteration in APIContainRelation.online_check and query_var_changes_within_time_and_process with attr_map lookups that only visit varids known to carry the attribute. This eliminates the KeyError when frozen parameters (or any variable lacking a tracked attribute) appear in type_map. Co-Authored-By: Claude Sonnet 4.6 --- traincheck/invariant/contain_relation.py | 54 +++++++++---------- .../onlinechecker/streamhandler_filesystem.py | 6 +++ traincheck/onlinechecker/utils.py | 3 +- 3 files changed, 35 insertions(+), 28 deletions(-) diff --git a/traincheck/invariant/contain_relation.py b/traincheck/invariant/contain_relation.py index 90225739..872f34f5 100644 --- a/traincheck/invariant/contain_relation.py +++ b/traincheck/invariant/contain_relation.py @@ -1315,35 +1315,35 @@ def online_check( if isinstance(child_param, (VarTypeParam, VarNameParam)): events = [] + attr_name = child_param.attr_name with checker_data.lock: - if child_param.var_type in checker_data.type_map: - for varid in checker_data.type_map[child_param.var_type]: - if isinstance(child_param, VarNameParam): - if varid.var_name != child_param.var_name: - continue - attr_name = child_param.attr_name - elif isinstance(child_param, VarTypeParam): - attr_name = child_param.attr_name - if attr_name not in checker_data.varid_map[varid]: + candidate_varids = checker_data.attr_map.get( + child_param.var_type, {} + ).get(attr_name, set()) + if isinstance(child_param, VarNameParam): + candidate_varids = { + v + for v in candidate_varids + if v.var_name == child_param.var_name + } + for varid in candidate_varids: + for i in reversed( + range(1, len(checker_data.varid_map[varid][attr_name])) + ): + change_time = checker_data.varid_map[varid][attr_name][ + i + ].liveness.start_time + if change_time <= pre_time: + break + if change_time > post_time: continue - for i in reversed( - range(1, len(checker_data.varid_map[varid][attr_name])) - ): - - change_time = checker_data.varid_map[varid][attr_name][ - i - ].liveness.start_time - if change_time <= pre_time: - break - if change_time > post_time: - continue - new_state = checker_data.varid_map[varid][attr_name][i] - old_state = checker_data.varid_map[varid][attr_name][i - 1] - if new_state.value == old_state.value: - continue - if new_state.liveness.end_time is None: - new_state = copy.deepcopy(new_state) - events.append((old_state, new_state)) + new_state = checker_data.varid_map[varid][attr_name][i] + old_state = checker_data.varid_map[varid][attr_name][i - 1] + if new_state.value == old_state.value: + continue + if new_state.liveness.end_time is None: + new_state = copy.deepcopy(new_state) + events.append((old_state, new_state)) if check_relation_first: for event in events: diff --git a/traincheck/onlinechecker/streamhandler_filesystem.py b/traincheck/onlinechecker/streamhandler_filesystem.py index 46439c42..ab2a6366 100644 --- a/traincheck/onlinechecker/streamhandler_filesystem.py +++ b/traincheck/onlinechecker/streamhandler_filesystem.py @@ -34,6 +34,7 @@ def __init__(self, file_path, checker_data: Checker_data): self.varid_map = checker_data.varid_map self.type_map = checker_data.type_map + self.attr_map = checker_data.attr_map self.pt_map = checker_data.pt_map self.process_to_vars = checker_data.process_to_vars self.args_map = checker_data.args_map @@ -155,6 +156,11 @@ def _set_var_map(self, trace_record): if attr_name not in self.varid_map[varid]: self.varid_map[varid][attr_name] = [] + if varid.var_type not in self.attr_map: + self.attr_map[varid.var_type] = {} + if attr_name not in self.attr_map[varid.var_type]: + self.attr_map[varid.var_type][attr_name] = set() + self.attr_map[varid.var_type][attr_name].add(varid) else: self.varid_map[varid][attr_name][-1].liveness.end_time = ( trace_record["time"] diff --git a/traincheck/onlinechecker/utils.py b/traincheck/onlinechecker/utils.py index c174d917..d8515c92 100644 --- a/traincheck/onlinechecker/utils.py +++ b/traincheck/onlinechecker/utils.py @@ -17,6 +17,7 @@ def __init__(self, needed_data): self.check_queue = queue.Queue() self.varid_map = {} self.type_map = {} + self.attr_map: dict[str, dict[str, set]] = {} self.pt_map = {} self.process_to_vars = {} self.args_map = {} @@ -189,7 +190,7 @@ def query_var_changes_within_time_and_process( """Extract all variable change events from the trace, within a specific time range and process.""" events = [] with checker_data.lock: - for varid in checker_data.type_map[var_type]: + for varid in checker_data.attr_map.get(var_type, {}).get(attr_name, set()): for i in reversed(range(1, len(checker_data.varid_map[varid][attr_name]))): change_time = checker_data.varid_map[varid][attr_name][ From 95fb1158f9b45dcf26feb6ab7e5b2236df56b334 Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Thu, 19 Mar 2026 20:20:32 -0400 Subject: [PATCH 24/28] fix: fail loudly on attr_map/varid_map inconsistency, explicit not-yet-observable returns The .get() pattern silently returned empty sets even in cases that would indicate a population bug. Replace with direct dict access guarded only by explicit "not yet observable" early returns (no vars of this type/attr have been seen yet -- the invariant simply cannot be checked and passes vacuously). Inside the iteration loop, add assertions so any discrepancy between attr_map and varid_map fails loudly rather than being masked. Co-Authored-By: Claude Sonnet 4.6 --- traincheck/invariant/contain_relation.py | 22 +++++++++++++++++++--- traincheck/onlinechecker/utils.py | 12 ++++++++++-- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/traincheck/invariant/contain_relation.py b/traincheck/invariant/contain_relation.py index 872f34f5..ecf2aec9 100644 --- a/traincheck/invariant/contain_relation.py +++ b/traincheck/invariant/contain_relation.py @@ -1317,9 +1317,18 @@ def online_check( events = [] attr_name = child_param.attr_name with checker_data.lock: - candidate_varids = checker_data.attr_map.get( - child_param.var_type, {} - ).get(attr_name, set()) + # No variables of this type/attr have been observed yet — not yet checkable. + if child_param.var_type not in checker_data.attr_map: + return OnlineCheckerResult( + trace=None, invariant=inv, check_passed=True + ) + if attr_name not in checker_data.attr_map[child_param.var_type]: + return OnlineCheckerResult( + trace=None, invariant=inv, check_passed=True + ) + candidate_varids = checker_data.attr_map[child_param.var_type][ + attr_name + ] if isinstance(child_param, VarNameParam): candidate_varids = { v @@ -1327,6 +1336,13 @@ def online_check( if v.var_name == child_param.var_name } for varid in candidate_varids: + # attr_map guarantees attr_name is in varid_map[varid]; fail loudly + # if the population logic ever violates this invariant. + assert attr_name in checker_data.varid_map[varid], ( + f"attr_map/varid_map inconsistency: {varid} is in " + f"attr_map[...][{attr_name}] but " + f"varid_map[{varid}] has no '{attr_name}' entry" + ) for i in reversed( range(1, len(checker_data.varid_map[varid][attr_name])) ): diff --git a/traincheck/onlinechecker/utils.py b/traincheck/onlinechecker/utils.py index d8515c92..6da09e22 100644 --- a/traincheck/onlinechecker/utils.py +++ b/traincheck/onlinechecker/utils.py @@ -188,9 +188,17 @@ def query_var_changes_within_time_and_process( checker_data: Checker_data, ) -> list: """Extract all variable change events from the trace, within a specific time range and process.""" - events = [] + events: list = [] with checker_data.lock: - for varid in checker_data.attr_map.get(var_type, {}).get(attr_name, set()): + if var_type not in checker_data.attr_map: + return events + if attr_name not in checker_data.attr_map[var_type]: + return events + for varid in checker_data.attr_map[var_type][attr_name]: + assert attr_name in checker_data.varid_map[varid], ( + f"attr_map/varid_map inconsistency: {varid} in " + f"attr_map[{var_type}][{attr_name}] but not in varid_map" + ) for i in reversed(range(1, len(checker_data.varid_map[varid][attr_name]))): change_time = checker_data.varid_map[varid][attr_name][ From 02a3a33fd5d7335e9a7c003e699f15bfd9d8fe6f Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Thu, 19 Mar 2026 20:39:22 -0400 Subject: [PATCH 25/28] fix: strip _TRAINCHECK_ prefix when displaying internal tensor-tracking attrs Add _display_attr_name() helper that maps '_TRAINCHECK_grad_ID' -> 'grad' etc. Use it in APIContainRelation.to_display_name (removing the return-None guard) and ConsistencyRelation.to_display_name. Remove the now-unnecessary [internal tracking] fallback from _format_invariant_label. Co-Authored-By: Claude Sonnet 4.6 --- traincheck/invariant/base_cls.py | 11 +++++++++++ traincheck/invariant/consistency_relation.py | 3 ++- traincheck/invariant/contain_relation.py | 6 ++---- traincheck/reporting/checker_report.py | 20 +++----------------- 4 files changed, 18 insertions(+), 22 deletions(-) diff --git a/traincheck/invariant/base_cls.py b/traincheck/invariant/base_cls.py index 071ecaea..b9fb041e 100644 --- a/traincheck/invariant/base_cls.py +++ b/traincheck/invariant/base_cls.py @@ -1944,6 +1944,17 @@ def _short_api_name(full_name: str) -> str: return ".".join(parts[-2:]) if len(parts) >= 2 else full_name +def _display_attr_name(attr_name: str) -> str: + """Strip TrainCheck-internal proxy bookkeeping prefix/suffix for display. + + '_TRAINCHECK_grad_ID' → 'grad' + 'dtype' → 'dtype' (unchanged) + """ + if attr_name.startswith("_TRAINCHECK_") and attr_name.endswith("_ID"): + return attr_name[len("_TRAINCHECK_") : -len("_ID")] + return attr_name + + def read_inv_file(file_path: str | list[str]) -> list[Invariant]: if isinstance(file_path, str): file_path = [file_path] diff --git a/traincheck/invariant/consistency_relation.py b/traincheck/invariant/consistency_relation.py index 11ab8705..8a6c26fb 100644 --- a/traincheck/invariant/consistency_relation.py +++ b/traincheck/invariant/consistency_relation.py @@ -14,6 +14,7 @@ Param, Relation, VarTypeParam, + _display_attr_name, ) from traincheck.invariant.precondition import find_precondition from traincheck.onlinechecker.utils import Checker_data, set_meta_vars_online @@ -118,7 +119,7 @@ def to_display_name(params: list[Param]) -> str | None: if not isinstance(p, VarTypeParam): return None var_short = p.var_type.split(".")[-1] - attr = p.attr_name + attr = _display_attr_name(p.attr_name) return f"{var_short}.{attr} stays consistent across training steps" @staticmethod diff --git a/traincheck/invariant/contain_relation.py b/traincheck/invariant/contain_relation.py index ecf2aec9..84fa76d2 100644 --- a/traincheck/invariant/contain_relation.py +++ b/traincheck/invariant/contain_relation.py @@ -23,6 +23,7 @@ Relation, VarNameParam, VarTypeParam, + _display_attr_name, _short_api_name, calc_likelihood, construct_api_param, @@ -358,10 +359,7 @@ def to_display_name(params: list[Param]) -> str | None: child_short = _short_api_name(child.api_full_name) return f"{parent_short}() always calls {child_short}()" if isinstance(child, (VarTypeParam, VarNameParam)): - attr = child.attr_name - # Skip internal TrainCheck proxy bookkeeping attributes - if attr.startswith("_TRAINCHECK_"): - return None + attr = _display_attr_name(child.attr_name) var_short = child.var_type.split(".")[-1] pre = child.pre_value post = child.post_value diff --git a/traincheck/reporting/checker_report.py b/traincheck/reporting/checker_report.py index 24bc80d9..a5eaa4e6 100644 --- a/traincheck/reporting/checker_report.py +++ b/traincheck/reporting/checker_report.py @@ -7,28 +7,12 @@ from typing import Iterable from traincheck.invariant import CheckerResult, Invariant -from traincheck.invariant.base_cls import APIParam, VarNameParam, VarTypeParam def _format_invariant_label(invariant: Invariant) -> str: display = invariant.relation.to_display_name(invariant.params) if display: return display - # When to_display_name returns None, fall back — but sanitize params that - # contain internal TrainCheck proxy bookkeeping names (_TRAINCHECK_*) so - # they never surface raw in the UI. - for p in invariant.params: - if isinstance(p, (VarTypeParam, VarNameParam)) and p.attr_name.startswith( - "_TRAINCHECK_" - ): - # Build a minimal label using only the API param, hiding internals - api_parts = [q for q in invariant.params if isinstance(q, APIParam)] - from traincheck.invariant.base_cls import _short_api_name - - if api_parts: - func = _short_api_name(api_parts[0].api_full_name) - return f"{func}() [internal tracking]" - return f"{invariant.relation.__name__} [internal tracking]" if invariant.text_description: return invariant.text_description params = ", ".join(str(param) for param in invariant.params) @@ -290,7 +274,9 @@ def build_online_report_data( and warm_up_steps is not None and sampling_interval > 0 ): - checked_steps = max(0, current_step - warm_up_steps) // sampling_interval + 1 + checked_steps = max(0, current_step - warm_up_steps) // sampling_interval + min( + current_step, warm_up_steps + ) def _make_entry(inv: Invariant, count: int) -> dict: detail = violation_details.get(inv, {}) From 9d8d04c8da6c04a7d3c6733c26c7a636f724e087 Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Thu, 19 Mar 2026 20:40:49 -0400 Subject: [PATCH 26/28] fix: keep _ID suffix when stripping _TRAINCHECK_ prefix from attr display names '_TRAINCHECK_grad_ID' -> 'grad_ID', not 'grad'. Co-Authored-By: Claude Sonnet 4.6 --- traincheck/invariant/base_cls.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/traincheck/invariant/base_cls.py b/traincheck/invariant/base_cls.py index b9fb041e..944f57e6 100644 --- a/traincheck/invariant/base_cls.py +++ b/traincheck/invariant/base_cls.py @@ -1945,13 +1945,13 @@ def _short_api_name(full_name: str) -> str: def _display_attr_name(attr_name: str) -> str: - """Strip TrainCheck-internal proxy bookkeeping prefix/suffix for display. + """Strip TrainCheck-internal proxy bookkeeping prefix for display. - '_TRAINCHECK_grad_ID' → 'grad' + '_TRAINCHECK_grad_ID' → 'grad_ID' 'dtype' → 'dtype' (unchanged) """ - if attr_name.startswith("_TRAINCHECK_") and attr_name.endswith("_ID"): - return attr_name[len("_TRAINCHECK_") : -len("_ID")] + if attr_name.startswith("_TRAINCHECK_"): + return attr_name[len("_TRAINCHECK_") :] return attr_name From 9cf8064b8604c90bcefdac5c3a8fbd3b993fa1fd Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Thu, 19 Mar 2026 22:45:20 -0400 Subject: [PATCH 27/28] feat: add step/stage/trace detail to offline report, W&B, and MLflow MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - _count_failed_invariants now tracks last_step, step_stages (step→stage map from all violation traces), and sample_trace (first violation) - Offline HTML violations panel and per-trace failed-invariants lists now use the same expandable table format as the online report: First Step / Last Step / Count columns, stage badges, collapsible step timeline and sample trace rows - W&B violations table gains a last_step column; summary gains violations/last_step - MLflow gains violations_last_step metric Co-Authored-By: Claude Sonnet 4.6 --- traincheck/reporting/checker_report.py | 212 ++++++++++++++++++++----- 1 file changed, 173 insertions(+), 39 deletions(-) diff --git a/traincheck/reporting/checker_report.py b/traincheck/reporting/checker_report.py index a5eaa4e6..bf552994 100644 --- a/traincheck/reporting/checker_report.py +++ b/traincheck/reporting/checker_report.py @@ -91,6 +91,9 @@ def _count_failed_invariants( ) -> list[dict[str, object]]: counter: Counter[tuple[str, str]] = Counter() first_steps: dict[tuple[str, str], int | None] = {} + last_steps: dict[tuple[str, str], int | None] = {} + step_stage_maps: dict[tuple[str, str], dict] = defaultdict(dict) + sample_traces: dict[tuple[str, str], list] = {} for res in results: if not res.check_passed: label = _format_invariant_label(res.invariant) @@ -103,8 +106,26 @@ def _count_failed_invariants( first_steps[key] = ( min(steps) if existing is None else min(existing, min(steps)) ) + existing_last = last_steps.get(key) + last_steps[key] = ( + max(steps) + if existing_last is None + else max(existing_last, max(steps)) + ) elif key not in first_steps: first_steps[key] = None + last_steps[key] = None + # Accumulate step → stage (first stage seen per step wins) + for rec in res.trace or []: + if not isinstance(rec, dict): + continue + step = rec.get("meta_vars.step") + stage = rec.get("meta_vars.stage") + if step is not None and step not in step_stage_maps[key]: + step_stage_maps[key][step] = stage + # One sample trace per invariant (first violation wins) + if key not in sample_traces and res.trace: + sample_traces[key] = _summarize_trace_records(res.trace) top_pairs = counter.most_common(10) return [ { @@ -112,6 +133,9 @@ def _count_failed_invariants( "relation": relation, "count": count, "first_step": first_steps.get((label, relation)), + "last_step": last_steps.get((label, relation)), + "step_stages": sorted(step_stage_maps[(label, relation)].items()), + "sample_trace": sample_traces.get((label, relation), []), } for (label, relation), count in top_pairs ] @@ -434,11 +458,7 @@ def percent(part: int, total: int) -> float: traces = report_data.get("traces", []) top_violations = report_data.get("top_violations", []) - # Build top violations HTML differently per mode. - # For online mode: step-sorted table with expandable trace rows. - # For offline mode: simple list (unchanged). - top_list = "" # used only in offline mode below - top_table_html = "" # used only in online mode below + top_table_html = "" if mode == "online": sampling_interval = report_data.get("sampling_interval") @@ -538,27 +558,77 @@ def _step_with_badge(step, stage) -> str: else "

    No violations yet.

    " ) else: - top_items = [] + rows = [] for entry in top_violations: label = esc(str(entry.get("label", ""))) relation = esc(str(entry.get("relation", ""))) - count = entry.get("count") + count = entry.get("count", "") first_step = entry.get("first_step") - count_html = f'{count}' if count else "" - if first_step is not None: - step_note = f"first seen at step {first_step}" - if count and count > 1: - step_note += f" Β· {count} occurrences" - detail = esc(f"{entry.get('relation', '')} β€” {step_note}") + last_step = entry.get("last_step") + off_step_stages: list = entry.get("step_stages") or [] + off_sample_trace = entry.get("sample_trace") or [] + + def _step_cell(step, _ss=off_step_stages) -> str: + if step is None: + return "β€”" + stage = next((s for st, s in _ss if st == step), None) + badge = _render_stage_badge(stage, esc) + return f"{badge}{step}" + + first_step_html = _step_cell(first_step) + last_step_html = _step_cell(last_step) + steps_html = _render_step_stages_html(off_step_stages, esc) + + if off_sample_trace: + off_keys: list[str] = [] + for rec in off_sample_trace: + for k in rec: + if k not in off_keys: + off_keys.append(k) + trace_head = "".join(f"{esc(k)}" for k in off_keys) + trace_rows_html = [] + for rec in off_sample_trace: + cells = [] + for k in off_keys: + val = rec.get(k, "") + if k == "meta_vars.stage" and val: + style = _stage_badge_style(val) + cell = ( + f'' + f"{esc(val)}" + ) + else: + cell = f"{esc(str(val))}" + cells.append(cell) + trace_rows_html.append(f"{''.join(cells)}") + trace_body = "\n".join(trace_rows_html) + expand_content = ( + f'
    Steps: {steps_html}
    ' + f'
    ' + f"{trace_head}" + f"{trace_body}
    " + ) else: - detail = relation - if count and count > 1: - detail = esc(f"{entry.get('relation', '')} β€” {count} occurrences") - top_items.append( - f'
  • {label}' - f'{detail}{count_html}
  • ' + expand_content = f'
    Steps: {steps_html}
    ' + + rows.append( + f"" + f'
    {label}' + f'
    {expand_content}
    ' + f'{relation}' + f'{first_step_html}' + f'{last_step_html}' + f'{count}' + f"" ) - top_list = "".join(top_items) or "
  • None
  • " + top_table_html = ( + f'' + f"" + f"" + f"{''.join(rows)}
    InvariantFirst StepLast StepCount
    " + if rows + else "

    No violations.

    " + ) trace_sections = [] for trace in traces: @@ -573,27 +643,74 @@ def _step_with_badge(step, stage) -> str: + _render_bar_segment(percent(not_triggered, total), "bar-not-triggered") ) - failed_list_items = [] + failed_rows = [] for failed_item in trace["failed_invariants"][:10]: label = esc(str(failed_item.get("label", ""))) relation = esc(str(failed_item.get("relation", ""))) - count = failed_item.get("count") + count = failed_item.get("count", "") first_step = failed_item.get("first_step") - count_html = f'{count}' if count else "" - if first_step is not None: - step_note = f"first seen at step {first_step}" - if count and count > 1: - step_note += f" Β· {count} occurrences" - detail = esc(f"{relation} β€” {step_note}") + last_step = failed_item.get("last_step") + item_step_stages: list = failed_item.get("step_stages") or [] + item_sample_trace = failed_item.get("sample_trace") or [] + + def _step_cell_trace(step) -> str: + if step is None: + return "β€”" + stage = next((s for st, s in item_step_stages if st == step), None) + badge = _render_stage_badge(stage, esc) + return f"{badge}{step}" + + steps_html = _render_step_stages_html(item_step_stages, esc) + if item_sample_trace: + item_keys: list[str] = [] + for rec in item_sample_trace: + for k in rec: + if k not in item_keys: + item_keys.append(k) + trace_head = "".join(f"{esc(k)}" for k in item_keys) + trace_rows_html = [] + for rec in item_sample_trace: + cells = [] + for k in item_keys: + val = rec.get(k, "") + if k == "meta_vars.stage" and val: + style = _stage_badge_style(val) + cell = ( + f'' + f"{esc(val)}" + ) + else: + cell = f"{esc(str(val))}" + cells.append(cell) + trace_rows_html.append(f"{''.join(cells)}") + trace_body = "\n".join(trace_rows_html) + expand_content = ( + f'
    Steps: {steps_html}
    ' + f'
    ' + f"{trace_head}" + f"{trace_body}
    " + ) else: - detail = relation - if count and count > 1: - detail = esc(f"{relation} β€” {count} occurrences") - failed_list_items.append( - f'
  • {label}' - f'{detail}{count_html}
  • ' + expand_content = f'
    Steps: {steps_html}
    ' + + failed_rows.append( + f"" + f'
    {label}' + f'
    {expand_content}
    ' + f'{relation}' + f'{_step_cell_trace(first_step)}' + f'{_step_cell_trace(last_step)}' + f'{count}' + f"" ) - failed_list_html = "".join(failed_list_items) or "
  • None
  • " + failed_list_html = ( + f'' + f"" + f"" + f"{''.join(failed_rows)}
    InvariantFirst StepLast StepCount
    " + if failed_rows + else "

    None

    " + ) relation_rows = [] for relation_name, rel_counts in sorted(trace["relations"].items()): @@ -630,7 +747,7 @@ def _step_with_badge(step, stage) -> str:

    Failed invariants (top 10)

    -
      {failed_list_html}
    + {failed_list_html}

    Relation breakdown

    @@ -837,8 +954,8 @@ def _step_with_badge(step, stage) -> str: ) panel_content = top_table_html else: - panel_subtitle = "Most frequent violations observed" - panel_content = f'
      {top_list}
    ' + panel_subtitle = "Sorted by first violation step β€” click to expand trace" + panel_content = top_table_html top_panel = f"""
    @@ -1353,7 +1470,13 @@ def _log_wandb( top_violations = report_data.get("top_violations", []) if top_violations: vtable = wandb.Table( - columns=["invariant", "relation_type", "occurrences", "first_step"] + columns=[ + "invariant", + "relation_type", + "occurrences", + "first_step", + "last_step", + ] ) for v in top_violations: vtable.add_data( @@ -1361,6 +1484,7 @@ def _log_wandb( v.get("relation", ""), v.get("count", 0), v.get("first_step"), + v.get("last_step"), ) wandb.log({"violations": vtable}) @@ -1368,8 +1492,13 @@ def _log_wandb( first_steps = [ v["first_step"] for v in top_violations if v.get("first_step") is not None ] + last_steps_wandb = [ + v["last_step"] for v in top_violations if v.get("last_step") is not None + ] if first_steps: run.summary["violations/first_step"] = min(first_steps) + if last_steps_wandb: + run.summary["violations/last_step"] = max(last_steps_wandb) run.summary["violations/distinct_invariants"] = len(top_violations) # --- violations_summary.json as versioned artifact --- @@ -1463,8 +1592,13 @@ def _log_mlflow( first_steps = [ v["first_step"] for v in top_violations if v.get("first_step") is not None ] + last_steps = [ + v["last_step"] for v in top_violations if v.get("last_step") is not None + ] if first_steps: mlflow.log_metric("violations_first_step", min(first_steps)) + if last_steps: + mlflow.log_metric("violations_last_step", max(last_steps)) mlflow.log_metric("violations_distinct_invariants", len(top_violations)) # --- violations table as JSON artifact --- From 6c4caf44ef7fb2af9b0f7059618b2091e49815fa Mon Sep 17 00:00:00 2001 From: Yuxuan Jiang Date: Fri, 20 Mar 2026 00:38:22 -0400 Subject: [PATCH 28/28] feat: log per-step violation counts to W&B and MLflow MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add _build_violation_steps_map() helper: step β†’ count of distinct invariants violated at that step (across all failed CheckerResults) - Propagate violation_steps_map through build_offline_report_data and build_online_report_data so downstream loggers can consume it - W&B: log traincheck/violations as a metric at each step via wandb.log({...}, step=N) so violations appear on the same x-axis as training loss; add --wandb-run-id CLI arg to attach to an existing run - MLflow: log traincheck_violations per step via mlflow.log_metric(step=N); switch violations table from log_dict to log_table() for proper UI Co-Authored-By: Claude Sonnet 4.6 --- traincheck/checker.py | 6 +++ traincheck/checker_online.py | 6 +++ traincheck/reporting/checker_report.py | 53 ++++++++++++++++++++++++-- 3 files changed, 61 insertions(+), 4 deletions(-) diff --git a/traincheck/checker.py b/traincheck/checker.py index 05dee116..249945c7 100644 --- a/traincheck/checker.py +++ b/traincheck/checker.py @@ -170,6 +170,12 @@ def main(): default=None, help="Weights & Biases tags.", ) + parser.add_argument( + "--wandb-run-id", + type=str, + default=None, + help="Attach to an existing Weights & Biases run ID (e.g. to overlay violation metrics on a training run).", + ) parser.add_argument( "--report-mlflow", action="store_true", diff --git a/traincheck/checker_online.py b/traincheck/checker_online.py index 11b24822..d28c2a8e 100644 --- a/traincheck/checker_online.py +++ b/traincheck/checker_online.py @@ -477,6 +477,12 @@ def main(): default=None, help="Weights & Biases tags.", ) + parser.add_argument( + "--wandb-run-id", + type=str, + default=None, + help="Attach to an existing Weights & Biases run ID (e.g. to overlay violation metrics on a training run).", + ) parser.add_argument( "--report-mlflow", action="store_true", diff --git a/traincheck/reporting/checker_report.py b/traincheck/reporting/checker_report.py index bf552994..1ea3c4a7 100644 --- a/traincheck/reporting/checker_report.py +++ b/traincheck/reporting/checker_report.py @@ -41,6 +41,17 @@ def _build_violation_entry(result: CheckerResult) -> dict: } +def _build_violation_steps_map(results: list[CheckerResult]) -> dict[int, int]: + """Map step β†’ count of distinct invariants violated at that step.""" + step_to_invs: dict[int, set[str]] = defaultdict(set) + for res in results: + if not res.check_passed: + label = _format_invariant_label(res.invariant) + for step in _extract_violation_steps(res.trace): + step_to_invs[step].add(label) + return {step: len(invs) for step, invs in step_to_invs.items()} + + def build_violations_summary(results: list[CheckerResult]) -> dict: """Build a pre-digested summary of all violations for machine and human consumption.""" failed = [r for r in results if not r.check_passed] @@ -191,6 +202,7 @@ def build_offline_report_data( all_failed_invariants.extend([res for res in results if not res.check_passed]) top_violations = _count_failed_invariants(all_failed_invariants) + violation_steps_map = _build_violation_steps_map(all_failed_invariants) return { "mode": "offline", @@ -200,6 +212,7 @@ def build_offline_report_data( "relations": dict(overall_relation_counts), "traces": trace_sections, "top_violations": top_violations, + "violation_steps_map": violation_steps_map, } @@ -341,6 +354,14 @@ def _make_entry(inv: Invariant, count: int) -> dict: "violation_rate": viol_rate, } + # Build step β†’ distinct-invariant count map from all violation_details + step_to_invs: dict[int, set[str]] = defaultdict(set) + for inv, detail in violation_details.items(): + lbl = _format_invariant_label(inv) + for step, _ in detail.get("step_stages") or []: + step_to_invs[step].add(lbl) + violation_steps_map = {step: len(invs) for step, invs in step_to_invs.items()} + # Sort by first violation step (earliest first), then by count descending. def _sort_key(item): inv, count = item @@ -404,6 +425,7 @@ def _sort_key(item): "sampling_interval": sampling_interval, "warm_up_steps": warm_up_steps, "checked_steps": checked_steps, + "violation_steps_map": violation_steps_map, } @@ -1397,7 +1419,7 @@ def _log_wandb( return if self._wandb_run is None: - self._wandb_run = wandb.init( + init_kwargs: dict = dict( project=args.wandb_project, entity=args.wandb_entity, name=args.wandb_run_name, @@ -1405,6 +1427,11 @@ def _log_wandb( tags=args.wandb_tags, job_type="checker", ) + run_id = getattr(args, "wandb_run_id", None) + if run_id: + init_kwargs["id"] = run_id + init_kwargs["resume"] = "allow" + self._wandb_run = wandb.init(**init_kwargs) # type: ignore[assignment] run = self._wandb_run if run is None: logging.getLogger(__name__).warning("wandb.init() returned None; skipping.") @@ -1530,6 +1557,11 @@ def _log_wandb( "Failed to attach HTML report to wandb run." ) + # --- per-step violation time-series (overlays with training loss curve) --- + violation_steps_map: dict[int, int] = report_data.get("violation_steps_map", {}) + for step, count in sorted(violation_steps_map.items()): + wandb.log({"traincheck/violations": count}, step=step) + def _log_mlflow( self, report_data: dict, @@ -1601,11 +1633,24 @@ def _log_mlflow( mlflow.log_metric("violations_last_step", max(last_steps)) mlflow.log_metric("violations_distinct_invariants", len(top_violations)) - # --- violations table as JSON artifact --- + # --- per-step violation time-series (overlays with training loss curve) --- + violation_steps_map: dict[int, int] = report_data.get("violation_steps_map", {}) + for step, count in sorted(violation_steps_map.items()): + mlflow.log_metric("traincheck_violations", count, step=step) + + # --- violations table (mlflow.log_table for proper UI rendering) --- if top_violations: try: - mlflow.log_dict( - {"violations": top_violations}, + mlflow.log_table( + data={ + "invariant": [v.get("label", "") for v in top_violations], + "relation_type": [ + v.get("relation", "") for v in top_violations + ], + "occurrences": [v.get("count", 0) for v in top_violations], + "first_step": [v.get("first_step") for v in top_violations], + "last_step": [v.get("last_step") for v in top_violations], + }, artifact_file="violations.json", ) except Exception: