diff --git a/examples/TTSwithVerification/ToT/tot_baseline.py b/examples/TTSwithVerification/ToT/tot_baseline.py new file mode 100644 index 0000000..7844038 --- /dev/null +++ b/examples/TTSwithVerification/ToT/tot_baseline.py @@ -0,0 +1,798 @@ +"""Command-line Tree-of-Thought baseline runner for interwhen datasets.""" + +import argparse +import asyncio +import json +import logging +import os +import re +from pathlib import Path +from typing import Any, Dict, List, Tuple + +from tqdm.asyncio import tqdm_asyncio + +import numpy as np +from datasets import load_dataset + +from interwhen import stream_completion +from utils.tree_of_thought import ( + SearchMethod, + ToTSearchConfig, + TreeOfThoughtSearch, + build_tot_problem, + build_verina_synthesis_prompt, + build_verina_spec_synthesis_prompt, +) +from utils.zebralogic_helper import extract_last_json, zebra_correctness +from utils.verina_tot_utils import ( + load_verina_dataset, + extract_code_from_response, + evaluate_generated_code, +) +from utils.verina_spec_tot_utils import ( + load_verina_dataset as load_verina_spec_dataset, + extract_spec_from_response, + evaluate_generated_spec, +) + +LOGGER = logging.getLogger("tot_baseline") + +# Walk up to find the repo root (contains pyproject.toml), output to its parent +_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +_dir = _SCRIPT_DIR +while _dir != os.path.dirname(_dir) and not os.path.isfile(os.path.join(_dir, "pyproject.toml")): + _dir = os.path.dirname(_dir) +_OUTPUT_ROOT = os.path.dirname(_dir) + +_TASK_FOLDER_MAP = { + "game24": "Gameof24results", + "maze": "MazeResults", + "spatialmap": "SpatialMapResults", + "zebralogic": "zebralogic", + "verina": "verina", + "verina_spec": "verina_spec", +} + + +def get_model_short_name(model_name: str) -> str: + """Extract a short, filesystem-safe name from the model path.""" + short_name = model_name.split("/")[-1] + short_name = short_name.replace(" ", "_").replace(":", "-") + return short_name + + +def get_default_output_dir(task: str, model: str) -> str: + """Build the default output directory matching the project convention.""" + task_folder = _TASK_FOLDER_MAP.get(task, task) + model_short = get_model_short_name(model) + return os.path.join(_OUTPUT_ROOT, "Outputs_TTS", "ToT", task_folder, model_short) + + +# ============== Helper Functions ============== + +def remove_last_paragraph(s: str) -> str: + return s[:-143] if len(s) > 143 else s + + +def build_maze_prompt(example): + pre_prompt = ( + "You are an expert problem solver. Carefully read the following multiple-choice question " + "and think through the solution step-by-step before providing your final answer. " + "Provide your final answer option by enclosing it within \\boxed{A/B/C/D}.:" + ) + description = remove_last_paragraph(str(example.get("prompt"))) + return pre_prompt, description + + +def build_spatialmap_prompt(example): + pre_prompt = ( + "You are an expert problem solver. Carefully read the following multiple-choice question " + "and think through the solution step-by-step before providing your final answer." + "Provide your final answer option by enclosing it within \\boxed{A/B/C/D}.:" + ) + description = remove_last_paragraph(str(example.get("prompt"))) + return pre_prompt, description + + +def extract_solution_game24(text): + boxed_pattern = r"\\boxed\{" + matches = list(re.finditer(boxed_pattern, text)) + if not matches: + return None + last_match = matches[-1] + start = last_match.end() + brace_count = 1 + end = start + while end < len(text) and brace_count > 0: + if text[end] == "{": + brace_count += 1 + elif text[end] == "}": + brace_count -= 1 + end += 1 + expr = text[start:end - 1].strip() + + frac_pattern = r"\\frac\{([^{}]+)\}\{([^{}]+)\}" + while re.search(frac_pattern, expr): + expr = re.sub(frac_pattern, r"(\1/\2)", expr) + + replacements = { + r"\times": "*", + r"\cdot": "*", + r"\div": "/", + } + for latex, op in replacements.items(): + expr = expr.replace(latex, op) + + expr = expr.replace(r"\\,", "").replace(r"\\ ", "") + expr = re.sub(r"\)\s*\(", ")*(", expr) + expr = re.sub(r"\)\s*(\d)", r")*\1", expr) + expr = re.sub(r"(\d)\s*\(", r"\1*(", expr) + + return expr + + +def extract_numbers_from_expr(expr): + numbers = re.findall(r"\d+\.?\d*", expr) + return [int(float(n)) if float(n).is_integer() else float(n) for n in numbers] + + +def validate_numbers_used(expr, expected_nums): + used_nums = extract_numbers_from_expr(expr) + return sorted(used_nums) == sorted(expected_nums) + + +def evaluate_expression(expr, expected_nums=None): + try: + if expected_nums is not None and not validate_numbers_used(expr, expected_nums): + return False + value = eval(expr, {"__builtins__": None}, {}) + return abs(value - 24) < 1e-6 + except Exception: + return False + + +def evaluate_game24_answer(answer, nums): + expr = extract_solution_game24(answer) + if not expr: + return False, None, "No expression found" + if evaluate_expression(expr, expected_nums=nums): + return True, expr, "Correct solution (evaluates to 24 using exactly the given numbers)" + used_nums = extract_numbers_from_expr(expr) + if sorted(used_nums) != sorted(nums): + return False, expr, f"Incorrect: Expression uses {used_nums}, expected {nums}" + return False, expr, "Expression does not evaluate to 24" + + +def extract_solution_mcq(text): + """Extract MCQ solution from model output.""" + patterns = [ + r"\\boxed\{([^}]*)\}", + r"boxed\{([^}]*)\}", + r"\*\*([A-D])\*\*", + r"answer[:\s]*([A-D])", + r"(?:^|\n)([A-D])(?:\s|$|\.)", + ] + + for pattern in patterns: + matches = re.findall(pattern, text, re.IGNORECASE) + if matches: + expr = matches[-1].strip() + choice_match = re.search(r"\b([ABCD])\b", expr, flags=re.IGNORECASE) + if choice_match: + return choice_match.group(1).upper() + + standalone = re.findall(r"\b([ABCD])\b", text) + if standalone: + return standalone[-1].upper() + + return None + + +def extract_options_from_prompt(prompt_text, target_options): + pattern = r"\b([A-D])\.\s*(.*?)(?=\s*[A-D]\.\s*|$)" + raw = re.findall(pattern, prompt_text, flags=re.DOTALL) + options = {k: v.strip().rstrip(".") for k, v in raw} + if target_options: + options = {k: v for k, v in options.items() if k in target_options} + return options + + +def evaluate_mcq_answer(answer, options, ground_truth): + sol = extract_solution_mcq(answer) + gt_sol = str(ground_truth).strip() + if not sol: + return False, None, "No expression found" + sol = sol.strip() + if sol in options: + if options[sol] == gt_sol: + return True, sol, f"Correct: option {sol} -> {options[sol]}" + return False, sol, f"Incorrect: expected '{gt_sol}', got '{options[sol]}' (option {sol})" + if sol.lower() == gt_sol.lower(): + return True, sol, f"Correct: answer text matches ground truth: {sol}" + for opt_letter, opt_value in options.items(): + if sol.lower() == opt_value.lower(): + if opt_value == gt_sol: + return True, sol, f"Correct: answer text {sol} (option {opt_letter})" + return False, sol, f"Incorrect: expected '{gt_sol}', got '{opt_value}' (option {opt_letter})" + return False, sol, f"Solution '{sol}' not found in options or ground truth" + + +def extract_solution_zebralogic(text): + """Extract JSON solution from ZebraLogic model output.""" + if not text: + return None + + def _try_parse(candidate: str): + try: + parsed = json.loads(candidate) + if isinstance(parsed, dict): + # Unwrap if it's a wrapper with "answer" key + if "answer" in parsed and isinstance(parsed["answer"], dict): + inner = parsed["answer"] + if any(re.match(r"^house\s*\d+$", str(k).strip(), flags=re.IGNORECASE) for k in inner.keys()): + return inner + return parsed + except json.JSONDecodeError: + return None + return None + + # Try to extract JSON from code blocks + patterns = [ + r"```json\s*({.*?})\s*```", # Markdown code block + r"```\s*({.*?})\s*```", # Generic code block + r"({\s*['\"]House\s*\d+['\"].*?})", # Direct JSON starting with House + ] + + for pattern in patterns: + matches = re.findall(pattern, text, re.DOTALL | re.IGNORECASE) + if matches: + json_str = matches[-1].strip() + solution = _try_parse(json_str) + if solution is not None: + return solution + + # Try parsing entire last large JSON-like structure + try: + # Find potential JSON starting with { + json_match = re.search(r"({\s*(?:['\"]House|['{\"\[])+[\s\S]*})", text) + if json_match: + json_str = json_match.group(1) + solution = _try_parse(json_str) + if solution is not None: + return solution + except (json.JSONDecodeError, AttributeError): + pass + + # Last-chance extraction: parse top-level JSON object spans and keep the + # last one that parses and looks like a house assignment dictionary. + stack = [] + spans = [] + for idx, ch in enumerate(text): + if ch == "{": + stack.append(idx) + elif ch == "}" and stack: + start = stack.pop() + if not stack: + spans.append((start, idx + 1)) + for start, end in reversed(spans): + candidate = text[start:end] + solution = _try_parse(candidate) + if solution is not None: + # Handle wrapped solution with "answer" key + if isinstance(solution, dict) and "answer" in solution: + answer = solution["answer"] + if isinstance(answer, dict) and any( + re.match(r"^house\s*\d+$", str(key).strip(), flags=re.IGNORECASE) + for key in answer.keys() + ): + return answer + # Direct house keys + if any( + re.match(r"^house\s*\d+$", str(key).strip(), flags=re.IGNORECASE) + for key in solution.keys() + ): + return solution + + return None + + +async def _request_zebralogic_json( + prompt: str, + llm_server: Dict[str, Any], +) -> str: + """Submit a strict-JSON request for ZebraLogic via stream_completion.""" + system_prefix = ( + "You solve Zebra Logic puzzles and MUST return " + "strictly valid JSON only. " + "No markdown fences. No explanation. No extra text.\n\n" + ) + full_prompt = system_prefix + prompt + result = await stream_completion( + full_prompt, + llm_server=llm_server, + monitors=[], + add_delay=False, + termination_requires_validation=False, + async_execution=True, + ) + return result.strip() + + +async def finalize_zebralogic_json(problem: str, trajectory: str, llm_server: Dict[str, Any]) -> str: + """Ask the model to convert an existing trajectory into strict final JSON only.""" + prompt = ( + "Convert the reasoning into the final Zebra Logic answer JSON.\n" + "Output ONLY valid JSON (no markdown, no explanation).\n" + "Use exact feature/value names from the puzzle.\n\n" + "PUZZLE:\n" + f"{problem}\n\n" + "REASONING:\n" + f"{trajectory}\n" + ) + return await _request_zebralogic_json(prompt, llm_server) + + +async def solve_zebralogic_json_direct(problem: str, llm_server: Dict[str, Any]) -> str: + """Directly solve ZebraLogic and return strict final JSON.""" + prompt = ( + "Solve the Zebra Logic puzzle and provide the final house assignments.\n" + "Output ONLY valid JSON with keys like 'House 1', 'House 2', etc.\n" + "Use exact feature/value names from the puzzle text.\n\n" + "PUZZLE:\n" + f"{problem}\n" + ) + return await _request_zebralogic_json(prompt, llm_server) + + +def _raw_example_to_zebra_problem(example): + """Convert a raw HuggingFace ZebraLogic example into the processed format + that zebra_correctness expects (matching process_zebralogic_problem output).""" + solution = example.get("solution", {}) + header = solution.get("header", []) + rows = solution.get("rows", []) + size = example.get("size", "") + n_houses, n_features = map(int, size.split("*")) + + # Build processed solution dict: {"House 1": {"feature": "value", ...}, ...} + processed_solution = {} + features = {} + for house_i, row in enumerate(rows): + house_dict = {} + for fname, value in zip(header[1:], row[1:]): + fname_l = fname.lower() + val_l = value.lower() + house_dict[fname_l] = val_l + features.setdefault(fname_l, set()).add(val_l) + processed_solution[f"House {house_i + 1}"] = house_dict + features = {k: sorted(v) for k, v in features.items()} + + return { + "solution": processed_solution, + "n_houses": n_houses, + "n_features": n_features, + "features": features, + } + + +def evaluate_zebralogic_answer(answer, example): + """Evaluate ZebraLogic solution against ground truth using zebra_correctness.""" + candidate = extract_last_json(answer) + if candidate is None: + # Fallback: try the older extraction for non-standard formats + candidate = extract_solution_zebralogic(answer) + if candidate is None: + return False, None, "Could not extract valid JSON solution" + + # Lowercase candidate keys/values to match processed ground truth + normed_candidate = {} + for house_key, attrs in candidate.items(): + house_match = re.search(r"House\s*(\d+)", house_key, re.IGNORECASE) + hk = f"House {house_match.group(1)}" if house_match else house_key + if isinstance(attrs, dict): + normed_candidate[hk] = {k.lower(): v.lower() if isinstance(v, str) else v + for k, v in attrs.items()} + else: + normed_candidate[hk] = attrs + + problem = _raw_example_to_zebra_problem(example) + correct, skipped, missing, total = zebra_correctness(problem, normed_candidate) + is_correct = correct == total + msg = f"Correct={correct}/{total}, skipped={skipped}, missing={missing}" + return is_correct, normed_candidate, msg + + +def load_dataset_for_task(task): + if task == "game24": + return load_dataset("nlile/24-game", split="train") + if task == "maze": + return load_dataset("microsoft/VISION_LANGUAGE", "maze_text_only", split="val") + if task == "spatialmap": + return load_dataset("microsoft/VISION_LANGUAGE", "spatial_map_text_only", split="val") + if task == "zebralogic": + return load_dataset("WildEval/ZebraLogic", name="grid_mode", split="test") + if task == "verina": + return load_verina_dataset() + if task == "verina_spec": + return load_verina_spec_dataset() + raise ValueError(f"Unsupported task: {task}") + + +def resolve_indices(task, dataset_len, args): + if args.indices: + return [int(x.strip()) for x in args.indices.split(",")] + if args.xrange: + parts = args.xrange.split("-") + if len(parts) == 2: + try: + start = int(parts[0].strip()) + end = int(parts[1].strip()) + return list(range(start, end)) + except ValueError: + raise ValueError(f"Invalid xrange format: {args.xrange}. Use 'start-end'") + if args.num_examples: + return list(np.linspace(0, dataset_len - 1, args.num_examples, dtype=int)) + start = args.start if args.start is not None else 0 + end = args.end if args.end is not None else dataset_len + return list(range(start, end)) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Run Tree-of-Thought search on a subset of the supported tasks", + ) + parser.add_argument("--task", choices=["game24", "maze", "spatialmap", "zebralogic", "verina", "verina_spec"], required=True) + parser.add_argument("--k", type=int, default=1, help="Unused placeholder to mirror other baselines") + parser.add_argument("--num_examples", "-n", type=int, default=None) + parser.add_argument("--indices", type=str, default=None) + parser.add_argument("--xrange", type=str, default=None) + parser.add_argument("--start", type=int, default=None) + parser.add_argument("--end", type=int, default=None) + parser.add_argument("--model", default="Qwen/Qwen3-30B-A3B-Thinking-2507") + parser.add_argument("--llm_url", default="http://localhost:{port}/v1/completions") + parser.add_argument( + "--ports", + default="8000", + help="Comma-separated list of vLLM ports to round-robin across", + ) + parser.add_argument("--temperature", type=float, default=0.6) + parser.add_argument("--top_p", type=float, default=0.95) + parser.add_argument("--top_k", type=int, default=20) + parser.add_argument("--max_tokens", type=int, default=8192) + # parser.add_argument("--max_context_length", type=int, default=40960) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--search_method", choices=["bfs", "dfs", "beam"], default="beam") + parser.add_argument("--branching_factor", type=int, default=4) + parser.add_argument("--max_depth", type=int, default=1) + parser.add_argument("--beam_width", type=int, default=2) + parser.add_argument("--sure_threshold", type=float, default=0.9) + parser.add_argument("--likely_threshold", type=float, default=0.5) + parser.add_argument("--impossible_threshold", type=float, default=0.2) + parser.add_argument("--max_candidates_per_level", type=int, default=3) + parser.add_argument("--early_termination", action="store_true") + parser.add_argument("--no_cache", action="store_true") + parser.add_argument( + "--concurrency", + type=int, + default=1, + help="Maximum number of ToT examples to run concurrently", + ) + parser.add_argument( + "--output_dir", + default=None, + help="Directory to store per-example JSON logs and summary (auto-derived if omitted)", + ) + parser.add_argument("--log_level", default="INFO") + parser.add_argument("--summary_file",default="summary.json") + parser.add_argument("--log_file", default="tot_baseline.log") + args = parser.parse_args() + if args.output_dir is None: + args.output_dir = get_default_output_dir(args.task, args.model) + return args + + +def parse_port_list(port_str: str) -> List[int]: + return [int(p.strip()) for p in port_str.split(",") if p.strip()] + + +def build_llm_server(args: argparse.Namespace, port: int) -> Dict[str, Any]: + payload = { + "model": args.model, + "temperature": args.temperature, + "top_p": args.top_p, + "top_k": args.top_k, + "max_tokens": args.max_tokens, + "stream": True, + "seed": args.seed, + } + return { + "url": args.llm_url.format(port=port), + "headers": {"Content-Type": "application/json"}, + "payload": payload, + } + + +def build_tot_config(args: argparse.Namespace) -> ToTSearchConfig: + method = SearchMethod[args.search_method.upper()] + return ToTSearchConfig( + branching_factor=args.branching_factor, + max_depth=args.max_depth, + search_method=method, + beam_width=args.beam_width, + sure_threshold=args.sure_threshold, + likely_threshold=args.likely_threshold, + impossible_threshold=args.impossible_threshold, + early_termination=args.early_termination, + cache_evaluations=not args.no_cache, + max_candidates_per_level=args.max_candidates_per_level, + ) + + +def ensure_output_dir(base_dir: str, task: str) -> Path: + path = Path(base_dir).expanduser().resolve() + path.mkdir(parents=True, exist_ok=True) + return path + + +def evaluate_verina_answer(output, example, idx): + """Evaluate Verina code generation output for ToT.""" + generated_code = extract_code_from_response(output) + if not generated_code.strip(): + return False, "", "No code extracted from response" + compiles, all_tests_pass, compile_output, test_results = evaluate_generated_code( + example, generated_code, idx, + ) + num_tests = len(example.tests) if example.tests else 0 + num_passed = sum(1 for v in test_results.values() if v == "pass") + if compiles and all_tests_pass: + return True, generated_code, f"Code compiles and all {num_tests} tests pass" + elif compiles: + return False, generated_code, f"Compilation succeeded but {num_tests - num_passed}/{num_tests} tests failed" + else: + error_preview = compile_output[:300] if compile_output else "Unknown error" + return False, generated_code, f"Compilation failed: {error_preview}" + + +def evaluate_verina_spec_answer(output, example, idx): + """Evaluate Verina spec generation output for ToT.""" + generated_spec = extract_spec_from_response(output) + if not generated_spec.get("precond") and not generated_spec.get("postcond"): + return False, generated_spec, "No spec extracted from response" + eval_result = evaluate_generated_spec(example, generated_spec, idx) + if eval_result["full_spec_correct"]: + return True, generated_spec, "Spec compiles and all soundness/completeness tests pass" + elif eval_result["compiles"]: + msg_parts = [] + if eval_result["precond_sound_total"] > 0: + msg_parts.append(f"precond_sound={eval_result['precond_sound_pass']}/{eval_result['precond_sound_total']}") + if eval_result["precond_complete_total"] > 0: + msg_parts.append(f"precond_complete={eval_result['precond_complete_pass']}/{eval_result['precond_complete_total']}") + if eval_result["postcond_sound_total"] > 0: + msg_parts.append(f"postcond_sound={eval_result['postcond_sound_pass']}/{eval_result['postcond_sound_total']}") + if eval_result["postcond_complete_total"] > 0: + msg_parts.append(f"postcond_complete={eval_result['postcond_complete_pass']}/{eval_result['postcond_complete_total']}") + return False, generated_spec, f"Compilation succeeded but tests: {', '.join(msg_parts)}" + else: + error_preview = eval_result.get("compile_error", "")[:300] + return False, generated_spec, f"Compilation failed: {error_preview}" + + +def prepare_eval(task: str, example: Dict[str, Any], idx: int = 0) -> Tuple: + if task == "game24": + nums = list(example.get("numbers", [])) + return (lambda output: evaluate_game24_answer(output, nums), {"numbers": nums}) + if task == "zebralogic": + # Pass raw example so zebra_correctness can evaluate against processed solution + gt_problem = _raw_example_to_zebra_problem(example) + meta = { + "ground_truth_sample": str(example.get("solution", {}))[:100], + "ground_truth_solution": gt_problem["solution"], + } + return (lambda output: evaluate_zebralogic_answer(output, example), meta) + if task == "verina": + meta = {"data_id": example.data_id} + return (lambda output: evaluate_verina_answer(output, example, idx), meta) + if task == "verina_spec": + meta = {"data_id": example.data_id} + return (lambda output: evaluate_verina_spec_answer(output, example, idx), meta) + gt = str(example.get("ground_truth", "")).strip() + target_options = ["A", "B"] if gt == "Q4" else ["A", "B", "C", "D"] + if task == "maze": + _, user_prompt = build_maze_prompt(example) + else: + _, user_prompt = build_spatialmap_prompt(example) + options = extract_options_from_prompt(user_prompt, target_options) + meta = {"options": options, "ground_truth": gt} + return (lambda output: evaluate_mcq_answer(output, options, gt), meta) + + +async def run_single_example( + idx: int, + task: str, + example: Dict[str, Any], + tot_config: ToTSearchConfig, + llm_server: Dict[str, Any], +) -> Dict[str, Any]: + eval_fn, eval_meta = prepare_eval(task, example, idx) + nums = example.get("numbers") if hasattr(example, "get") else None + problem = build_tot_problem(task, example, nums=nums) + tot = TreeOfThoughtSearch(tot_config) + search_result = await tot.search(task, problem, llm_server) + best_traj = search_result.get("best_trajectory", "") + best_value = search_result.get("best_value", 0.0) + + synthesized_code = None + synthesized_spec = None + if task == "verina" and best_traj.strip(): + try: + synthesis_prompt = build_verina_synthesis_prompt(problem, best_traj) + synthesized_code = await stream_completion( + synthesis_prompt, + llm_server=llm_server, + monitors=[], + add_delay=False, + termination_requires_validation=False, + async_execution=True, + ) + is_correct, extracted, message = eval_fn(synthesized_code) + except Exception as exc: + LOGGER.warning("Verina synthesis failed for index %s: %s", idx, exc) + is_correct, extracted, message = eval_fn(best_traj) + elif task == "verina_spec" and best_traj.strip(): + try: + synthesis_prompt = build_verina_spec_synthesis_prompt(problem, best_traj) + synthesized_spec = await stream_completion( + synthesis_prompt, + llm_server=llm_server, + monitors=[], + add_delay=False, + termination_requires_validation=False, + async_execution=True, + ) + is_correct, extracted, message = eval_fn(synthesized_spec) + except Exception as exc: + LOGGER.warning("Verina spec synthesis failed for index %s: %s", idx, exc) + is_correct, extracted, message = eval_fn(best_traj) + else: + is_correct, extracted, message = eval_fn(best_traj) + + # ZebraLogic often ends with partial reasoning trajectories; add strict-JSON + # recovery passes before scoring. + finalized_answer = None + direct_answer = None + if task == "zebralogic" and (not is_correct): + try: + finalized_answer = await finalize_zebralogic_json(problem, best_traj, llm_server) + final_is_correct, final_extracted, final_message = eval_fn(finalized_answer) + if final_extracted is not None or final_is_correct: + is_correct = final_is_correct + extracted = final_extracted + message = final_message + best_traj = finalized_answer + except Exception as exc: # pragma: no cover + LOGGER.warning("ZebraLogic finalization failed for index %s: %s", idx, exc) + + if task == "zebralogic" and (not is_correct): + try: + direct_answer = await solve_zebralogic_json_direct(problem, llm_server) + direct_is_correct, direct_extracted, direct_message = eval_fn(direct_answer) + if direct_extracted is not None or direct_is_correct: + is_correct = direct_is_correct + extracted = direct_extracted + message = direct_message + best_traj = direct_answer + except Exception as exc: # pragma: no cover + LOGGER.warning("ZebraLogic direct solve failed for index %s: %s", idx, exc) + + return { + "index": int(idx), + "best_value": best_value, + "best_trajectory": best_traj, + "raw_best_trajectory": search_result.get("best_trajectory", ""), + "synthesized_code": synthesized_code, + "synthesized_spec": synthesized_spec, + "finalized_answer": finalized_answer, + "direct_answer": direct_answer, + "search_stats": search_result.get("search_stats", {}), + "decision_tree": search_result.get("decision_tree", []), + "correct": bool(is_correct), + "extracted": extracted, + "message": message, + "evaluation_meta": eval_meta, + } + + +async def run_tot_baseline(args: argparse.Namespace) -> None: + log_level = getattr(logging, args.log_level.upper(), logging.INFO) + root_logger = logging.getLogger() + root_logger.setLevel(log_level) + # Send all logs to a file instead of stdout/stderr (keeps tqdm clean) + log_file = Path(args.output_dir) / args.log_file + log_file.parent.mkdir(parents=True, exist_ok=True) + fh = logging.FileHandler(str(log_file), mode="a") + fh.setLevel(log_level) + fh.setFormatter(logging.Formatter("%(asctime)s %(name)s %(levelname)s %(message)s")) + root_logger.addHandler(fh) + # Remove default stderr handler so logs don't clobber the progress bar + for h in root_logger.handlers[:]: + if isinstance(h, logging.StreamHandler) and not isinstance(h, logging.FileHandler): + root_logger.removeHandler(h) + dataset = load_dataset_for_task(args.task) + indices = resolve_indices(args.task, len(dataset), args) + output_dir = ensure_output_dir(args.output_dir, args.task) + tot_config = build_tot_config(args) + ports = parse_port_list(args.ports) + if not ports: + raise ValueError("At least one port must be specified via --ports") + concurrency = max(1, args.concurrency) + port_lock = asyncio.Lock() + port_index = {"value": 0} + + async def next_port() -> int: + async with port_lock: + port = ports[port_index["value"] % len(ports)] + port_index["value"] += 1 + return port + + semaphore = asyncio.Semaphore(concurrency) + + async def process_index(idx: int) -> Dict[str, Any]: + async with semaphore: + example = dataset[int(idx)] + port = await next_port() + llm_server = build_llm_server(args, port) + LOGGER.info("Running ToT on example %s via port %s", idx, port) + try: + record = await run_single_example(idx, args.task, example, tot_config, llm_server) + except Exception as exc: # pragma: no cover + LOGGER.exception("Failed example %s", idx) + record = { + "index": int(idx), + "error": str(exc), + "best_trajectory": "", + "correct": False, + } + example_path = output_dir / f"example_{idx}.json" + with example_path.open("w", encoding="utf-8") as handle: + json.dump(record, handle, indent=2) + return record + + processed = await tqdm_asyncio.gather( + *[process_index(idx) for idx in indices], + desc="ToT examples", + ) + + total = len(processed) + correct = sum(1 for r in processed if r.get("correct")) + summary = { + "task": args.task, + "model": args.model, + "total_examples": total, + "correct": correct, + "accuracy": (correct / total) if total else 0.0, + "search_method": args.search_method, + "config": { + "branching_factor": args.branching_factor, + "max_depth": args.max_depth, + "beam_width": args.beam_width, + "sure_threshold": args.sure_threshold, + "likely_threshold": args.likely_threshold, + "impossible_threshold": args.impossible_threshold, + "max_candidates_per_level": args.max_candidates_per_level, + "early_termination": args.early_termination, + "cache_evaluations": not args.no_cache, + "ports": ports, + "concurrency": concurrency, + }, + } + summary_path = output_dir / args.summary_file + with summary_path.open("w", encoding="utf-8") as handle: + json.dump(summary, handle, indent=2) + LOGGER.info("Accuracy %.2f (%d/%d)", summary["accuracy"], correct, total) + + +if __name__ == "__main__": + import time + st = time.time() + asyncio.run(run_tot_baseline(parse_args())) + et = time.time() + print(f"Total execution time: {et - st:.2f} seconds") \ No newline at end of file diff --git a/examples/TTSwithVerification/ToT/utils/tree_of_thought.py b/examples/TTSwithVerification/ToT/utils/tree_of_thought.py new file mode 100644 index 0000000..7df5db3 --- /dev/null +++ b/examples/TTSwithVerification/ToT/utils/tree_of_thought.py @@ -0,0 +1,1525 @@ +""" +Tree of Thought implementation for interwhen-style streaming completion. + +Implements proper ToT search using: +1. Propose function to generate candidate next steps +2. Value function to evaluate intermediate states +3. Search algorithm (BFS/DFS/beam) to explore the tree +4. Integrated with interwhen's async streaming architecture +""" + +import asyncio +import json +import logging +import random +import re +from typing import Dict, List, Any, Optional, Tuple +from dataclasses import dataclass, field +from enum import Enum +import time + +from .value_prompts import ( + build_game24_value_prompt, + build_mcq_value_prompt, + build_tot_value_prompt as build_tot_value_prompt_impl, +) +from interwhen import stream_completion + +logger = logging.getLogger(__name__) + + +# --------------------- Dataset prompt helpers --------------------- + +def build_game24_prompt(nums: List[int]) -> str: + """Return the canonical Game24 instruction block used across baselines.""" + if len(nums) != 4: + raise ValueError("Game24 requires exactly four numbers.") + a, b, c, d = nums + boxed = r"\\boxed{}" + return ( + "You are solving the Game of 24.\n\n" + f"You are given four numbers: {a}, {b}, {c}, {d}\n\n" + "Your job is to produce a valid arithmetic expression using:\n" + "- ALL four numbers exactly once\n- ONLY +, -, *, /\n" + "- The expression must evaluate to exactly 24.\n\n" + "Please reason step by step, and put your final answer containing" + f" only the expression within {boxed}." + ) + + +def build_maze_prompt(example: Dict[str, Any]) -> str: + """Construct the maze reasoning instructions used in other pipelines.""" + pre_prompt = ( + "You are an expert problem solver. Carefully read the following " + "multiple-choice question and think through the solution step-by-step " + "before providing your final answer. Provide the final answer option by " + "enclosing it within \\boxed{A/B/C/D}." + ) + description = str(example.get("prompt", "")) + return f"{pre_prompt}\n\n{description.strip()}" + + +def build_spatialmap_prompt(example: Dict[str, Any]) -> str: + """Construct the spatial reasoning instructions for TOT experiments.""" + pre_prompt = ( + "You are an expert problem solver. Carefully read the following " + "multiple-choice question and think through the solution step-by-step " + "before providing your final answer. Provide the final answer option by " + "enclosing it within \\boxed{A/B/C/D}." + ) + description = str(example.get("prompt", "")) + return f"{pre_prompt}\n\n{description.strip()}" + + +def build_zebralogic_prompt(example: Dict[str, Any]) -> str: + """Construct the Zebra Logic puzzle solving instructions for TOT experiments.""" + puzzle_text = str(example.get("puzzle", "")) + prompt = ( + "# Problem Description\n\n" + "You are solving a house grid logic puzzle. You are given:\n" + "1. Features and Domains\n" + " - A fixed number of houses, indexed sequentially (e.g., House 1, House 2, …) from left to right.\n" + " - A set of features (e.g., color, name, pet, book genre).\n" + " - Each feature has a finite domain of possible values.\n" + "2. Constraints:\n" + " - Each house has exactly one value per feature.\n" + " - No two houses share the same value for the same feature.\n" + "3. Clues / Constraints describing:\n" + " - Houses and their positions\n" + " - Feature values\n" + " - Relative ordering (e.g., 'next to', 'to the left of', '2 houses away from')\n\n" + "Solve this puzzle to your best ability by determining the arrangement of features across the houses.\n\n" + "# Puzzle\n\n" + f"{puzzle_text}\n\n" + "# Solution Format\n\n" + "Provide your final answer in this exact JSON format:\n" + "```json\n" + '{\n' + ' "House 1": { "feature1": "value1", "feature2": "value2", ... },\n' + ' "House 2": { "feature1": "value1", "feature2": "value2", ... },\n' + ' ...\n' + '}\n' + "```\n\n" + "Make sure to use the exact feature/value names as given in the puzzle.\n" + "Ensure the JSON is valid and parsable." + ) + return prompt + + +def _build_verina_problem_text(example) -> str: + """Build the shared verina problem description used by both propose and synthesis prompts.""" + data = example # BenchmarkData instance + signature = data.signature + func_name = signature.get("name", "solution") + return_type = signature.get("return_type", "Bool") + params = signature.get("parameters", []) + param_list = " ".join( + f"({p['param_name']} : {p['param_type']})" for p in params + ) + precond_name = f"{func_name}_precond" + param_names_str = " ".join(f"({p['param_name']})" for p in params) + + solution_aux = data.lean_data.get("solution_aux", "").strip() + task_aux = data.lean_data.get("task_aux", "").strip() + code_aux = data.lean_data.get("code_aux", "").strip() + precond = data.lean_data.get("precond", "True").strip() + postcond = data.lean_data.get("postcond", "").strip() + + helper_section = "" + all_aux = "\n".join(filter(None, [solution_aux, task_aux, code_aux])) + if all_aux: + helper_section = f"\n## Helper Definitions\n```lean4\n{all_aux}\n```\n" + + return ( + f"## Task\n{data.description}\n\n" + f"## Function Signature\n```lean4\n" + f"def {func_name} {param_list} (h_precond : {precond_name} {param_names_str}) : {return_type} :=\n" + f" -- YOUR CODE HERE\n```\n\n" + f"## Precondition\n```lean4\ndef {precond_name} {param_list} : Prop := {precond}\n```\n\n" + f"## Postcondition\n```lean4\n" + f"def {func_name}_postcond {param_list} (result: {return_type}) : Prop := {postcond}\n```\n" + f"{helper_section}" + ) + + +def build_verina_tot_prompt(example) -> str: + """Build a plain-text verina problem prompt for ToT (no chat template).""" + return ( + "You are an expert Lean 4 programmer. Think step-by-step about how to " + "implement the function body described below.\n\n" + + _build_verina_problem_text(example) + ) + + +def build_verina_synthesis_prompt(problem: str, trajectory: str) -> str: + """Build a prompt that converts a reasoning trajectory into final Lean 4 code.""" + return ( + "You are an expert Lean 4 programmer. Based on the reasoning below, " + "produce the final function body expression. Output ONLY the code " + "wrapped in [CODE]...[/CODE] tags — no explanation.\n\n" + f"PROBLEM:\n{problem}\n\n" + f"REASONING TRAJECTORY:\n{trajectory}\n\n" + "Write the final Lean 4 function body expression inside [CODE]...[/CODE] tags." + ) + + +def _build_verina_spec_problem_text(example) -> str: + """Build the shared verina spec problem description used by both propose and synthesis prompts.""" + data = example # BenchmarkData instance + signature = data.signature + func_name = signature.get("name", "solution") + return_type = signature.get("return_type", "Bool") + params = signature.get("parameters", []) + param_list = " ".join( + f"({p['param_name']} : {p['param_type']})" for p in params + ) + param_names_str = " ".join(f"({p['param_name']})" for p in params) + precond_name = f"{func_name}_precond" + postcond_name = f"{func_name}_postcond" + + # Get reference implementation code + code = data.lean_data.get("code", "").strip() + code_aux = data.lean_data.get("code_aux", "").strip() + task_aux = data.lean_data.get("task_aux", "").strip() + + # Natural language spec descriptions + precond_desc = data.spec_desc.get("precond_desc", "") + postcond_desc = data.spec_desc.get("postcond_desc", "") + + spec_desc_section = "" + if precond_desc or postcond_desc: + spec_desc_section = ( + f"\n## Specification Hints\n" + f"Precondition: {precond_desc if precond_desc else 'Derive from task description'}\n" + f"Postcondition: {postcond_desc if postcond_desc else 'Derive from task description'}\n" + ) + + helper_section = "" + all_aux = "\n".join(filter(None, [task_aux, code_aux])) + if all_aux: + helper_section = f"\n## Helper Definitions\n```lean4\n{all_aux}\n```\n" + + code_section = "" + if code: + code_section = ( + f"\n## Reference Implementation\n```lean4\n" + f"def {func_name} {param_list} (h_precond : {precond_name} {param_names_str}) : {return_type} :=\n" + f" {code}\n```\n" + ) + + return ( + f"## Task\n{data.description}\n\n" + f"## Function Signature\n" + f"- Function name: {func_name}\n" + f"- Parameters: {param_list}\n" + f"- Return type: {return_type}\n\n" + f"## Expected Output Format\n```lean4\n" + f"-- Precondition auxiliary (optional)\n" + f"[PRECOND_AUX]\n" + f"-- helper definitions for precondition\n" + f"[/PRECOND_AUX]\n\n" + f"-- Precondition: when should the function be allowed to run?\n" + f"def {precond_name} {param_list} : Prop :=\n" + f" [PRECOND] -- your precondition here [/PRECOND]\n\n" + f"-- Postcondition auxiliary (optional)\n" + f"[POSTCOND_AUX]\n" + f"-- helper definitions for postcondition\n" + f"[/POSTCOND_AUX]\n\n" + f"-- Postcondition: what must be true about the result?\n" + f"def {postcond_name} {param_list} (result: {return_type}) " + f"(h_precond : {precond_name} {param_names_str}) : Prop :=\n" + f" [POSTCOND] -- your postcondition here [/POSTCOND]\n```\n" + f"{spec_desc_section}{helper_section}{code_section}" + ) + + +def build_verina_spec_tot_prompt(example) -> str: + """Build a plain-text verina spec-generation problem prompt for ToT.""" + return ( + "You are an expert Lean 4 programmer specializing in formal specifications. " + "Think step-by-step about how to write the precondition and postcondition " + "for the function described below.\n\n" + + _build_verina_spec_problem_text(example) + ) + + +def build_verina_spec_synthesis_prompt(problem: str, trajectory: str) -> str: + """Build a prompt that converts a reasoning trajectory into final Lean 4 spec.""" + return ( + "You are an expert Lean 4 programmer specializing in formal specifications. " + "Based on the reasoning below, produce the final precondition and postcondition. " + "Output ONLY the spec wrapped in tags — no explanation.\n\n" + "Use [PRECOND]...[/PRECOND] for the precondition body.\n" + "Use [POSTCOND]...[/POSTCOND] for the postcondition body.\n" + "Optionally use [PRECOND_AUX]...[/PRECOND_AUX] and [POSTCOND_AUX]...[/POSTCOND_AUX] " + "for helper definitions.\n\n" + "Use correct lean 4 syntax\n" + "Do not add imports in your generation specs\n\n" + f"PROBLEM:\n{problem}\n\n" + f"REASONING TRAJECTORY:\n{trajectory}\n\n" + "Write the final Lean 4 specification inside the tags above." + ) + + +def build_tot_problem(task: str, example: Dict[str, Any], nums: Optional[List[int]] = None) -> str: + """Helper that mirrors the best-of-k prompt builders for ToT runs.""" + task_lower = task.lower() + if task_lower == "game24": + numbers = nums or example.get("numbers") + if numbers is None: + raise ValueError("Game24 prompt requires 'numbers' in the example") + return build_game24_prompt(list(numbers)) + if task_lower == "maze": + return build_maze_prompt(example) + if task_lower == "spatialmap": + return build_spatialmap_prompt(example) + if task_lower == "zebralogic": + return build_zebralogic_prompt(example) + if task_lower == "verina": + return build_verina_tot_prompt(example) + if task_lower == "verina_spec": + return build_verina_spec_tot_prompt(example) + raise ValueError(f"Unsupported task for ToT prompt building: {task}") + + +def build_tot_value_prompt( + task: str, + problem: str, + trajectory: str, + use_fewshot: bool = True +) -> str: + """ + Build value prompt for Tree of Thought evaluation. + + Args: + task: The task type (e.g., "game24", "maze", "spatialmap") + problem: The original problem statement + trajectory: Current partial solution (or 'No progress yet' if empty) + use_fewshot: Whether to use few-shot examples (default True for better evaluation) + + Returns: + Formatted value prompt with or without few-shot examples + """ + if not trajectory.strip(): + trajectory = "No progress yet" + return build_tot_value_prompt_impl(task, problem, trajectory, use_fewshot=use_fewshot) + + +class SearchMethod(Enum): + """Search algorithm types""" + BFS = "bfs" + DFS = "dfs" + BEAM = "beam" + + +@dataclass +class TreeNode: + """Represents a node in the Tree of Thought""" + trajectory: str + depth: int + value: float = 0.5 + parent: Optional['TreeNode'] = None + children: List['TreeNode'] = field(default_factory=list) + is_terminal: bool = False + proposals: List[str] = field(default_factory=list) + evaluation_log: Dict[str, Any] = field(default_factory=dict) + + def __hash__(self): + return hash(self.trajectory) + + def __eq__(self, other): + return isinstance(other, TreeNode) and self.trajectory == other.trajectory + + +@dataclass +class ToTSearchConfig: + """Configuration for Tree of Thought search""" + branching_factor: int = 4 + max_depth: int = 6 + search_method: SearchMethod = SearchMethod.BFS + beam_width: int = 2 + + # Value thresholds + sure_threshold: float = 0.7 + likely_threshold: float = 0.5 + impossible_threshold: float = 0.2 + + # Optimization settings + early_termination: bool = True + cache_evaluations: bool = True + max_candidates_per_level: int = 3 + + +class TreeOfThoughtSearch: + """ + Tree of Thought search controller compatible with interwhen's streaming. + + Provides propose/evaluate/search methods that work with vLLM API calls + via the llm_server interface used in interwhen. + """ + + def __init__(self, config: ToTSearchConfig = None): + self.config = config or ToTSearchConfig() + self.evaluation_cache = {} + self.proposal_cache = {} + self.search_stats = { + "nodes_explored": 0, + "evaluations_performed": 0, + "branches_pruned": 0, + "cache_hits": 0, + "solutions_found": 0, + "total_nodes_in_tree": 0, + } + self.decision_tree = [] + self.root = None + + # ===================== PROPOSE FUNCTION ===================== + + async def propose_next_steps( + self, + task: str, + problem: str, + current_trajectory: str, + llm_server: Dict, + num_proposals: Optional[int] = None, + ) -> List[str]: + """ + Generate candidate next steps using the model's propose capability. + + Args: + task: The task type (e.g., "game24", "maze", "spatialmap") + problem: The original problem statement + current_trajectory: Current partial solution + llm_server: vLLM server config (url, headers, payload template) + num_proposals: Number of proposals to generate (defaults to branching_factor) + + Returns: + List of proposed next steps + """ + if num_proposals is None: + num_proposals = self.config.branching_factor + + # Check cache + cache_key = f"propose_{hash(problem)}_{hash(current_trajectory)}" + if self.config.cache_evaluations and cache_key in self.proposal_cache: + self.search_stats["cache_hits"] += 1 + return self.proposal_cache[cache_key] + + self.search_stats["nodes_explored"] += 1 + + # Build propose prompt + propose_prompt = self._build_propose_prompt( + task, + problem, + current_trajectory, + num_proposals + ) + + # Call model with streaming + proposal_text = await self._call_llm_streaming( + llm_server, + propose_prompt + ) + + # Parse proposals from response + proposals = self._parse_proposals(proposal_text, num_proposals) + + logger.info( + "Generated %d proposals at depth hint=%s", + len(proposals), + "root" if not current_trajectory.strip() else "non-root", + ) + + # Log decision point + decision_log = { + "type": "proposal_generation", + "timestamp": time.time(), + "problem_hash": hash(problem), + "trajectory": current_trajectory, + "prompt": propose_prompt, + "prompt_preview": propose_prompt[:200] + "..." if len(propose_prompt) > 200 else propose_prompt, + "raw_response": proposal_text, + "raw_response_preview": proposal_text[:300] + "..." if len(proposal_text) > 300 else proposal_text, + "parsed_proposals": proposals, + } + self.decision_tree.append(decision_log) + + # Cache + if self.config.cache_evaluations: + self.proposal_cache[cache_key] = proposals + + return proposals + + def _build_propose_prompt( + self, + task: str, + problem: str, + trajectory: str, + num_proposals: int + ) -> str: + """Build a prompt requesting proposals for next steps.""" + if task == "maze": + return self._build_maze_propose_prompt(problem, trajectory, num_proposals) + if task == "spatialmap": + return self._build_spatialmap_propose_prompt(problem, trajectory, num_proposals) + if task == "verina": + return self._build_verina_propose_prompt(problem, trajectory, num_proposals) + if task == "verina_spec": + return self._build_verina_spec_propose_prompt(problem, trajectory, num_proposals) + + return f"""Given the following problem and current progress, propose {num_proposals} possible next steps. + +PROBLEM: +{problem} + +CURRENT PROGRESS/TRAJECTORY: +{trajectory if trajectory.strip() else "Starting fresh - no progress yet"} + +Generate {num_proposals} distinct next steps that could advance the solution. Be specific and actionable. + +Format each proposal clearly, one per line: +1. [Proposal 1] +2. [Proposal 2] +... + +Think step by step about what makes each proposal viable. +""" + + def _detect_maze_question_type(self, problem: str) -> str: + """Detect maze subtype for proposal steering (Q0/Q2/Q4).""" + lower = problem.lower() + if "how many right turns" in lower: + logger.debug("Detected Q0: right turn counting") + return "q0" + if "how many turns" in lower and "right turns" not in lower: + logger.debug("Detected Q2: total turn counting") + return "q2" + if "starting from s" in lower and "where is e" in lower: + logger.debug("Detected Q4: spatial relation") + return "q4" + if "relative" in lower and "s" in lower and "e" in lower: + logger.debug("Detected Q4: spatial relation (relative)") + return "q4" + logger.warning(f"Maze question type not recognized, using generic. Problem preview: {lower[:200]}") + return "generic" + + def _extract_last_move_info(self, trajectory: str) -> dict: + """Extract previous direction and counter from last step in trajectory.""" + if not trajectory.strip(): + return {"prev_direction": None, "right_count": 0, "left_count": 0, "total_count": 0} + + lines = trajectory.strip().split('\n') + last_line = lines[-1] if lines else "" + + # Try to extract direction from "Next move: [DIRECTION]" + import re + direction_match = re.search(r'Next move:\s*(UP|DOWN|LEFT|RIGHT)', last_line, re.IGNORECASE) + prev_direction = direction_match.group(1).upper() if direction_match else None + + # Extract counters + right_match = re.search(r'Right-turn count:\s*(\d+)', last_line) + left_match = re.search(r'Left-turn count:\s*(\d+)', last_line) + total_match = re.search(r'Total-turn count:\s*(\d+)', last_line) + + right_count = int(right_match.group(1)) if right_match else 0 + left_count = int(left_match.group(1)) if left_match else 0 + total_count = int(total_match.group(1)) if total_match else 0 + + return { + "prev_direction": prev_direction, + "right_count": right_count, + "left_count": left_count, + "total_count": total_count + } + + def _build_maze_propose_prompt( + self, + problem: str, + trajectory: str, + num_proposals: int, + ) -> str: + """Build maze-specific atomic next-step proposal prompts by question type.""" + question_type = self._detect_maze_question_type(problem) + logger.debug(f"Detected maze question type: {question_type}") + + # Extract previous move info for bookkeeping + prev_info = self._extract_last_move_info(trajectory) + + if not trajectory.strip(): + current = "Starting fresh - no progress yet" + last_step_hint = "" + else: + current = trajectory.strip() + lines = current.split('\n') + last_line = lines[-1] if lines else "" + last_step_hint = f"\nLAST COMPLETED STEP: {last_line}\nNow generate the NEXT move after this (do NOT repeat this move).\n" + + if question_type == "q0": + prev_dir = prev_info["prev_direction"] + prev_count = prev_info["right_count"] + + if prev_dir is None: + # First move - all directions result in STRAIGHT with count 0 + examples = f"""PARENT: first move, count=0 + +Valid answers (pick {num_proposals}): +Next move: UP | Turn: STRAIGHT | Right-turn count: 0 +Next move: DOWN | Turn: STRAIGHT | Right-turn count: 0 +Next move: LEFT | Turn: STRAIGHT | Right-turn count: 0 +Next move: RIGHT | Turn: STRAIGHT | Right-turn count: 0""" + else: + # Define turn mappings + turn_map = { + "UP": {"RIGHT": "RIGHT", "LEFT": "LEFT", "UP": "STRAIGHT", "DOWN": "STRAIGHT"}, + "DOWN": {"LEFT": "RIGHT", "RIGHT": "LEFT", "DOWN": "STRAIGHT", "UP": "STRAIGHT"}, + "LEFT": {"UP": "RIGHT", "DOWN": "LEFT", "LEFT": "STRAIGHT", "RIGHT": "STRAIGHT"}, + "RIGHT": {"DOWN": "RIGHT", "UP": "LEFT", "RIGHT": "STRAIGHT", "LEFT": "STRAIGHT"} + } + + moves = turn_map.get(prev_dir, {}) + examples_list = [] + for next_dir, turn_type in moves.items(): + new_count = prev_count + 1 if turn_type == "RIGHT" else prev_count + examples_list.append(f"Next move: {next_dir} | Turn: {turn_type} | Right-turn count: {new_count}") + + examples = f"""PARENT: direction={prev_dir}, count={prev_count} + +Valid answers (pick {num_proposals}): +{chr(10).join(examples_list)}""" + + return f"""{examples} + +DO NOT explain. DO NOT reason. Just output {num_proposals} lines from above.""" + + if question_type == "q2": + prev_dir = prev_info["prev_direction"] + prev_count = prev_info["total_count"] + + if prev_dir is None: + examples = f"""PARENT: first move, count=0 + +Valid answers (pick {num_proposals}): +Next move: UP | Turn: STRAIGHT | Total-turn count: 0 +Next move: DOWN | Turn: STRAIGHT | Total-turn count: 0 +Next move: LEFT | Turn: STRAIGHT | Total-turn count: 0 +Next move: RIGHT | Turn: STRAIGHT | Total-turn count: 0""" + else: + # Define turn mappings (same as Q0) + turn_map = { + "UP": {"RIGHT": "RIGHT", "LEFT": "LEFT", "UP": "STRAIGHT", "DOWN": "STRAIGHT"}, + "DOWN": {"LEFT": "RIGHT", "RIGHT": "LEFT", "DOWN": "STRAIGHT", "UP": "STRAIGHT"}, + "LEFT": {"UP": "RIGHT", "DOWN": "LEFT", "LEFT": "STRAIGHT", "RIGHT": "STRAIGHT"}, + "RIGHT": {"DOWN": "RIGHT", "UP": "LEFT", "RIGHT": "STRAIGHT", "LEFT": "STRAIGHT"} + } + + moves = turn_map.get(prev_dir, {}) + examples_list = [] + for next_dir, turn_type in moves.items(): + new_count = prev_count + 1 if turn_type in ["RIGHT", "LEFT"] else prev_count + examples_list.append(f"Next move: {next_dir} | Turn: {turn_type} | Total-turn count: {new_count}") + + examples = f"""PARENT: direction={prev_dir}, count={prev_count} + +Valid answers (pick {num_proposals}): +{chr(10).join(examples_list)}""" + + return f"""{examples} + +DO NOT explain. DO NOT reason. Just output {num_proposals} lines from above.""" + + if question_type == "q4": + # Q4 should also be structured - no long reasoning + return f"""Maze spatial question. Generate {num_proposals} brief factual statements. + +{current if current else "Starting."} + +Output {num_proposals} lines. Each line: one short fact. NO long explanations.""" + + # Generic fallback - also keep it structured + return f"""Maze question. Generate {num_proposals} next steps. + +{current if current else "Starting."} + +Output {num_proposals} lines. Each line: one short step. NO explanations.""" + + def _detect_spatialmap_question_type(self, problem: str) -> str: + """Detect spatialmap subtype for proposal steering (direction/object/counting).""" + lower = problem.lower() + if "how many" in lower and ("objects" in lower or "places" in lower or "locations" in lower): + return "counting" + if "which object" in lower or "what object" in lower or "which place" in lower or "which location" in lower: + return "object" + if "in which direction" in lower or "what direction" in lower or "relative to" in lower: + return "direction" + return "generic" + + def _build_spatialmap_propose_prompt( + self, + problem: str, + trajectory: str, + num_proposals: int, + ) -> str: + """Build spatialmap-specific atomic next-step proposal prompts by question type.""" + question_type = self._detect_spatialmap_question_type(problem) + logger.debug(f"Detected spatialmap question type: {question_type}") + + if not trajectory.strip(): + current = "Starting fresh - no progress yet" + last_step_hint = "" + else: + current = trajectory.strip() + lines = [line.strip() for line in current.split("\n") if line.strip()] + last_line = lines[-1] if lines else "" + last_step_hint = ( + f"\nLAST COMPLETED STEP: {last_line}\n" + "Now generate the NEXT atomic step after this (do NOT repeat this step).\n" + ) + + if question_type == "direction": + return f"""You are solving a spatial-map DIRECTION question. + +PROBLEM: +{problem} + +TRAJECTORY SO FAR: +{current} +{last_step_hint} +Your task: Propose {num_proposals} ATOMIC next steps only. +- Each step must advance exactly ONE concrete spatial inference +- Prefer one of: parse one relation, apply reversibility once, apply transitivity once, or map target-vs-reference direction +- Do NOT restate the whole map + +Output format (one line per proposal, no preamble): +1. [Atomic spatial inference] +2. [Atomic spatial inference] +...""" + + if question_type == "object": + return f"""You are solving a spatial-map OBJECT-IDENTIFICATION question. + +PROBLEM: +{problem} + +TRAJECTORY SO FAR: +{current} +{last_step_hint} +Your task: Propose {num_proposals} ATOMIC next steps only. +- Each step should do one action: identify candidate set, eliminate one candidate, or validate one relation against query direction +- Keep steps local and specific to the asked direction/object +- Do NOT rewrite all relationships + +Output format (one line per proposal, no preamble): +1. [Atomic candidate/evidence step] +2. [Atomic candidate/evidence step] +...""" + + if question_type == "counting": + return f"""You are solving a spatial-map COUNTING question. + +PROBLEM: +{problem} + +TRAJECTORY SO FAR: +{current} +{last_step_hint} +Your task: Propose {num_proposals} ATOMIC next steps only. +- Each step should do one action: identify one qualifying object, rule out one object, or update running count by exactly one justified change +- Keep a clear running count state +- Do NOT provide final answer yet unless count is fully justified + +Output format (one line per proposal, no preamble): +1. [Atomic counting step, e.g., "Qualifies: ; running count = n"] +2. [Atomic counting step, e.g., "Ruled out: ; running count = n"] +...""" + + return f"""You are solving a spatial-map reasoning question. + +PROBLEM: +{problem} + +TRAJECTORY SO FAR: +{current} +{last_step_hint} +Propose {num_proposals} atomic, actionable next steps. +Each step must add ONE new spatial fact/inference and be different from prior steps. + +Output format (one line per proposal, no preamble): +1. ... +2. ... +...""" + + def _build_verina_propose_prompt( + self, + problem: str, + trajectory: str, + num_proposals: int, + ) -> str: + """Build reasoning-step proposal prompts for verina Lean 4 code generation.""" + if not trajectory.strip(): + current = "Starting fresh - no progress yet" + last_step_hint = "" + else: + current = trajectory.strip() + lines = [line.strip() for line in current.split("\n") if line.strip()] + last_line = lines[-1] if lines else "" + last_step_hint = ( + f"\nLAST COMPLETED STEP: {last_line}\n" + "Now generate the NEXT reasoning step after this (do NOT repeat this step).\n" + ) + + return f"""You are reasoning step-by-step toward implementing a Lean 4 function body. +Do NOT write any code yet. Only output reasoning about the approach. + +PROBLEM: +{problem} + +REASONING SO FAR: +{current} +{last_step_hint} +Your task: Propose {num_proposals} DISTINCT next reasoning steps. +Each step MUST add NEW information not already present in the reasoning so far. Valid new information includes: +- A specific Lean 4 function name, operator, or syntax not yet mentioned +- An edge case or branch not yet considered +- A type or signature detail relevant to the implementation +- A postcondition argument not yet made + +Do NOT propose steps that merely rephrase, verify, or confirm what is already in the reasoning. "Confirm that X works" or "Verify that Y is correct" are NOT valid proposals unless they add concrete new details. +Do NOT output [CODE] tags. Only reasoning about how to solve the problem. + +Output format (one step per line, no preamble): +1. [Reasoning step] +2. [Reasoning step] +...""" + + def _build_verina_spec_propose_prompt( + self, + problem: str, + trajectory: str, + num_proposals: int, + ) -> str: + """Build reasoning-step proposal prompts for verina Lean 4 spec generation.""" + if not trajectory.strip(): + current = "Starting fresh - no progress yet" + last_step_hint = "" + else: + current = trajectory.strip() + lines = [line.strip() for line in current.split("\n") if line.strip()] + last_line = lines[-1] if lines else "" + last_step_hint = ( + f"\nLAST COMPLETED STEP: {last_line}\n" + "Now generate the NEXT reasoning step after this (do NOT repeat this step).\n" + ) + + return f"""You are reasoning step-by-step toward writing Lean 4 preconditions and postconditions. +Do NOT write any spec code yet. Only output reasoning about what the precondition and postcondition should capture. + +PROBLEM: +{problem} + +REASONING SO FAR: +{current} +{last_step_hint} +Your task: Propose {num_proposals} DISTINCT next reasoning steps. +Each step MUST add NEW information not already present in the reasoning so far. Valid new information includes: +- A specific Lean 4 construct, type, or operator for the spec not yet mentioned +- An input constraint or edge case for the precondition not yet considered +- A postcondition property or invariant not yet identified +- A relationship between input and output that the spec should capture + +Do NOT propose steps that merely rephrase, verify, or confirm what is already in the reasoning. +Do NOT output [PRECOND], [POSTCOND], [PRECOND_AUX], or [POSTCOND_AUX] tags. Only reasoning about what the spec should express. + +Output format (one step per line, no preamble): +1. [Reasoning step] +2. [Reasoning step] +...""" + + @staticmethod + def _clean_proposal_text(text: str) -> str: + """Strip chain-of-thought artifacts (, trailing meta-commentary).""" + think_idx = text.find("") + if think_idx != -1: + text = text[think_idx + len(""):] + text = re.sub(r"", "", text) + return text.strip() + + def _parse_proposals(self, response: str, num_proposals: int) -> List[str]: + """ + Parse proposals from model response. + Handles various formats (numbered lists, bullets, etc.) + """ + proposals: List[str] = [] + + # First, try to extract exact format lines: "Next move: X | Turn: Y | ...count: Z" + # Match line by line to avoid cross-line pollution + for line in response.split('\n'): + line = line.strip() + # Check if it matches our exact format + if 'Next move:' in line and 'Turn:' in line and 'count:' in line: + proposals.append(line) + if len(proposals) >= num_proposals: + return proposals[:num_proposals] + + # If we got enough exact format proposals, return them + if len(proposals) >= num_proposals: + return proposals[:num_proposals] + + # Fallback: numbered multiline blocks (preserve continuation lines) + numbered_blocks = re.findall( + r"(?:^|\n)\s*\d+[\.)]\s*(.+?)(?=(?:\n\s*\d+[\.)]\s*)|\Z)", + response, + flags=re.DOTALL, + ) + for block in numbered_blocks: + cleaned = " ".join(block.strip().split()) + if cleaned and len(cleaned) > 3: + proposals.append(cleaned) + if len(proposals) >= num_proposals: + return proposals[:num_proposals] + + # Second pass: fallback line parser for bullets/single-line proposals + for line in response.split("\n"): + cleaned = line.strip() + if not cleaned or cleaned in ["Next steps:", "Proposals:", "possible next steps"]: + continue + + cleaned = re.sub(r"^\s*(?:\d+[\.)]|[-•*])\s*", "", cleaned) + cleaned = re.sub(r"^\s*\[\d+\]\s*", "", cleaned) + cleaned = re.sub(r"^\s*proposal\s*[:\-]\s*", "", cleaned, flags=re.IGNORECASE) + cleaned = " ".join(cleaned.split()) + + if cleaned and len(cleaned) > 3: + proposals.append(cleaned) + if len(proposals) >= num_proposals: + return proposals[:num_proposals] + + return proposals[:num_proposals] + + # ===================== EVALUATE FUNCTION ===================== + + async def evaluate_state( + self, + task: str, + problem: str, + trajectory: str, + llm_server: Dict, + ) -> float: + """ + Evaluate the quality/progress of current state. + + Args: + task: The task type (e.g., "game24", "maze", "spatialmap") + problem: Original problem + trajectory: Current solution trajectory + llm_server: vLLM server config + + Returns: + Value score between 0.0 and 1.0 + """ + # Check cache + cache_key = f"evaluate_{hash(problem)}_{hash(trajectory)}" + if self.config.cache_evaluations and cache_key in self.evaluation_cache: + self.search_stats["cache_hits"] += 1 + return self.evaluation_cache[cache_key] + + self.search_stats["evaluations_performed"] += 1 + + # Build evaluation prompt + eval_prompt = self._build_evaluation_prompt(task, problem, trajectory) + + # Call model + eval_response = await self._call_llm_streaming(llm_server, eval_prompt) + + # Parse evaluation into score + score = self._parse_evaluation(eval_response) + confidence_label = self._extract_confidence_label(eval_response, score) + + # Log evaluation + eval_log = { + "type": "state_evaluation", + "timestamp": time.time(), + "trajectory": trajectory, + "prompt_preview": eval_prompt[:200] + "...", + "response_preview": eval_response[:200] + "...", + "score": score, + "confidence": confidence_label, + } + if self.decision_tree: + if "evaluations" not in self.decision_tree[-1]: + self.decision_tree[-1]["evaluations"] = [] + self.decision_tree[-1]["evaluations"].append(eval_log) + + # Cache + if self.config.cache_evaluations: + self.evaluation_cache[cache_key] = score + + return score + + def _build_evaluation_prompt(self, task: str, problem: str, trajectory: str) -> str: + """Build dataset-aware evaluation prompts reused by ToT scoring.""" + return build_tot_value_prompt(task, problem, trajectory) + + def _parse_evaluation(self, response: str) -> float: + """ + Parse evaluation response into a scalar score [0, 1]. + Adds small random noise (±0.05) to break ties between nodes + that receive the same confidence label. + """ + response_lower = response.lower() + + confidence_keywords = { + "sure": 0.9, "certain": 0.9, "confident": 0.9, + "likely": 0.7, "probably": 0.7, + "possible": 0.5, "maybe": 0.5, + "unlikely": 0.3, "doubtful": 0.3, + "impossible": 0.1, "blocked": 0.1, + } + + base_score = None + for keyword, score in confidence_keywords.items(): + if keyword in response_lower: + base_score = score + break + + if base_score is None: + # Try to extract numeric score if present (1-9 scale) + for i, char in enumerate(response): + if char.isdigit(): + digit = int(char) + if 1 <= digit <= 9: + base_score = digit / 9.0 + break + + if base_score is None: + base_score = 0.5 # Default neutral score + + # Add small uniform noise to break ties + noise = random.uniform(-0.05, 0.05) + return max(0.0, min(1.0, base_score + noise)) + + def _score_to_confidence(self, score: float) -> str: + """Map scalar score [0,1] to confidence bucket.""" + if score >= 0.8: + return "sure" + if score >= 0.6: + return "likely" + if score >= 0.4: + return "possible" + if score >= 0.2: + return "unlikely" + return "impossible" + + def _extract_confidence_label(self, response: str, score: float) -> str: + """Extract confidence label from value response; fallback to score mapping.""" + lower = response.lower() + for label in ["sure", "likely", "possible", "unlikely", "impossible"]: + if label in lower: + return label + return self._score_to_confidence(score) + + def _log_proposal_transition( + self, + depth: int, + parent_trajectory: str, + proposal: str, + next_state: str, + value: float, + is_terminal: bool, + pruned: bool, + ) -> None: + """Log proposal -> next-state -> value transition for debugging/analysis.""" + self.decision_tree.append( + { + "type": "proposal_transition", + "timestamp": time.time(), + "depth": depth, + "parent_trajectory": parent_trajectory, + "proposal": proposal, + "next_state": next_state, + "value": value, + "value_confidence": self._score_to_confidence(value), + "is_terminal": is_terminal, + "pruned": pruned, + } + ) + + logger.info( + "ToT transition | depth=%d | proposal=%s | value=%.3f | confidence=%s | pruned=%s | terminal=%s", + depth, + proposal, + value, + self._score_to_confidence(value), + pruned, + is_terminal, + ) + + # ===================== SEARCH IMPLEMENTATION ===================== + + async def search( + self, + task: str, + problem: str, + llm_server: Dict, + ) -> Dict[str, Any]: + """ + Perform Tree of Thought search on the problem. + + Args: + task: The task type (e.g., "game24", "maze", "spatialmap") + problem: Problem statement + llm_server: vLLM server config + + Returns: + Dictionary with best_trajectory, best_value, search_log + """ + logger.info(f"Starting ToT search with method={self.config.search_method.value}") + + # Initialize root node + self.root = TreeNode(trajectory="", depth=0, value=0.5) + + if self.config.search_method == SearchMethod.BFS: + return await self._bfs_search(task, problem, llm_server) + elif self.config.search_method == SearchMethod.BEAM: + return await self._beam_search(task, problem, llm_server) + else: + return await self._dfs_search(task, problem, llm_server) + + async def _bfs_search(self, task: str, problem: str, llm_server: Dict) -> Dict[str, Any]: + """Breadth-First Search implementation""" + queue = [self.root] + best_terminal = None + best_value = 0.0 + best_candidate = None + best_candidate_value = float('-inf') + + for depth in range(self.config.max_depth): + if not queue: + break + + next_queue = [] + + for node in queue: + # Generate proposals + proposals = await self.propose_next_steps( + task, + problem, + node.trajectory, + llm_server, + self.config.branching_factor + ) + node.proposals = proposals + + # Create child nodes + for prop in proposals: + new_trajectory = f"{node.trajectory}\n{prop}" if node.trajectory else prop + child = TreeNode( + trajectory=new_trajectory, + depth=depth + 1, + parent=node, + ) + + # Evaluate + value = await self.evaluate_state(task, problem, new_trajectory, llm_server) + child.value = value + + # Track best candidate regardless of terminal status + if value > best_candidate_value: + best_candidate = child + best_candidate_value = value + + # Check if terminal and meets threshold + if self._is_terminal(new_trajectory, task): + child.is_terminal = True + self.search_stats["solutions_found"] += 1 + if value > best_value: + best_value = value + best_terminal = child + is_terminal = True + + # Early termination if high confidence + if self.config.early_termination and value >= self.config.sure_threshold: + self._log_proposal_transition( + depth=depth + 1, + parent_trajectory=node.trajectory, + proposal=prop, + next_state=new_trajectory, + value=value, + is_terminal=is_terminal, + pruned=False, + ) + return self._format_search_result(best_terminal, problem) + else: + is_terminal = False + + # Prune low-value nodes + if value < self.config.impossible_threshold: + self.search_stats["branches_pruned"] += 1 + self._log_proposal_transition( + depth=depth + 1, + parent_trajectory=node.trajectory, + proposal=prop, + next_state=new_trajectory, + value=value, + is_terminal=is_terminal, + pruned=True, + ) + continue + + self._log_proposal_transition( + depth=depth + 1, + parent_trajectory=node.trajectory, + proposal=prop, + next_state=new_trajectory, + value=value, + is_terminal=is_terminal, + pruned=False, + ) + + node.children.append(child) + next_queue.append(child) + self.search_stats["total_nodes_in_tree"] += 1 + + queue = next_queue[:self.config.max_candidates_per_level] + + return self._format_search_result(best_terminal or best_candidate, problem) + + async def _beam_search(self, task: str, problem: str, llm_server: Dict) -> Dict[str, Any]: + """Beam Search implementation""" + beam = [self.root] + best_terminal = None + best_value = 0.0 + best_candidate = None + best_candidate_value = float('-inf') + + for depth in range(self.config.max_depth): + candidates = [] + + for node in beam: + # Generate and evaluate proposals + proposals = await self.propose_next_steps( + task, + problem, + node.trajectory, + llm_server, + self.config.branching_factor + ) + node.proposals = proposals + + for prop in proposals: + new_trajectory = f"{node.trajectory}\n{prop}" if node.trajectory else prop + value = await self.evaluate_state(task, problem, new_trajectory, llm_server) + + child = TreeNode( + trajectory=new_trajectory, + depth=depth + 1, + value=value, + parent=node, + ) + + candidates.append((child, value)) + + if value > best_candidate_value: + best_candidate = child + best_candidate_value = value + self.search_stats["total_nodes_in_tree"] += 1 + + if self._is_terminal(new_trajectory, task): + child.is_terminal = True + self.search_stats["solutions_found"] += 1 + if value > best_value: + best_value = value + best_terminal = child + is_terminal = True + + if self.config.early_termination and value >= self.config.sure_threshold: + self._log_proposal_transition( + depth=depth + 1, + parent_trajectory=node.trajectory, + proposal=prop, + next_state=new_trajectory, + value=value, + is_terminal=is_terminal, + pruned=False, + ) + return self._format_search_result(best_terminal, problem) + else: + is_terminal = False + + pruned = value < self.config.impossible_threshold + if pruned: + self.search_stats["branches_pruned"] += 1 + + self._log_proposal_transition( + depth=depth + 1, + parent_trajectory=node.trajectory, + proposal=prop, + next_state=new_trajectory, + value=value, + is_terminal=is_terminal, + pruned=pruned, + ) + + if pruned: + continue + + # Keep top-k by value + candidates.sort(key=lambda x: x[1], reverse=True) + beam = [child for child, _ in candidates[:self.config.beam_width]] + + if not beam: + break + + return self._format_search_result(best_terminal or best_candidate, problem) + + async def _dfs_search(self, task: str, problem: str, llm_server: Dict) -> Dict[str, Any]: + """Depth-First Search implementation""" + best_terminal = None + best_value = 0.0 + best_candidate = None + best_candidate_value = float('-inf') + + async def dfs(node: TreeNode, depth: int): + nonlocal best_terminal, best_value + + if depth >= self.config.max_depth: + return + + # Generate proposals + proposals = await self.propose_next_steps( + task, + problem, + node.trajectory, + llm_server, + self.config.branching_factor + ) + node.proposals = proposals + + for prop in proposals: + new_trajectory = f"{node.trajectory}\n{prop}" if node.trajectory else prop + value = await self.evaluate_state(task, problem, new_trajectory, llm_server) + + child = TreeNode( + trajectory=new_trajectory, + depth=depth + 1, + value=value, + parent=node, + ) + node.children.append(child) + + if value > best_candidate_value: + best_candidate = child + best_candidate_value = value + self.search_stats["total_nodes_in_tree"] += 1 + + if self._is_terminal(new_trajectory, task): + child.is_terminal = True + self.search_stats["solutions_found"] += 1 + if value > best_value: + best_value = value + best_terminal = child + is_terminal = True + + if self.config.early_termination and value >= self.config.sure_threshold: + self._log_proposal_transition( + depth=depth + 1, + parent_trajectory=node.trajectory, + proposal=prop, + next_state=new_trajectory, + value=value, + is_terminal=is_terminal, + pruned=False, + ) + return + else: + is_terminal = False + + # Prune + if value >= self.config.impossible_threshold: + self._log_proposal_transition( + depth=depth + 1, + parent_trajectory=node.trajectory, + proposal=prop, + next_state=new_trajectory, + value=value, + is_terminal=is_terminal, + pruned=False, + ) + await dfs(child, depth + 1) + else: + self.search_stats["branches_pruned"] += 1 + self._log_proposal_transition( + depth=depth + 1, + parent_trajectory=node.trajectory, + proposal=prop, + next_state=new_trajectory, + value=value, + is_terminal=is_terminal, + pruned=True, + ) + + await dfs(self.root, 0) + return self._format_search_result(best_terminal or best_candidate, problem) + + # ===================== UTILITIES ===================== + + def _is_terminal(self, trajectory: str, task: str = "") -> bool: + """Check if trajectory represents a complete solution""" + # Verina: terminal when the trajectory contains a [CODE] block, + # consistent with how other tasks detect final answers. + if task == "verina": + return bool(re.search(r'\[CODE\]', trajectory, re.IGNORECASE)) + # Verina spec: terminal when trajectory contains both [PRECOND] and [POSTCOND]. + if task == "verina_spec": + has_precond = bool(re.search(r'\[PRECOND\]', trajectory, re.IGNORECASE)) + has_postcond = bool(re.search(r'\[POSTCOND\]', trajectory, re.IGNORECASE)) + return has_precond and has_postcond + keywords = [ + "final answer", + "reached goal", + "solution:", + "answer:", + "conclusion:", + "result:", + ] + trajectory_lower = trajectory.lower() + return any(kw in trajectory_lower for kw in keywords) + + def _format_search_result( + self, + best_node: Optional[TreeNode], + problem: str + ) -> Dict[str, Any]: + """Format search results for return""" + if best_node: + best_trajectory = best_node.trajectory + best_value = best_node.value + else: + best_trajectory = "" + best_value = 0.0 + + return { + "best_trajectory": best_trajectory, + "best_value": best_value, + "search_stats": self.search_stats, + "decision_tree": self.decision_tree, + "root_node": self.root, + } + + async def _call_llm_streaming( + self, + llm_server: Dict, + prompt: str + ) -> str: + """Call the LLM via interwhen's stream_completion (no monitors).""" + return await stream_completion( + prompt, + llm_server=llm_server, + monitors=[], + add_delay=False, + termination_requires_validation=False, + async_execution=True, + ) + + def get_decision_tree_json(self) -> str: + """Export decision tree as JSON""" + return json.dumps({ + "search_stats": self.search_stats, + "decision_points": self.decision_tree, + "num_decision_points": len(self.decision_tree), + }, indent=2, default=str) + + def _serialize_node(self, node: Optional[TreeNode], max_depth: Optional[int] = None) -> Optional[Dict[str, Any]]: + """Serialize tree nodes recursively for debugging/state inspection.""" + if node is None: + return None + + if max_depth is not None and node.depth >= max_depth: + return { + "depth": node.depth, + "value": node.value, + "is_terminal": node.is_terminal, + "trajectory": node.trajectory, + "num_children": len(node.children), + "children": [], + "truncated": True, + } + + return { + "depth": node.depth, + "value": node.value, + "is_terminal": node.is_terminal, + "trajectory": node.trajectory, + "num_children": len(node.children), + "children": [self._serialize_node(child, max_depth=max_depth) for child in node.children], + } + + def get_state_snapshot( + self, + include_tree: bool = True, + max_tree_depth: Optional[int] = None, + decision_tail: Optional[int] = 50, + include_cache_samples: bool = True, + cache_sample_size: int = 5, + ) -> Dict[str, Any]: + """Return a comprehensive snapshot of the current ToT search state.""" + proposal_cache_keys = list(self.proposal_cache.keys()) + evaluation_cache_keys = list(self.evaluation_cache.keys()) + + snapshot: Dict[str, Any] = { + "config": { + "branching_factor": self.config.branching_factor, + "max_depth": self.config.max_depth, + "search_method": self.config.search_method.value, + "beam_width": self.config.beam_width, + "sure_threshold": self.config.sure_threshold, + "likely_threshold": self.config.likely_threshold, + "impossible_threshold": self.config.impossible_threshold, + "early_termination": self.config.early_termination, + "cache_evaluations": self.config.cache_evaluations, + "max_candidates_per_level": self.config.max_candidates_per_level, + }, + "search_stats": dict(self.search_stats), + "decision_tree_size": len(self.decision_tree), + "cache_state": { + "proposal_cache_size": len(self.proposal_cache), + "evaluation_cache_size": len(self.evaluation_cache), + }, + "root_present": self.root is not None, + "root_depth": self.root.depth if self.root is not None else None, + "root_value": self.root.value if self.root is not None else None, + "root_is_terminal": self.root.is_terminal if self.root is not None else None, + } + + if decision_tail is None: + snapshot["decision_tree"] = self.decision_tree + else: + snapshot["decision_tree_tail"] = self.decision_tree[-decision_tail:] + + if include_cache_samples: + sample_size = max(0, cache_sample_size) + snapshot["cache_state"]["proposal_cache_key_samples"] = proposal_cache_keys[:sample_size] + snapshot["cache_state"]["evaluation_cache_key_samples"] = evaluation_cache_keys[:sample_size] + + if include_tree: + snapshot["tree"] = self._serialize_node(self.root, max_depth=max_tree_depth) + + return snapshot + + def get_state_snapshot_json( + self, + include_tree: bool = True, + max_tree_depth: Optional[int] = None, + decision_tail: Optional[int] = 50, + include_cache_samples: bool = True, + cache_sample_size: int = 5, + indent: int = 2, + ) -> str: + """Return JSON string for the current ToT state snapshot.""" + snapshot = self.get_state_snapshot( + include_tree=include_tree, + max_tree_depth=max_tree_depth, + decision_tail=decision_tail, + include_cache_samples=include_cache_samples, + cache_sample_size=cache_sample_size, + ) + return json.dumps(snapshot, indent=indent, default=str) \ No newline at end of file diff --git a/examples/TTSwithVerification/ToT/utils/value_prompts.py b/examples/TTSwithVerification/ToT/utils/value_prompts.py new file mode 100644 index 0000000..7a1e52f --- /dev/null +++ b/examples/TTSwithVerification/ToT/utils/value_prompts.py @@ -0,0 +1,967 @@ +""" +Value prompts with few-shot examples for Tree of Thought evaluation across datasets. +""" + +# ============================================================================ +# GAME24 VALUE PROMPTS WITH FEW-SHOT EXAMPLES +# ============================================================================ + +GAME24_VALUE_PROMPT_WITH_FEWSHOT = """ +Evaluate if given numbers can reach 24 (sure/likely/impossible). + +PROBLEM STATEMENT: +{problem} + +CURRENT TRAJECTORY: +{trajectory} + +Here are examples of how to evaluate Game of 24 trajectories: + +EXAMPLE 1 - SURE: +Numbers: 4 4 6 8 +Trajectory: 4 + 8 = 12 (left: 4 6 12), 6 - 4 = 2 (left: 2 12), 2 * 12 = 24 +Analysis: Reaches exactly 24 using each number exactly once. +Confidence: sure (9) + +EXAMPLE 2 - SURE: +Numbers: 2 9 10 12 +Trajectory: 12 * 2 = 24 (left: 9 10 24), then 24 * (10 - 9) = 24 * 1 = 24 +Analysis: Valid path found with each number used once, equals 24. +Confidence: sure (9) + +EXAMPLE 3 - SURE: +Numbers: 10 14 +Trajectory: 10 + 14 = 24 +Analysis: Direct calculation using both numbers reaches exactly 24. +Confidence: sure (9) + +EXAMPLE 4 - SURE: +Numbers: 4 4 10 +Trajectory: (10 - 4) * 4 = 6 * 4 = 24 +Analysis: Uses all three numbers exactly once with factorization that reaches 24. +Confidence: sure (9) + +EXAMPLE 5 - SURE: +Numbers: 4 9 11 +Trajectory: 9 + 11 + 4 = 24 +Analysis: All numbers used, arithmetic valid, equals 24. +Confidence: sure (9) + +EXAMPLE 6 - LIKELY: +Numbers: 5 7 8 +Trajectory: 5 + 7 + 8 = 20, or (8 - 5) * 7 = 21 +Analysis: Cannot reach 24 immediately, but numbers in reasonable arithmetic range where 24 might be +achievable. +Confidence: likely (7) + +EXAMPLE 7 - LIKELY: +Numbers: 5 6 6 +Trajectory: 5 + 6 + 6 = 17, or (6 - 5) * 6 = 6 +Analysis: Current attempts don't reach 24, but numbers are within reasonable range. +Confidence: likely (7) + +EXAMPLE 8 - IMPOSSIBLE: +Numbers: 1 3 3 +Trajectory: 1 * 3 * 3 = 9, or (1 + 3) * 3 = 12 +Analysis: Maximum reachable with any operations is much less than 24. Numbers all too small. +Confidence: impossible (1) + +EXAMPLE 9 - IMPOSSIBLE: +Numbers: 10 10 11 +Trajectory: 10 + 10 + 11 = 31, or (11 - 10) * 10 = 10 +Analysis: Sum exceeds 24, factorizations fall short. Cannot reach exactly 24. +Confidence: impossible (1) + +EXAMPLE 10 - IMPOSSIBLE: +Numbers: 11 12 +Trajectory: 11 + 12 = 23, or 12 - 11 = 1, or 11 * 12 = 132, or 11 / 12 ≈ 0.91 +Analysis: No operation reaches 24. Sum close but not exact. +Confidence: impossible (1) + +Rubric: +- sure (9): Reaches 24 using each number exactly once +- likely (7): Cannot reach 24 yet, but numbers in reasonable range +- possible (5): Uncertain if 24 is reachable +- unlikely (3): Numbers seem misaligned +- impossible (1): Numbers demonstrably cannot reach 24 + +Respond with "Confidence: " followed by brief justification tied to arithmetic evaluations. +""" + +GAME24_VALUE_PROMPT_SIMPLE = """Evaluate if given numbers can reach 24. + +PROBLEM STATEMENT: +{problem} + +CURRENT TRAJECTORY: +{trajectory} + +Rate confidence on this rubric: +- sure (9): Reaches 24 using each number exactly once +- likely (7): Cannot reach 24 yet, but numbers in reasonable range +- possible (5): Uncertain if 24 is reachable +- unlikely (3): Numbers seem misaligned +- impossible (1): Numbers cannot reach 24 + +Respond with "Confidence: " and brief justification. +""" + + + + +# ============================================================================ +# MAZE VALUE PROMPTS WITH FEW-SHOT EXAMPLES +# ============================================================================ + + + +MAZE_VALUE_PROMPT_WITH_FEWSHOT = """Verify a maze reasoning trace. + +TASK PROMPT: +{problem} + +MODEL TRAJECTORY: +{trajectory} + +EXAMPLE 1 - SURE: +Question: Count right turns in path X from S to E +Trajectory: Carefully trace X-marked path. Starting at S, move UP (initial direction). +Then RIGHT (90 degrees clockwise = right turn 1). Then DOWN (90 degrees clockwise = right turn 2). +Then RIGHT (90 degrees clockwise = right turn 3). Continuing pattern: 6 right turns total. +Answer: B (6 right turns) +Analysis: Systematic path tracing with correct turn geometry, defensible count. +Confidence: sure (9) + +EXAMPLE 2 - SURE: +Question: What is the sequence of grid direction? +Trajectory: Following marked path from S: [0,0] then [0,1] (UP) then [1,1] (RIGHT) then [1,0] +(DOWN) then [2,0] (RIGHT). Each step verified against grid. +Answer: UP, RIGHT, DOWN, RIGHT +Analysis: Clear coordinate tracking, systematic verification. +Confidence: sure (9) + +EXAMPLE 3 - LIKELY: +Question: Count right turns in path from S to E +Trajectory: Observing marked path shows mostly straight movements with a zigzag pattern. Zigzags +suggest mostly left turns. Likely 0-2 right turns based on pattern. +Answer: A (0 right turns) +Confidence: likely (7) +Analysis: Shows reasonable spatial intuition but lacks systematic verification. + +EXAMPLE 4 - LIKELY: +Question: Is path continuous S to E? +Trajectory: I trace the marked path and it appears to connect from S all the way to E without +breaks. The X marks form a continuous line. +Answer: Yes +Confidence: likely (7) +Analysis: Reasonable assessment but could benefit from detailed step verification. + +EXAMPLE 5 - POSSIBLE: +Question: Navigate maze from S to E +Trajectory: Following path X... I see turns but the specific sequence is unclear to me. Could be 3 +or 4 right turns. +Answer: Uncertain between options +Confidence: possible (5) +Analysis: Recognizes task but cannot decisively trace path geometry. + +EXAMPLE 6 - UNLIKELY: +Question: Count right turns in path X +Trajectory: Tracing marked path X from S. Moving DOWN, then RIGHT (left turn?), then DOWN, then +RIGHT (right turn?). I'm confused about turn geometry. +Answer: Some right turns but not sure +Confidence: unlikely (3) +Analysis: Confused about path following and angle identification. + +EXAMPLE 7 - IMPOSSIBLE: +Question: Navigate from S to E following marked path +Trajectory: I move RIGHT, then up, then left, I'm not sure where the path goes. I think I hit a +wall. +Answer: I'm stuck +Confidence: impossible (1) +Analysis: Abandons task without following clearly marked X path provided. + +Rubric: sure (9), likely (7), possible (5), unlikely (3), impossible (1) +Respond with "Confidence: " + explanation referencing moves/directions. +""" + +MAZE_VALUE_PROMPT_SIMPLE = """Verify a maze reasoning trace. + +TASK PROMPT: +{problem} + +MODEL TRAJECTORY: +{trajectory} + +Judge if reasoning is consistent with maze/spatial relationships and if final answer is defensible. +Rubric: sure (9), likely (7), possible (5), unlikely (3), impossible (1) +Respond with "Confidence: " + explanation referencing moves/directions. +""" + + + + +# ============================================================================ +# SPATIAL REASONING VALUE PROMPTS WITH FEW-SHOT EXAMPLES +# ============================================================================ + + + + +SPATIALMAP_VALUE_PROMPT_WITH_FEWSHOT = """ +You are verifying a spatial reasoning multiple-choice trace. + +TASK PROMPT: +{problem} + +MODEL TRAJECTORY: +{trajectory} + +Here are examples of how to evaluate spatial reasoning: + +EXAMPLE 1: +Question: Based on the map, which location is northeast of the library? +Trajectory: I look at the map and see the library in the center. To the northeast means both north +AND east of that point. Looking at that quadrant, I see the museum is northeast of the library. +Answer: A (Museum) + +Analysis: The student correctly understands the spatial direction (northeast = north AND east), +correctly identifies it on the map, and selects the correct option. +Confidence: sure/certain (9) +Justification: The reasoning correctly applies spatial relationships and identifies the appropriate +location. + +EXAMPLE 2: +Question: Which building is closest to the park? +Trajectory: The park looks like it's in the middle of the map. Near it I see a building... looks +like it could be the school or maybe the library. I think the school is closer. +Answer: C (School) + +Analysis: The student makes reasonable spatial observations but doesn't verify distances or compare +alternatives systematically. +Confidence: likely/probably (7) +Justification: The reasoning shows spatial awareness but lacks systematic comparison of distances +to verify the answer. + +EXAMPLE 3: +Question: What is north of the train station? +Trajectory: I see the train station. North is... up on the map. I see some buildings, but I'm not +sure exactly which one. Could be the post office or the police station. +Answer: I'm not sure + +Analysis: The student recognizes the direction but fails to identify the specific location clearly. +Confidence: possible/maybe (5) +Justification: The student understands the spatial direction but cannot decisively identify which +building is in that location. + +EXAMPLE 4: +Question: If you're at the bank facing east, what's behind you? +Trajectory: At the bank facing east means I'm looking east. Behind me would be... west? I need to +think about what's west of the bank. I think there's a hotel or a store but I'm not sure. +Answer: Maybe the hotel + +Analysis: The student correctly understands relative directions (east/behind = west) but isn't +certain about the specific feature. +Confidence: unlikely/doubtful (3) +Justification: While the directional reasoning is sound, the uncertainty about the specific +location makes this answer questionable. + +Use the confidence rubric: sure/certain (9), likely/probably (7), possible/maybe (5), +unlikely/doubtful (3), impossible/blocked (1). +Respond with "Confidence: " plus a concise explanation that references spatial +relationships and map features. +""" + +SPATIALMAP_VALUE_PROMPT_SIMPLE = """You are verifying a spatial reasoning multiple-choice trace. + +TASK PROMPT: +{problem} + +MODEL TRAJECTORY: +{trajectory} + +Judge if the reasoning correctly applies spatial relationships (north, south, east, west, near, +far, etc.) and whether the final \\boxed{{choice}} is defensible. +Use the confidence rubric: sure/certain (9), likely/probably (7), possible/maybe (5), +unlikely/doubtful (3), impossible/blocked (1). +Respond with "Confidence: " plus a concise explanation that references the spatial +relationships and locations. +""" + + + +# ============================================================================ +# ZEBRA LOGIC VALUE PROMPTS WITH FEW-SHOT EXAMPLES +# ============================================================================ + + + + +ZEBRALOGIC_VALUE_PROMPT_WITH_FEWSHOT = """Evaluate a Zebra Logic puzzle solution trajectory. + +TASK PROMPT: +{problem} + +MODEL TRAJECTORY: +{trajectory} + +Here are examples of how to evaluate Zebra Logic trajectories: + +EXAMPLE 1 - SURE: +Puzzle: Houses with colors, pets, beverages, and nationalities with clues about relationships. +Trajectory: I've systematically worked through the constraints. House 1 has British resident. Red +house owner has Panda. Coffee drinker speaks Japanese. Working through elimination, I've determined +all houses uniquely and the solution satisfies all clues without contradictions. +Analysis: Systematic constraint satisfaction with clear justification for each assignment. Solution +verifiable. +Confidence: sure (9) + +EXAMPLE 2 - LIKELY: +Trajectory: Working through the clues methodically. I've identified several definite assignments +(House 2 has Swedish resident with bird). For the remaining houses, the constraints are narrowing +down possibilities and should lead to a unique solution. +Analysis: Reasonable progress using logic, but not yet complete verification of all constraints. +Confidence: likely (7) + +EXAMPLE 3 - POSSIBLE: +Trajectory: I understand the puzzle structure. I'm working through clues but some deductions are +unclear to me. I think House 1 might have the British resident, but I'm not certain. +Analysis: Shows problem understanding but lacks decisive constraint application. +Confidence: possible (5) + +EXAMPLE 4 - UNLIKELY: +Trajectory: I'm trying to assign attributes to houses. House 1 has red color and Swedish resident. +House 2 has green... wait, but green is next to red. I'm getting confused by the adjacency +constraints. +Analysis: Fundamental misunderstanding of spatial/logical constraints. +Confidence: unlikely (3) + +EXAMPLE 5 - IMPOSSIBLE: +Trajectory: I'm going to assign all attributes randomly since I don't see how the clues relate to +each other. +Analysis: Abandons logical reasoning without attempting systematic constraint satisfaction. +Confidence: impossible (1) + +Rubric for Zebra Logic: +- sure (9): Complete solution derived with clear constraint verification, all assignments justified +- likely (7): Systematic progress with mostly confident deductions, minor uncertainties remain +- possible (5): Some correct deductions but missing clear constraint application +- unlikely (3): Attempting logic but making errors in constraint application or showing confusion +- impossible (1): No meaningful attempt at systematic constraint satisfaction + +Respond with "Confidence: " followed by brief justification referencing the logical +deductions and constraint satisfaction. +""" + +ZEBRALOGIC_VALUE_PROMPT_SIMPLE = """Evaluate a Zebra Logic puzzle solution trajectory. + +TASK PROMPT: +{problem} + +MODEL TRAJECTORY: +{trajectory} + +Judge if the reasoning systematically applies logical constraints and whether the solution +assignments are well-justified. +Use the confidence rubric: +- sure (9): Complete solution with clear constraint verification +- likely (7): Systematic progress with mostly confident deductions +- possible (5): Some correct deductions with minor gaps +- unlikely (3): Attempting logic but making constraint errors +- impossible (1): No meaningful systematic reasoning + +Respond with "Confidence: " and brief justification referencing constraint satisfaction. +""" + + + +# ============================================================================ +# VERINA (LEAN 4 CODE GENERATION) VALUE PROMPTS WITH FEW-SHOT EXAMPLES +# ============================================================================ + + + + + +VERINA_VALUE_PROMPT_WITH_FEWSHOT = """Evaluate a Lean 4 code-generation reasoning trajectory. + +PROBLEM: +{problem} + +TRAJECTORY: +{trajectory} + +You are evaluating whether this reasoning trajectory is making SOUND PROGRESS toward correct Lean 4 +code. +The trajectory may or may not contain final code - that's fine. Judge the REASONING QUALITY. + +## WHAT MAKES REASONING GOOD vs BAD + +GOOD reasoning: +- Each step adds NEW concrete information (function name, operator, edge case, type) +- Mentions CORRECT Lean 4 constructs (foldl, match, recursion, List/Array methods) +- Identifies the right functional idioms (pure, immutable, no mutation) +- Builds toward a clear implementation path + +BAD reasoning: +- REPETITION: Multiple steps saying the same thing ("use XOR", "XOR will work", "apply XOR") +- WRONG IDIOMS: Mentions imperative patterns (mutable variables, imperative iteration) +- WRONG OPERATORS: Plans to use `^` for XOR (it's exponentiation!), Python syntax +- VAGUE: "Process the list" without specifying HOW (fold? recursion? map?) +- STUCK: Goes in circles without making progress + +## SCORING RUBRIC + +Ask: "Is this reasoning on track to produce CORRECT, COMPILABLE Lean 4 code?" + +- **sure (9)**: Clear path to correct code. Either has code, OR reasoning identifies correct Lean 4 +constructs with enough detail to write code immediately. No wrong idioms. +- **likely (7)**: Right direction with specific Lean 4 constructs mentioned. Minor gaps but +approach is sound. +- **possible (5)**: Correct algorithm but vague on Lean 4 specifics. OR has repetition (multiple +steps, one idea). +- **unlikely (3)**: Reasoning uses WRONG patterns that won't work in pure Lean 4 function bodies +(mutable state, `^` for XOR, imperative style). +- **impossible (1)**: Wrong language, nonsense, or completely off-track. + +## FEW-SHOT EXAMPLES + +Problem: Find single number in list where all others appear twice. + +EXAMPLE 1 — SURE (9): +Trajectory: "XOR all elements. XOR cancels duplicates (a⊕a=0). In Lean 4, use List.foldl with +Nat.xor and initial value 0." +Why SURE: Correct algorithm + correct Lean 4 construct (foldl, Nat.xor). Could write code now. + +EXAMPLE 2 — SURE (9): +Trajectory: "Use foldl to accumulate XOR. Start with 0, apply Nat.xor to each element. The property +a⊕a=0 ensures only the unique element survives." +Why SURE: Same information, phrased differently but complete. Ready to implement. + +EXAMPLE 3 — LIKELY (7): +Trajectory: "XOR all elements using some kind of fold operation. Lean has foldl for lists. Need to +find the right XOR function." +Why LIKELY: Knows the approach and foldl, but hasn't pinpointed Nat.xor yet. One detail away. + +EXAMPLE 4 — POSSIBLE (5): +Trajectory: "XOR all numbers together. The duplicates will cancel out leaving the unique one." +Why POSSIBLE: Correct algorithm but NO Lean 4 specifics. How to XOR? What function? + +EXAMPLE 5 — POSSIBLE (5) — REPETITION: +Trajectory: "XOR cancels duplicates. So XORing all elements gives the answer. The XOR operation is +the key insight here." +Why POSSIBLE: Three sentences, ONE idea repeated. No progress on HOW to implement in Lean 4. + +EXAMPLE 6 — UNLIKELY (3): +Trajectory: "Loop through the list and XOR each element with an accumulator using ^. Something +like: for x in nums, result = result ^ x" +Why UNLIKELY: `^` is exponentiation in Lean 4, not XOR. Mutable accumulator pattern won't work in a +pure function body. This reasoning leads to compile errors. + +EXAMPLE 7 — UNLIKELY (3): +Trajectory: "Create a mutable result variable initialized to 0. Iterate through nums, updating +result ^= x for each element." +Why UNLIKELY: Mutable variables and imperative iteration don't exist in pure Lean 4. Wrong paradigm +entirely. + +--- + +Problem: Return maximum of two natural numbers. + +EXAMPLE 8 — SURE (9): +Trajectory: "Use if-then-else: `if a >= b then a else b`. Nat has decidable ordering so comparison +works directly." +Why SURE: Exact syntax given, identifies why it works (decidable). Ready to implement. + +EXAMPLE 9 — LIKELY (7): +Trajectory: "Compare a and b with if-then-else. Return the larger one. Lean's Nat supports direct +comparison." +Why LIKELY: Right approach, right construct, just needs exact syntax. + +EXAMPLE 10 — POSSIBLE (5): +Trajectory: "Return whichever number is bigger." +Why POSSIBLE: Correct goal, but how? No Lean 4 specifics at all. + +--- + +Problem: Compute longest common subsequence length. + +EXAMPLE 11 — UNLIKELY (3): +Trajectory: "Build a 2D DP table. Use nested for loops: for i in 1..m, for j in 1..n, set dp[i][j] +based on matches." +Why UNLIKELY: Nested loops with mutable array updates don't work in pure Lean 4 function bodies. +Need functional approach (folds, recursion). + +EXAMPLE 12 — LIKELY (7): +Trajectory: "Use dynamic programming with foldl. Outer fold over rows, inner fold over columns. +Build each row as a new List, tracking previous row values." +Why LIKELY: Correct functional approach (foldl, immutable row building). Needs details but sound +path. + +--- + +## KEY PENALTIES + +1. **REPETITION**: If 3 steps contain only 1 unique idea → score based on that 1 idea only +2. **WRONG OPERATORS**: `^` for XOR, Python's `max()`, etc. → unlikely (3) at best +3. **IMPERATIVE PATTERNS**: mutable state, imperative iteration in pure function bodies → unlikely +(3) at best +4. **VAGUE ABSTRACTION**: "process", "iterate", "handle" without Lean specifics → possible (5) max + +## YOUR EVALUATION + +Respond with "Confidence: " followed by: +1. What NEW concrete information does each step add? (or note repetition) +2. Are the Lean 4 constructs mentioned correct? +3. Does the reasoning lead toward compilable code or toward errors? +""" + +VERINA_VALUE_PROMPT_SIMPLE = """Evaluate a Lean 4 code-generation reasoning trajectory. + +PROBLEM: +{problem} + +TRAJECTORY: +{trajectory} + +The trajectory contains step-by-step reasoning toward a Lean 4 function body. +Judge if the reasoning is on a sound path to producing valid Lean 4 that satisfies the +postcondition given the precondition. + +Key scoring factors: +- Could you write the final code RIGHT NOW from this trajectory? If yes → sure (9). +- Steps that rephrase or "verify" earlier steps without adding new specifics (syntax, types, edge +cases) count as repetition. Repetitive trajectories should NEVER score sure. +- A single high-level idea without Lean 4 specifics should score possible (5) at most. +- Concrete Lean 4 syntax, function names, and postcondition reasoning push toward likely (7) or +sure (9). + +Use the confidence rubric: +- sure (9): CODE-READY — correct Lean 4 syntax/functions AND postcondition logic, no gaps +- likely (7): Correct approach with specific Lean 4 constructs, minor gaps before code +- possible (5): Right direction but missing Lean 4 details, OR multi-step but repetitive +- unlikely (3): Fundamental issues in approach or idioms +- impossible (1): Completely off-track or wrong language + +Respond with "Confidence: " and brief justification. Note if steps are repetitive. +""" + + + +# ============================================================================ +# VERINA SPEC (LEAN 4 SPECIFICATION GENERATION) VALUE PROMPTS +# ============================================================================ + + + +VERINA_SPEC_VALUE_PROMPT_WITH_FEWSHOT = """ +Evaluate a Lean 4 specification-generation reasoning trajectory. + +PROBLEM: +{problem} + +TRAJECTORY: +{trajectory} + +You are evaluating whether this reasoning trajectory is making SOUND PROGRESS toward correct Lean 4 +specifications (preconditions and postconditions as Prop expressions). + +## WHAT MAKES SPEC REASONING GOOD vs BAD + +GOOD reasoning: +- Each step adds NEW concrete information (a constraint, a quantifier, a Lean type) +- Uses correct Lean 4 Prop syntax: `∧`, `∨`, `→`, `¬`, `∀`, `∃`, `True`, `False` +- Mentions real Lean 4 functions: `List.count`, `List.length`, `Array.size`, etc. +- Reasons about SOUNDNESS (rejects bad inputs/outputs) and COMPLETENESS (accepts all valid ones) +- Produces actual Prop expressions, not English descriptions + +BAD reasoning: +- NATURAL LANGUAGE: "The result must be..." instead of actual Lean syntax +- NON-EXISTENT FUNCTIONS: `lcsLength`, `maxSubarray`, made-up methods +- REPETITION: "The precondition should check X. Verify X is correct. Confirm X works." +- PYTHON SYNTAX: `result == max(a, b)`, `len(nums)`, `nums.count(x)` +- VAGUE: "The postcondition should ensure correctness" without specifics + +## CRITICAL DISTINCTION: Props vs English + +WRONG (English descriptions): +- "The input list must be non-empty" +- "The result should be the maximum value" +- "All elements must appear exactly twice" + +RIGHT (Lean Props): +- `nums.length > 0` +- `result >= a ∧ result >= b ∧ (result = a ∨ result = b)` +- `∀ x, x ∈ nums → (List.count x nums = 1 ∨ List.count x nums = 2)` + +## SCORING RUBRIC + +Ask: "Could I write compilable [PRECOND]...[/PRECOND] and [POSTCOND]...[/POSTCOND] from this +reasoning?" + +- **sure (9)**: Has (or is ready to produce) valid Lean Prop syntax. Mentions correct operators (∧, +∨, ∀, ∃). Uses real Lean functions. +- **likely (7)**: Right direction with specific Lean constructs identified. Needs exact syntax but +approach is sound. +- **possible (5)**: Correct intuition about what to specify, but vague on Lean syntax. OR +repetitive reasoning. +- **unlikely (3)**: Uses English descriptions, made-up functions, or Python syntax. Won't compile. +- **impossible (1)**: Wrong language, nonsense, or completely off-track. + +## FEW-SHOT EXAMPLES + +Problem: Find single number in list where all others appear twice. + +EXAMPLE 1 — SURE (9): +Trajectory: "Precondition: exactly one element has count 1, others have count 2. Use ∃ x, x ∈ nums +∧ List.count x nums = 1 ∧ (∀ y, y ∈ nums → y ≠ x → List.count y nums = 2). Postcondition: +List.count result nums = 1." +Why SURE: Valid Lean syntax with ∃, ∀, ∧, →. Uses real function List.count. Membership guard (∈ +nums) ensures quantifiers are well-scoped. Ready to write. + +EXAMPLE 2 — LIKELY (7): +Trajectory: "Need to say exactly one element appears once. Use existential quantifier and +List.count. The result should be that unique element." +Why LIKELY: Knows the constructs (∃, List.count) but hasn't assembled exact syntax yet. + +EXAMPLE 3 — POSSIBLE (5): +Trajectory: "The precondition should check that the list has a unique element. The postcondition +should verify the result is that element." +Why POSSIBLE: Right idea but NO Lean syntax. Just English descriptions. + +EXAMPLE 4 — UNLIKELY (3): +Trajectory: "The input list must be non-empty, and there must be exactly one element that appears +exactly once in the list." +Why UNLIKELY: Pure English prose. "The input list must be..." won't compile. Needs Prop syntax. + +EXAMPLE 5 — UNLIKELY (3): +Trajectory: "Postcondition: result = nums.findUnique() where findUnique returns the single +element." +Why UNLIKELY: `findUnique()` doesn't exist. Made-up method. + +--- + +Problem: Return maximum of two natural numbers. + +EXAMPLE 6 — SURE (9): +Trajectory: "Precondition: True (no restrictions on Nat). Postcondition: result >= a ∧ result >= b +∧ (result = a ∨ result = b). This is sound: rejects values not equal to a or b. Complete: actual +max satisfies all conjuncts." +Why SURE: Complete Prop with soundness/completeness reasoning. Exact syntax ready. + +EXAMPLE 7 — LIKELY (7): +Trajectory: "Postcondition needs three parts: result >= a, result >= b, and result equals one of +them. Use conjunction ∧ and disjunction ∨." +Why LIKELY: Knows the pieces and operators, just needs to assemble. + +EXAMPLE 8 — POSSIBLE (5): +Trajectory: "The postcondition should ensure the result is the maximum." +Why POSSIBLE: Right goal, but "the maximum" is English, not Lean. What IS maximum in Prop terms? + +EXAMPLE 9 — UNLIKELY (3): +Trajectory: "Postcondition: result == max(a, b)" +Why UNLIKELY: `==` is boolean equality (BEq), not propositional equality — use `=` for Props. +`max(a, b)` uses parenthesized call syntax — use `max a b` or `Nat.max a b` in Lean. + +--- + +Problem: Longest common subsequence length. + +EXAMPLE 10 — UNLIKELY (3): +Trajectory: "Postcondition: result = lcsLength a b where lcsLength computes the LCS." +Why UNLIKELY: `lcsLength` doesn't exist in Lean's standard library. Made-up function. + +EXAMPLE 11 — LIKELY (7): +Trajectory: "Need to define what LCS means. In POSTCOND_AUX, define a recursive function for +subsequences. Then postcondition says result = length of longest common one." +Why LIKELY: Acknowledges need for auxiliary definition. Approach is sound even if details missing. + +EXAMPLE 12 — POSSIBLE (5) — REPETITION: +Trajectory: "The LCS length should be computed correctly. Verify the computation is correct. +Confirm the result matches the LCS definition." +Why POSSIBLE: Three sentences, ONE idea. No actual Lean syntax or helper definition. + +--- + +## KEY PENALTIES + +1. **ENGLISH PROSE**: "The X must be Y" → unlikely (3) unless followed by Lean syntax +2. **MADE-UP FUNCTIONS**: Anything not in Lean stdlib without defining it → unlikely (3) +3. **PYTHON SYNTAX**: `==`, `len()`, `max()`, `.count()` without module → unlikely (3) +4. **REPETITION**: Same property phrased multiple ways → score based on unique content only +5. **NO SOUNDNESS/COMPLETENESS**: Specs without reasoning about what they accept/reject → likely +(7) max + +## YOUR EVALUATION + +Respond with "Confidence: " followed by: +1. Does the reasoning produce actual Lean Prop syntax or just English? +2. Are the Lean constructs mentioned real (List.count, ∧, ∀) or made-up? +3. Is there soundness/completeness reasoning? +4. Is there repetition? +""" + +VERINA_SPEC_VALUE_PROMPT_SIMPLE = """ +Evaluate a Lean 4 specification-generation reasoning trajectory. + +PROBLEM: +{problem} + +TRAJECTORY: +{trajectory} + +The trajectory contains step-by-step reasoning toward Lean 4 preconditions and postconditions. +Judge if the reasoning is on a sound path to producing valid, sound, and complete specifications. + +Key scoring factors: +- Could you write the final [PRECOND] and [POSTCOND] RIGHT NOW from this trajectory? If yes → sure +(9). +- Steps that rephrase or "verify" earlier steps without adding new specifics count as repetition. +Repetitive trajectories should NEVER score sure. +- A single high-level idea without Lean 4 Prop specifics should score possible (5) at most. +- Concrete Lean 4 Prop syntax (∧, ∨, ¬, ∀, ∃) and soundness/completeness reasoning push toward +likely (7) or sure (9). + +Use the confidence rubric: +- sure (9): SPEC-READY — correct Lean 4 Prop syntax AND soundness+completeness reasoning, no gaps +- likely (7): Correct approach with specific Lean 4 Prop constructs, minor gaps +- possible (5): Right direction but missing Lean 4 Prop details, OR multi-step but repetitive +- unlikely (3): Fundamental issues in approach or syntax +- impossible (1): Completely off-track or wrong language + +Respond with "Confidence: " and brief justification. Note if steps are repetitive. +""" + +# ============================================================================ +# GENERIC VALUE PROMPT +# ============================================================================ + +GENERIC_VALUE_PROMPT_WITH_FEWSHOT = """ +Evaluate how close the following trajectory is to solving the problem. + +PROBLEM: +{problem} + +CURRENT TRAJECTORY: +{trajectory} + +Here are examples of trajectory evaluations: + +EXAMPLE 1 (Strong progress): +Trajectory: I've broken down the problem into steps, identified key constraints, and I'm halfway +through the solution with correct logic so far. +Confidence: likely/probably (7) +Justification: Clear methodology and correct intermediate progress toward the solution. + +EXAMPLE 2 (Uncertain progress): +Trajectory: I've started the problem and my approach seems reasonable, but I'm not confident about +the next steps. +Confidence: possible/maybe (5) +Justification: Direction is sound but execution and completeness require verification. + +EXAMPLE 3 (Low chance of success): +Trajectory: I tried an approach but it seems to have led to a contradiction. +Confidence: unlikely/doubtful (3) +Justification: The approach has fundamental issues that need to be reconsidered. + +Rate the state on the scale: sure/certain (9), likely/probably (7), possible/maybe (5), +unlikely/doubtful (3), impossible/blocked (1). +Respond with "Confidence: " and a short rationale. +""" + +GENERIC_VALUE_PROMPT_SIMPLE = """ +Evaluate how close the following trajectory is to solving the problem. + +PROBLEM: +{problem} + +CURRENT TRAJECTORY: +{trajectory} + +Rate the state on the scale: sure/certain (9), likely/probably (7), possible/maybe (5), +unlikely/doubtful (3), impossible/blocked (1). +Respond with "Confidence: " and a short rationale. +""" + + +def build_game24_value_prompt( + problem: str, + trajectory: str, + use_fewshot: bool = True +) -> str: + """Build game24 value prompt with or without few-shot examples.""" + if use_fewshot: + return GAME24_VALUE_PROMPT_WITH_FEWSHOT.format( + problem=problem, trajectory=trajectory + ) + else: + return GAME24_VALUE_PROMPT_SIMPLE.format( + problem=problem, trajectory=trajectory + ) + + +def build_mcq_value_prompt( + problem: str, + trajectory: str, + task_name: str, + use_fewshot: bool = True +) -> str: + """Build MCQ (maze/spatial) value prompt with or without few-shot examples.""" + if task_name.lower() == "maze": + if use_fewshot: + return MAZE_VALUE_PROMPT_WITH_FEWSHOT.format( + problem=problem, trajectory=trajectory + ) + else: + return MAZE_VALUE_PROMPT_SIMPLE.format( + problem=problem, trajectory=trajectory + ) + elif task_name.lower() in ( + "spatial", "spatialmap", "spatial reasoning" + ): + if use_fewshot: + return SPATIALMAP_VALUE_PROMPT_WITH_FEWSHOT.format( + problem=problem, trajectory=trajectory + ) + else: + return SPATIALMAP_VALUE_PROMPT_SIMPLE.format( + problem=problem, trajectory=trajectory + ) + else: + # Default MCQ template + if use_fewshot: + return MAZE_VALUE_PROMPT_WITH_FEWSHOT.format( + problem=problem, trajectory=trajectory + ) + else: + return MAZE_VALUE_PROMPT_SIMPLE.format( + problem=problem, trajectory=trajectory + ) + + +def build_generic_value_prompt( + problem: str, + trajectory: str, + use_fewshot: bool = True +) -> str: + """Build generic value prompt with or without few-shot examples.""" + if use_fewshot: + return GENERIC_VALUE_PROMPT_WITH_FEWSHOT.format( + problem=problem, trajectory=trajectory + ) + else: + return GENERIC_VALUE_PROMPT_SIMPLE.format( + problem=problem, trajectory=trajectory + ) + + +def build_zebralogic_value_prompt( + problem: str, + trajectory: str, + use_fewshot: bool = True +) -> str: + """Build zebralogic value prompt with or without few-shot examples.""" + if use_fewshot: + return ZEBRALOGIC_VALUE_PROMPT_WITH_FEWSHOT.format( + problem=problem, trajectory=trajectory + ) + else: + return ZEBRALOGIC_VALUE_PROMPT_SIMPLE.format( + problem=problem, trajectory=trajectory + ) + + +def build_verina_value_prompt( + problem: str, + trajectory: str, + use_fewshot: bool = True +) -> str: + """Build verina (Lean 4 code generation) value prompt + with or without few-shot examples.""" + if use_fewshot: + return VERINA_VALUE_PROMPT_WITH_FEWSHOT.format( + problem=problem, trajectory=trajectory + ) + else: + return VERINA_VALUE_PROMPT_SIMPLE.format( + problem=problem, trajectory=trajectory + ) + + +def build_verina_spec_value_prompt( + problem: str, + trajectory: str, + use_fewshot: bool = True +) -> str: + """Build verina spec (Lean 4 spec generation) value prompt + with or without few-shot examples.""" + if use_fewshot: + return VERINA_SPEC_VALUE_PROMPT_WITH_FEWSHOT.format( + problem=problem, trajectory=trajectory + ) + else: + return VERINA_SPEC_VALUE_PROMPT_SIMPLE.format( + problem=problem, trajectory=trajectory + ) + + +def build_tot_value_prompt( + task: str, + problem: str, + trajectory: str, + use_fewshot: bool = True +) -> str: + """ + Build value prompt for Tree of Thought evaluation. + + Args: + task: The task type + (e.g., "game24", "maze", "spatialmap", + "zebralogic", "verina") + problem: The original problem statement + trajectory: Current partial solution + use_fewshot: Whether to use few-shot examples + (default True) + + Returns: + Formatted value prompt + """ + + if task == "game24": + return build_game24_value_prompt( + problem, trajectory, use_fewshot + ) + + + if task == "maze": + return build_mcq_value_prompt( + problem, trajectory, "maze", use_fewshot + ) + + + if task == "spatialmap": + return build_mcq_value_prompt( + problem, trajectory, "spatial reasoning", use_fewshot + ) + + + if task == "zebralogic": + return build_zebralogic_value_prompt( + problem, trajectory, use_fewshot + ) + + + if task == "verina": + return build_verina_value_prompt( + problem, trajectory, use_fewshot + ) + + + if task == "verina_spec": + return build_verina_spec_value_prompt( + problem, trajectory, use_fewshot + ) + + + return build_generic_value_prompt( + problem, trajectory, use_fewshot + ) \ No newline at end of file diff --git a/examples/TTSwithVerification/ToT/utils/verina_spec_tot_utils.py b/examples/TTSwithVerification/ToT/utils/verina_spec_tot_utils.py new file mode 100644 index 0000000..9e427c6 --- /dev/null +++ b/examples/TTSwithVerification/ToT/utils/verina_spec_tot_utils.py @@ -0,0 +1,994 @@ +import argparse +import asyncio +import json +import logging +import os +import re +import sys +import shutil +import subprocess +from contextlib import contextmanager +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Any +import numpy as np +import pandas as pd +import aiohttp +from datasets import load_dataset +from tqdm import tqdm +from transformers import AutoTokenizer + +__all__ = [ + # Path constants + "VERINA_ROOT", + "VERINA_DATASETS_PATH", + "LEAN_PLAYGROUND_DIR", + # Classes + "BenchmarkData", + # Constants + "CODE_TEST_MSG_MARKER", + "DECIDABLE_ERR_MSG", + "PRECOND_TEST_MSG_MARKER", + "POSTCOND_TEST_MSG_MARKER", + "PLAUSIBLE_SUCCESS_MSG", + "PLAUSIBLE_FAILED_MSG", + "PLAUSIBLE_TEST_COMMAND", + "parse_benchmark_lean_data", + "load_benchmark_data_from_task_dir", + "load_verina_dataset", + "render_param_list", + "strip_function_definition", + "create_lean_file", + "check_lean_compile", + "render_unit_test_value", + "render_code_unit_test", + "build_test_lean_file", + "parse_unit_test_results", + "evaluate_generated_code", + "evaluate_verina_answer", + # Spec generation functions + "build_spec_gen_prompt", + "build_verina_spec_prompt", + "extract_spec_from_response", + "make_aux_reducible", + "build_spec_test_lean_file", + "render_precond_sound_test", + "render_precond_complete_test", + "render_postcond_complete_test", + "render_postcond_sound_test", + "parse_spec_test_results", + "evaluate_generated_spec", +] + +_SCRIPT_DIR = Path(__file__).parent.resolve() +VERINA_ROOT = (_SCRIPT_DIR / "../../../verina").resolve() +VERINA_DATASETS_PATH = VERINA_ROOT / "datasets" / "verina" +LEAN_PLAYGROUND_DIR = VERINA_ROOT / "lean-playground" + +class BenchmarkData: + """Verina benchmark data structure""" + def __init__(self, data_id: str, description: str, signature: dict, + lean_data: dict, spec_desc: dict, tests: list, reject_inputs: list, metadata: dict): + self.data_id = data_id + self.description = description + self.signature = signature # {"name": str, "parameters": list, "return_type": str} + self.lean_data = lean_data # contains task_imports, task_aux, code, precond, postcond, proof, etc. + self.spec_desc = spec_desc # {"precond_desc": str, "postcond_desc": str} + self.tests = tests + self.reject_inputs = reject_inputs # For precondition completeness testing + self.metadata = metadata + + +def parse_benchmark_lean_data(raw_lean_data: str) -> dict: + """Parse a .lean file with !benchmark markers into sections""" + lines = raw_lean_data.strip().splitlines() + + lean_data = { + "task_imports": "", + "solution_imports": "", + "task_aux": "", + "solution_aux": "", + "code_aux": "", + "precond_aux": "", + "postcond_aux": "", + "proof_aux": "", + "code": "", + "precond": "True", + "postcond": "", + "proof": "sorry", + } + + current_section = None + current_content = [] + current_args = {} + + for line in lines: + if "-- !benchmark" in line: + marker_part = line.split("-- !benchmark", 1)[1].strip() + + if marker_part.startswith("@start"): + # Save previous section if any + if current_section is not None: + content = "\n".join(current_content).strip() + if current_section == "import": + import_type = current_args.get("type", "task") + if import_type == "task": + lean_data["task_imports"] += content + "\n" + elif import_type == "solution": + lean_data["solution_imports"] += content + "\n" + elif current_section in lean_data: + lean_data[current_section] = content + + # Start new section + parts = marker_part.split("@start", 1)[1].strip().split(None, 1) + current_section = parts[0].strip() + current_args = {} + current_content = [] + + if len(parts) > 1: + for arg in parts[1].strip().split(): + if "=" in arg: + key, value = arg.split("=", 1) + current_args[key] = value + + elif marker_part.startswith("@end"): + if current_section is not None: + content = "\n".join(current_content).strip() + if current_section == "import": + import_type = current_args.get("type", "task") + if import_type == "task": + lean_data["task_imports"] += content + "\n" + elif import_type == "solution": + lean_data["solution_imports"] += content + "\n" + elif current_section in lean_data: + lean_data[current_section] = content + current_section = None + current_content = [] + current_args = {} + else: + if current_section is not None: + current_content.append(line) + + return lean_data + + +def load_benchmark_data_from_task_dir(task_dir: Path) -> Optional[BenchmarkData]: + """Load a single benchmark task from its directory""" + task_path = task_dir / "task.json" + if not task_path.exists(): + return None + + try: + with open(task_path, "r") as f: + task_data = json.load(f) + + task_id = task_data.get("id") + if not task_id: + return None + + # Read description + desc_path = task_dir / task_data.get("description_file", "description.txt") + description = desc_path.read_text().strip() if desc_path.exists() else "" + + # Read signature + signature = task_data.get("signature", {}) + + # Read lean file + lean_path = task_dir / task_data.get("lean_file", "task.lean") + if lean_path.exists(): + lean_data = parse_benchmark_lean_data(lean_path.read_text()) + else: + lean_data = {} + + # Read spec description + spec_desc = { + "precond_desc": task_data.get("specification", {}).get("preconditions", ""), + "postcond_desc": task_data.get("specification", {}).get("postconditions", ""), + } + + # Read tests + test_path = task_dir / task_data.get("test_file", "test.json") + tests = [] + if test_path.exists(): + with open(test_path, "r") as f: + tests = json.load(f) + + # Load reject_inputs for precondition completeness testing + reject_inputs_path = task_dir / "reject_inputs.json" + reject_inputs = [] + if reject_inputs_path.exists(): + with open(reject_inputs_path, "r") as f: + reject_inputs = json.load(f) + + metadata = task_data.get("metadata", {}) + + return BenchmarkData( + data_id=task_id, + description=description, + signature=signature, + lean_data=lean_data, + spec_desc=spec_desc, + tests=tests, + reject_inputs=reject_inputs, + metadata=metadata, + ) + except Exception as e: + #logger.error(f"Error loading {task_dir}: {e}") + return None + + +def load_verina_dataset() -> List[BenchmarkData]: + """Load all verina benchmark tasks from the datasets directory""" + results = [] + + # Get all task directories sorted by ID + task_dirs = sorted( + [d for d in VERINA_DATASETS_PATH.glob("verina_*") if d.is_dir()], + key=lambda x: (x.name.split("_")[1], int(x.name.split("_")[-1])), + ) + + for task_dir in task_dirs: + data = load_benchmark_data_from_task_dir(task_dir) + if data: + results.append(data) + + #logger.info(f"Loaded {len(results)} verina tasks") + return results + + +def render_param_list(signature: dict) -> str: + """Render the parameter list for a function signature""" + params = signature.get("parameters", []) + rendered = "" + for param in params: + rendered += f"({param['param_name']} : {param['param_type']}) " + return rendered.strip() + + +def strip_function_definition(code: str) -> str: + """ + Strip function definition prefix if the model accidentally included it. + + The prompt asks for just the function body, but sometimes the model outputs: + def FunctionName (params) (h_precond : ...) : ReturnType := + actual_body + + We need to extract just 'actual_body' and dedent it properly. + """ + import textwrap + + code = code.strip() + + # Pattern to match Lean function definition: + # def (h_precond : ) : := + # The function body follows after := + func_def_pattern = r'^def\s+\w+\s+.*?:=[ \t]*\n?' + + match = re.match(func_def_pattern, code, re.DOTALL) + if match: + # Extract everything after the := + body = code[match.end():] + # Dedent to remove common leading whitespace from all lines + body = textwrap.dedent(body).strip() + if body: + return body + + return code + + +def create_lean_file(file_name: str, content: str) -> Path: + """Create a lean file in the playground directory""" + LEAN_PLAYGROUND_DIR.mkdir(parents=True, exist_ok=True) + lean_file = LEAN_PLAYGROUND_DIR / f"{file_name}.lean" + with open(lean_file, "w") as f: + f.write(content) + return lean_file + + +def check_lean_compile(lean_file: Path, timeout: int = 120) -> Tuple[bool, str]: + """Check if the Lean file compiles successfully""" + try: + result = subprocess.run( + ["lake", "lean", str(lean_file)], + check=False, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + timeout=timeout, + cwd=VERINA_ROOT, + ) + + output = result.stdout.decode() + "\n" + result.stderr.decode() + + if result.returncode == 0: + return True, output + else: + return False, output + + except subprocess.TimeoutExpired: + #logger.warning(f"Lean compilation timed out for {lean_file}") + return False, "TIMEOUT" + except Exception as e: + #logger.error(f"Error during compilation: {e}") + return False, f"ERROR: {e}" + + +# Unit Test Rendering +CODE_TEST_MSG_MARKER = "code_test" +DECIDABLE_ERR_MSG = "did not evaluate to `true`" + + +def render_unit_test_value(lean_type: str, value) -> str: + """Convert a Python value to Lean syntax based on type""" + if lean_type == "Bool": + return str(value).lower() + elif lean_type == "String": + return f'"{value}"' + elif lean_type == "Char": + return f"'{value}'" + else: + # For Int, List, Array, etc. - use value as-is (already in Lean format from JSON) + return str(value) + + +def render_code_unit_test(signature: dict, test_case: dict, test_idx: int) -> str: + """Render a single unit test using #guard""" + func_name = signature.get("name", "solution") + params = signature.get("parameters", []) + return_type = signature.get("return_type", "Bool") + + rendered = f'#print "<{CODE_TEST_MSG_MARKER}>{test_idx}"\n\n' + rendered += f"#guard {func_name}" + + for param in params: + param_name = param["param_name"] + param_type = param["param_type"] + input_value = test_case["input"].get(param_name, "") + rendered += f" ({render_unit_test_value(param_type, input_value)})" + + # Add (by sorry) to satisfy precondition hypothesis + rendered += " (by sorry)" + + # Add expected value comparison + expected = test_case.get("expected", "") + rendered += f" == ({render_unit_test_value(return_type, expected)})" + + return rendered + + +def build_test_lean_file(data: BenchmarkData, generated_code: str, include_unit_tests: bool = True) -> str: + """Build a complete Lean file to test the generated code""" + signature = data.signature + func_name = signature.get("name", "solution") + return_type = signature.get("return_type", "Bool") + param_list = render_param_list(signature) + params = signature.get("parameters", []) + param_names = " ".join([f"({p['param_name']})" for p in params]) + + # Indent multiline generated_code so all lines have proper indentation + # First line gets 2 spaces from template, subsequent lines need explicit indentation + if '\n' in generated_code: + lines = generated_code.split('\n') + # First line has no extra indent (template adds 2 spaces) + # Subsequent lines need 2 spaces prepended + indented_lines = [lines[0]] + [' ' + line if line.strip() else line for line in lines[1:]] + generated_code = '\n'.join(indented_lines) + + # Build imports - include both task and solution imports + task_imports = data.lean_data.get("task_imports", "").strip() + solution_imports = data.lean_data.get("solution_imports", "").strip() + imports = task_imports + if solution_imports: + imports += "\n" + solution_imports + if "import Mathlib" not in imports: + imports = "import Mathlib\n" + imports + + # Build auxiliary definitions - include solution_aux which has helper functions + solution_aux = data.lean_data.get("solution_aux", "").strip() + task_aux = data.lean_data.get("task_aux", "").strip() + precond_aux = data.lean_data.get("precond_aux", "").strip() + postcond_aux = data.lean_data.get("postcond_aux", "").strip() + code_aux = data.lean_data.get("code_aux", "").strip() + + # Build precondition + precond = data.lean_data.get("precond", "True").strip() + precond_name = f"{func_name}_precond" + + # Build postcondition + postcond = data.lean_data.get("postcond", "").strip() + postcond_name = f"{func_name}_postcond" + + lean_content = f"""{imports} + +-- Solution auxiliary definitions (helper functions) +{solution_aux} + +-- Task auxiliary definitions +{task_aux} + +-- Precondition auxiliary definitions +{precond_aux} + +@[reducible, simp] +def {precond_name} {param_list} : Prop := + {precond} + +-- Postcondition auxiliary definitions +{postcond_aux} + +-- Code auxiliary definitions +{code_aux} + +def {func_name} {param_list} (h_precond : {precond_name} {param_names}) : {return_type} := + {generated_code} + +@[reducible, simp] +def {postcond_name} {param_list} (result: {return_type}) (h_precond : {precond_name} {param_names}) : Prop := + {postcond} + +-- Verification theorem (compilation test) +-- If this compiles, the code at least type-checks +#check {func_name} +""" + + # Add unit tests if requested + if include_unit_tests and data.tests: + lean_content += "\n-- Unit Tests\n" + for idx, test_case in enumerate(data.tests): + lean_content += "\n" + render_code_unit_test(signature, test_case, idx) + "\n" + + return lean_content + + +def parse_unit_test_results(compile_output: str, num_tests: int) -> Tuple[int, int, dict]: + """ + Parse the compilation output to determine which unit tests passed/failed. + + Returns: (num_passed, num_failed, test_results_dict) + """ + test_results = {} + + # If compilation succeeded with no errors, all tests passed + if "error" not in compile_output.lower(): + for idx in range(num_tests): + test_results[idx] = "pass" + return num_tests, 0, test_results + + # Parse the output to find which tests failed + # Look for markers like 0 followed by error messages + code_test_start = f"<{CODE_TEST_MSG_MARKER}>" + code_test_end = f"" + + # Split by start marker to get test sections + parts = compile_output.split(code_test_start) + + # Build a map of test index to message + test_messages = {} + for part in parts[1:]: # Skip first part (before any marker) + if code_test_end in part: + idx_str, rest = part.split(code_test_end, 1) + try: + test_idx = int(idx_str.strip()) + test_messages[test_idx] = rest + except ValueError: + continue + + num_passed = 0 + num_failed = 0 + + for idx in range(num_tests): + msg = test_messages.get(idx, "") + if DECIDABLE_ERR_MSG in msg: + test_results[idx] = "fail" + num_failed += 1 + elif "error" in msg.lower(): + # Some other error (e.g., type mismatch) - count as fail + test_results[idx] = "error" + num_failed += 1 + else: + test_results[idx] = "pass" + num_passed += 1 + + return num_passed, num_failed, test_results + +PRECOND_TEST_MSG_MARKER = "precond_test" +POSTCOND_TEST_MSG_MARKER = "postcond_test" +PLAUSIBLE_SUCCESS_MSG = "Unable to find a counter-example" +PLAUSIBLE_FAILED_MSG = "Found a counter-example!" +PLAUSIBLE_TEST_COMMAND = "plausible ( config := { numInst := 1000, maxSize := 100, numRetries := 20, randomSeed := some 42})" + + +def build_spec_gen_prompt(data: BenchmarkData) -> Tuple[str, str]: + """ + Build a prompt for Lean 4 specification generation. + Returns (system_prompt, user_prompt) + """ + system_prompt = """You are an expert Lean 4 programmer specializing in formal specifications. +Generate valid Lean 4 preconditions and postconditions for the function described. + +The precondition should: +- Be as permissive as possible while ensuring the function can execute correctly +- Capture constraints on input values that are necessary for correct execution + +The postcondition should: +- Be sound: Only accept correct outputs (reject any incorrect output) +- Be complete: Accept all correct outputs (don't reject valid solutions) +- Fully specify the relationship between inputs and the expected output + +Wrap your precondition in [PRECOND]...[/PRECOND] tags. +Wrap your postcondition in [POSTCOND]...[/POSTCOND] tags. +If you need auxiliary definitions for precondition, wrap them in [PRECOND_AUX]...[/PRECOND_AUX] tags. +If you need auxiliary definitions for postcondition, wrap them in [POSTCOND_AUX]...[/POSTCOND_AUX] tags. +""" + + signature = data.signature + func_name = signature.get("name", "solution") + return_type = signature.get("return_type", "Bool") + param_list = render_param_list(signature) + params = signature.get("parameters", []) + param_names_str = ' '.join([f"({p['param_name']})" for p in params]) + + # Get ground truth code to show + code = data.lean_data.get("code", "").strip() + code_aux = data.lean_data.get("code_aux", "").strip() + task_aux = data.lean_data.get("task_aux", "").strip() + + # Natural language spec descriptions if available + precond_desc = data.spec_desc.get("precond_desc", "") + postcond_desc = data.spec_desc.get("postcond_desc", "") + + spec_desc_section = "" + if precond_desc or postcond_desc: + spec_desc_section = f""" +## Specification Hints +Precondition: {precond_desc if precond_desc else "Derive from task description"} +Postcondition: {postcond_desc if postcond_desc else "Derive from task description"} +""" + + helper_section = "" + if task_aux or code_aux: + all_aux = "\n".join(filter(None, [task_aux, code_aux])) + helper_section = f""" +## Helper Definitions +```lean4 +{all_aux} +``` +""" + + code_section = "" + if code: + code_section = f""" +## Reference Implementation +```lean4 +def {func_name} {param_list} (h_precond : {func_name}_precond {param_names_str}) : {return_type} := + {code} +``` +""" + + user_prompt = f"""## Task +{data.description} + +## Function Signature +- Function name: {func_name} +- Parameters: {param_list} +- Return type: {return_type} + +## Expected Output Format +```lean4 +-- Precondition auxiliary (optional) +[PRECOND_AUX] +-- helper definitions for precondition +[/PRECOND_AUX] + +-- Precondition: when should the function be allowed to run? +def {func_name}_precond {param_list} : Prop := + [PRECOND] + -- your precondition here (e.g., True, or constraints on inputs) + [/PRECOND] + +-- Postcondition auxiliary (optional) +[POSTCOND_AUX] +-- helper definitions for postcondition +[/POSTCOND_AUX] + +-- Postcondition: what must be true about the result? +def {func_name}_postcond {param_list} (result: {return_type}) (h_precond : {func_name}_precond {param_names_str}) : Prop := + [POSTCOND] + -- your postcondition here + [/POSTCOND] +``` +{spec_desc_section}{helper_section}{code_section} +Generate the precondition and postcondition. Use [PRECOND]...[/PRECOND] and [POSTCOND]...[/POSTCOND] tags.""" + + return system_prompt, user_prompt + + +def build_verina_spec_prompt(data: BenchmarkData) -> str: + """Build the full prompt string for the LLM for spec generation""" + system_prompt, user_prompt = build_spec_gen_prompt(data) + return f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" + + +def extract_spec_from_response(response: str) -> Dict[str, str]: + """Extract precondition and postcondition from response. + + Returns dict with keys: precond, postcond, precond_aux, postcond_aux + """ + result = { + "precond": "", + "postcond": "", + "precond_aux": "", + "postcond_aux": "", + } + + # Remove ... block + cleaned = re.sub(r'.*?', '', response, flags=re.DOTALL | re.IGNORECASE) + + # Handle partial + if not cleaned.strip() or cleaned.strip() == response.strip(): + think_end = response.lower().rfind("") + if think_end != -1: + cleaned = response[think_end + len(""):] + + # Extract PRECOND_AUX (take last match to get the most recent/corrected version) + precond_aux_matches = re.findall(r'\[PRECOND_AUX\](.*?)\[/PRECOND_AUX\]', cleaned, re.DOTALL | re.IGNORECASE) + if precond_aux_matches: + result["precond_aux"] = precond_aux_matches[-1].strip() + + # Extract POSTCOND_AUX (take last match) + postcond_aux_matches = re.findall(r'\[POSTCOND_AUX\](.*?)\[/POSTCOND_AUX\]', cleaned, re.DOTALL | re.IGNORECASE) + if postcond_aux_matches: + result["postcond_aux"] = postcond_aux_matches[-1].strip() + + # Extract PRECOND (take last match to get the most recent/corrected version) + precond_matches = re.findall(r'\[PRECOND\](.*?)\[/PRECOND\]', cleaned, re.DOTALL | re.IGNORECASE) + if precond_matches: + result["precond"] = precond_matches[-1].strip() + else: + precond_start_match = re.search(r'\[PRECOND\]\s*(.*)', cleaned, re.DOTALL | re.IGNORECASE) + if precond_start_match: + precond = precond_start_match.group(1).strip() + postcond_idx = precond.lower().find("[postcond") + if postcond_idx != -1: + precond = precond[:postcond_idx].strip() + result["precond"] = precond + + # Extract POSTCOND (take last match to get the most recent/corrected version) + postcond_matches = re.findall(r'\[POSTCOND\](.*?)\[/POSTCOND\]', cleaned, re.DOTALL | re.IGNORECASE) + if postcond_matches: + result["postcond"] = postcond_matches[-1].strip() + else: + postcond_start_match = re.search(r'\[POSTCOND\]\s*(.*)', cleaned, re.DOTALL | re.IGNORECASE) + if postcond_start_match: + postcond = postcond_start_match.group(1).strip() + result["postcond"] = postcond + + # Clean up any remaining markdown tags + for key in result: + result[key] = re.sub(r'```(lean4?)?\s*', '', result[key]) + result[key] = re.sub(r'```\s*$', '', result[key]) + + return result + + +def make_aux_reducible(aux: str) -> str: + """Add @[reducible, simp] to definitions if not present""" + lines = aux.split("\n") + result = [] + for i, line in enumerate(lines): + if line.strip().startswith("def "): + if i == 0 or "@[reducible, simp]" not in lines[i-1]: + result.append("@[reducible, simp]") + result.append(line) + return "\n".join(result) + + +def build_spec_test_lean_file( + data: BenchmarkData, + generated_spec: Dict[str, str], + test_type: str = "compile" +) -> str: + """Build a Lean file to test the generated specification.""" + signature = data.signature + func_name = signature.get("name", "solution") + return_type = signature.get("return_type", "Bool") + param_list = render_param_list(signature) + params = signature.get("parameters", []) + param_names = " ".join([f"({p['param_name']})" for p in params]) + + # Build imports + task_imports = data.lean_data.get("task_imports", "").strip() + solution_imports = data.lean_data.get("solution_imports", "").strip() + imports = task_imports + if solution_imports: + imports += "\n" + solution_imports + if "import Mathlib" not in imports: + imports = "import Mathlib\n" + imports + if "import Plausible" not in imports: + imports = "import Plausible\n" + imports + + # Build auxiliary definitions + task_aux = data.lean_data.get("task_aux", "").strip() + solution_aux = data.lean_data.get("solution_aux", "").strip() + + precond_name = f"{func_name}_precond" + postcond_name = f"{func_name}_postcond" + + # Use generated spec + precond = generated_spec.get("precond", "True").strip() + postcond = generated_spec.get("postcond", "").strip() + precond_aux = generated_spec.get("precond_aux", "").strip() + postcond_aux = generated_spec.get("postcond_aux", "").strip() + + if precond_aux: + precond_aux = make_aux_reducible(precond_aux) + if postcond_aux: + postcond_aux = make_aux_reducible(postcond_aux) + + lean_content = f"""{imports} + +-- Task auxiliary definitions +{task_aux} + +-- Solution auxiliary definitions +{solution_aux} + +-- Generated precondition auxiliary +{precond_aux} + +@[reducible, simp] +def {precond_name} {param_list} : Prop := + {precond} + +-- Generated postcondition auxiliary +{postcond_aux} + +@[reducible, simp] +def {postcond_name} {param_list} (result: {return_type}) (h_precond : {precond_name} {param_names}) : Prop := + {postcond} + +-- Compilation check +#check {precond_name} +#check {postcond_name} +""" + + # Add tests based on test_type + if test_type == "precond_sound": + lean_content += "\n-- Precondition Soundness Tests (valid inputs should satisfy precond)\n" + for idx, test_case in enumerate(data.tests): + lean_content += render_precond_sound_test(signature, precond_name, test_case, idx) + + elif test_type == "precond_complete": + lean_content += "\n-- Precondition Completeness Tests (invalid inputs should NOT satisfy precond)\n" + for idx, reject_input in enumerate(data.reject_inputs): + lean_content += render_precond_complete_test(signature, precond_name, reject_input, idx) + + elif test_type == "postcond_sound": + lean_content += "\n-- Postcondition Soundness Tests (wrong outputs should NOT satisfy postcond)\n" + global_idx = 0 + for idx, test_case in enumerate(data.tests): + unexpected_list = test_case.get("unexpected", []) + for unexpected_idx, unexpected in enumerate(unexpected_list): + lean_content += render_postcond_sound_test( + signature, precond_name, postcond_name, test_case, global_idx, unexpected, unexpected_idx + ) + global_idx += 1 + + elif test_type == "postcond_complete": + lean_content += "\n-- Postcondition Completeness Tests (correct outputs should satisfy postcond)\n" + for idx, test_case in enumerate(data.tests): + lean_content += render_postcond_complete_test( + signature, precond_name, postcond_name, test_case, idx + ) + + return lean_content + + +def render_precond_sound_test(signature: dict, precond_name: str, test_case: dict, test_idx: int) -> str: + """Render test: valid input should satisfy precondition""" + params = signature.get("parameters", []) + + rendered = f'\n#print "<{PRECOND_TEST_MSG_MARKER}_sound>{test_idx}"\n' + rendered += f"#guard decide ({precond_name}" + + for param in params: + param_name = param["param_name"] + param_type = param["param_type"] + input_value = test_case["input"].get(param_name, "") + rendered += f" ({render_unit_test_value(param_type, input_value)})" + + rendered += ")\n" + return rendered + + +def render_precond_complete_test(signature: dict, precond_name: str, reject_input: dict, test_idx: int) -> str: + """Render test: reject_input should NOT satisfy precondition""" + params = signature.get("parameters", []) + + rendered = f'\n#print "<{PRECOND_TEST_MSG_MARKER}_complete>{test_idx}"\n' + rendered += f"#guard decide (¬ ({precond_name}" + + for param in params: + param_name = param["param_name"] + param_type = param["param_type"] + input_value = reject_input.get("input", {}).get(param_name, "") + rendered += f" ({render_unit_test_value(param_type, input_value)})" + + rendered += "))\n" + return rendered + + +def render_postcond_complete_test(signature: dict, precond_name: str, postcond_name: str, test_case: dict, test_idx: int) -> str: + """Render test: expected output should satisfy postcondition""" + params = signature.get("parameters", []) + return_type = signature.get("return_type", "Bool") + + rendered = f'\n#print "<{POSTCOND_TEST_MSG_MARKER}_complete>{test_idx}"\n' + rendered += f"#guard decide ({postcond_name}" + + for param in params: + param_name = param["param_name"] + param_type = param["param_type"] + input_value = test_case["input"].get(param_name, "") + rendered += f" ({render_unit_test_value(param_type, input_value)})" + + expected = test_case.get("expected", "") + rendered += f" ({render_unit_test_value(return_type, expected)}) (by sorry))\n" + + return rendered + + +def render_postcond_sound_test( + signature: dict, precond_name: str, postcond_name: str, + test_case: dict, global_idx: int, unexpected: Any, unexpected_idx: int +) -> str: + """Render test: unexpected output should NOT satisfy postcondition""" + params = signature.get("parameters", []) + return_type = signature.get("return_type", "Bool") + test_idx = global_idx + + rendered = f'\n#print "<{POSTCOND_TEST_MSG_MARKER}_sound>{test_idx}"\n' + rendered += f"#guard decide (¬ ({postcond_name}" + + for param in params: + param_name = param["param_name"] + param_type = param["param_type"] + input_value = test_case["input"].get(param_name, "") + rendered += f" ({render_unit_test_value(param_type, input_value)})" + + rendered += f" ({render_unit_test_value(return_type, unexpected)}) (by sorry)))\n" + + return rendered + + +def parse_spec_test_results(compile_output: str, marker: str, num_tests: int) -> Tuple[int, int, dict]: + """Parse compilation output for spec test results.""" + test_results = {} + + if "error" not in compile_output.lower(): + for idx in range(num_tests): + test_results[idx] = "pass" + return num_tests, 0, test_results + + start_marker = f"<{marker}>" + end_marker = f"" + + parts = compile_output.split(start_marker) + test_messages = {} + + for part in parts[1:]: + if end_marker in part: + idx_str, rest = part.split(end_marker, 1) + try: + test_idx = int(idx_str.strip().split(",")[0]) + test_messages[test_idx] = rest + except ValueError: + continue + + num_passed = 0 + num_failed = 0 + + for idx in range(num_tests): + msg = test_messages.get(idx, "") + if DECIDABLE_ERR_MSG in msg: + test_results[idx] = "fail" + num_failed += 1 + elif "error" in msg.lower(): + test_results[idx] = "error" + num_failed += 1 + else: + test_results[idx] = "pass" + num_passed += 1 + + return num_passed, num_failed, test_results + + +def evaluate_generated_spec( + data: BenchmarkData, + generated_spec: Dict[str, str], + task_idx: int +) -> Dict[str, Any]: + """Evaluate the generated specification using soundness and completeness tests.""" + result = { + "compiles": False, + "precond_sound_pass": 0, + "precond_sound_total": 0, + "precond_complete_pass": 0, + "precond_complete_total": 0, + "postcond_sound_pass": 0, + "postcond_sound_total": 0, + "postcond_complete_pass": 0, + "postcond_complete_total": 0, + "precond_correct": False, + "postcond_correct": False, + "spec_sound": False, + "spec_complete": False, + "full_spec_correct": False, + "compile_error": "", + } + + # First check if spec compiles + compile_content = build_spec_test_lean_file(data, generated_spec, "compile") + lean_file = create_lean_file(f"spec_compile_{data.data_id}_{task_idx}", compile_content) + compiles, output = check_lean_compile(lean_file) + + result["compiles"] = compiles + if not compiles: + result["compile_error"] = output[:500] + return result + + # Test precondition soundness + if data.tests: + precond_sound_content = build_spec_test_lean_file(data, generated_spec, "precond_sound") + lean_file = create_lean_file(f"spec_precond_sound_{data.data_id}_{task_idx}", precond_sound_content) + _, output = check_lean_compile(lean_file) + + result["precond_sound_total"] = len(data.tests) + passed, failed, _ = parse_spec_test_results(output, f"{PRECOND_TEST_MSG_MARKER}_sound", len(data.tests)) + result["precond_sound_pass"] = passed + + # Test precondition completeness + if data.reject_inputs: + precond_complete_content = build_spec_test_lean_file(data, generated_spec, "precond_complete") + lean_file = create_lean_file(f"spec_precond_complete_{data.data_id}_{task_idx}", precond_complete_content) + _, output = check_lean_compile(lean_file) + + result["precond_complete_total"] = len(data.reject_inputs) + passed, failed, _ = parse_spec_test_results(output, f"{PRECOND_TEST_MSG_MARKER}_complete", len(data.reject_inputs)) + result["precond_complete_pass"] = passed + + # Test postcondition completeness + if data.tests: + postcond_complete_content = build_spec_test_lean_file(data, generated_spec, "postcond_complete") + lean_file = create_lean_file(f"spec_postcond_complete_{data.data_id}_{task_idx}", postcond_complete_content) + _, output = check_lean_compile(lean_file) + + result["postcond_complete_total"] = len(data.tests) + passed, failed, _ = parse_spec_test_results(output, f"{POSTCOND_TEST_MSG_MARKER}_complete", len(data.tests)) + result["postcond_complete_pass"] = passed + + # Test postcondition soundness + total_unexpected = sum(len(t.get("unexpected", [])) for t in data.tests) if data.tests else 0 + if total_unexpected > 0: + postcond_sound_content = build_spec_test_lean_file(data, generated_spec, "postcond_sound") + lean_file = create_lean_file(f"spec_postcond_sound_{data.data_id}_{task_idx}", postcond_sound_content) + _, output = check_lean_compile(lean_file) + + result["postcond_sound_total"] = total_unexpected + passed, failed, _ = parse_spec_test_results(output, f"{POSTCOND_TEST_MSG_MARKER}_sound", total_unexpected) + result["postcond_sound_pass"] = passed + + # Compute combined correctness metrics + precond_sound_all_pass = (result["precond_sound_pass"] == result["precond_sound_total"] and result["precond_sound_total"] > 0) or result["precond_sound_total"] == 0 + precond_complete_all_pass = (result["precond_complete_pass"] == result["precond_complete_total"] and result["precond_complete_total"] > 0) or result["precond_complete_total"] == 0 + result["precond_correct"] = precond_sound_all_pass and precond_complete_all_pass + + postcond_sound_all_pass = (result["postcond_sound_pass"] == result["postcond_sound_total"] and result["postcond_sound_total"] > 0) or result["postcond_sound_total"] == 0 + postcond_complete_all_pass = (result["postcond_complete_pass"] == result["postcond_complete_total"] and result["postcond_complete_total"] > 0) or result["postcond_complete_total"] == 0 + result["postcond_correct"] = postcond_sound_all_pass and postcond_complete_all_pass + + result["spec_sound"] = precond_sound_all_pass and postcond_sound_all_pass + result["spec_complete"] = precond_complete_all_pass and postcond_complete_all_pass + result["full_spec_correct"] = result["precond_correct"] and result["postcond_correct"] + + return result + diff --git a/examples/TTSwithVerification/ToT/utils/verina_tot_utils.py b/examples/TTSwithVerification/ToT/utils/verina_tot_utils.py new file mode 100644 index 0000000..35d22ed --- /dev/null +++ b/examples/TTSwithVerification/ToT/utils/verina_tot_utils.py @@ -0,0 +1,651 @@ +import argparse +import asyncio +import json +import logging +import os +import re +import sys +import shutil +import subprocess +from contextlib import contextmanager +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional, Tuple +import numpy as np +import pandas as pd +import aiohttp +from datasets import load_dataset +from tqdm import tqdm +from transformers import AutoTokenizer + +__all__ = [ + # Path constants + "VERINA_ROOT", + "VERINA_DATASETS_PATH", + "LEAN_PLAYGROUND_DIR", + # Classes + "BenchmarkData", + # Constants + "CODE_TEST_MSG_MARKER", + "DECIDABLE_ERR_MSG", + # Functions + "parse_benchmark_lean_data", + "load_benchmark_data_from_task_dir", + "load_verina_dataset", + "render_param_list", + "build_verina_prompt", + "strip_function_definition", + "extract_code_from_response", + "create_lean_file", + "check_lean_compile", + "render_unit_test_value", + "render_code_unit_test", + "build_test_lean_file", + "parse_unit_test_results", + "evaluate_generated_code", +] + + + +_SCRIPT_DIR = Path(__file__).parent.resolve() +VERINA_ROOT = (_SCRIPT_DIR / "../../../verina").resolve() +VERINA_DATASETS_PATH = VERINA_ROOT / "datasets" / "verina" +LEAN_PLAYGROUND_DIR = VERINA_ROOT / "lean-playground" + + +class BenchmarkData: + """Verina benchmark data structure""" + def __init__(self, data_id: str, description: str, signature: dict, + lean_data: dict, spec_desc: dict, tests: list, metadata: dict): + self.data_id = data_id + self.description = description + self.signature = signature # {"name": str, "parameters": list, "return_type": str} + self.lean_data = lean_data # contains task_imports, task_aux, code, precond, postcond, proof, etc. + self.spec_desc = spec_desc # {"precond_desc": str, "postcond_desc": str} + self.tests = tests + self.metadata = metadata + + +def parse_benchmark_lean_data(raw_lean_data: str) -> dict: + """Parse a .lean file with !benchmark markers into sections""" + lines = raw_lean_data.strip().splitlines() + + lean_data = { + "task_imports": "", + "solution_imports": "", + "task_aux": "", + "solution_aux": "", + "code_aux": "", + "precond_aux": "", + "postcond_aux": "", + "proof_aux": "", + "code": "", + "precond": "True", + "postcond": "", + "proof": "sorry", + } + + current_section = None + current_content = [] + current_args = {} + + for line in lines: + if "-- !benchmark" in line: + marker_part = line.split("-- !benchmark", 1)[1].strip() + + if marker_part.startswith("@start"): + # Save previous section if any + if current_section is not None: + content = "\n".join(current_content).strip() + if current_section == "import": + import_type = current_args.get("type", "task") + if import_type == "task": + lean_data["task_imports"] += content + "\n" + elif import_type == "solution": + lean_data["solution_imports"] += content + "\n" + elif current_section in lean_data: + lean_data[current_section] = content + + # Start new section + parts = marker_part.split("@start", 1)[1].strip().split(None, 1) + current_section = parts[0].strip() + current_args = {} + current_content = [] + + if len(parts) > 1: + for arg in parts[1].strip().split(): + if "=" in arg: + key, value = arg.split("=", 1) + current_args[key] = value + + elif marker_part.startswith("@end"): + if current_section is not None: + content = "\n".join(current_content).strip() + if current_section == "import": + import_type = current_args.get("type", "task") + if import_type == "task": + lean_data["task_imports"] += content + "\n" + elif import_type == "solution": + lean_data["solution_imports"] += content + "\n" + elif current_section in lean_data: + lean_data[current_section] = content + current_section = None + current_content = [] + current_args = {} + else: + if current_section is not None: + current_content.append(line) + + return lean_data + + +def load_benchmark_data_from_task_dir(task_dir: Path) -> Optional[BenchmarkData]: + """Load a single benchmark task from its directory""" + task_path = task_dir / "task.json" + if not task_path.exists(): + return None + + try: + with open(task_path, "r") as f: + task_data = json.load(f) + + task_id = task_data.get("id") + if not task_id: + return None + + # Read description + desc_path = task_dir / task_data.get("description_file", "description.txt") + description = desc_path.read_text().strip() if desc_path.exists() else "" + + # Read signature + signature = task_data.get("signature", {}) + + # Read lean file + lean_path = task_dir / task_data.get("lean_file", "task.lean") + if lean_path.exists(): + lean_data = parse_benchmark_lean_data(lean_path.read_text()) + else: + lean_data = {} + + # Read spec description + spec_desc = { + "precond_desc": task_data.get("specification", {}).get("preconditions", ""), + "postcond_desc": task_data.get("specification", {}).get("postconditions", ""), + } + + # Read tests + test_path = task_dir / task_data.get("test_file", "test.json") + tests = [] + if test_path.exists(): + with open(test_path, "r") as f: + tests = json.load(f) + + metadata = task_data.get("metadata", {}) + + return BenchmarkData( + data_id=task_id, + description=description, + signature=signature, + lean_data=lean_data, + spec_desc=spec_desc, + tests=tests, + metadata=metadata, + ) + except Exception as e: + return None + + +def load_verina_dataset() -> List[BenchmarkData]: + """Load all verina benchmark tasks from the datasets directory""" + results = [] + + # Get all task directories sorted by ID + task_dirs = sorted( + [d for d in VERINA_DATASETS_PATH.glob("verina_*") if d.is_dir()], + key=lambda x: (x.name.split("_")[1], int(x.name.split("_")[-1])), + ) + + for task_dir in task_dirs: + data = load_benchmark_data_from_task_dir(task_dir) + if data: + results.append(data) + + return results + + +def render_param_list(signature: dict) -> str: + """Render the parameter list for a function signature""" + params = signature.get("parameters", []) + rendered = "" + for param in params: + rendered += f"({param['param_name']} : {param['param_type']}) " + return rendered.strip() + + +def build_verina_prompt(data: BenchmarkData) -> str: + """Build prompt for Verina Lean code generation (build_code_gen_prompt + build_full_prompt)""" + system_prompt = "You are an expert Lean 4 programmer. Generate valid Lean 4 code for the function body. Wrap your final code in [CODE] [/CODE] tags strictly." + + signature = data.signature + func_name = signature.get("name", "solution") + return_type = signature.get("return_type", "Bool") + param_list = render_param_list(signature) + params = signature.get("parameters", []) + + precond_name = f"{func_name}_precond" + param_names_str = ' '.join([f"({p['param_name']})" for p in params]) + + # Get auxiliary definitions (only if they exist) + solution_aux = data.lean_data.get("solution_aux", "").strip() + task_aux = data.lean_data.get("task_aux", "").strip() + code_aux = data.lean_data.get("code_aux", "").strip() + precond = data.lean_data.get("precond", "True").strip() + postcond = data.lean_data.get("postcond", "").strip() + + # Build helper section only if there are helpers + helper_section = "" + all_aux = "\n".join(filter(None, [solution_aux, task_aux, code_aux])) + if all_aux: + helper_section = f""" +## Helper Definitions +```lean4 +{all_aux} +``` +""" + + user_prompt = f"""## Task +{data.description} + +## Function Signature +```lean4 +def {func_name} {param_list} (h_precond : {precond_name} {param_names_str}) : {return_type} := + -- YOUR CODE HERE (just output this part inside [CODE] [/CODE] tags) +``` + +## Precondition +```lean4 +def {precond_name} {param_list} : Prop := {precond} +``` + +## Postcondition +```lean4 +def {func_name}_postcond {param_list} (result: {return_type}) : Prop := {postcond} +``` +{helper_section} +Provide ONLY the function body expression wrapped in [CODE]...[/CODE] tags.""" + + return ( + f"<|im_start|>system\n{system_prompt}<|im_end|>\n" + f"<|im_start|>user\n{user_prompt}<|im_end|>\n" + f"<|im_start|>assistant\n" + ) + + +def strip_function_definition(code: str) -> str: + """ + Strip function definition prefix if the model accidentally included it. + + The prompt asks for just the function body, but sometimes the model outputs: + def FunctionName (params) (h_precond : ...) : ReturnType := + actual_body + + We need to extract just 'actual_body' and dedent it properly. + """ + import textwrap + + code = code.strip() + + # Pattern to match Lean function definition: + # def (h_precond : ) : := + # The function body follows after := + func_def_pattern = r'^def\s+\w+\s+.*?:=[ \t]*\n?' + + match = re.match(func_def_pattern, code, re.DOTALL) + if match: + # Extract everything after the := + body = code[match.end():] + # Dedent to remove common leading whitespace from all lines + body = textwrap.dedent(body).strip() + if body: + return body + + return code + + +def extract_code_from_response(response: str) -> str: + """Extract code from the LAST [CODE]...[/CODE] tags or lean code blocks. + + Handles cases where: + 1. Response has ... reasoning block + 2. [CODE] tag exists but [/CODE] may be missing (truncated response) + 3. Code is in markdown lean blocks + 4. Model outputs [CORE] or other variants instead of [CODE] + 5. Model uses mismatched tags like [CORE]...[/CODE] + 6. Model includes full function definition instead of just the body + """ + # Step 1: Remove ... block entirely (case insensitive) + # This prevents extracting reasoning text as code + cleaned = re.sub(r'.*?', '', response, flags=re.DOTALL | re.IGNORECASE) + + # If exists but doesn't match (partial), take everything after + if not cleaned.strip() or cleaned.strip() == response.strip(): + think_end = response.lower().rfind("") + if think_end != -1: + cleaned = response[think_end + len(""):] + + extracted_code = None + + # Step 2: Find the last closing tag [/CODE] or [/CORE] and work backwards to find opening tag + # This handles mismatched tags like [CORE]...[/CODE] + closing_pattern = r'\[/(?:CODE|CORE)\]' + closing_matches = list(re.finditer(closing_pattern, cleaned, re.IGNORECASE)) + + if closing_matches: + # Get position of the last closing tag + last_close = closing_matches[-1] + close_pos = last_close.start() + + # Search backwards for the last opening tag before this closing tag + text_before_close = cleaned[:close_pos] + opening_pattern = r'\[(?:CODE|CORE|CORRECTED CODE)\]' + opening_matches = list(re.finditer(opening_pattern, text_before_close, re.IGNORECASE)) + + if opening_matches: + last_open = opening_matches[-1] + extracted_code = cleaned[last_open.end():close_pos].strip() + + # Step 3: Try [CODE] without closing tag (truncated response) - find the LAST one + if extracted_code is None: + code_start_matches = list(re.finditer(r'\[(?:CODE|CORE|CORRECTED CODE)\]\s*', cleaned, re.DOTALL | re.IGNORECASE)) + if code_start_matches: + # Get the last [CODE] tag position and extract everything after it + last_match = code_start_matches[-1] + code = cleaned[last_match.end():].strip() + # Remove any trailing incomplete text that looks like reasoning + # Stop at any line that looks like it's not code (e.g., starts with "Wait", "So", etc.) + lines = code.split('\n') + code_lines = [] + for line in lines: + stripped = line.strip() + # Stop if we hit obvious non-code reasoning text + if stripped and re.match(r'^(Wait|So |But |Now|Note|The |This |However|Therefore|Thus|In |Since)', stripped): + break + code_lines.append(line) + if code_lines: + extracted_code = '\n'.join(code_lines).strip() + + # Step 4: Try markdown lean code blocks (find the LAST one) + if extracted_code is None: + lean_matches = list(re.finditer(r'```lean4?\s*\n(.*?)```', cleaned, re.DOTALL | re.IGNORECASE)) + if lean_matches: + extracted_code = lean_matches[-1].group(1).strip() + + # Step 5: Try lean block without closing (truncated) - find the LAST one + if extracted_code is None: + lean_start_matches = list(re.finditer(r'```lean4?\s*\n', cleaned, re.DOTALL | re.IGNORECASE)) + if lean_start_matches: + last_match = lean_start_matches[-1] + code = cleaned[last_match.end():].strip() + # Remove trailing ``` if present + extracted_code = re.sub(r'```\s*$', '', code).strip() + + # Step 6: Last resort, return cleaned content if it looks like code + if extracted_code is None: + cleaned = cleaned.strip() + if cleaned: + # Filter out lines that look like reasoning + lines = cleaned.split('\n') + code_lines = [] + for line in lines: + stripped = line.strip() + if stripped and not re.match(r'^(Wait|So |But |Now|Note|The |This |However|Therefore|Thus|In |Since|I |We |You )', stripped): + code_lines.append(line) + if code_lines: + extracted_code = '\n'.join(code_lines).strip() + + # Step 7: Strip function definition prefix if model included it + # The prompt asks for just the body, but sometimes model outputs full "def ... :=" + if extracted_code: + extracted_code = strip_function_definition(extracted_code) + return extracted_code + + return "" + + +def create_lean_file(file_name: str, content: str) -> Path: + """Create a lean file in the playground directory""" + LEAN_PLAYGROUND_DIR.mkdir(parents=True, exist_ok=True) + lean_file = LEAN_PLAYGROUND_DIR / f"{file_name}.lean" + with open(lean_file, "w") as f: + f.write(content) + return lean_file + + +def check_lean_compile(lean_file: Path, timeout: int = 120) -> Tuple[bool, str]: + """Check if the Lean file compiles successfully""" + try: + result = subprocess.run( + ["lake", "lean", str(lean_file)], + check=False, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + timeout=timeout, + cwd=VERINA_ROOT, + ) + + output = result.stdout.decode() + "\n" + result.stderr.decode() + + if result.returncode == 0: + return True, output + else: + return False, output + + except subprocess.TimeoutExpired: + return False, "TIMEOUT" + except Exception as e: + return False, f"ERROR: {e}" + + +# Unit Test Rendering +CODE_TEST_MSG_MARKER = "code_test" +DECIDABLE_ERR_MSG = "did not evaluate to `true`" + + +def render_unit_test_value(lean_type: str, value) -> str: + """Convert a Python value to Lean syntax based on type""" + if lean_type == "Bool": + return str(value).lower() + elif lean_type == "String": + return f'"{value}"' + elif lean_type == "Char": + return f"'{value}'" + else: + return str(value) + + +def render_code_unit_test(signature: dict, test_case: dict, test_idx: int) -> str: + """Render a single unit test using #guard""" + func_name = signature.get("name", "solution") + params = signature.get("parameters", []) + return_type = signature.get("return_type", "Bool") + + rendered = f'#print "<{CODE_TEST_MSG_MARKER}>{test_idx}"\n\n' + rendered += f"#guard {func_name}" + + for param in params: + param_name = param["param_name"] + param_type = param["param_type"] + input_value = test_case["input"].get(param_name, "") + rendered += f" ({render_unit_test_value(param_type, input_value)})" + + # Add (by sorry) to satisfy precondition hypothesis + rendered += " (by sorry)" + + # Add expected value comparison + expected = test_case.get("expected", "") + rendered += f" == ({render_unit_test_value(return_type, expected)})" + + return rendered + + +def build_test_lean_file(data: BenchmarkData, generated_code: str, include_unit_tests: bool = True) -> str: + """Build a complete Lean file to test the generated code""" + signature = data.signature + func_name = signature.get("name", "solution") + return_type = signature.get("return_type", "Bool") + param_list = render_param_list(signature) + params = signature.get("parameters", []) + param_names = " ".join([f"({p['param_name']})" for p in params]) + + # Indent multiline generated_code so all lines have proper indentation + # First line gets 2 spaces from template, subsequent lines need explicit indentation + if '\n' in generated_code: + lines = generated_code.split('\n') + # First line has no extra indent (template adds 2 spaces) + # Subsequent lines need 2 spaces prepended + indented_lines = [lines[0]] + [' ' + line if line.strip() else line for line in lines[1:]] + generated_code = '\n'.join(indented_lines) + + # Build imports - include both task and solution imports + task_imports = data.lean_data.get("task_imports", "").strip() + solution_imports = data.lean_data.get("solution_imports", "").strip() + imports = task_imports + if solution_imports: + imports += "\n" + solution_imports + if "import Mathlib" not in imports: + imports = "import Mathlib\n" + imports + + # Build auxiliary definitions - include solution_aux which has helper functions + solution_aux = data.lean_data.get("solution_aux", "").strip() + task_aux = data.lean_data.get("task_aux", "").strip() + precond_aux = data.lean_data.get("precond_aux", "").strip() + postcond_aux = data.lean_data.get("postcond_aux", "").strip() + code_aux = data.lean_data.get("code_aux", "").strip() + + # Build precondition + precond = data.lean_data.get("precond", "True").strip() + precond_name = f"{func_name}_precond" + + # Build postcondition + postcond = data.lean_data.get("postcond", "").strip() + postcond_name = f"{func_name}_postcond" + + lean_content = f"""{imports} + +-- Solution auxiliary definitions (helper functions) +{solution_aux} + +-- Task auxiliary definitions +{task_aux} + +-- Precondition auxiliary definitions +{precond_aux} + +@[reducible, simp] +def {precond_name} {param_list} : Prop := + {precond} + +-- Postcondition auxiliary definitions +{postcond_aux} + +-- Code auxiliary definitions +{code_aux} + +def {func_name} {param_list} (h_precond : {precond_name} {param_names}) : {return_type} := + {generated_code} + +@[reducible, simp] +def {postcond_name} {param_list} (result: {return_type}) (h_precond : {precond_name} {param_names}) : Prop := + {postcond} + +-- Verification theorem (compilation test) +-- If this compiles, the code at least type-checks +#check {func_name} +""" + + # Add unit tests if requested + if include_unit_tests and data.tests: + lean_content += "\n-- Unit Tests\n" + for idx, test_case in enumerate(data.tests): + lean_content += "\n" + render_code_unit_test(signature, test_case, idx) + "\n" + + return lean_content + + +def parse_unit_test_results(compile_output: str, num_tests: int) -> Tuple[int, int, dict]: + """ + Parse the compilation output to determine which unit tests passed/failed. + + Returns: (num_passed, num_failed, test_results_dict) + """ + test_results = {} + + # If compilation succeeded with no errors, all tests passed + if "error" not in compile_output.lower(): + for idx in range(num_tests): + test_results[idx] = "pass" + return num_tests, 0, test_results + + # Parse the output to find which tests failed + # Look for markers like 0 followed by error messages + code_test_start = f"<{CODE_TEST_MSG_MARKER}>" + code_test_end = f"" + + # Split by start marker to get test sections + parts = compile_output.split(code_test_start) + + # Build a map of test index to message + test_messages = {} + for part in parts[1:]: # Skip first part (before any marker) + if code_test_end in part: + idx_str, rest = part.split(code_test_end, 1) + try: + test_idx = int(idx_str.strip()) + test_messages[test_idx] = rest + except ValueError: + continue + + num_passed = 0 + num_failed = 0 + + for idx in range(num_tests): + msg = test_messages.get(idx, "") + if DECIDABLE_ERR_MSG in msg: + test_results[idx] = "fail" + num_failed += 1 + elif "error" in msg.lower(): + # Some other error (e.g., type mismatch) - count as fail + test_results[idx] = "error" + num_failed += 1 + else: + test_results[idx] = "pass" + num_passed += 1 + + return num_passed, num_failed, test_results + + +def evaluate_generated_code(data: BenchmarkData, generated_code: str, task_idx: int) -> Tuple[bool, bool, str, dict]: + """ + Evaluate the generated code by compiling it with Lean and running unit tests. + + Returns: (compiles, all_tests_pass, output, test_results) + """ + lean_content = build_test_lean_file(data, generated_code, include_unit_tests=True) + + # Create lean file + lean_file = create_lean_file(f"test_{data.data_id}_{task_idx}", lean_content) + + # Check compilation (which also runs unit tests via #guard) + compiles, output = check_lean_compile(lean_file) + + # Parse unit test results + num_tests = len(data.tests) if data.tests else 0 + if num_tests > 0: + num_passed, num_failed, test_results = parse_unit_test_results(output, num_tests) + all_tests_pass = (num_failed == 0) and compiles + else: + # No tests, just check compilation + test_results = {} + all_tests_pass = compiles + + return compiles, all_tests_pass, output, test_results \ No newline at end of file diff --git a/examples/TTSwithVerification/ToT/utils/zebralogic_helper.py b/examples/TTSwithVerification/ToT/utils/zebralogic_helper.py new file mode 100644 index 0000000..8a4bb0b --- /dev/null +++ b/examples/TTSwithVerification/ToT/utils/zebralogic_helper.py @@ -0,0 +1,284 @@ +""" +Prompt and dataset utilities for WildEval/ZebraLogic. +""" + +import re +import json +from collections import defaultdict +from typing import Dict, List, Tuple, Optional, Set +from importlib import resources +import datasets + + +# ============== Prompt Templates ============== + +SYSTEM_PROMPT_VANILLA = """\ +# Problem Description + +You are solving a house grid problem. You are given: +1. Features and Domains + - A fixed number of houses, indexed sequentially (e.g., House 1, House 2, …) from left to right. + - A set of features (e.g., color, name, pet, book genre). + - Each feature has a finite domain of possible values. +2. Constraint: + - Each house has exactly one value per feature. + - No two houses share the same value for the same feature. +3. Clues / Constraints descrbing: + - Houses and their positions + - Feature values + - Relative ordering (e.g., "next to", "to the left of", "2 houses away from") + +Solve to your best ability the arrangement of features across the houses. + +# Final Answer Format + +```json +{ + "House 1": { "feature1": "value1", "feature2": "value2", ... }, + "House 2": { "feature1": "value1", "feature2": "value2", ... }, + ... +} +``` + +Make sure to use the exact feature/value names as given in the problem description. +Make sure the JSON is valid and parsable.""" + +SYSTEM_PROMPT_STATEEXTRACT = """\ +# Problem Description + +You are solving a house grid problem. You are given: +1. Features and Domains + - A fixed number of houses, indexed sequentially (e.g., House 1, House 2, …) from left to right. + - A set of features (e.g., color, name, pet, book genre). + - Each feature has a finite domain of possible values. +2. Constraint: + - Each house has exactly one value per feature. + - No two houses share the same value for the same feature. +3. Clues / Constraints describing: + - Houses and their positions + - Feature values + - Relative ordering (e.g., "next to", "to the left of", "2 houses away from") + +# Rules for Solving + +1. Reason about the problem in text. +2. After every inference, no matter how minor or tentative, immediately report the updated partial assignments. + - Always output partial assignments frequently as the reasoning progresses, not only at major steps or when confident. + - If an inference adds, removes, or narrows even a single possibility, report it. + +# House/Feature Partial Assignment Reporting Format + +```json +{ + "House N": { "feature1": "value1", "feature2": "value2", ... }, + ... +} +``` + +Omit any unassigned features. +Make sure to use the exact feature/value names as given in the problem description. +Make sure the JSON is valid and parsable.""" + +SYSTEM_PROMPT_VANILLA = """\ +# Problem Description + +You are solving a house grid problem. You are given: +1. Features and Domains + - A fixed number of houses, indexed sequentially (e.g., House 1, House 2, …) from left to right. + - A set of features (e.g., color, name, pet, book genre). + - Each feature has a finite domain of possible values. +2. Constraint: + - Each house has exactly one value per feature. + - No two houses share the same value for the same feature. +3. Clues / Constraints describing: + - Houses and their positions + - Feature values + - Relative ordering (e.g., "next to", "to the left of", "2 houses away from") + +# Rules for Solving + +1. Reason about the problem in text. +2. You may receive feedback from the user if anything is wrong. Use any feedback to guide your reasoning until a complete solution is reached. +3. Do not stop responding until you've assigned each and every variable. + +# Final Answer Reporting Format + +```json +{ + "House 1": { "feature1": "value1", "feature2": "value2", ... }, + "House 2": { "feature1": "value1", "feature2": "value2", ... }, + ... +} +``` + +Make sure to use the exact feature/value names as given in the problem description. +Make sure the JSON is valid and parsable.""" + +USER_PROMPT_TEMPLATE = "{problem_text}" + +# ============== Dataset Loading ============== + +def clean_problem_text(problem_text: str, features: dict) -> str: + """Clean up problem text giving explicit feature domains.""" + desc, clues = problem_text.split('## clues:') + line0 = desc.splitlines()[0] + + feature_text = '' + for feature, values in features.items(): + values_str = ', '.join(f"'{v}'" for v in values) + feature_text += f"- '{feature}': {values_str}\n" + + return f"{line0}\n{feature_text}\n## clues:{clues}".strip() + + +def process_zebralogic_problem(problem: dict, ir_map: dict) -> dict: + """Process a raw ZebraLogic problem into the format needed by ZebraLogicProblem. + + Args: + problem: Raw problem dict from the HuggingFace dataset. + ir_map: Dict mapping problem IDs to their IR representations. + + Returns: + Processed problem dict with keys: n_houses, n_features, features, clues, + clue_irs, solution_irs, solution, puzzle_clean, etc. + """ + def apply_text_replacements(problem): + replacements = [ + ['january', 'jan'], ['february', 'feb'], ['march', 'mar'], + ['august', 'aug'], ['september', 'sept'], ['f-150', 'f150'], + ['animal', 'pet'], ['loves the spaghetti eater', 'loves spaghetti'], + ['very short', 'veryshort'], ['super short', 'supershort'], + ['very tall', 'verytall'], ['super tall', 'supertall'], + ] + problem_str = json.dumps(problem) + for old, new in replacements: + problem_str = problem_str.replace(old, new) + return json.loads(problem_str) + + pid = problem['id'] + size = problem["size"] + n_houses, n_features = map(int, size.split("*")) + + problem_text = problem["puzzle"].lower() + clues_raw = re.split(r"##\s*clues\s*:", problem_text, flags=re.IGNORECASE)[1].strip().split("\n") + clues = [] + for clue in clues_raw: + clue_text_index, clue_text = clue.strip().split(". ", 1) + clues.append({ + "text_index": int(clue_text_index.strip()), + "text": clue_text.strip() + }) + + solution = problem["solution"] + solution['header'] = [h.lower() for h in solution['header']] + solution['rows'] = [[v.lower() for v in row] for row in solution['rows']] + + features = defaultdict(list) + for row in solution['rows']: + for fname, value in zip(solution['header'], row): + if fname.lower() == 'house': + continue + features[fname].append(value) + features = dict(features) + for fname in features: + features[fname] = sorted(list(set(features[fname]))) + assert len(features[fname]) == n_houses + assert len(features) == n_features + + processed_solution = {f'House {i+1}': {} for i in range(n_houses)} + for house_i, row in enumerate(solution['rows']): + for fname, value in zip(solution['header'][1:], row[1:]): + processed_solution[f'House {house_i+1}'][fname.lower()] = value.lower() + problem['solution'] = processed_solution + + problem['puzzle'] = problem_text + problem["clues"] = clues + problem["features"] = features + problem["n_houses"] = n_houses + problem["n_features"] = n_features + problem["clue_irs"] = ir_map[pid]["clue_irs"] + problem["solution_irs"] = ir_map[pid]["solution_irs"] + problem['puzzle_clean'] = clean_problem_text(problem['puzzle'], problem['features']) + + return apply_text_replacements(problem) + +def get_zebralogic_dataset() -> list: + """Load and process the ZebraLogic dataset from HuggingFace. + + Loads WildEval/ZebraLogic grid_mode test split and processes each problem + with the IR map. + + The IR map file must be at interwhen/data/zebralogic_ir_map.json + + Returns: + List of processed problem dicts. + """ + + dataset = datasets.load_dataset("WildEval/ZebraLogic", "grid_mode", split="test").to_list() + + pkg = "interwhen.data" + with resources.files(pkg).joinpath("zebralogic_ir_map.json").open("r") as f: + ir_map = json.load(f) + + # Known problematic problem IDs (unsolvable or malformed) + bad_ids = { + 'lgp-test-6x5-2', 'lgp-test-6x6-5', 'lgp-test-2x5-1', 'lgp-test-4x5-5', + 'lgp-test-2x4-6', 'lgp-test-2x6-11', 'lgp-test-4x6-35', 'lgp-test-3x5-15', + 'lgp-test-5x5-37', 'lgp-test-5x5-17', 'lgp-test-4x5-15', 'lgp-test-6x6-2', + 'lgp-test-5x6-4', 'lgp-test-5x6-2', 'lgp-test-5x5-1' + } + dataset = [p for p in dataset if p['id'] not in bad_ids] + dataset = [process_zebralogic_problem(p, ir_map) for p in dataset] + return dataset + +def extract_last_json(text): + """Extract the last JSON object from the model's output text.""" + json_text = text.split('')[-1].strip() + + # try with md tags + matches = re.findall(r'```json(.*?)```', json_text, re.DOTALL) + if matches and len(matches) > 0: + json_match = matches[-1] + return json.loads(json_match.strip()) + + # try without assuming md tags + matches = re.findall(r'\{.*\}\s*?}', json_text, re.DOTALL) + if matches and len(matches) > 0: + json_match = matches[0] + return json.loads(json_match.strip()) + + return None + +def zebra_correctness(problem: dict, candidate_solution: dict) -> Tuple[int, int, int, int]: + """Check candidate solution against ground truth. + + Args: + problem: Processed problem dict with 'solution', 'n_houses', 'n_features', 'features'. + candidate_solution: Dict mapping "House N" -> {feature: value}. + + Returns: + (correct, skipped, missing, total) counts where: + correct: Number of matching assignments + skipped: Assignments for invalid houses/features + missing: Features not in candidate solution + total: Total expected assignments (n_houses * n_features) + """ + c, s = 0, 0 + t_soln = 0 + t = problem['n_houses'] * problem['n_features'] + solution = problem['solution'] + + for house in candidate_solution: + if house not in solution: + s += len(problem['features']) + continue + for fname in candidate_solution[house]: + if fname not in solution[house]: + s += 1 + continue + t_soln += 1 + if candidate_solution[house][fname] == solution[house][fname]: + c += 1 + + m = t - t_soln + return c, s, m, t