2121 _match_dispatch_rule ,
2222 get_available_intrinsics ,
2323 _validate_gadget_worker ,
24+ graph_max_depth ,
2425 )
2526except ImportError :
2627 from success_progress import SuccessProgress
3839 _match_dispatch_rule ,
3940 get_available_intrinsics ,
4041 _validate_gadget_worker ,
42+ graph_max_depth ,
4143 )
4244
4345# Re-export all public names so existing importers keep working.
5557 "StageTracker" ,
5658 "_BATCH_THRESHOLDS" ,
5759 "apply_retroactive_input" ,
60+ "graph_max_depth" ,
5861]
5962
6063
@@ -142,6 +145,11 @@ class StageTracker:
142145 # Progress
143146 progress_task_id : int | None = None
144147
148+ # Depth-stratified search
149+ shallow_complete : bool = False
150+ depth_escalated : bool = False
151+ all_input_states : list = field (default_factory = list )
152+
145153 @property
146154 def is_complete (self ) -> bool :
147155 """True when all submitted jobs are done AND no more inputs coming."""
@@ -237,6 +245,7 @@ def build_solution_tree(
237245 resume_data : dict | None = None ,
238246 max_workers : int | None = None ,
239247 max_tasks_per_child : int | None = 1000 ,
248+ depth2_threshold : float = 0.1 ,
240249 ) -> tuple [list [SolutionNode ], bool ]:
241250 """
242251 Iteratively explore all stage transitions to build solution tree.
@@ -257,6 +266,11 @@ def build_solution_tree(
257266 to os.cpu_count().
258267 max_tasks_per_child: Maximum tasks per worker process before recycling.
259268 Limits memory growth in long runs. None disables recycling.
269+ depth2_threshold: Coverage threshold for depth-2 escalation.
270+ After depth-1 completes for a stage, coverage is computed as
271+ unique_outputs / 2^K (K = number of comparison pairs).
272+ If coverage >= threshold, depth-2 is skipped for that stage.
273+ Default: 0.1 (10%).
260274
261275 Returns:
262276 Tuple of (root nodes, all_stages_complete) where all_stages_complete
@@ -290,8 +304,10 @@ def build_solution_tree(
290304 self ._stages_completed : set [int ] = set ()
291305
292306 initial_state = self ._create_initial_state ()
293- # Pre-compute all candidates once - they're independent of stage/input state
294- all_candidates = self .synthesizer .precompute_all_candidates (gadget_depth )
307+ # Pre-compute candidates split by depth tier
308+ shallow_candidates , deep_candidates = (
309+ self .synthesizer .precompute_candidates_stratified (gadget_depth )
310+ )
295311
296312 # Build checkpoint config if checkpoint_dir is provided
297313 checkpoint_config = None
@@ -316,7 +332,9 @@ def build_solution_tree(
316332 nodes_by_path = self ._build_tree_pipelined (
317333 input_states_with_context ,
318334 depth_limit ,
319- all_candidates ,
335+ shallow_candidates ,
336+ deep_candidates ,
337+ depth2_threshold = depth2_threshold ,
320338 max_unique_outputs = max_unique_outputs ,
321339 checkpoint_dir = checkpoint_dir ,
322340 checkpoint_config = checkpoint_config ,
@@ -345,8 +363,15 @@ def _make_jobs_for_inputs(
345363 stage_pairs : list ,
346364 perm_gadget_candidates : list [tuple ],
347365 max_unique_outputs : int ,
366+ exclude_outputs : list [tuple ] | None = None ,
348367 ) -> list [tuple ]:
349- """Create validation jobs for a set of input states (Phase 1)."""
368+ """Create validation jobs for a set of input states (Phase 1).
369+
370+ Args:
371+ exclude_outputs: Optional list of (top_tuple, bottom_tuple) output
372+ states to exclude from solver enumeration. Workers will skip
373+ these outputs, increasing diversity of newly-found solutions.
374+ """
350375 jobs = []
351376 for input_state , parent_path in input_states_with_context :
352377 metadata = {
@@ -355,6 +380,8 @@ def _make_jobs_for_inputs(
355380 "stage_idx" : stage_idx ,
356381 "max_unique_outputs" : max_unique_outputs ,
357382 }
383+ if exclude_outputs :
384+ metadata ["exclude_outputs" ] = exclude_outputs
358385 if (
359386 self ._natural_order_stage is not None
360387 and stage_idx == self ._natural_order_stage
@@ -502,7 +529,9 @@ def _build_tree_pipelined(
502529 self ,
503530 initial_inputs : list [tuple [VectorState , tuple ]],
504531 depth_limit : int | None ,
505- all_candidates : list [tuple ],
532+ shallow_candidates : list [tuple ],
533+ deep_candidates : list [tuple ],
534+ depth2_threshold : float = 0.1 ,
506535 max_unique_outputs : int = 3 ,
507536 checkpoint_dir : str | None = None ,
508537 checkpoint_config = None ,
@@ -514,6 +543,10 @@ def _build_tree_pipelined(
514543 soon as unique outputs from stage N are discovered (at geometric-series
515544 thresholds), allowing overlap.
516545
546+ Depth-stratified search: shallow_candidates (depth <= 1) are tried
547+ first. If coverage after depth-1 is below depth2_threshold,
548+ deep_candidates (depth >= 2) are submitted for that stage.
549+
517550 The final output is deterministic: ``_finalize_stage`` sorts transitions
518551 and gadgets in canonical order.
519552 """
@@ -663,62 +696,52 @@ def _forward_outputs(
663696 next_stage_idx : int ,
664697 new_outputs : list [tuple [VectorState , tuple ]],
665698 ) -> None :
666- """Submit jobs for newly-discovered outputs to the next stage."""
699+ """Submit shallow jobs for newly-discovered outputs to the next stage."""
667700 if next_stage_idx >= effective_limit :
668701 return
669702
670703 stage_pairs = self .bitonic_sorter .stages [next_stage_idx ]
704+ is_new = next_stage_idx not in trackers
671705
672- if next_stage_idx not in trackers :
673- # Launch new stage
706+ if is_new :
674707 tracker = StageTracker (
675708 stage_idx = next_stage_idx , stage_pairs = stage_pairs
676709 )
677710 tracker .progress_task_id = stage_task_ids .get (next_stage_idx )
678711 trackers [next_stage_idx ] = tracker
712+ else :
713+ tracker = trackers [next_stage_idx ]
679714
680- jobs = self ._make_jobs_for_inputs (
681- new_outputs ,
682- next_stage_idx ,
683- stage_pairs ,
684- all_candidates ,
685- max_unique_outputs ,
686- )
687- self ._submit_stage_jobs (
688- pool , tracker , jobs , pending_count , completion_queue
689- )
715+ # Common: create jobs, submit, record inputs
716+ jobs = self ._make_jobs_for_inputs (
717+ new_outputs ,
718+ next_stage_idx ,
719+ stage_pairs ,
720+ shallow_candidates ,
721+ max_unique_outputs ,
722+ )
723+ self ._submit_stage_jobs (
724+ pool , tracker , jobs , pending_count , completion_queue
725+ )
726+ tracker .all_input_states .extend (new_outputs )
690727
728+ if is_new :
691729 if tracker .progress_task_id is not None :
692730 progress .start_task (tracker .progress_task_id )
693731 progress .update (
694732 tracker .progress_task_id ,
695733 total = tracker .total_jobs_submitted ,
696734 )
697-
698735 progress .console .print (
699736 f"Stage { next_stage_idx } : Launched { tracker .total_jobs_submitted } jobs "
700737 f"(pipelined from stage { next_stage_idx - 1 } )"
701738 )
702739 else :
703- # Add more jobs to existing stage
704- tracker = trackers [next_stage_idx ]
705- jobs = self ._make_jobs_for_inputs (
706- new_outputs ,
707- next_stage_idx ,
708- stage_pairs ,
709- all_candidates ,
710- max_unique_outputs ,
711- )
712- self ._submit_stage_jobs (
713- pool , tracker , jobs , pending_count , completion_queue
714- )
715-
716740 if tracker .progress_task_id is not None :
717741 progress .update (
718742 tracker .progress_task_id ,
719743 total = tracker .total_jobs_submitted ,
720744 )
721-
722745 progress .console .print (
723746 f"Stage { next_stage_idx } : Added { len (jobs )} jobs "
724747 f"(total: { tracker .total_jobs_submitted } )"
@@ -750,13 +773,14 @@ def _forward_outputs(
750773 current_inputs ,
751774 start_stage ,
752775 new_tracker .stage_pairs ,
753- all_candidates ,
776+ shallow_candidates ,
754777 max_unique_outputs ,
755778 )
756779 self ._submit_stage_jobs (
757780 pool , new_tracker , jobs , pending_count , completion_queue
758781 )
759782 new_tracker .total_jobs_expected = new_tracker .total_jobs_submitted
783+ new_tracker .all_input_states = list (current_inputs )
760784
761785 if new_tracker .progress_task_id is not None :
762786 progress .start_task (new_tracker .progress_task_id )
@@ -770,6 +794,73 @@ def _forward_outputs(
770794 f"for { len (current_inputs )} inputs"
771795 )
772796
797+ def _escalate_stage (stage_idx : int ) -> None :
798+ """Escalate a stage to depth-2 if coverage is below threshold."""
799+ if not deep_candidates :
800+ return
801+ tracker = trackers [stage_idx ]
802+ if tracker .depth_escalated :
803+ return
804+ tracker .depth_escalated = True
805+
806+ # Coverage check: unique_outputs / 2^K
807+ k = len (tracker .stage_pairs )
808+ comparison_space = 2 ** k
809+ n_outputs = len (tracker .unique_outputs )
810+ coverage = n_outputs / comparison_space
811+
812+ if coverage >= depth2_threshold :
813+ progress .console .print (
814+ f"Stage { stage_idx } : { n_outputs } unique outputs = "
815+ f"{ coverage :.1%} of 2^{ k } ({ comparison_space } ) — "
816+ f"above { depth2_threshold :.0%} threshold, skipping depth-2"
817+ )
818+ return
819+
820+ progress .console .print (
821+ f"Stage { stage_idx } : { n_outputs } unique outputs = "
822+ f"{ coverage :.1%} of 2^{ k } ({ comparison_space } ) — "
823+ f"below { depth2_threshold :.0%} threshold, escalating to depth-2 "
824+ f"for { len (tracker .all_input_states )} inputs "
825+ f"({ len (deep_candidates )} deep candidates)"
826+ )
827+
828+ # Pass known outputs so depth-2 workers skip them,
829+ # increasing diversity of newly-found solutions.
830+ known_outputs = list (tracker .unique_outputs .keys ())
831+
832+ jobs = self ._make_jobs_for_inputs (
833+ tracker .all_input_states ,
834+ stage_idx ,
835+ tracker .stage_pairs ,
836+ deep_candidates ,
837+ max_unique_outputs ,
838+ exclude_outputs = known_outputs ,
839+ )
840+ shallow_jobs = tracker .total_jobs_submitted
841+ self ._submit_stage_jobs (
842+ pool , tracker , jobs , pending_count , completion_queue
843+ )
844+ deep_jobs = tracker .total_jobs_submitted - shallow_jobs
845+ tracker .total_jobs_expected = tracker .total_jobs_submitted
846+ tracker .next_threshold_index = 0
847+
848+ progress .console .print (
849+ f"Stage { stage_idx } : Submitted { deep_jobs } depth-2 jobs "
850+ f"(total: { tracker .total_jobs_submitted } , "
851+ f"shallow: { shallow_jobs } , deep: { deep_jobs } )"
852+ )
853+
854+ if tracker .progress_task_id is not None :
855+ progress .update (
856+ tracker .progress_task_id ,
857+ description = f"Stage { stage_idx } (depth-2)" ,
858+ total = tracker .total_jobs_submitted ,
859+ )
860+
861+ # Free accumulated inputs — no longer needed after job creation
862+ tracker .all_input_states = []
863+
773864 # Event loop: process completed results via callback queue
774865 while pending_count [0 ] > 0 :
775866 stage_idx , result , error = completion_queue .get ()
@@ -812,7 +903,24 @@ def _forward_outputs(
812903
813904 # Check completion
814905 if tracker .is_complete and not tracker .finalized :
815- _finalize_and_cascade (stage_idx )
906+ if (
907+ not tracker .shallow_complete
908+ and deep_candidates
909+ and tracker .all_inputs_received
910+ ):
911+ # Shallow phase done. Process remaining, then escalate.
912+ tracker .shallow_complete = True
913+ if tracker .unprocessed_results :
914+ new_outputs = self ._process_batch (tracker )
915+ if new_outputs and stage_idx + 1 < effective_limit :
916+ _forward_outputs (stage_idx + 1 , new_outputs )
917+ _escalate_stage (stage_idx )
918+ # After escalation, re-check: finalize unless deep
919+ # jobs made is_complete False again.
920+ if tracker .is_complete :
921+ _finalize_and_cascade (stage_idx )
922+ else :
923+ _finalize_and_cascade (stage_idx )
816924
817925 # Finalize any stages that haven't been finalized yet
818926 # (e.g., stages with 0 jobs due to no inputs)
@@ -872,6 +980,7 @@ def synthesize_all_stages(
872980 resume_data : dict | None = None ,
873981 max_workers : int | None = None ,
874982 max_tasks_per_child : int | None = 1000 ,
983+ depth2_threshold : float = 0.1 ,
875984 ) -> tuple [list [SolutionNode ], bool ]:
876985 """Entry point: builds solution tree for all stages.
877986
@@ -907,4 +1016,5 @@ def synthesize_all_stages(
9071016 resume_data = resume_data ,
9081017 max_workers = max_workers ,
9091018 max_tasks_per_child = max_tasks_per_child ,
1019+ depth2_threshold = depth2_threshold ,
9101020 )
0 commit comments