diff --git a/docs/technical-doc.md b/docs/technical-doc.md index c4923738..a49d6a4f 100644 --- a/docs/technical-doc.md +++ b/docs/technical-doc.md @@ -1,8 +1,5 @@ # TrainCheck Documentation -🚜 This documentation is under construction. We welcome any feedback or questions through GitHub Issues or [our Discord server](https://discord.gg/DPEd7Xeg). - - TrainCheck is a lightweight, invariant-based instrumentation and analysis tool for identifying silent correctness issues in PyTorch training pipelines. It infers behavioral invariants from correct reference runs (e.g., official examples or clean configurations), then checks other scripts for behavioral violations. TrainCheck is designed to be minimally intrusive—requiring no code modifications or rewrites of training logic. ## 🔧 System Overview diff --git a/docs/usage-guide.md b/docs/usage-guide.md index 06e2b935..39f12cbd 100644 --- a/docs/usage-guide.md +++ b/docs/usage-guide.md @@ -4,7 +4,7 @@ TrainCheck helps detect and diagnose silent errors in deep learning training run ## 🚀 Quick Start -Check out the [5-minute guide](./docs/5-min.md) for a minimal working example. +Check out the [5-minute guide](5-min-tutorial.md) for a minimal working example. ## ✅ Common Use Cases diff --git a/tests/test_display_names.py b/tests/test_display_names.py new file mode 100644 index 00000000..68ffd112 --- /dev/null +++ b/tests/test_display_names.py @@ -0,0 +1,278 @@ +"""Semantic unit tests for Relation.to_display_name(). + +These tests verify that key *meaning* tokens appear in the output for each +relation type given a known params list. They do NOT test inference logic — +the params are constructed directly, so the tests remain stable even if the +inference algorithm changes. +""" + +import pytest + +from traincheck.invariant.base_cls import ( + _NOT_SET, + APIParam, + InputOutputParam, + VarTypeParam, +) +from traincheck.invariant.consistency_relation import ConsistencyRelation +from traincheck.invariant.consistency_transient_vars import ( + ConsistentInputOutputRelation, + ConsistentOutputRelation, + ThresholdRelation, +) +from traincheck.invariant.contain_relation import APIContainRelation +from traincheck.invariant.cover_relation import FunctionCoverRelation +from traincheck.invariant.DistinctArgumentRelation import DistinctArgumentRelation +from traincheck.invariant.lead_relation import FunctionLeadRelation + + +class TestAPIContainRelationDisplayName: + def test_state_transition(self): + params = [ + APIParam("torch.optim.optimizer.Optimizer.zero_grad"), + VarTypeParam( + "torch.nn.Parameter", "grad", pre_value="non_zero", post_value=None + ), + ] + name = APIContainRelation.to_display_name(params) + assert name is not None + assert "zero_grad" in name + assert "grad" in name + assert "non" in name.lower() # "non-zero" + + def test_api_calls_api(self): + params = [ + APIParam("torch.optim.optimizer.Optimizer.step"), + APIParam("torch.optim.adadelta.adadelta"), + ] + name = APIContainRelation.to_display_name(params) + assert name is not None + assert "step" in name + assert "adadelta" in name + + def test_const_value(self): + params = [ + APIParam("torch.nn.modules.module.Module.forward"), + VarTypeParam("torch.nn.Parameter", "requires_grad", const_value=True), + ] + name = APIContainRelation.to_display_name(params) + assert name is not None + assert "forward" in name + assert "requires_grad" in name + + def test_post_value_non_zero_normalized(self): + """non_zero post-value should render as 'non-zero', not 'non_zero'.""" + params = [ + APIParam("torch.optim.sgd.SGD.step"), + VarTypeParam( + "torch.nn.Parameter", + "data", + pre_value="non_zero", + post_value="non_zero", + ), + ] + name = APIContainRelation.to_display_name(params) + assert name is not None + assert "non_zero" not in name + assert "non-zero" in name + + def test_traincheck_internal_attr_hidden(self): + """Attributes starting with _TRAINCHECK_ are internal proxy IDs and should be filtered.""" + params = [ + APIParam("torch.optim.sgd.SGD.step"), + VarTypeParam( + "torch.nn.Parameter", + "_TRAINCHECK_grad_ID", + pre_value="above_zero", + post_value="above_zero", + ), + ] + assert APIContainRelation.to_display_name(params) is None + + def test_returns_none_for_empty_params(self): + assert APIContainRelation.to_display_name([]) is None + + def test_returns_none_for_single_param(self): + assert APIContainRelation.to_display_name([APIParam("torch.foo")]) is None + + +class TestConsistencyRelationDisplayName: + def test_basic(self): + params = [VarTypeParam("torch.nn.Parameter", "grad")] + name = ConsistencyRelation.to_display_name(params) + assert name is not None + assert "Parameter" in name + assert "grad" in name + assert any(w in name.lower() for w in ("consistent", "stay", "step")) + + def test_returns_none_for_empty(self): + assert ConsistencyRelation.to_display_name([]) is None + + def test_returns_none_for_non_vartype(self): + assert ConsistencyRelation.to_display_name([APIParam("torch.foo.bar")]) is None + + +class TestFunctionCoverRelationDisplayName: + def test_cover_direction(self): + params = [ + APIParam("torch.distributed.is_initialized"), + APIParam("torch.nn.modules.module.Module.eval"), + ] + name = FunctionCoverRelation.to_display_name(params) + assert name is not None + assert "is_initialized" in name + assert "eval" in name + assert any(w in name.lower() for w in ("occurs", "cover", "when")) + + def test_returns_none_for_insufficient_params(self): + assert FunctionCoverRelation.to_display_name([APIParam("torch.foo")]) is None + + +class TestFunctionLeadRelationDisplayName: + def test_ordering(self): + params = [ + APIParam("torch.Tensor.backward"), + APIParam("torch.optim.optimizer.Optimizer.step"), + ] + name = FunctionLeadRelation.to_display_name(params) + assert name is not None + assert "backward" in name + assert "step" in name + assert any(w in name.lower() for w in ("precede", "before", "lead")) + + def test_merged_three_params(self): + """Merged lead invariants can have 3 APIParams; display uses first and last.""" + params = [ + APIParam("torch.Tensor.backward"), + APIParam("torch.optim.optimizer.Optimizer.zero_grad"), + APIParam("torch.optim.optimizer.Optimizer.step"), + ] + name = FunctionLeadRelation.to_display_name(params) + assert name is not None + assert "backward" in name + assert "step" in name + + def test_returns_none_for_single_param(self): + assert FunctionLeadRelation.to_display_name([APIParam("torch.foo")]) is None + + +class TestDistinctArgumentRelationDisplayName: + def test_basic(self): + params = [APIParam("torch.nn.init.normal_")] + name = DistinctArgumentRelation.to_display_name(params) + assert name is not None + assert "normal_" in name + assert any(w in name.lower() for w in ("distinct", "different", "argument")) + + def test_returns_none_for_empty(self): + assert DistinctArgumentRelation.to_display_name([]) is None + + def test_returns_none_for_non_api_param(self): + params = [VarTypeParam("torch.nn.Parameter", "grad")] + assert DistinctArgumentRelation.to_display_name(params) is None + + +class TestConsistentOutputRelationDisplayName: + def test_with_const_value(self): + params = [ + APIParam("torch.nn.functional.relu"), + VarTypeParam("torch.Tensor", "dtype", const_value="float32"), + ] + name = ConsistentOutputRelation.to_display_name(params) + assert name is not None + assert "relu" in name + assert "dtype" in name + assert "float32" in name + assert any(w in name.lower() for w in ("consistent", "return")) + + def test_without_const_value(self): + params = [ + APIParam("torch.nn.functional.relu"), + VarTypeParam("torch.Tensor", "ndim"), + ] + name = ConsistentOutputRelation.to_display_name(params) + assert name is not None + assert "relu" in name + assert "ndim" in name + + def test_returns_none_for_insufficient_params(self): + assert ConsistentOutputRelation.to_display_name([APIParam("torch.foo")]) is None + + +class TestConsistentInputOutputRelationDisplayName: + def test_basic(self): + in_p = InputOutputParam( + name="input", + index=0, + type="torch.Tensor", + additional_path=("itemsize",), + api_name="kaiming_uniform_", + is_input=True, + ) + out_p = InputOutputParam( + name="output", + index=0, + type="torch.Tensor", + additional_path=("ndim",), + api_name="kaiming_uniform_", + is_input=False, + ) + api_p = APIParam("torch.nn.init.kaiming_uniform_") + name = ConsistentInputOutputRelation.to_display_name([in_p, api_p, out_p]) + assert name is not None + assert "kaiming_uniform_" in name + assert "itemsize" in name + assert "ndim" in name + assert "input" in name.lower() + assert "output" in name.lower() + + def test_returns_none_for_insufficient_params(self): + api_p = APIParam("torch.foo") + assert ConsistentInputOutputRelation.to_display_name([api_p]) is None + + +class TestThresholdRelationDisplayName: + def _make_output_param(self, api_name: str) -> InputOutputParam: + return InputOutputParam( + name="output_tensors", + index=0, + type="torch.Tensor", + additional_path=("value",), + api_name=api_name, + is_input=False, + ) + + def _make_threshold_param(self, name: str, api_name: str) -> InputOutputParam: + return InputOutputParam( + name=name, + index=None, + type="float", + additional_path=None, + api_name=api_name, + is_input=True, + ) + + def test_min_threshold_gte(self): + """params=[output, api, threshold] → output ≥ threshold.""" + api_p = APIParam("torch.optim.optimizer.Optimizer.step") + out_p = self._make_output_param("Optimizer.step") + thresh_p = self._make_threshold_param("lr", "Optimizer.step") + name = ThresholdRelation.to_display_name([out_p, api_p, thresh_p]) + assert name is not None + assert "Optimizer.step" in name + assert "lr" in name + assert "≥" in name + + def test_max_threshold_lte(self): + """params=[threshold, api, output] → output ≤ threshold.""" + api_p = APIParam("torch.optim.optimizer.Optimizer.step") + out_p = self._make_output_param("Optimizer.step") + thresh_p = self._make_threshold_param("lr", "Optimizer.step") + name = ThresholdRelation.to_display_name([thresh_p, api_p, out_p]) + assert name is not None + assert "Optimizer.step" in name + assert "lr" in name + assert "≤" in name + + def test_returns_none_for_insufficient_params(self): + assert ThresholdRelation.to_display_name([APIParam("torch.foo")]) is None diff --git a/tests/test_violation_summary.py b/tests/test_violation_summary.py new file mode 100644 index 00000000..db4962f8 --- /dev/null +++ b/tests/test_violation_summary.py @@ -0,0 +1,167 @@ +"""Semantic unit tests for violation summary helpers. + +Tests verify the extraction and aggregation logic — pure function tests that +do not depend on trace loading or the inference algorithm. +""" + +import pytest + +from traincheck.invariant.base_cls import APIParam, CheckerResult, Invariant +from traincheck.invariant.cover_relation import FunctionCoverRelation +from traincheck.reporting.checker_report import ( + _build_violation_entry, + _extract_violation_steps, + build_violations_summary, +) + +# --------------------------------------------------------------------------- +# _extract_violation_steps +# --------------------------------------------------------------------------- + + +def test_extract_steps_basic(): + trace = [{"meta_vars.step": 1}, {"meta_vars.step": 3}] + assert _extract_violation_steps(trace) == [1, 3] + + +def test_extract_steps_missing_key(): + trace = [{"function": "foo"}, {"meta_vars.step": 5}] + assert _extract_violation_steps(trace) == [5] + + +def test_extract_steps_none_trace(): + assert _extract_violation_steps(None) == [] + + +def test_extract_steps_empty_trace(): + assert _extract_violation_steps([]) == [] + + +def test_extract_steps_none_value_skipped(): + trace = [{"meta_vars.step": None}, {"meta_vars.step": 2}] + assert _extract_violation_steps(trace) == [2] + + +# --------------------------------------------------------------------------- +# Helper to build minimal CheckerResult fixtures +# --------------------------------------------------------------------------- + + +def _make_invariant() -> Invariant: + return Invariant( + relation=FunctionCoverRelation, + params=[ + APIParam("torch.distributed.is_initialized"), + APIParam("torch.nn.modules.module.Module.eval"), + ], + precondition=None, + text_description="test invariant", + ) + + +def _make_result(steps: list[int] | None, check_passed: bool = False) -> CheckerResult: + trace = [{"meta_vars.step": s} for s in steps] if steps is not None else None + return CheckerResult( + invariant=_make_invariant(), + trace=trace, + check_passed=check_passed, + triggered=True, + ) + + +# --------------------------------------------------------------------------- +# _build_violation_entry +# --------------------------------------------------------------------------- + + +def test_build_entry_fields_present(): + result = _make_result(steps=[1, 2, 3]) + entry = _build_violation_entry(result) + assert "display_name" in entry + assert "relation_type" in entry + assert "first_step" in entry + assert "last_step" in entry + assert "occurrences" in entry + + +def test_build_entry_step_values(): + result = _make_result(steps=[1, 5, 3]) + entry = _build_violation_entry(result) + assert entry["first_step"] == 1 + assert entry["last_step"] == 5 + assert entry["occurrences"] == 3 + + +def test_build_entry_no_steps(): + # trace records that have no meta_vars.step key + trace = [{"function": "foo"}] + result = CheckerResult( + invariant=_make_invariant(), + trace=trace, + check_passed=False, + triggered=True, + ) + entry = _build_violation_entry(result) + assert entry["first_step"] is None + assert entry["last_step"] is None + + +def test_build_entry_display_name_is_string(): + result = _make_result(steps=[1]) + entry = _build_violation_entry(result) + assert isinstance(entry["display_name"], str) + assert len(entry["display_name"]) > 0 + + +def test_build_entry_relation_type(): + result = _make_result(steps=[1]) + entry = _build_violation_entry(result) + assert entry["relation_type"] == "FunctionCoverRelation" + + +# --------------------------------------------------------------------------- +# build_violations_summary +# --------------------------------------------------------------------------- + + +def test_summary_no_failures(): + results = [_make_result(steps=[1], check_passed=True)] + summary = build_violations_summary(results) + assert summary["distinct_invariants_violated"] == 0 + assert summary["violations"] == [] + assert summary["first_violation_step"] is None + + +def test_summary_with_failures(): + results = [ + _make_result(steps=[2, 4], check_passed=False), + _make_result(steps=[1], check_passed=False), + _make_result(steps=[5], check_passed=True), + ] + summary = build_violations_summary(results) + assert summary["distinct_invariants_violated"] == 2 + assert summary["first_violation_step"] == 1 + assert len(summary["violations"]) == 2 + + +def test_summary_first_step_across_violations(): + results = [ + _make_result(steps=[10, 20], check_passed=False), + _make_result(steps=[3], check_passed=False), + ] + summary = build_violations_summary(results) + assert summary["first_violation_step"] == 3 + + +def test_summary_no_step_data(): + # violation with trace records that carry no step information + result = CheckerResult( + invariant=_make_invariant(), + trace=[{"function": "foo"}], + check_passed=False, + triggered=True, + ) + summary = build_violations_summary([result]) + assert summary["distinct_invariants_violated"] == 1 + assert summary["first_violation_step"] is None + assert summary["violations"][0]["first_step"] is None diff --git a/traincheck/checker.py b/traincheck/checker.py index 45086ea8..249945c7 100644 --- a/traincheck/checker.py +++ b/traincheck/checker.py @@ -6,8 +6,13 @@ from tqdm import tqdm +import traincheck.utils as _tc_utils from traincheck.invariant import CheckerResult, Invariant, read_inv_file -from traincheck.reporting import ReportEmitter, build_offline_report_data +from traincheck.reporting import ( + ReportEmitter, + build_offline_report_data, + build_violations_summary, +) from traincheck.trace import MDNONEJSONEncoder, Trace, select_trace_implementation from traincheck.utils import register_custom_excepthook @@ -36,17 +41,39 @@ def check_engine( ) -> list[CheckerResult]: logger = logging.getLogger(__name__) results = [] - for inv in tqdm( - invariants, desc="Checking invariants", unit="invariant", leave=False - ): - assert ( - inv.precondition is not None - ), "Invariant precondition is None. It should at least be 'Unconditional' or an empty list. Please check the invariant file and the inference process." - logger.info("=====================================") - res = inv.check(trace, check_relation_first) - res.calc_and_set_time_precentage(trace.get_start_time(), trace.get_end_time()) - logger.info("Invariant %s on trace %s: %s", inv, trace, res) - results.append(res) + total = len(invariants) + n_violated = 0 + bar_fmt = ( + "{desc} {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]" + ) + with tqdm( + total=total, + bar_format=bar_fmt, + unit="inv", + desc=f"0 checked · {total} left · 0 violated", + ) as pbar: + _tc_utils._suppress_inner_progress = True + try: + for i, inv in enumerate(invariants): + assert ( + inv.precondition is not None + ), "Invariant precondition is None. It should at least be 'Unconditional' or an empty list. Please check the invariant file and the inference process." + logger.info("=====================================") + res = inv.check(trace, check_relation_first) + res.calc_and_set_time_precentage( + trace.get_start_time(), trace.get_end_time() + ) + logger.info("Invariant %s on trace %s: %s", inv, trace, res) + results.append(res) + if not res.check_passed: + n_violated += 1 + done = i + 1 + pbar.set_description( + f"{done} checked · {total - done} left · {n_violated} violated" + ) + pbar.update(1) + finally: + _tc_utils._suppress_inner_progress = False return results @@ -143,6 +170,12 @@ def main(): default=None, help="Weights & Biases tags.", ) + parser.add_argument( + "--wandb-run-id", + type=str, + default=None, + help="Attach to an existing Weights & Biases run ID (e.g. to overlay violation metrics on a training run).", + ) parser.add_argument( "--report-mlflow", action="store_true", @@ -232,7 +265,15 @@ def main(): logger.info("Reading traces from %s", "\n".join(trace_files)) traces.append(read_trace_file(trace_files)) - logger.addHandler(logging.StreamHandler()) + # Warnings and above go to stderr so the user sees them; detailed per-invariant + # INFO logs stay in the log file only. + stream_handler = logging.StreamHandler() + stream_handler.setLevel(logging.WARNING) + stream_handler.setFormatter( + logging.Formatter("[TrainCheck] %(levelname)s: %(message)s") + ) + logger.addHandler(stream_handler) + results_by_trace: list[tuple[str, list[CheckerResult]]] = [] for trace, trace_parent_folder in zip(traces, trace_parent_folders): @@ -244,6 +285,15 @@ def main(): res for res in results_per_trace if res.triggered is False ] + n_checked = len(results_per_trace) + n_failed = len(results_per_trace_failed) + n_passed = n_checked - n_failed + n_not_triggered = len(results_per_trace_not_triggered) + print(f"Checking finished. {n_checked} invariants checked") + print(f"Total failed invariants: {n_failed}/{n_checked}") + print(f"Total passed invariants: {n_passed}/{n_checked}") + print(f"Total invariants that are not triggered: {n_not_triggered}/{n_checked}") + logger.info("Checking finished. %d invariants checked", len(results_per_trace)) logger.info( "Total failed invariants: %d/%d", @@ -292,6 +342,15 @@ def main(): json.dump(res.to_dict(), f, indent=4, cls=MDNONEJSONEncoder) f.write("\n") + violations_summary = build_violations_summary(results_per_trace) + with open( + os.path.join( + args.output_dir, trace_parent_folder, "violations_summary.json" + ), + "w", + ) as f: + json.dump(violations_summary, f, indent=2) + results_by_trace.append((trace_parent_folder, results_per_trace)) report_data = build_offline_report_data( diff --git a/traincheck/checker_online.py b/traincheck/checker_online.py index 9279b77c..d28c2a8e 100644 --- a/traincheck/checker_online.py +++ b/traincheck/checker_online.py @@ -23,6 +23,13 @@ ) NUM_VIOLATIONS = 0 FAILED_INV: dict[Invariant, int] = {} +VIOLATION_DETAILS: dict[Invariant, dict] = {} +TRIGGERED_INV: set[Invariant] = set() +ALL_INVS: list[Invariant] = [] +CURRENT_STEP: int | None = None +CURRENT_STAGE: str | None = None +SAMPLING_INTERVAL: int | None = None +WARM_UP_STEPS: int | None = None TOTAL_INVARIANTS = 0 RELATION_TOTALS: dict[str, int] = {} REPORTER: ReportEmitter | None = None @@ -100,7 +107,7 @@ def sort_inv_file(invariants): vartype_to_invs: dict[str, dict[str, list[Invariant]]] = {} needed_vars = set() needed_apis = set() - _get_api_args_map_to_check = set() + all_needed_args_api = set() for inv in invs: assert ( inv.precondition is not None @@ -114,7 +121,7 @@ def sort_inv_file(invariants): if needed_api is not None: needed_apis.update(needed_api) if needed_args_api is not None: - _get_api_args_map_to_check.update(needed_args_api) + all_needed_args_api.update(needed_args_api) for param in params: if isinstance(param, VarTypeParam): if param.var_type not in vartype_to_invs: @@ -127,7 +134,7 @@ def sort_inv_file(invariants): param_to_invs[param] = [] param_to_invs[param].append(inv) logger.info("Sorting done.") - needed_data = (needed_vars, needed_apis, _get_api_args_map_to_check) + needed_data = (needed_vars, needed_apis, all_needed_args_api) return invs, param_to_invs, vartype_to_invs, needed_data @@ -139,6 +146,54 @@ def get_violated_pair_hash(trace_pair): return tuple(sorted((h1, h2), reverse=True)) +_MAX_TRACKED_STEPS = 500 # cap on steps stored per invariant + + +def _record_violation_details( + inv: Invariant, result, violation_details: dict[Invariant, dict] +): + """Update per-invariant (step, stage) list and sample trace for the HTML report.""" + trace = result.trace or [] + step_stages = [ + (r["meta_vars.step"], r.get("meta_vars.stage")) + for r in trace + if isinstance(r, dict) and r.get("meta_vars.step") is not None + ] + if inv not in violation_details: + violation_details[inv] = {"step_stages": [], "sample_trace": None} + detail = violation_details[inv] + remaining = _MAX_TRACKED_STEPS - len(detail["step_stages"]) + if remaining > 0: + detail["step_stages"].extend(step_stages[:remaining]) + if detail["sample_trace"] is None and trace: + detail["sample_trace"] = trace[:8] + + +def _read_sampling_config( + trace_folders: list[str] | None, +) -> tuple[int | None, int | None]: + """Parse sampling_interval and warm_up_steps from env_dump.txt in a trace folder.""" + import re + + for folder in trace_folders or []: + env_dump_path = os.path.join(folder, "env_dump.txt") + if not os.path.exists(env_dump_path): + continue + sampling_interval: int | None = None + warm_up_steps: int | None = None + with open(env_dump_path) as fh: + for line in fh: + m = re.match(r"^sampling_interval:\s*(\d+)", line) + if m: + sampling_interval = int(m.group(1)) + m = re.match(r"^warm_up_steps:\s*(\d+)", line) + if m: + warm_up_steps = int(m.group(1)) + if sampling_interval is not None or warm_up_steps is not None: + return sampling_interval, warm_up_steps + return None, None + + def _emit_report(force: bool = False): if REPORTER is None: return @@ -149,6 +204,13 @@ def _emit_report(force: bool = False): total_violations=NUM_VIOLATIONS, failed_inv=FAILED_INV, relation_totals=RELATION_TOTALS, + violation_details=VIOLATION_DETAILS, + triggered_inv=TRIGGERED_INV, + all_invs=ALL_INVS, + current_step=CURRENT_STEP, + current_stage=CURRENT_STAGE, + sampling_interval=SAMPLING_INTERVAL, + warm_up_steps=WARM_UP_STEPS, ) report_state = (NUM_VIOLATIONS, len(FAILED_INV)) REPORTER.emit(report_data, force=force, report_state=report_state) @@ -160,6 +222,13 @@ def check( global OBSERVER global NUM_VIOLATIONS global FAILED_INV + global VIOLATION_DETAILS + global TRIGGERED_INV + global ALL_INVS + global CURRENT_STEP + global CURRENT_STAGE + global SAMPLING_INTERVAL + global WARM_UP_STEPS global TOTAL_INVARIANTS global RELATION_TOTALS @@ -169,8 +238,11 @@ def check( logger.addHandler(logging.StreamHandler()) logger.info("Starting online checker") + SAMPLING_INTERVAL, WARM_UP_STEPS = _read_sampling_config(trace_folders) + invs, param_to_invs, vartype_to_invs, needed_data = sort_inv_file(invariants) TOTAL_INVARIANTS = len(invs) + ALL_INVS = list(invs) RELATION_TOTALS = defaultdict(int) for inv in invs: RELATION_TOTALS[inv.relation.__name__] += 1 @@ -198,6 +270,13 @@ def check( else: break + step = trace_record.get("meta_vars.step") + stage = trace_record.get("meta_vars.stage") + if step is not None: + CURRENT_STEP = step + if stage is not None: + CURRENT_STAGE = stage + if "var_name" in trace_record and trace_record["var_name"] is not None: varid = VarInstId( trace_record["process_id"], @@ -216,6 +295,7 @@ def check( result = inv.online_check( trace_record, checker_data, check_relation_first ) + TRIGGERED_INV.add(inv) if not result.check_passed: violated_pair = get_violated_pair_hash(result.trace) if inv not in violated_pairs: @@ -227,6 +307,9 @@ def check( if inv not in FAILED_INV: FAILED_INV[inv] = 0 FAILED_INV[inv] += 1 + _record_violation_details( + inv, result, VIOLATION_DETAILS + ) NUM_VIOLATIONS += 1 result.set_id_and_detection_time( NUM_VIOLATIONS, time.monotonic_ns() @@ -258,16 +341,18 @@ def check( result = inv.online_check( trace_record, checker_data, check_relation_first ) + TRIGGERED_INV.add(inv) if not result.check_passed: if inv not in FAILED_INV: FAILED_INV[inv] = 0 FAILED_INV[inv] += 1 + _record_violation_details(inv, result, VIOLATION_DETAILS) NUM_VIOLATIONS += 1 result.set_id_and_detection_time( NUM_VIOLATIONS, time.monotonic_ns() ) logger.error( - f"Violated id {NUM_VIOLATIONS}:\nInvariant {inv} violated near time {trace_record['time']}" + f"Violated id {NUM_VIOLATIONS}:\nInvariant {inv.text_description} violated near time {trace_record['time'], trace_record['meta_vars.step']}" ) with open(output_file, "a") as f: json.dump( @@ -392,6 +477,12 @@ def main(): default=None, help="Weights & Biases tags.", ) + parser.add_argument( + "--wandb-run-id", + type=str, + default=None, + help="Attach to an existing Weights & Biases run ID (e.g. to overlay violation metrics on a training run).", + ) parser.add_argument( "--report-mlflow", action="store_true", diff --git a/traincheck/collect_trace.py b/traincheck/collect_trace.py index e45e996f..2d307e49 100644 --- a/traincheck/collect_trace.py +++ b/traincheck/collect_trace.py @@ -104,7 +104,7 @@ def merge(a: dict, b: dict, path=[]): func_instr_opts[func_name]["dump_args"] = True func_instr_opts[func_name]["dump_ret"] = True # TODO: convert the arguments to instr_opts_dict (currently not possible as the index indicates the index of the argument/ret value among other tensors not all arguments) - logger.warning( + logger.debug( "Currently not supporting fine-grained dumping of arguments and return values for InputOutputParam" ) @@ -431,12 +431,13 @@ def main(): if args.warm_up_steps is None: args.warm_up_steps = config.INSTRUMENTATION_POLICY["warm_up"] - # set up logging + # set up logging (force=True overrides any handler already added by imported libs) + _log_fmt = "[TrainCheck] %(levelname)s: %(message)s" if args.debug_mode: - logging.basicConfig(level=logging.DEBUG) + logging.basicConfig(level=logging.DEBUG, format=_log_fmt, force=True) os.environ["TRAINCHECK_DEBUG"] = "1" else: - logging.basicConfig(level=logging.INFO) + logging.basicConfig(level=logging.WARNING, format=_log_fmt, force=True) logger = logging.getLogger(__name__) diff --git a/traincheck/config/config.py b/traincheck/config/config.py index b7d019b0..ddbc4a2a 100644 --- a/traincheck/config/config.py +++ b/traincheck/config/config.py @@ -86,6 +86,7 @@ "torch.optim.optimizer._get_value", "torch.overrides", "._", # skip all private functions (they can only be the contained, but not containing functions) + "", # skip closures / local functions — not meaningful in invariant descriptions ] INSTR_OPTS = None # TODO: set defaults for this variable diff --git a/traincheck/infer_engine.py b/traincheck/infer_engine.py index 3b452c6d..917cb530 100644 --- a/traincheck/infer_engine.py +++ b/traincheck/infer_engine.py @@ -9,6 +9,7 @@ from tqdm import tqdm import traincheck.config.config as config +import traincheck.utils as _tc_utils from traincheck.invariant import ( FailedHypothesis, Hypothesis, @@ -51,70 +52,82 @@ def generate_hypothesis(self) -> dict[Hypothesis, list[int]]: Returns: dict[Hypothesis, list[int]]: A dictionary mapping hypotheses to the indices of traces that support them """ - logger.info("============= GENERATING HYPOTHESIS =============") hypotheses_and_trace_idxs: dict[Hypothesis, list[int]] = {} - hypo_lookup = {} # Dictionary for O(1) lookup of hypotheses + hypo_lookup = {} + n_traces = len(self.traces) + active_relations = [ + r for r in relation_pool if r not in self.disabled_relations + ] + + rel_bar_fmt = "{desc} {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]" for trace_idx, trace in enumerate(self.traces): - logger.info(f"Processing trace {trace_idx + 1}/{len(self.traces)}") - for relation_idx, relation in enumerate(relation_pool): - logger.info( - f"Processing relation {relation_idx + 1}/{len(relation_pool)}: {relation.__name__}" - ) - if self.disabled_relations and relation in self.disabled_relations: + tqdm.write(f"\n[Trace {trace_idx + 1}/{n_traces}] Generating hypotheses") + with tqdm( + active_relations, + bar_format=rel_bar_fmt, + unit="relation", + leave=True, + ) as rel_bar: + for relation in rel_bar: + rel_bar.set_description(f" {relation.__name__}") logger.info( - f"Skipping relation {relation.__name__} as it is disabled" + f"Generating hypotheses for relation: {relation.__name__}" ) - continue - logger.info(f"Generating hypotheses for relation: {relation.__name__}") - inferred_hypos = relation.generate_hypothesis(trace) - logger.info( - f"Found {len(inferred_hypos)} hypotheses for relation: {relation.__name__} on trace {trace_idx + 1}/{len(self.traces)}" - ) - logger.info( - f"Merging hypotheses with existing ones, number of existing ones: {len(hypotheses_and_trace_idxs)}" - ) - for hypo in tqdm( - inferred_hypos, desc="Merging Hypotheses with existing ones" - ): - if hypo not in hypotheses_and_trace_idxs: - hypotheses_and_trace_idxs[hypo] = [trace_idx] - hypo_lookup[hypo] = hypo # Add to lookup dictionary - else: - hypotheses_and_trace_idxs[hypo].append(trace_idx) - original_hypo = hypo_lookup[hypo] # O(1) lookup - orig_num_pos_exps = len(original_hypo.positive_examples) - orig_num_neg_exps = len(original_hypo.negative_examples) - original_hypo.positive_examples.examples.extend( - hypo.positive_examples.examples - ) - original_hypo.negative_examples.examples.extend( - hypo.negative_examples.examples - ) - - assert len( - hypo_lookup[hypo].positive_examples - ) == orig_num_pos_exps + len( - hypo.positive_examples - ), f"Expected {orig_num_pos_exps} + {len(hypo.positive_examples)} positive examples, got {len(hypo_lookup[hypo].positive_examples)}" - assert len( - hypo_lookup[hypo].negative_examples - ) == orig_num_neg_exps + len( - hypo.negative_examples - ), f"Expected {orig_num_neg_exps} + {len(hypo.negative_examples)} negative examples, got {len(hypo_lookup[hypo].negative_examples)}" - logger.info(f"Finished processing trace {trace_idx + 1}/{len(self.traces)}") - logger.info( - f"Finished generating hypotheses, found {len(hypotheses_and_trace_idxs)} hypotheses" - ) + t0 = time.time() + inferred_hypos = relation.generate_hypothesis(trace) + elapsed = time.time() - t0 + tqdm.write( + f" {relation.__name__}: {len(inferred_hypos)} hypotheses ({elapsed:.1f}s)" + ) + logger.info( + f"Found {len(inferred_hypos)} hypotheses for {relation.__name__} " + f"on trace {trace_idx + 1}/{n_traces}" + ) + for hypo in inferred_hypos: + if hypo not in hypotheses_and_trace_idxs: + hypotheses_and_trace_idxs[hypo] = [trace_idx] + hypo_lookup[hypo] = hypo + else: + hypotheses_and_trace_idxs[hypo].append(trace_idx) + original_hypo = hypo_lookup[hypo] + orig_num_pos_exps = len(original_hypo.positive_examples) + orig_num_neg_exps = len(original_hypo.negative_examples) + original_hypo.positive_examples.examples.extend( + hypo.positive_examples.examples + ) + original_hypo.negative_examples.examples.extend( + hypo.negative_examples.examples + ) + + assert len( + hypo_lookup[hypo].positive_examples + ) == orig_num_pos_exps + len( + hypo.positive_examples + ), f"Expected {orig_num_pos_exps} + {len(hypo.positive_examples)} positive examples, got {len(hypo_lookup[hypo].positive_examples)}" + assert len( + hypo_lookup[hypo].negative_examples + ) == orig_num_neg_exps + len( + hypo.negative_examples + ), f"Expected {orig_num_neg_exps} + {len(hypo.negative_examples)} negative examples, got {len(hypo_lookup[hypo].negative_examples)}" + + total = len(hypotheses_and_trace_idxs) + print(f"\n {total} hypotheses generated across all relations") + logger.info(f"Finished generating hypotheses, found {total} hypotheses") return hypotheses_and_trace_idxs def collect_examples(self, hypotheses: dict[Hypothesis, list[int]]): logger.info("============= COLLECTING EXAMPLES =============") - logger.info(f"Start collecting examples for {len(hypotheses)} hypotheses") - for hypo, trace_idxs in hypotheses.items(): - logger.info( - f"Collecting examples for hypothesis: {hypo.invariant.text_description}" + cross_trace_hypos = [ + (hypo, trace_idxs) + for hypo, trace_idxs in hypotheses.items() + if len(set(range(len(self.traces))) - set(trace_idxs)) > 0 + ] + if cross_trace_hypos: + print( + f"\nCollecting examples for {len(cross_trace_hypos)} cross-trace hypotheses" ) + for hypo, trace_idxs in cross_trace_hypos: for trace_idx, trace in enumerate(self.traces): if trace_idx in trace_idxs: continue @@ -125,7 +138,6 @@ def collect_examples(self, hypotheses: dict[Hypothesis, list[int]]): def prune_incorrect_hypos(self, hypotheses: dict[Hypothesis, list[int]]): """Prune incorrect hypotheses based on the collected examples""" - incorrect_hypos = [] correct_hypos = {} for hypo, trace_idxs in hypotheses.items(): @@ -135,30 +147,46 @@ def prune_incorrect_hypos(self, hypotheses: dict[Hypothesis, list[int]]): incorrect_hypos.append( FailedHypothesis(hypo, "only one positive example") ) + n_pruned = len(incorrect_hypos) + n_kept = len(correct_hypos) + print(f" {n_pruned} pruned (insufficient examples) → {n_kept} remaining") return correct_hypos, incorrect_hypos def infer_precondition(self, hypotheses: dict[Hypothesis, list[int]]): """TODO: move the precondition inference driving code into Hypothesis.get_invariant()""" logger.info("============= INFERING PRECONDITIONS =============") - logger.info(f"Inferring preconditions for {len(hypotheses)} hypotheses") - all_hypotheses: list[Hypothesis] = [] - for hypo in hypotheses: - all_hypotheses.append(hypo) + all_hypotheses = list(hypotheses.keys()) + total = len(all_hypotheses) + print(f"\nInferring preconditions for {total} hypotheses") invariants = [] failed_hypos = [] - for hypo_idx, hypothesis in enumerate(all_hypotheses): - logger.info( - f"Inferring precondition for hypothesis {hypo_idx + 1}/{len(all_hypotheses)}: {hypothesis.invariant.text_description}" - ) - precondition = find_precondition(hypothesis, self.traces) - if precondition is None: - failed_hypos.append( - FailedHypothesis(hypothesis, "Precondition not found") - ) - else: - hypothesis.invariant.precondition = precondition - invariants.append(hypothesis.get_invariant(self.all_stages)) + bar_fmt = "{desc} {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]" + _tc_utils._suppress_inner_progress = True + try: + with tqdm(total=total, bar_format=bar_fmt, unit="hypo") as pbar: + pbar.set_description("0 done · 0 failed") + for hypo_idx, hypothesis in enumerate(all_hypotheses): + logger.info( + f"Inferring precondition for hypothesis {hypo_idx + 1}/{total}: " + f"{hypothesis.invariant.text_description}" + ) + precondition = find_precondition(hypothesis, self.traces) + if precondition is None: + failed_hypos.append( + FailedHypothesis(hypothesis, "Precondition not found") + ) + else: + hypothesis.invariant.precondition = precondition + invariants.append(hypothesis.get_invariant(self.all_stages)) + pbar.set_description( + f"{len(invariants)} done · {len(failed_hypos)} failed" + ) + pbar.update(1) + finally: + _tc_utils._suppress_inner_progress = False + + print(f" {len(invariants)} invariants · {len(failed_hypos)} failed") return invariants, failed_hypos @@ -257,7 +285,7 @@ def main(): logging.basicConfig( filename=f'traincheck_infer_engine_{datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}_{pid}.log', level=log_level, - format="%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)s - %(funcName)20s()] - %(message)s", + format="%(asctime)s - [TrainCheck] %(levelname)s - [%(filename)s:%(lineno)s - %(funcName)20s()] - %(message)s", ) disabled_relations: list[Relation] = [] diff --git a/traincheck/instrumentor/control.py b/traincheck/instrumentor/control.py index a1f5c87d..84e0ca5f 100644 --- a/traincheck/instrumentor/control.py +++ b/traincheck/instrumentor/control.py @@ -32,17 +32,17 @@ def start_step(): config.DISABLE_WRAPPER = False if current_step < warm_up: - print(f"Warmup step {current_step}") + logger.debug(f"Warmup step {current_step}") config.DISABLE_WRAPPER = False elif (current_step - warm_up) % interval == 0: - print(f"Interval step {current_step}") + logger.debug(f"Interval step {current_step}") config.DISABLE_WRAPPER = False else: - print(f"Skipping step {current_step}") + logger.debug(f"Skipping step {current_step}") config.DISABLE_WRAPPER = True else: # No policy, always enable - print("No policy, always enable") + logger.debug("No policy, always enable") config.DISABLE_WRAPPER = False @@ -65,13 +65,13 @@ def start_eval_step(): config.DISABLE_WRAPPER = False if current_step < warm_up: - print(f"Eval: Warmup step {current_step}") + logger.debug(f"Eval: Warmup step {current_step}") config.DISABLE_WRAPPER = False elif (current_step - warm_up) % interval == 0: - print(f"Eval: Interval step {current_step}") + logger.debug(f"Eval: Interval step {current_step}") config.DISABLE_WRAPPER = False else: - print(f"Eval: Skipping step {current_step}") + logger.debug(f"Eval: Skipping step {current_step}") config.DISABLE_WRAPPER = True else: config.DISABLE_WRAPPER = False diff --git a/traincheck/instrumentor/dumper.py b/traincheck/instrumentor/dumper.py index a7471053..27c458f8 100644 --- a/traincheck/instrumentor/dumper.py +++ b/traincheck/instrumentor/dumper.py @@ -92,8 +92,8 @@ def serialize(obj_dict: dict[str, object | str]) -> str: def monitor_main_thread(main_thread, stop_event): main_thread.join() # Wait for the main thread to finish - print("Main thread has finished or encountered an exception") - print("Flushing all buffers to the trace log file") + logger.debug("Main thread has finished or encountered an exception") + logger.debug("Flushing all buffers to the trace log file") stop_event.set() # Signal the logging threads to stop @@ -106,12 +106,12 @@ def trace_dumper(task_queue: Queue, trace_file_name: str, stop_event: threading. ) # wait for 2x the flush interval, this is an arbitrary number, as long as it is larger than the flush interval, it should be fine. except Empty: if stop_event.is_set(): - print("Trace dumper thread has stopped.") + logger.debug("Trace dumper thread has stopped.") break continue f.write(f"{trace}\n") task_queue.task_done() - print("Trace dumper thread has finished normally...") + logger.debug("Trace dumper thread has finished normally.") def get_trace_API_dumper_queue(): @@ -316,11 +316,15 @@ class NOT_FOUND: def safe_getattr(obj, attr_name): + import warnings + try: - attr = getattr(obj, attr_name, NOT_FOUND) - if attr is NOT_FOUND: - if issubclass(type(obj), dict): - attr = dict.get(obj, attr_name, NOT_FOUND) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + attr = getattr(obj, attr_name, NOT_FOUND) + if attr is NOT_FOUND: + if issubclass(type(obj), dict): + attr = dict.get(obj, attr_name, NOT_FOUND) return attr except Exception: return NOT_FOUND @@ -392,7 +396,7 @@ def convert_var_to_dict(var, include_tensor_data=True, dump_config=None) -> dict and isinstance(var, torch.Tensor) and not include_tensor_data ): - logger.warning( + logger.debug( f"Failed to get attribute {attr_name} of object type {type(var)}, skipping it for all following dumps for this attribute." ) if var_type not in skip_attrs_due_to_errs: diff --git a/traincheck/instrumentor/proxy_wrapper/proxy.py b/traincheck/instrumentor/proxy_wrapper/proxy.py index d9003f10..4def285b 100644 --- a/traincheck/instrumentor/proxy_wrapper/proxy.py +++ b/traincheck/instrumentor/proxy_wrapper/proxy.py @@ -21,6 +21,8 @@ from .proxy_registry import get_global_registry from .utils import print_debug +logger = logging.getLogger(__name__) + class ProxyObjInfo: def __init__(self, var_name: str, last_update_timestamp: int, version: int | None): @@ -118,9 +120,8 @@ def proxy_parameters(module: torch.nn.Module, parent_name="", from_iter=False): time_end = time.perf_counter() if num_params != 0: - print( - "logger_proxy: " - + f"Proxied {num_params} parameters of '{parent_name + module.__class__.__name__}', duration: {time_end - start_time} seconds" + logger.debug( + f"Proxied {num_params} parameters of '{parent_name + module.__class__.__name__}', duration: {time_end - start_time:.3f}s" ) def update_timestamp(self): diff --git a/traincheck/instrumentor/source_file.py b/traincheck/instrumentor/source_file.py index e056b601..5ddb82f7 100644 --- a/traincheck/instrumentor/source_file.py +++ b/traincheck/instrumentor/source_file.py @@ -125,10 +125,10 @@ def _get_loop_context(self, node): if iter_name: if "train" in iter_name: - print(f"Found training loop based on iterator: {iter_name}") + logger.debug(f"Found training loop based on iterator: {iter_name}") return "training" elif any(x in iter_name for x in ["val", "eval", "test"]): - print(f"Found eval loop based on iterator: {iter_name}") + logger.debug(f"Found eval loop based on iterator: {iter_name}") return "eval" # Heuristic 2: Check for calls to .step() or .backward() or .eval() @@ -187,25 +187,25 @@ def _get_loop_context(self, node): if isinstance(expr.func, ast.Attribute): if expr.func.attr == "no_grad": has_eval_signal = True - print(f"Found no_grad context in loop {node}.") + logger.debug(f"Found no_grad context in loop {node}.") if has_training_signal: - print(f"Found training signal in loop {node}.") + logger.debug(f"Found training signal in loop {node}.") return "training" if has_eval_signal: - print(f"Found eval signal in loop {node}.") + logger.debug(f"Found eval signal in loop {node}.") return "eval" # if the number of lines are too few and the function calls do not involve "eval", "train", we omit the loop context # We use statement_count calculated recursively if statement_count < 3: - print( + logger.debug( f"Skipping loop {node} as it is too short ({statement_count} statements) and does not contain eval/train/step/backward signal." ) return None - print(f"Found eval signal in loop {node} (fallback).") + logger.debug(f"Found eval signal in loop {node} (fallback).") return "eval" def _inject_call(self, node, func_name): @@ -465,7 +465,7 @@ def get_child_parent_map(root) -> dict[ast.AST, ast.AST]: for node in ast.walk(root): for child in ast.iter_child_nodes(node): if child in parent_map and not ast.unparse(child).strip() == "": - print( + logger.debug( f"Node {ast.unparse(child)} already has a parent, {ast.unparse(parent_map[child])}" ) parent_map[child] = node @@ -480,7 +480,7 @@ def instrument_all_model_assignments( Finds all assignment statements to `model` and inserts a Proxy statement or a VarSampler statement after each assignment, depending on the mode. """ - print( + logger.debug( f"Instrumenting model: {model_name}, mode: {mode}, scanning for assignments to {model_name}" ) @@ -529,10 +529,10 @@ def instrument_all_model_assignments( if node in parent_map: parent = parent_map[node] # print(f"Parent node: {ast.unparse(parent)}") - print("\tInstrumenting: ", ast.unparse(node)) + logger.debug("Instrumenting: %s", ast.unparse(node)) if isinstance(parent, ast.For): - print( - "\t\t⬆️ Parent is a for loop, cowardly skipping instrumentation in fear of multiple models with the same 'var_name'" + logger.debug( + "Parent is a for loop, skipping instrumentation to avoid multiple models with the same 'var_name'" ) continue if node in parent.body: # type: ignore @@ -601,25 +601,18 @@ def instrument_model_tracker_proxy( spec = importlib.util.find_spec('traincheck') if spec and spec.origin: traincheck_folder = os.path.dirname(spec.origin) - print("traincheck folder: ", traincheck_folder) else: raise Exception("traincheck is not installed properly") -print("auto observer enabled with observing depth: ", auto_observer_config["enable_auto_observer_depth"]) enable_auto_observer_depth = auto_observer_config["enable_auto_observer_depth"] neglect_hidden_func = auto_observer_config["neglect_hidden_func"] neglect_hidden_module = auto_observer_config["neglect_hidden_module"] observe_then_unproxy = auto_observer_config["observe_then_unproxy"] observe_up_to_depth = auto_observer_config["observe_up_to_depth"] -if observe_up_to_depth: - print("observe up to the depth of the function call") -else: - print("observe only the function call at the depth") from traincheck.static_analyzer.graph_generator.call_graph_parser import add_observer_given_call_graph log_files = glob.glob( os.path.join(traincheck_folder, "static_analyzer", "func_level", "*.log") ) -print("log_files: ", log_files) for log_file in log_files: add_observer_given_call_graph( log_file, @@ -743,7 +736,7 @@ def has_stage(src: str, name: str) -> bool: for stage_name, present in orig_has.items(): if present: - logger.info( + logger.debug( _ctx( f"Stage '{stage_name}' already present in source; skip adding this stage." ) @@ -826,13 +819,13 @@ def at_attr(name: str) -> bool: indent = m.group(0) new_lines.append(f'{indent}annotate_stage("{stage}")\n') inserted_count[stage] += 1 - logger.info( + logger.debug( _ctx( f"Inserted stage '{stage}' before line {lineno}: {line.strip()}" ) ) else: - logger.info( + logger.debug( _ctx( f"Skip inserting '{stage}' at line {lineno} (previous non-empty line already has it)." ) @@ -865,7 +858,7 @@ def _find_annotate_import_idx(lines): lines_list.insert(insert_idx, "from traincheck import annotate_stage\n") annot_import_idx = insert_idx inserted_count["import"] += 1 - logger.info( + logger.debug( _ctx( f"Inserted import 'from traincheck import annotate_stage' at line {annot_import_idx + 1}." ) @@ -924,13 +917,13 @@ def is_single_line_triple_quoted_string(line: str, quote: str) -> bool: if not (("annotate_stage" in prev) and ("init" in prev)): nl.insert(insert_at, f'{body_indent}annotate_stage("init")\n') inserted_count["init"] += 1 - logger.info( + logger.debug( _ctx( f"Inserted stage 'init' at start of main() body (line {insert_at + 1})." ) ) else: - logger.info( + logger.debug( _ctx( "Skip inserting 'init' inside main(): previous non-empty line already has it." ) @@ -973,13 +966,13 @@ def is_single_line_triple_quoted_string(line: str, quote: str) -> bool: if not (("annotate_stage" in next_line) and ("init" in next_line)): lines2.insert(insert_at, 'annotate_stage("init")\n') inserted_count["init"] += 1 - logger.info( + logger.debug( _ctx( f"Inserted stage 'init' right after annotate_stage import at line {insert_at + 1}." ) ) else: - logger.info( + logger.debug( _ctx( "Skip inserting 'init': next non-empty line after annotate_stage import is already init." ) @@ -1072,7 +1065,7 @@ def instrument_file( if model_tracker_style == "proxy" or model_tracker_style == "subclass": if model_tracker_style == "subclass": # adjust the proxy config to disable the proxy-specific configs - print( + logger.debug( "Using subclass model tracker, overriding observe_then_unproxy to False" ) adjusted_proxy_config[0]["observe_then_unproxy"] = False diff --git a/traincheck/instrumentor/tracer.py b/traincheck/instrumentor/tracer.py index 50e532a5..8a6ac1e1 100644 --- a/traincheck/instrumentor/tracer.py +++ b/traincheck/instrumentor/tracer.py @@ -627,6 +627,8 @@ def __init__( ) def instrument(self) -> int: + import warnings + if not self.instrumenting: return 0 @@ -639,9 +641,11 @@ def instrument(self) -> int: "First pass: Recursive scan of the module" ) assert isinstance(self.target, (types.ModuleType, type)), "Invalid target" - first_pass_instrumented_count += self._instrument_module( - self.target, visited_file_paths, True, 0 - ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + first_pass_instrumented_count += self._instrument_module( + self.target, visited_file_paths, True, 0 + ) get_instrumentation_logger_for_process().info( "Files scanned %s", "\n".join(sorted(visited_file_paths)) ) @@ -665,13 +669,15 @@ def instrument(self) -> int: f"Instrumenting module {module_path}" ) - pymodule = importlib.import_module(module_path) - second_pass_instrumented_count += self._instrument_module( - pymodule, - visited_file_paths, - False, - 0, - ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + pymodule = importlib.import_module(module_path) + second_pass_instrumented_count += self._instrument_module( + pymodule, + visited_file_paths, + False, + 0, + ) get_instrumentation_logger_for_process().info( "Second pass instrumented %d functions", second_pass_instrumented_count ) diff --git a/traincheck/invariant/DistinctArgumentRelation.py b/traincheck/invariant/DistinctArgumentRelation.py index 025f7c69..0088715d 100644 --- a/traincheck/invariant/DistinctArgumentRelation.py +++ b/traincheck/invariant/DistinctArgumentRelation.py @@ -1,8 +1,8 @@ +import logging from itertools import combinations from typing import Any, Dict, Iterable, List, Set, Tuple -from tqdm import tqdm - +from traincheck.config.config import ANALYSIS_SKIP_FUNC_NAMES from traincheck.instrumentor.tracer import TraceLineType from traincheck.invariant.base_cls import ( # GroupedPreconditions, APIParam, @@ -15,12 +15,16 @@ OnlineCheckerResult, Param, Relation, + _short_api_name, ) from traincheck.invariant.precondition import find_precondition from traincheck.onlinechecker.utils import Checker_data, set_meta_vars_online +from traincheck.progress import tqdm from traincheck.trace.trace import Trace from traincheck.utils import safe_isnan +logger = logging.getLogger(__name__) + EXP_GROUP_NAME = "distinct_arg" MAX_FUNC_NUM_CONSECUTIVE_CALL = 6 IOU_THRESHHOLD = 0.1 # pre-defined threshhold for IOU @@ -42,9 +46,9 @@ def get_func_names_to_deal_with(trace: Trace) -> List[str]: # get all functions in the trace all_func_names = trace.get_func_names() - # filtering 1: remove private functions + # filtering 1: skip functions matched by ANALYSIS_SKIP_FUNC_NAMES for func_name in all_func_names: - if "._" in func_name: + if any(skip in func_name for skip in ANALYSIS_SKIP_FUNC_NAMES): continue function_pool.add(func_name) @@ -198,7 +202,7 @@ class DistinctArgumentRelation(Relation): def generate_hypothesis(trace) -> list[Hypothesis]: """Generate hypothesis for the DistinctArgumentRelation on trace.""" # 1. Pre-process all the events - print("Start preprocessing....") + logger.debug("Start preprocessing....") listed_arguments: Dict[ str, Dict[int, Dict[Tuple[str, str], List[dict[str, Any]]]] ] = {} @@ -209,7 +213,7 @@ def generate_hypothesis(trace) -> list[Hypothesis]: function_pool, listed_arguments = get_event_data_per_function_per_step( trace, function_pool ) - print("End preprocessing") + logger.debug("End preprocessing") # If there is no filtered function, return [], [] if not function_pool: @@ -220,24 +224,27 @@ def generate_hypothesis(trace) -> list[Hypothesis]: # function_pool.add("torch.nn.init.normal_") # 2. Generating hypothesis - print("Start generating hypo...") + logger.debug("Start generating hypo...") hypothesis_with_examples = { func_name: Hypothesis( invariant=Invariant( relation=DistinctArgumentRelation, params=[APIParam(func_name)], precondition=None, - text_description=f"{func_name} has distinct input arguments on difference PT for each step", + text_description=DistinctArgumentRelation.to_display_name( + [APIParam(func_name)] + ) + or f"{func_name} has distinct input arguments on difference PT for each step", ), positive_examples=ExampleList({EXP_GROUP_NAME}), negative_examples=ExampleList({EXP_GROUP_NAME}), ) for func_name in function_pool } - print("End generating hypo") + logger.debug("End generating hypo") # 3. Add positive and negative examples - print("Start adding examples...") + logger.debug("Start adding examples...") for func_name in tqdm(function_pool): flag = False for step, records in listed_arguments[func_name].items(): @@ -279,7 +286,7 @@ def generate_hypothesis(trace) -> list[Hypothesis]: if not flag: hypothesis_with_examples.pop(func_name) - print("End adding examples") + logger.debug("End adding examples") return list(hypothesis_with_examples.values()) @@ -289,7 +296,7 @@ def collect_examples(trace, hypothesis): inv = hypothesis.invariant # 1. Pre-process all the events - print("Start preprocessing....") + logger.debug("Start preprocessing....") listed_arguments: Dict[ str, Dict[int, Dict[Tuple[str, str], List[dict[str, Any]]]] ] = {} @@ -305,7 +312,7 @@ def collect_examples(trace, hypothesis): trace, function_pool ) - print("End preprocessing") + logger.debug("End preprocessing") if not function_pool: return @@ -343,7 +350,7 @@ def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]: # DistinctArgumentRelation.collect_examples(trace, hypothesis) # 4. Precondition inference - print("Start precondition inference...") + logger.debug("Start precondition inference...") failed_hypothesis = [] for hypothesis in all_hypotheses.copy(): preconditions = find_precondition(hypothesis, [trace]) @@ -354,7 +361,7 @@ def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]: FailedHypothesis(hypothesis, "Precondition not found") ) all_hypotheses.remove(hypothesis) - print("End precondition inference") + logger.debug("End precondition inference") return ( list([hypo.invariant for hypo in all_hypotheses]), @@ -389,7 +396,7 @@ def static_check_all( assert inv.precondition is not None, "Invariant should have a precondition." # 1. Pre-process all the events - print("Start preprocessing....") + logger.debug("Start preprocessing....") listed_arguments: Dict[ str, Dict[int, Dict[Tuple[str, str], List[dict[str, Any]]]] ] = {} @@ -414,7 +421,7 @@ def static_check_all( ) events_list = get_event_list(trace, function_pool) - print("End preprocessing") + logger.debug("End preprocessing") if not inv.precondition.verify(events_list, EXP_GROUP_NAME, trace): return CheckerResult( @@ -465,6 +472,13 @@ def _get_variables_to_check(inv: Invariant): def _get_apis_to_check(inv: Invariant): return None + @staticmethod + def to_display_name(params: list[Param]) -> str | None: + if not params or not isinstance(params[0], APIParam): + return None + func_short = _short_api_name(params[0].api_full_name) + return f"{func_short}() receives distinct arguments at each step" + @staticmethod def _get_api_args_map_to_check(inv): return [inv.params[0].api_full_name] diff --git a/traincheck/invariant/base_cls.py b/traincheck/invariant/base_cls.py index d8649377..944f57e6 100644 --- a/traincheck/invariant/base_cls.py +++ b/traincheck/invariant/base_cls.py @@ -1921,6 +1921,39 @@ def online_check( """Check the invariant online, i.e. during the trace collection process.""" pass + @staticmethod + def to_display_name(params: list[Param]) -> str | None: + """Return a short, human-readable label for this invariant given its params. + + Returns None to fall back to text_description / raw param rendering. + Subclasses should override this to provide meaningful natural-language names. + """ + return None + + +def _short_api_name(full_name: str) -> str: + """Shorten a fully-qualified API name for display. + + 'torch.optim.optimizer.Optimizer.zero_grad' → 'Optimizer.zero_grad' + 'torch.nn.functional.dropout' → 'functional.dropout' + """ + parts = full_name.split(".") + for i, part in enumerate(parts): + if part and part[0].isupper(): + return ".".join(parts[i:]) + return ".".join(parts[-2:]) if len(parts) >= 2 else full_name + + +def _display_attr_name(attr_name: str) -> str: + """Strip TrainCheck-internal proxy bookkeeping prefix for display. + + '_TRAINCHECK_grad_ID' → 'grad_ID' + 'dtype' → 'dtype' (unchanged) + """ + if attr_name.startswith("_TRAINCHECK_"): + return attr_name[len("_TRAINCHECK_") :] + return attr_name + def read_inv_file(file_path: str | list[str]) -> list[Invariant]: if isinstance(file_path, str): diff --git a/traincheck/invariant/consistency_relation.py b/traincheck/invariant/consistency_relation.py index 06b24714..8a6c26fb 100644 --- a/traincheck/invariant/consistency_relation.py +++ b/traincheck/invariant/consistency_relation.py @@ -2,8 +2,6 @@ import time from itertools import combinations -from tqdm import tqdm - from traincheck.config import config from traincheck.invariant.base_cls import ( CheckerResult, @@ -16,9 +14,11 @@ Param, Relation, VarTypeParam, + _display_attr_name, ) from traincheck.invariant.precondition import find_precondition from traincheck.onlinechecker.utils import Checker_data, set_meta_vars_online +from traincheck.progress import tqdm from traincheck.trace.trace import Trace from traincheck.trace.types import Liveness, VarInstId @@ -111,6 +111,17 @@ class ConsistencyRelation(Relation): where the consistency relationships of the variables across different nodes is crucial to maintain. """ + @staticmethod + def to_display_name(params: list[Param]) -> str | None: + if not params: + return None + p = params[0] + if not isinstance(p, VarTypeParam): + return None + var_short = p.var_type.split(".")[-1] + attr = _display_attr_name(p.attr_name) + return f"{var_short}.{attr} stays consistent across training steps" + @staticmethod def infer(trace: Trace) -> tuple[list[Invariant], list[FailedHypothesis]]: """Infer Invariants for the ConsistencyRelation.""" @@ -180,6 +191,7 @@ def skip_attrs_with_different_dtypes(attr1, attr2): combinations(var_insts, 2), desc="Generating Hypothesis for Consistency Relation", total=len(var_insts) * (len(var_insts) - 1) // 2, + leave=False, ): for attr in var_insts[var_inst]: for other_attr in var_insts[other_var_inst]: @@ -255,7 +267,13 @@ def skip_attrs_with_different_dtypes(attr1, attr2): VarTypeParam(var_type=hypo[2], attr_name=hypo[3]), ], precondition=None, - text_description=f"Consistency Relation between {hypo[0]}.{hypo[1]} and {hypo[2]}.{hypo[3]}", + text_description=ConsistencyRelation.to_display_name( + [ + VarTypeParam(var_type=hypo[0], attr_name=hypo[1]), + VarTypeParam(var_type=hypo[2], attr_name=hypo[3]), + ] + ) + or f"Consistency Relation between {hypo[0]}.{hypo[1]} and {hypo[2]}.{hypo[3]}", ), positive_examples=ExampleList({VAR_GROUP_NAME}), negative_examples=ExampleList({VAR_GROUP_NAME}), @@ -421,7 +439,9 @@ def static_check_all( start_time_collecting_pairs = time.time() num_collected_pairs = 0 value_pairs_to_check: dict[float, list[tuple]] = {} - for i, var1_id in enumerate(tqdm(type1_attr1, desc="Collecting Value Pairs")): + for i, var1_id in enumerate( + tqdm(type1_attr1, desc="Collecting Value Pairs", leave=False) + ): for j, var2_id in enumerate(type2_attr2): if var_type1 == var_type2 and attr1 == attr2 and i >= j: continue @@ -462,7 +482,9 @@ def static_check_all( start_time_checking_pairs = time.time() num_checked_pairs = 0 value_pairs_to_check = dict(sorted(value_pairs_to_check.items())) - for time_pair in tqdm(value_pairs_to_check, desc="Checking Value Pairs"): + for time_pair in tqdm( + value_pairs_to_check, desc="Checking Value Pairs", leave=False + ): for attr1_val, attr2_val in value_pairs_to_check[time_pair]: traces = [attr1_val.traces[-1], attr2_val.traces[-1]] num_checked_pairs += 1 @@ -605,7 +627,7 @@ def get_check_attr(param, varid): for var2 in checker_data.type_map[ref_param.var_type]: if var2 == varid: continue - if checker_data.varid_map[var2][ref_param.attr_name] is None: + if checker_data.varid_map[var2].get(ref_param.attr_name) is None: logger.debug( f"Attribute {ref_param.attr_name} not found in variable {var2}" ) diff --git a/traincheck/invariant/consistency_transient_vars.py b/traincheck/invariant/consistency_transient_vars.py index 485800e9..f02e65bb 100644 --- a/traincheck/invariant/consistency_transient_vars.py +++ b/traincheck/invariant/consistency_transient_vars.py @@ -3,10 +3,11 @@ from typing import Hashable import pandas as pd -from tqdm import tqdm +from traincheck.config.config import ANALYSIS_SKIP_FUNC_NAMES from traincheck.instrumentor.tracer import TraceLineType from traincheck.invariant.base_cls import ( + _NOT_SET, APIParam, Arguments, CheckerResult, @@ -20,10 +21,12 @@ Param, Relation, VarTypeParam, + _short_api_name, make_hashable, ) from traincheck.invariant.precondition import find_precondition from traincheck.onlinechecker.utils import Checker_data, set_meta_vars_online +from traincheck.progress import tqdm from traincheck.trace.trace import Trace from traincheck.trace.types import ( FuncCallEvent, @@ -243,17 +246,11 @@ def get_input_thresholds( def get_events_of_funcs_with_tensors( all_func_names, trace, output_has_tensors=True, input_has_tensors=True ): - # HACK: remove all torch.overrides + # skip functions matched by ANALYSIS_SKIP_FUNC_NAMES all_func_names = [ - func_name for func_name in all_func_names if "torch.override" not in func_name - ] - # remove all functions with "._" in them - all_func_names = [ - func_name for func_name in all_func_names if "._" not in func_name - ] - # remove all functions with "._is_" in them - all_func_names = [ - func_name for func_name in all_func_names if ".is_" not in func_name + func_name + for func_name in all_func_names + if not any(skip in func_name for skip in ANALYSIS_SKIP_FUNC_NAMES) ] # if os.path.exists(_CACHE_PATH): @@ -389,7 +386,17 @@ def generate_hypothesis(trace) -> list[Hypothesis]: ), ], precondition=None, - text_description=f"{prop} of the tensors returned by the function {func_name} is consistently {prop_val}.", + text_description=ConsistentOutputRelation.to_display_name( + [ + APIParam(api_full_name=func_name), + VarTypeParam( + var_type="torch.Tensor", + attr_name=prop, + const_value=make_hashable(prop_val), + ), + ] + ) + or f"{prop} of the tensors returned by the function {func_name} is consistently {prop_val}.", ), positive_examples=ExampleList({"pre_event"}), negative_examples=ExampleList({"pre_event"}), @@ -532,7 +539,9 @@ def static_check_all( triggered = False # for each function call, check if the property holds for func_call_id in tqdm( - func_call_ids, desc=f"Checking invariant {inv.text_description}" + func_call_ids, + desc=f"Checking invariant {inv.text_description}", + leave=False, ): func_call_event = trace.query_func_call_event(func_call_id) if isinstance( @@ -580,6 +589,20 @@ def static_check_all( ) # raise NotImplementedError + @staticmethod + def to_display_name(params: list[Param]) -> str | None: + if len(params) < 2: + return None + api, vt = params[0], params[1] + if not isinstance(api, APIParam) or not isinstance(vt, VarTypeParam): + return None + func_short = _short_api_name(api.api_full_name) + attr = vt.attr_name + const = vt.const_value + if const is not _NOT_SET: + return f"{func_short}() consistently returns tensors with {attr}={const}" + return f"{func_short}() consistently returns tensors with consistent {attr}" + @staticmethod def _get_identifying_params(inv: Invariant) -> list[Param]: return [inv.params[0]] @@ -759,7 +782,10 @@ def generate_hypothesis(trace: Trace) -> list[Hypothesis]: relation=ConsistentInputOutputRelation, params=[input_param, api_param, output_param], precondition=None, - text_description=f"The value {common_value} is consistent across the input {input_path} and output {output_path} tensors of the function {func_name}.", + text_description=ConsistentInputOutputRelation.to_display_name( + [input_param, api_param, output_param] + ) + or f"The value {common_value} is consistent across the input {input_path} and output {output_path} tensors of the function {func_name}.", ), positive_examples=ExampleList( {"pre_event"} @@ -922,7 +948,9 @@ def static_check_all(trace, inv, check_relation_first): triggered = False for func_call_id in tqdm( - func_call_ids, desc=f"Checking invariant {inv.text_description}" + func_call_ids, + desc=f"Checking invariant {inv.text_description}", + leave=False, ): func_call_event = trace.query_func_call_event(func_call_id) if isinstance( @@ -965,6 +993,40 @@ def static_check_all(trace, inv, check_relation_first): triggered=triggered, ) + @staticmethod + def to_display_name(params: list[Param]) -> str | None: + if len(params) < 3: + return None + input_param, api_param, output_param = params[0], params[1], params[2] + if not isinstance(api_param, APIParam): + return None + if not isinstance(input_param, InputOutputParam) or not isinstance( + output_param, InputOutputParam + ): + return None + func_short = _short_api_name(api_param.api_full_name) + in_path = ( + ".".join(str(p) for p in input_param.additional_path) + if input_param.additional_path + else (input_param.name or "?") + ) + out_path = ( + ".".join(str(p) for p in output_param.additional_path) + if output_param.additional_path + else (output_param.name or "?") + ) + in_ref = ( + f"input[{input_param.index}].{in_path}" + if input_param.index is not None + else f"input.{in_path}" + ) + out_ref = ( + f"output[{output_param.index}].{out_path}" + if output_param.index is not None + else f"output.{out_path}" + ) + return f"{func_short}(): {in_ref} consistent with {out_ref}" + @staticmethod def _get_identifying_params(inv: Invariant) -> list[Param]: return [inv.params[1]] @@ -1155,7 +1217,10 @@ def generate_hypothesis(trace: Trace) -> list[Hypothesis]: input_param, ], # the first param should be larger or equal to the second param precondition=None, - text_description=f"Output tensor's value at {output_param.additional_path} is consistently larger than or equal to the min input threshold {input_param.name} for the function {func_name}.", + text_description=ThresholdRelation.to_display_name( + [output_param, api_param, input_param] + ) + or f"Output tensor's value at {output_param.additional_path} is consistently larger than or equal to the min input threshold {input_param.name} for the function {func_name}.", ), positive_examples=ExampleList( {"pre_event"} @@ -1208,7 +1273,10 @@ def generate_hypothesis(trace: Trace) -> list[Hypothesis]: output_param, ], # the first param should be larger or equal to the second param precondition=None, - text_description=f"Output tensor's value at {output_param.additional_path} is consistently less than or equal to the max input threshold {input_param.name} for the function {func_name}.", + text_description=ThresholdRelation.to_display_name( + [input_param, api_param, output_param] + ) + or f"Output tensor's value at {output_param.additional_path} is consistently less than or equal to the max input threshold {input_param.name} for the function {func_name}.", ), positive_examples=ExampleList( {"pre_event"} @@ -1391,7 +1459,9 @@ def static_check_all(trace, inv, check_relation_first): triggered = False for func_call_id in tqdm( - func_call_ids, desc=f"Checking invariant {inv.text_description}" + func_call_ids, + desc=f"Checking invariant {inv.text_description}", + leave=False, ): func_call_event = trace.query_func_call_event(func_call_id) if isinstance( @@ -1453,6 +1523,46 @@ def _get_apis_to_check(inv: Invariant): assert isinstance(inv.params[1], APIParam) return [inv.params[1].api_full_name] + @staticmethod + def to_display_name(params: list[Param]) -> str | None: + if len(params) < 3: + return None + first, api_param, second = params[0], params[1], params[2] + if not isinstance(api_param, APIParam): + return None + if not isinstance(first, InputOutputParam) or not isinstance( + second, InputOutputParam + ): + return None + func_short = _short_api_name(api_param.api_full_name) + # min case: params=[output, api, input_threshold] → output ≥ threshold + if not first.is_input and second.is_input: + out_path = ( + ".".join(str(p) for p in first.additional_path) + if first.additional_path + else "value" + ) + out_ref = ( + f"output[{first.index}].{out_path}" + if first.index is not None + else "output" + ) + return f"{func_short}(): {out_ref} ≥ {second.name}" + # max case: params=[input_threshold, api, output] → output ≤ threshold + if first.is_input and not second.is_input: + out_path = ( + ".".join(str(p) for p in second.additional_path) + if second.additional_path + else "value" + ) + out_ref = ( + f"output[{second.index}].{out_path}" + if second.index is not None + else "output" + ) + return f"{func_short}(): {out_ref} ≤ {first.name}" + return None + @staticmethod def _get_api_args_map_to_check(inv): return None diff --git a/traincheck/invariant/contain_relation.py b/traincheck/invariant/contain_relation.py index 0c4d0e95..84fa76d2 100644 --- a/traincheck/invariant/contain_relation.py +++ b/traincheck/invariant/contain_relation.py @@ -5,11 +5,11 @@ from typing import Type import numpy as np -from tqdm import tqdm from traincheck.config.config import ANALYSIS_SKIP_FUNC_NAMES from traincheck.instrumentor.tracer import TraceLineType from traincheck.invariant.base_cls import ( + _NOT_SET, APIParam, Arguments, CheckerResult, @@ -23,6 +23,8 @@ Relation, VarNameParam, VarTypeParam, + _display_attr_name, + _short_api_name, calc_likelihood, construct_api_param, construct_var_param_from_var_change, @@ -36,6 +38,7 @@ get_var_raw_event_before_time, set_meta_vars_online, ) +from traincheck.progress import tqdm from traincheck.trace.trace import Trace from traincheck.trace.types import ( ALL_EVENT_TYPES, @@ -276,14 +279,19 @@ def _merge_hypotheses(hypotheses: list[Hypothesis]) -> list[Hypothesis]: setattr(merged_child_param, field, generalized_value) # construct the merged hypotheses + parent_param = hypotheses[0].invariant.params[0] + assert isinstance(parent_param, APIParam) merged_hypothesis = Hypothesis( invariant=Invariant( relation=hypotheses[0].invariant.relation, params=[ - hypotheses[0].invariant.params[0], + parent_param, merged_child_param, ], - text_description="TBD merged", + text_description=APIContainRelation.to_display_name( + [parent_param, merged_child_param] + ) + or f"{parent_param.api_full_name} contains {merged_child_param}", num_positive_examples=len(all_positive_examples), num_negative_examples=len(all_positive_examples), precondition=None, # to be inferred later @@ -338,6 +346,35 @@ class APIContainRelation(Relation): - [ ] Make the Dynamic Analysis part less ad-hoc as of its current form in the code """ + @staticmethod + def to_display_name(params: list[Param]) -> str | None: + if len(params) < 2: + return None + parent = params[0] + child = params[1] + if not isinstance(parent, APIParam): + return None + parent_short = _short_api_name(parent.api_full_name) + if isinstance(child, APIParam): + child_short = _short_api_name(child.api_full_name) + return f"{parent_short}() always calls {child_short}()" + if isinstance(child, (VarTypeParam, VarNameParam)): + attr = _display_attr_name(child.attr_name) + var_short = child.var_type.split(".")[-1] + pre = child.pre_value + post = child.post_value + const = child.const_value + + def _fmt_val(v: object) -> str: + return "non-zero" if v == "non_zero" else str(v) + + if pre is not _NOT_SET and post is not _NOT_SET: + return f"{parent_short}() changes {var_short}.{attr}: {_fmt_val(pre)} --> {_fmt_val(post)}" + if const is not _NOT_SET: + return f"{parent_short}() sees {var_short}.{attr} = {const}" + return f"{parent_short}() accesses {var_short}.{attr}" + return None + @staticmethod def generate_hypothesis(trace) -> list[Hypothesis]: # let's play it dumb here, @@ -699,7 +736,10 @@ def _infer( relation=APIContainRelation, params=[parent_param, child_param], precondition=None, - text_description=f"{parent} contains {child_param} of type {typename(event)}", + text_description=APIContainRelation.to_display_name( + [parent_param, child_param] + ) + or f"{parent} contains {child_param} of type {typename(event)}", ), positive_examples=ExampleList( {PARENT_GROUP_NAME, VAR_GROUP_NAME} @@ -737,7 +777,10 @@ def _infer( relation=APIContainRelation, params=[parent_param, child_param], precondition=None, - text_description=f"{parent} contains {child_param} of type {typename(event)}", + text_description=APIContainRelation.to_display_name( + [parent_param, child_param] + ) + or f"{parent} contains {child_param} of type {typename(event)}", ), positive_examples=ExampleList({PARENT_GROUP_NAME}), negative_examples=ExampleList({PARENT_GROUP_NAME}), @@ -791,7 +834,10 @@ def _merge_child_API_events( relation=APIContainRelation, params=[parent_param, merged_child_param], precondition=None, - text_description=f"{parent} contains {merged_child_param}", + text_description=APIContainRelation.to_display_name( + [parent_param, merged_child_param] + ) + or f"{parent} contains {merged_child_param}", ), positive_examples=ExampleList({PARENT_GROUP_NAME}), negative_examples=ExampleList({PARENT_GROUP_NAME}), @@ -1267,33 +1313,51 @@ def online_check( if isinstance(child_param, (VarTypeParam, VarNameParam)): events = [] + attr_name = child_param.attr_name with checker_data.lock: - if child_param.var_type in checker_data.type_map: - for varid in checker_data.type_map[child_param.var_type]: - if isinstance(child_param, VarNameParam): - if varid.var_name != child_param.var_name: - continue - attr_name = child_param.attr_name - elif isinstance(child_param, VarTypeParam): - attr_name = child_param.attr_name - for i in reversed( - range(1, len(checker_data.varid_map[varid][attr_name])) - ): - - change_time = checker_data.varid_map[varid][attr_name][ - i - ].liveness.start_time - if change_time <= pre_time: - break - if change_time > post_time: - continue - new_state = checker_data.varid_map[varid][attr_name][i] - old_state = checker_data.varid_map[varid][attr_name][i - 1] - if new_state.value == old_state.value: - continue - if new_state.liveness.end_time is None: - new_state = copy.deepcopy(new_state) - events.append((old_state, new_state)) + # No variables of this type/attr have been observed yet — not yet checkable. + if child_param.var_type not in checker_data.attr_map: + return OnlineCheckerResult( + trace=None, invariant=inv, check_passed=True + ) + if attr_name not in checker_data.attr_map[child_param.var_type]: + return OnlineCheckerResult( + trace=None, invariant=inv, check_passed=True + ) + candidate_varids = checker_data.attr_map[child_param.var_type][ + attr_name + ] + if isinstance(child_param, VarNameParam): + candidate_varids = { + v + for v in candidate_varids + if v.var_name == child_param.var_name + } + for varid in candidate_varids: + # attr_map guarantees attr_name is in varid_map[varid]; fail loudly + # if the population logic ever violates this invariant. + assert attr_name in checker_data.varid_map[varid], ( + f"attr_map/varid_map inconsistency: {varid} is in " + f"attr_map[...][{attr_name}] but " + f"varid_map[{varid}] has no '{attr_name}' entry" + ) + for i in reversed( + range(1, len(checker_data.varid_map[varid][attr_name])) + ): + change_time = checker_data.varid_map[varid][attr_name][ + i + ].liveness.start_time + if change_time <= pre_time: + break + if change_time > post_time: + continue + new_state = checker_data.varid_map[varid][attr_name][i] + old_state = checker_data.varid_map[varid][attr_name][i - 1] + if new_state.value == old_state.value: + continue + if new_state.liveness.end_time is None: + new_state = copy.deepcopy(new_state) + events.append((old_state, new_state)) if check_relation_first: for event in events: diff --git a/traincheck/invariant/cover_relation.py b/traincheck/invariant/cover_relation.py index d4e7c2d9..567e6bec 100644 --- a/traincheck/invariant/cover_relation.py +++ b/traincheck/invariant/cover_relation.py @@ -2,8 +2,6 @@ from itertools import permutations from typing import Any, Dict, List, Set, Tuple -from tqdm import tqdm - from traincheck.instrumentor.tracer import TraceLineType from traincheck.invariant.base_cls import ( APIParam, @@ -17,6 +15,7 @@ OnlineCheckerResult, Param, Relation, + _short_api_name, ) from traincheck.invariant.lead_relation import ( check_same_level, @@ -27,9 +26,12 @@ ) from traincheck.invariant.precondition import find_precondition from traincheck.onlinechecker.utils import Checker_data, set_meta_vars_online +from traincheck.progress import tqdm from traincheck.trace.trace import Trace from traincheck.trace.trace_pandas import TracePandas +logger = logging.getLogger(__name__) + EXP_GROUP_NAME = "func_cover" @@ -103,6 +105,17 @@ class FunctionCoverRelation(Relation): every time function B is called, a function A invocation exists before it. """ + @staticmethod + def to_display_name(params: list[Param]) -> str | None: + if len(params) < 2: + return None + a, b = params[0], params[1] + if not isinstance(a, APIParam) or not isinstance(b, APIParam): + return None + a_short = _short_api_name(a.api_full_name) + b_short = _short_api_name(b.api_full_name) + return f"{a_short}() always occurs when {b_short}() is called" + @staticmethod def generate_hypothesis(trace) -> list[Hypothesis]: """Generate hypothesis for the FunctionCoverRelation on trace.""" @@ -110,7 +123,7 @@ def generate_hypothesis(trace) -> list[Hypothesis]: logger = logging.getLogger(__name__) # 1. Pre-process all the events - print("Start preprocessing....") + logger.debug("Start preprocessing....") function_times: Dict[Tuple[str, str], Dict[str, Dict[str, Any]]] = {} function_id_map: Dict[Tuple[str, str], Dict[str, List[str]]] = {} listed_events: Dict[Tuple[str, str], List[dict[str, Any]]] = {} @@ -144,9 +157,9 @@ def generate_hypothesis(trace) -> list[Hypothesis]: trace.function_times = function_times trace.function_id_map = function_id_map trace.listed_events = listed_events - print("End preprocessing") + logger.debug("End preprocessing") - print("Start same level checking...") + logger.debug("Start same level checking...") same_level_func: Dict[Tuple[str, str], Dict[str, Any]] = {} valid_relations: Dict[Tuple[str, str], bool] = {} @@ -158,13 +171,13 @@ def generate_hypothesis(trace) -> list[Hypothesis]: valid_relations = trace.valid_relations_cover else: for (process_id, thread_id), _ in tqdm( - listed_events.items(), ascii=True, leave=True, desc="Groups Processed" + listed_events.items(), ascii=True, leave=False, desc="Groups Processed" ): same_level_func[(process_id, thread_id)] = {} for func_A, func_B in tqdm( permutations(function_pool, 2), ascii=True, - leave=True, + leave=False, desc="Combinations Checked", total=len(function_pool) ** 2, ): @@ -182,10 +195,10 @@ def generate_hypothesis(trace) -> list[Hypothesis]: valid_relations[(func_A, func_B)] = True trace.same_level_func_cover = same_level_func trace.valid_relations_cover = valid_relations - print("End same level checking") + logger.debug("End same level checking") # 3. Generating hypothesis - print("Start generating hypo...") + logger.debug("Start generating hypo...") hypothesis_with_examples = { (func_A, func_B): Hypothesis( invariant=Invariant( @@ -195,19 +208,22 @@ def generate_hypothesis(trace) -> list[Hypothesis]: APIParam(func_B), ], precondition=None, - text_description=f"FunctionCoverRelation between {func_A} and {func_B}", + text_description=FunctionCoverRelation.to_display_name( + [APIParam(func_A), APIParam(func_B)] + ) + or f"FunctionCoverRelation between {func_A} and {func_B}", ), positive_examples=ExampleList({EXP_GROUP_NAME}), negative_examples=ExampleList({EXP_GROUP_NAME}), ) for (func_A, func_B), _ in valid_relations.items() } - print("End generating hypo") + logger.debug("End generating hypo") # 4. Add positive and negative examples - print("Start adding examples...") + logger.debug("Start adding examples...") for (process_id, thread_id), events_list in tqdm( - listed_events.items(), ascii=True, leave=True, desc="Group" + listed_events.items(), ascii=True, leave=False, desc="Group" ): for (func_A, func_B), _ in tqdm( @@ -308,7 +324,7 @@ def generate_hypothesis(trace) -> list[Hypothesis]: (func_A, func_B) ].negative_examples.add_example(example) - print("End adding examples") + logger.debug("End adding examples") return list(hypothesis_with_examples.values()) @@ -319,7 +335,7 @@ def collect_examples(trace, hypothesis): logger = logging.getLogger(__name__) # 1. Pre-process all the events - print("Start preprocessing....") + logger.debug("Start preprocessing....") function_times: Dict[Tuple[str, str], Dict[str, Dict[str, Any]]] = {} function_id_map: Dict[Tuple[str, str], Dict[str, List[str]]] = {} listed_events: Dict[Tuple[str, str], List[dict[str, Any]]] = {} @@ -353,9 +369,9 @@ def collect_examples(trace, hypothesis): trace.function_times = function_times trace.function_id_map = function_id_map trace.listed_events = listed_events - print("End preprocessing") + logger.debug("End preprocessing") - print("Start same level checking...") + logger.debug("Start same level checking...") same_level_func: Dict[Tuple[str, str], Dict[str, Any]] = {} valid_relations: Dict[Tuple[str, str], bool] = {} @@ -367,13 +383,13 @@ def collect_examples(trace, hypothesis): valid_relations = trace.valid_relations_cover else: for (process_id, thread_id), _ in tqdm( - listed_events.items(), ascii=True, leave=True, desc="Groups Processed" + listed_events.items(), ascii=True, leave=False, desc="Groups Processed" ): same_level_func[(process_id, thread_id)] = {} for func_A, func_B in tqdm( permutations(function_pool, 2), ascii=True, - leave=True, + leave=False, desc="Combinations Checked", total=len(function_pool) ** 2, ): @@ -391,7 +407,7 @@ def collect_examples(trace, hypothesis): valid_relations[(func_A, func_B)] = True trace.same_level_func_cover = same_level_func trace.valid_relations_cover = valid_relations - print("End same level checking") + logger.debug("End same level checking") inv = hypothesis.invariant @@ -409,12 +425,12 @@ def collect_examples(trace, hypothesis): function_pool = set(function_pool).intersection(function_pool_temp) if len(function_pool) == 0: - print( + logger.debug( "No relevant function calls found in the trace, skipping the collecting" ) return - print("Starting collecting iteration...") + logger.debug("Starting collecting iteration...") # for i in tqdm(range(invariant_length - 1)): for i in range(invariant_length - 1): param_A = inv.params[i] @@ -528,7 +544,7 @@ def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]: if_merge = True - print("Start precondition inference...") + logger.debug("Start precondition inference...") failed_hypothesis = [] for hypothesis in all_hypotheses.copy(): preconditions = find_precondition(hypothesis, [trace]) @@ -539,17 +555,17 @@ def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]: FailedHypothesis(hypothesis, "Precondition not found") ) all_hypotheses.remove(hypothesis) - print("End precondition inference") + logger.debug("End precondition inference") if not if_merge: return ( list([hypo.invariant for hypo in all_hypotheses]), failed_hypothesis, ) - print("End precondition inference") + logger.debug("End precondition inference") # 6. Merge invariants - print("Start merging invariants...") + logger.debug("Start merging invariants...") relation_pool: Dict[ GroupedPreconditions | None, List[Tuple[APIParam, APIParam]] ] = {} @@ -576,10 +592,13 @@ def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]: relation=FunctionCoverRelation, params=[param for param in merged_value], precondition=key, - text_description="Merged FunctionCoverRelation in Ordered List", + text_description=FunctionCoverRelation.to_display_name( + [param for param in merged_value] + ) + or "Merged FunctionCoverRelation in Ordered List", ) merged_ininvariants.append(new_invariant) - print("End merging invariants") + logger.debug("End merging invariants") return merged_ininvariants, failed_hypothesis @@ -613,7 +632,7 @@ def static_check_all( logger = logging.getLogger(__name__) # 1. Pre-process all the events - print("Start preprocessing....") + logger.debug("Start preprocessing....") function_times: Dict[Tuple[str, str], Dict[str, Dict[str, Any]]] = {} function_id_map: Dict[Tuple[str, str], Dict[str, List[str]]] = {} listed_events: Dict[Tuple[str, str], List[dict[str, Any]]] = {} @@ -655,9 +674,9 @@ def static_check_all( trace.function_times = function_times trace.function_id_map = function_id_map trace.listed_events = listed_events - print("End preprocessing") + logger.debug("End preprocessing") - print("Start same level checking...") + logger.debug("Start same level checking...") same_level_func: Dict[Tuple[str, str], Dict[str, Any]] = {} valid_relations: Dict[Tuple[str, str], bool] = {} @@ -669,13 +688,13 @@ def static_check_all( valid_relations = trace.valid_relations_cover else: for (process_id, thread_id), _ in tqdm( - listed_events.items(), ascii=True, leave=True, desc="Groups Processed" + listed_events.items(), ascii=True, leave=False, desc="Groups Processed" ): same_level_func[(process_id, thread_id)] = {} for func_A, func_B in tqdm( permutations(function_pool, 2), ascii=True, - leave=True, + leave=False, desc="Combinations Checked", total=len(function_pool) ** 2, ): @@ -693,7 +712,7 @@ def static_check_all( valid_relations[(func_A, func_B)] = True trace.same_level_func_cover = same_level_func trace.valid_relations_cover = valid_relations - print("End same level checking") + logger.debug("End same level checking") inv_triggered = False @@ -710,7 +729,7 @@ def static_check_all( function_pool = set(function_pool).intersection(set(function_pool_temp)) # type: ignore if len(function_pool) == 0: - print( + logger.debug( "No relevant function calls found in the trace, skipping the checking" ) return CheckerResult( @@ -720,8 +739,8 @@ def static_check_all( triggered=False, ) - print("Starting checking iteration...") - for i in tqdm(range(invariant_length - 1)): + logger.debug("Starting checking iteration...") + for i in tqdm(range(invariant_length - 1), leave=False): param_A = inv.params[i] param_B = inv.params[i + 1] diff --git a/traincheck/invariant/lead_relation.py b/traincheck/invariant/lead_relation.py index 2db839a9..6f433a77 100644 --- a/traincheck/invariant/lead_relation.py +++ b/traincheck/invariant/lead_relation.py @@ -2,8 +2,7 @@ from itertools import permutations from typing import Any, Dict, Iterable, List, Set, Tuple -from tqdm import tqdm - +from traincheck.config.config import ANALYSIS_SKIP_FUNC_NAMES from traincheck.instrumentor.tracer import TraceLineType from traincheck.invariant.base_cls import ( APIParam, @@ -17,12 +16,16 @@ OnlineCheckerResult, Param, Relation, + _short_api_name, ) from traincheck.invariant.precondition import find_precondition from traincheck.onlinechecker.utils import Checker_data, set_meta_vars_online +from traincheck.progress import tqdm from traincheck.trace.trace import Trace from traincheck.trace.trace_pandas import TracePandas +logger = logging.getLogger(__name__) + EXP_GROUP_NAME = "func_lead" MAX_FUNC_NUM_CONSECUTIVE_CALL = 4 # ideally this should be proportional to the number of training and testing iterations in the trace @@ -84,9 +87,9 @@ def get_func_names_to_deal_with(trace: Trace) -> List[str]: # get all functions in the trace all_func_names = trace.get_func_names() - # filtering 1: remove private functions + # filtering 1: skip functions matched by ANALYSIS_SKIP_FUNC_NAMES for func_name in all_func_names: - if "._" in func_name: + if any(skip in func_name for skip in ANALYSIS_SKIP_FUNC_NAMES): continue function_pool.add(func_name) @@ -275,7 +278,7 @@ def generate_hypothesis(trace) -> list[Hypothesis]: logger = logging.getLogger(__name__) # 1. Pre-process all the events - print("Start preprocessing....") + logger.debug("Start preprocessing....") function_times: Dict[Tuple[str, str], Dict[str, Dict[str, Any]]] = {} function_id_map: Dict[Tuple[str, str], Dict[str, List[str]]] = {} listed_events: Dict[Tuple[str, str], List[dict[str, Any]]] = {} @@ -311,9 +314,9 @@ def generate_hypothesis(trace) -> list[Hypothesis]: trace.function_times = function_times trace.function_id_map = function_id_map trace.listed_events = listed_events - print("End preprocessing") + logger.debug("End preprocessing") - print("Start same level checking...") + logger.debug("Start same level checking...") same_level_func: Dict[Tuple[str, str], Dict[str, Any]] = {} valid_relations: Dict[Tuple[str, str], bool] = {} @@ -325,13 +328,13 @@ def generate_hypothesis(trace) -> list[Hypothesis]: valid_relations = trace.valid_relations_lead else: for (process_id, thread_id), _ in tqdm( - listed_events.items(), ascii=True, leave=True, desc="Groups Processed" + listed_events.items(), ascii=True, leave=False, desc="Groups Processed" ): same_level_func[(process_id, thread_id)] = {} for func_A, func_B in tqdm( permutations(function_pool, 2), ascii=True, - leave=True, + leave=False, desc="Combinations Checked", total=len(function_pool) ** 2, ): @@ -349,10 +352,10 @@ def generate_hypothesis(trace) -> list[Hypothesis]: valid_relations[(func_A, func_B)] = True trace.same_level_func_lead = same_level_func trace.valid_relations_lead = valid_relations - print("End same level checking") + logger.debug("End same level checking") # 3. Generating hypothesis - print("Start generating hypo...") + logger.debug("Start generating hypo...") hypothesis_with_examples = { (func_A, func_B): Hypothesis( invariant=Invariant( @@ -362,19 +365,22 @@ def generate_hypothesis(trace) -> list[Hypothesis]: APIParam(func_B), ], precondition=None, - text_description=f"FunctionLeadRelation between {func_A} and {func_B}", + text_description=FunctionLeadRelation.to_display_name( + [APIParam(func_A), APIParam(func_B)] + ) + or f"FunctionLeadRelation between {func_A} and {func_B}", ), positive_examples=ExampleList({EXP_GROUP_NAME}), negative_examples=ExampleList({EXP_GROUP_NAME}), ) for (func_A, func_B), _ in valid_relations.items() } - print("End generating hypo") + logger.debug("End generating hypo") # 4. Add positive and negative examples - print("Start adding examples...") + logger.debug("Start adding examples...") for (process_id, thread_id), events_list in tqdm( - listed_events.items(), ascii=True, leave=True, desc="Group" + listed_events.items(), ascii=True, leave=False, desc="Group" ): for (func_A, func_B), _ in tqdm( @@ -511,7 +517,7 @@ def generate_hypothesis(trace) -> list[Hypothesis]: (func_A, func_B) ].negative_examples.add_example(example) - print("End adding examples") + logger.debug("End adding examples") return list(hypothesis_with_examples.values()) @@ -522,7 +528,7 @@ def collect_examples(trace, hypothesis): logger = logging.getLogger(__name__) # 1. Pre-process all the events - print("Start preprocessing....") + logger.debug("Start preprocessing....") function_times: Dict[Tuple[str, str], Dict[str, Dict[str, Any]]] = {} function_id_map: Dict[Tuple[str, str], Dict[str, List[str]]] = {} listed_events: Dict[Tuple[str, str], List[dict[str, Any]]] = {} @@ -558,9 +564,9 @@ def collect_examples(trace, hypothesis): trace.function_times = function_times trace.function_id_map = function_id_map trace.listed_events = listed_events - print("End preprocessing") + logger.debug("End preprocessing") - print("Start same level checking...") + logger.debug("Start same level checking...") same_level_func: Dict[Tuple[str, str], Dict[str, Any]] = {} valid_relations: Dict[Tuple[str, str], bool] = {} @@ -572,13 +578,13 @@ def collect_examples(trace, hypothesis): valid_relations = trace.valid_relations_lead else: for (process_id, thread_id), _ in tqdm( - listed_events.items(), ascii=True, leave=True, desc="Groups Processed" + listed_events.items(), ascii=True, leave=False, desc="Groups Processed" ): same_level_func[(process_id, thread_id)] = {} for func_A, func_B in tqdm( permutations(function_pool, 2), ascii=True, - leave=True, + leave=False, desc="Combinations Checked", total=len(function_pool) ** 2, ): @@ -596,7 +602,7 @@ def collect_examples(trace, hypothesis): valid_relations[(func_A, func_B)] = True trace.same_level_func_lead = same_level_func trace.valid_relations_lead = valid_relations - print("End same level checking") + logger.debug("End same level checking") inv = hypothesis.invariant @@ -613,12 +619,12 @@ def collect_examples(trace, hypothesis): function_pool = set(function_pool).intersection(function_pool_temp) if len(function_pool) == 0: - print( + logger.debug( "No relevant function calls found in the trace, skipping the collecting" ) return - print("Starting collecting iteration...") + logger.debug("Starting collecting iteration...") for i in range(invariant_length - 1): param_A = inv.params[i] param_B = inv.params[i + 1] @@ -764,7 +770,7 @@ def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]: # for hypothesis in all_hypotheses: # FunctionLeadRelation.collect_examples(trace, hypothesis) - print("Start precondition inference...") + logger.debug("Start precondition inference...") failed_hypothesis = [] for hypothesis in all_hypotheses.copy(): preconditions = find_precondition(hypothesis, [trace]) @@ -775,7 +781,7 @@ def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]: FailedHypothesis(hypothesis, "Precondition not found") ) all_hypotheses.remove(hypothesis) - print("End precondition inference") + logger.debug("End precondition inference") if_merge = True @@ -786,7 +792,7 @@ def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]: ) # 6. Merge invariants - print("Start merging invariants...") + logger.debug("Start merging invariants...") relation_pool: Dict[ GroupedPreconditions | None, List[Tuple[APIParam, APIParam]] ] = {} @@ -813,10 +819,13 @@ def infer(trace: Trace) -> Tuple[List[Invariant], List[FailedHypothesis]]: relation=FunctionLeadRelation, params=[param for param in merged_value], precondition=key, - text_description="Merged FunctionLeadRelation in Ordered List", + text_description=FunctionLeadRelation.to_display_name( + [param for param in merged_value] + ) + or "Merged FunctionLeadRelation in Ordered List", ) merged_ininvariants.append(new_invariant) - print("End merging invariants") + logger.debug("End merging invariants") return merged_ininvariants, failed_hypothesis @@ -851,7 +860,7 @@ def static_check_all( logger = logging.getLogger(__name__) # 1. Pre-process all the events - print("Start preprocessing....") + logger.debug("Start preprocessing....") function_times: Dict[Tuple[str, str], Dict[str, Dict[str, Any]]] = {} function_id_map: Dict[Tuple[str, str], Dict[str, List[str]]] = {} listed_events: Dict[Tuple[str, str], List[dict[str, Any]]] = {} @@ -892,9 +901,9 @@ def static_check_all( trace.function_times = function_times trace.function_id_map = function_id_map trace.listed_events = listed_events - print("End preprocessing") + logger.debug("End preprocessing") - print("Start same level checking...") + logger.debug("Start same level checking...") same_level_func: Dict[Tuple[str, str], Dict[str, Any]] = {} valid_relations: Dict[Tuple[str, str], bool] = {} @@ -906,13 +915,13 @@ def static_check_all( valid_relations = trace.valid_relations_lead else: for (process_id, thread_id), _ in tqdm( - listed_events.items(), ascii=True, leave=True, desc="Groups Processed" + listed_events.items(), ascii=True, leave=False, desc="Groups Processed" ): same_level_func[(process_id, thread_id)] = {} for func_A, func_B in tqdm( permutations(function_pool, 2), ascii=True, - leave=True, + leave=False, desc="Combinations Checked", total=len(function_pool) ** 2, ): @@ -930,7 +939,7 @@ def static_check_all( valid_relations[(func_A, func_B)] = True trace.same_level_func_lead = same_level_func trace.valid_relations_lead = valid_relations - print("End same level checking") + logger.debug("End same level checking") inv_triggered = False @@ -947,7 +956,7 @@ def static_check_all( function_pool = set(function_pool).intersection(set(function_pool_temp)) if len(function_pool) == 0: - print( + logger.debug( "No relevant function calls found in the trace, skipping the checking" ) return CheckerResult( @@ -957,7 +966,7 @@ def static_check_all( triggered=False, ) - print("Starting checking iteration...") + logger.debug("Starting checking iteration...") for i in range(invariant_length - 1): param_A = inv.params[i] param_B = inv.params[i + 1] @@ -1102,6 +1111,17 @@ def _get_apis_to_check(inv: Invariant): api_name_list.append(param.api_full_name) return api_name_list + @staticmethod + def to_display_name(params: list[Param]) -> str | None: + if len(params) < 2: + return None + a, b = params[0], params[-1] + if not isinstance(a, APIParam) or not isinstance(b, APIParam): + return None + a_short = _short_api_name(a.api_full_name) + b_short = _short_api_name(b.api_full_name) + return f"{a_short}() always precedes {b_short}()" + @staticmethod def _get_api_args_map_to_check(inv): return None diff --git a/traincheck/invariant/precondition.py b/traincheck/invariant/precondition.py index b2040d83..f86ffb95 100644 --- a/traincheck/invariant/precondition.py +++ b/traincheck/invariant/precondition.py @@ -2,8 +2,6 @@ from itertools import combinations from typing import Hashable -from tqdm import tqdm - import traincheck.config.config as config from traincheck.invariant.base_cls import ( PT, @@ -14,6 +12,7 @@ Preconditions, UnconditionalPrecondition, ) +from traincheck.progress import tqdm from traincheck.trace.trace import Trace from traincheck.trace.types import MD_NONE from traincheck.utils import safe_isnan @@ -160,8 +159,8 @@ def verify_precondition_safety( """ for example in negative_examples: if precondition.verify(example): - print("Precondition is not safe") - print("Example", example) + logger.debug("Precondition is not safe") + logger.debug("Example %s", example) return False return True @@ -563,7 +562,7 @@ def find_precondition_from_single_group( if len(local_clauses) == 0: # NOTE: this would also happen under the unconditional case, but since the unconditional case is handled separately, we should not reach here - print("example: ", example) + logger.debug("example: %s", example) raise ValueError( "No clauses can be found in the example, precondition will be empty." ) diff --git a/traincheck/onlinechecker/streamhandler_filesystem.py b/traincheck/onlinechecker/streamhandler_filesystem.py index 2b69c5cc..ab2a6366 100644 --- a/traincheck/onlinechecker/streamhandler_filesystem.py +++ b/traincheck/onlinechecker/streamhandler_filesystem.py @@ -4,7 +4,7 @@ import re import time -from watchdog.events import FileSystemEventHandler +from watchdog.events import FileCreatedEvent, FileSystemEventHandler from watchdog.observers.polling import PollingObserver from traincheck.config import config @@ -34,6 +34,7 @@ def __init__(self, file_path, checker_data: Checker_data): self.varid_map = checker_data.varid_map self.type_map = checker_data.type_map + self.attr_map = checker_data.attr_map self.pt_map = checker_data.pt_map self.process_to_vars = checker_data.process_to_vars self.args_map = checker_data.args_map @@ -43,7 +44,7 @@ def __init__(self, file_path, checker_data: Checker_data): self.needed_vars = checker_data.needed_vars self.needed_apis = checker_data.needed_apis - self._get_api_args_map_to_check = checker_data._get_api_args_map_to_check + self.all_needed_args_api = checker_data.all_needed_args_api self.min_read_time = checker_data.min_read_time self.lock = checker_data.lock @@ -61,11 +62,22 @@ def _save_initial_content(self): self.logger.info(f"Processing initial content from {self.file_path}") self.fp.seek(0) lines = self.fp.readlines() - if not lines: - return - self._handle_line(lines) - self.logger.info(f"Initial content from {self.file_path} processed.") + if lines: + self._handle_line(lines) + self.logger.info(f"Initial content from {self.file_path} processed.") + + # Mark this file as fully read so it doesn't block other files that + # have records at later timestamps. Live on_modified updates override + # this when new records actually arrive. + with self.cond: + self.checker_data.read_time_map[self.file_path] = float("inf") + pre_min = self.checker_data.min_read_time + self.checker_data.min_read_path, self.checker_data.min_read_time = min( + self.checker_data.read_time_map.items(), default=(None, None) + ) + if pre_min != self.checker_data.min_read_time: + self.checker_data.cond.notify_all() def on_modified(self, event): if os.path.abspath(event.src_path) != os.path.abspath(self.file_path): @@ -144,6 +156,11 @@ def _set_var_map(self, trace_record): if attr_name not in self.varid_map[varid]: self.varid_map[varid][attr_name] = [] + if varid.var_type not in self.attr_map: + self.attr_map[varid.var_type] = {} + if attr_name not in self.attr_map[varid.var_type]: + self.attr_map[varid.var_type][attr_name] = set() + self.attr_map[varid.var_type][attr_name].add(varid) else: self.varid_map[varid][attr_name][-1].liveness.end_time = ( trace_record["time"] @@ -181,33 +198,37 @@ def _set_func_map(self, trace_record): ) if trace_type == TraceLineType.FUNC_CALL_PRE: self.pt_map[ptname][func_call_id].pre_record = trace_record - self.pt_map[ptname][func_call_id].args = trace_record["args"] - self.pt_map[ptname][func_call_id].kwargs = trace_record["kwargs"] + self.pt_map[ptname][func_call_id].args = trace_record.get("args") + self.pt_map[ptname][func_call_id].kwargs = trace_record.get( + "kwargs" + ) elif trace_type == TraceLineType.FUNC_CALL_POST: assert self.pt_map[ptname][func_call_id].pre_record is not None self.pt_map[ptname][func_call_id].post_record = trace_record - self.pt_map[ptname][func_call_id].return_values = trace_record[ + self.pt_map[ptname][func_call_id].return_values = trace_record.get( "return_values" - ] + ) elif trace_type == TraceLineType.FUNC_CALL_POST_EXCEPTION: self.pt_map[ptname][func_call_id].post_record = trace_record - self.pt_map[ptname][func_call_id].exception = trace_record[ + self.pt_map[ptname][func_call_id].exception = trace_record.get( "exception" - ] + ) if trace_type == TraceLineType.FUNC_CALL_PRE: - if function_name in self.checker_data._get_api_args_map_to_check: - if "args" in trace_record: - if "meta_vars.step" not in trace_record: - trace_record["meta_vars.step"] = -1 - step = trace_record["meta_vars.step"] - if function_name not in self.args_map: - self.args_map[function_name] = {} - if step not in self.args_map[function_name]: - self.args_map[function_name][step] = {} - if ptid not in self.args_map[function_name][step]: - self.args_map[function_name][step][ptid] = [] - self.args_map[function_name][step][ptid].append(trace_record) + if function_name in self.checker_data.all_needed_args_api: + assert ( + "args" in trace_record and "kwargs" in trace_record + ), f"Trace record for function call {function_name} does not contain args or kwargs: {trace_record}" + if "meta_vars.step" not in trace_record: + trace_record["meta_vars.step"] = -1 + step = trace_record["meta_vars.step"] + if function_name not in self.args_map: + self.args_map[function_name] = {} + if step not in self.args_map[function_name]: + self.args_map[function_name][step] = {} + if ptid not in self.args_map[function_name][step]: + self.args_map[function_name][step][ptid] = [] + self.args_map[function_name][step][ptid].append(trace_record) if ( ".__enter__" in function_name @@ -310,29 +331,59 @@ def _set_read_time(self, trace_record): self.checker_data.cond.notify_all() +class FolderCreationHandler(FileSystemEventHandler): + """Watches a trace folder and dynamically attaches StreamLogHandler for new trace files.""" + + def __init__(self, trace_folder, checker_data: Checker_data, observer): + self.trace_folder = os.path.abspath(trace_folder) + self.checker_data = checker_data + self.observer = observer + self.seen_files: set[str] = set() + self.logger = logging.getLogger(__name__) + + def _is_trace_file(self, filename: str) -> bool: + return filename.startswith("trace_") or filename.endswith("proxy_log.json") + + def attach(self, file_path: str): + """Create and schedule a StreamLogHandler for a newly discovered file.""" + if file_path in self.seen_files: + return + self.seen_files.add(file_path) + self.logger.info(f"New trace file detected, watching: {file_path}") + handler = StreamLogHandler(file_path, self.checker_data) + self.observer.schedule( + handler, path=os.path.dirname(file_path), recursive=False + ) + + def on_created(self, event): + if isinstance(event, FileCreatedEvent): + filename = os.path.basename(event.src_path) + if self._is_trace_file(filename): + self.attach(os.path.abspath(event.src_path)) + + def run_stream_monitor(traces, trace_folders, checker_data: Checker_data): """Run the stream monitor to watch the trace files and folders.""" logger = logging.getLogger(__name__) observer = PollingObserver() - handlers = [] if traces is not None: file_path = os.path.abspath(traces[0]) handler = StreamLogHandler(file_path, checker_data) - handlers.append(handler) watch_dir = os.path.dirname(file_path) observer.schedule(handler, path=watch_dir, recursive=False) logger.info(f"Watching: {file_path}") if trace_folders is not None: for trace_folder in trace_folders: + folder_abs = os.path.abspath(trace_folder) + creation_handler = FolderCreationHandler(folder_abs, checker_data, observer) + observer.schedule(creation_handler, path=folder_abs, recursive=False) + + # Pick up any trace files that already exist in the folder. for file in sorted(os.listdir(trace_folder)): if file.startswith("trace_") or file.endswith("proxy_log.json"): - file_path = os.path.join(trace_folder, file) - handler = StreamLogHandler(file_path, checker_data) - handlers.append(handler) - watch_dir = os.path.dirname(file_path) - observer.schedule(handler, path=watch_dir, recursive=False) - logger.info(f"Watching: {file_path}") + file_path = os.path.join(folder_abs, file) + creation_handler.attach(file_path) observer.start() return observer diff --git a/traincheck/onlinechecker/utils.py b/traincheck/onlinechecker/utils.py index 987f22a6..6da09e22 100644 --- a/traincheck/onlinechecker/utils.py +++ b/traincheck/onlinechecker/utils.py @@ -9,14 +9,15 @@ class Checker_data: """Data structure for online checker threads. Holds the needed data and the queue for processing.""" def __init__(self, needed_data): - needed_vars, needed_apis, _get_api_args_map_to_check = needed_data + needed_vars, needed_apis, all_needed_args_api = needed_data self.needed_vars = needed_vars self.needed_apis = needed_apis - self._get_api_args_map_to_check = _get_api_args_map_to_check + self.all_needed_args_api = all_needed_args_api self.check_queue = queue.Queue() self.varid_map = {} self.type_map = {} + self.attr_map: dict[str, dict[str, set]] = {} self.pt_map = {} self.process_to_vars = {} self.args_map = {} @@ -187,9 +188,17 @@ def query_var_changes_within_time_and_process( checker_data: Checker_data, ) -> list: """Extract all variable change events from the trace, within a specific time range and process.""" - events = [] + events: list = [] with checker_data.lock: - for varid in checker_data.type_map[var_type]: + if var_type not in checker_data.attr_map: + return events + if attr_name not in checker_data.attr_map[var_type]: + return events + for varid in checker_data.attr_map[var_type][attr_name]: + assert attr_name in checker_data.varid_map[varid], ( + f"attr_map/varid_map inconsistency: {varid} in " + f"attr_map[{var_type}][{attr_name}] but not in varid_map" + ) for i in reversed(range(1, len(checker_data.varid_map[varid][attr_name]))): change_time = checker_data.varid_map[varid][attr_name][ diff --git a/traincheck/progress.py b/traincheck/progress.py new file mode 100644 index 00000000..7947dcd4 --- /dev/null +++ b/traincheck/progress.py @@ -0,0 +1,23 @@ +"""Thin tqdm wrapper that can be silenced during invariant checking. + +Import this instead of tqdm in relation and trace modules: + + from traincheck.progress import tqdm + +When traincheck.utils._suppress_inner_progress is True (set by +check_engine() while the outer checking bar is active), all bars +created via this wrapper are disabled so only the single top-level +progress bar is visible. +""" + +from tqdm import tqdm as _tqdm_orig + + +def tqdm(iterable=None, *args, **kwargs): # type: ignore[override] + from traincheck import utils as _utils + + if _utils._suppress_inner_progress and "disable" not in kwargs: + kwargs["disable"] = True + if iterable is not None: + return _tqdm_orig(iterable, *args, **kwargs) + return _tqdm_orig(*args, **kwargs) diff --git a/traincheck/reporting/__init__.py b/traincheck/reporting/__init__.py index 75bd0c9e..f3c090ec 100644 --- a/traincheck/reporting/__init__.py +++ b/traincheck/reporting/__init__.py @@ -2,4 +2,5 @@ ReportEmitter, build_offline_report_data, build_online_report_data, + build_violations_summary, ) diff --git a/traincheck/reporting/checker_report.py b/traincheck/reporting/checker_report.py index 92e36cee..1ea3c4a7 100644 --- a/traincheck/reporting/checker_report.py +++ b/traincheck/reporting/checker_report.py @@ -10,12 +10,61 @@ def _format_invariant_label(invariant: Invariant) -> str: + display = invariant.relation.to_display_name(invariant.params) + if display: + return display if invariant.text_description: return invariant.text_description params = ", ".join(str(param) for param in invariant.params) return f"{invariant.relation.__name__}({params})" +def _extract_violation_steps(trace: list[dict] | None) -> list[int]: + """Extract training step numbers from a violation trace.""" + if not trace: + return [] + return [ + r["meta_vars.step"] + for r in trace + if isinstance(r, dict) and r.get("meta_vars.step") is not None + ] + + +def _build_violation_entry(result: CheckerResult) -> dict: + steps = _extract_violation_steps(result.trace) + return { + "display_name": _format_invariant_label(result.invariant), + "relation_type": result.invariant.relation.__name__, + "first_step": min(steps) if steps else None, + "last_step": max(steps) if steps else None, + "occurrences": len(result.trace) if result.trace else 1, + } + + +def _build_violation_steps_map(results: list[CheckerResult]) -> dict[int, int]: + """Map step → count of distinct invariants violated at that step.""" + step_to_invs: dict[int, set[str]] = defaultdict(set) + for res in results: + if not res.check_passed: + label = _format_invariant_label(res.invariant) + for step in _extract_violation_steps(res.trace): + step_to_invs[step].add(label) + return {step: len(invs) for step, invs in step_to_invs.items()} + + +def build_violations_summary(results: list[CheckerResult]) -> dict: + """Build a pre-digested summary of all violations for machine and human consumption.""" + failed = [r for r in results if not r.check_passed] + all_steps = [] + for r in failed: + all_steps.extend(_extract_violation_steps(r.trace)) + return { + "first_violation_step": min(all_steps) if all_steps else None, + "distinct_invariants_violated": len(failed), + "violations": [_build_violation_entry(r) for r in failed], + } + + def _summarize_results(results: Iterable[CheckerResult]) -> dict[str, int]: failed = sum(1 for res in results if not res.check_passed) not_triggered = sum(1 for res in results if res.triggered is False) @@ -52,14 +101,53 @@ def _count_failed_invariants( results: Iterable[CheckerResult], ) -> list[dict[str, object]]: counter: Counter[tuple[str, str]] = Counter() + first_steps: dict[tuple[str, str], int | None] = {} + last_steps: dict[tuple[str, str], int | None] = {} + step_stage_maps: dict[tuple[str, str], dict] = defaultdict(dict) + sample_traces: dict[tuple[str, str], list] = {} for res in results: if not res.check_passed: label = _format_invariant_label(res.invariant) relation = res.invariant.relation.__name__ - counter[(label, relation)] += 1 + key = (label, relation) + counter[key] += 1 + steps = _extract_violation_steps(res.trace) + if steps: + existing = first_steps.get(key) + first_steps[key] = ( + min(steps) if existing is None else min(existing, min(steps)) + ) + existing_last = last_steps.get(key) + last_steps[key] = ( + max(steps) + if existing_last is None + else max(existing_last, max(steps)) + ) + elif key not in first_steps: + first_steps[key] = None + last_steps[key] = None + # Accumulate step → stage (first stage seen per step wins) + for rec in res.trace or []: + if not isinstance(rec, dict): + continue + step = rec.get("meta_vars.step") + stage = rec.get("meta_vars.stage") + if step is not None and step not in step_stage_maps[key]: + step_stage_maps[key][step] = stage + # One sample trace per invariant (first violation wins) + if key not in sample_traces and res.trace: + sample_traces[key] = _summarize_trace_records(res.trace) top_pairs = counter.most_common(10) return [ - {"label": label, "relation": relation, "count": count} + { + "label": label, + "relation": relation, + "count": count, + "first_step": first_steps.get((label, relation)), + "last_step": last_steps.get((label, relation)), + "step_stages": sorted(step_stage_maps[(label, relation)].items()), + "sample_trace": sample_traces.get((label, relation), []), + } for (label, relation), count in top_pairs ] @@ -114,6 +202,7 @@ def build_offline_report_data( all_failed_invariants.extend([res for res in results if not res.check_passed]) top_violations = _count_failed_invariants(all_failed_invariants) + violation_steps_map = _build_violation_steps_map(all_failed_invariants) return { "mode": "offline", @@ -123,9 +212,74 @@ def build_offline_report_data( "relations": dict(overall_relation_counts), "traces": trace_sections, "top_violations": top_violations, + "violation_steps_map": violation_steps_map, } +_TRACE_DISPLAY_KEYS = ( + "function", + "meta_vars.step", + "meta_vars.stage", + "type", + "var_name", + "var_type", +) +_TRACE_SKIP_PREFIXES = ("attributes._TRAINCHECK_",) + +# Known stage → badge color (bg, text) +_STAGE_COLORS: dict[str, tuple[str, str]] = { + "train": ("#2f6fed", "#fff"), + "training": ("#2f6fed", "#fff"), + "eval": ("#2fb679", "#fff"), + "evaluation": ("#2fb679", "#fff"), + "validation": ("#2fb679", "#fff"), + "val": ("#2fb679", "#fff"), + "test": ("#f2b233", "#333"), + "inference": ("#9b59b6", "#fff"), + "pretrain": ("#1abc9c", "#fff"), +} +_STAGE_FALLBACK_PALETTE = [ + ("#e24c4b", "#fff"), + ("#e67e22", "#fff"), + ("#e91e63", "#fff"), + ("#00bcd4", "#fff"), + ("#607d8b", "#fff"), +] + + +def _stage_badge_style(stage: str) -> str: + """Return inline CSS background/color for a stage badge.""" + key = stage.lower() + if key in _STAGE_COLORS: + bg, fg = _STAGE_COLORS[key] + else: + bg, fg = _STAGE_FALLBACK_PALETTE[hash(key) % len(_STAGE_FALLBACK_PALETTE)] + return f"background:{bg};color:{fg}" + + +def _summarize_trace_records(trace: list[dict] | None) -> list[dict]: + """Return a compact, HTML-safe subset of trace records for display.""" + if not trace: + return [] + out = [] + for rec in trace: + if not isinstance(rec, dict): + continue + row: dict = {} + for key in _TRACE_DISPLAY_KEYS: + val = rec.get(key) + if val is not None: + row[key] = str(val) + # Add first attribute-style key that is not an internal one + for key, val in rec.items(): + if key.startswith("attributes.") and val is not None: + if not any(key.startswith(p) for p in _TRACE_SKIP_PREFIXES): + row[key] = str(val) + break + out.append(row) + return out + + def build_online_report_data( *, generated_at: str, @@ -134,24 +288,106 @@ def build_online_report_data( total_violations: int, failed_inv: dict[Invariant, int], relation_totals: dict[str, int], + violation_details: dict | None = None, + triggered_inv: set | None = None, + all_invs: list | None = None, + current_step: int | None = None, + current_stage: str | None = None, + sampling_interval: int | None = None, + warm_up_steps: int | None = None, ) -> dict: relation_violations: dict[str, int] = defaultdict(int) for inv in failed_inv: relation_violations[inv.relation.__name__] += 1 - top_pairs = sorted( - ((count, inv) for inv, count in failed_inv.items()), - key=lambda item: item[0], - reverse=True, - )[:10] - top_violations = [ - { + if violation_details is None: + violation_details = {} + + # Estimate how many steps have been checked so far. + checked_steps: int | None = None + if ( + current_step is not None + and sampling_interval is not None + and warm_up_steps is not None + and sampling_interval > 0 + ): + checked_steps = max(0, current_step - warm_up_steps) // sampling_interval + min( + current_step, warm_up_steps + ) + + def _make_entry(inv: Invariant, count: int) -> dict: + detail = violation_details.get(inv, {}) + step_stages: list[tuple] = detail.get("step_stages") or [] + sample_trace: list[dict] | None = detail.get("sample_trace") + steps = [s for s, _ in step_stages] + first_step = min(steps) if steps else None + last_step = max(steps) if steps else None + # stage of the first/last violation event + first_stage = ( + next((st for s, st in step_stages if s == first_step), None) + if first_step is not None + else None + ) + last_stage = ( + next((st for s, st in reversed(step_stages) if s == last_step), None) + if last_step is not None + else None + ) + # deduplicated, sorted (step, stage) pairs for the expanded view + unique_step_stages = sorted(set(step_stages), key=lambda x: x[0]) + unique_viol_steps = len(set(s for s, _ in step_stages)) + viol_rate: float | None = None + if checked_steps is not None and checked_steps > 0: + viol_rate = round(unique_viol_steps / checked_steps * 100, 1) + return { "label": _format_invariant_label(inv), "relation": inv.relation.__name__, "count": count, + "first_step": first_step, + "first_stage": first_stage, + "last_step": last_step, + "last_stage": last_stage, + "step_stages": unique_step_stages[:100], # cap for HTML size + "sample_trace": _summarize_trace_records(sample_trace), + "violation_step_count": unique_viol_steps, + "checked_steps": checked_steps, + "violation_rate": viol_rate, } - for count, inv in top_pairs - ] + + # Build step → distinct-invariant count map from all violation_details + step_to_invs: dict[int, set[str]] = defaultdict(set) + for inv, detail in violation_details.items(): + lbl = _format_invariant_label(inv) + for step, _ in detail.get("step_stages") or []: + step_to_invs[step].add(lbl) + violation_steps_map = {step: len(invs) for step, invs in step_to_invs.items()} + + # Sort by first violation step (earliest first), then by count descending. + def _sort_key(item): + inv, count = item + detail = violation_details.get(inv, {}) + step_stages = detail.get("step_stages") or [] + steps = [s for s, _ in step_stages] + first = min(steps) if steps else float("inf") + return (first, -count) + + sorted_pairs = sorted(failed_inv.items(), key=_sort_key)[:20] + top_violations = [_make_entry(inv, count) for inv, count in sorted_pairs] + + # Progress tracking + triggered_count = len(triggered_inv) if triggered_inv is not None else 0 + failing_count = len(failed_inv) + passing_count = triggered_count - failing_count + not_triggered_count = total_invariants - triggered_count + pass_rate = ( + round(passing_count / triggered_count * 100, 1) if triggered_count > 0 else None + ) + + not_triggered_labels: list[str] = [] + if all_invs is not None and triggered_inv is not None: + not_triggered_labels = [ + _format_invariant_label(inv) for inv in all_invs if inv not in triggered_inv + ][:50] relations = {} for relation_name, total in relation_totals.items(): @@ -168,6 +404,13 @@ def build_online_report_data( "not_triggered": None, "triggered": None, "violated_invariants": len(failed_inv), + # progress fields + "triggered_count": triggered_count, + "passing_count": passing_count, + "not_triggered_count": not_triggered_count, + "pass_rate": pass_rate, + "current_step": current_step, + "current_stage": current_stage, } return { @@ -178,9 +421,46 @@ def build_online_report_data( "relations": relations, "traces": [], "top_violations": top_violations, + "not_triggered_labels": not_triggered_labels, + "sampling_interval": sampling_interval, + "warm_up_steps": warm_up_steps, + "checked_steps": checked_steps, + "violation_steps_map": violation_steps_map, } +def _render_stage_badge(stage: str | None, esc_fn) -> str: + if not stage: + return "" + style = _stage_badge_style(stage) + return f'{esc_fn(stage)}' + + +def _render_step_stages_html( + step_stages: list[tuple], esc_fn, max_per_group: int = 15 +) -> str: + """Render a compact stage-grouped step list as HTML.""" + if not step_stages: + return "—" + # Group consecutive same-stage runs + groups: list[tuple[str | None, list[int]]] = [] + for step, stage in step_stages: + if groups and groups[-1][0] == stage: + groups[-1][1].append(step) + else: + groups.append((stage, [step])) + parts = [] + for stage, steps in groups: + shown = steps[:max_per_group] + more = len(steps) - len(shown) + steps_str = ", ".join(str(s) for s in shown) + if more > 0: + steps_str += f' +{more} more' + badge = _render_stage_badge(stage, esc_fn) + parts.append(f"{badge}{steps_str}") + return ' · '.join(parts) + + def _render_bar_segment(width_pct: float, class_name: str) -> str: width_pct = max(0.0, min(100.0, width_pct)) return f'' @@ -200,17 +480,177 @@ def percent(part: int, total: int) -> float: traces = report_data.get("traces", []) top_violations = report_data.get("top_violations", []) - top_items = [] - for entry in top_violations: - label = esc(str(entry.get("label", ""))) - detail = esc(str(entry.get("relation", ""))) - count = entry.get("count") - count_html = f'{count}' if count else "" - top_items.append( - f'
  • {label}' - f'{detail}{count_html}
  • ' + top_table_html = "" + + if mode == "online": + sampling_interval = report_data.get("sampling_interval") + warm_up_steps_val = report_data.get("warm_up_steps") + checked_steps_total = report_data.get("checked_steps") + has_sampling = sampling_interval is not None + + rows = [] + for entry in top_violations: + label = esc(str(entry.get("label", ""))) + relation = esc(str(entry.get("relation", ""))) + count = entry.get("count", "") + first_step = entry.get("first_step") + first_stage = entry.get("first_stage") + last_step = entry.get("last_step") + last_stage = entry.get("last_stage") + step_stages: list = entry.get("step_stages") or [] + sample_trace = entry.get("sample_trace") or [] + violation_step_count = entry.get("violation_step_count", 0) + entry_checked = entry.get("checked_steps") + viol_rate = entry.get("violation_rate") + + def _step_with_badge(step, stage) -> str: + if step is None: + return "—" + badge = _render_stage_badge(stage, esc) + return f"{badge}{step}" + + first_step_html = _step_with_badge(first_step, first_stage) + last_step_html = _step_with_badge(last_step, last_stage) + steps_html = _render_step_stages_html(step_stages, esc) + + # Build sample trace table + if sample_trace: + all_keys: list[str] = [] + for rec in sample_trace: + for k in rec: + if k not in all_keys: + all_keys.append(k) + trace_head = "".join(f"{esc(k)}" for k in all_keys) + trace_rows_html = [] + for rec in sample_trace: + cells = [] + for k in all_keys: + val = rec.get(k, "") + # Style stage cells + if k == "meta_vars.stage" and val: + style = _stage_badge_style(val) + cell = ( + f'' + f"{esc(val)}" + ) + else: + cell = f"{esc(str(val))}" + cells.append(cell) + trace_rows_html.append(f"{''.join(cells)}") + trace_body = "\n".join(trace_rows_html) + expand_content = ( + f'
    Steps: {steps_html}
    ' + f'
    ' + f"{trace_head}" + f"{trace_body}
    " + ) + else: + expand_content = f'
    Steps: {steps_html}
    ' + + # Frequency cell: prefer rate when sampling info available + if has_sampling and entry_checked is not None and entry_checked > 0: + rate_str = f"{viol_rate}%" if viol_rate is not None else "?" + freq_cell = ( + f'{rate_str}' + f'' + f"{violation_step_count}/{entry_checked} steps" + f"" + ) + else: + freq_cell = f'{count}' + + rows.append( + f"" + f'
    {label}' + f'
    {expand_content}
    ' + f'{relation}' + f'{first_step_html}' + f'{last_step_html}' + f'{freq_cell}' + f"" + ) + + freq_col_header = "Frequency" if has_sampling else "Count" + top_table_html = ( + f'' + f"" + f"" + f"{''.join(rows)}
    InvariantFirst StepLast Step{freq_col_header}
    " + if rows + else "

    No violations yet.

    " + ) + else: + rows = [] + for entry in top_violations: + label = esc(str(entry.get("label", ""))) + relation = esc(str(entry.get("relation", ""))) + count = entry.get("count", "") + first_step = entry.get("first_step") + last_step = entry.get("last_step") + off_step_stages: list = entry.get("step_stages") or [] + off_sample_trace = entry.get("sample_trace") or [] + + def _step_cell(step, _ss=off_step_stages) -> str: + if step is None: + return "—" + stage = next((s for st, s in _ss if st == step), None) + badge = _render_stage_badge(stage, esc) + return f"{badge}{step}" + + first_step_html = _step_cell(first_step) + last_step_html = _step_cell(last_step) + steps_html = _render_step_stages_html(off_step_stages, esc) + + if off_sample_trace: + off_keys: list[str] = [] + for rec in off_sample_trace: + for k in rec: + if k not in off_keys: + off_keys.append(k) + trace_head = "".join(f"{esc(k)}" for k in off_keys) + trace_rows_html = [] + for rec in off_sample_trace: + cells = [] + for k in off_keys: + val = rec.get(k, "") + if k == "meta_vars.stage" and val: + style = _stage_badge_style(val) + cell = ( + f'' + f"{esc(val)}" + ) + else: + cell = f"{esc(str(val))}" + cells.append(cell) + trace_rows_html.append(f"{''.join(cells)}") + trace_body = "\n".join(trace_rows_html) + expand_content = ( + f'
    Steps: {steps_html}
    ' + f'
    ' + f"{trace_head}" + f"{trace_body}
    " + ) + else: + expand_content = f'
    Steps: {steps_html}
    ' + + rows.append( + f"" + f'
    {label}' + f'
    {expand_content}
    ' + f'{relation}' + f'{first_step_html}' + f'{last_step_html}' + f'{count}' + f"" + ) + top_table_html = ( + f'' + f"" + f"" + f"{''.join(rows)}
    InvariantFirst StepLast StepCount
    " + if rows + else "

    No violations.

    " ) - top_list = "".join(top_items) or "
  • None
  • " trace_sections = [] for trace in traces: @@ -225,17 +665,74 @@ def percent(part: int, total: int) -> float: + _render_bar_segment(percent(not_triggered, total), "bar-not-triggered") ) - failed_list_items = [] + failed_rows = [] for failed_item in trace["failed_invariants"][:10]: label = esc(str(failed_item.get("label", ""))) - detail = esc(str(failed_item.get("relation", ""))) - count = failed_item.get("count") - count_html = f'{count}' if count else "" - failed_list_items.append( - f'
  • {label}' - f'{detail}{count_html}
  • ' + relation = esc(str(failed_item.get("relation", ""))) + count = failed_item.get("count", "") + first_step = failed_item.get("first_step") + last_step = failed_item.get("last_step") + item_step_stages: list = failed_item.get("step_stages") or [] + item_sample_trace = failed_item.get("sample_trace") or [] + + def _step_cell_trace(step) -> str: + if step is None: + return "—" + stage = next((s for st, s in item_step_stages if st == step), None) + badge = _render_stage_badge(stage, esc) + return f"{badge}{step}" + + steps_html = _render_step_stages_html(item_step_stages, esc) + if item_sample_trace: + item_keys: list[str] = [] + for rec in item_sample_trace: + for k in rec: + if k not in item_keys: + item_keys.append(k) + trace_head = "".join(f"{esc(k)}" for k in item_keys) + trace_rows_html = [] + for rec in item_sample_trace: + cells = [] + for k in item_keys: + val = rec.get(k, "") + if k == "meta_vars.stage" and val: + style = _stage_badge_style(val) + cell = ( + f'' + f"{esc(val)}" + ) + else: + cell = f"{esc(str(val))}" + cells.append(cell) + trace_rows_html.append(f"{''.join(cells)}") + trace_body = "\n".join(trace_rows_html) + expand_content = ( + f'
    Steps: {steps_html}
    ' + f'
    ' + f"{trace_head}" + f"{trace_body}
    " + ) + else: + expand_content = f'
    Steps: {steps_html}
    ' + + failed_rows.append( + f"" + f'
    {label}' + f'
    {expand_content}
    ' + f'{relation}' + f'{_step_cell_trace(first_step)}' + f'{_step_cell_trace(last_step)}' + f'{count}' + f"" ) - failed_list_html = "".join(failed_list_items) or "
  • None
  • " + failed_list_html = ( + f'' + f"" + f"" + f"{''.join(failed_rows)}
    InvariantFirst StepLast StepCount
    " + if failed_rows + else "

    None

    " + ) relation_rows = [] for relation_name, rel_counts in sorted(trace["relations"].items()): @@ -272,7 +769,7 @@ def percent(part: int, total: int) -> float:

    Failed invariants (top 10)

    -
      {failed_list_html}
    + {failed_list_html}

    Relation breakdown

    @@ -312,23 +809,114 @@ def percent(part: int, total: int) -> float: ) if mode == "online": + cur_step = overall.get("current_step") + cur_stage = overall.get("current_stage") + triggered_count = overall.get("triggered_count", 0) or 0 + passing_count = overall.get("passing_count", 0) or 0 + not_triggered_count = overall.get("not_triggered_count", 0) or 0 + pass_rate_val = overall.get("pass_rate") + total_invariants_n = overall["total_invariants"] or 0 + violated = overall.get("violated_invariants", 0) or 0 + + # Current step card — show stage badge inline + if cur_step is not None: + stage_badge = _render_stage_badge(cur_stage, esc) if cur_stage else "" + step_value_html = f"{stage_badge}{cur_step}" + step_sub = esc(f"stage: {cur_stage}") if cur_stage else "no stage info" + else: + step_value_html = "—" + step_sub = "waiting for first trace record" + + pass_rate_display = f"{pass_rate_val}%" if pass_rate_val is not None else "—" + card_html = f"""
    Total Invariants
    -
    {overall['total_invariants']}
    +
    {total_invariants_n}
    +
    +
    +
    Triggered
    +
    {triggered_count}
    +
    of {total_invariants_n} loaded
    +
    +
    +
    Pass Rate
    +
    {pass_rate_display}
    +
    {passing_count} passing · {violated} failing
    Violations
    {overall['failed']}
    -
    -
    Violated Invariants
    -
    {overall.get('violated_invariants', 0)}
    +
    +
    Current Step
    +
    {step_value_html}
    +
    {step_sub}
    """ relation_header = "RelationTotalViolated" - mode_note = '
    Online mode (partial coverage).
    ' + mode_note = '
    Online mode — checking in progress.
    ' + + # Progress panel (checking coverage bar + not-yet-triggered list) + not_triggered_labels: list[str] = report_data.get("not_triggered_labels", []) + bar_total = total_invariants_n or 1 + passing_pct = percent(passing_count, bar_total) + failing_pct = percent(violated, bar_total) + not_triggered_pct = percent(not_triggered_count, bar_total) + progress_bar = ( + _render_bar_segment(passing_pct, "bar-passed") + + _render_bar_segment(failing_pct, "bar-failed") + + _render_bar_segment(not_triggered_pct, "bar-not-triggered") + ) + + if not_triggered_labels: + nt_items = "".join( + f'
  • {esc(lbl)}
  • ' for lbl in not_triggered_labels + ) + suffix = ( + f" (showing first {len(not_triggered_labels)})" + if not_triggered_count > len(not_triggered_labels) + else "" + ) + nt_section = ( + f'
    ' + f"{not_triggered_count} not yet triggered{esc(suffix)}" + f'
      {nt_items}
    ' + f"
    " + ) + elif not_triggered_count > 0: + nt_section = f'

    {not_triggered_count} invariant(s) not yet triggered.

    ' + else: + nt_section = '

    All invariants have been triggered at least once.

    ' + + progress_panel = f""" +
    +
    +
    +

    Checking Progress

    +
    {triggered_count} of {total_invariants_n} invariants triggered so far
    +
    +
    + Passing{passing_count} + Failing{violated} + Not Triggered{not_triggered_count} +
    +
    +
    {progress_bar}
    +
    + Passing ({passing_count}) + Failing ({violated}) + Not Triggered ({not_triggered_count}) +
    + {nt_section} +
    + """ else: + total_checks = overall["total_checks"] or 0 + passed_checks = overall["passed"] or 0 + pass_rate = ( + round(passed_checks / total_checks * 100, 1) if total_checks else 0.0 + ) card_html = f"""
    Total Invariants
    @@ -342,6 +930,10 @@ def percent(part: int, total: int) -> float:
    Passed Checks
    {overall['passed']}
    +
    +
    Pass Rate
    +
    {pass_rate}%
    +
    Not Triggered
    {overall['not_triggered']}
    @@ -349,16 +941,54 @@ def percent(part: int, total: int) -> float: """ relation_header = "RelationFailedPassedNot Triggered" mode_note = "" + progress_panel = "" + + # surface first violation step in the top violations panel header + all_first_steps = [ + v["first_step"] for v in top_violations if v.get("first_step") is not None + ] + if all_first_steps: + first_step_note = ( + f'
    First violation at step {min(all_first_steps)}' + f" · {len(top_violations)} distinct invariant(s) violated
    " + ) + elif top_violations: + first_step_note = ( + f'
    {len(top_violations)}' + " distinct invariant(s) violated
    " + ) + else: + first_step_note = "" + + if mode == "online": + sampling_interval = report_data.get("sampling_interval") + warm_up_steps_val = report_data.get("warm_up_steps") + checked_steps_total = report_data.get("checked_steps") + sampling_ctx = "" + if sampling_interval is not None: + sampling_ctx = f" · sampled every {sampling_interval} steps" + if warm_up_steps_val is not None: + sampling_ctx += f", warm-up {warm_up_steps_val}" + if checked_steps_total is not None: + sampling_ctx += f" ({checked_steps_total} steps checked)" + panel_subtitle = esc( + f"Sorted by first violation step — click to expand trace{sampling_ctx}" + ) + panel_content = top_table_html + else: + panel_subtitle = "Sorted by first violation step — click to expand trace" + panel_content = top_table_html top_panel = f"""
    -

    Top Violations

    -
    Most frequent violations observed
    +

    Violations

    +
    {panel_subtitle}
    + {first_step_note}
    -
      {top_list}
    + {panel_content}
    """ @@ -563,6 +1193,91 @@ def percent(part: int, total: int) -> float: text-transform: uppercase; letter-spacing: 0.04em; }} + .viol-table td {{ vertical-align: top; }} + .step-cell {{ white-space: nowrap; font-variant-numeric: tabular-nums; font-weight: 600; }} + .count-cell {{ white-space: nowrap; font-variant-numeric: tabular-nums; font-weight: 700; color: var(--failed); }} + .freq-cell {{ white-space: nowrap; }} + .freq-rate {{ display: block; font-variant-numeric: tabular-nums; font-weight: 700; color: var(--failed); }} + .freq-detail {{ display: block; font-size: 11px; color: var(--muted); margin-top: 2px; }} + .stage-badge {{ + display: inline-block; + font-size: 10px; + font-weight: 700; + padding: 1px 6px; + border-radius: 99px; + margin-right: 4px; + text-transform: uppercase; + letter-spacing: 0.05em; + vertical-align: middle; + }} + .step-sep {{ color: var(--muted); }} + .more-steps {{ color: var(--muted); font-size: 11px; }} + .card-sub {{ font-size: 12px; color: var(--muted); margin-top: 4px; }} + .card-pass .value {{ color: var(--passed); }} + .card-step .step-value {{ font-size: 22px; }} + .c-pass {{ color: var(--passed); }} + .c-fail {{ color: var(--failed); }} + .progress-panel h2 {{ margin: 0 0 4px; font-size: 22px; }} + .nt-details summary {{ + cursor: pointer; + font-size: 13px; + color: var(--muted); + padding: 6px 0; + }} + .nt-details summary::-webkit-details-marker {{ display: none; }} + .nt-details summary::before {{ content: "▸ "; color: var(--accent); }} + details[open].nt-details summary::before {{ content: "▾ "; }} + .nt-list {{ + list-style: none; + padding: 0; + margin: 8px 0 0; + display: grid; + grid-template-columns: repeat(auto-fill, minmax(280px, 1fr)); + gap: 6px; + }} + .nt-item {{ + font-size: 12px; + color: var(--muted); + background: #f4f6fb; + border: 1px solid var(--border); + border-radius: 6px; + padding: 5px 10px; + }} + .inv-label-summary {{ + cursor: pointer; + font-weight: 600; + font-size: 14px; + list-style: none; + padding: 2px 0; + }} + .inv-label-summary::-webkit-details-marker {{ display: none; }} + .inv-label-summary::before {{ content: "▸ "; color: var(--accent); font-size: 11px; }} + details[open] .inv-label-summary::before {{ content: "▾ "; }} + .inv-rel-tag {{ + display: inline-block; + font-size: 11px; + color: var(--muted); + background: #f0f2f8; + border-radius: 4px; + padding: 1px 6px; + margin-top: 4px; + }} + .expand-body {{ + margin-top: 10px; + padding: 10px; + background: #f8f9fd; + border-radius: 8px; + border: 1px solid var(--border); + }} + .trace-steps {{ + font-size: 12px; + color: var(--muted); + margin-bottom: 8px; + word-break: break-all; + }} + .trace-wrap {{ overflow-x: auto; }} + .trace-table {{ font-size: 11px; min-width: 400px; }} + .trace-table th {{ font-size: 10px; }} footer {{ margin-top: 28px; font-size: 12px; @@ -585,6 +1300,8 @@ def percent(part: int, total: int) -> float: {card_html}
    + {progress_panel} + {top_panel} {relation_table} @@ -691,6 +1408,8 @@ def _log_wandb( report_path: str | None, args: argparse.Namespace, ): + import glob + try: import wandb except ImportError: @@ -700,7 +1419,7 @@ def _log_wandb( return if self._wandb_run is None: - self._wandb_run = wandb.init( + init_kwargs: dict = dict( project=args.wandb_project, entity=args.wandb_entity, name=args.wandb_run_name, @@ -708,38 +1427,127 @@ def _log_wandb( tags=args.wandb_tags, job_type="checker", ) + run_id = getattr(args, "wandb_run_id", None) + if run_id: + init_kwargs["id"] = run_id + init_kwargs["resume"] = "allow" + self._wandb_run = wandb.init(**init_kwargs) # type: ignore[assignment] + run = self._wandb_run + if run is None: + logging.getLogger(__name__).warning("wandb.init() returned None; skipping.") + return overall = report_data["overall"] mode = report_data.get("mode", "offline") + + # --- run config (searchable/filterable in W&B UI) --- + run.config.update( + { + "traincheck/invariants_total": overall["total_invariants"], + "traincheck/output_dir": report_data.get("output_dir", ""), + "traincheck/mode": mode, + }, + allow_val_change=True, + ) + + # --- scalar metrics --- if mode == "online": + total_invariants = overall["total_invariants"] or 0 + violated = overall.get("violated_invariants", 0) or 0 + violation_rate = ( + round(violated / total_invariants * 100, 1) if total_invariants else 0.0 + ) wandb.log( { - "invariants/total": overall["total_invariants"], - "invariants/violated_unique": overall.get("violated_invariants", 0), + "invariants/total": total_invariants, + "invariants/violated_unique": violated, + "invariants/violation_rate_pct": violation_rate, "violations/total": overall["failed"], } ) else: + total_checks = overall["total_checks"] or 0 + passed = overall["passed"] or 0 + pass_rate = round(passed / total_checks * 100, 1) if total_checks else 0.0 wandb.log( { "invariants/total": overall["total_invariants"], - "checks/total": overall["total_checks"], + "checks/total": total_checks, "checks/failed": overall["failed"], - "checks/passed": overall["passed"], + "checks/passed": passed, "checks/not_triggered": overall["not_triggered"], + "checks/pass_rate_pct": pass_rate, } ) - table = wandb.Table(columns=["relation", "failed", "passed", "not_triggered"]) + # --- relation breakdown table --- + rel_table = wandb.Table( + columns=["relation", "failed", "passed", "not_triggered"] + ) for relation_name, rel_counts in report_data["relations"].items(): - table.add_data( + rel_table.add_data( relation_name, rel_counts.get("failed", 0), rel_counts.get("passed", 0), rel_counts.get("not_triggered", 0), ) - wandb.log({"relation_breakdown": table}) + wandb.log({"relation_breakdown": rel_table}) + + # --- violated invariants table --- + top_violations = report_data.get("top_violations", []) + if top_violations: + vtable = wandb.Table( + columns=[ + "invariant", + "relation_type", + "occurrences", + "first_step", + "last_step", + ] + ) + for v in top_violations: + vtable.add_data( + v.get("label", ""), + v.get("relation", ""), + v.get("count", 0), + v.get("first_step"), + v.get("last_step"), + ) + wandb.log({"violations": vtable}) + + # --- summary metrics (shown in run comparison table) --- + first_steps = [ + v["first_step"] for v in top_violations if v.get("first_step") is not None + ] + last_steps_wandb = [ + v["last_step"] for v in top_violations if v.get("last_step") is not None + ] + if first_steps: + run.summary["violations/first_step"] = min(first_steps) + if last_steps_wandb: + run.summary["violations/last_step"] = max(last_steps_wandb) + run.summary["violations/distinct_invariants"] = len(top_violations) + + # --- violations_summary.json as versioned artifact --- + summary_files = glob.glob( + os.path.join(self.output_dir, "*", "violations_summary.json") + ) + if summary_files: + try: + artifact = wandb.Artifact( + name="violations_summary", + type="checker_output", + description="Per-trace violation summaries from traincheck-check", + ) + for summary_file in summary_files: + artifact.add_file(summary_file) + run.log_artifact(artifact) + except Exception: + logging.getLogger(__name__).warning( + "Failed to attach violations_summary artifact to wandb run." + ) + # --- HTML report --- if report_path: try: with open(report_path, "r") as f: @@ -749,6 +1557,11 @@ def _log_wandb( "Failed to attach HTML report to wandb run." ) + # --- per-step violation time-series (overlays with training loss curve) --- + violation_steps_map: dict[int, int] = report_data.get("violation_steps_map", {}) + for step, count in sorted(violation_steps_map.items()): + wandb.log({"traincheck/violations": count}, step=step) + def _log_mlflow( self, report_data: dict, @@ -763,6 +1576,8 @@ def _log_mlflow( ) return + import glob + if args.mlflow_experiment: mlflow.set_experiment(args.mlflow_experiment) @@ -772,18 +1587,89 @@ def _log_mlflow( overall = report_data["overall"] mode = report_data.get("mode", "offline") + top_violations = report_data.get("top_violations", []) + + # --- run tags (searchable in MLflow UI) --- + mlflow.set_tags( + { + "traincheck.mode": mode, + "traincheck.invariants_total": str(overall["total_invariants"]), + "traincheck.output_dir": report_data.get("output_dir", ""), + } + ) + + # --- scalar metrics --- if mode == "online": - mlflow.log_metric("invariants_total", overall["total_invariants"]) - mlflow.log_metric( - "invariants_violated_unique", overall.get("violated_invariants", 0) + total_invariants = overall["total_invariants"] or 0 + violated = overall.get("violated_invariants", 0) or 0 + violation_rate = ( + round(violated / total_invariants * 100, 1) if total_invariants else 0.0 ) + mlflow.log_metric("invariants_total", total_invariants) + mlflow.log_metric("invariants_violated_unique", violated) + mlflow.log_metric("invariants_violation_rate_pct", violation_rate) mlflow.log_metric("violations_total", overall["failed"]) else: + total_checks = overall["total_checks"] or 0 + passed = overall["passed"] or 0 + pass_rate = round(passed / total_checks * 100, 1) if total_checks else 0.0 mlflow.log_metric("invariants_total", overall["total_invariants"]) - mlflow.log_metric("checks_total", overall["total_checks"]) + mlflow.log_metric("checks_total", total_checks) mlflow.log_metric("checks_failed", overall["failed"]) - mlflow.log_metric("checks_passed", overall["passed"]) + mlflow.log_metric("checks_passed", passed) mlflow.log_metric("checks_not_triggered", overall["not_triggered"]) + mlflow.log_metric("checks_pass_rate_pct", pass_rate) + + # --- violation summary metrics --- + first_steps = [ + v["first_step"] for v in top_violations if v.get("first_step") is not None + ] + last_steps = [ + v["last_step"] for v in top_violations if v.get("last_step") is not None + ] + if first_steps: + mlflow.log_metric("violations_first_step", min(first_steps)) + if last_steps: + mlflow.log_metric("violations_last_step", max(last_steps)) + mlflow.log_metric("violations_distinct_invariants", len(top_violations)) + + # --- per-step violation time-series (overlays with training loss curve) --- + violation_steps_map: dict[int, int] = report_data.get("violation_steps_map", {}) + for step, count in sorted(violation_steps_map.items()): + mlflow.log_metric("traincheck_violations", count, step=step) + + # --- violations table (mlflow.log_table for proper UI rendering) --- + if top_violations: + try: + mlflow.log_table( + data={ + "invariant": [v.get("label", "") for v in top_violations], + "relation_type": [ + v.get("relation", "") for v in top_violations + ], + "occurrences": [v.get("count", 0) for v in top_violations], + "first_step": [v.get("first_step") for v in top_violations], + "last_step": [v.get("last_step") for v in top_violations], + }, + artifact_file="violations.json", + ) + except Exception: + logging.getLogger(__name__).warning( + "Failed to log violations table to MLflow." + ) + + # --- per-trace violations_summary.json artifacts --- + summary_files = glob.glob( + os.path.join(self.output_dir, "*", "violations_summary.json") + ) + for summary_file in summary_files: + try: + mlflow.log_artifact(summary_file, artifact_path="violations_summaries") + except Exception: + logging.getLogger(__name__).warning( + "Failed to log %s to MLflow.", summary_file + ) + # --- HTML report --- if report_path: mlflow.log_artifact(report_path) diff --git a/traincheck/static_analyzer/graph_generator/call_graph_parser.py b/traincheck/static_analyzer/graph_generator/call_graph_parser.py index 5ce9d258..dd997b26 100644 --- a/traincheck/static_analyzer/graph_generator/call_graph_parser.py +++ b/traincheck/static_analyzer/graph_generator/call_graph_parser.py @@ -1,11 +1,14 @@ ## Step 1: filer out the lines from func_level.log with the format - function_depth, e.g. - 2 import importlib +import logging import os import re from traincheck.instrumentor.proxy_wrapper.proxy_observer import add_observer_to_func +logger = logging.getLogger(__name__) + def unparse_module(module_name, level=0): if level > 3: @@ -19,10 +22,9 @@ def unparse_module(module_name, level=0): last_name = module_name.split(".")[-1] try: func_obj = getattr(module, last_name) - # print(f"object {last_name} found in module {'.'.join(module_name.split('.')[:-1])}") return func_obj except AttributeError: - print( + logger.debug( f"object {last_name} not found in module {'.'.join(module_name.split('.')[:-1])}" ) # Ziming: from out observation this typically just mean a function call contains a local class or function call, so we can just pass @@ -39,7 +41,9 @@ def add_observer(module_name, function_name, observe_then_unproxy=False): # module could be a class here, load the class and get the function module = unparse_module(module_name) if module is None: - print(f"error finding {function_name}: module {module_name} not found") + logger.debug( + f"error finding {function_name}: module {module_name} not found" + ) try: # Retrieve the function or property @@ -48,13 +52,13 @@ def add_observer(module_name, function_name, observe_then_unproxy=False): # Check if it's a property before proceeding if isinstance(function, property): - print( + logger.debug( f"Skipping property function: {function_name} in module: {module_name}" ) return # Apply observer to non-property functions - print(f"Observe function: {function_name} found in module: {module}") + logger.debug(f"Observe function: {function_name} found in module: {module}") setattr( module, function_name, @@ -62,9 +66,8 @@ def add_observer(module_name, function_name, observe_then_unproxy=False): ) except AttributeError: - print(f"function {function_name} not found in module {module_name}") + logger.debug(f"function {function_name} not found in module {module_name}") return - # print(f'function: {function} found in module: {module}') # read the func_level.log file @@ -104,11 +107,9 @@ def add_observer_given_call_graph( # save those with function_depth <= depth if observe_up_to_depth: if int(function_depth) <= depth: - # print(f'module_name: {module_name}, function_name: {function_name}, function_depth: {function_depth}') add_observer(module_name, function_name, observe_then_unproxy) else: if int(function_depth) == depth: - # print(f'module_name: {module_name}, function_name: {function_name}, function_depth: {function_depth}') add_observer(module_name, function_name, observe_then_unproxy) diff --git a/traincheck/trace/trace_dict.py b/traincheck/trace/trace_dict.py index 7ae2339b..a218a647 100644 --- a/traincheck/trace/trace_dict.py +++ b/traincheck/trace/trace_dict.py @@ -3,10 +3,9 @@ import re from collections import defaultdict -from tqdm import tqdm - from traincheck.config import config from traincheck.instrumentor.tracer import TraceLineType +from traincheck.progress import tqdm from traincheck.trace.trace import Trace from traincheck.trace.types import ( AttrState, diff --git a/traincheck/trace/trace_pandas.py b/traincheck/trace/trace_pandas.py index c51d946d..6263bea6 100644 --- a/traincheck/trace/trace_pandas.py +++ b/traincheck/trace/trace_pandas.py @@ -3,11 +3,11 @@ from typing import Any import pandas as pd -from tqdm import tqdm from traincheck.config import config from traincheck.instrumentor.tracer import TraceLineType from traincheck.instrumentor.types import PTID +from traincheck.progress import tqdm from traincheck.trace.trace import Trace from traincheck.trace.types import ( MD_NONE, @@ -201,6 +201,7 @@ def _rm_incomplete_trailing_func_calls(self): for _, row in tqdm( incomplete_func_call_records.iterrows(), desc="Removing Incomplete Function Calls", + leave=False, ): assert ( row["type"] == TraceLineType.FUNC_CALL_PRE diff --git a/traincheck/trace/trace_polars.py b/traincheck/trace/trace_polars.py index 924637de..243e671f 100644 --- a/traincheck/trace/trace_polars.py +++ b/traincheck/trace/trace_polars.py @@ -2,10 +2,10 @@ import re import polars as pl -from tqdm import tqdm from traincheck.config import config from traincheck.instrumentor.tracer import TraceLineType +from traincheck.progress import tqdm from traincheck.trace.trace import Trace from traincheck.trace.types import ( AttrState, diff --git a/traincheck/utils.py b/traincheck/utils.py index 944fd989..332cfe04 100644 --- a/traincheck/utils.py +++ b/traincheck/utils.py @@ -13,6 +13,10 @@ THREAD_LOCAL = threading.local() +# When True, tqdm bars created via traincheck.progress.tqdm are disabled. +# Set by check_engine() so only the single outer checking bar is visible. +_suppress_inner_progress: bool = False + def safe_getattr(obj, attr, default=None): """Safely get the attribute of an object.