diff --git a/.gitignore b/.gitignore index b7faf403..43e0af15 100644 --- a/.gitignore +++ b/.gitignore @@ -198,6 +198,9 @@ cython_debug/ # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data # refer to https://docs.cursor.com/context/ignore-files +.*/Outputs_TTS/ +Outputs_TTS/ +Outputs_TTS_temp/ .cursorignore .cursorindexingignore diff --git a/examples/TTSwithVerification/MULTIPROCESS_README.md b/examples/TTSwithVerification/MULTIPROCESS_README.md new file mode 100644 index 00000000..4a64612d --- /dev/null +++ b/examples/TTSwithVerification/MULTIPROCESS_README.md @@ -0,0 +1,67 @@ +# Multi-Process vLLM Setup for Best-of-K Baseline + +This directory contains scripts and code for running the best-of-K baseline with multi-process vLLM serving. + +## Setup + +### 1. Start vLLM with 4 processes (2 GPUs each) + +```bash +bash start_vllm_multiprocess.sh +``` + +This launches 4 vLLM OpenAI-compatible API servers: +- **Process 1**: GPUs 0-1, Port 8000 +- **Process 2**: GPUs 2-3, Port 8001 +- **Process 3**: GPUs 4-5, Port 8002 +- **Process 4**: GPUs 6-7, Port 8003 + +Each process uses `tensor-parallel-size 2` for distributed inference. + +### 2. Run the baseline + +In a separate terminal: + +```bash +# Test with 1 example +python bestofk_baseline.py --task game24 --num_examples 1 --k 4 --use_critic + +# Run on maze dataset +python bestofk_baseline.py --task maze --num_examples 10 --k 4 + +# Run on spatialmap dataset +python bestofk_baseline.py --task spatialmap --num_examples 5 --k 4 +``` + +Or use the test script: +```bash +bash run_multiprocess_test.sh game24 5 +``` + +## Load Balancing + +- Requests are distributed **round-robin** across the 4 vLLM instances +- Each generation request goes to the next available port (8000 → 8001 → 8002 → 8003 → 8000 ...) +- Critic evaluation requests use separate round-robin tracking (independent counter) +- This ensures even load distribution across all 4 GPU pairs + +## Stopping vLLM + +```bash +pkill -9 -f "vllm.entrypoints.openai.api_server" +``` + +## Configuration + +Edit `start_vllm_multiprocess.sh` to change: +- `MODEL`: Model name (default: `Qwen/QwQ-32B`) +- `MAX_TOKENS`: Maximum sequence length (default: 8192) +- `GPU_MEMORY`: GPU memory utilization (default: 0.4) +- `TENSOR_PARALLEL`: Must be ≤ 2 for this 8-GPU setup + +## Benefits + +- **Better throughput**: 4 independent processes handle requests in parallel +- **Fault tolerance**: If one process crashes, others continue +- **GPU utilization**: Balanced load across all 8 GPUs (2 GPUs per process) +- **Reduced latency**: Each process has dedicated GPU resources diff --git a/examples/TTSwithVerification/README.md b/examples/TTSwithVerification/README.md index ed7fbb4e..a2c52dd4 100644 --- a/examples/TTSwithVerification/README.md +++ b/examples/TTSwithVerification/README.md @@ -156,6 +156,39 @@ The Z3 solver handles diagonal directions (`Northwest`, `Northeast`, `Southwest` --- +# Best-of-K Baseline + +A simple best-of-K baseline that generates K independent reasoning traces per example and selects the best based on: +1. **Ground-truth matching** (default): Greedy selection of first correct answer among K samples +2. **Critic model evaluation** (optional): Use a separate critic LLM to evaluate correctness without access to ground truth + +This baseline demonstrates that with sufficient sampling, even simple CoT can achieve good performance. + +## Usage + +```bash +# Best-of-K with ground-truth evaluation +python ./examples/TTSwithVerification/bestofk_baseline.py --task game24 -n 10 --k 4 + +# Best-of-K with critic model evaluation +python ./examples/TTSwithVerification/bestofk_baseline.py --task game24 -n 10 --k 4 --use_critic --critic_model Qwen/Qwen3-30B-A3B-Thinking-2507 --critic_port 8001 +``` + +### Parameters + +| Argument | Description | Default | +|----------|-------------|---------| +| `--task` | Task: `game24`, `maze`, or `spatialmap` | required | +| `--k` | Number of samples per example | `4` | +| `--use_critic` | Use critic model for evaluation instead of ground truth | `False` | +| `--critic_model` | Model to use for critic evaluation | MAIN_MODEL | +| `--critic_port` | vLLM server port for critic model | `8001` | +| `--num_examples`, `-n` | Number of examples to run | varies | +| `--main_model` | Model for generation | `Qwen/Qwen3-30B-A3B-Thinking-2507` | +| `--port` | vLLM server port for main model | `8000` | + +--- + ## Example Scripts Each script runs a full evaluation: loading a dataset, building structured prompts, running inference with step verification, and computing accuracy/token statistics. @@ -169,6 +202,14 @@ python ./examples/TTSwithVerification/maze_stepverifier.py -n 1 # SpatialMap with step verification python ./examples/TTSwithVerification/spatialmap_stepverifier.py -n 1 + +# Best-of-K baseline (standard CoT, no monitors) +python ./examples/TTSwithVerification/bestofk_baseline.py --task game24 -n 1 --k 4 +python ./examples/TTSwithVerification/bestofk_baseline.py --task maze -n 1 --k 4 +python ./examples/TTSwithVerification/bestofk_baseline.py --task spatialmap -n 1 --k 4 + +# Best-of-K with critic model evaluation +python ./examples/TTSwithVerification/bestofk_baseline.py --task game24 -n 1 --k 4 --use_critic ``` ### Common arguments diff --git a/examples/TTSwithVerification/bestofk_baseline.py b/examples/TTSwithVerification/bestofk_baseline.py new file mode 100644 index 00000000..eb6aef3d --- /dev/null +++ b/examples/TTSwithVerification/bestofk_baseline.py @@ -0,0 +1,1023 @@ +import argparse +import asyncio +from datetime import datetime +import json +import logging +import os +import re +import sys +import shutil +import subprocess +from multiprocessing.pool import ThreadPool +from contextlib import contextmanager +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import numpy as np +import pandas as pd +import aiohttp +from datasets import load_dataset +from tqdm import tqdm +from transformers import AutoTokenizer + +from interwhen.utils.zebralogic_helper import SYSTEM_PROMPT_VANILLA, USER_PROMPT_TEMPLATE, get_zebralogic_dataset, extract_last_json, zebra_correctness + +from interwhen import stream_completion +from verina_utils import * + +# ============== MODEL CONFIGURATION ============== +MAIN_MODEL = "Qwen/Qwen3-30B-A3B-Thinking-2507" +# Multi-process vLLM configuration +VLLM_PORTS = [8000, 8001, 8002] # 3 instances with tensor-parallel-size 2 each +REQUEST_COUNTER = {"main": 0, "critic": 0} # Track request count for round-robin load balancing +# Verina paths +_SCRIPT_DIR = Path(__file__).parent.resolve() +VERINA_ROOT = (_SCRIPT_DIR / "../../../verina").resolve() +VERINA_DATASETS_PATH = VERINA_ROOT / "datasets" / "verina" +LEAN_PLAYGROUND_DIR = VERINA_ROOT / "lean-playground" + +logger = logging.getLogger(__name__) + +# Save the real stderr so tqdm always works even if suppress_output is active +_real_stderr = sys.stderr + +class NumpyEncoder(json.JSONEncoder): + """Custom JSON encoder that handles numpy types.""" + def default(self, obj): + if isinstance(obj, np.bool_): + return bool(obj) + if isinstance(obj, np.integer): + return int(obj) + if isinstance(obj, np.floating): + return float(obj) + if isinstance(obj, np.ndarray): + return obj.tolist() + return super().default(obj) + + +@contextmanager +def suppress_output(): + """Context manager to suppress stdout and stderr.""" + with open(os.devnull, 'w') as devnull: + old_stdout = sys.stdout + old_stderr = sys.stderr + sys.stdout = devnull + sys.stderr = devnull + try: + yield + finally: + sys.stdout = old_stdout + sys.stderr = old_stderr + + +@dataclass +class SampleResult: + output: str + correct: bool + extracted: Optional[str] + message: str + tokens: int + critic_correct: Optional[bool] = None + critic_feedback: Optional[str] = None + + +def get_model_short_name(model_name: str) -> str: + short_name = model_name.split("/")[-1] + return short_name.replace(" ", "_").replace(":", "-") + + +def get_next_port(server_type: str = "main") -> int: + """Get next vLLM port in round-robin fashion.""" + global REQUEST_COUNTER + port = VLLM_PORTS[REQUEST_COUNTER[server_type] % len(VLLM_PORTS)] + REQUEST_COUNTER[server_type] += 1 + return port + + +def get_output_dirs(task: str, main_model: str, use_critic: bool, critic_early_stop: bool, base_dir: str = "../../b-pchanda/Outputs_TTS_temp/BestOfKResults"): + model_short_name = get_model_short_name(main_model) + critic_status = "on" if use_critic else "off" + earlystop_status = "on" if critic_early_stop else "off" + output_base = os.path.join(base_dir, task, model_short_name, f"critic_{critic_status}", f"earlystop_{earlystop_status}") + dirs = { + "base": output_base, + "reasoning": os.path.join(output_base, "Reasoning_output"), + "critic": os.path.join(output_base, "Critic_output") if use_critic else None, + } + for dir_path in dirs.values(): + if dir_path: + os.makedirs(dir_path, exist_ok=True) + return dirs + +def init_llm_server(modelname, max_tokens=200, port=8000, temperature=0.6, seed=42): # + url = f"http://localhost:{port}/v1/completions" + payload = { + "model": modelname, + "max_tokens": max_tokens, + "top_k": 20, + "top_p": 0.95, + "min_p": 0.0, + "do_sample" : True, + "temperature": temperature, + "stream": False, + "logprobs": 20, + "use_beam_search": False, + "prompt_cache": True, + "seed" : seed + } + headers = {"Content-Type": "application/json"} + return {"url": url, "payload": payload, "headers": headers} + + +def count_tokens(text: str, tokenizer) -> int: + """Count tokens in text, with fallback to character count.""" + try: + if not text or len(text.strip()) == 0: + return 0 + tokens = tokenizer.encode(text, add_special_tokens=False) + return len(tokens) + except Exception as e: + logger.warning(f"Tokenization failed: {e}, using character count estimate") + # Rough estimate: ~4 characters per token + return max(1, len(text) // 4) + + +def save_outputs(idx: int, outputs: List[SampleResult], best_idx: int, output_dir: str): + os.makedirs(output_dir, exist_ok=True) + filepath = os.path.join(output_dir, f"output_{idx}.txt") + with open(filepath, "w", encoding="utf-8") as f: + f.write(f"BEST_INDEX={best_idx}\n") + for i, result in enumerate(outputs): + f.write("\n" + "=" * 80 + "\n") + f.write(f"SAMPLE {i}\n") + f.write(f"CORRECT={result.correct}\n") + f.write(f"CRITIC_CORRECT={result.critic_correct}\n") + f.write(f"EXTRACTED={result.extracted}\n") + f.write(f"TOKENS={result.tokens}\n") + f.write(f"MESSAGE={result.message}\n") + if result.critic_feedback: + f.write(f"CRITIC_FEEDBACK={result.critic_feedback}\n") + f.write("\n") + f.write(result.output) + f.write("\n") + # logger.info(f"Saved outputs to {filepath}") + + +# --------------------- Game24 helpers --------------------- + +def build_game24_prompt(nums): + a, b, c, d = nums + boxed = r"\\boxed{}" + base_prompt = f""" +You are solving the Game of 24. + +You are given four numbers: {a}, {b}, {c}, {d} + +Your job is to produce a valid arithmetic expression using: +- ALL four numbers exactly once +- ONLY +, -, *, / +- The expression must evaluate to exactly 24. + +Please reason step by step, and put your final answer containing only the expression within {boxed}. +""".strip() + return base_prompt + + +def extract_solution_game24(text): + boxed_pattern = r"\\boxed\{" + matches = list(re.finditer(boxed_pattern, text)) + if not matches: + return None + last_match = matches[-1] + start = last_match.end() + brace_count = 1 + end = start + while end < len(text) and brace_count > 0: + if text[end] == "{": + brace_count += 1 + elif text[end] == "}": + brace_count -= 1 + end += 1 + expr = text[start:end - 1].strip() + + frac_pattern = r"\\frac\{([^{}]+)\}\{([^{}]+)\}" + while re.search(frac_pattern, expr): + expr = re.sub(frac_pattern, r"(\1/\2)", expr) + + replacements = { + r"\times": "*", + r"\cdot": "*", + r"\div": "/", + } + for latex, op in replacements.items(): + expr = expr.replace(latex, op) + + expr = expr.replace(r"\\,", "").replace(r"\\ ", "") + expr = re.sub(r"\)\s*\(", ")*(", expr) + expr = re.sub(r"\)\s*(\d)", r")*\1", expr) + expr = re.sub(r"(\d)\s*\(", r"\1*(", expr) + + return expr + + +def extract_numbers_from_expr(expr): + numbers = re.findall(r"\d+\.?\d*", expr) + return [int(float(n)) if float(n).is_integer() else float(n) for n in numbers] + + +def validate_numbers_used(expr, expected_nums): + used_nums = extract_numbers_from_expr(expr) + return sorted(used_nums) == sorted(expected_nums) + + +def evaluate_expression(expr, expected_nums=None): + try: + if expected_nums is not None and not validate_numbers_used(expr, expected_nums): + return False + value = eval(expr, {"__builtins__": None}, {}) + return abs(value - 24) < 1e-6 + except Exception: + return False + + +def evaluate_game24_answer(answer, nums): + expr = extract_solution_game24(answer) + if not expr: + return False, None, "No expression found" + if evaluate_expression(expr, expected_nums=nums): + return True, expr, "Correct solution (evaluates to 24 using exactly the given numbers)" + used_nums = extract_numbers_from_expr(expr) + if sorted(used_nums) != sorted(nums): + return False, expr, f"Incorrect: Expression uses {used_nums}, expected {nums}" + return False, expr, "Expression does not evaluate to 24" + + +# --------------------- Maze/SpatialMap helpers --------------------- + +def remove_last_paragraph(s: str) -> str: + return s[:-143] if len(s) > 143 else s + + +def build_maze_prompt(example): + pre_prompt = ( + "You are an expert problem solver. Carefully read the following multiple-choice question " + "and think through the solution step-by-step before providing your final answer. " + "Provide your final answer option by enclosing it within \\boxed{A/B/C/D}.:" + ) + description = remove_last_paragraph(str(example.get("prompt"))) + return pre_prompt, description + + +def build_spatialmap_prompt(example): + pre_prompt = ( + "You are an expert problem solver. Carefully read the following multiple-choice question " + "and think through the solution step-by-step before providing your final answer." + "Provide your final answer option by enclosing it within \\boxed{A/B/C/D}.:" + ) + description = remove_last_paragraph(str(example.get("prompt"))) + return pre_prompt, description + + +def extract_solution_mcq(text): + """Extract MCQ solution from model output.""" + # Try multiple boxed patterns + patterns = [ + r"\\boxed\{([^}]*)\}", # \boxed{...} + r"boxed\{([^}]*)\}", # boxed{...} without escape + r"\*\*([A-D])\*\*", # **A** format + r"answer[:\s]*([A-D])", # answer: A format + r"(?:^|\n)([A-D])(?:\s|$|\.)", # Standalone letter + ] + + for pattern in patterns: + matches = re.findall(pattern, text, re.IGNORECASE) + if matches: + expr = matches[-1].strip() + choice_match = re.search(r"\b([ABCD])\b", expr, flags=re.IGNORECASE) + if choice_match: + return choice_match.group(1).upper() + + # Last resort: look for any standalone A, B, C, or D + standalone = re.findall(r"\b([ABCD])\b", text) + if standalone: + return standalone[-1].upper() + + return None + + +def extract_options_from_prompt(prompt_text, target_options): + pattern = r"\b([A-D])\.\s*(.*?)(?=\s*[A-D]\.\s*|$)" + raw = re.findall(pattern, prompt_text, flags=re.DOTALL) + options = {k: v.strip().rstrip(".") for k, v in raw} + if target_options: + options = {k: v for k, v in options.items() if k in target_options} + return options + + +def evaluate_mcq_answer(answer, options, ground_truth): + sol = extract_solution_mcq(answer) + gt_sol = str(ground_truth).strip() + if not sol: + return False, None, "No expression found" + sol = sol.strip() + if sol in options: + if options[sol] == gt_sol: + return True, sol, f"Correct: option {sol} -> {options[sol]}" + return False, sol, f"Incorrect: expected '{gt_sol}', got '{options[sol]}' (option {sol})" + if sol.lower() == gt_sol.lower(): + return True, sol, f"Correct: answer text matches ground truth: {sol}" + for opt_letter, opt_value in options.items(): + if sol.lower() == opt_value.lower(): + if opt_value == gt_sol: + return True, sol, f"Correct: answer text {sol} (option {opt_letter})" + return False, sol, f"Incorrect: expected '{gt_sol}', got '{opt_value}' (option {opt_letter})" + return False, sol, f"Solution '{sol}' not found in options or ground truth" + +# --------------------- ZebraLogic helpers --------------------- + +def evaluate_zebralogic_answer(answer, example): + """Evaluate a zebralogic answer against ground truth using zebra_correctness.""" + candidate = extract_last_json(answer) + if not candidate: + return False, None, "No valid JSON solution found" + correct, skipped, missing, total = zebra_correctness(example, candidate) + is_correct = correct == total + msg = f"Correct={correct}/{total}, skipped={skipped}, missing={missing}" + return is_correct, candidate, msg + + +def build_zebralogic_prompt(example): + system_prompt = SYSTEM_PROMPT_VANILLA + user_prompt = USER_PROMPT_TEMPLATE.format(problem_text=example['puzzle_clean']) + return system_prompt, user_prompt + +# verina helpers +def evaluate_verina_answer(output: str, data: BenchmarkData, task_idx: int) -> Tuple[bool, str, str]: + """Evaluate Verina code generation output - wrapper for best-of-k interface""" + generated_code = extract_code_from_response(output) + + if not generated_code.strip(): + return False, "", "No code extracted from response" + + compiles, all_tests_pass, compile_output, test_results = evaluate_generated_code(data, generated_code, task_idx) + + num_tests = len(data.tests) if data.tests else 0 + num_passed = sum(1 for v in test_results.values() if v == "pass") + + if compiles and all_tests_pass: + return True, generated_code, f"Code compiles and all {num_tests} tests pass" + elif compiles: + return False, generated_code, f"Compilation succeeded but {num_tests - num_passed}/{num_tests} tests failed" + else: + error_preview = compile_output[:300] if compile_output else "Unknown error" + return False, generated_code, f"Compilation failed: {error_preview}" + + +def build_full_prompt(task, example, nums=None): + if task == "game24": + prompt = build_game24_prompt(nums) + return f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" + if task == "maze": + system_prompt, user_prompt = build_maze_prompt(example) + elif task == 'zebralogic': + system_prompt, user_prompt = build_zebralogic_prompt(example) + elif task == "verina": + return build_verina_prompt(example) + else: + system_prompt, user_prompt = build_spatialmap_prompt(example) + return ( + f"<|im_start|>system\n{system_prompt}<|im_end|>\n" + f"<|im_start|>user\n{user_prompt}<|im_end|>\n" + f"<|im_start|>assistant\n" + ) + + +def load_dataset_for_task(task): + if task == "game24": + return load_dataset("nlile/24-game", split="train") + if task == "maze": + return load_dataset("microsoft/VISION_LANGUAGE", "maze_text_only", split="val") + if task == "spatialmap": + return load_dataset("microsoft/VISION_LANGUAGE", "spatial_map_text_only", split="val") + if task == "zebralogic": + return get_zebralogic_dataset() + if task == "verina": + return load_verina_dataset() + raise ValueError(f"Unsupported task: {task}") + + +def resolve_indices(task, dataset_len, args): + if args.indices: + return [int(x.strip()) for x in args.indices.split(",")] + if args.xrange: + parts = args.xrange.split("-") + if len(parts) == 2: + try: + start = int(parts[0].strip()) + end = int(parts[1].strip()) + return range(start, end) + except ValueError: + raise ValueError(f"Invalid xrange format: {args.xrange}. Use 'start-end'") + if args.num_examples: + max_idx = dataset_len - 1 + upper_bound = min(max_idx, 1362) if task == "game24" else min(max_idx, 1499) + return list((np.linspace(0, upper_bound, args.num_examples)).astype(int)) + # Default: use full range + start = args.start if args.start is not None else 0 + end = args.end if args.end is not None else dataset_len + return range(start, end) + + +def batch_generate_samples(prompt, llm_server, k, seed, quiet=True): + """Generate k samples using vLLM batch processing via API across multiple instances.""" + payload_template = llm_server["payload"].copy() + headers = llm_server["headers"] + + # Create k requests with different seeds + batch_payloads = [] + for i in range(k): + payload = payload_template.copy() + payload["prompt"] = prompt + payload["seed"] = seed + i + batch_payloads.append(payload) + + # Send requests to vLLM instances in parallel (true concurrency) + async def _fetch_one(session, sem, idx, url, payload): + async with sem: + try: + async with session.post(url, json=payload, headers=headers, timeout=300) as resp: + text = await resp.text() + if resp.status >= 400: + logger.warning(f"HTTP error for seed {seed + idx} on {url}: {resp.status} - {text[:200]}") + return idx, "" + try: + result = json.loads(text) + except json.JSONDecodeError: + logger.warning(f"Invalid JSON for seed {seed + idx} on {url}") + return idx, "" + except Exception as e: + logger.warning(f"Batch generation failed for seed {seed + idx} on {url}: {e}") + return idx, "" + + if "choices" in result and len(result["choices"]) > 0: + choice = result["choices"][0] + if isinstance(choice, dict): + output_text = choice.get("text") or choice.get("message", {}).get("content", "") + else: + output_text = str(choice) + if output_text and len(output_text.strip()) > 0: + return idx, output_text + logger.warning(f"Empty output for seed {seed + idx} on {url}") + return idx, "" + + logger.warning(f"No choices in response for seed {seed + idx} on {url}: {result.keys() if isinstance(result, dict) else type(result)}") + return idx, "" + + async def _run_parallel(): + sem = asyncio.Semaphore(len(VLLM_PORTS)) + async with aiohttp.ClientSession() as session: + tasks = [] + for i, payload in enumerate(batch_payloads): + port = get_next_port(server_type="main") + url = f"http://localhost:{port}/v1/completions" + tasks.append(asyncio.create_task(_fetch_one(session, sem, i, url, payload))) + results = await asyncio.gather(*tasks) + return results + + if quiet: + with suppress_output(): + results = asyncio.run(_run_parallel()) + else: + results = asyncio.run(_run_parallel()) + + outputs = [""] * k + for idx, output_text in results: + outputs[idx] = output_text + if output_text and not quiet: + print(f"[Generated sample {idx}] {len(output_text)} chars, {len(output_text.split())} words") + + return outputs + + +# --------------------- Critic model helpers --------------------- + +def build_game24_critic_prompt(nums, reasoning_output): + """Build critic prompt to evaluate Game of 24 solution and provide reasoning.""" + return f"""You are a math verifier. Evaluate the following Game of 24 solution. + +Numbers: {nums} +Target: 24 + +Student's reasoning and answer: +{reasoning_output} + +Verify: +1. Does it use ALL four numbers exactly once? +2. Does each step follow correct arithmetic? +3. Does the final expression evaluate to exactly 24? + +Respond in the following format: +VERDICT: CORRECT or INCORRECT +REASONING: Your detailed explanation + +If CORRECT, briefly explain why. +If INCORRECT, explain what went wrong and how to fix it. +""" + + +def build_zebralogic_critic_prompt(task_description, reasoning_output): + """Build critic prompt to evaluate ZebraLogic solution and provide reasoning.""" + return f"""You are an expert logic puzzle verifier. Evaluate the following ZebraLogic solution. + +Task: +{task_description} + +Student's reasoning and answer: +{reasoning_output} + +Verify: +1. Does the solution assign exactly one value per feature per house? +2. Are all constraints/clues satisfied? +3. Is the JSON output well-formed and complete? + +Respond in the following format: +VERDICT: CORRECT or INCORRECT +REASONING: Your detailed explanation + +If CORRECT, briefly explain why. +If INCORRECT, explain what went wrong and suggest corrections. +""" + + +def build_mcq_critic_prompt(task, task_description, reasoning_output): + """Build critic prompt to evaluate MCQ solution and provide reasoning.""" + task_name = "Maze" if task == "maze" else "Spatial Reasoning" + return f"""You are an expert {task_name} verifier. Evaluate the following solution. + +Task: +{task_description} + +Student's reasoning and answer: +{reasoning_output} + +Verify the correctness of the step-by-step reasoning and final answer. + +Respond in the following format: +VERDICT: CORRECT or INCORRECT +REASONING: Your detailed explanation + +If CORRECT, briefly explain why. +If INCORRECT, explain what went wrong and suggest the correct approach. +""" + +def build_verina_critic_prompt(data: BenchmarkData, reasoning_output: str) -> str: + """Build critic prompt to evaluate Verina Lean code generation and provide reasoning.""" + signature = data.signature + func_name = signature.get("name", "solution") + return_type = signature.get("return_type", "Bool") + param_list = render_param_list(signature) + + precond = data.lean_data.get("precond", "True").strip() + postcond = data.lean_data.get("postcond", "").strip() + + return f"""You are an expert Lean 4 code verifier. Evaluate the following code generation attempt. + +## Task Description +{data.description} + +## Function Signature +```lean4 +def {func_name} {param_list} (h_precond : {func_name}_precond ...) : {return_type} +``` + +## Precondition +```lean4 +{precond} +``` + +## Postcondition +```lean4 +{postcond} +``` + +## Student's Reasoning and Generated Code +{reasoning_output} + +Verify: +1. Is the generated code syntactically valid Lean 4? +2. Does it match the expected function signature and return type ({return_type})? +3. Does the logic appear to satisfy the postcondition given the precondition? +4. Are there any obvious bugs, infinite loops, or incorrect base cases? + +Respond in the following format: +VERDICT: CORRECT or INCORRECT +REASONING: Your detailed explanation + +If CORRECT, briefly explain why +If INCORRECT, explain what went wrong and suggest how to fix it. +""" + +def batch_evaluate_with_critic(outputs_df, task, example, critic_llm_server, tokenizer, nums=None, quiet=True): + """Batch evaluate outputs using vLLM API across multiple instances. Outputs_df should have columns: 'output', 'seed_idx'""" + payload_template = critic_llm_server["payload"].copy() + headers = critic_llm_server["headers"] + + async def _fetch_one(session, sem, idx, url, payload): + async with sem: + try: + async with session.post(url, json=payload, headers=headers, timeout=300) as resp: + text = await resp.text() + if resp.status >= 400: + logger.warning(f"HTTP error for critic sample {idx} on {url}: {resp.status} - {text[:200]}") + return idx, "", False + try: + result = json.loads(text) + except json.JSONDecodeError: + logger.warning(f"Invalid JSON for critic sample {idx} on {url}") + return idx, "", False + except Exception as e: + logger.warning(f"Critic evaluation failed for sample {idx} on {url}: {e}") + return idx, "", False + + if "choices" in result and len(result["choices"]) > 0: + choice = result["choices"][0] + critic_output = choice.get("text") or choice.get("message", {}).get("content", "") + else: + critic_output = "" + + is_correct = "CORRECT" in critic_output.upper() + reasoning = "" + if "REASONING:" in critic_output: + reasoning = critic_output.split("REASONING:", 1)[1].strip() + elif "VERDICT:" not in critic_output: + reasoning = critic_output + + return idx, reasoning, is_correct + + async def _run_parallel(): + sem = asyncio.Semaphore(len(VLLM_PORTS)) + async with aiohttp.ClientSession() as session: + tasks = [] + for idx, row in outputs_df.iterrows(): + output_text = row["output"] + if task == "game24": + critic_prompt = build_game24_critic_prompt(nums, output_text) + elif task == "zebralogic": + _, task_desc = build_zebralogic_prompt(example) + critic_prompt = build_zebralogic_critic_prompt(task_desc, output_text) + elif task == "verina": + critic_prompt = build_verina_critic_prompt(example, output_text) + else: + if task == "maze": + _, task_desc = build_maze_prompt(example) + else: + _, task_desc = build_spatialmap_prompt(example) + critic_prompt = build_mcq_critic_prompt(task, task_desc, output_text) + + critic_system = "You are a strict academic verifier." + full_prompt = f"<|im_start|>system\n{critic_system}<|im_end|>\n<|im_start|>user\n{critic_prompt}<|im_end|>\n<|im_start|>assistant\n" + + payload = payload_template.copy() + payload["prompt"] = full_prompt + payload["seed"] = row.get("critic_seed", idx) + + port = get_next_port(server_type="critic") + url = f"http://localhost:{port}/v1/completions" + tasks.append(asyncio.create_task(_fetch_one(session, sem, idx, url, payload))) + + return await asyncio.gather(*tasks) + + if quiet: + with suppress_output(): + results = asyncio.run(_run_parallel()) + else: + results = asyncio.run(_run_parallel()) + + rows = [] + for sample_idx, reasoning, is_correct in results: + rows.append({ + "sample_idx": sample_idx, + "critic_correct": is_correct, + "critic_feedback": reasoning, + }) + + return pd.DataFrame(rows) + + +def run_k_samples_with_critic( + prompt, + llm_server, + critic_llm_server, + k, + seed, + task, + example, + tokenizer, + eval_fn, + nums=None, + early_stop=False, + quiet=True, +): + """Run k samples with critic evaluation using vLLM batching.""" + # Generate k samples + outputs = batch_generate_samples(prompt, llm_server, k, seed, quiet=quiet) + + # Create dataframe with outputs + df_samples = pd.DataFrame({ + "sample_idx": range(k), + "output": outputs, + "seed": [seed + i for i in range(k)], + }) + + # If early stop mode, stop at first critic-correct + if early_stop: + sample_results = [] + for idx, row in df_samples.iterrows(): + output = row["output"] + + # Evaluate with critic + df_critic = batch_evaluate_with_critic( + pd.DataFrame([{"output": output, "seed_idx": idx}]), + task, example, critic_llm_server, tokenizer, nums=nums, quiet=quiet + ) + critic_correct = df_critic.iloc[0]["critic_correct"] if len(df_critic) > 0 else False + critic_feedback = df_critic.iloc[0]["critic_feedback"] if len(df_critic) > 0 else "" + + # Evaluate with ground truth + is_correct, extracted, message = eval_fn(output) + token_count = count_tokens(output, tokenizer) + + sample_results.append(SampleResult( + output=output, + correct=is_correct, + extracted=extracted, + message=f"Critic verdict: {'CORRECT' if critic_correct else 'INCORRECT'} | {message}", + tokens=token_count, + critic_correct=critic_correct, + critic_feedback=critic_feedback, + )) + + if critic_correct: + break + + return sample_results + else: + # Batch critic evaluation + df_critic = batch_evaluate_with_critic( + df_samples, task, example, critic_llm_server, tokenizer, nums=nums, quiet=quiet + ) + + # Merge critic results + df_samples = df_samples.merge(df_critic, left_index=True, right_on="sample_idx", how="left") + + # Process all results + sample_results = [] + for idx, row in df_samples.iterrows(): + output = row["output"] + critic_correct = row.get("critic_correct", False) + critic_feedback = row.get("critic_feedback", "") + + is_correct, extracted, message = eval_fn(output) + token_count = count_tokens(output, tokenizer) + + sample_results.append(SampleResult( + output=output, + correct=is_correct, + extracted=extracted, + message=f"Critic verdict: {'CORRECT' if critic_correct else 'INCORRECT'} | {message}", + tokens=token_count, + critic_correct=critic_correct, + critic_feedback=critic_feedback, + )) + + return sample_results + + + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Best-of-K baseline (standard CoT) for TTSwithVerification datasets") + parser.add_argument("--task", type=str, required=True, choices=["game24", "maze", "spatialmap", "zebralogic","verina"], + help="Task to run") + parser.add_argument("--k", type=int, default=1, help="Number of samples per example") + parser.add_argument("--num_examples", "-n", type=int, default=100, + help="Number of examples to run (overrides start/end)") + parser.add_argument("--indices", type=str, default=None, + help="Comma-separated indices to run") + parser.add_argument("--xrange", type=str, default=None, + help="Range of indices to run (format: 'start-end')") + parser.add_argument("--start", type=int, default=None, help="Start index") + parser.add_argument("--end", type=int, default=None, help="End index") + parser.add_argument("--main_model", type=str, default=MAIN_MODEL, help="Main model to use for generation") + parser.add_argument("--port", type=int, default=8000, help="vLLM server port") + parser.add_argument("--use_critic", action="store_true", help="Use critic model for evaluation instead of ground truth") + parser.add_argument("--critic_model", type=str, default=MAIN_MODEL, help="Critic model to use for evaluation") + parser.add_argument("--critic_port", type=int, default=8000, help="vLLM server port for critic model (default: same as main model port)") + parser.add_argument("--critic_early_stop", action="store_true", help="Stop sampling after first critic-correct trace") + parser.add_argument("--critic_feedback_baseline", action="store_true", help="Use critic feedback as a separate baseline for post-hoc correction") + parser.add_argument("--seed", type=int, default=42, help="Base random seed") + parser.add_argument("--max_tokens", type=int, default=32768, help="Max tokens for generation") + parser.add_argument("--temperature", type=float, default=0.6, help="Sampling temperature") + parser.add_argument("--processes", "-p", type=int, default=1, help="Number of examples to process in parallel (default: 1, sequential)") + parser.add_argument("--debug", "-d", action="store_true", help="Enable debug logging") + args = parser.parse_args() + + log_level = logging.DEBUG if args.debug else logging.ERROR + logging.basicConfig(level=log_level, format="%(message)s") + + quiet_mode = not args.debug + + if quiet_mode: + with suppress_output(): + dataset = load_dataset_for_task(args.task) + else: + dataset = load_dataset_for_task(args.task) + indices = resolve_indices(args.task, len(dataset), args) + + llm_server = init_llm_server( + args.main_model, + max_tokens=args.max_tokens, + port=args.port, + temperature=args.temperature, + seed=args.seed, + ) + + critic_llm_server = None + if args.use_critic: + critic_llm_server = init_llm_server( + args.critic_model, + max_tokens=512, + port=args.critic_port, + temperature=0.2, + seed=args.seed, + ) + # logger.info(f"Using critic model: {args.critic_model} on port {args.critic_port}") + + # logger.info(f"Loading tokenizer for {args.main_model}...") + if quiet_mode: + with suppress_output(): + tokenizer = AutoTokenizer.from_pretrained(args.main_model, trust_remote_code=True) + else: + tokenizer = AutoTokenizer.from_pretrained(args.main_model, trust_remote_code=True) + # logger.info("Tokenizer loaded successfully.") + + output_dirs = get_output_dirs(args.task, args.main_model, args.use_critic, args.critic_early_stop) + + total_examples = 0 + total_correct = 0 + total_correct_samples = 0 + total_samples = 0 + critic_correct_samples = 0 + critic_total_samples = 0 + total_tokens = 0 + total_tokens_all_samples = 0 + results = [] + + def process_example(idx): + """Process a single example: generate k samples, evaluate, return result dict.""" + example = dataset[int(idx)] + if args.task == "game24": + nums = example["numbers"] + prompt = build_full_prompt(args.task, example, nums=nums) + eval_fn = lambda output: evaluate_game24_answer(output, nums) + options = None + + elif args.task == "zebralogic": + prompt = build_full_prompt(args.task, example) + eval_fn = lambda output, ex=example: evaluate_zebralogic_answer(output, ex) + options = None + elif args.task == "verina": + # For verina, example is a BenchmarkData object + prompt = build_full_prompt(args.task, example) + current_idx = int(idx) + current_data = example + eval_fn = lambda output, data=current_data, task_idx=current_idx: evaluate_verina_answer(output, data, task_idx) + options = None + + else: + prompt = build_full_prompt(args.task, example) + gt = str(example.get("ground_truth", "")).strip() + if gt == "Q4": + target_options = ["A", "B"] + else: + target_options = ["A", "B", "C", "D"] + if args.task == "maze": + _, user_prompt = build_maze_prompt(example) + else: + _, user_prompt = build_spatialmap_prompt(example) + options = extract_options_from_prompt(user_prompt, target_options) + eval_fn = lambda output: evaluate_mcq_answer(output, options, gt) + + #logger.info(f"---- Example {idx} ----") + + quiet_mode = not args.debug + + if args.use_critic: + sample_results = run_k_samples_with_critic( + prompt, llm_server, critic_llm_server, args.k, args.seed, + args.task, example, tokenizer, eval_fn, nums=(nums if args.task == "game24" else None), + early_stop=args.critic_early_stop, quiet=quiet_mode + ) + else: + outputs = batch_generate_samples(prompt, llm_server, args.k, args.seed, quiet=quiet_mode) + sample_results = [] + for output in outputs: + is_correct, extracted, message = eval_fn(output) + token_count = count_tokens(output, tokenizer) + sample_results.append(SampleResult( + output=output, + correct=is_correct, + extracted=extracted, + message=message, + tokens=token_count, + critic_correct=None, + )) + + if args.use_critic: + best_idx = next((i for i, r in enumerate(sample_results) if r.critic_correct), 0) + else: + best_idx = next((i for i, r in enumerate(sample_results) if r.correct), 0) + best_result = sample_results[best_idx] + any_correct = any(r.correct for r in sample_results) + correct_samples = sum(1 for r in sample_results if r.correct) + critic_correct_samples_example = sum(1 for r in sample_results if r.critic_correct) + + save_outputs(idx, sample_results, best_idx, output_dirs["reasoning"]) + + return { + "idx": int(idx), + "best_idx": best_idx, + "any_correct": any_correct, + "best_correct": best_result.correct, + "best_critic_correct": best_result.critic_correct, + "best_extracted": best_result.extracted, + "best_message": best_result.message, + "best_critic_feedback": best_result.critic_feedback, + "best_tokens": best_result.tokens, + "all_tokens": [r.tokens for r in sample_results], + "all_correct": [r.correct for r in sample_results], + "all_critic_correct": [r.critic_correct for r in sample_results], + "all_critic_feedback": [r.critic_feedback for r in sample_results], + "options": options, + "_any_correct": any_correct, + "_correct_samples": correct_samples, + "_critic_correct_samples": critic_correct_samples_example, + "_n_samples": len(sample_results), + "_best_tokens": best_result.tokens, + "_all_tokens_sum": sum(r.tokens for r in sample_results), + } + + with ThreadPool(processes=args.processes) as pool: + for result in tqdm(pool.imap_unordered(process_example, indices), total=len(indices), desc="Processing examples", unit="example", file=_real_stderr): + total_examples += 1 + if result["_any_correct"]: + total_correct += 1 + total_correct_samples += result["_correct_samples"] + total_samples += result["_n_samples"] + critic_correct_samples += result["_critic_correct_samples"] + critic_total_samples += result["_n_samples"] + total_tokens += result["_best_tokens"] + total_tokens_all_samples += result["_all_tokens_sum"] + + # Remove internal keys before appending + for k in list(result.keys()): + if k.startswith("_"): + del result[k] + results.append(result) + + accuracy = total_correct / total_examples if total_examples else 0 + avg_best_tokens = total_tokens / total_examples if total_examples else 0 + avg_all_tokens = total_tokens_all_samples / total_examples if total_examples else 0 + + summary = { + "task": args.task, + "model": args.main_model, + "k": args.k, + "use_critic": args.use_critic, + "total_examples": total_examples, + "correct": total_correct, + "correct_samples": total_correct_samples, + "total_samples": total_samples, + "critic_correct_samples": critic_correct_samples, + "critic_total_samples": critic_total_samples, + "critic_accuracy": (critic_correct_samples / critic_total_samples) if critic_total_samples else 0, + "accuracy": accuracy, + "avg_best_tokens": avg_best_tokens, + "avg_all_tokens": avg_all_tokens, + "total_tokens_best": total_tokens, + "total_tokens_all_samples": total_tokens_all_samples, + "results": results, + } + + if args.use_critic: + summary["critic_model"] = args.critic_model + summary["critic_port"] = args.critic_port + summary["critic_early_stop"] = args.critic_early_stop + summary["critic_feedback_baseline"] = args.critic_feedback_baseline + + summary_path = os.path.join(output_dirs["base"], "summary.json") + with open(summary_path, "w", encoding="utf-8") as f: + json.dump(summary, f, indent=2, cls=NumpyEncoder) + # logger.info(f"Saved summary to {summary_path}") diff --git a/examples/TTSwithVerification/check_tokens.py b/examples/TTSwithVerification/check_tokens.py new file mode 100644 index 00000000..4380c194 --- /dev/null +++ b/examples/TTSwithVerification/check_tokens.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +"""Check what payload is actually being sent.""" +import requests +import json + +url = "http://localhost:8000/v1/completions" +payload = { + "model": "Qwen/QwQ-32B", + "prompt": "What is 2+2?", + "max_tokens": 50, + "temperature": 0.0, + "stream": False, +} + +print("Payload being sent:") +print(json.dumps(payload, indent=2)) + +resp = requests.post(url, json=payload, timeout=10) +result = resp.json() +output = result["choices"][0]["text"] + +print(f"\nResponse usage:") +print(f" Tokens generated: {result['usage'].get('completion_tokens', '?')}") +print(f" Total tokens: {result['usage'].get('total_tokens', '?')}") + +print(f"\nOutput ({len(output)} chars):") +print(repr(output)) diff --git a/examples/TTSwithVerification/cleanup_vllm.sh b/examples/TTSwithVerification/cleanup_vllm.sh new file mode 100755 index 00000000..f552ec11 --- /dev/null +++ b/examples/TTSwithVerification/cleanup_vllm.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +# Cleanup script to kill all vLLM processes and Python instances + +echo "Stopping all vLLM processes..." +pkill -9 -f "vllm.entrypoints.openai.api_server" + +echo "Stopping Python processes..." +pkill -9 -f "bestofk_baseline.py" + +sleep 2 + +echo "Verifying all processes stopped..." +if pgrep -f "vllm.entrypoints.openai.api_server" > /dev/null; then + echo "WARNING: Some vLLM processes still running" +else + echo "✓ All vLLM processes stopped" +fi + +if pgrep -f "bestofk_baseline.py" > /dev/null; then + echo "WARNING: Some Python processes still running" +else + echo "✓ All Python processes stopped" +fi + +echo "Cleanup complete" diff --git a/examples/TTSwithVerification/run_experiments.py b/examples/TTSwithVerification/run_experiments.py new file mode 100644 index 00000000..a52952b0 --- /dev/null +++ b/examples/TTSwithVerification/run_experiments.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python3 +""" +Job scheduler for running bestofk_baseline.py experiments sequentially. +""" +import subprocess +import sys +import time +from datetime import datetime +from pathlib import Path + +# Base command +BASE_CMD = "python /data/b-pchanda/interwhen/examples/TTSwithVerification/bestofk_baseline.py" + +# Define your experiment configurations +EXPERIMENTS = [ + # Maze experiments + { + "name": "maze_k4_critic", + "args": "--task maze --k 4 --use_critic" + }, + { + "name": "maze_k4_critic_earlystop", + "args": "--task maze --k 4 --use_critic --critic_early_stop" + }, + { + "name": "maze_k4_no_critic", + "args": "--task maze --k 4" + }, + + # Game24 experiments + { + "name": "game24_k4_critic", + "args": "--task game24 --k 4 --use_critic" + }, + { + "name": "game24_k4_critic_earlystop", + "args": "--task game24 --k 4 --use_critic --critic_early_stop" + }, + { + "name": "game24_k4_no_critic", + "args": "--task game24 --k 4" + }, + + # Spatialmap experiments + { + "name": "spatialmap_k4_critic", + "args": "--task spatialmap --k 4 --use_critic" + }, + { + "name": "spatialmap_k4_critic_earlystop", + "args": "--task spatialmap --k 4 --use_critic --critic_early_stop" + }, + { + "name": "spatialmap_k4_no_critic", + "args": "--task spatialmap --k 4" + }, +] + + +def run_experiment(exp_config, exp_num, total_exps): + """Run a single experiment.""" + name = exp_config["name"] + args = exp_config["args"] + + print("\n" + "="*80) + print(f"Experiment [{exp_num}/{total_exps}]: {name}") + print(f"Command: {BASE_CMD} {args}") + print("="*80) + + start_time = time.time() + start_ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + try: + # Run the command + result = subprocess.run( + f"{BASE_CMD} {args}", + shell=True, + capture_output=False, # Show output in real-time + text=True + ) + + elapsed = time.time() - start_time + end_ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + if result.returncode == 0: + print(f"\n✓ Experiment '{name}' completed successfully") + print(f" Started: {start_ts}") + print(f" Finished: {end_ts}") + print(f" Duration: {elapsed:.1f}s ({elapsed/60:.1f} min)") + return True, elapsed + else: + print(f"\n✗ Experiment '{name}' failed with exit code {result.returncode}") + print(f" Duration: {elapsed:.1f}s") + return False, elapsed + + except KeyboardInterrupt: + print(f"\n\n⚠ Experiment '{name}' interrupted by user") + raise + except Exception as e: + elapsed = time.time() - start_time + print(f"\n✗ Experiment '{name}' failed with exception: {e}") + print(f" Duration: {elapsed:.1f}s") + return False, elapsed + + +def main(): + """Run all experiments sequentially.""" + print("="*80) + print("JOB SCHEDULER - Running experiments sequentially") + print("="*80) + print(f"Total experiments: {len(EXPERIMENTS)}") + print(f"Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") + + results = [] + total_start = time.time() + + try: + for i, exp in enumerate(EXPERIMENTS, 1): + success, duration = run_experiment(exp, i, len(EXPERIMENTS)) + results.append({ + "name": exp["name"], + "success": success, + "duration": duration + }) + + # Brief pause between experiments + if i < len(EXPERIMENTS): + print("\nWaiting 5 seconds before next experiment...") + time.sleep(5) + + except KeyboardInterrupt: + print("\n\n⚠ Job scheduler interrupted by user") + + finally: + # Print summary + total_elapsed = time.time() - total_start + print("\n" + "="*80) + print("SUMMARY") + print("="*80) + + successful = sum(1 for r in results if r["success"]) + failed = len(results) - successful + + print(f"\nCompleted: {len(results)}/{len(EXPERIMENTS)} experiments") + print(f"Successful: {successful}") + print(f"Failed: {failed}") + print(f"\nTotal time: {total_elapsed:.1f}s ({total_elapsed/60:.1f} min, {total_elapsed/3600:.2f} hrs)") + + print("\nDetailed results:") + for i, r in enumerate(results, 1): + status = "✓" if r["success"] else "✗" + print(f" {i}. {status} {r['name']:40s} - {r['duration']:.1f}s ({r['duration']/60:.1f} min)") + + if failed > 0: + print("\nFailed experiments:") + for i, r in enumerate(results, 1): + if not r["success"]: + print(f" {i}. {r['name']}") + + sys.exit(0 if failed == 0 else 1) + + +if __name__ == "__main__": + main() diff --git a/examples/TTSwithVerification/start_vllm_multiprocess.sh b/examples/TTSwithVerification/start_vllm_multiprocess.sh new file mode 100755 index 00000000..601918a2 --- /dev/null +++ b/examples/TTSwithVerification/start_vllm_multiprocess.sh @@ -0,0 +1,80 @@ +#!/bin/bash + +# Start 3 vLLM processes with explicit GPU assignment +# Process 1: GPUs 0-1, Port 8000, TP=2 +# Process 2: GPUs 2-3, Port 8001, TP=2 +# Process 3: GPUs 4-5, Port 8002, TP=2 + +MODEL="Qwen/Qwen3-30B-A3B-Thinking-2507" +GPU_MEMORY=0.4 +TENSOR_PARALLEL=2 + +echo "Killing any existing vLLM processes..." +pkill -9 -f "vllm.entrypoints.openai.api_server" +sleep 2 + +echo "Starting 3 vLLM processes..." + +# Process 1 - GPUs 0,1 +( + export CUDA_VISIBLE_DEVICES=0,1 + python -m vllm.entrypoints.openai.api_server \ + --model $MODEL \ + --port 8000 \ + --tensor-parallel-size $TENSOR_PARALLEL \ + --gpu-memory-utilization $GPU_MEMORY \ + --disable-log-requests \ + > /tmp/vllm_8000.log 2>&1 +) & +PID1=$! +echo "Started Process 1 (GPUs 0-1, Port 8000) - PID: $PID1" + +sleep 5 + +# Process 2 - GPUs 2,3 +( + export CUDA_VISIBLE_DEVICES=2,3 + python -m vllm.entrypoints.openai.api_server \ + --model $MODEL \ + --port 8001 \ + --tensor-parallel-size $TENSOR_PARALLEL \ + --gpu-memory-utilization $GPU_MEMORY \ + --disable-log-requests \ + > /tmp/vllm_8001.log 2>&1 +) & +PID2=$! +echo "Started Process 2 (GPUs 2-3, Port 8001) - PID: $PID2" + +sleep 5 + +# Process 3 - GPUs 4,5 +( + export CUDA_VISIBLE_DEVICES=4,5 + python -m vllm.entrypoints.openai.api_server \ + --model $MODEL \ + --port 8002 \ + --tensor-parallel-size $TENSOR_PARALLEL \ + --gpu-memory-utilization $GPU_MEMORY \ + --disable-log-requests \ + > /tmp/vllm_8002.log 2>&1 +) & +PID3=$! +echo "Started Process 3 (GPUs 4-5, Port 8002) - PID: $PID3" + +echo "" +echo "All 3 vLLM processes started successfully." +echo "Process PIDs: $PID1 $PID2 $PID3" +echo "" +echo "Log files:" +echo " /tmp/vllm_8000.log - Process 1" +echo " /tmp/vllm_8001.log - Process 2" +echo " /tmp/vllm_8002.log - Process 3" +echo "" +echo "To stop all processes, run:" +echo " pkill -9 -f 'vllm.entrypoints.openai.api_server'" +echo "" +echo "Waiting for processes to initialize (this may take 60-120 seconds)..." +echo "" + +# Wait for all processes +wait $PID1 $PID2 $PID3 diff --git a/examples/TTSwithVerification/test_maze_extraction.py b/examples/TTSwithVerification/test_maze_extraction.py new file mode 100644 index 00000000..bd5d4632 --- /dev/null +++ b/examples/TTSwithVerification/test_maze_extraction.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 +"""Test maze/spatialmap extraction with real model output.""" +import requests +import json +import re +from datasets import load_dataset + +def extract_solution_mcq(text): + """Current extraction function from baseline.""" + patterns = [ + r"\\boxed\{([^}]*)\}", # \boxed{...} + r"boxed\{([^}]*)\}", # boxed{...} without escape + r"\*\*([A-D])\*\*", # **A** format + r"answer[:\s]*([A-D])", # answer: A format + r"(?:^|\n)([A-D])(?:\s|$|\.)", # Standalone letter + ] + + for pattern in patterns: + matches = re.findall(pattern, text, re.IGNORECASE) + if matches: + expr = matches[-1].strip() + choice_match = re.search(r"\b([ABCD])\b", expr, flags=re.IGNORECASE) + if choice_match: + return choice_match.group(1).upper() + + standalone = re.findall(r"\b([ABCD])\b", text) + if standalone: + return standalone[-1].upper() + + return None + +# Load maze dataset +print("Loading maze dataset...") +dataset = load_dataset("microsoft/VISION_LANGUAGE", "maze_text_only", split="val") +example = dataset[0] + +# Build prompt like baseline does +pre_prompt = ( + "You are an expert problem solver. Carefully read the following multiple-choice question " + "and think through the solution step-by-step before providing your final answer. " + "Provide your final answer option by enclosing it within \\boxed{A/B/C/D}.:" +) +prompt_text = str(example.get("prompt", ""))[:500] # First 500 chars + +full_prompt = f"<|im_start|>system\nYou are helpful.\n<|im_end|>\n<|im_start|>user\n{pre_prompt}\n\n{prompt_text}\n<|im_end|>\n<|im_start|>assistant\n" + +print(f"\nFull prompt:\n{full_prompt[:300]}...\n") + +# Test on port 8000 +url = "http://localhost:8000/v1/completions" +payload = { + "model": "Qwen/QwQ-32B", + "prompt": full_prompt, + "max_tokens": 500, + "temperature": 0.6, + "stream": False, +} + +print("Requesting model output...") +resp = requests.post(url, json=payload, timeout=30) +if resp.status_code == 200: + result = resp.json() + output = result["choices"][0].get("text", "") + print(f"\nModel output ({len(output)} chars):") + print("="*60) + print(output) + print("="*60) + + extracted = extract_solution_mcq(output) + print(f"\nExtraction result: {extracted}") + + # Try alternative patterns + print("\nTrying alternative extraction patterns:") + if "\\boxed{" in output: + print(" - Contains \\boxed{ pattern") + if r"\boxed{" in output or r"\\boxed{" in output: + print(" - Contains escaped boxed") + if re.search(r"[Aa]nswer[:\s]*([A-D])", output): + print(" - Found 'answer: X' pattern") + response_letters = re.findall(r"\b[A-D]\b", output) + if response_letters: + print(f" - Found standalone letters: {response_letters}") + +else: + print(f"Error: {resp.status_code}") + print(resp.text[:500]) diff --git a/examples/TTSwithVerification/test_vllm_api.py b/examples/TTSwithVerification/test_vllm_api.py new file mode 100644 index 00000000..eaf88430 --- /dev/null +++ b/examples/TTSwithVerification/test_vllm_api.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +"""Test vLLM API responses directly.""" +import requests +import json + +# Test each port +for port in [8000, 8001, 8002, 8003]: + print(f"\n{'='*60}") + print(f"Testing port {port}...") + print('='*60) + + url = f"http://localhost:{port}/v1/completions" + + # Test: simple completion + payload = { + "model": "Qwen/QwQ-32B", + "prompt": "What is 2+2? Answer:", + "max_tokens": 50, + "temperature": 0.0, + "stream": False, + } + + try: + resp = requests.post(url, json=payload, timeout=10) + print(f"Status: {resp.status_code}") + + if resp.status_code == 200: + result = resp.json() + print(f"Response keys: {result.keys()}") + if "choices" in result and len(result["choices"]) > 0: + choice = result["choices"][0] + text = choice.get("text", "") + print(f"Generated text: {repr(text[:100])}") + print(f"Text length: {len(text)}") + else: + print(f"No choices in response: {result}") + else: + print(f"Error response: {resp.text[:200]}") + except Exception as e: + print(f"Request failed: {e}") diff --git a/examples/TTSwithVerification/verina_utils.py b/examples/TTSwithVerification/verina_utils.py new file mode 100644 index 00000000..84db2338 --- /dev/null +++ b/examples/TTSwithVerification/verina_utils.py @@ -0,0 +1,649 @@ +import argparse +import asyncio +import json +import logging +import os +import re +import sys +import shutil +import subprocess +from contextlib import contextmanager +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional, Tuple +import numpy as np +import pandas as pd +import aiohttp +from datasets import load_dataset +from tqdm import tqdm +from transformers import AutoTokenizer + +__all__ = [ + # Path constants + "VERINA_ROOT", + "VERINA_DATASETS_PATH", + "LEAN_PLAYGROUND_DIR", + # Classes + "BenchmarkData", + # Constants + "CODE_TEST_MSG_MARKER", + "DECIDABLE_ERR_MSG", + # Functions + "parse_benchmark_lean_data", + "load_benchmark_data_from_task_dir", + "load_verina_dataset", + "render_param_list", + "build_verina_prompt", + "strip_function_definition", + "extract_code_from_response", + "create_lean_file", + "check_lean_compile", + "render_unit_test_value", + "render_code_unit_test", + "build_test_lean_file", + "parse_unit_test_results", + "evaluate_generated_code", +] + +_SCRIPT_DIR = Path(__file__).parent.resolve() +VERINA_ROOT = (_SCRIPT_DIR / "../../../verina").resolve() +VERINA_DATASETS_PATH = VERINA_ROOT / "datasets" / "verina" +LEAN_PLAYGROUND_DIR = VERINA_ROOT / "lean-playground" + + +class BenchmarkData: + """Verina benchmark data structure""" + def __init__(self, data_id: str, description: str, signature: dict, + lean_data: dict, spec_desc: dict, tests: list, metadata: dict): + self.data_id = data_id + self.description = description + self.signature = signature # {"name": str, "parameters": list, "return_type": str} + self.lean_data = lean_data # contains task_imports, task_aux, code, precond, postcond, proof, etc. + self.spec_desc = spec_desc # {"precond_desc": str, "postcond_desc": str} + self.tests = tests + self.metadata = metadata + + +def parse_benchmark_lean_data(raw_lean_data: str) -> dict: + """Parse a .lean file with !benchmark markers into sections""" + lines = raw_lean_data.strip().splitlines() + + lean_data = { + "task_imports": "", + "solution_imports": "", + "task_aux": "", + "solution_aux": "", + "code_aux": "", + "precond_aux": "", + "postcond_aux": "", + "proof_aux": "", + "code": "", + "precond": "True", + "postcond": "", + "proof": "sorry", + } + + current_section = None + current_content = [] + current_args = {} + + for line in lines: + if "-- !benchmark" in line: + marker_part = line.split("-- !benchmark", 1)[1].strip() + + if marker_part.startswith("@start"): + # Save previous section if any + if current_section is not None: + content = "\n".join(current_content).strip() + if current_section == "import": + import_type = current_args.get("type", "task") + if import_type == "task": + lean_data["task_imports"] += content + "\n" + elif import_type == "solution": + lean_data["solution_imports"] += content + "\n" + elif current_section in lean_data: + lean_data[current_section] = content + + # Start new section + parts = marker_part.split("@start", 1)[1].strip().split(None, 1) + current_section = parts[0].strip() + current_args = {} + current_content = [] + + if len(parts) > 1: + for arg in parts[1].strip().split(): + if "=" in arg: + key, value = arg.split("=", 1) + current_args[key] = value + + elif marker_part.startswith("@end"): + if current_section is not None: + content = "\n".join(current_content).strip() + if current_section == "import": + import_type = current_args.get("type", "task") + if import_type == "task": + lean_data["task_imports"] += content + "\n" + elif import_type == "solution": + lean_data["solution_imports"] += content + "\n" + elif current_section in lean_data: + lean_data[current_section] = content + current_section = None + current_content = [] + current_args = {} + else: + if current_section is not None: + current_content.append(line) + + return lean_data + + +def load_benchmark_data_from_task_dir(task_dir: Path) -> Optional[BenchmarkData]: + """Load a single benchmark task from its directory""" + task_path = task_dir / "task.json" + if not task_path.exists(): + return None + + try: + with open(task_path, "r") as f: + task_data = json.load(f) + + task_id = task_data.get("id") + if not task_id: + return None + + # Read description + desc_path = task_dir / task_data.get("description_file", "description.txt") + description = desc_path.read_text().strip() if desc_path.exists() else "" + + # Read signature + signature = task_data.get("signature", {}) + + # Read lean file + lean_path = task_dir / task_data.get("lean_file", "task.lean") + if lean_path.exists(): + lean_data = parse_benchmark_lean_data(lean_path.read_text()) + else: + lean_data = {} + + # Read spec description + spec_desc = { + "precond_desc": task_data.get("specification", {}).get("preconditions", ""), + "postcond_desc": task_data.get("specification", {}).get("postconditions", ""), + } + + # Read tests + test_path = task_dir / task_data.get("test_file", "test.json") + tests = [] + if test_path.exists(): + with open(test_path, "r") as f: + tests = json.load(f) + + metadata = task_data.get("metadata", {}) + + return BenchmarkData( + data_id=task_id, + description=description, + signature=signature, + lean_data=lean_data, + spec_desc=spec_desc, + tests=tests, + metadata=metadata, + ) + except Exception as e: + return None + + +def load_verina_dataset() -> List[BenchmarkData]: + """Load all verina benchmark tasks from the datasets directory""" + results = [] + + # Get all task directories sorted by ID + task_dirs = sorted( + [d for d in VERINA_DATASETS_PATH.glob("verina_*") if d.is_dir()], + key=lambda x: (x.name.split("_")[1], int(x.name.split("_")[-1])), + ) + + for task_dir in task_dirs: + data = load_benchmark_data_from_task_dir(task_dir) + if data: + results.append(data) + + return results + + +def render_param_list(signature: dict) -> str: + """Render the parameter list for a function signature""" + params = signature.get("parameters", []) + rendered = "" + for param in params: + rendered += f"({param['param_name']} : {param['param_type']}) " + return rendered.strip() + + +def build_verina_prompt(data: BenchmarkData) -> str: + """Build prompt for Verina Lean code generation (build_code_gen_prompt + build_full_prompt)""" + system_prompt = "You are an expert Lean 4 programmer. Generate valid Lean 4 code for the function body. Wrap your final code in [CODE] [/CODE] tags strictly." + + signature = data.signature + func_name = signature.get("name", "solution") + return_type = signature.get("return_type", "Bool") + param_list = render_param_list(signature) + params = signature.get("parameters", []) + + precond_name = f"{func_name}_precond" + param_names_str = ' '.join([f"({p['param_name']})" for p in params]) + + # Get auxiliary definitions (only if they exist) + solution_aux = data.lean_data.get("solution_aux", "").strip() + task_aux = data.lean_data.get("task_aux", "").strip() + code_aux = data.lean_data.get("code_aux", "").strip() + precond = data.lean_data.get("precond", "True").strip() + postcond = data.lean_data.get("postcond", "").strip() + + # Build helper section only if there are helpers + helper_section = "" + all_aux = "\n".join(filter(None, [solution_aux, task_aux, code_aux])) + if all_aux: + helper_section = f""" +## Helper Definitions +```lean4 +{all_aux} +``` +""" + + user_prompt = f"""## Task +{data.description} + +## Function Signature +```lean4 +def {func_name} {param_list} (h_precond : {precond_name} {param_names_str}) : {return_type} := + -- YOUR CODE HERE (just output this part inside [CODE] [/CODE] tags) +``` + +## Precondition +```lean4 +def {precond_name} {param_list} : Prop := {precond} +``` + +## Postcondition +```lean4 +def {func_name}_postcond {param_list} (result: {return_type}) : Prop := {postcond} +``` +{helper_section} +Provide ONLY the function body expression wrapped in [CODE]...[/CODE] tags.""" + + return ( + f"<|im_start|>system\n{system_prompt}<|im_end|>\n" + f"<|im_start|>user\n{user_prompt}<|im_end|>\n" + f"<|im_start|>assistant\n" + ) + + +def strip_function_definition(code: str) -> str: + """ + Strip function definition prefix if the model accidentally included it. + + The prompt asks for just the function body, but sometimes the model outputs: + def FunctionName (params) (h_precond : ...) : ReturnType := + actual_body + + We need to extract just 'actual_body' and dedent it properly. + """ + import textwrap + + code = code.strip() + + # Pattern to match Lean function definition: + # def (h_precond : ) : := + # The function body follows after := + func_def_pattern = r'^def\s+\w+\s+.*?:=[ \t]*\n?' + + match = re.match(func_def_pattern, code, re.DOTALL) + if match: + # Extract everything after the := + body = code[match.end():] + # Dedent to remove common leading whitespace from all lines + body = textwrap.dedent(body).strip() + if body: + return body + + return code + + +def extract_code_from_response(response: str) -> str: + """Extract code from the LAST [CODE]...[/CODE] tags or lean code blocks. + + Handles cases where: + 1. Response has ... reasoning block + 2. [CODE] tag exists but [/CODE] may be missing (truncated response) + 3. Code is in markdown lean blocks + 4. Model outputs [CORE] or other variants instead of [CODE] + 5. Model uses mismatched tags like [CORE]...[/CODE] + 6. Model includes full function definition instead of just the body + """ + # Step 1: Remove ... block entirely (case insensitive) + # This prevents extracting reasoning text as code + cleaned = re.sub(r'.*?', '', response, flags=re.DOTALL | re.IGNORECASE) + + # If exists but doesn't match (partial), take everything after + if not cleaned.strip() or cleaned.strip() == response.strip(): + think_end = response.lower().rfind("") + if think_end != -1: + cleaned = response[think_end + len(""):] + + extracted_code = None + + # Step 2: Find the LAST closing tag [/CODE] or [/CORE] and work backwards to find opening tag + # This handles mismatched tags like [CORE]...[/CODE] + closing_pattern = r'\[/(?:CODE|CORE)\]' + closing_matches = list(re.finditer(closing_pattern, cleaned, re.IGNORECASE)) + + if closing_matches: + # Get position of the last closing tag + last_close = closing_matches[-1] + close_pos = last_close.start() + + # Search backwards for the last opening tag before this closing tag + text_before_close = cleaned[:close_pos] + opening_pattern = r'\[(?:CODE|CORE|CORRECTED CODE)\]' + opening_matches = list(re.finditer(opening_pattern, text_before_close, re.IGNORECASE)) + + if opening_matches: + last_open = opening_matches[-1] + extracted_code = cleaned[last_open.end():close_pos].strip() + + # Step 3: Try [CODE] without closing tag (truncated response) - find the LAST one + if extracted_code is None: + code_start_matches = list(re.finditer(r'\[(?:CODE|CORE|CORRECTED CODE)\]\s*', cleaned, re.DOTALL | re.IGNORECASE)) + if code_start_matches: + # Get the last [CODE] tag position and extract everything after it + last_match = code_start_matches[-1] + code = cleaned[last_match.end():].strip() + # Remove any trailing incomplete text that looks like reasoning + # Stop at any line that looks like it's not code (e.g., starts with "Wait", "So", etc.) + lines = code.split('\n') + code_lines = [] + for line in lines: + stripped = line.strip() + # Stop if we hit obvious non-code reasoning text + if stripped and re.match(r'^(Wait|So |But |Now|Note|The |This |However|Therefore|Thus|In |Since)', stripped): + break + code_lines.append(line) + if code_lines: + extracted_code = '\n'.join(code_lines).strip() + + # Step 4: Try markdown lean code blocks (find the LAST one) + if extracted_code is None: + lean_matches = list(re.finditer(r'```lean4?\s*\n(.*?)```', cleaned, re.DOTALL | re.IGNORECASE)) + if lean_matches: + extracted_code = lean_matches[-1].group(1).strip() + + # Step 5: Try lean block without closing (truncated) - find the LAST one + if extracted_code is None: + lean_start_matches = list(re.finditer(r'```lean4?\s*\n', cleaned, re.DOTALL | re.IGNORECASE)) + if lean_start_matches: + last_match = lean_start_matches[-1] + code = cleaned[last_match.end():].strip() + # Remove trailing ``` if present + extracted_code = re.sub(r'```\s*$', '', code).strip() + + # Step 6: Last resort, return cleaned content if it looks like code + if extracted_code is None: + cleaned = cleaned.strip() + if cleaned: + # Filter out lines that look like reasoning + lines = cleaned.split('\n') + code_lines = [] + for line in lines: + stripped = line.strip() + if stripped and not re.match(r'^(Wait|So |But |Now|Note|The |This |However|Therefore|Thus|In |Since|I |We |You )', stripped): + code_lines.append(line) + if code_lines: + extracted_code = '\n'.join(code_lines).strip() + + # Step 7: Strip function definition prefix if model included it + # The prompt asks for just the body, but sometimes model outputs full "def ... :=" + if extracted_code: + extracted_code = strip_function_definition(extracted_code) + return extracted_code + + return "" + + +def create_lean_file(file_name: str, content: str) -> Path: + """Create a lean file in the playground directory""" + LEAN_PLAYGROUND_DIR.mkdir(parents=True, exist_ok=True) + lean_file = LEAN_PLAYGROUND_DIR / f"{file_name}.lean" + with open(lean_file, "w") as f: + f.write(content) + return lean_file + + +def check_lean_compile(lean_file: Path, timeout: int = 120) -> Tuple[bool, str]: + """Check if the Lean file compiles successfully""" + try: + result = subprocess.run( + ["lake", "lean", str(lean_file)], + check=False, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + timeout=timeout, + cwd=VERINA_ROOT, + ) + + output = result.stdout.decode() + "\n" + result.stderr.decode() + + if result.returncode == 0: + return True, output + else: + return False, output + + except subprocess.TimeoutExpired: + return False, "TIMEOUT" + except Exception as e: + return False, f"ERROR: {e}" + + +# Unit Test Rendering +CODE_TEST_MSG_MARKER = "code_test" +DECIDABLE_ERR_MSG = "did not evaluate to `true`" + + +def render_unit_test_value(lean_type: str, value) -> str: + """Convert a Python value to Lean syntax based on type""" + if lean_type == "Bool": + return str(value).lower() + elif lean_type == "String": + return f'"{value}"' + elif lean_type == "Char": + return f"'{value}'" + else: + return str(value) + + +def render_code_unit_test(signature: dict, test_case: dict, test_idx: int) -> str: + """Render a single unit test using #guard""" + func_name = signature.get("name", "solution") + params = signature.get("parameters", []) + return_type = signature.get("return_type", "Bool") + + rendered = f'#print "<{CODE_TEST_MSG_MARKER}>{test_idx}"\n\n' + rendered += f"#guard {func_name}" + + for param in params: + param_name = param["param_name"] + param_type = param["param_type"] + input_value = test_case["input"].get(param_name, "") + rendered += f" ({render_unit_test_value(param_type, input_value)})" + + # Add (by sorry) to satisfy precondition hypothesis + rendered += " (by sorry)" + + # Add expected value comparison + expected = test_case.get("expected", "") + rendered += f" == ({render_unit_test_value(return_type, expected)})" + + return rendered + + +def build_test_lean_file(data: BenchmarkData, generated_code: str, include_unit_tests: bool = True) -> str: + """Build a complete Lean file to test the generated code""" + signature = data.signature + func_name = signature.get("name", "solution") + return_type = signature.get("return_type", "Bool") + param_list = render_param_list(signature) + params = signature.get("parameters", []) + param_names = " ".join([f"({p['param_name']})" for p in params]) + + # Indent multiline generated_code so all lines have proper indentation + # First line gets 2 spaces from template, subsequent lines need explicit indentation + if '\n' in generated_code: + lines = generated_code.split('\n') + # First line has no extra indent (template adds 2 spaces) + # Subsequent lines need 2 spaces prepended + indented_lines = [lines[0]] + [' ' + line if line.strip() else line for line in lines[1:]] + generated_code = '\n'.join(indented_lines) + + # Build imports - include both task and solution imports + task_imports = data.lean_data.get("task_imports", "").strip() + solution_imports = data.lean_data.get("solution_imports", "").strip() + imports = task_imports + if solution_imports: + imports += "\n" + solution_imports + if "import Mathlib" not in imports: + imports = "import Mathlib\n" + imports + + # Build auxiliary definitions - include solution_aux which has helper functions + solution_aux = data.lean_data.get("solution_aux", "").strip() + task_aux = data.lean_data.get("task_aux", "").strip() + precond_aux = data.lean_data.get("precond_aux", "").strip() + postcond_aux = data.lean_data.get("postcond_aux", "").strip() + code_aux = data.lean_data.get("code_aux", "").strip() + + # Build precondition + precond = data.lean_data.get("precond", "True").strip() + precond_name = f"{func_name}_precond" + + # Build postcondition + postcond = data.lean_data.get("postcond", "").strip() + postcond_name = f"{func_name}_postcond" + + lean_content = f"""{imports} + +-- Solution auxiliary definitions (helper functions) +{solution_aux} + +-- Task auxiliary definitions +{task_aux} + +-- Precondition auxiliary definitions +{precond_aux} + +@[reducible, simp] +def {precond_name} {param_list} : Prop := + {precond} + +-- Postcondition auxiliary definitions +{postcond_aux} + +-- Code auxiliary definitions +{code_aux} + +def {func_name} {param_list} (h_precond : {precond_name} {param_names}) : {return_type} := + {generated_code} + +@[reducible, simp] +def {postcond_name} {param_list} (result: {return_type}) (h_precond : {precond_name} {param_names}) : Prop := + {postcond} + +-- Verification theorem (compilation test) +-- If this compiles, the code at least type-checks +#check {func_name} +""" + + # Add unit tests if requested + if include_unit_tests and data.tests: + lean_content += "\n-- Unit Tests\n" + for idx, test_case in enumerate(data.tests): + lean_content += "\n" + render_code_unit_test(signature, test_case, idx) + "\n" + + return lean_content + + +def parse_unit_test_results(compile_output: str, num_tests: int) -> Tuple[int, int, dict]: + """ + Parse the compilation output to determine which unit tests passed/failed. + + Returns: (num_passed, num_failed, test_results_dict) + """ + test_results = {} + + # If compilation succeeded with no errors, all tests passed + if "error" not in compile_output.lower(): + for idx in range(num_tests): + test_results[idx] = "pass" + return num_tests, 0, test_results + + # Parse the output to find which tests failed + # Look for markers like 0 followed by error messages + code_test_start = f"<{CODE_TEST_MSG_MARKER}>" + code_test_end = f"" + + # Split by start marker to get test sections + parts = compile_output.split(code_test_start) + + # Build a map of test index to message + test_messages = {} + for part in parts[1:]: # Skip first part (before any marker) + if code_test_end in part: + idx_str, rest = part.split(code_test_end, 1) + try: + test_idx = int(idx_str.strip()) + test_messages[test_idx] = rest + except ValueError: + continue + + num_passed = 0 + num_failed = 0 + + for idx in range(num_tests): + msg = test_messages.get(idx, "") + if DECIDABLE_ERR_MSG in msg: + test_results[idx] = "fail" + num_failed += 1 + elif "error" in msg.lower(): + # Some other error (e.g., type mismatch) - count as fail + test_results[idx] = "error" + num_failed += 1 + else: + test_results[idx] = "pass" + num_passed += 1 + + return num_passed, num_failed, test_results + + +def evaluate_generated_code(data: BenchmarkData, generated_code: str, task_idx: int) -> Tuple[bool, bool, str, dict]: + """ + Evaluate the generated code by compiling it with Lean and running unit tests. + + Returns: (compiles, all_tests_pass, output, test_results) + """ + lean_content = build_test_lean_file(data, generated_code, include_unit_tests=True) + + # Create lean file + lean_file = create_lean_file(f"test_{data.data_id}_{task_idx}", lean_content) + + # Check compilation (which also runs unit tests via #guard) + compiles, output = check_lean_compile(lean_file) + + # Parse unit test results + num_tests = len(data.tests) if data.tests else 0 + if num_tests > 0: + num_passed, num_failed, test_results = parse_unit_test_results(output, num_tests) + all_tests_pass = (num_failed == 0) and compiles + else: + # No tests, just check compilation + test_results = {} + all_tests_pass = compiles + + return compiles, all_tests_pass, output, test_results