Skip to content

Commit 6e9d692

Browse files
committed
feat: depth-stratified gadget search with conditional depth-2 escalation
Split candidate gadgets into shallow (depth<=1) and deep (depth>=2) tiers. Run depth-1 first for every stage; only escalate to depth-2 when coverage (unique_outputs / 2^K) falls below --depth2-threshold (default 10%). Depth-2 workers receive exclude_outputs from the shallow phase so the Z3 solver skips already-known output states, increasing solution diversity. When gadget-depth=1 or coverage is sufficient, deep_candidates is empty or escalation is skipped — zero overhead vs. previous behavior.
1 parent 6463068 commit 6e9d692

5 files changed

Lines changed: 349 additions & 36 deletions

File tree

vxsort/smallsort/codegen/src/bitonic_compiler.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,7 @@ def generate_bitonic_sorter(
357357
max_workers: int | None = None,
358358
max_tasks_per_child: int | None = 1000,
359359
retroactive_input: bool = False,
360+
depth2_threshold: float = 0.1,
360361
):
361362
"""
362363
Generate bitonic sorter with super-optimized permutation sequences.
@@ -385,6 +386,7 @@ def generate_bitonic_sorter(
385386
llvm_mca_path: Explicit path to llvm-mca binary. If None, auto-detected.
386387
max_tasks_per_child: Maximum tasks per worker process before recycling.
387388
Limits memory growth in long runs. None disables recycling.
389+
depth2_threshold: Coverage threshold for depth-2 escalation (default: 0.1).
388390
389391
Returns:
390392
List of SolutionNode trees representing different optimized solutions
@@ -447,6 +449,7 @@ def generate_bitonic_sorter(
447449
resume_data=resume_data,
448450
max_workers=max_workers,
449451
max_tasks_per_child=max_tasks_per_child,
452+
depth2_threshold=depth2_threshold,
450453
)
451454

452455
print(f"Found {len(solutions)} root solutions")
@@ -715,6 +718,15 @@ def main():
715718
help="Maximum tasks per worker process before recycling (default: 1000). "
716719
"Limits memory growth in long runs. Set to 0 to disable recycling.",
717720
)
721+
parser.add_argument(
722+
"--depth2-threshold",
723+
type=float,
724+
default=0.1,
725+
metavar="FRAC",
726+
help="Coverage threshold for depth-2 escalation (default: 0.1). "
727+
"After depth-1 completes, if coverage >= threshold, depth-2 is skipped. "
728+
"Set to 0 to always escalate. Values > 1.0 are allowed as multipliers.",
729+
)
718730
parser.add_argument(
719731
"--list-cpus",
720732
action="store_true",
@@ -804,7 +816,6 @@ def main():
804816
vm,
805817
depth_limit=args.depth_limit,
806818
top_k=args.top_k,
807-
output_formats=args.output_format, # Will be None if not specified, handled by function default
808819
gadget_depth=args.gadget_depth,
809820
smt2_dump_dir=smt2_dump_dir,
810821
natural_order=args.natural_order,
@@ -819,6 +830,7 @@ def main():
819830
max_workers=args.max_workers,
820831
max_tasks_per_child=max_tasks_per_child,
821832
retroactive_input=args.retroactive_input,
833+
depth2_threshold=args.depth2_threshold,
822834
)
823835

824836

vxsort/smallsort/codegen/src/bitonic_super_optimizer.py

Lines changed: 145 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
_match_dispatch_rule,
2222
get_available_intrinsics,
2323
_validate_gadget_worker,
24+
graph_max_depth,
2425
)
2526
except ImportError:
2627
from success_progress import SuccessProgress
@@ -38,6 +39,7 @@
3839
_match_dispatch_rule,
3940
get_available_intrinsics,
4041
_validate_gadget_worker,
42+
graph_max_depth,
4143
)
4244

4345
# Re-export all public names so existing importers keep working.
@@ -55,6 +57,7 @@
5557
"StageTracker",
5658
"_BATCH_THRESHOLDS",
5759
"apply_retroactive_input",
60+
"graph_max_depth",
5861
]
5962

6063

@@ -142,6 +145,11 @@ class StageTracker:
142145
# Progress
143146
progress_task_id: int | None = None
144147

148+
# Depth-stratified search
149+
shallow_complete: bool = False
150+
depth_escalated: bool = False
151+
all_input_states: list = field(default_factory=list)
152+
145153
@property
146154
def is_complete(self) -> bool:
147155
"""True when all submitted jobs are done AND no more inputs coming."""
@@ -237,6 +245,7 @@ def build_solution_tree(
237245
resume_data: dict | None = None,
238246
max_workers: int | None = None,
239247
max_tasks_per_child: int | None = 1000,
248+
depth2_threshold: float = 0.1,
240249
) -> tuple[list[SolutionNode], bool]:
241250
"""
242251
Iteratively explore all stage transitions to build solution tree.
@@ -257,6 +266,11 @@ def build_solution_tree(
257266
to os.cpu_count().
258267
max_tasks_per_child: Maximum tasks per worker process before recycling.
259268
Limits memory growth in long runs. None disables recycling.
269+
depth2_threshold: Coverage threshold for depth-2 escalation.
270+
After depth-1 completes for a stage, coverage is computed as
271+
unique_outputs / 2^K (K = number of comparison pairs).
272+
If coverage >= threshold, depth-2 is skipped for that stage.
273+
Default: 0.1 (10%).
260274
261275
Returns:
262276
Tuple of (root nodes, all_stages_complete) where all_stages_complete
@@ -290,8 +304,10 @@ def build_solution_tree(
290304
self._stages_completed: set[int] = set()
291305

292306
initial_state = self._create_initial_state()
293-
# Pre-compute all candidates once - they're independent of stage/input state
294-
all_candidates = self.synthesizer.precompute_all_candidates(gadget_depth)
307+
# Pre-compute candidates split by depth tier
308+
shallow_candidates, deep_candidates = (
309+
self.synthesizer.precompute_candidates_stratified(gadget_depth)
310+
)
295311

296312
# Build checkpoint config if checkpoint_dir is provided
297313
checkpoint_config = None
@@ -316,7 +332,9 @@ def build_solution_tree(
316332
nodes_by_path = self._build_tree_pipelined(
317333
input_states_with_context,
318334
depth_limit,
319-
all_candidates,
335+
shallow_candidates,
336+
deep_candidates,
337+
depth2_threshold=depth2_threshold,
320338
max_unique_outputs=max_unique_outputs,
321339
checkpoint_dir=checkpoint_dir,
322340
checkpoint_config=checkpoint_config,
@@ -345,8 +363,15 @@ def _make_jobs_for_inputs(
345363
stage_pairs: list,
346364
perm_gadget_candidates: list[tuple],
347365
max_unique_outputs: int,
366+
exclude_outputs: list[tuple] | None = None,
348367
) -> list[tuple]:
349-
"""Create validation jobs for a set of input states (Phase 1)."""
368+
"""Create validation jobs for a set of input states (Phase 1).
369+
370+
Args:
371+
exclude_outputs: Optional list of (top_tuple, bottom_tuple) output
372+
states to exclude from solver enumeration. Workers will skip
373+
these outputs, increasing diversity of newly-found solutions.
374+
"""
350375
jobs = []
351376
for input_state, parent_path in input_states_with_context:
352377
metadata = {
@@ -355,6 +380,8 @@ def _make_jobs_for_inputs(
355380
"stage_idx": stage_idx,
356381
"max_unique_outputs": max_unique_outputs,
357382
}
383+
if exclude_outputs:
384+
metadata["exclude_outputs"] = exclude_outputs
358385
if (
359386
self._natural_order_stage is not None
360387
and stage_idx == self._natural_order_stage
@@ -502,7 +529,9 @@ def _build_tree_pipelined(
502529
self,
503530
initial_inputs: list[tuple[VectorState, tuple]],
504531
depth_limit: int | None,
505-
all_candidates: list[tuple],
532+
shallow_candidates: list[tuple],
533+
deep_candidates: list[tuple],
534+
depth2_threshold: float = 0.1,
506535
max_unique_outputs: int = 3,
507536
checkpoint_dir: str | None = None,
508537
checkpoint_config=None,
@@ -514,6 +543,10 @@ def _build_tree_pipelined(
514543
soon as unique outputs from stage N are discovered (at geometric-series
515544
thresholds), allowing overlap.
516545
546+
Depth-stratified search: shallow_candidates (depth <= 1) are tried
547+
first. If coverage after depth-1 is below depth2_threshold,
548+
deep_candidates (depth >= 2) are submitted for that stage.
549+
517550
The final output is deterministic: ``_finalize_stage`` sorts transitions
518551
and gadgets in canonical order.
519552
"""
@@ -663,62 +696,52 @@ def _forward_outputs(
663696
next_stage_idx: int,
664697
new_outputs: list[tuple[VectorState, tuple]],
665698
) -> None:
666-
"""Submit jobs for newly-discovered outputs to the next stage."""
699+
"""Submit shallow jobs for newly-discovered outputs to the next stage."""
667700
if next_stage_idx >= effective_limit:
668701
return
669702

670703
stage_pairs = self.bitonic_sorter.stages[next_stage_idx]
704+
is_new = next_stage_idx not in trackers
671705

672-
if next_stage_idx not in trackers:
673-
# Launch new stage
706+
if is_new:
674707
tracker = StageTracker(
675708
stage_idx=next_stage_idx, stage_pairs=stage_pairs
676709
)
677710
tracker.progress_task_id = stage_task_ids.get(next_stage_idx)
678711
trackers[next_stage_idx] = tracker
712+
else:
713+
tracker = trackers[next_stage_idx]
679714

680-
jobs = self._make_jobs_for_inputs(
681-
new_outputs,
682-
next_stage_idx,
683-
stage_pairs,
684-
all_candidates,
685-
max_unique_outputs,
686-
)
687-
self._submit_stage_jobs(
688-
pool, tracker, jobs, pending_count, completion_queue
689-
)
715+
# Common: create jobs, submit, record inputs
716+
jobs = self._make_jobs_for_inputs(
717+
new_outputs,
718+
next_stage_idx,
719+
stage_pairs,
720+
shallow_candidates,
721+
max_unique_outputs,
722+
)
723+
self._submit_stage_jobs(
724+
pool, tracker, jobs, pending_count, completion_queue
725+
)
726+
tracker.all_input_states.extend(new_outputs)
690727

728+
if is_new:
691729
if tracker.progress_task_id is not None:
692730
progress.start_task(tracker.progress_task_id)
693731
progress.update(
694732
tracker.progress_task_id,
695733
total=tracker.total_jobs_submitted,
696734
)
697-
698735
progress.console.print(
699736
f"Stage {next_stage_idx}: Launched {tracker.total_jobs_submitted} jobs "
700737
f"(pipelined from stage {next_stage_idx - 1})"
701738
)
702739
else:
703-
# Add more jobs to existing stage
704-
tracker = trackers[next_stage_idx]
705-
jobs = self._make_jobs_for_inputs(
706-
new_outputs,
707-
next_stage_idx,
708-
stage_pairs,
709-
all_candidates,
710-
max_unique_outputs,
711-
)
712-
self._submit_stage_jobs(
713-
pool, tracker, jobs, pending_count, completion_queue
714-
)
715-
716740
if tracker.progress_task_id is not None:
717741
progress.update(
718742
tracker.progress_task_id,
719743
total=tracker.total_jobs_submitted,
720744
)
721-
722745
progress.console.print(
723746
f"Stage {next_stage_idx}: Added {len(jobs)} jobs "
724747
f"(total: {tracker.total_jobs_submitted})"
@@ -750,13 +773,14 @@ def _forward_outputs(
750773
current_inputs,
751774
start_stage,
752775
new_tracker.stage_pairs,
753-
all_candidates,
776+
shallow_candidates,
754777
max_unique_outputs,
755778
)
756779
self._submit_stage_jobs(
757780
pool, new_tracker, jobs, pending_count, completion_queue
758781
)
759782
new_tracker.total_jobs_expected = new_tracker.total_jobs_submitted
783+
new_tracker.all_input_states = list(current_inputs)
760784

761785
if new_tracker.progress_task_id is not None:
762786
progress.start_task(new_tracker.progress_task_id)
@@ -770,6 +794,73 @@ def _forward_outputs(
770794
f"for {len(current_inputs)} inputs"
771795
)
772796

797+
def _escalate_stage(stage_idx: int) -> None:
798+
"""Escalate a stage to depth-2 if coverage is below threshold."""
799+
if not deep_candidates:
800+
return
801+
tracker = trackers[stage_idx]
802+
if tracker.depth_escalated:
803+
return
804+
tracker.depth_escalated = True
805+
806+
# Coverage check: unique_outputs / 2^K
807+
k = len(tracker.stage_pairs)
808+
comparison_space = 2**k
809+
n_outputs = len(tracker.unique_outputs)
810+
coverage = n_outputs / comparison_space
811+
812+
if coverage >= depth2_threshold:
813+
progress.console.print(
814+
f"Stage {stage_idx}: {n_outputs} unique outputs = "
815+
f"{coverage:.1%} of 2^{k} ({comparison_space}) — "
816+
f"above {depth2_threshold:.0%} threshold, skipping depth-2"
817+
)
818+
return
819+
820+
progress.console.print(
821+
f"Stage {stage_idx}: {n_outputs} unique outputs = "
822+
f"{coverage:.1%} of 2^{k} ({comparison_space}) — "
823+
f"below {depth2_threshold:.0%} threshold, escalating to depth-2 "
824+
f"for {len(tracker.all_input_states)} inputs "
825+
f"({len(deep_candidates)} deep candidates)"
826+
)
827+
828+
# Pass known outputs so depth-2 workers skip them,
829+
# increasing diversity of newly-found solutions.
830+
known_outputs = list(tracker.unique_outputs.keys())
831+
832+
jobs = self._make_jobs_for_inputs(
833+
tracker.all_input_states,
834+
stage_idx,
835+
tracker.stage_pairs,
836+
deep_candidates,
837+
max_unique_outputs,
838+
exclude_outputs=known_outputs,
839+
)
840+
shallow_jobs = tracker.total_jobs_submitted
841+
self._submit_stage_jobs(
842+
pool, tracker, jobs, pending_count, completion_queue
843+
)
844+
deep_jobs = tracker.total_jobs_submitted - shallow_jobs
845+
tracker.total_jobs_expected = tracker.total_jobs_submitted
846+
tracker.next_threshold_index = 0
847+
848+
progress.console.print(
849+
f"Stage {stage_idx}: Submitted {deep_jobs} depth-2 jobs "
850+
f"(total: {tracker.total_jobs_submitted}, "
851+
f"shallow: {shallow_jobs}, deep: {deep_jobs})"
852+
)
853+
854+
if tracker.progress_task_id is not None:
855+
progress.update(
856+
tracker.progress_task_id,
857+
description=f"Stage {stage_idx} (depth-2)",
858+
total=tracker.total_jobs_submitted,
859+
)
860+
861+
# Free accumulated inputs — no longer needed after job creation
862+
tracker.all_input_states = []
863+
773864
# Event loop: process completed results via callback queue
774865
while pending_count[0] > 0:
775866
stage_idx, result, error = completion_queue.get()
@@ -812,7 +903,24 @@ def _forward_outputs(
812903

813904
# Check completion
814905
if tracker.is_complete and not tracker.finalized:
815-
_finalize_and_cascade(stage_idx)
906+
if (
907+
not tracker.shallow_complete
908+
and deep_candidates
909+
and tracker.all_inputs_received
910+
):
911+
# Shallow phase done. Process remaining, then escalate.
912+
tracker.shallow_complete = True
913+
if tracker.unprocessed_results:
914+
new_outputs = self._process_batch(tracker)
915+
if new_outputs and stage_idx + 1 < effective_limit:
916+
_forward_outputs(stage_idx + 1, new_outputs)
917+
_escalate_stage(stage_idx)
918+
# After escalation, re-check: finalize unless deep
919+
# jobs made is_complete False again.
920+
if tracker.is_complete:
921+
_finalize_and_cascade(stage_idx)
922+
else:
923+
_finalize_and_cascade(stage_idx)
816924

817925
# Finalize any stages that haven't been finalized yet
818926
# (e.g., stages with 0 jobs due to no inputs)
@@ -872,6 +980,7 @@ def synthesize_all_stages(
872980
resume_data: dict | None = None,
873981
max_workers: int | None = None,
874982
max_tasks_per_child: int | None = 1000,
983+
depth2_threshold: float = 0.1,
875984
) -> tuple[list[SolutionNode], bool]:
876985
"""Entry point: builds solution tree for all stages.
877986
@@ -907,4 +1016,5 @@ def synthesize_all_stages(
9071016
resume_data=resume_data,
9081017
max_workers=max_workers,
9091018
max_tasks_per_child=max_tasks_per_child,
1019+
depth2_threshold=depth2_threshold,
9101020
)

0 commit comments

Comments
 (0)