diff --git a/README.md b/README.md index 03a763e..3b79ab9 100644 --- a/README.md +++ b/README.md @@ -39,10 +39,10 @@ Running verifier-guided inference requires only a few lines of code: just specif **Set up target LLM server** ```bash python -m vllm.entrypoints.openai.api_server \ - --model Qwen/Qwen3-30B-A3B-Thinking-2507 \ - --max-model-len 65536 \ - --port 8000 \ - --tensor-parallel-size 8 + --model microsoft/Phi-4-reasoning \ + --max-model-len 32768 \ + --port 8001 \ + --tensor-parallel-size 2 ``` **Generate answer enabled with given monitors** diff --git a/examples/EarlyStopping/game24_example.py b/examples/EarlyStopping/game24_example.py index 7872527..53e8ff8 100644 --- a/examples/EarlyStopping/game24_example.py +++ b/examples/EarlyStopping/game24_example.py @@ -76,6 +76,7 @@ def init_llm_server(modelname, max_tokens=200, port=8000): "top_k": 20, "top_p": 0.95, "min_p": 0.0, + "do_sample" : True, "temperature": 0.6, "stream": True, "logprobs": 20, @@ -113,10 +114,22 @@ def count_tokens(text, tokenizer): def extract_solution(text): + # Only search for \boxed{} AFTER to avoid grabbing unverified + # expressions from inside the thinking trace. + # If model opened but never closed it (hit token limit), there is + # no final answer — return None. + if '' in text: + search_text = text[text.rfind(''):] + elif '' in text: + # Model started thinking but never finished — no verified answer + return None + else: + search_text = text + # Use a more robust extraction that handles nested braces in \boxed{} # Find \boxed{ and then match braces properly boxed_pattern = r"\\boxed\{" - matches = list(re.finditer(boxed_pattern, text)) + matches = list(re.finditer(boxed_pattern, search_text)) if not matches: return None @@ -125,14 +138,18 @@ def extract_solution(text): start = last_match.end() # Position right after \boxed{ brace_count = 1 end = start - while end < len(text) and brace_count > 0: - if text[end] == '{': + while end < len(search_text) and brace_count > 0: + if search_text[end] == '{': brace_count += 1 - elif text[end] == '}': + elif search_text[end] == '}': brace_count -= 1 end += 1 - expr = text[start:end-1].strip() # -1 to exclude the closing brace + expr = search_text[start:end-1].strip() # -1 to exclude the closing brace + + # Skip empty \boxed{} (e.g., from verifier feedback "Wrap in \boxed{}.") + if not expr: + return None # 1. Convert \frac{a}{b} to (a/b) frac_pattern = r"\\frac\{([^{}]+)\}\{([^{}]+)\}" @@ -148,8 +165,16 @@ def extract_solution(text): for latex, op in replacements.items(): expr = expr.replace(latex, op) - # 3. Cleanup (remove LaTeX spacing) + # 2b. Replace Unicode math operators (QwQ frequently uses these) + expr = expr.replace('\u00d7', '*').replace('\u00f7', '/').replace('\u2212', '-') + expr = expr.replace('\u2013', '-').replace('\u2014', '-') # en-dash, em-dash + + # 3. Cleanup (remove LaTeX formatting artifacts) expr = expr.replace(r"\,", "").replace(r"\ ", "") + expr = expr.replace(r"\left", "").replace(r"\right", "") + + # 3b. Strip trailing "= " (e.g., "10 - 8/8 * 1 = 24" -> "10 - 8/8 * 1") + expr = re.sub(r'\s*=\s*[\d.]+\s*$', '', expr) # 4. Handle implicit multiplication (e.g., "(11+1)(1+1)" -> "(11+1)*(1+1)") # Insert * between: )( , )number, number(, )( @@ -183,7 +208,6 @@ def evaluate_expression(expr, expected_nums=None): except Exception: return False - def evaluate_game24_answer(answer, nums): """ Evaluate a Game24 answer and return (is_correct, expr, error_message). @@ -214,7 +238,7 @@ def evaluate_game24_answer(answer, nums): parser = argparse.ArgumentParser(description="Game of 24 step-by-step solver with monitors") parser.add_argument("--thinking", "-t", action="store_true", help="Enable chain-of-thought output") - parser.add_argument("--monitor", "-m", default = True, action="store_true", help="Enable step-by-step monitor") + parser.add_argument("--monitor", "-m", default = False, action="store_true", help="Enable step-by-step monitor") parser.add_argument("--num_examples", "-n", type=int, default=1362, help="Number of examples to run") parser.add_argument("--debug", "-d", action="store_true", help="Enable debug logs") parser.add_argument("--main_model", type=str, default=MAIN_MODEL, help="Main model to use for generation") @@ -249,7 +273,7 @@ def evaluate_game24_answer(answer, nums): dataset = load_game24_dataset() - llm_server = init_llm_server(main_model, max_tokens=32768) + llm_server = init_llm_server(main_model, max_tokens=32768, port=8000) # Load tokenizer for accurate token counting logger.info(f"Loading tokenizer for {main_model}...") @@ -273,23 +297,23 @@ def evaluate_game24_answer(answer, nums): if args.monitor: # Use K-stable answer monitor to detect when equation stabilizes k times # monitors = (SimpleTextReplaceMonitor("IsCheck", "", async_execution=False),) - # monitors=(KstableAnswerGame24Monitor( - # name="game24_kstable", - # k=3, - # expected_nums=nums, # Validate equations use exactly these numbers - # answer_start_token="" - # ),) - monitors = ( - EATMonitor( - name="EAT_monitor", - model_name=earlystop_model, - alpha=0.2, - delta=0.02, - min_steps=4, - answer_start_token="", - async_execution=True - ), - ) + monitors=(KstableAnswerGame24Monitor( + name="game24_kstable", + k=2, + expected_nums=nums, # Validate equations use exactly these numbers + answer_start_token="" + ),) + # monitors = ( + # EATMonitor( + # name="EAT_monitor", + # model_name=earlystop_model, + # alpha=0.2, + # delta=0.02, + # min_steps=4, + # answer_start_token="", + # async_execution=True + # ), + # ) else: monitors = () @@ -297,6 +321,21 @@ def evaluate_game24_answer(answer, nums): logger.info(f"---- Example {idx+1} ----") logger.info(f"Numbers: {nums}") + # system_prompt = ( + # "You are Phi, a language model trained by Microsoft to help users. " + # "Your role as an assistant involves thoroughly exploring questions through a systematic thinking process " + # "before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle " + # "of analysis, summarizing, exploration, reassessment, reflection, backtracing, and iteration to develop " + # "well-considered thinking process. Please structure your response into two main sections: Thought and Solution " + # "using the specified format: {Thought section} {Solution section}. In the Thought section, " + # "detail your reasoning process in steps. Each step should include detailed considerations such as analysing " + # "questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, " + # "refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, " + # "explorations, and reflections from the Thought section, systematically present the final solution that you " + # "deem correct. The Solution section should be logical, accurate, and concise and detail necessary steps needed " + # "to reach the conclusion. Now, try to solve the following question through the above guidelines." + # ) + answer = asyncio.run(stream_completion( f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n", llm_server=llm_server, diff --git a/examples/EarlyStopping/maze_example.py b/examples/EarlyStopping/maze_example.py index 74c0465..8cccfda 100644 --- a/examples/EarlyStopping/maze_example.py +++ b/examples/EarlyStopping/maze_example.py @@ -28,7 +28,7 @@ def get_model_short_name(model_name: str) -> str: short_name = short_name.replace(" ", "_").replace(":", "-") return short_name -def get_output_dirs(main_model: str, base_dir: str = "../../Outputs/MazeResults"): +def get_output_dirs(main_model: str, base_dir: str = "../Outputs/MazeResults"): """Create and return output directory paths based on model name.""" model_short_name = get_model_short_name(main_model) output_base = os.path.join(base_dir, model_short_name) @@ -46,14 +46,14 @@ def get_output_dirs(main_model: str, base_dir: str = "../../Outputs/MazeResults" return dirs -def get_log_filename(main_model: str, num_examples: int, base_dir: str = "../../Outputs/MazeResults") -> str: +def get_log_filename(main_model: str, num_examples: int, base_dir: str = "../Outputs/MazeResults") -> str: """Generate log filename based on model name.""" model_short_name = get_model_short_name(main_model) output_base = os.path.join(base_dir, model_short_name) os.makedirs(output_base, exist_ok=True) return os.path.join(output_base, f"EAT_{num_examples}examples.log") -def get_token_filename(main_model: str, num_examples: int, base_dir: str = "../../Outputs/MazeResults") -> str: +def get_token_filename(main_model: str, num_examples: int, base_dir: str = "../Outputs/MazeResults") -> str: """Generate token CSV filename based on model name.""" model_short_name = get_model_short_name(main_model) output_base = os.path.join(base_dir, model_short_name) @@ -66,10 +66,10 @@ def remove_last_paragraph(s: str) -> str: logger = logging.getLogger(__name__) def load_maze_dataset(split="val"): - ds = load_dataset("microsoft/VISION_LANGUAGE", "maze", split=split) + ds = load_dataset("microsoft/VISION_LANGUAGE", "maze_text_only", split=split) return ds -def init_llm_server(modelname, max_tokens=200, port=8000): # +def init_llm_server(modelname, max_tokens=200, port=8000): url = f"http://localhost:{port}/v1/completions" payload = { "model": modelname, @@ -101,19 +101,26 @@ def build_prompt_from_example(example): #(original prompt config) return pre_prompt , description -def extract_solution(text): - matches = re.findall(r"\\boxed\{([^}]*)\}", text) - if not matches: - return None - - expr = matches[-1].strip() # take last boxed content - - # find one of A/B/C/D inside the boxed content - choice_match = re.search(r"\b([ABCD])\b", expr, flags=re.IGNORECASE) - if not choice_match: - return None - - return choice_match.group(1).upper() +def extract_solution_mcq(text): + """Extract MCQ solution from model output.""" + # Try multiple boxed patterns + patterns = [ + r"\\boxed\{([^}]*)\}", # \boxed{...} + r"boxed\{([^}]*)\}", # boxed{...} without escape + r"\*\*([A-D])\*\*", # **A** format + r"answer[:\s]*([A-D])", # answer: A format + r"(?:^|\n)([A-D])(?:\s|$|\.)", # Standalone letter + ] + + for pattern in patterns: + matches = re.findall(pattern, text, re.IGNORECASE) + if matches: + expr = matches[-1].strip() + choice_match = re.search(r"\b([ABCD])\b", expr, flags=re.IGNORECASE) + if choice_match: + return choice_match.group(1).upper() + + return None def save_prompt(idx, prompt_with_answer, reason_dir): filename = os.path.join(reason_dir, f"reason_{idx}.txt") @@ -127,54 +134,30 @@ def count_tokens(text, tokenizer): return len(tokens) -def evaluate_maze_answer(answer, options, ground_truth): - """ - Evaluate a Maze MCQ answer and return (is_correct, extracted_answer, message). - - Args: - answer: Raw model output - options: Dictionary mapping option letters (A/B/C/D) to their values - ground_truth: The correct answer value - - Returns: - Tuple of (is_correct, extracted_answer, message) - """ - sol = extract_solution(answer) +def evaluate_mcq_answer(answer, options, ground_truth): + sol = extract_solution_mcq(answer) gt_sol = str(ground_truth).strip() - if not sol: return False, None, "No expression found" - sol = sol.strip() - - # Case 1: LLM returned option letter (A/B/C/D) if sol in options: if options[sol] == gt_sol: return True, sol, f"Correct: option {sol} -> {options[sol]}" - else: - return False, sol, f"Incorrect: expected '{gt_sol}', got '{options[sol]}' (option {sol})" - - # Case 2: LLM returned the actual answer text - # First check if sol matches ground truth directly + return False, sol, f"Incorrect: expected '{gt_sol}', got '{options[sol]}' (option {sol})" if sol.lower() == gt_sol.lower(): return True, sol, f"Correct: answer text matches ground truth: {sol}" - - # Check if sol matches any option value for opt_letter, opt_value in options.items(): if sol.lower() == opt_value.lower(): if opt_value == gt_sol: return True, sol, f"Correct: answer text {sol} (option {opt_letter})" - else: - return False, sol, f"Incorrect: expected '{gt_sol}', got '{opt_value}' (option {opt_letter})" - + return False, sol, f"Incorrect: expected '{gt_sol}', got '{opt_value}' (option {opt_letter})" return False, sol, f"Solution '{sol}' not found in options or ground truth" - if __name__ == "__main__": parser = argparse.ArgumentParser(description="Maze problem solver with LLM and monitors") parser.add_argument("--thinking", "-t", action="store_true", help="Enable chain-of-thought output") - parser.add_argument("--monitor", "-m", default = True, action="store_true", help="Enable step-by-step monitor") + parser.add_argument("--monitor", "-m", default = False, action="store_true", help="Enable step-by-step monitor") parser.add_argument("--num_examples", "-n", type=int, default=1500, help="Number of examples to run") parser.add_argument("--debug", "-d", action="store_true", help="Enable debug logs") parser.add_argument("--main_model", type=str, default=MAIN_MODEL, help="Main model to use for generation") @@ -207,7 +190,7 @@ def evaluate_maze_answer(answer, options, ground_truth): dataset = load_maze_dataset() - llm_server = init_llm_server(main_model, max_tokens=15000) + llm_server = init_llm_server(main_model, max_tokens=32768) # Load tokenizer for accurate token counting logger.info(f"Loading tokenizer for {main_model}...") @@ -219,7 +202,7 @@ def evaluate_maze_answer(answer, options, ground_truth): total_generated_tokens = 0 generated_token_counts = [] total = len(dataset) - indices = np.linspace(3000, total-1, N, dtype=int).tolist() + indices = np.linspace(0, total-1, N, dtype=int).tolist() for idx in indices: example = dataset[idx] @@ -268,7 +251,7 @@ def evaluate_maze_answer(answer, options, ground_truth): # Evaluate the answer gt_sol = str(example.get("ground_truth", "")).strip() - is_correct, extracted_answer, message = evaluate_maze_answer(answer, options, gt_sol) + is_correct, extracted_answer, message = evaluate_mcq_answer(answer, options, gt_sol) if extracted_answer: logger.info(f"Extracted answer: {extracted_answer}") diff --git a/examples/EarlyStopping/spatialmap_example.py b/examples/EarlyStopping/spatialmap_example.py index c3925c3..195ea3f 100644 --- a/examples/EarlyStopping/spatialmap_example.py +++ b/examples/EarlyStopping/spatialmap_example.py @@ -28,7 +28,7 @@ def get_model_short_name(model_name: str) -> str: short_name = short_name.replace(" ", "_").replace(":", "-") return short_name -def get_output_dirs(main_model: str, base_dir: str = "../../Outputs/SpatialMap_results2"): +def get_output_dirs(main_model: str, base_dir: str = "../Outputs/SpatialMap_results"): """Create and return output directory paths based on model name.""" model_short_name = get_model_short_name(main_model) output_base = os.path.join(base_dir, model_short_name) @@ -46,14 +46,14 @@ def get_output_dirs(main_model: str, base_dir: str = "../../Outputs/SpatialMap_r return dirs -def get_log_filename(main_model: str, num_examples: int, base_dir: str = "../../Outputs/SpatialMap_results2") -> str: +def get_log_filename(main_model: str, num_examples: int, base_dir: str = "../Outputs/SpatialMap_results") -> str: """Generate log filename based on model name.""" model_short_name = get_model_short_name(main_model) output_base = os.path.join(base_dir, model_short_name) os.makedirs(output_base, exist_ok=True) return os.path.join(output_base, f"EAT_{num_examples}examples.log") -def get_token_filename(main_model: str, num_examples: int, base_dir: str = "../../Outputs/SpatialMap_results2") -> str: +def get_token_filename(main_model: str, num_examples: int, base_dir: str = "../Outputs/SpatialMap_results") -> str: """Generate token CSV filename based on model name.""" model_short_name = get_model_short_name(main_model) output_base = os.path.join(base_dir, model_short_name) @@ -99,19 +99,27 @@ def build_prompt_from_example(example): description = remove_last_paragraph(description) return pre_prompt , description -def extract_solution(text): - matches = re.findall(r"\\boxed\{([^}]*)\}", text) - if not matches: - return None - - expr = matches[-1].strip() # take last boxed content - - # find one of A/B/C/D inside the boxed content - choice_match = re.search(r"\b([ABCD])\b", expr, flags=re.IGNORECASE) - if not choice_match: - return None - - return choice_match.group(1).upper() +def extract_solution_mcq(text): + """Extract MCQ solution from model output.""" + # Try multiple boxed patterns + patterns = [ + r"\\boxed\{([^}]*)\}", # \boxed{...} + r"boxed\{([^}]*)\}", # boxed{...} without escape + r"\*\*([A-D])\*\*", # **A** format + r"answer[:\s]*([A-D])", # answer: A format + r"(?:^|\n)([A-D])(?:\s|$|\.)", # Standalone letter + ] + + for pattern in patterns: + matches = re.findall(pattern, text, re.IGNORECASE) + if matches: + expr = matches[-1].strip() + choice_match = re.search(r"\b([ABCD])\b", expr, flags=re.IGNORECASE) + if choice_match: + return choice_match.group(1).upper() + + return None + def save_prompt(idx, prompt_with_answer, reason_dir): filename = os.path.join(reason_dir, f"reason_{idx}.txt") @@ -125,46 +133,23 @@ def count_tokens(text, tokenizer): return len(tokens) -def evaluate_spatialmap_answer(answer, options, ground_truth): - """ - Evaluate a SpatialMap MCQ answer and return (is_correct, extracted_answer, message). - - Args: - answer: Raw model output - options: Dictionary mapping option letters (A/B/C/D) to their values - ground_truth: The correct answer value - - Returns: - Tuple of (is_correct, extracted_answer, message) - """ - sol = extract_solution(answer) +def evaluate_mcq_answer(answer, options, ground_truth): + sol = extract_solution_mcq(answer) gt_sol = str(ground_truth).strip() - if not sol: return False, None, "No expression found" - sol = sol.strip() - - # Case 1: LLM returned option letter (A/B/C/D) if sol in options: if options[sol] == gt_sol: return True, sol, f"Correct: option {sol} -> {options[sol]}" - else: - return False, sol, f"Incorrect: expected '{gt_sol}', got '{options[sol]}' (option {sol})" - - # Case 2: LLM returned the actual answer text - # First check if sol matches ground truth directly + return False, sol, f"Incorrect: expected '{gt_sol}', got '{options[sol]}' (option {sol})" if sol.lower() == gt_sol.lower(): return True, sol, f"Correct: answer text matches ground truth: {sol}" - - # Check if sol matches any option value for opt_letter, opt_value in options.items(): if sol.lower() == opt_value.lower(): if opt_value == gt_sol: return True, sol, f"Correct: answer text {sol} (option {opt_letter})" - else: - return False, sol, f"Incorrect: expected '{gt_sol}', got '{opt_value}' (option {opt_letter})" - + return False, sol, f"Incorrect: expected '{gt_sol}', got '{opt_value}' (option {opt_letter})" return False, sol, f"Solution '{sol}' not found in options or ground truth" @@ -172,7 +157,7 @@ def evaluate_spatialmap_answer(answer, options, ground_truth): parser = argparse.ArgumentParser(description="SpatialMap problem solver with LLM and monitors") parser.add_argument("--thinking", "-t", action="store_true", help="Enable chain-of-thought output") - parser.add_argument("--monitor", "-m", default = True, action="store_true", help="Enable step-by-step monitor") + parser.add_argument("--monitor", "-m", default = False, action="store_true", help="Enable step-by-step monitor") parser.add_argument("--num_examples", "-n", type=int, default=1500, help="Number of examples to run") parser.add_argument("--debug", "-d", action="store_true", help="Enable debug logs") parser.add_argument("--main_model", type=str, default=MAIN_MODEL, help="Main model to use for generation") @@ -207,7 +192,7 @@ def evaluate_spatialmap_answer(answer, options, ground_truth): dataset = load_maze_dataset() - llm_server = init_llm_server(main_model, max_tokens=15000) + llm_server = init_llm_server(main_model, max_tokens=32768) # Load tokenizer for accurate token counting logger.info(f"Loading tokenizer for {main_model}...") @@ -268,7 +253,7 @@ def evaluate_spatialmap_answer(answer, options, ground_truth): # Evaluate the answer gt_sol = str(example.get("ground_truth", "")).strip() - is_correct, extracted_answer, message = evaluate_spatialmap_answer(answer, options, gt_sol) + is_correct, extracted_answer, message = evaluate_mcq_answer(answer, options, gt_sol) if extracted_answer: logger.info(f"Extracted answer: {extracted_answer}") diff --git a/examples/README.md b/examples/README.md index 79f5ea0..fdbc6bd 100644 --- a/examples/README.md +++ b/examples/README.md @@ -5,9 +5,8 @@ Running verifier-guided inference requires only a few lines of code: just specif ```bash python -m vllm.entrypoints.openai.api_server \ --model Qwen/Qwen3-30B-A3B-Thinking-2507 \ - --max-model-len 65536 \ --port 8000 \ - --tensor-parallel-size 8 + --tensor-parallel-size 4 ``` **Generate answer enabled with given monitors** diff --git a/examples/TTSwithVerification/bestofk_baseline.py b/examples/TTSwithVerification/bestofk_baseline.py new file mode 100644 index 0000000..9ed0950 --- /dev/null +++ b/examples/TTSwithVerification/bestofk_baseline.py @@ -0,0 +1,1070 @@ +import asyncio +import argparse +from datetime import datetime +import json +import logging +import os +import re +import sys +import shutil +import subprocess +from multiprocessing.pool import ThreadPool +from contextlib import contextmanager +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import numpy as np +import pandas as pd +import aiohttp +from datasets import load_dataset +from tqdm import tqdm +from transformers import AutoTokenizer + +from interwhen.utils.zebralogic_helper import SYSTEM_PROMPT_VANILLA, USER_PROMPT_TEMPLATE, get_zebralogic_dataset, extract_last_json, zebra_correctness + +from interwhen import stream_completion +from verina_utils import * + +# ============== MODEL CONFIGURATION ============== +MAIN_MODEL = "Qwen/Qwen3-30B-A3B-Thinking-2507" +# Multi-process vLLM configuration +VLLM_PORTS = [8000, 8001, 8002] # 3 instances with tensor-parallel-size 2 each +REQUEST_COUNTER = {"main": 0, "critic": 0} # Track request count for round-robin load balancing +# Verina paths +_SCRIPT_DIR = Path(__file__).parent.resolve() +VERINA_ROOT = (_SCRIPT_DIR / "../../../verina").resolve() +VERINA_DATASETS_PATH = VERINA_ROOT / "datasets" / "verina" +LEAN_PLAYGROUND_DIR = VERINA_ROOT / "lean-playground" + +logger = logging.getLogger(__name__) + +# Save the real stderr so tqdm always works even if suppress_output is active +_real_stderr = sys.stderr + +class NumpyEncoder(json.JSONEncoder): + """Custom JSON encoder that handles numpy types.""" + def default(self, obj): + if isinstance(obj, np.bool_): + return bool(obj) + if isinstance(obj, np.integer): + return int(obj) + if isinstance(obj, np.floating): + return float(obj) + if isinstance(obj, np.ndarray): + return obj.tolist() + return super().default(obj) + + +@contextmanager +def suppress_output(): + """Context manager to suppress stdout and stderr.""" + with open(os.devnull, 'w') as devnull: + old_stdout = sys.stdout + old_stderr = sys.stderr + sys.stdout = devnull + sys.stderr = devnull + try: + yield + finally: + sys.stdout = old_stdout + sys.stderr = old_stderr + + +@dataclass +class SampleResult: + output: str + correct: bool + extracted: Optional[str] + message: str + tokens: int + critic_correct: Optional[bool] = None + critic_feedback: Optional[str] = None + + +def get_model_short_name(model_name: str) -> str: + short_name = model_name.split("/")[-1] + return short_name.replace(" ", "_").replace(":", "-") + + +def get_next_port(server_type: str = "main") -> int: + """Get next vLLM port in round-robin fashion.""" + global REQUEST_COUNTER + port = VLLM_PORTS[REQUEST_COUNTER[server_type] % len(VLLM_PORTS)] + REQUEST_COUNTER[server_type] += 1 + return port + + +def get_output_dirs(task: str, main_model: str, use_critic: bool, critic_early_stop: bool, base_dir: str = "../../b-pchanda/Outputs_TTS_temp/BestOfKResults"): + model_short_name = get_model_short_name(main_model) + critic_status = "on" if use_critic else "off" + earlystop_status = "on" if critic_early_stop else "off" + output_base = os.path.join(base_dir, task, model_short_name, f"critic_{critic_status}", f"earlystop_{earlystop_status}") + dirs = { + "base": output_base, + "reasoning": os.path.join(output_base, "Reasoning_output"), + "critic": os.path.join(output_base, "Critic_output") if use_critic else None, + } + for dir_path in dirs.values(): + if dir_path: + os.makedirs(dir_path, exist_ok=True) + return dirs + +def init_llm_server(modelname, max_tokens=200, port=8000, temperature=0.6, seed=42): # + url = f"http://localhost:{port}/v1/completions" + payload = { + "model": modelname, + "max_tokens": max_tokens, + "top_k": 20, + "top_p": 0.95, + "min_p": 0.0, + "do_sample" : True, + "temperature": temperature, + "stream": False, + "logprobs": 20, + "use_beam_search": False, + "prompt_cache": True, + "seed" : seed + } + headers = {"Content-Type": "application/json"} + return {"url": url, "payload": payload, "headers": headers} + + +def count_tokens(text: str, tokenizer) -> int: + """Count tokens in text, with fallback to character count.""" + try: + if not text or len(text.strip()) == 0: + return 0 + tokens = tokenizer.encode(text, add_special_tokens=False) + return len(tokens) + except Exception as e: + logger.warning(f"Tokenization failed: {e}, using character count estimate") + # Rough estimate: ~4 characters per token + return max(1, len(text) // 4) + + +def save_outputs(idx: int, outputs: List[SampleResult], best_idx: int, output_dir: str): + os.makedirs(output_dir, exist_ok=True) + filepath = os.path.join(output_dir, f"output_{idx}.txt") + with open(filepath, "w", encoding="utf-8") as f: + f.write(f"BEST_INDEX={best_idx}\n") + for i, result in enumerate(outputs): + f.write("\n" + "=" * 80 + "\n") + f.write(f"SAMPLE {i}\n") + f.write(f"CORRECT={result.correct}\n") + f.write(f"CRITIC_CORRECT={result.critic_correct}\n") + f.write(f"EXTRACTED={result.extracted}\n") + f.write(f"TOKENS={result.tokens}\n") + f.write(f"MESSAGE={result.message}\n") + if result.critic_feedback: + f.write(f"CRITIC_FEEDBACK={result.critic_feedback}\n") + f.write("\n") + f.write(result.output) + f.write("\n") + # logger.info(f"Saved outputs to {filepath}") + + +# --------------------- Game24 helpers --------------------- + +def build_game24_prompt(nums): + a, b, c, d = nums + boxed = r"\\boxed{}" + base_prompt = f""" +You are solving the Game of 24. + +You are given four numbers: {a}, {b}, {c}, {d} + +Your job is to produce a valid arithmetic expression using: +- ALL four numbers exactly once +- ONLY +, -, *, / +- The expression must evaluate to exactly 24. + +Please reason step by step, and put your final answer containing only the expression within {boxed}. +""".strip() + return base_prompt + + +def extract_solution_game24(text): + boxed_pattern = r"\\boxed\{" + matches = list(re.finditer(boxed_pattern, text)) + if not matches: + return None + last_match = matches[-1] + start = last_match.end() + brace_count = 1 + end = start + while end < len(text) and brace_count > 0: + if text[end] == "{": + brace_count += 1 + elif text[end] == "}": + brace_count -= 1 + end += 1 + expr = text[start:end - 1].strip() + + frac_pattern = r"\\frac\{([^{}]+)\}\{([^{}]+)\}" + while re.search(frac_pattern, expr): + expr = re.sub(frac_pattern, r"(\1/\2)", expr) + + replacements = { + r"\times": "*", + r"\cdot": "*", + r"\div": "/", + } + for latex, op in replacements.items(): + expr = expr.replace(latex, op) + + expr = expr.replace(r"\\,", "").replace(r"\\ ", "") + expr = re.sub(r"\)\s*\(", ")*(", expr) + expr = re.sub(r"\)\s*(\d)", r")*\1", expr) + expr = re.sub(r"(\d)\s*\(", r"\1*(", expr) + + return expr + + +def extract_numbers_from_expr(expr): + numbers = re.findall(r"\d+\.?\d*", expr) + return [int(float(n)) if float(n).is_integer() else float(n) for n in numbers] + + +def validate_numbers_used(expr, expected_nums): + used_nums = extract_numbers_from_expr(expr) + return sorted(used_nums) == sorted(expected_nums) + + +def evaluate_expression(expr, expected_nums=None): + try: + if expected_nums is not None and not validate_numbers_used(expr, expected_nums): + return False + value = eval(expr, {"__builtins__": None}, {}) + return abs(value - 24) < 1e-6 + except Exception: + return False + + +def evaluate_game24_answer(answer, nums): + expr = extract_solution_game24(answer) + if not expr: + return False, None, "No expression found" + if evaluate_expression(expr, expected_nums=nums): + return True, expr, "Correct solution (evaluates to 24 using exactly the given numbers)" + used_nums = extract_numbers_from_expr(expr) + if sorted(used_nums) != sorted(nums): + return False, expr, f"Incorrect: Expression uses {used_nums}, expected {nums}" + return False, expr, "Expression does not evaluate to 24" + + +# --------------------- Maze/SpatialMap helpers --------------------- + +def remove_last_paragraph(s: str) -> str: + return s[:-143] if len(s) > 143 else s + + +def build_maze_prompt(example): + pre_prompt = ( + "You are an expert problem solver. Carefully read the following multiple-choice question " + "and think through the solution step-by-step before providing your final answer. " + "Provide your final answer option by enclosing it within \\boxed{A/B/C/D}.:" + ) + description = remove_last_paragraph(str(example.get("prompt"))) + return pre_prompt, description + + +def build_spatialmap_prompt(example): + pre_prompt = ( + "You are an expert problem solver. Carefully read the following multiple-choice question " + "and think through the solution step-by-step before providing your final answer." + "Provide your final answer option by enclosing it within \\boxed{A/B/C/D}.:" + ) + description = remove_last_paragraph(str(example.get("prompt"))) + return pre_prompt, description + + +def extract_solution_mcq(text): + """Extract MCQ solution from model output.""" + # Try multiple boxed patterns + patterns = [ + r"\\boxed\{([^}]*)\}", # \boxed{...} + r"boxed\{([^}]*)\}", # boxed{...} without escape + r"\*\*([A-D])\*\*", # **A** format + r"answer[:\s]*([A-D])", # answer: A format + r"(?:^|\n)([A-D])(?:\s|$|\.)", # Standalone letter + ] + + for pattern in patterns: + matches = re.findall(pattern, text, re.IGNORECASE) + if matches: + expr = matches[-1].strip() + choice_match = re.search(r"\b([ABCD])\b", expr, flags=re.IGNORECASE) + if choice_match: + return choice_match.group(1).upper() + + # Last resort: look for any standalone A, B, C, or D + standalone = re.findall(r"\b([ABCD])\b", text) + if standalone: + return standalone[-1].upper() + + return None + + +def extract_options_from_prompt(prompt_text, target_options): + pattern = r"\b([A-D])\.\s*(.*?)(?=\s*[A-D]\.\s*|$)" + raw = re.findall(pattern, prompt_text, flags=re.DOTALL) + options = {k: v.strip().rstrip(".") for k, v in raw} + if target_options: + options = {k: v for k, v in options.items() if k in target_options} + return options + + +def evaluate_mcq_answer(answer, options, ground_truth): + sol = extract_solution_mcq(answer) + gt_sol = str(ground_truth).strip() + if not sol: + return False, None, "No expression found" + sol = sol.strip() + if sol in options: + if options[sol] == gt_sol: + return True, sol, f"Correct: option {sol} -> {options[sol]}" + return False, sol, f"Incorrect: expected '{gt_sol}', got '{options[sol]}' (option {sol})" + if sol.lower() == gt_sol.lower(): + return True, sol, f"Correct: answer text matches ground truth: {sol}" + for opt_letter, opt_value in options.items(): + if sol.lower() == opt_value.lower(): + if opt_value == gt_sol: + return True, sol, f"Correct: answer text {sol} (option {opt_letter})" + return False, sol, f"Incorrect: expected '{gt_sol}', got '{opt_value}' (option {opt_letter})" + return False, sol, f"Solution '{sol}' not found in options or ground truth" + +# --------------------- ZebraLogic helpers --------------------- + +def evaluate_zebralogic_answer(answer, example): + """Evaluate a zebralogic answer against ground truth using zebra_correctness.""" + candidate = extract_last_json(answer) + if not candidate: + return False, None, "No valid JSON solution found" + correct, skipped, missing, total = zebra_correctness(example, candidate) + is_correct = correct == total + msg = f"Correct={correct}/{total}, skipped={skipped}, missing={missing}" + return is_correct, candidate, msg + + +def build_zebralogic_prompt(example): + system_prompt = SYSTEM_PROMPT_VANILLA + user_prompt = USER_PROMPT_TEMPLATE.format(problem_text=example['puzzle_clean']) + return system_prompt, user_prompt + +# verina helpers +def evaluate_verina_answer(output: str, data: BenchmarkData, task_idx: int) -> Tuple[bool, str, str]: + """Evaluate Verina code generation output - wrapper for best-of-k interface""" + generated_code = extract_code_from_response(output) + + if not generated_code.strip(): + return False, "", "No code extracted from response" + + compiles, all_tests_pass, compile_output, test_results = evaluate_generated_code(data, generated_code, task_idx) + + num_tests = len(data.tests) if data.tests else 0 + num_passed = sum(1 for v in test_results.values() if v == "pass") + + if compiles and all_tests_pass: + return True, generated_code, f"Code compiles and all {num_tests} tests pass" + elif compiles: + return False, generated_code, f"Compilation succeeded but {num_tests - num_passed}/{num_tests} tests failed" + else: + error_preview = compile_output[:300] if compile_output else "Unknown error" + return False, generated_code, f"Compilation failed: {error_preview}" + + +def build_full_prompt(task, example, nums=None): + if task == "game24": + prompt = build_game24_prompt(nums) + return f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" + if task == "maze": + system_prompt, user_prompt = build_maze_prompt(example) + elif task == 'zebralogic': + system_prompt, user_prompt = build_zebralogic_prompt(example) + elif task == "verina": + return build_verina_prompt(example) + else: + system_prompt, user_prompt = build_spatialmap_prompt(example) + return ( + f"<|im_start|>system\n{system_prompt}<|im_end|>\n" + f"<|im_start|>user\n{user_prompt}<|im_end|>\n" + f"<|im_start|>assistant\n" + ) + + +def load_dataset_for_task(task): + if task == "game24": + return load_dataset("nlile/24-game", split="train") + if task == "maze": + return load_dataset("microsoft/VISION_LANGUAGE", "maze_text_only", split="val") + if task == "spatialmap": + return load_dataset("microsoft/VISION_LANGUAGE", "spatial_map_text_only", split="val") + if task == "zebralogic": + return get_zebralogic_dataset() + if task == "verina": + return load_verina_dataset() + raise ValueError(f"Unsupported task: {task}") + + +def resolve_indices(task, dataset_len, args): + if args.indices: + return [int(x.strip()) for x in args.indices.split(",")] + if args.xrange: + parts = args.xrange.split("-") + if len(parts) == 2: + try: + start = int(parts[0].strip()) + end = int(parts[1].strip()) + return range(start, end) + except ValueError: + raise ValueError(f"Invalid xrange format: {args.xrange}. Use 'start-end'") + if args.num_examples: + max_idx = dataset_len - 1 + upper_bound = min(max_idx, 1362) if task == "game24" else min(max_idx, 1499) + return list((np.linspace(0, upper_bound, args.num_examples)).astype(int)) + # Default: use full range + start = args.start if args.start is not None else 0 + end = args.end if args.end is not None else dataset_len + return range(start, end) + + +def batch_generate_samples(prompt, llm_server, k, seed, quiet=True): + """Generate k samples using vLLM batch processing via API across multiple instances.""" + payload_template = llm_server["payload"].copy() + headers = llm_server["headers"] + + # Create k requests with different seeds + batch_payloads = [] + for i in range(k): + payload = payload_template.copy() + payload["prompt"] = prompt + payload["seed"] = seed + i + batch_payloads.append(payload) + + # Send requests to vLLM instances in parallel (true concurrency) + async def _fetch_one(session, sem, idx, url, payload): + async with sem: + try: + async with session.post(url, json=payload, headers=headers, timeout=300) as resp: + text = await resp.text() + if resp.status >= 400: + logger.warning(f"HTTP error for seed {seed + idx} on {url}: {resp.status} - {text[:200]}") + return idx, "" + try: + result = json.loads(text) + except json.JSONDecodeError: + logger.warning(f"Invalid JSON for seed {seed + idx} on {url}") + return idx, "" + except Exception as e: + logger.warning(f"Batch generation failed for seed {seed + idx} on {url}: {e}") + return idx, "" + + if "choices" in result and len(result["choices"]) > 0: + choice = result["choices"][0] + if isinstance(choice, dict): + output_text = choice.get("text") or choice.get("message", {}).get("content", "") + else: + output_text = str(choice) + if output_text and len(output_text.strip()) > 0: + return idx, output_text + logger.warning(f"Empty output for seed {seed + idx} on {url}") + return idx, "" + + logger.warning(f"No choices in response for seed {seed + idx} on {url}: {result.keys() if isinstance(result, dict) else type(result)}") + return idx, "" + + async def _run_parallel(): + sem = asyncio.Semaphore(len(VLLM_PORTS)) + async with aiohttp.ClientSession() as session: + tasks = [] + for i, payload in enumerate(batch_payloads): + port = get_next_port(server_type="main") + url = f"http://localhost:{port}/v1/completions" + tasks.append(asyncio.create_task(_fetch_one(session, sem, i, url, payload))) + results = await asyncio.gather(*tasks) + return results + + if quiet: + with suppress_output(): + results = asyncio.run(_run_parallel()) + else: + results = asyncio.run(_run_parallel()) + + outputs = [""] * k + for idx, output_text in results: + outputs[idx] = output_text + if output_text and not quiet: + print(f"[Generated sample {idx}] {len(output_text)} chars, {len(output_text.split())} words") + + return outputs + + +# --------------------- Critic model helpers --------------------- + +def build_game24_critic_prompt(nums, reasoning_output): + """Build critic prompt to evaluate Game of 24 solution and provide reasoning.""" + return f"""You are a math verifier. Evaluate the following Game of 24 solution. + +Numbers: {nums} +Target: 24 + +Student's reasoning and answer: +{reasoning_output} + +Verify: +1. Does it use ALL four numbers exactly once? +2. Does each step follow correct arithmetic? +3. Does the final expression evaluate to exactly 24? + +Respond in the following format: +VERDICT: CORRECT or INCORRECT +REASONING: Your detailed explanation + +If CORRECT, briefly explain why. +If INCORRECT, explain what went wrong and how to fix it. +""" + + +def build_zebralogic_critic_prompt(task_description, reasoning_output): + """Build critic prompt to evaluate ZebraLogic solution and provide reasoning.""" + return f"""You are an expert logic puzzle verifier. Evaluate the following ZebraLogic solution. + +Task: +{task_description} + +Student's reasoning and answer: +{reasoning_output} + +Verify: +1. Does the solution assign exactly one value per feature per house? +2. Are all constraints/clues satisfied? +3. Is the JSON output well-formed and complete? + +Respond in the following format: +VERDICT: CORRECT or INCORRECT +REASONING: Your detailed explanation + +If CORRECT, briefly explain why. +If INCORRECT, explain what went wrong and suggest corrections. +""" + + +def build_mcq_critic_prompt(task, task_description, reasoning_output): + """Build critic prompt to evaluate MCQ solution and provide reasoning.""" + task_name = "Maze" if task == "maze" else "Spatial Reasoning" + return f"""You are an expert {task_name} verifier. Evaluate the following solution. + +Task: +{task_description} + +Student's reasoning and answer: +{reasoning_output} + +Verify the correctness of the step-by-step reasoning and final answer. + +Respond in the following format: +VERDICT: CORRECT or INCORRECT +REASONING: Your detailed explanation + +If CORRECT, briefly explain why. +If INCORRECT, explain what went wrong and suggest the correct approach. +""" + +def build_verina_critic_prompt(data: BenchmarkData, reasoning_output: str) -> str: + """Build critic prompt to evaluate Verina Lean code generation and provide reasoning.""" + signature = data.signature + func_name = signature.get("name", "solution") + return_type = signature.get("return_type", "Bool") + param_list = render_param_list(signature) + + precond = data.lean_data.get("precond", "True").strip() + postcond = data.lean_data.get("postcond", "").strip() + + return f"""You are an expert Lean 4 code verifier. Evaluate the following code generation attempt. + +## Task Description +{data.description} + +## Function Signature +```lean4 +def {func_name} {param_list} (h_precond : {func_name}_precond ...) : {return_type} +``` + +## Precondition +```lean4 +{precond} +``` + +## Postcondition +```lean4 +{postcond} +``` + +## Student's Reasoning and Generated Code +{reasoning_output} + +Verify: +1. Is the generated code syntactically valid Lean 4? +2. Does it match the expected function signature and return type ({return_type})? +3. Does the logic appear to satisfy the postcondition given the precondition? +4. Are there any obvious bugs, infinite loops, or incorrect base cases? + +Respond in the following format: +VERDICT: CORRECT or INCORRECT +REASONING: Your detailed explanation + +If CORRECT, briefly explain why +If INCORRECT, explain what went wrong and suggest how to fix it. +""" + +def batch_evaluate_with_critic(outputs_df, task, example, critic_llm_server, tokenizer, nums=None, quiet=True): + """Batch evaluate outputs using vLLM API across multiple instances. Outputs_df should have columns: 'output', 'seed_idx'""" + payload_template = critic_llm_server["payload"].copy() + headers = critic_llm_server["headers"] + + async def _fetch_one(session, sem, idx, url, payload): + async with sem: + try: + async with session.post(url, json=payload, headers=headers, timeout=300) as resp: + text = await resp.text() + if resp.status >= 400: + logger.warning(f"HTTP error for critic sample {idx} on {url}: {resp.status} - {text[:200]}") + return idx, "", False + try: + result = json.loads(text) + except json.JSONDecodeError: + logger.warning(f"Invalid JSON for critic sample {idx} on {url}") + return idx, "", False + except Exception as e: + logger.warning(f"Critic evaluation failed for sample {idx} on {url}: {e}") + return idx, "", False + + if "choices" in result and len(result["choices"]) > 0: + choice = result["choices"][0] + critic_output = choice.get("text") or choice.get("message", {}).get("content", "") + else: + critic_output = "" + + is_correct = "CORRECT" in critic_output.upper() + reasoning = "" + if "REASONING:" in critic_output: + reasoning = critic_output.split("REASONING:", 1)[1].strip() + elif "VERDICT:" not in critic_output: + reasoning = critic_output + + return idx, reasoning, is_correct + + async def _run_parallel(): + sem = asyncio.Semaphore(len(VLLM_PORTS)) + async with aiohttp.ClientSession() as session: + tasks = [] + for idx, row in outputs_df.iterrows(): + output_text = row["output"] + if task == "game24": + critic_prompt = build_game24_critic_prompt(nums, output_text) + elif task == "zebralogic": + _, task_desc = build_zebralogic_prompt(example) + critic_prompt = build_zebralogic_critic_prompt(task_desc, output_text) + elif task == "verina": + critic_prompt = build_verina_critic_prompt(example, output_text) + else: + if task == "maze": + _, task_desc = build_maze_prompt(example) + else: + _, task_desc = build_spatialmap_prompt(example) + critic_prompt = build_mcq_critic_prompt(task, task_desc, output_text) + + critic_system = "You are a strict academic verifier." + full_prompt = f"<|im_start|>system\n{critic_system}<|im_end|>\n<|im_start|>user\n{critic_prompt}<|im_end|>\n<|im_start|>assistant\n" + + payload = payload_template.copy() + payload["prompt"] = full_prompt + payload["seed"] = row.get("critic_seed", idx) + + port = get_next_port(server_type="critic") + url = f"http://localhost:{port}/v1/completions" + tasks.append(asyncio.create_task(_fetch_one(session, sem, idx, url, payload))) + + return await asyncio.gather(*tasks) + + if quiet: + with suppress_output(): + results = asyncio.run(_run_parallel()) + else: + results = asyncio.run(_run_parallel()) + + rows = [] + for sample_idx, reasoning, is_correct in results: + rows.append({ + "sample_idx": sample_idx, + "critic_correct": is_correct, + "critic_feedback": reasoning, + }) + + return pd.DataFrame(rows) + + +def run_k_samples_with_critic( + prompt, + llm_server, + critic_llm_server, + k, + seed, + task, + example, + tokenizer, + eval_fn, + nums=None, + early_stop=False, + critic_feedback_baseline=False, + quiet=True, +): + """Run k samples with critic evaluation using vLLM batching.""" + # If critic_feedback_baseline, generate samples sequentially with feedback chaining + if critic_feedback_baseline: + sample_results = [] + current_prompt = prompt + + for i in range(k): + # Generate single sample + output = batch_generate_samples(current_prompt, llm_server, 1, seed + i, quiet=quiet)[0] + + # Evaluate with critic + df_critic = batch_evaluate_with_critic( + pd.DataFrame([{"output": output, "seed_idx": i}]), + task, example, critic_llm_server, tokenizer, nums=nums, quiet=quiet + ) + critic_correct = df_critic.iloc[0]["critic_correct"] if len(df_critic) > 0 else False + critic_feedback = df_critic.iloc[0]["critic_feedback"] if len(df_critic) > 0 else "" + + # Evaluate with ground truth + is_correct, extracted, message = eval_fn(output) + token_count = count_tokens(output, tokenizer) + + sample_results.append(SampleResult( + output=output, + correct=is_correct, + extracted=extracted, + message=f"Critic verdict: {'CORRECT' if critic_correct else 'INCORRECT'} | {message}", + tokens=token_count, + critic_correct=critic_correct, + critic_feedback=critic_feedback, + )) + + # If critic says INCORRECT and not the last sample, add feedback to prompt + if not critic_correct and i < k - 1 and critic_feedback: + feedback_text = f"\n\nPrevious attempt was incorrect:\n{output}\n\nCritic feedback:\n{critic_feedback}\n\nPlease address the feedback and try again:" + # Insert feedback before the assistant tag + if "<|im_end|>\n<|im_start|>assistant\n" in current_prompt: + current_prompt = current_prompt.replace( + "<|im_end|>\n<|im_start|>assistant\n", + f"{feedback_text}<|im_end|>\n<|im_start|>assistant\n" + ) + else: + # Fallback: append to end of prompt + current_prompt = current_prompt + feedback_text + + return sample_results + + # Generate k samples + outputs = batch_generate_samples(prompt, llm_server, k, seed, quiet=quiet) + + # Create dataframe with outputs + df_samples = pd.DataFrame({ + "sample_idx": range(k), + "output": outputs, + "seed": [seed + i for i in range(k)], + }) + + # If early stop mode, stop at first critic-correct + if early_stop: + sample_results = [] + for idx, row in df_samples.iterrows(): + output = row["output"] + + # Evaluate with critic + df_critic = batch_evaluate_with_critic( + pd.DataFrame([{"output": output, "seed_idx": idx}]), + task, example, critic_llm_server, tokenizer, nums=nums, quiet=quiet + ) + critic_correct = df_critic.iloc[0]["critic_correct"] if len(df_critic) > 0 else False + critic_feedback = df_critic.iloc[0]["critic_feedback"] if len(df_critic) > 0 else "" + + # Evaluate with ground truth + is_correct, extracted, message = eval_fn(output) + token_count = count_tokens(output, tokenizer) + + sample_results.append(SampleResult( + output=output, + correct=is_correct, + extracted=extracted, + message=f"Critic verdict: {'CORRECT' if critic_correct else 'INCORRECT'} | {message}", + tokens=token_count, + critic_correct=critic_correct, + critic_feedback=critic_feedback, + )) + + if critic_correct: + break + + return sample_results + else: + # Batch critic evaluation + df_critic = batch_evaluate_with_critic( + df_samples, task, example, critic_llm_server, tokenizer, nums=nums, quiet=quiet + ) + + # Merge critic results + df_samples = df_samples.merge(df_critic, left_index=True, right_on="sample_idx", how="left") + + # Process all results + sample_results = [] + for idx, row in df_samples.iterrows(): + output = row["output"] + critic_correct = row.get("critic_correct", False) + critic_feedback = row.get("critic_feedback", "") + + is_correct, extracted, message = eval_fn(output) + token_count = count_tokens(output, tokenizer) + + sample_results.append(SampleResult( + output=output, + correct=is_correct, + extracted=extracted, + message=f"Critic verdict: {'CORRECT' if critic_correct else 'INCORRECT'} | {message}", + tokens=token_count, + critic_correct=critic_correct, + critic_feedback=critic_feedback, + )) + + return sample_results + + + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Best-of-K baseline (standard CoT) for TTSwithVerification datasets") + parser.add_argument("--task", type=str, required=True, choices=["game24", "maze", "spatialmap", "zebralogic","verina"], + help="Task to run") + parser.add_argument("--k", type=int, default=1, help="Number of samples per example") + parser.add_argument("--num_examples", "-n", type=int, default=100, + help="Number of examples to run (overrides start/end)") + parser.add_argument("--indices", type=str, default=None, + help="Comma-separated indices to run") + parser.add_argument("--xrange", type=str, default=None, + help="Range of indices to run (format: 'start-end')") + parser.add_argument("--start", type=int, default=None, help="Start index") + parser.add_argument("--end", type=int, default=None, help="End index") + parser.add_argument("--main_model", type=str, default=MAIN_MODEL, help="Main model to use for generation") + parser.add_argument("--port", type=int, default=8000, help="vLLM server port") + parser.add_argument("--use_critic", action="store_true", help="Use critic model for evaluation instead of ground truth") + parser.add_argument("--critic_model", type=str, default=MAIN_MODEL, help="Critic model to use for evaluation") + parser.add_argument("--critic_port", type=int, default=8000, help="vLLM server port for critic model (default: same as main model port)") + parser.add_argument("--critic_early_stop", action="store_true", help="Stop sampling after first critic-correct trace") + parser.add_argument("--critic_feedback_baseline", action="store_true", help="Use critic feedback as a separate baseline for post-hoc correction") + parser.add_argument("--seed", type=int, default=42, help="Base random seed") + parser.add_argument("--max_tokens", type=int, default=32768, help="Max tokens for generation") + parser.add_argument("--temperature", type=float, default=0.6, help="Sampling temperature") + parser.add_argument("--processes", "-p", type=int, default=1, help="Number of examples to process in parallel (default: 1, sequential)") + parser.add_argument("--debug", "-d", action="store_true", help="Enable debug logging") + args = parser.parse_args() + + log_level = logging.DEBUG if args.debug else logging.ERROR + logging.basicConfig(level=log_level, format="%(message)s") + + quiet_mode = not args.debug + + if quiet_mode: + with suppress_output(): + dataset = load_dataset_for_task(args.task) + else: + dataset = load_dataset_for_task(args.task) + indices = resolve_indices(args.task, len(dataset), args) + + llm_server = init_llm_server( + args.main_model, + max_tokens=args.max_tokens, + port=args.port, + temperature=args.temperature, + seed=args.seed, + ) + + critic_llm_server = None + if args.use_critic: + critic_llm_server = init_llm_server( + args.critic_model, + max_tokens=512, + port=args.critic_port, + temperature=0.2, + seed=args.seed, + ) + # logger.info(f"Using critic model: {args.critic_model} on port {args.critic_port}") + + # logger.info(f"Loading tokenizer for {args.main_model}...") + if quiet_mode: + with suppress_output(): + tokenizer = AutoTokenizer.from_pretrained(args.main_model, trust_remote_code=True) + else: + tokenizer = AutoTokenizer.from_pretrained(args.main_model, trust_remote_code=True) + # logger.info("Tokenizer loaded successfully.") + + output_dirs = get_output_dirs(args.task, args.main_model, args.use_critic, args.critic_early_stop) + + total_examples = 0 + total_correct = 0 + total_correct_samples = 0 + total_samples = 0 + critic_correct_samples = 0 + critic_total_samples = 0 + total_tokens = 0 + total_tokens_all_samples = 0 + results = [] + + def process_example(idx): + """Process a single example: generate k samples, evaluate, return result dict.""" + example = dataset[int(idx)] + if args.task == "game24": + nums = example["numbers"] + prompt = build_full_prompt(args.task, example, nums=nums) + eval_fn = lambda output: evaluate_game24_answer(output, nums) + options = None + + elif args.task == "zebralogic": + prompt = build_full_prompt(args.task, example) + eval_fn = lambda output, ex=example: evaluate_zebralogic_answer(output, ex) + options = None + elif args.task == "verina": + # For verina, example is a BenchmarkData object + prompt = build_full_prompt(args.task, example) + current_idx = int(idx) + current_data = example + eval_fn = lambda output, data=current_data, task_idx=current_idx: evaluate_verina_answer(output, data, task_idx) + options = None + + else: + prompt = build_full_prompt(args.task, example) + gt = str(example.get("ground_truth", "")).strip() + if gt == "Q4": + target_options = ["A", "B"] + else: + target_options = ["A", "B", "C", "D"] + if args.task == "maze": + _, user_prompt = build_maze_prompt(example) + else: + _, user_prompt = build_spatialmap_prompt(example) + options = extract_options_from_prompt(user_prompt, target_options) + eval_fn = lambda output: evaluate_mcq_answer(output, options, gt) + + #logger.info(f"---- Example {idx} ----") + + quiet_mode = not args.debug + + if args.use_critic: + sample_results = run_k_samples_with_critic( + prompt, llm_server, critic_llm_server, args.k, args.seed, + args.task, example, tokenizer, eval_fn, nums=(nums if args.task == "game24" else None), + early_stop=args.critic_early_stop, critic_feedback_baseline=args.critic_feedback_baseline, quiet=quiet_mode + ) + else: + outputs = batch_generate_samples(prompt, llm_server, args.k, args.seed, quiet=quiet_mode) + sample_results = [] + for output in outputs: + is_correct, extracted, message = eval_fn(output) + token_count = count_tokens(output, tokenizer) + sample_results.append(SampleResult( + output=output, + correct=is_correct, + extracted=extracted, + message=message, + tokens=token_count, + critic_correct=None, + )) + + if args.use_critic: + best_idx = next((i for i, r in enumerate(sample_results) if r.critic_correct), 0) + else: + best_idx = next((i for i, r in enumerate(sample_results) if r.correct), 0) + best_result = sample_results[best_idx] + any_correct = any(r.correct for r in sample_results) + correct_samples = sum(1 for r in sample_results if r.correct) + critic_correct_samples_example = sum(1 for r in sample_results if r.critic_correct) + + save_outputs(idx, sample_results, best_idx, output_dirs["reasoning"]) + + return { + "idx": int(idx), + "best_idx": best_idx, + "any_correct": any_correct, + "best_correct": best_result.correct, + "best_critic_correct": best_result.critic_correct, + "best_extracted": best_result.extracted, + "best_message": best_result.message, + "best_critic_feedback": best_result.critic_feedback, + "best_tokens": best_result.tokens, + "all_tokens": [r.tokens for r in sample_results], + "all_correct": [r.correct for r in sample_results], + "all_critic_correct": [r.critic_correct for r in sample_results], + "all_critic_feedback": [r.critic_feedback for r in sample_results], + "options": options, + "_any_correct": any_correct, + "_correct_samples": correct_samples, + "_critic_correct_samples": critic_correct_samples_example, + "_n_samples": len(sample_results), + "_best_tokens": best_result.tokens, + "_all_tokens_sum": sum(r.tokens for r in sample_results), + } + + with ThreadPool(processes=args.processes) as pool: + for result in tqdm(pool.imap_unordered(process_example, indices), total=len(indices), desc="Processing examples", unit="example", file=_real_stderr): + total_examples += 1 + if result["_any_correct"]: + total_correct += 1 + total_correct_samples += result["_correct_samples"] + total_samples += result["_n_samples"] + critic_correct_samples += result["_critic_correct_samples"] + critic_total_samples += result["_n_samples"] + total_tokens += result["_best_tokens"] + total_tokens_all_samples += result["_all_tokens_sum"] + + # Remove internal keys before appending + for k in list(result.keys()): + if k.startswith("_"): + del result[k] + results.append(result) + + accuracy = total_correct / total_examples if total_examples else 0 + avg_best_tokens = total_tokens / total_examples if total_examples else 0 + avg_all_tokens = total_tokens_all_samples / total_examples if total_examples else 0 + + summary = { + "task": args.task, + "model": args.main_model, + "k": args.k, + "use_critic": args.use_critic, + "total_examples": total_examples, + "correct": total_correct, + "correct_samples": total_correct_samples, + "total_samples": total_samples, + "critic_correct_samples": critic_correct_samples, + "critic_total_samples": critic_total_samples, + "critic_accuracy": (critic_correct_samples / critic_total_samples) if critic_total_samples else 0, + "accuracy": accuracy, + "avg_best_tokens": avg_best_tokens, + "avg_all_tokens": avg_all_tokens, + "total_tokens_best": total_tokens, + "total_tokens_all_samples": total_tokens_all_samples, + "results": results, + } + + if args.use_critic: + summary["critic_model"] = args.critic_model + summary["critic_port"] = args.critic_port + summary["critic_early_stop"] = args.critic_early_stop + summary["critic_feedback_baseline"] = args.critic_feedback_baseline + + summary_path = os.path.join(output_dirs["base"], "summary.json") + with open(summary_path, "w", encoding="utf-8") as f: + json.dump(summary, f, indent=2, cls=NumpyEncoder) + # logger.info(f"Saved summary to {summary_path}") \ No newline at end of file diff --git a/examples/TTSwithVerification/game24_stepverifier.py b/examples/TTSwithVerification/game24_stepverifier.py index a6437ea..597f909 100644 --- a/examples/TTSwithVerification/game24_stepverifier.py +++ b/examples/TTSwithVerification/game24_stepverifier.py @@ -1,33 +1,47 @@ +""" +Game of 24 experiment with thinking-phase step verification. + +Uses ThinkingPhaseStepVerifierGame24Monitor which: + - Verifies the model's intermediate expressions during via side-streams + - Injects expression extraction after + - Verifies the final \\boxed{} expression for correctness +""" + import argparse import asyncio -import csv -import json import logging import os import re import numpy as np from datasets import load_dataset -from openai import OpenAI from transformers import AutoTokenizer from interwhen import stream_completion -from interwhen.monitors import KstableAnswerGame24Monitor, StepVerifierGame24Monitor +from interwhen.monitors import ThinkingPhaseStepVerifierGame24Monitor # ============== MODEL CONFIGURATION ============== -# Change these model names to scale experiments easily -MAIN_MODEL = "Qwen/Qwen3-30B-A3B-Thinking-2507" -EARLYSTOP_MODEL = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" +MAIN_MODEL = "Qwen/QwQ-32B" # ================================================= +_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) + +# Walk up to find the repo root (contains pyproject.toml), output to its parent +_dir = _SCRIPT_DIR +while _dir != os.path.dirname(_dir) and not os.path.isfile(os.path.join(_dir, "pyproject.toml")): + _dir = os.path.dirname(_dir) +_OUTPUT_ROOT = os.path.dirname(_dir) + def get_model_short_name(model_name: str) -> str: """Extract a short, filesystem-safe name from the model path.""" short_name = model_name.split("/")[-1] short_name = short_name.replace(" ", "_").replace(":", "-") return short_name -def get_output_dirs(main_model: str, base_dir: str = "../../Outputs_TTS/Gameof24results"): +def get_output_dirs(main_model: str, base_dir: str = None): """Create and return output directory paths based on model name.""" + if base_dir is None: + base_dir = os.path.join(_OUTPUT_ROOT, "Outputs_TTS", "Gameof24results") model_short_name = get_model_short_name(main_model) output_base = os.path.join(base_dir, model_short_name) @@ -37,26 +51,20 @@ def get_output_dirs(main_model: str, base_dir: str = "../../Outputs_TTS/Gameof24 "csv_saved": os.path.join(output_base, "csv_saved"), } - # Create all directories for dir_path in dirs.values(): os.makedirs(dir_path, exist_ok=True) return dirs -def get_log_filename(main_model: str, num_examples: int, base_dir: str = "../../Outputs_TTS/Gameof24_results") -> str: +def get_log_filename(main_model: str, num_examples: int, base_dir: str = None) -> str: """Generate log filename based on model name.""" + if base_dir is None: + base_dir = os.path.join(_OUTPUT_ROOT, "Outputs_TTS", "Gameof24results") model_short_name = get_model_short_name(main_model) output_base = os.path.join(base_dir, model_short_name) os.makedirs(output_base, exist_ok=True) return os.path.join(output_base, f"EAT_{num_examples}examples.log") -def get_token_filename(main_model: str, num_examples: int, base_dir: str = "../../Outputs_TTS/Gameof24_results") -> str: - """Generate token CSV filename based on model name.""" - model_short_name = get_model_short_name(main_model) - output_base = os.path.join(base_dir, model_short_name) - os.makedirs(output_base, exist_ok=True) - return os.path.join(output_base, f"EAT_{num_examples}examples.csv") - def save_prompt(idx, prompt_with_answer, reason_dir): filename = os.path.join(reason_dir, f"reason_{idx}.txt") with open(filename, "w", encoding="utf-8") as f: @@ -65,11 +73,7 @@ def save_prompt(idx, prompt_with_answer, reason_dir): logger = logging.getLogger(__name__) -def load_game24_dataset(): - ds = load_dataset("nlile/24-game", split="train") - return ds - -def init_llm_server(modelname, max_tokens=200, port=8000): +def init_llm_server(modelname, max_tokens=32768, port=8000): url = f"http://localhost:{port}/v1/completions" payload = { "model": modelname, @@ -77,159 +81,33 @@ def init_llm_server(modelname, max_tokens=200, port=8000): "top_k": 20, "top_p": 0.95, "min_p": 0.0, + "do_sample": True, "temperature": 0.6, "stream": True, "logprobs": 20, "use_beam_search": False, "prompt_cache": True, - "seed" : 42 + "seed": 42 } headers = {"Content-Type": "application/json"} return {"url": url, "payload": payload, "headers": headers} -def build_meta_prompt_from_example(nums): - """Build the system and user prompts for Game of 24 with step verification format.""" +def build_prompt(nums): a, b, c, d = nums + boxed = r"\boxed{}" + base_prompt = f""" + You are solving the Game of 24. - system_prompt = r"""You are solving the Game of 24. - -GAME RULES: -- You are given four numbers -- Use ALL four numbers exactly once -- Use ONLY the operations: +, -, *, / -- The final expression must evaluate to exactly 24 - -OUTPUT FORMAT: -You must follow this EXACT structured format for your solution: - ->Step1 -available numbers: [a, b, c, d] -suggested operation: a * b = result1 -remaining numbers: [result1, c, d] - ->Step2 -available numbers: [result1, c, d] -suggested operation: result1 + c = result2 -remaining numbers: [result2, d] - ->Step3 -available numbers: [result2, d] -suggested operation: result2 - d = 24 -remaining numbers: [24] - -> Final expression: \boxed{expression using original numbers} - -IMPORTANT RULES: -1. Each step MUST show the available numbers at the start -2. Each step MUST show the suggested operation with its result -3. Each step MUST show the remaining numbers after the operation -4. Continue until you reach exactly 24 -5. The final expression inside \boxed{} must use the ORIGINAL numbers -6. If you receive VERIFIER FEEDBACK, immediately provide a corrected step - do NOT restart your thinking - -═══════════════════════════════════════════════════════════════════════════════ -EXAMPLE 1: Numbers [2, 3, 4, 5] -═══════════════════════════════════════════════════════════════════════════════ - -### Final Answer - ->Step1 -available numbers: [2, 3, 4, 5] -suggested operation: 5 + 3 = 8 -remaining numbers: [8, 2, 4] - ->Step2 -available numbers: [8, 2, 4] -suggested operation: 8 - 2 = 6 -remaining numbers: [6, 4] - ->Step3 -available numbers: [6, 4] -suggested operation: 6 * 4 = 24 -remaining numbers: [24] - -> Final expression: \boxed{(5 + 3 - 2) * 4} - -═══════════════════════════════════════════════════════════════════════════════ -EXAMPLE 2: Numbers [1, 5, 5, 5] -═══════════════════════════════════════════════════════════════════════════════ - -### Final Answer - ->Step1 -available numbers: [1, 5, 5, 5] -suggested operation: 1 / 5 = 0.2 -remaining numbers: [0.2, 5, 5] - ->Step2 -available numbers: [0.2, 5, 5] -suggested operation: 5 - 0.2 = 4.8 -remaining numbers: [4.8, 5] - ->Step3 -available numbers: [4.8, 5] -suggested operation: 4.8 * 5 = 24 -remaining numbers: [24] - -> Final expression: \boxed{(5 - 1/5) * 5} - -═══════════════════════════════════════════════════════════════════════════════ -EXAMPLE 3: Handling Verifier Feedback - Numbers [1, 2, 6, 8] -═══════════════════════════════════════════════════════════════════════════════ - -### Final Answer - ->Step1 -available numbers: [1, 2, 6, 8] -suggested operation: 8 / 2 = 4 -remaining numbers: [4, 1, 6] - ->Step2 -available numbers: [4, 1, 6] -suggested operation: 4 - 1 = 3 -remaining numbers: [3, 6] - -[VERIFIER FEEDBACK for Step 2: - ✗ Cannot reach 24 from remaining numbers [3, 6]. This path is a dead end. -The previous steps are correct. Please provide a corrected Step 2 and continue.] - ->Step2 -available numbers: [4, 1, 6] -suggested operation: 6 - 1 = 5 -remaining numbers: [5, 4] - -[VERIFIER FEEDBACK for Step 2: - ✗ Cannot reach 24 from remaining numbers [4, 5]. This path is a dead end. -The previous steps are correct. Please provide a corrected Step 2 and continue.] - ->Step2 -available numbers: [4, 1, 6] -suggested operation: 6 * 1 = 6 -remaining numbers: [6, 4] - ->Step3 -available numbers: [6, 4] -suggested operation: 6 * 4 = 24 -remaining numbers: [24] - -> Final expression: \boxed{(8 / 2) * 6 * 1} - -═══════════════════════════════════════════════════════════════════════════════ - -Now solve the following Game of 24 problem using the EXACT same format.""" - - user_prompt = f""" -Numbers: {a}, {b}, {c}, {d} - -Find an arithmetic expression using these four numbers exactly once each with +, -, *, / that equals 24. - -Use the structured step-by-step format shown in the examples above.""" - - # Combine into a single prompt - full_prompt = f"{system_prompt}\n\n{user_prompt}" + You are given four numbers: {a}, {b}, {c}, {d} - return full_prompt + Your job is to produce a valid arithmetic expression using: + - ALL four numbers exactly once + - ONLY +, -, *, / + - The expression must evaluate to exactly 24. + + Please reason step by step, and put your final answer containing only the expression within {boxed}.""".strip() + return base_prompt def count_tokens(text: str, tokenizer) -> int: @@ -240,10 +118,22 @@ def count_tokens(text: str, tokenizer) -> int: def extract_solution(text): + # Only search for \boxed{} AFTER to avoid grabbing unverified + # expressions from inside the thinking trace. + # If model opened but never closed it (hit token limit), there is + # no final answer — return None. + if '' in text: + search_text = text[text.rfind(''):] + elif '' in text: + # Model started thinking but never finished — no verified answer + return None + else: + search_text = text + # Use a more robust extraction that handles nested braces in \boxed{} # Find \boxed{ and then match braces properly boxed_pattern = r"\\boxed\{" - matches = list(re.finditer(boxed_pattern, text)) + matches = list(re.finditer(boxed_pattern, search_text)) if not matches: return None @@ -252,14 +142,18 @@ def extract_solution(text): start = last_match.end() # Position right after \boxed{ brace_count = 1 end = start - while end < len(text) and brace_count > 0: - if text[end] == '{': + while end < len(search_text) and brace_count > 0: + if search_text[end] == '{': brace_count += 1 - elif text[end] == '}': + elif search_text[end] == '}': brace_count -= 1 end += 1 - expr = text[start:end-1].strip() # -1 to exclude the closing brace + expr = search_text[start:end-1].strip() # -1 to exclude the closing brace + + # Skip empty \boxed{} (e.g., from verifier feedback "Wrap in \boxed{}.") + if not expr: + return None # 1. Convert \frac{a}{b} to (a/b) frac_pattern = r"\\frac\{([^{}]+)\}\{([^{}]+)\}" @@ -275,8 +169,16 @@ def extract_solution(text): for latex, op in replacements.items(): expr = expr.replace(latex, op) - # 3. Cleanup (remove LaTeX spacing) + # 2b. Replace Unicode math operators (QwQ frequently uses these) + expr = expr.replace('\u00d7', '*').replace('\u00f7', '/').replace('\u2212', '-') + expr = expr.replace('\u2013', '-').replace('\u2014', '-') # en-dash, em-dash + + # 3. Cleanup (remove LaTeX formatting artifacts) expr = expr.replace(r"\,", "").replace(r"\ ", "") + expr = expr.replace(r"\left", "").replace(r"\right", "") + + # 3b. Strip trailing "= " (e.g., "10 - 8/8 * 1 = 24" -> "10 - 8/8 * 1") + expr = re.sub(r'\s*=\s*[\d.]+\s*$', '', expr) # 4. Handle implicit multiplication (e.g., "(11+1)(1+1)" -> "(11+1)*(1+1)") # Insert * between: )( , )number, number(, )( @@ -322,10 +224,8 @@ def evaluate_game24_answer(answer, nums): Tuple of (is_correct, extracted_expression, error_message) """ expr = extract_solution(answer) - if not expr: return False, None, "No expression found" - if evaluate_expression(expr, expected_nums=nums): return True, expr, "Correct solution (evaluates to 24 using exactly the given numbers)" else: @@ -338,22 +238,19 @@ def evaluate_game24_answer(answer, nums): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Game of 24 step-by-step solver with monitors") - parser.add_argument("--thinking", "-t", action="store_true", help="Enable chain-of-thought output") - parser.add_argument("--monitor", "-m", default = True, action="store_true", help="Enable step-by-step monitor") - parser.add_argument("--num_examples", "-n", type=int, default=1, help="Number of examples to run") + parser.add_argument("--num_examples", "-n", type=int, default=1362, help="Number of examples to run") parser.add_argument("--debug", "-d", action="store_true", help="Enable debug logs") - parser.add_argument("--main_model", type=str, default=MAIN_MODEL, help="Main model to use for generation") - parser.add_argument("--earlystop_model", type=str, default=EARLYSTOP_MODEL, help="Model to use for early stopping") + parser.add_argument("--newline_threshold", type=int, default=20, help="Number of newlines in thinking before forcing step verification") + parser.add_argument("--max_corrections", type=int, default=3, help="Maximum number of correction attempts per example") + parser.add_argument("--warmup", type=int, default=4, help="Number of \\n to skip before starting side-chain verification") + parser.add_argument("--model", type=str, default=MAIN_MODEL, help="Main model to use for generation") + parser.add_argument("--port", type=int, default=8000, help="vLLM server port") args = parser.parse_args() - # Use models from args (allows command-line override) - main_model = args.main_model - earlystop_model = args.earlystop_model + main_model = args.model - # Setup output directories based on model name output_dirs = get_output_dirs(main_model) logfile = get_log_filename(main_model, args.num_examples) - token_filename = get_token_filename(main_model, args.num_examples) reason_dir = output_dirs["reasoning"] log_level = logging.DEBUG if args.debug else logging.INFO @@ -362,107 +259,120 @@ def evaluate_game24_answer(answer, nums): level=log_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", handlers=[ - logging.FileHandler(logfile, mode="w"), + logging.FileHandler(logfile, mode="w"), logging.StreamHandler() ], force=True, ) logger.info(f"Main model: {main_model}") - logger.info(f"Early stop model: {earlystop_model}") logger.info(f"Output directory: {output_dirs['base']}") + logger.info(f"Newline threshold: {args.newline_threshold}") + logger.info(f"Warmup: {args.warmup}") - dataset = load_game24_dataset() + dataset = load_dataset("nlile/24-game", split="train") - llm_server = init_llm_server(main_model, max_tokens=32768) + llm_server = init_llm_server(main_model, port=args.port) - # Load tokenizer for accurate token counting logger.info(f"Loading tokenizer for {main_model}...") tokenizer = AutoTokenizer.from_pretrained(main_model, trust_remote_code=True) logger.info("Tokenizer loaded successfully.") num_correct = 0 + num_attempted = 0 # model produced a real answer (not "no solution" and not missing after ) + num_excluded = 0 # excluded from soundness (no solution or token budget exceeded) N = args.num_examples - total_reasoning_tokens = 0 - reasoning_token_counts = [] + total_generated_tokens = 0 + generated_token_counts = [] - # total = len(dataset) indices = np.linspace(0, len(dataset)-1, N, dtype=int) - for idx in indices: #for idx in indices: + for idx in indices: example = dataset[idx] nums = example["numbers"] + prompt = build_prompt(nums) + full_prompt = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" - prompt = build_meta_prompt_from_example(nums) - - if args.monitor: - # Use StepVerifierGame24Monitor to detect when equation stabilizes k times - monitors=(StepVerifierGame24Monitor( - name="game24_kstable", - answer_start_token = "", - original_numbers=nums, # Validate equations use exactly these numbers - ),) - else: - monitors = () + monitor = ThinkingPhaseStepVerifierGame24Monitor( + name="game24_verifier", + original_numbers=nums, + llm_server=llm_server, + prompt=full_prompt, + newline_threshold=args.newline_threshold, + max_corrections=args.max_corrections, + answer_start_token="", + warmup_newlines=args.warmup, + ) - logger.info(f"---- length of monitors {len(monitors)} ----") logger.info(f"---- Example {idx+1} ----") logger.info(f"Numbers: {nums}") - answer = asyncio.run(stream_completion( - f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n", - llm_server=llm_server, - monitors=monitors, - add_delay=False, - termination_requires_validation=False, - async_execution=True - )) + try: + answer = asyncio.run(stream_completion( + full_prompt, + llm_server=llm_server, + monitors=(monitor,), + add_delay=False, + termination_requires_validation=False, + async_execution=True + )) + except Exception as e: + logger.error(f"Error running example {idx}: {e}") + continue save_prompt(idx, answer, reason_dir) logger.info(f"Raw final output:\n{answer}") - reasoning_tokens = count_tokens(answer, tokenizer) - reasoning_token_counts.append(reasoning_tokens) - total_reasoning_tokens += reasoning_tokens - logger.info(f"Generated tokens in this example: {reasoning_tokens}") + generated_tokens = count_tokens(answer, tokenizer) + generated_token_counts.append(generated_tokens) + total_generated_tokens += generated_tokens + logger.info(f"Generated tokens in this example: {generated_tokens}") is_correct, expr, message = evaluate_game24_answer(answer, nums) - + # Attempted: model produced a real answer (not "no solution" and not missing after ) + gave_no_solution = (expr is not None and "no solution" in expr.strip().lower()) + no_expr_found = (expr is None) + attempted = not (gave_no_solution or no_expr_found) + if attempted: + num_attempted += 1 + else: + num_excluded += 1 + if expr: logger.info(f"Extracted expression: {expr}") logger.info(message) - if is_correct: num_correct += 1 - # Calculate final statistics - avg_reasoning_tokens = total_reasoning_tokens / N if N > 0 else 0 + avg_generated_tokens = total_generated_tokens / N if N > 0 else 0 accuracy = num_correct / N if N > 0 else 0 - + soundness = num_correct / num_attempted if num_attempted > 0 else 0 + print(f"\nFinal Accuracy: {num_correct}/{N} ({accuracy:.2%})") - print(f"Average Reasoning Tokens: {avg_reasoning_tokens:.2f}") - print(f"Total Reasoning Tokens: {total_reasoning_tokens}") - - # Save results to a text file + print(f"Soundness: {num_correct}/{num_attempted} ({soundness:.2%})") + print(f"Excluded from soundness (no solution / token budget exceeded): {num_excluded}") + print(f"Average Generated Tokens: {avg_generated_tokens:.2f}") + print(f"Total Generated Tokens: {total_generated_tokens}") + results_file = logfile.replace('.log', '_results.txt') with open(results_file, 'w') as f: f.write(f"Game of 24 Evaluation Results\n") f.write(f"{'='*50}\n\n") f.write(f"Model: {main_model}\n") - f.write(f"Number of Examples: {N}\n") - f.write(f"Monitor Enabled: {args.monitor}\n\n") + f.write(f"Number of Examples: {N}\n\n") f.write(f"Results:\n") f.write(f"---------\n") f.write(f"Correct: {num_correct}/{N}\n") - f.write(f"Accuracy: {accuracy:.2%}\n\n") - f.write(f"Reasoning Token Statistics:\n") + f.write(f"Accuracy: {accuracy:.2%}\n") + f.write(f"Soundness: {num_correct}/{num_attempted} = {soundness:.2%}\n") + f.write(f"Excluded from soundness (no solution / token budget exceeded): {num_excluded}\n\n") + f.write(f"Generated Token Statistics:\n") f.write(f"---------------------------\n") - f.write(f"Total Reasoning Tokens: {total_reasoning_tokens}\n") - f.write(f"Average Reasoning Tokens: {avg_reasoning_tokens:.2f}\n") - if reasoning_token_counts: - f.write(f"Min Reasoning Tokens: {min(reasoning_token_counts)}\n") - f.write(f"Max Reasoning Tokens: {max(reasoning_token_counts)}\n") - f.write(f"Std Dev: {np.std(reasoning_token_counts):.2f}\n") - + f.write(f"Total Generated Tokens: {total_generated_tokens}\n") + f.write(f"Average Generated Tokens: {avg_generated_tokens:.2f}\n") + if generated_token_counts: + f.write(f"Min Generated Tokens: {min(generated_token_counts)}\n") + f.write(f"Max Generated Tokens: {max(generated_token_counts)}\n") + f.write(f"Std Dev: {np.std(generated_token_counts):.2f}\n") logger.info(f"Results saved to {results_file}") print(f"Results saved to {results_file}") diff --git a/examples/TTSwithVerification/game24meta.py b/examples/TTSwithVerification/game24meta.py new file mode 100644 index 0000000..a48e110 --- /dev/null +++ b/examples/TTSwithVerification/game24meta.py @@ -0,0 +1,475 @@ +import argparse +import asyncio +import json +import logging +import os +import re +import numpy as np + +from datasets import load_dataset +from transformers import AutoTokenizer + +from interwhen import stream_completion +from interwhen.monitors import StepVerifierGame24Monitor + +# ============== MODEL CONFIGURATION ============== +MAIN_MODEL = "Qwen/QwQ-32B" +# ================================================= + +_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) + +# Walk up to find the repo root (contains pyproject.toml), output to its parent +_dir = _SCRIPT_DIR +while _dir != os.path.dirname(_dir) and not os.path.isfile(os.path.join(_dir, "pyproject.toml")): + _dir = os.path.dirname(_dir) +_OUTPUT_ROOT = os.path.dirname(_dir) + +def get_model_short_name(model_name: str) -> str: + """Extract a short, filesystem-safe name from the model path.""" + short_name = model_name.split("/")[-1] + short_name = short_name.replace(" ", "_").replace(":", "-") + return short_name + +def get_output_dirs(main_model: str, base_dir: str = None): + """Create and return output directory paths based on model name.""" + if base_dir is None: + base_dir = os.path.join(_OUTPUT_ROOT, "Outputs_TTS", "Gameof24results", "metaPrompt") + model_short_name = get_model_short_name(main_model) + output_base = os.path.join(base_dir, model_short_name) + + dirs = { + "base": output_base, + "reasoning": os.path.join(output_base, "Reasoning_output"), + "csv_saved": os.path.join(output_base, "csv_saved"), + } + + # Create all directories + for dir_path in dirs.values(): + os.makedirs(dir_path, exist_ok=True) + + return dirs + +def get_log_filename(main_model: str, num_examples: int, base_dir: str = None) -> str: + """Generate log filename based on model name.""" + if base_dir is None: + base_dir = os.path.join(_OUTPUT_ROOT, "Outputs_TTS", "Gameof24results", "metaPrompt") + model_short_name = get_model_short_name(main_model) + output_base = os.path.join(base_dir, model_short_name) + os.makedirs(output_base, exist_ok=True) + return os.path.join(output_base, f"EAT_{num_examples}examples.log") + +def save_prompt(idx, prompt_with_answer, reason_dir): + filename = os.path.join(reason_dir, f"reason_{idx}.txt") + with open(filename, "w", encoding="utf-8") as f: + f.write(prompt_with_answer) + +logger = logging.getLogger(__name__) + + +def init_llm_server(modelname, max_tokens=32768, port=8000): + url = f"http://localhost:{port}/v1/completions" + payload = { + "model": modelname, + "max_tokens": max_tokens, + "top_k": 20, + "top_p": 0.95, + "min_p": 0.0, + "do_sample": True, + "temperature": 0.6, + "stream": True, + "logprobs": 20, + "use_beam_search": False, + "prompt_cache": True, + "seed": 42 + } + headers = {"Content-Type": "application/json"} + return {"url": url, "payload": payload, "headers": headers} + + +def build_meta_prompt_from_example(nums): + """Build the system and user prompts for Game of 24 with step verification format.""" + a, b, c, d = nums + + system_prompt = r"""You are solving the Game of 24. + +GAME RULES: +- You are given four numbers +- Use ALL four numbers exactly once +- Use ONLY the operations: +, -, *, / +- The final expression must evaluate to exactly 24 + +OUTPUT FORMAT: +You must follow this EXACT structured format for your solution: + +>Step1 +available numbers: [a, b, c, d] +suggested operation: a * b = result1 +remaining numbers: [result1, c, d] + +>Step2 +available numbers: [result1, c, d] +suggested operation: result1 + c = result2 +remaining numbers: [result2, d] + +>Step3 +available numbers: [result2, d] +suggested operation: result2 - d = 24 +remaining numbers: [24] + +> Final expression: \boxed{expression using original numbers} + +IMPORTANT RULES: +1. Each step MUST show the available numbers at the start +2. Each step MUST show the suggested operation with its result +3. Each step MUST show the remaining numbers after the operation +4. Continue until you reach exactly 24 +5. The final expression inside \boxed{} must use the ORIGINAL numbers +6. If you receive VERIFIER FEEDBACK, immediately provide a corrected step - do NOT restart your thinking + +═══════════════════════════════════════════════════════════════════════════════ +EXAMPLE 1: Numbers [2, 3, 4, 5] +═══════════════════════════════════════════════════════════════════════════════ + +### Final Answer + +>Step1 +available numbers: [2, 3, 4, 5] +suggested operation: 5 + 3 = 8 +remaining numbers: [8, 2, 4] + +>Step2 +available numbers: [8, 2, 4] +suggested operation: 8 - 2 = 6 +remaining numbers: [6, 4] + +>Step3 +available numbers: [6, 4] +suggested operation: 6 * 4 = 24 +remaining numbers: [24] + +> Final expression: \boxed{(5 + 3 - 2) * 4} + +═══════════════════════════════════════════════════════════════════════════════ +EXAMPLE 2: Numbers [1, 5, 5, 5] +═══════════════════════════════════════════════════════════════════════════════ + +### Final Answer + +>Step1 +available numbers: [1, 5, 5, 5] +suggested operation: 1 / 5 = 0.2 +remaining numbers: [0.2, 5, 5] + +>Step2 +available numbers: [0.2, 5, 5] +suggested operation: 5 - 0.2 = 4.8 +remaining numbers: [4.8, 5] + +>Step3 +available numbers: [4.8, 5] +suggested operation: 4.8 * 5 = 24 +remaining numbers: [24] + +> Final expression: \boxed{(5 - 1/5) * 5} + +═══════════════════════════════════════════════════════════════════════════════ +EXAMPLE 3: Handling Verifier Feedback - Numbers [1, 2, 6, 8] +═══════════════════════════════════════════════════════════════════════════════ + +### Final Answer + +>Step1 +available numbers: [1, 2, 6, 8] +suggested operation: 8 / 2 = 4 +remaining numbers: [4, 1, 6] + +>Step2 +available numbers: [4, 1, 6] +suggested operation: 4 - 1 = 3 +remaining numbers: [3, 6] + +[VERIFIER FEEDBACK for Step 2: + ✗ Cannot reach 24 from remaining numbers [3, 6]. This path is a dead end. +The previous steps are correct. Please provide a corrected Step 2 and continue.] + +>Step2 +available numbers: [4, 1, 6] +suggested operation: 6 - 1 = 5 +remaining numbers: [5, 4] + +[VERIFIER FEEDBACK for Step 2: + ✗ Cannot reach 24 from remaining numbers [4, 5]. This path is a dead end. +The previous steps are correct. Please provide a corrected Step 2 and continue.] + +>Step2 +available numbers: [4, 1, 6] +suggested operation: 6 * 1 = 6 +remaining numbers: [6, 4] + +>Step3 +available numbers: [6, 4] +suggested operation: 6 * 4 = 24 +remaining numbers: [24] + +> Final expression: \boxed{(8 / 2) * 6 * 1} + +═══════════════════════════════════════════════════════════════════════════════ + +Now solve the following Game of 24 problem using the EXACT same format.""" + + user_prompt = f""" +Numbers: {a}, {b}, {c}, {d} + +Find an arithmetic expression using these four numbers exactly once each with +, -, *, / that equals 24. + +Use the structured step-by-step format shown in the examples above.""" + + # Combine into a single prompt + full_prompt = f"{system_prompt}\n\n{user_prompt}" + + return full_prompt + + +def count_tokens(text: str, tokenizer) -> int: + """Count the total number of tokens in the generated text using the tokenizer.""" + tokens = tokenizer.encode(text, add_special_tokens=False) + return len(tokens) + + +def extract_solution(text): + + # Use a more robust extraction that handles nested braces in \boxed{} + # Find \boxed{ and then match braces properly + boxed_pattern = r"\\boxed\{" + matches = list(re.finditer(boxed_pattern, text)) + if not matches: + return None + + # Get the last \boxed{} content by matching braces + last_match = matches[-1] + start = last_match.end() # Position right after \boxed{ + brace_count = 1 + end = start + while end < len(text) and brace_count > 0: + if text[end] == '{': + brace_count += 1 + elif text[end] == '}': + brace_count -= 1 + end += 1 + + expr = text[start:end-1].strip() # -1 to exclude the closing brace + + # 1. Convert \frac{a}{b} to (a/b) + frac_pattern = r"\\frac\{([^{}]+)\}\{([^{}]+)\}" + while re.search(frac_pattern, expr): + expr = re.sub(frac_pattern, r"(\1/\2)", expr) + + # 2. Replace LaTeX operators + replacements = { + r"\times": "*", + r"\cdot": "*", + r"\div": "/", + } + for latex, op in replacements.items(): + expr = expr.replace(latex, op) + + # 2b. Replace Unicode math operators (QwQ frequently uses these) + expr = expr.replace('\u00d7', '*').replace('\u00f7', '/').replace('\u2212', '-') + expr = expr.replace('\u2013', '-').replace('\u2014', '-') # en-dash, em-dash + + # 3. Cleanup (remove LaTeX formatting artifacts) + expr = expr.replace(r"\,", "").replace(r"\ ", "") + expr = expr.replace(r"\left", "").replace(r"\right", "") + + # 4. Handle implicit multiplication (e.g., "(11+1)(1+1)" -> "(11+1)*(1+1)") + # Insert * between: )( , )number, number(, )( + expr = re.sub(r'\)\s*\(', ')*(', expr) # )( -> )*( + expr = re.sub(r'\)\s*(\d)', r')*\1', expr) # )number -> )*number + expr = re.sub(r'(\d)\s*\(', r'\1*(', expr) # number( -> number*( + + return expr + +def extract_numbers_from_expr(expr): + """Extract all numbers (including decimals) from an expression.""" + # Match integers and decimals + numbers = re.findall(r'\d+\.?\d*', expr) + return [int(float(n)) if float(n).is_integer() else float(n) for n in numbers] + +def validate_numbers_used(expr, expected_nums): + """Check if the expression uses exactly the given numbers (each exactly once).""" + used_nums = extract_numbers_from_expr(expr) + # Sort both lists to compare + return sorted(used_nums) == sorted(expected_nums) + +def evaluate_expression(expr, expected_nums=None): + try: + # First check if expression uses exactly the given numbers + if expected_nums is not None: + if not validate_numbers_used(expr, expected_nums): + return False + + value = eval(expr, {"__builtins__": None}, {}) + return abs(value - 24) < 1e-6 + except Exception: + return False + +def evaluate_game24_answer(answer, nums): + """ + Evaluate a Game24 answer and return (is_correct, expr, error_message). + + Args: + answer: Raw model output + nums: Expected numbers to use + + Returns: + Tuple of (is_correct, extracted_expression, error_message) + """ + expr = extract_solution(answer) + + if not expr: + return False, None, "No expression found" + + if evaluate_expression(expr, expected_nums=nums): + return True, expr, "Correct solution (evaluates to 24 using exactly the given numbers)" + else: + used_nums = extract_numbers_from_expr(expr) + if sorted(used_nums) != sorted(nums): + return False, expr, f"Incorrect: Expression uses {used_nums}, expected {nums}" + else: + return False, expr, "Expression does not evaluate to 24" + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description="Game of 24 step-by-step solver with monitors") + parser.add_argument("--monitor", "-m", action="store_true", help="Enable step-by-step monitor") + parser.add_argument("--num_examples", "-n", type=int, default=1, help="Number of examples to run") + parser.add_argument("--debug", "-d", action="store_true", help="Enable debug logs") + parser.add_argument("--model", type=str, default=MAIN_MODEL, help="Model to use for generation") + parser.add_argument("--port", type=int, default=8000, help="vLLM server port") + parser.add_argument("--max_corrections", type=int, default=5, + help="Maximum number of correction attempts per example") + args = parser.parse_args() + + main_model = args.model + + # Setup output directories based on model name + output_dirs = get_output_dirs(main_model) + logfile = get_log_filename(main_model, args.num_examples) + reason_dir = output_dirs["reasoning"] + + log_level = logging.DEBUG if args.debug else logging.INFO + + logging.basicConfig( + level=log_level, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[ + logging.FileHandler(logfile, mode="w"), + logging.StreamHandler() + ], + force=True, + ) + + logger.info(f"Main model: {main_model}") + logger.info(f"Output directory: {output_dirs['base']}") + + dataset = load_dataset("nlile/24-game", split="train") + + llm_server = init_llm_server(main_model, port=args.port) + + # Load tokenizer for accurate token counting + logger.info(f"Loading tokenizer for {main_model}...") + tokenizer = AutoTokenizer.from_pretrained(main_model, trust_remote_code=True) + logger.info("Tokenizer loaded successfully.") + + num_correct = 0 + N = args.num_examples + total_reasoning_tokens = 0 + reasoning_token_counts = [] + + # total = len(dataset) + indices = np.linspace(0, len(dataset)-1, N, dtype=int) + + for idx in indices: #for idx in indices: + example = dataset[idx] + nums = example["numbers"] + + prompt = build_meta_prompt_from_example(nums) + + if args.monitor: + monitors = (StepVerifierGame24Monitor( + name="game24_verifier", + answer_start_token="", + original_numbers=nums, + max_corrections=args.max_corrections, + ),) + else: + monitors = () + + logger.info(f"---- Example {idx+1} ----") + logger.info(f"Numbers: {nums}") + + full_prompt = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" + + try: + answer = asyncio.run(stream_completion( + full_prompt, + llm_server=llm_server, + monitors=monitors, + add_delay=False, + termination_requires_validation=False, + async_execution=True + )) + except Exception as e: + logger.error(f"Error running example {idx}: {e}") + import traceback + traceback.print_exc() + continue + + save_prompt(idx, answer, reason_dir) + logger.info(f"Raw final output:\n{answer}") + + reasoning_tokens = count_tokens(answer, tokenizer) + reasoning_token_counts.append(reasoning_tokens) + total_reasoning_tokens += reasoning_tokens + logger.info(f"Generated tokens in this example: {reasoning_tokens}") + + is_correct, expr, message = evaluate_game24_answer(answer, nums) + + if expr: + logger.info(f"Extracted expression: {expr}") + logger.info(message) + + if is_correct: + num_correct += 1 + + # Calculate final statistics + avg_reasoning_tokens = total_reasoning_tokens / N if N > 0 else 0 + accuracy = num_correct / N if N > 0 else 0 + + print(f"\nFinal Accuracy: {num_correct}/{N} ({accuracy:.2%})") + print(f"Average Reasoning Tokens: {avg_reasoning_tokens:.2f}") + print(f"Total Reasoning Tokens: {total_reasoning_tokens}") + + # Save results to a text file + results_file = logfile.replace('.log', '_results.txt') + with open(results_file, 'w') as f: + f.write(f"Game of 24 Evaluation Results\n") + f.write(f"{'='*50}\n\n") + f.write(f"Model: {main_model}\n") + f.write(f"Number of Examples: {N}\n") + f.write(f"Monitor Enabled: {args.monitor}\n\n") + f.write(f"Results:\n") + f.write(f"---------\n") + f.write(f"Correct: {num_correct}/{N}\n") + f.write(f"Accuracy: {accuracy:.2%}\n\n") + f.write(f"Reasoning Token Statistics:\n") + f.write(f"---------------------------\n") + f.write(f"Total Reasoning Tokens: {total_reasoning_tokens}\n") + f.write(f"Average Reasoning Tokens: {avg_reasoning_tokens:.2f}\n") + if reasoning_token_counts: + f.write(f"Min Reasoning Tokens: {min(reasoning_token_counts)}\n") + f.write(f"Max Reasoning Tokens: {max(reasoning_token_counts)}\n") + f.write(f"Std Dev: {np.std(reasoning_token_counts):.2f}\n") + + logger.info(f"Results saved to {results_file}") + print(f"Results saved to {results_file}") \ No newline at end of file diff --git a/examples/TTSwithVerification/maze_stepverifier.py b/examples/TTSwithVerification/maze_stepverifier.py index ebeb1d9..9884065 100644 --- a/examples/TTSwithVerification/maze_stepverifier.py +++ b/examples/TTSwithVerification/maze_stepverifier.py @@ -1,32 +1,43 @@ """ -Maze experiment with step-by-step verification using StepVerifierMazeMonitor. +Maze experiment with thinking-phase step verification. -Uses the new monitor-based architecture that integrates with stream_completion. +Uses ThinkingPhaseStepVerifierMazeMonitor which: + - Verifies the model's traced path during via side-streams + - Injects a structured step format after (no meta-prompt needed) + - Verifies each step as the model fills in the structured template """ import argparse import asyncio +import csv import json import logging import os import re import numpy as np -from pathlib import Path from datasets import load_dataset from transformers import AutoTokenizer from interwhen import stream_completion -from interwhen.monitors import StepVerifierMazeMonitor +from interwhen.monitors import ThinkingPhaseStepVerifierMazeMonitor from interwhen.utils.maze_verifier import parse_maze_from_prompt logging.basicConfig(level=logging.INFO, format='%(message)s') logger = logging.getLogger(__name__) # ============== MODEL CONFIGURATION ============== -MAIN_MODEL = "Qwen/Qwen3-30B-A3B-Thinking-2507" +MAIN_MODEL = "Qwen/QwQ-32B" # ================================================= +_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) + +# Walk up to find the repo root (contains pyproject.toml), output to its parent +_dir = _SCRIPT_DIR +while _dir != os.path.dirname(_dir) and not os.path.isfile(os.path.join(_dir, "pyproject.toml")): + _dir = os.path.dirname(_dir) +_OUTPUT_ROOT = os.path.dirname(_dir) + def get_model_short_name(model_name: str) -> str: """Extract a short, filesystem-safe name from the model path.""" @@ -35,14 +46,17 @@ def get_model_short_name(model_name: str) -> str: return short_name -def get_output_dirs(main_model: str, base_dir: str = "../../Outputs_TTS/MazeResults"): +def get_output_dirs(main_model: str, base_dir: str = None): """Create and return output directory paths based on model name.""" + if base_dir is None: + base_dir = os.path.join(_OUTPUT_ROOT, "Outputs_TTS", "MazeResults") model_short_name = get_model_short_name(main_model) output_base = os.path.join(base_dir, model_short_name) dirs = { "base": output_base, "reasoning": os.path.join(output_base, "Reasoning_output"), + "csv_saved": os.path.join(output_base, "csv_saved"), } for dir_path in dirs.values(): @@ -50,126 +64,34 @@ def get_output_dirs(main_model: str, base_dir: str = "../../Outputs_TTS/MazeResu return dirs - -def build_meta_prompt_from_example(example): - """Build prompt for maze example.""" - system_prompt = """You are a maze-solving AI. Given a maze in ASCII format, analyze it step by step. - -COORDINATE SYSTEM: -- Rows are numbered from top (row 0) to bottom -- Columns are numbered from left (col 0) to right -- Movement: UP (row decreases), DOWN (row increases), LEFT (col decreases), RIGHT (col increases) - -TURN DEFINITIONS: -- RIGHT_TURN = 90° clockwise change (e.g., DOWN→LEFT, LEFT→UP, UP→RIGHT, RIGHT→DOWN) -- LEFT_TURN = 90° counterclockwise change (e.g., DOWN→RIGHT, RIGHT→UP, UP→LEFT, LEFT→DOWN) - -RELATIVE POSITION DEFINITIONS: -- "directly to the left" = same row, E has smaller column than S -- "directly to the right" = same row, E has larger column than S -- "directly above" = same column, E has smaller row than S -- "directly below" = same column, E has larger row than S -- "top left" = E has smaller row AND smaller column than S -- "top right" = E has smaller row AND larger column than S -- "bottom left" = E has larger row AND smaller column than S -- "bottom right" = E has larger row AND larger column than S - -IMPORTANT: Follow the EXACT output format below. Do NOT use tags. - -EXAMPLE 1: Counting Right Turns -Question: How many right turns are there in the path from S to E? - ->>> LOCATE START AND EXIT: - S position: (3,5) - E position: (1,1) - ->>> STEP 1: Move DOWN from (3,5) to (4,5) - Current position: (4,5) - Previous direction: — - Current direction: DOWN - Turn type: STRAIGHT - Running count: Right=0, Left=0 - ->>> STEP 2: Move DOWN from (4,5) to (5,5) - Current position: (5,5) - Previous direction: DOWN - Current direction: DOWN - Turn type: STRAIGHT - Running count: Right=0, Left=0 - ->>> STEP 3: Move LEFT from (5,5) to (5,4) - Current position: (5,4) - Previous direction: DOWN - Current direction: LEFT - Turn type: RIGHT_TURN - Running count: Right=1, Left=0 - ->>> FINAL ANSWER: Right turns = 2 - \\boxed{C} - -EXAMPLE 2: Counting Total Turns -Question: How many total turns are there in the path from S to E? - ->>> LOCATE START AND EXIT: - S position: (3,5) - E position: (1,1) - ->>> STEP 1: Move DOWN from (3,5) to (4,5) - Current position: (4,5) - Previous direction: — - Current direction: DOWN - Turn type: STRAIGHT - Running count: Right=0, Left=0, Total=0 - -[... continue for all steps ...] - ->>> FINAL ANSWER: Total turns = 2 - \\boxed{C} - -EXAMPLE 3: Relative Position -Question: Is the exit (E) to the top left of the starting point (S)? - ->>> LOCATE START AND EXIT: - S position: (3,5) - E position: (1,1) - ->>> COMPARE POSITIONS: - Row comparison: E row (1) < S row (3) → E is ABOVE S ✓ - Col comparison: E col (1) < S col (5) → E is LEFT of S ✓ - ->>> ANALYSIS: - E is above S (smaller row): YES - E is left of S (smaller col): YES - Therefore E is at TOP LEFT of S. - ->>> ANSWER: YES, E is to the top left of S. - \\boxed{A} - -════════════════════════════════════════════════════════════════════════════════ -Now solve the following maze using the EXACT same format. First locate S and E, then trace the path step by step.""" - - # Get the maze description (trimmed to remove trailing instructions) - description = str(example.get("prompt", "")) - description_trimmed = description[:-143] if len(description) > 143 else description - - return system_prompt, description_trimmed - - -def extract_solution(text: str) -> str: - """Extract the boxed answer from the response (after ).""" - if "" in text: - answer_section = text.split("")[-1] - else: - answer_section = text - - matches = re.findall(r'\\boxed\{([^}]*)\}', answer_section) - if matches: - return matches[-1].strip() - - match = re.search(r'(?:answer|Answer)[:\s]+([A-D])', answer_section) - if match: - return match.group(1).strip() - +def remove_last_paragraph(s: str) -> str: + return s[:-143] + +def build_prompt_from_example(example): #(original prompt config) + + pre_prompt = "You are an expert problem solver. Carefully read the following multiple-choice question and think through the solution step-by-step before providing your final answer. Provide your final answer option by enclosing it within \\boxed{A/B/C/D}.:" + description = example.get("prompt") + description = str(description) + description = remove_last_paragraph(description) + return pre_prompt, description + + +def extract_solution_mcq(text): + """Extract MCQ solution from model output.""" + patterns = [ + r"\\boxed\{([^}]*)\}", + r"boxed\{([^}]*)\}", + r"\*\*([A-D])\*\*", + r"answer[:\s]*([A-D])", + r"(?:^|\n)([A-D])(?:\s|$|\.)", + ] + for pattern in patterns: + matches = re.findall(pattern, text, re.IGNORECASE) + if matches: + expr = matches[-1].strip() + choice_match = re.search(r"\b([ABCD])\b", expr, flags=re.IGNORECASE) + if choice_match: + return choice_match.group(1).upper() return None @@ -179,31 +101,15 @@ def count_tokens(text: str, tokenizer) -> int: return len(tokens) -def get_question_type_from_index(idx: int) -> str: - """Determine question type based on index range. - - Dataset structure: - - 3000-3499: right turns - - 3500-3999: total turns - - 4000-4500: relative position - """ - if idx < 3500: - return "right_turns" - elif idx < 4000: - return "total_turns" - else: - return "relative_position" - - -def init_llm_server(model_name, max_tokens=32768, port=8000): - """Initialize LLM server configuration.""" +def init_llm_server(modelname, max_tokens=32768, port=8000): url = f"http://localhost:{port}/v1/completions" payload = { - "model": model_name, + "model": modelname, "max_tokens": max_tokens, "top_k": 20, "top_p": 0.95, "min_p": 0.0, + "do_sample": True, "temperature": 0.6, "stream": True, "logprobs": 20, @@ -215,77 +121,73 @@ def init_llm_server(model_name, max_tokens=32768, port=8000): return {"url": url, "payload": payload, "headers": headers} -def save_output(idx: int, output: str, output_dir: str): - """Save output to file.""" - os.makedirs(output_dir, exist_ok=True) - filepath = os.path.join(output_dir, f"output_{idx}.txt") - with open(filepath, 'w') as f: - f.write(output) - logger.info(f"Saved output to {filepath}") +def save_prompt(idx, prompt_with_answer, reason_dir): + """Save reasoning trace to file.""" + os.makedirs(reason_dir, exist_ok=True) + filename = os.path.join(reason_dir, f"reason_{idx}.txt") + with open(filename, "w", encoding="utf-8") as f: + f.write(prompt_with_answer) + logger.info(f"Saved reasoning trace to {filename}") -def evaluate_maze_answer(answer, options, ground_truth): - """ - Evaluate a Maze MCQ answer and return (is_correct, extracted_answer, message). - - Args: - answer: Raw model output - options: Dictionary mapping option letters (A/B/C/D) to their values - ground_truth: The correct answer value - - Returns: - Tuple of (is_correct, extracted_answer, message) - """ - sol = extract_solution(answer) + +def get_log_filename(main_model: str, num_examples: int, base_dir: str = None) -> str: + """Generate log filename based on model name.""" + if base_dir is None: + base_dir = os.path.join(_OUTPUT_ROOT, "Outputs_TTS", "MazeResults") + model_short_name = get_model_short_name(main_model) + output_base = os.path.join(base_dir, model_short_name) + os.makedirs(output_base, exist_ok=True) + return os.path.join(output_base, f"EAT_{num_examples}examples.log") + + +def evaluate_mcq_answer(answer, options, ground_truth): + sol = extract_solution_mcq(answer) gt_sol = str(ground_truth).strip() - if not sol: return False, None, "No expression found" - sol = sol.strip() - - # Case 1: LLM returned option letter (A/B/C/D) if sol in options: if options[sol] == gt_sol: return True, sol, f"Correct: option {sol} -> {options[sol]}" - else: - return False, sol, f"Incorrect: expected '{gt_sol}', got '{options[sol]}' (option {sol})" - - # Case 2: LLM returned the actual answer text - # First check if sol matches ground truth directly + return False, sol, f"Incorrect: expected '{gt_sol}', got '{options[sol]}' (option {sol})" if sol.lower() == gt_sol.lower(): return True, sol, f"Correct: answer text matches ground truth: {sol}" - - # Check if sol matches any option value for opt_letter, opt_value in options.items(): if sol.lower() == opt_value.lower(): if opt_value == gt_sol: return True, sol, f"Correct: answer text {sol} (option {opt_letter})" - else: - return False, sol, f"Incorrect: expected '{gt_sol}', got '{opt_value}' (option {opt_letter})" - + return False, sol, f"Incorrect: expected '{gt_sol}', got '{opt_value}' (option {opt_letter})" return False, sol, f"Solution '{sol}' not found in options or ground truth" if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Run maze experiments with StepVerifierMazeMonitor") + parser = argparse.ArgumentParser(description="Run maze experiments with step verification") parser.add_argument("--model", type=str, default=MAIN_MODEL, help="Model name for generation") parser.add_argument("--indices", type=str, default=None, help="Comma-separated indices to run (e.g., '3000,3500,4000')") - parser.add_argument("--start", type=int, default=3000, help="Start index") - parser.add_argument("--end", type=int, default=3010, help="End index") + parser.add_argument("--start", type=int, default=0, help="Start index") + parser.add_argument("--end", type=int, default=10, help="End index") parser.add_argument("--num_examples", "-n", type=int, default=None, help="Number of examples to run (overrides start/end)") parser.add_argument("--max_corrections", type=int, default=5, help="Maximum number of correction attempts per example") parser.add_argument("--port", type=int, default=8000, help="vLLM server port") parser.add_argument("--debug", "-d", action="store_true", help="Enable debug logging") + parser.add_argument("--newline_threshold", type=int, default=20, + help="Number of \\n in thinking before triggering side verification") + parser.add_argument("--warmup", type=int, default=0, + help="Number of \\n to skip before starting side-chain verification (warmup period)") args = parser.parse_args() + + logger.info(f"Thinking-phase verification: always on") + logger.info(f" Newline threshold: {args.newline_threshold}") + logger.info(f" Warmup: {args.warmup}") if args.debug: logging.getLogger().setLevel(logging.DEBUG) # Load dataset - dataset = load_dataset("microsoft/VISION_LANGUAGE", 'maze', split='val') + dataset = load_dataset("microsoft/VISION_LANGUAGE", 'maze_text_only', split='val') # Setup LLM server llm_server = init_llm_server(args.model, port=args.port) @@ -303,8 +205,8 @@ def evaluate_maze_answer(answer, options, ground_truth): if args.indices: indices = [int(x.strip()) for x in args.indices.split(",")] elif args.num_examples: - # Use 4499 as endpoint (4500 is out of bounds since dataset size is 4500) - indices = np.linspace(3000, 4499, args.num_examples, dtype=int) + # Use 1499 as endpoint (1500 is out of bounds since dataset size is 1500) + indices = np.linspace(0, 1499, args.num_examples, dtype=int) else: indices = range(args.start, args.end) @@ -313,10 +215,13 @@ def evaluate_maze_answer(answer, options, ground_truth): total_correct = 0 total_examples = 0 total_reasoning_tokens = 0 + num_attempted = 0 # examples where a \boxed{} answer was produced + reasoning_token_counts = [] + per_example_results = [] # list of dicts for CSV for idx in indices: example = dataset[idx] - system_prompt, user_prompt = build_meta_prompt_from_example(example) + pre_prompt, user_prompt = build_prompt_from_example(example) if str(example.get("ground_truth", "")).strip() == "Q4": target_options = ["A", "B"] else: @@ -325,8 +230,7 @@ def evaluate_maze_answer(answer, options, ground_truth): pattern = rf'\b({keys})\.\s*([A-Za-z0-9]+)\b' options = dict(re.findall(pattern, user_prompt)) - # Build full prompt - full_prompt = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n\n" + full_prompt = f"<|im_start|>system\n{pre_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" # Parse maze from prompt grid, start_pos, exit_pos = parse_maze_from_prompt(user_prompt) @@ -336,23 +240,29 @@ def evaluate_maze_answer(answer, options, ground_truth): continue # Detect question type from prompt (auto-detection) - # Falls back to index-based if no turn keywords found - question_type = StepVerifierMazeMonitor.detect_question_type(user_prompt) + question_type = ThinkingPhaseStepVerifierMazeMonitor.detect_question_type(user_prompt) logger.info(f"\n{'='*60}") logger.info(f"Example {idx} ({question_type})") logger.info(f"Maze: S={start_pos}, E={exit_pos}, grid={len(grid)}x{len(grid[0]) if grid else 0}") logger.info(f"{'='*60}") - # Create the monitor - monitor = StepVerifierMazeMonitor( - name="maze_step_verifier", - answer_start_token="", + # Always use ThinkingPhaseStepVerifierMazeMonitor: + # Phase 1 — verifies during via side-streams + # Phase 2a — injects structured step format after + # Phase 2b — verifies structured output as model fills it in + monitor = ThinkingPhaseStepVerifierMazeMonitor( + name="maze_thinking_verifier", grid=grid, start_pos=start_pos, exit_pos=exit_pos, - max_corrections=args.max_corrections, + llm_server=llm_server, + prompt=full_prompt, question_type=question_type, + newline_threshold=args.newline_threshold, + max_corrections=args.max_corrections, + answer_start_token="", + warmup_newlines=args.warmup, ) # Run with stream_completion @@ -371,12 +281,23 @@ def evaluate_maze_answer(answer, options, ground_truth): traceback.print_exc() continue + # Save reasoning trace + save_prompt(int(idx), answer, reason_dir) + logger.info(f"Raw final output:\n{answer}") + # Count generated tokens reasoning_tokens = count_tokens(answer, tokenizer) total_reasoning_tokens += reasoning_tokens + reasoning_token_counts.append(reasoning_tokens) + logger.info(f"Generated tokens in this example: {reasoning_tokens}") gt_sol = str(example.get("ground_truth", "")).strip() - is_correct, extracted_answer, message = evaluate_maze_answer(answer, options, gt_sol) + is_correct, extracted_answer, message = evaluate_mcq_answer(answer, options, gt_sol) + + # "attempted" = model produced a real \boxed{} answer (not "no solution") + attempted = (extracted_answer is not None and extracted_answer.strip().lower() != "no solution") + if attempted: + num_attempted += 1 if extracted_answer: logger.info(f"Extracted answer: {extracted_answer}") @@ -391,17 +312,30 @@ def evaluate_maze_answer(answer, options, ground_truth): 'idx': int(idx), # Convert numpy int64 to Python int 'question_type': question_type, 'correct': is_correct, + 'attempted': attempted, 'sol': extracted_answer, 'gt': gt_sol, 'reasoning_tokens': reasoning_tokens, } results.append(result) - logger.info(f"Result: sol={extracted_answer}, gt={gt_sol}, correct={is_correct}") + per_example_results.append({ + "index": int(idx), + "question_type": question_type, + "correct": is_correct, + "attempted": attempted, + "sol": extracted_answer if extracted_answer else "", + "gt": gt_sol, + "tokens": reasoning_tokens, + "message": message, + }) + + logger.info(f"Result: sol={extracted_answer}, gt={gt_sol}, correct={is_correct}, attempted={attempted}") logger.info(f"Reasoning tokens: {reasoning_tokens}") # Compute final metrics accuracy = total_correct / total_examples if total_examples > 0 else 0 + soundness = total_correct / num_attempted if num_attempted > 0 else 0 # correct / attempted avg_reasoning_tokens = total_reasoning_tokens / total_examples if total_examples > 0 else 0 logger.info(f"\n{'='*60}") @@ -409,16 +343,33 @@ def evaluate_maze_answer(answer, options, ground_truth): logger.info(f"{'='*60}") logger.info(f"Total examples: {total_examples}") logger.info(f"Correct: {total_correct}") + logger.info(f"Attempted (produced \\boxed answer): {num_attempted}/{total_examples}") logger.info(f"Accuracy: {accuracy:.4f} ({total_correct}/{total_examples})") + logger.info(f"Soundness: {soundness:.4f} ({total_correct}/{num_attempted})") logger.info(f"Total reasoning tokens: {total_reasoning_tokens}") logger.info(f"Avg reasoning tokens: {avg_reasoning_tokens:.1f}") + print(f"\nFinal Accuracy: {total_correct}/{total_examples} ({accuracy:.2%})") + print(f"Soundness: {total_correct}/{num_attempted} ({soundness:.2%})") + print(f"Average Reasoning Tokens: {avg_reasoning_tokens:.2f}") + print(f"Total Reasoning Tokens: {total_reasoning_tokens}") + + # Save per-example CSV + csv_file = os.path.join(output_dirs["csv_saved"], f"results_{total_examples}examples.csv") + with open(csv_file, 'w', newline='') as f: + writer = csv.DictWriter(f, fieldnames=["index", "question_type", "correct", "attempted", "sol", "gt", "tokens", "message"]) + writer.writeheader() + writer.writerows(per_example_results) + logger.info(f"Per-example CSV saved to {csv_file}") + # Save summary summary = { 'model': args.model, 'total_examples': total_examples, 'correct': total_correct, + 'attempted': num_attempted, 'accuracy': accuracy, + 'soundness': soundness, 'total_reasoning_tokens': total_reasoning_tokens, 'avg_reasoning_tokens': avg_reasoning_tokens, 'max_corrections': args.max_corrections, @@ -428,4 +379,34 @@ def evaluate_maze_answer(answer, options, ground_truth): summary_path = os.path.join(output_dirs["base"], "summary.json") with open(summary_path, 'w') as f: json.dump(summary, f, indent=2) - logger.info(f"\nSaved summary to {summary_path}") \ No newline at end of file + logger.info(f"\nSaved summary to {summary_path}") + + # Save results summary to a text file + results_file = os.path.join(output_dirs["base"], f"EAT_{total_examples}examples_results.txt") + with open(results_file, 'w') as f: + f.write(f"Maze Step Verification Results\n") + f.write(f"{'='*50}\n\n") + f.write(f"Model: {args.model}\n") + f.write(f"Number of Examples: {total_examples}\n") + f.write(f"Max Corrections: {args.max_corrections}\n") + f.write(f"Newline Threshold: {args.newline_threshold}\n") + f.write(f"Warmup: {args.warmup}\n") + f.write(f"\n") + f.write(f"Results:\n") + f.write(f"---------\n") + f.write(f"Correct: {total_correct}/{total_examples}\n") + f.write(f"Accuracy: {accuracy:.2%}\n") + f.write(f"Attempted (produced \\boxed answer): {num_attempted}/{total_examples}\n") + f.write(f"Soundness (correct/attempted): {soundness:.2%}\n\n") + f.write(f"Token Statistics:\n") + f.write(f"---------------------------\n") + f.write(f"Total Tokens: {total_reasoning_tokens}\n") + f.write(f"Average Tokens: {avg_reasoning_tokens:.2f}\n") + if reasoning_token_counts: + f.write(f"Median Tokens: {float(np.median(reasoning_token_counts)):.0f}\n") + f.write(f"Min Tokens: {min(reasoning_token_counts)}\n") + f.write(f"Max Tokens: {max(reasoning_token_counts)}\n") + f.write(f"Std Dev: {np.std(reasoning_token_counts):.2f}\n") + + logger.info(f"Results saved to {results_file}") + print(f"Results saved to {results_file}") \ No newline at end of file diff --git a/examples/TTSwithVerification/mazemeta.py b/examples/TTSwithVerification/mazemeta.py new file mode 100644 index 0000000..94704af --- /dev/null +++ b/examples/TTSwithVerification/mazemeta.py @@ -0,0 +1,451 @@ +""" +Maze experiment with step-by-step verification using StepVerifierMazeMonitor. + +Uses the new monitor-based architecture that integrates with stream_completion. +""" + +import argparse +import asyncio +import json +import logging +import os +import re +import numpy as np + +from datasets import load_dataset +from transformers import AutoTokenizer + +from interwhen import stream_completion +from interwhen.monitors import StepVerifierMazeMonitor +from interwhen.utils.maze_verifier import parse_maze_from_prompt + +logging.basicConfig(level=logging.INFO, format='%(message)s') +logger = logging.getLogger(__name__) + +# ============== MODEL CONFIGURATION ============== +MAIN_MODEL = "Qwen/QwQ-32B" +# ================================================= + +_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) + +# Walk up to find the repo root (contains pyproject.toml), output to its parent +_dir = _SCRIPT_DIR +while _dir != os.path.dirname(_dir) and not os.path.isfile(os.path.join(_dir, "pyproject.toml")): + _dir = os.path.dirname(_dir) +_OUTPUT_ROOT = os.path.dirname(_dir) + + +def get_model_short_name(model_name: str) -> str: + """Extract a short, filesystem-safe name from the model path.""" + short_name = model_name.split("/")[-1] + short_name = short_name.replace(" ", "_").replace(":", "-") + return short_name + + +def get_output_dirs(main_model: str, base_dir: str = None): + """Create and return output directory paths based on model name.""" + if base_dir is None: + base_dir = os.path.join(_OUTPUT_ROOT, "Outputs_TTS", "MazeResults", "metaPrompt") + model_short_name = get_model_short_name(main_model) + output_base = os.path.join(base_dir, model_short_name) + + dirs = { + "base": output_base, + "reasoning": os.path.join(output_base, "Reasoning_output"), + } + + for dir_path in dirs.values(): + os.makedirs(dir_path, exist_ok=True) + + return dirs + + +def save_prompt(idx, prompt_with_answer, reason_dir): + """Save reasoning trace to a text file.""" + filename = os.path.join(reason_dir, f"reason_{idx}.txt") + with open(filename, "w", encoding="utf-8") as f: + f.write(prompt_with_answer) + + +def build_meta_prompt_from_example(example): + """Build prompt for maze example.""" + system_prompt = """You are a maze-solving AI. Given a maze in ASCII format, analyze it step by step. + +COORDINATE SYSTEM: +- Rows are numbered from top (row 0) to bottom +- Columns are numbered from left (col 0) to right +- Movement: UP (row decreases), DOWN (row increases), LEFT (col decreases), RIGHT (col increases) + +TURN DEFINITIONS: +- RIGHT_TURN = 90° clockwise change (e.g., DOWN→LEFT, LEFT→UP, UP→RIGHT, RIGHT→DOWN) +- LEFT_TURN = 90° counterclockwise change (e.g., DOWN→RIGHT, RIGHT→UP, UP→LEFT, LEFT→DOWN) + +RELATIVE POSITION DEFINITIONS: +- "directly to the left" = same row, E has smaller column than S +- "directly to the right" = same row, E has larger column than S +- "directly above" = same column, E has smaller row than S +- "directly below" = same column, E has larger row than S +- "top left" = E has smaller row AND smaller column than S +- "top right" = E has smaller row AND larger column than S +- "bottom left" = E has larger row AND smaller column than S +- "bottom right" = E has larger row AND larger column than S + +IMPORTANT: Follow the EXACT output format below. Do NOT use tags. + +EXAMPLE 1: Counting Right Turns +Question: How many right turns are there in the path from S to E? + +>>> LOCATE START AND EXIT: + S position: (3,5) + E position: (1,1) + +>>> STEP 1: Move DOWN from (3,5) to (4,5) + Current position: (4,5) + Previous direction: — + Current direction: DOWN + Turn type: STRAIGHT + Running count: Right=0, Left=0 + +>>> STEP 2: Move DOWN from (4,5) to (5,5) + Current position: (5,5) + Previous direction: DOWN + Current direction: DOWN + Turn type: STRAIGHT + Running count: Right=0, Left=0 + +>>> STEP 3: Move LEFT from (5,5) to (5,4) + Current position: (5,4) + Previous direction: DOWN + Current direction: LEFT + Turn type: RIGHT_TURN + Running count: Right=1, Left=0 + +>>> FINAL ANSWER: Right turns = 2 + \\boxed{C} + +EXAMPLE 2: Counting Total Turns +Question: How many total turns are there in the path from S to E? + +>>> LOCATE START AND EXIT: + S position: (3,5) + E position: (1,1) + +>>> STEP 1: Move DOWN from (3,5) to (4,5) + Current position: (4,5) + Previous direction: — + Current direction: DOWN + Turn type: STRAIGHT + Running count: Right=0, Left=0, Total=0 + +[... continue for all steps ...] + +>>> FINAL ANSWER: Total turns = 2 + \\boxed{C} + +EXAMPLE 3: Relative Position +Question: Is the exit (E) to the top left of the starting point (S)? + +>>> LOCATE START AND EXIT: + S position: (3,5) + E position: (1,1) + +>>> COMPARE POSITIONS: + Row comparison: E row (1) < S row (3) → E is ABOVE S ✓ + Col comparison: E col (1) < S col (5) → E is LEFT of S ✓ + +>>> ANALYSIS: + E is above S (smaller row): YES + E is left of S (smaller col): YES + Therefore E is at TOP LEFT of S. + +>>> ANSWER: YES, E is to the top left of S. + \\boxed{A} + +════════════════════════════════════════════════════════════════════════════════ +Now solve the following maze using the EXACT same format. First locate S and E, then trace the path step by step.""" + + # Get the maze description (trimmed to remove trailing instructions) + description = str(example.get("prompt", "")) + description_trimmed = description[:-143] if len(description) > 143 else description + + return system_prompt, description_trimmed + + +def extract_solution_mcq(text): + """Extract MCQ solution from model output.""" + # Try multiple boxed patterns + patterns = [ + r"\\boxed\{([^}]*)\}", # \boxed{...} + r"boxed\{([^}]*)\}", # boxed{...} without escape + r"\*\*([A-D])\*\*", # **A** format + r"answer[:\s]*([A-D])", # answer: A format + r"(?:^|\n)([A-D])(?:\s|$|\.)", # Standalone letter + ] + + for pattern in patterns: + matches = re.findall(pattern, text, re.IGNORECASE) + if matches: + expr = matches[-1].strip() + choice_match = re.search(r"\b([ABCD])\b", expr, flags=re.IGNORECASE) + if choice_match: + return choice_match.group(1).upper() + + # Last resort: look for any standalone A, B, C, or D + standalone = re.findall(r"\b([ABCD])\b", text) + if standalone: + return standalone[-1].upper() + + return None + + +def count_tokens(text: str, tokenizer) -> int: + """Count the total number of tokens in the generated text using the tokenizer.""" + tokens = tokenizer.encode(text, add_special_tokens=False) + return len(tokens) + + +def get_question_type_from_index(idx: int) -> str: + """Determine question type based on index range. + + Dataset structure: + - 3000-3499: right turns + - 3500-3999: total turns + - 4000-4500: relative position + """ + if idx < 3500: + return "right_turns" + elif idx < 4000: + return "total_turns" + else: + return "relative_position" + + +def init_llm_server(model_name, max_tokens=32768, port=8000): + """Initialize LLM server configuration.""" + url = f"http://localhost:{port}/v1/completions" + payload = { + "model": model_name, + "max_tokens": max_tokens, + "top_k": 20, + "top_p": 0.95, + "min_p": 0.0, + "do_sample": True, + "temperature": 0.6, + "stream": True, + "logprobs": 20, + "use_beam_search": False, + "prompt_cache": True, + "seed": 42 + } + headers = {"Content-Type": "application/json"} + return {"url": url, "payload": payload, "headers": headers} + +def evaluate_maze_answer(answer, options, ground_truth): + """ + Evaluate a Maze MCQ answer and return (is_correct, extracted_answer, message). + + Args: + answer: Raw model output + options: Dictionary mapping option letters (A/B/C/D) to their values + ground_truth: The correct answer value + + Returns: + Tuple of (is_correct, extracted_answer, message) + """ + sol = extract_solution_mcq(answer) + gt_sol = str(ground_truth).strip() + + if not sol: + return False, None, "No expression found" + + sol = sol.strip() + + # Case 1: LLM returned option letter (A/B/C/D) + if sol in options: + if options[sol] == gt_sol: + return True, sol, f"Correct: option {sol} -> {options[sol]}" + else: + return False, sol, f"Incorrect: expected '{gt_sol}', got '{options[sol]}' (option {sol})" + + # Case 2: LLM returned the actual answer text + # First check if sol matches ground truth directly + if sol.lower() == gt_sol.lower(): + return True, sol, f"Correct: answer text matches ground truth: {sol}" + + # Check if sol matches any option value + for opt_letter, opt_value in options.items(): + if sol.lower() == opt_value.lower(): + if opt_value == gt_sol: + return True, sol, f"Correct: answer text {sol} (option {opt_letter})" + else: + return False, sol, f"Incorrect: expected '{gt_sol}', got '{opt_value}' (option {opt_letter})" + + return False, sol, f"Solution '{sol}' not found in options or ground truth" + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run maze experiments with StepVerifierMazeMonitor") + parser.add_argument("--model", type=str, default=MAIN_MODEL, + help="Model name for generation") + parser.add_argument("--indices", type=str, default=None, + help="Comma-separated indices to run (e.g., '3000,3500,4000')") + parser.add_argument("--start", type=int, default=3000, help="Start index") + parser.add_argument("--end", type=int, default=3010, help="End index") + parser.add_argument("--num_examples", "-n", type=int, default=None, + help="Number of examples to run (overrides start/end)") + parser.add_argument("--max_corrections", type=int, default=5, + help="Maximum number of correction attempts per example") + parser.add_argument("--port", type=int, default=8000, help="vLLM server port") + parser.add_argument("--debug", "-d", action="store_true", help="Enable debug logging") + args = parser.parse_args() + + if args.debug: + logging.getLogger().setLevel(logging.DEBUG) + + # Load dataset + dataset = load_dataset("microsoft/VISION_LANGUAGE", 'maze', split='val') + + # Setup LLM server + llm_server = init_llm_server(args.model, port=args.port) + + # Load tokenizer for accurate token counting + logger.info(f"Loading tokenizer for {args.model}...") + tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) + logger.info("Tokenizer loaded successfully.") + + # Setup output directory + output_dirs = get_output_dirs(args.model) + reason_dir = output_dirs["reasoning"] + + # Determine indices + if args.indices: + indices = [int(x.strip()) for x in args.indices.split(",")] + elif args.num_examples: + # Use 4499 as endpoint (4500 is out of bounds since dataset size is 4500) + indices = np.linspace(3000, 4499, args.num_examples, dtype=int) + else: + indices = range(args.start, args.end) + + # Stats tracking + results = [] + total_correct = 0 + total_examples = 0 + total_reasoning_tokens = 0 + + for idx in indices: + example = dataset[idx] + system_prompt, user_prompt = build_meta_prompt_from_example(example) + if str(example.get("ground_truth", "")).strip() == "Q4": + target_options = ["A", "B"] + else: + target_options = ["A", "B", "C", "D"] + keys = "|".join(map(re.escape, target_options)) + pattern = rf'\b({keys})\.\s*([A-Za-z0-9]+)\b' + options = dict(re.findall(pattern, user_prompt)) + + # Build full prompt with ChatML format + full_prompt = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" + + # Parse maze from prompt + grid, start_pos, exit_pos = parse_maze_from_prompt(user_prompt) + + if not grid or not start_pos or not exit_pos: + logger.error(f"Could not parse maze for example {idx}") + continue + + # Detect question type from prompt (auto-detection) + # Falls back to index-based if no turn keywords found + question_type = StepVerifierMazeMonitor.detect_question_type(user_prompt) + + logger.info(f"\n{'='*60}") + logger.info(f"Example {idx} ({question_type})") + logger.info(f"Maze: S={start_pos}, E={exit_pos}, grid={len(grid)}x{len(grid[0]) if grid else 0}") + logger.info(f"{'='*60}") + + # Create the monitor + monitor = StepVerifierMazeMonitor( + name="maze_step_verifier", + answer_start_token="", + grid=grid, + start_pos=start_pos, + exit_pos=exit_pos, + max_corrections=args.max_corrections, + question_type=question_type, + ) + + # Run with stream_completion + try: + answer = asyncio.run(stream_completion( + full_prompt, + llm_server=llm_server, + monitors=(monitor), + add_delay=False, + termination_requires_validation=False, + async_execution=True + )) + except Exception as e: + logger.error(f"Error running example {idx}: {e}") + import traceback + traceback.print_exc() + continue + + # Save reasoning trace + save_prompt(idx, answer, reason_dir) + + # Count generated tokens + reasoning_tokens = count_tokens(answer, tokenizer) + total_reasoning_tokens += reasoning_tokens + + gt_sol = str(example.get("ground_truth", "")).strip() + is_correct, extracted_answer, message = evaluate_maze_answer(answer, options, gt_sol) + + if extracted_answer: + logger.info(f"Extracted answer: {extracted_answer}") + logger.info(message) + + if is_correct: + total_correct += 1 + + total_examples += 1 + # Log result + result = { + 'idx': int(idx), # Convert numpy int64 to Python int + 'question_type': question_type, + 'correct': is_correct, + 'sol': extracted_answer, + 'gt': gt_sol, + 'reasoning_tokens': reasoning_tokens, + } + results.append(result) + + logger.info(f"Result: sol={extracted_answer}, gt={gt_sol}, correct={is_correct}") + logger.info(f"Reasoning tokens: {reasoning_tokens}") + + # Compute final metrics + accuracy = total_correct / total_examples if total_examples > 0 else 0 + avg_reasoning_tokens = total_reasoning_tokens / total_examples if total_examples > 0 else 0 + + logger.info(f"\n{'='*60}") + logger.info(f"FINAL RESULTS") + logger.info(f"{'='*60}") + logger.info(f"Total examples: {total_examples}") + logger.info(f"Correct: {total_correct}") + logger.info(f"Accuracy: {accuracy:.4f} ({total_correct}/{total_examples})") + logger.info(f"Total reasoning tokens: {total_reasoning_tokens}") + logger.info(f"Avg reasoning tokens: {avg_reasoning_tokens:.1f}") + + # Save summary + summary = { + 'model': args.model, + 'total_examples': total_examples, + 'correct': total_correct, + 'accuracy': accuracy, + 'total_reasoning_tokens': total_reasoning_tokens, + 'avg_reasoning_tokens': avg_reasoning_tokens, + 'max_corrections': args.max_corrections, + 'results': results, + } + + summary_path = os.path.join(output_dirs["base"], "summary.json") + with open(summary_path, 'w') as f: + json.dump(summary, f, indent=2) + logger.info(f"\nSaved summary to {summary_path}") \ No newline at end of file diff --git a/examples/TTSwithVerification/spatialmap_stepverifier.py b/examples/TTSwithVerification/spatialmap_stepverifier.py index 4dc33ff..59d8b8a 100644 --- a/examples/TTSwithVerification/spatialmap_stepverifier.py +++ b/examples/TTSwithVerification/spatialmap_stepverifier.py @@ -1,31 +1,42 @@ """ -SpatialMap experiment with step-by-step verification using StepVerifierSpatialMapMonitor. +SpatialMap experiment with thinking-phase step verification. -Uses the new monitor-based architecture that integrates with stream_completion. +Uses ThinkingPhaseStepVerifierSpatialMapMonitor which: + - Verifies the model's directional claims during via side-streams + - Injects a structured step format after (no meta-prompt needed) + - Verifies each step as the model fills in the structured template """ import argparse import asyncio +import csv import json import logging import os import re import numpy as np -from pathlib import Path from datasets import load_dataset from transformers import AutoTokenizer from interwhen import stream_completion -from interwhen.monitors import StepVerifierSpatialMapMonitor +from interwhen.monitors import ThinkingPhaseStepVerifierSpatialMapMonitor logging.basicConfig(level=logging.INFO, format='%(message)s') logger = logging.getLogger(__name__) # ============== MODEL CONFIGURATION ============== -MAIN_MODEL = "Qwen/Qwen3-30B-A3B-Thinking-2507" +MAIN_MODEL = "Qwen/QwQ-32B" # ================================================= +_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) + +# Walk up to find the repo root (contains pyproject.toml), output to its parent +_dir = _SCRIPT_DIR +while _dir != os.path.dirname(_dir) and not os.path.isfile(os.path.join(_dir, "pyproject.toml")): + _dir = os.path.dirname(_dir) +_OUTPUT_ROOT = os.path.dirname(_dir) + def get_model_short_name(model_name: str) -> str: """Extract a short, filesystem-safe name from the model path.""" @@ -34,14 +45,17 @@ def get_model_short_name(model_name: str) -> str: return short_name -def get_output_dirs(main_model: str, base_dir: str = "../../Outputs_TTS/SpatialMapResults"): +def get_output_dirs(main_model: str, base_dir: str = None): """Create and return output directory paths based on model name.""" + if base_dir is None: + base_dir = os.path.join(_OUTPUT_ROOT, "Outputs_TTS", "SpatialMapResults") model_short_name = get_model_short_name(main_model) output_base = os.path.join(base_dir, model_short_name) dirs = { "base": output_base, "reasoning": os.path.join(output_base, "Reasoning_output"), + "csv_saved": os.path.join(output_base, "csv_saved"), } for dir_path in dirs.values(): @@ -50,11 +64,6 @@ def get_output_dirs(main_model: str, base_dir: str = "../../Outputs_TTS/SpatialM return dirs -def remove_last_paragraph(s: str) -> str: - """Remove the last instruction paragraph from the prompt.""" - return s[:-143] if len(s) > 143 else s - - def get_question_type(idx: int) -> str: """Determine question type based on index range. @@ -71,191 +80,35 @@ def get_question_type(idx: int) -> str: return "counting" -def build_meta_prompt_from_example(example): - """Build prompt with structured output format instructions.""" - - # Get the description - description = example.get("prompt") - description = str(description) - description = remove_last_paragraph(description) - - pre_prompt = """You are a spatial reasoning expert. Given a description of objects on a map and their relative positions, analyze the spatial relationships step by step. - -CRITICAL INSTRUCTION: DO NOT use abbreviations or initials for entity names. Always use the COMPLETE FULL NAME of each entity exactly as given in the problem. For example, write "Police Supply Store" not "PSS" or "PS". - -DIRECTION DEFINITIONS (Diagonal Directions): -- Northwest = up and to the left (row decreases, col decreases) -- Northeast = up and to the right (row decreases, col increases) -- Southwest = down and to the left (row increases, col decreases) -- Southeast = down and to the right (row increases, col increases) - -CARDINAL DIRECTIONS (for questions asking about North/South/East/West): -- North = directly up - requires BOTH Northwest AND Northeast relationships to be confirmed -- South = directly down - requires BOTH Southwest AND Southeast relationships to be confirmed -- West = directly left - requires BOTH Northwest AND Southwest relationships to be confirmed -- East = directly right - requires BOTH Northeast AND Southeast relationships to be confirmed - -IMPORTANT: In this dataset, only diagonal relationships (NW/NE/SW/SE) are given. An object can ONLY be in a pure cardinal direction (N/S/E/W) if BOTH required diagonal relationships exist. - -IMPORTANT RULES: -- Directions are TRANSITIVE: If A is Northwest of B, and B is Northwest of C, then A is Northwest of C. -- Directions are REVERSIBLE: If A is Northwest of B, then B is Southeast of A. -- Opposite pairs: Northwest ↔ Southeast, Northeast ↔ Southwest - -STRUCTURED OUTPUT FORMAT: - -═══════════════════════════════════════════════════════════════════════════════ -EXAMPLE 1: Direction Finding (Q0) -═══════════════════════════════════════════════════════════════════════════════ - -Map Description: -Police Supply Store is in the map. Narwhal's Novelties is to the Northwest of Police Supply Store. Coral Crafts is to the Northwest of Narwhal's Novelties. Coral Crafts is to the Northwest of Police Supply Store. Planetarium Prints is to the Southeast of Coral Crafts. Planetarium Prints is to the Northeast of Police Supply Store. Oz Oddities is to the Southwest of Planetarium Prints. Oz Oddities is to the Southwest of Police Supply Store. Ice Queen Ice Cream is to the Northwest of Planetarium Prints. Ice Queen Ice Cream is to the Southeast of Coral Crafts. - -Question: In which direction is Planetarium Prints relative to Police Supply Store? - -### Final Answer - ->>> STEP 1: PARSE RELATIONSHIPS - - Narwhal's Novelties is to the Northwest of Police Supply Store - - Coral Crafts is to the Northwest of Narwhal's Novelties - - Coral Crafts is to the Northwest of Police Supply Store - - Planetarium Prints is to the Southeast of Coral Crafts - - Planetarium Prints is to the Northeast of Police Supply Store - - Oz Oddities is to the Southwest of Planetarium Prints - - Oz Oddities is to the Southwest of Police Supply Store - - Ice Queen Ice Cream is to the Northwest of Planetarium Prints - - Ice Queen Ice Cream is to the Southeast of Coral Crafts - ->>> STEP 2: FIND DIRECT RELATIONSHIP - - Looking for: Planetarium Prints relative to Police Supply Store - - Direct relationship found: "Planetarium Prints is to the Northeast of Police Supply Store" - ->>> STEP 3: ANSWER - - Planetarium Prints is to the NORTHEAST of Police Supply Store. - ->>> FINAL ANSWER: Northeast - \\boxed{A} - -═══════════════════════════════════════════════════════════════════════════════ -EXAMPLE 2: Object Finding (Q1) -═══════════════════════════════════════════════════════════════════════════════ - -Map Description: -Quail's Quilts is in the map. Olive's Oils is to the Southeast of Quail's Quilts. Lumber's Marketplace is to the Northeast of Olive's Oils. Lumber's Marketplace is to the Northeast of Quail's Quilts. Stingray Shoes is to the Northeast of Quail's Quilts. Stingray Shoes is to the Northwest of Lumber's Marketplace. Elephant's Electronics is to the Northeast of Olive's Oils. Elephant's Electronics is to the Northeast of Lumber's Marketplace. Blossom Boutique is to the Northwest of Elephant's Electronics. Blossom Boutique is to the Southeast of Stingray Shoes. - -Question: Which object is in the Southwest of Lumber's Marketplace? - -### Final Answer - ->>> STEP 1: PARSE RELATIONSHIPS - - Olive's Oils is to the Southeast of Quail's Quilts - - Lumber's Marketplace is to the Northeast of Olive's Oils - - Lumber's Marketplace is to the Northeast of Quail's Quilts - - Stingray Shoes is to the Northeast of Quail's Quilts - - Stingray Shoes is to the Northwest of Lumber's Marketplace - - Elephant's Electronics is to the Northeast of Olive's Oils - - Elephant's Electronics is to the Northeast of Lumber's Marketplace - - Blossom Boutique is to the Northwest of Elephant's Electronics - - Blossom Boutique is to the Southeast of Stingray Shoes - ->>> STEP 2: FIND OBJECTS IN SOUTHWEST OF Lumber's Marketplace - - Using reversibility: if Lumber's Marketplace is to the Northeast of X, then X is to the Southwest of Lumber's Marketplace. - - Scanning relationships for "Lumber's Marketplace is to the Northeast of X": - - "Lumber's Marketplace is to the Northeast of Olive's Oils" → Olive's Oils is SOUTHWEST of Lumber's Marketplace ✓ - - "Lumber's Marketplace is to the Northeast of Quail's Quilts" → Quail's Quilts is SOUTHWEST of Lumber's Marketplace ✓ - - Other objects: - - Stingray Shoes is Northwest of Lumber's Marketplace → NOT Southwest - - Elephant's Electronics is Northeast of Lumber's Marketplace → NOT Southwest - - Blossom Boutique: no direct relationship to Lumber's Marketplace given - - Objects in Southwest of Lumber's Marketplace: Olive's Oils, Quail's Quilts - - Checking options: Quail's Quilts matches option D. - ->>> STEP 3: ANSWER - - Quail's Quilts is in the Southwest of Lumber's Marketplace. - ->>> FINAL ANSWER: Quail's Quilts - \\boxed{D} - -═══════════════════════════════════════════════════════════════════════════════ -EXAMPLE 3: Counting (Q2) -═══════════════════════════════════════════════════════════════════════════════ - -Map Description: -Tremor Toys is in the map. Fresh Foods is to the Northeast of Tremor Toys. Salmon Sushi is to the Northeast of Fresh Foods. Salmon Sushi is to the Northeast of Tremor Toys. Recycle Center is to the Northeast of Fresh Foods. Recycle Center is to the Southeast of Salmon Sushi. Wolf's Wardrobe is to the Southeast of Fresh Foods. Wolf's Wardrobe is to the Southeast of Tremor Toys. Mantis's Maps is to the Southeast of Salmon Sushi. Mantis's Maps is to the Southeast of Fresh Foods. - -Question: How many objects are in the Southwest of Mantis's Maps? - -### Final Answer - ->>> STEP 1: PARSE RELATIONSHIPS - - Fresh Foods is to the Northeast of Tremor Toys - - Salmon Sushi is to the Northeast of Fresh Foods - - Salmon Sushi is to the Northeast of Tremor Toys - - Recycle Center is to the Northeast of Fresh Foods - - Recycle Center is to the Southeast of Salmon Sushi - - Wolf's Wardrobe is to the Southeast of Fresh Foods - - Wolf's Wardrobe is to the Southeast of Tremor Toys - - Mantis's Maps is to the Southeast of Salmon Sushi - - Mantis's Maps is to the Southeast of Fresh Foods - ->>> STEP 2: COUNT OBJECTS IN SOUTHWEST OF Mantis's Maps - - Using reversibility: if Mantis's Maps is to the Southeast of X, then X is to the Northwest of Mantis's Maps (NOT Southwest!). - - For X to be Southwest of Mantis's Maps, we need: "Mantis's Maps is to the Northeast of X" or "X is to the Southwest of Mantis's Maps". - - Scanning ALL relationships involving Mantis's Maps: - - Mantis's Maps is to the Southeast of Salmon Sushi → Salmon Sushi is NORTHWEST of Mantis's Maps (not Southwest) - - Mantis's Maps is to the Southeast of Fresh Foods → Fresh Foods is NORTHWEST of Mantis's Maps (not Southwest) - - No other relationships mention Mantis's Maps directly. - - Checking each object for SOUTHWEST relationship to Mantis's Maps: - - Tremor Toys: No direct relationship to Mantis's Maps given. Cannot determine. - - Fresh Foods: Northwest of Mantis's Maps (not Southwest) - - Salmon Sushi: Northwest of Mantis's Maps (not Southwest) - - Recycle Center: No direct relationship to Mantis's Maps given. Cannot determine. - - Wolf's Wardrobe: No direct relationship to Mantis's Maps given. Cannot determine. - - Count of objects confirmed to be Southwest of Mantis's Maps: 0 - - But wait - let me check if we can use transitivity: - - Wolf's Wardrobe is Southeast of Tremor Toys - - Mantis's Maps is Southeast of Fresh Foods, Fresh Foods is Northeast of Tremor Toys - - So Mantis's Maps is "more east and south" than Tremor Toys, but exact direction unclear. - - Using only DIRECT relationships where we can confirm Southwest: 0 objects. - - Checking the options: If 0 is not available, we need to reconsider. - - Options available: A. 5, B. 3, C. 2, D. 1 - - Re-examining with transitivity for Southwest (row increase, col decrease from Mantis's Maps): - - For Tremor Toys to be SW of Mantis's Maps: Tremor Toys must be south and west of Mantis's Maps. - - Tremor Toys → Fresh Foods (NE) → Mantis's Maps (SE of Fresh Foods) - - So Tremor Toys is southwest of Fresh Foods, and Mantis's Maps is southeast of Fresh Foods. - - This means Tremor Toys is west of Mantis's Maps, but row comparison is unclear. - - Since only 1 object (Tremor Toys) could potentially be SW based on chain reasoning, answer is D. 1. - ->>> STEP 3: ANSWER - - There is 1 object in the Southwest of Mantis's Maps. - ->>> FINAL ANSWER: 1 - \\boxed{D} - -═══════════════════════════════════════════════════════════════════════════════ - -REMINDER: Use the COMPLETE FULL NAME of each entity. DO NOT abbreviate or use initials. - -Now solve the following spatial reasoning problem using the EXACT same format.""" - - return pre_prompt, description +def build_simple_prompt(example): + """Build a prompt matching spatialmap_example.py.""" + pre_prompt = "You are an expert problem solver. Carefully read the following multiple-choice question and think through the solution step-by-step before providing your final answer. Provide your final answer option by enclosing it within \\boxed{A/B/C/D}.:" + description = str(example.get("prompt", "")) + description_trimmed = description[:-143] if len(description) > 143 else description + return pre_prompt, description_trimmed def extract_solution(text: str) -> str: """Extract the boxed answer from the response (after ).""" + patterns = [ + r"\\boxed\{([^}]*)\}", + r"boxed\{([^}]*)\}", + r"\*\*([A-D])\*\*", + r"answer[:\s]*([A-D])", + r"(?:^|\n)([A-D])(?:\s|$|\.)", + ] if "" in text: answer_section = text.split("")[-1] else: answer_section = text - - matches = re.findall(r'\\boxed\{([^}]*)\}', answer_section) - if matches: - return matches[-1].strip() - - match = re.search(r'(?:answer|Answer)[:\s]+([A-D])', answer_section) - if match: - return match.group(1).strip() - + answer_section = re.sub(r'.*?', '', answer_section, flags=re.DOTALL) + for pattern in patterns: + matches = re.findall(pattern, answer_section, re.IGNORECASE) + if matches: + expr = matches[-1].strip() + choice_match = re.search(r"\b([ABCD])\b", expr, flags=re.IGNORECASE) + if choice_match: + return choice_match.group(1).upper() return None @@ -265,7 +118,7 @@ def count_tokens(text: str, tokenizer) -> int: return len(tokens) -def init_llm_server(model_name, max_tokens=32768, port=8000): +def init_llm_server(model_name, max_tokens=20480, port=8000): """Initialize LLM server configuration.""" url = f"http://localhost:{port}/v1/completions" payload = { @@ -274,6 +127,7 @@ def init_llm_server(model_name, max_tokens=32768, port=8000): "top_k": 20, "top_p": 0.95, "min_p": 0.0, + "do_sample": True, "temperature": 0.6, "stream": True, "logprobs": 20, @@ -285,13 +139,14 @@ def init_llm_server(model_name, max_tokens=32768, port=8000): return {"url": url, "payload": payload, "headers": headers} -def save_output(idx: int, output: str, output_dir: str): - """Save output to file.""" - os.makedirs(output_dir, exist_ok=True) - filepath = os.path.join(output_dir, f"output_{idx}.txt") - with open(filepath, 'w') as f: - f.write(output) - logger.info(f"Saved output to {filepath}") +def save_prompt(idx, prompt_with_answer, reason_dir): + """Save reasoning trace to file.""" + os.makedirs(reason_dir, exist_ok=True) + filename = os.path.join(reason_dir, f"reason_{idx}.txt") + with open(filename, "w", encoding="utf-8") as f: + f.write(prompt_with_answer) + logger.info(f"Saved reasoning trace to {filename}") + def evaluate_spatialmap_answer(answer, options, ground_truth): """ @@ -307,32 +162,20 @@ def evaluate_spatialmap_answer(answer, options, ground_truth): """ sol = extract_solution(answer) gt_sol = str(ground_truth).strip() - if not sol: return False, None, "No expression found" - sol = sol.strip() - - # Case 1: LLM returned option letter (A/B/C/D) if sol in options: if options[sol] == gt_sol: return True, sol, f"Correct: option {sol} -> {options[sol]}" - else: - return False, sol, f"Incorrect: expected '{gt_sol}', got '{options[sol]}' (option {sol})" - - # Case 2: LLM returned the actual answer text - # First check if sol matches ground truth directly + return False, sol, f"Incorrect: expected '{gt_sol}', got '{options[sol]}' (option {sol})" if sol.lower() == gt_sol.lower(): return True, sol, f"Correct: answer text matches ground truth: {sol}" - - # Check if sol matches any option value for opt_letter, opt_value in options.items(): if sol.lower() == opt_value.lower(): if opt_value == gt_sol: return True, sol, f"Correct: answer text {sol} (option {opt_letter})" - else: - return False, sol, f"Incorrect: expected '{gt_sol}', got '{opt_value}' (option {opt_letter})" - + return False, sol, f"Incorrect: expected '{gt_sol}', got '{opt_value}' (option {opt_letter})" return False, sol, f"Solution '{sol}' not found in options or ground truth" @@ -350,7 +193,15 @@ def evaluate_spatialmap_answer(answer, options, ground_truth): help="Maximum number of correction attempts per example") parser.add_argument("--port", type=int, default=8000, help="vLLM server port") parser.add_argument("--debug", "-d", action="store_true", help="Enable debug logging") + parser.add_argument("--newline_threshold", type=int, default=20, + help="Number of \\n in thinking before triggering side verification") + parser.add_argument("--warmup", type=int, default=0, + help="Number of \\n to skip before starting side-chain verification (warmup period)") args = parser.parse_args() + + logger.info(f"Thinking-phase verification: always on") + logger.info(f" Newline threshold: {args.newline_threshold}") + logger.info(f" Warmup: {args.warmup}") if args.debug: logging.getLogger().setLevel(logging.DEBUG) @@ -385,6 +236,9 @@ def evaluate_spatialmap_answer(answer, options, ground_truth): total_correct = 0 total_examples = 0 total_reasoning_tokens = 0 + num_attempted = 0 # examples where a \boxed{} answer was produced (not "no solution") + reasoning_token_counts = [] + per_example_results = [] # list of dicts for CSV # Per-type stats stats_by_type = { @@ -395,32 +249,38 @@ def evaluate_spatialmap_answer(answer, options, ground_truth): for idx in indices: example = dataset[idx] - system_prompt, user_prompt = build_meta_prompt_from_example(example) + pre_prompt, description_trimmed = build_simple_prompt(example) if str(example.get("ground_truth", "")).strip() == "Q4": - target_options = ["A", "B"] + target_options = ["A", "B"] else: - target_options = ["A", "B", "C", "D"] + target_options = ["A", "B", "C", "D"] keys = "|".join(map(re.escape, target_options)) pattern = r'\b([A-D])\.\s*(.*?)(?=\s*[A-D]\.|$)' - raw = re.findall(pattern, user_prompt, flags=re.DOTALL) + raw = re.findall(pattern, description_trimmed, flags=re.DOTALL) options = {k: v.strip().rstrip(".") for k, v in raw} - - # Determine question type + question_type = get_question_type(idx) - - # Build full prompt - full_prompt = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n\n" + + full_prompt = f"<|im_start|>system\n{pre_prompt}<|im_end|>\n<|im_start|>user\n{description_trimmed}<|im_end|>\n<|im_start|>assistant\n" logger.info(f"\n{'='*60}") logger.info(f"Example {idx} ({question_type})") logger.info(f"{'='*60}") - # Create the monitor with the problem text - monitor = StepVerifierSpatialMapMonitor.from_prompt( - problem_text=user_prompt, + # Always use ThinkingPhaseStepVerifierSpatialMapMonitor: + # Phase 1 — verifies during via side-streams + # Phase 2a — injects structured step format after + # Phase 2b — verifies structured output as model fills it in + monitor = ThinkingPhaseStepVerifierSpatialMapMonitor( + name="spatialmap_thinking_verifier", + problem_text=description_trimmed, + llm_server=llm_server, + prompt=full_prompt, + newline_threshold=args.newline_threshold, max_corrections=args.max_corrections, - name="spatialmap_step_verifier" + answer_start_token="", + warmup_newlines=args.warmup, ) logger.info(f"Z3 solver initialized with {len(monitor.z3_solver.parsed_relations)} relations") @@ -441,14 +301,25 @@ def evaluate_spatialmap_answer(answer, options, ground_truth): traceback.print_exc() continue + # Save reasoning trace + save_prompt(int(idx), answer, reason_dir) + logger.info(f"Raw final output:\n{answer}") + # Count generated tokens reasoning_tokens = count_tokens(answer, tokenizer) total_reasoning_tokens += reasoning_tokens + reasoning_token_counts.append(reasoning_tokens) + logger.info(f"Generated tokens in this example: {reasoning_tokens}") # Evaluate the answer gt_sol = str(example.get("ground_truth", "")).strip() is_correct, extracted_answer, message = evaluate_spatialmap_answer(answer, options, gt_sol) + # "attempted" = model produced a real \boxed{} answer (not "no solution") + attempted = (extracted_answer is not None and extracted_answer.strip().lower() != "no solution") + if attempted: + num_attempted += 1 + if extracted_answer: logger.info(f"Extracted answer: {extracted_answer}") logger.info(message) @@ -459,14 +330,13 @@ def evaluate_spatialmap_answer(answer, options, ground_truth): total_examples += 1 stats_by_type[question_type]["total"] += 1 - # Save output - save_output(idx, answer, reason_dir) # Log result result = { 'idx': int(idx), 'question_type': question_type, 'correct': is_correct, + 'attempted': attempted, 'sol': extracted_answer, 'gt': gt_sol, 'reasoning_tokens': reasoning_tokens, @@ -475,12 +345,26 @@ def evaluate_spatialmap_answer(answer, options, ground_truth): } results.append(result) - logger.info(f"Result: sol={extracted_answer}, gt={gt_sol}, correct={is_correct}") + per_example_results.append({ + "index": int(idx), + "question_type": question_type, + "correct": is_correct, + "attempted": attempted, + "sol": extracted_answer if extracted_answer else "", + "gt": gt_sol, + "tokens": reasoning_tokens, + "num_relations": len(monitor.z3_solver.parsed_relations), + "verified_claims": len(monitor.verified_claims), + "message": message, + }) + + logger.info(f"Result: sol={extracted_answer}, gt={gt_sol}, correct={is_correct}, attempted={attempted}") logger.info(f"Verified claims: {len(monitor.verified_claims)}") logger.info(f"Reasoning tokens: {reasoning_tokens}") # Compute final metrics accuracy = total_correct / total_examples if total_examples > 0 else 0 + soundness = total_correct / num_attempted if num_attempted > 0 else 0 # correct / attempted avg_reasoning_tokens = total_reasoning_tokens / total_examples if total_examples > 0 else 0 logger.info(f"\n{'='*60}") @@ -488,7 +372,9 @@ def evaluate_spatialmap_answer(answer, options, ground_truth): logger.info(f"{'='*60}") logger.info(f"Total examples: {total_examples}") logger.info(f"Correct: {total_correct}") + logger.info(f"Attempted (produced \\boxed answer): {num_attempted}/{total_examples}") logger.info(f"Accuracy: {accuracy:.4f} ({total_correct}/{total_examples})") + logger.info(f"Soundness: {soundness:.4f} ({total_correct}/{num_attempted})") logger.info(f"Total reasoning tokens: {total_reasoning_tokens}") logger.info(f"Avg reasoning tokens: {avg_reasoning_tokens:.1f}") @@ -499,12 +385,27 @@ def evaluate_spatialmap_answer(answer, options, ground_truth): acc = stats["correct"] / stats["total"] logger.info(f" {qtype}: {acc:.4f} ({stats['correct']}/{stats['total']})") + print(f"\nFinal Accuracy: {total_correct}/{total_examples} ({accuracy:.2%})") + print(f"Soundness: {total_correct}/{num_attempted} ({soundness:.2%})") + print(f"Average Reasoning Tokens: {avg_reasoning_tokens:.2f}") + print(f"Total Reasoning Tokens: {total_reasoning_tokens}") + + # Save per-example CSV + csv_file = os.path.join(output_dirs["csv_saved"], f"results_{total_examples}examples.csv") + with open(csv_file, 'w', newline='') as f: + writer = csv.DictWriter(f, fieldnames=["index", "question_type", "correct", "attempted", "sol", "gt", "tokens", "num_relations", "verified_claims", "message"]) + writer.writeheader() + writer.writerows(per_example_results) + logger.info(f"Per-example CSV saved to {csv_file}") + # Save summary summary = { 'model': args.model, 'total_examples': total_examples, 'correct': total_correct, + 'attempted': num_attempted, 'accuracy': accuracy, + 'soundness': soundness, 'total_reasoning_tokens': total_reasoning_tokens, 'avg_reasoning_tokens': avg_reasoning_tokens, 'max_corrections': args.max_corrections, @@ -516,3 +417,39 @@ def evaluate_spatialmap_answer(answer, options, ground_truth): with open(summary_path, 'w') as f: json.dump(summary, f, indent=2) logger.info(f"\nSaved summary to {summary_path}") + + # Save results summary to a text file + results_file = os.path.join(output_dirs["base"], f"EAT_{total_examples}examples_results.txt") + with open(results_file, 'w') as f: + f.write(f"SpatialMap Step Verification Results\n") + f.write(f"{'='*50}\n\n") + f.write(f"Model: {args.model}\n") + f.write(f"Number of Examples: {total_examples}\n") + f.write(f"Max Corrections: {args.max_corrections}\n") + f.write(f"Newline Threshold: {args.newline_threshold}\n") + f.write(f"Warmup: {args.warmup}\n") + f.write(f"\n") + f.write(f"Results:\n") + f.write(f"---------\n") + f.write(f"Correct: {total_correct}/{total_examples}\n") + f.write(f"Accuracy: {accuracy:.2%}\n") + f.write(f"Attempted (produced \\boxed answer): {num_attempted}/{total_examples}\n") + f.write(f"Soundness (correct/attempted): {soundness:.2%}\n\n") + f.write(f"Per-type Breakdown:\n") + f.write(f"---------------------------\n") + for qtype, stats in stats_by_type.items(): + if stats["total"] > 0: + acc = stats["correct"] / stats["total"] + f.write(f" {qtype}: {acc:.2%} ({stats['correct']}/{stats['total']})\n") + f.write(f"\nToken Statistics:\n") + f.write(f"---------------------------\n") + f.write(f"Total Tokens: {total_reasoning_tokens}\n") + f.write(f"Average Tokens: {avg_reasoning_tokens:.2f}\n") + if reasoning_token_counts: + f.write(f"Median Tokens: {float(np.median(reasoning_token_counts)):.0f}\n") + f.write(f"Min Tokens: {min(reasoning_token_counts)}\n") + f.write(f"Max Tokens: {max(reasoning_token_counts)}\n") + f.write(f"Std Dev: {np.std(reasoning_token_counts):.2f}\n") + + logger.info(f"Results saved to {results_file}") + print(f"Results saved to {results_file}") diff --git a/examples/TTSwithVerification/spatialmeta.py b/examples/TTSwithVerification/spatialmeta.py new file mode 100644 index 0000000..dda2a5d --- /dev/null +++ b/examples/TTSwithVerification/spatialmeta.py @@ -0,0 +1,527 @@ +""" +SpatialMap experiment with step-by-step verification using StepVerifierSpatialMapMonitor. + +Uses the new monitor-based architecture that integrates with stream_completion. +""" + +import argparse +import asyncio +import json +import logging +import os +import re +import numpy as np + +from datasets import load_dataset +from transformers import AutoTokenizer + +from interwhen import stream_completion +from interwhen.monitors import StepVerifierSpatialMapMonitor + +logging.basicConfig(level=logging.INFO, format='%(message)s') +logger = logging.getLogger(__name__) + +# ============== MODEL CONFIGURATION ============== +MAIN_MODEL = "Qwen/QwQ-32B" +# ================================================= + +_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) + +# Walk up to find the repo root (contains pyproject.toml), output to its parent +_dir = _SCRIPT_DIR +while _dir != os.path.dirname(_dir) and not os.path.isfile(os.path.join(_dir, "pyproject.toml")): + _dir = os.path.dirname(_dir) +_OUTPUT_ROOT = os.path.dirname(_dir) + + +def get_model_short_name(model_name: str) -> str: + """Extract a short, filesystem-safe name from the model path.""" + short_name = model_name.split("/")[-1] + short_name = short_name.replace(" ", "_").replace(":", "-") + return short_name + + +def get_output_dirs(main_model: str, base_dir: str = None): + """Create and return output directory paths based on model name.""" + if base_dir is None: + base_dir = os.path.join(_OUTPUT_ROOT, "Outputs_TTS", "SpatialMapResults", "metaPrompt") + model_short_name = get_model_short_name(main_model) + output_base = os.path.join(base_dir, model_short_name) + + dirs = { + "base": output_base, + "reasoning": os.path.join(output_base, "Reasoning_output"), + } + + for dir_path in dirs.values(): + os.makedirs(dir_path, exist_ok=True) + + return dirs + + +def remove_last_paragraph(s: str) -> str: + """Remove the last instruction paragraph from the prompt.""" + return s[:-143] if len(s) > 143 else s + + +def save_prompt(idx, prompt_with_answer, reason_dir): + """Save reasoning trace to a text file.""" + os.makedirs(reason_dir, exist_ok=True) + filename = os.path.join(reason_dir, f"reason_{idx}.txt") + with open(filename, "w", encoding="utf-8") as f: + f.write(prompt_with_answer) + + +def get_question_type(idx: int) -> str: + """Determine question type based on index range. + + Dataset structure (1500 examples total): + - 0-499: Q0 (direction finding) + - 500-999: Q1 (object finding) + - 1000-1499: Q2 (counting) + """ + if idx < 500: + return "direction" + elif idx < 1000: + return "object" + else: + return "counting" + + +def build_meta_prompt_from_example(example): + """Build prompt with structured output format instructions.""" + + # Get the description + description = example.get("prompt") + description = str(description) + description = remove_last_paragraph(description) + + pre_prompt = """You are a spatial reasoning expert. Given a description of objects on a map and their relative positions, analyze the spatial relationships step by step. + +CRITICAL INSTRUCTION: DO NOT use abbreviations or initials for entity names. Always use the COMPLETE FULL NAME of each entity exactly as given in the problem. For example, write "Police Supply Store" not "PSS" or "PS". + +DIRECTION DEFINITIONS (Diagonal Directions): +- Northwest = up and to the left (row decreases, col decreases) +- Northeast = up and to the right (row decreases, col increases) +- Southwest = down and to the left (row increases, col decreases) +- Southeast = down and to the right (row increases, col increases) + +CARDINAL DIRECTIONS (for questions asking about North/South/East/West): +- North = directly up - requires BOTH Northwest AND Northeast relationships to be confirmed +- South = directly down - requires BOTH Southwest AND Southeast relationships to be confirmed +- West = directly left - requires BOTH Northwest AND Southwest relationships to be confirmed +- East = directly right - requires BOTH Northeast AND Southeast relationships to be confirmed + +IMPORTANT: In this dataset, only diagonal relationships (NW/NE/SW/SE) are given. An object can ONLY be in a pure cardinal direction (N/S/E/W) if BOTH required diagonal relationships exist. + +IMPORTANT RULES: +- Directions are TRANSITIVE: If A is Northwest of B, and B is Northwest of C, then A is Northwest of C. +- Directions are REVERSIBLE: If A is Northwest of B, then B is Southeast of A. +- Opposite pairs: Northwest ↔ Southeast, Northeast ↔ Southwest + +STRUCTURED OUTPUT FORMAT: + +═══════════════════════════════════════════════════════════════════════════════ +EXAMPLE 1: Direction Finding (Q0) +═══════════════════════════════════════════════════════════════════════════════ + +Map Description: +Police Supply Store is in the map. Narwhal's Novelties is to the Northwest of Police Supply Store. Coral Crafts is to the Northwest of Narwhal's Novelties. Coral Crafts is to the Northwest of Police Supply Store. Planetarium Prints is to the Southeast of Coral Crafts. Planetarium Prints is to the Northeast of Police Supply Store. Oz Oddities is to the Southwest of Planetarium Prints. Oz Oddities is to the Southwest of Police Supply Store. Ice Queen Ice Cream is to the Northwest of Planetarium Prints. Ice Queen Ice Cream is to the Southeast of Coral Crafts. + +Question: In which direction is Planetarium Prints relative to Police Supply Store? + +### Final Answer + +>>> STEP 1: PARSE RELATIONSHIPS + - Narwhal's Novelties is to the Northwest of Police Supply Store + - Coral Crafts is to the Northwest of Narwhal's Novelties + - Coral Crafts is to the Northwest of Police Supply Store + - Planetarium Prints is to the Southeast of Coral Crafts + - Planetarium Prints is to the Northeast of Police Supply Store + - Oz Oddities is to the Southwest of Planetarium Prints + - Oz Oddities is to the Southwest of Police Supply Store + - Ice Queen Ice Cream is to the Northwest of Planetarium Prints + - Ice Queen Ice Cream is to the Southeast of Coral Crafts + +>>> STEP 2: FIND DIRECT RELATIONSHIP + - Looking for: Planetarium Prints relative to Police Supply Store + - Direct relationship found: "Planetarium Prints is to the Northeast of Police Supply Store" + +>>> STEP 3: ANSWER + - Planetarium Prints is to the NORTHEAST of Police Supply Store. + +>>> FINAL ANSWER: Northeast + \\boxed{A} + +═══════════════════════════════════════════════════════════════════════════════ +EXAMPLE 2: Object Finding (Q1) +═══════════════════════════════════════════════════════════════════════════════ + +Map Description: +Quail's Quilts is in the map. Olive's Oils is to the Southeast of Quail's Quilts. Lumber's Marketplace is to the Northeast of Olive's Oils. Lumber's Marketplace is to the Northeast of Quail's Quilts. Stingray Shoes is to the Northeast of Quail's Quilts. Stingray Shoes is to the Northwest of Lumber's Marketplace. Elephant's Electronics is to the Northeast of Olive's Oils. Elephant's Electronics is to the Northeast of Lumber's Marketplace. Blossom Boutique is to the Northwest of Elephant's Electronics. Blossom Boutique is to the Southeast of Stingray Shoes. + +Question: Which object is in the Southwest of Lumber's Marketplace? + +### Final Answer + +>>> STEP 1: PARSE RELATIONSHIPS + - Olive's Oils is to the Southeast of Quail's Quilts + - Lumber's Marketplace is to the Northeast of Olive's Oils + - Lumber's Marketplace is to the Northeast of Quail's Quilts + - Stingray Shoes is to the Northeast of Quail's Quilts + - Stingray Shoes is to the Northwest of Lumber's Marketplace + - Elephant's Electronics is to the Northeast of Olive's Oils + - Elephant's Electronics is to the Northeast of Lumber's Marketplace + - Blossom Boutique is to the Northwest of Elephant's Electronics + - Blossom Boutique is to the Southeast of Stingray Shoes + +>>> STEP 2: FIND OBJECTS IN SOUTHWEST OF Lumber's Marketplace + - Using reversibility: if Lumber's Marketplace is to the Northeast of X, then X is to the Southwest of Lumber's Marketplace. + - Scanning relationships for "Lumber's Marketplace is to the Northeast of X": + - "Lumber's Marketplace is to the Northeast of Olive's Oils" → Olive's Oils is SOUTHWEST of Lumber's Marketplace ✓ + - "Lumber's Marketplace is to the Northeast of Quail's Quilts" → Quail's Quilts is SOUTHWEST of Lumber's Marketplace ✓ + - Other objects: + - Stingray Shoes is Northwest of Lumber's Marketplace → NOT Southwest + - Elephant's Electronics is Northeast of Lumber's Marketplace → NOT Southwest + - Blossom Boutique: no direct relationship to Lumber's Marketplace given + - Objects in Southwest of Lumber's Marketplace: Olive's Oils, Quail's Quilts + - Checking options: Quail's Quilts matches option D. + +>>> STEP 3: ANSWER + - Quail's Quilts is in the Southwest of Lumber's Marketplace. + +>>> FINAL ANSWER: Quail's Quilts + \\boxed{D} + +═══════════════════════════════════════════════════════════════════════════════ +EXAMPLE 3: Counting (Q2) +═══════════════════════════════════════════════════════════════════════════════ + +Map Description: +Tremor Toys is in the map. Fresh Foods is to the Northeast of Tremor Toys. Salmon Sushi is to the Northeast of Fresh Foods. Salmon Sushi is to the Northeast of Tremor Toys. Recycle Center is to the Northeast of Fresh Foods. Recycle Center is to the Southeast of Salmon Sushi. Wolf's Wardrobe is to the Southeast of Fresh Foods. Wolf's Wardrobe is to the Southeast of Tremor Toys. Mantis's Maps is to the Southeast of Salmon Sushi. Mantis's Maps is to the Southeast of Fresh Foods. + +Question: How many objects are in the Southwest of Mantis's Maps? + +### Final Answer + +>>> STEP 1: PARSE RELATIONSHIPS + - Fresh Foods is to the Northeast of Tremor Toys + - Salmon Sushi is to the Northeast of Fresh Foods + - Salmon Sushi is to the Northeast of Tremor Toys + - Recycle Center is to the Northeast of Fresh Foods + - Recycle Center is to the Southeast of Salmon Sushi + - Wolf's Wardrobe is to the Southeast of Fresh Foods + - Wolf's Wardrobe is to the Southeast of Tremor Toys + - Mantis's Maps is to the Southeast of Salmon Sushi + - Mantis's Maps is to the Southeast of Fresh Foods + +>>> STEP 2: COUNT OBJECTS IN SOUTHWEST OF Mantis's Maps + - Using reversibility: if Mantis's Maps is to the Southeast of X, then X is to the Northwest of Mantis's Maps (NOT Southwest!). + - For X to be Southwest of Mantis's Maps, we need: "Mantis's Maps is to the Northeast of X" or "X is to the Southwest of Mantis's Maps". + - Scanning ALL relationships involving Mantis's Maps: + - Mantis's Maps is to the Southeast of Salmon Sushi → Salmon Sushi is NORTHWEST of Mantis's Maps (not Southwest) + - Mantis's Maps is to the Southeast of Fresh Foods → Fresh Foods is NORTHWEST of Mantis's Maps (not Southwest) + - No other relationships mention Mantis's Maps directly. + - Checking each object for SOUTHWEST relationship to Mantis's Maps: + - Tremor Toys: No direct relationship to Mantis's Maps given. Cannot determine. + - Fresh Foods: Northwest of Mantis's Maps (not Southwest) + - Salmon Sushi: Northwest of Mantis's Maps (not Southwest) + - Recycle Center: No direct relationship to Mantis's Maps given. Cannot determine. + - Wolf's Wardrobe: No direct relationship to Mantis's Maps given. Cannot determine. + - Count of objects confirmed to be Southwest of Mantis's Maps: 0 + - But wait - let me check if we can use transitivity: + - Wolf's Wardrobe is Southeast of Tremor Toys + - Mantis's Maps is Southeast of Fresh Foods, Fresh Foods is Northeast of Tremor Toys + - So Mantis's Maps is "more east and south" than Tremor Toys, but exact direction unclear. + - Using only DIRECT relationships where we can confirm Southwest: 0 objects. + - Checking the options: If 0 is not available, we need to reconsider. + - Options available: A. 5, B. 3, C. 2, D. 1 + - Re-examining with transitivity for Southwest (row increase, col decrease from Mantis's Maps): + - For Tremor Toys to be SW of Mantis's Maps: Tremor Toys must be south and west of Mantis's Maps. + - Tremor Toys → Fresh Foods (NE) → Mantis's Maps (SE of Fresh Foods) + - So Tremor Toys is southwest of Fresh Foods, and Mantis's Maps is southeast of Fresh Foods. + - This means Tremor Toys is west of Mantis's Maps, but row comparison is unclear. + - Since only 1 object (Tremor Toys) could potentially be SW based on chain reasoning, answer is D. 1. + +>>> STEP 3: ANSWER + - There is 1 object in the Southwest of Mantis's Maps. + +>>> FINAL ANSWER: 1 + \\boxed{D} + +═══════════════════════════════════════════════════════════════════════════════ + +REMINDER: Use the COMPLETE FULL NAME of each entity. DO NOT abbreviate or use initials. + +Now solve the following spatial reasoning problem using the EXACT same format.""" + + return pre_prompt, description + + +def extract_solution(text: str) -> str: + """Extract the boxed answer from the response (after ).""" + if "" in text: + answer_section = text.split("")[-1] + else: + answer_section = text + + matches = re.findall(r'\\boxed\{([^}]*)\}', answer_section) + if matches: + return matches[-1].strip() + + match = re.search(r'(?:answer|Answer)[:\s]+([A-D])', answer_section) + if match: + return match.group(1).strip() + + return None + + +def count_tokens(text: str, tokenizer) -> int: + """Count the total number of tokens in the generated text using the tokenizer.""" + tokens = tokenizer.encode(text, add_special_tokens=False) + return len(tokens) + + +def init_llm_server(model_name, max_tokens=32768, port=8000): + """Initialize LLM server configuration.""" + url = f"http://localhost:{port}/v1/completions" + payload = { + "model": model_name, + "max_tokens": max_tokens, + "top_k": 20, + "top_p": 0.95, + "min_p": 0.0, + "do_sample": True, + "temperature": 0.6, + "stream": True, + "logprobs": 20, + "use_beam_search": False, + "prompt_cache": True, + "seed": 42 + } + headers = {"Content-Type": "application/json"} + return {"url": url, "payload": payload, "headers": headers} + +def evaluate_spatialmap_answer(answer, options, ground_truth): + """ + Evaluate a SpatialMap MCQ answer and return (is_correct, extracted_answer, message). + + Args: + answer: Raw model output + options: Dictionary mapping option letters (A/B/C/D) to their values + ground_truth: The correct answer value + + Returns: + Tuple of (is_correct, extracted_answer, message) + """ + sol = extract_solution(answer) + gt_sol = str(ground_truth).strip() + + if not sol: + return False, None, "No expression found" + + sol = sol.strip() + + # Case 1: LLM returned option letter (A/B/C/D) + if sol in options: + if options[sol] == gt_sol: + return True, sol, f"Correct: option {sol} -> {options[sol]}" + else: + return False, sol, f"Incorrect: expected '{gt_sol}', got '{options[sol]}' (option {sol})" + + # Case 2: LLM returned the actual answer text + # First check if sol matches ground truth directly + if sol.lower() == gt_sol.lower(): + return True, sol, f"Correct: answer text matches ground truth: {sol}" + + # Check if sol matches any option value + for opt_letter, opt_value in options.items(): + if sol.lower() == opt_value.lower(): + if opt_value == gt_sol: + return True, sol, f"Correct: answer text {sol} (option {opt_letter})" + else: + return False, sol, f"Incorrect: expected '{gt_sol}', got '{opt_value}' (option {opt_letter})" + + return False, sol, f"Solution '{sol}' not found in options or ground truth" + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run SpatialMap experiments with StepVerifierSpatialMapMonitor") + parser.add_argument("--model", type=str, default=MAIN_MODEL, + help="Model name for generation") + parser.add_argument("--indices", type=str, default=None, + help="Comma-separated indices to run (e.g., '0,100,200')") + parser.add_argument("--start", type=int, default=0, help="Start index") + parser.add_argument("--end", type=int, default=1500, help="End index") + parser.add_argument("--num_examples", "-n", type=int, default=None, + help="Number of examples to run (overrides start/end)") + parser.add_argument("--max_corrections", type=int, default=5, + help="Maximum number of correction attempts per example") + parser.add_argument("--port", type=int, default=8001, help="vLLM server port") + parser.add_argument("--debug", "-d", action="store_true", help="Enable debug logging") + args = parser.parse_args() + + if args.debug: + logging.getLogger().setLevel(logging.DEBUG) + + # Load dataset (spatial_map_text_only has 1500 examples) + dataset = load_dataset("microsoft/VISION_LANGUAGE", 'spatial_map_text_only', split='val') + + # Setup LLM server + llm_server = init_llm_server(args.model, port=args.port) + + # Load tokenizer for accurate token counting + logger.info(f"Loading tokenizer for {args.model}...") + tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) + logger.info("Tokenizer loaded successfully.") + + # Setup output directory + output_dirs = get_output_dirs(args.model) + reason_dir = output_dirs["reasoning"] + + # Determine indices + max_idx = len(dataset) - 1 + if args.indices: + indices = [int(x.strip()) for x in args.indices.split(",")] + elif args.num_examples: + # Sample evenly across all 1500 examples (0-1499) + indices = np.linspace(0, min(max_idx, 1499), args.num_examples, dtype=int) + else: + indices = range(args.start, min(args.end, max_idx + 1)) + + # Stats tracking + results = [] + total_correct = 0 + total_examples = 0 + total_reasoning_tokens = 0 + + # Per-type stats + stats_by_type = { + "direction": {"total": 0, "correct": 0}, + "object": {"total": 0, "correct": 0}, + "counting": {"total": 0, "correct": 0}, + } + + for idx in indices: + example = dataset[idx] + system_prompt, user_prompt = build_meta_prompt_from_example(example) + if str(example.get("ground_truth", "")).strip() == "Q4": + target_options = ["A", "B"] + else: + target_options = ["A", "B", "C", "D"] + keys = "|".join(map(re.escape, target_options)) + pattern = r'\b([A-D])\.\s*(.*?)(?=\s*[A-D]\.|$)' + raw = re.findall(pattern, user_prompt, flags=re.DOTALL) + + options = {k: v.strip().rstrip(".") for k, v in raw} + + # Determine question type + question_type = get_question_type(idx) + + # Build full prompt with ChatML format + full_prompt = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" + + logger.info(f"\n{'='*60}") + logger.info(f"Example {idx} ({question_type})") + logger.info(f"{'='*60}") + + # Create the monitor with the problem text + monitor = StepVerifierSpatialMapMonitor.from_prompt( + problem_text=user_prompt, + max_corrections=args.max_corrections, + name="spatialmap_step_verifier" + ) + + logger.info(f"Z3 solver initialized with {len(monitor.z3_solver.parsed_relations)} relations") + + # Run with stream_completion + try: + answer = asyncio.run(stream_completion( + full_prompt, + llm_server=llm_server, + monitors=(monitor,), + add_delay=False, + termination_requires_validation=False, + async_execution=True + )) + except Exception as e: + logger.error(f"Error running example {idx}: {e}") + import traceback + traceback.print_exc() + continue + + # Count generated tokens + reasoning_tokens = count_tokens(answer, tokenizer) + total_reasoning_tokens += reasoning_tokens + + # Evaluate the answer + gt_sol = str(example.get("ground_truth", "")).strip() + is_correct, extracted_answer, message = evaluate_spatialmap_answer(answer, options, gt_sol) + + if extracted_answer: + logger.info(f"Extracted answer: {extracted_answer}") + logger.info(message) + + if is_correct: + total_correct += 1 + stats_by_type[question_type]["correct"] += 1 + + total_examples += 1 + stats_by_type[question_type]["total"] += 1 + # Save output + save_prompt(idx, answer, reason_dir) + + # Log result + result = { + 'idx': int(idx), + 'question_type': question_type, + 'correct': is_correct, + 'sol': extracted_answer, + 'gt': gt_sol, + 'reasoning_tokens': reasoning_tokens, + 'num_relations': len(monitor.z3_solver.parsed_relations), + 'verified_claims': len(monitor.verified_claims), + } + results.append(result) + + logger.info(f"Result: sol={extracted_answer}, gt={gt_sol}, correct={is_correct}") + logger.info(f"Verified claims: {len(monitor.verified_claims)}") + logger.info(f"Reasoning tokens: {reasoning_tokens}") + + # Compute final metrics + accuracy = total_correct / total_examples if total_examples > 0 else 0 + avg_reasoning_tokens = total_reasoning_tokens / total_examples if total_examples > 0 else 0 + + logger.info(f"\n{'='*60}") + logger.info(f"FINAL RESULTS") + logger.info(f"{'='*60}") + logger.info(f"Total examples: {total_examples}") + logger.info(f"Correct: {total_correct}") + logger.info(f"Accuracy: {accuracy:.4f} ({total_correct}/{total_examples})") + logger.info(f"Total reasoning tokens: {total_reasoning_tokens}") + logger.info(f"Avg reasoning tokens: {avg_reasoning_tokens:.1f}") + + # Per-type breakdown + logger.info(f"\nPer-type breakdown:") + for qtype, stats in stats_by_type.items(): + if stats["total"] > 0: + acc = stats["correct"] / stats["total"] + logger.info(f" {qtype}: {acc:.4f} ({stats['correct']}/{stats['total']})") + + # Save summary + summary = { + 'model': args.model, + 'total_examples': total_examples, + 'correct': total_correct, + 'accuracy': accuracy, + 'total_reasoning_tokens': total_reasoning_tokens, + 'avg_reasoning_tokens': avg_reasoning_tokens, + 'max_corrections': args.max_corrections, + 'stats_by_type': stats_by_type, + 'results': results, + } + + summary_path = os.path.join(output_dirs["base"], "summary.json") + with open(summary_path, 'w') as f: + json.dump(summary, f, indent=2) + logger.info(f"\nSaved summary to {summary_path}") \ No newline at end of file diff --git a/examples/TTSwithVerification/tot_baseline.py b/examples/TTSwithVerification/tot_baseline.py new file mode 100644 index 0000000..58b2171 --- /dev/null +++ b/examples/TTSwithVerification/tot_baseline.py @@ -0,0 +1,784 @@ + #!/usr/bin/env python3 +"""Command-line Tree-of-Thought baseline runner for interwhen datasets.""" + +import argparse +import asyncio +import json +import logging +import os +import re +from pathlib import Path +from typing import Any, Dict, List, Tuple + +from tqdm.asyncio import tqdm_asyncio + +import httpx +import numpy as np +from datasets import load_dataset + +from interwhen.tree_of_thought import ( + SearchMethod, + ToTSearchConfig, + TreeOfThoughtSearch, + build_tot_problem, + # build_verina_synthesis_prompt, + # build_verina_spec_synthesis_prompt, +) +from interwhen.utils.zebralogic_helper import extract_last_json, zebra_correctness +# from verina_utils import ( +# load_verina_dataset, +# extract_code_from_response, +# evaluate_generated_code, +# ) +# from verina_spec_utils import ( +# load_verina_dataset as load_verina_spec_dataset, +# extract_spec_from_response, +# evaluate_generated_spec, +# ) + +LOGGER = logging.getLogger("tot_baseline") + + +# ============== Helper Functions ============== + +def remove_last_paragraph(s: str) -> str: + return s[:-143] if len(s) > 143 else s + + +def build_maze_prompt(example): + pre_prompt = ( + "You are an expert problem solver. Carefully read the following multiple-choice question " + "and think through the solution step-by-step before providing your final answer. " + "Provide your final answer option by enclosing it within \\boxed{A/B/C/D}.:" + ) + description = remove_last_paragraph(str(example.get("prompt"))) + return pre_prompt, description + + +def build_spatialmap_prompt(example): + pre_prompt = ( + "You are an expert problem solver. Carefully read the following multiple-choice question " + "and think through the solution step-by-step before providing your final answer." + "Provide your final answer option by enclosing it within \\boxed{A/B/C/D}.:" + ) + description = remove_last_paragraph(str(example.get("prompt"))) + return pre_prompt, description + + +def extract_solution_game24(text): + boxed_pattern = r"\\boxed\{" + matches = list(re.finditer(boxed_pattern, text)) + if not matches: + return None + last_match = matches[-1] + start = last_match.end() + brace_count = 1 + end = start + while end < len(text) and brace_count > 0: + if text[end] == "{": + brace_count += 1 + elif text[end] == "}": + brace_count -= 1 + end += 1 + expr = text[start:end - 1].strip() + + frac_pattern = r"\\frac\{([^{}]+)\}\{([^{}]+)\}" + while re.search(frac_pattern, expr): + expr = re.sub(frac_pattern, r"(\1/\2)", expr) + + replacements = { + r"\times": "*", + r"\cdot": "*", + r"\div": "/", + } + for latex, op in replacements.items(): + expr = expr.replace(latex, op) + + expr = expr.replace(r"\\,", "").replace(r"\\ ", "") + expr = re.sub(r"\)\s*\(", ")*(", expr) + expr = re.sub(r"\)\s*(\d)", r")*\1", expr) + expr = re.sub(r"(\d)\s*\(", r"\1*(", expr) + + return expr + + +def extract_numbers_from_expr(expr): + numbers = re.findall(r"\d+\.?\d*", expr) + return [int(float(n)) if float(n).is_integer() else float(n) for n in numbers] + + +def validate_numbers_used(expr, expected_nums): + used_nums = extract_numbers_from_expr(expr) + return sorted(used_nums) == sorted(expected_nums) + + +def evaluate_expression(expr, expected_nums=None): + try: + if expected_nums is not None and not validate_numbers_used(expr, expected_nums): + return False + value = eval(expr, {"__builtins__": None}, {}) + return abs(value - 24) < 1e-6 + except Exception: + return False + + +def evaluate_game24_answer(answer, nums): + expr = extract_solution_game24(answer) + if not expr: + return False, None, "No expression found" + if evaluate_expression(expr, expected_nums=nums): + return True, expr, "Correct solution (evaluates to 24 using exactly the given numbers)" + used_nums = extract_numbers_from_expr(expr) + if sorted(used_nums) != sorted(nums): + return False, expr, f"Incorrect: Expression uses {used_nums}, expected {nums}" + return False, expr, "Expression does not evaluate to 24" + + +def extract_solution_mcq(text): + """Extract MCQ solution from model output.""" + patterns = [ + r"\\boxed\{([^}]*)\}", + r"boxed\{([^}]*)\}", + r"\*\*([A-D])\*\*", + r"answer[:\s]*([A-D])", + r"(?:^|\n)([A-D])(?:\s|$|\.)", + ] + + for pattern in patterns: + matches = re.findall(pattern, text, re.IGNORECASE) + if matches: + expr = matches[-1].strip() + choice_match = re.search(r"\b([ABCD])\b", expr, flags=re.IGNORECASE) + if choice_match: + return choice_match.group(1).upper() + + standalone = re.findall(r"\b([ABCD])\b", text) + if standalone: + return standalone[-1].upper() + + return None + + +def extract_options_from_prompt(prompt_text, target_options): + pattern = r"\b([A-D])\.\s*(.*?)(?=\s*[A-D]\.\s*|$)" + raw = re.findall(pattern, prompt_text, flags=re.DOTALL) + options = {k: v.strip().rstrip(".") for k, v in raw} + if target_options: + options = {k: v for k, v in options.items() if k in target_options} + return options + + +def evaluate_mcq_answer(answer, options, ground_truth): + sol = extract_solution_mcq(answer) + gt_sol = str(ground_truth).strip() + if not sol: + return False, None, "No expression found" + sol = sol.strip() + if sol in options: + if options[sol] == gt_sol: + return True, sol, f"Correct: option {sol} -> {options[sol]}" + return False, sol, f"Incorrect: expected '{gt_sol}', got '{options[sol]}' (option {sol})" + if sol.lower() == gt_sol.lower(): + return True, sol, f"Correct: answer text matches ground truth: {sol}" + for opt_letter, opt_value in options.items(): + if sol.lower() == opt_value.lower(): + if opt_value == gt_sol: + return True, sol, f"Correct: answer text {sol} (option {opt_letter})" + return False, sol, f"Incorrect: expected '{gt_sol}', got '{opt_value}' (option {opt_letter})" + return False, sol, f"Solution '{sol}' not found in options or ground truth" + + +def extract_solution_zebralogic(text): + """Extract JSON solution from ZebraLogic model output.""" + if not text: + return None + + def _try_parse(candidate: str): + try: + parsed = json.loads(candidate) + if isinstance(parsed, dict): + # Unwrap if it's a wrapper with "answer" key + if "answer" in parsed and isinstance(parsed["answer"], dict): + inner = parsed["answer"] + if any(re.match(r"^house\s*\d+$", str(k).strip(), flags=re.IGNORECASE) for k in inner.keys()): + return inner + return parsed + except json.JSONDecodeError: + return None + return None + + # Try to extract JSON from code blocks + patterns = [ + r"```json\s*({.*?})\s*```", # Markdown code block + r"```\s*({.*?})\s*```", # Generic code block + r"({\s*['\"]House\s*\d+['\"].*?})", # Direct JSON starting with House + ] + + for pattern in patterns: + matches = re.findall(pattern, text, re.DOTALL | re.IGNORECASE) + if matches: + json_str = matches[-1].strip() + solution = _try_parse(json_str) + if solution is not None: + return solution + + # Try parsing entire last large JSON-like structure + try: + # Find potential JSON starting with { + json_match = re.search(r"({\s*(?:['\"]House|['{\"\[])+[\s\S]*})", text) + if json_match: + json_str = json_match.group(1) + solution = _try_parse(json_str) + if solution is not None: + return solution + except (json.JSONDecodeError, AttributeError): + pass + + # Last-chance extraction: parse top-level JSON object spans and keep the + # last one that parses and looks like a house assignment dictionary. + stack = [] + spans = [] + for idx, ch in enumerate(text): + if ch == "{": + stack.append(idx) + elif ch == "}" and stack: + start = stack.pop() + if not stack: + spans.append((start, idx + 1)) + for start, end in reversed(spans): + candidate = text[start:end] + solution = _try_parse(candidate) + if solution is not None: + # Handle wrapped solution with "answer" key + if isinstance(solution, dict) and "answer" in solution: + answer = solution["answer"] + if isinstance(answer, dict) and any( + re.match(r"^house\s*\d+$", str(key).strip(), flags=re.IGNORECASE) + for key in answer.keys() + ): + return answer + # Direct house keys + if any( + re.match(r"^house\s*\d+$", str(key).strip(), flags=re.IGNORECASE) + for key in solution.keys() + ): + return solution + + return None + + +async def _request_zebralogic_json(prompt: str, llm_server: Dict[str, Any]) -> str: + """Submit a strict-JSON request for ZebraLogic and return raw model content.""" + payload = dict(llm_server["payload"]) + payload["temperature"] = 0.0 + payload["messages"] = [ + { + "role": "system", + "content": ( + "You solve Zebra Logic puzzles and MUST return strictly valid JSON only. " + "No markdown fences. No explanation. No extra text." + ), + }, + { + "role": "user", + "content": prompt, + }, + ] + payload["response_format"] = {"type": "json_object"} + # Clamp max_tokens so input + output fits within model context window + # max_ctx = llm_server.get("max_context_length", 40960) + # msg_text = " ".join(m["content"] for m in payload["messages"]) + # est_input_tokens = len(msg_text) // 3 # conservative: ~3 chars/token + # available = max_ctx - est_input_tokens - 200 # 200 token safety margin + # if available < payload.get("max_tokens", 0): + # payload["max_tokens"] = max(512, available) + async with httpx.AsyncClient(timeout=120.0) as client: + response = await client.post( + llm_server["url"], + headers=llm_server["headers"], + json=payload, + ) + response.raise_for_status() + body = response.json() + return body["choices"][0]["message"]["content"].strip() + + +async def finalize_zebralogic_json(problem: str, trajectory: str, llm_server: Dict[str, Any]) -> str: + """Ask the model to convert an existing trajectory into strict final JSON only.""" + prompt = ( + "Convert the reasoning into the final Zebra Logic answer JSON.\n" + "Output ONLY valid JSON (no markdown, no explanation).\n" + "Use exact feature/value names from the puzzle.\n\n" + "PUZZLE:\n" + f"{problem}\n\n" + "REASONING:\n" + f"{trajectory}\n" + ) + return await _request_zebralogic_json(prompt, llm_server) + + +async def solve_zebralogic_json_direct(problem: str, llm_server: Dict[str, Any]) -> str: + """Directly solve ZebraLogic and return strict final JSON.""" + prompt = ( + "Solve the Zebra Logic puzzle and provide the final house assignments.\n" + "Output ONLY valid JSON with keys like 'House 1', 'House 2', etc.\n" + "Use exact feature/value names from the puzzle text.\n\n" + "PUZZLE:\n" + f"{problem}\n" + ) + return await _request_zebralogic_json(prompt, llm_server) + + +def _raw_example_to_zebra_problem(example): + """Convert a raw HuggingFace ZebraLogic example into the processed format + that zebra_correctness expects (matching process_zebralogic_problem output).""" + solution = example.get("solution", {}) + header = solution.get("header", []) + rows = solution.get("rows", []) + size = example.get("size", "") + n_houses, n_features = map(int, size.split("*")) + + # Build processed solution dict: {"House 1": {"feature": "value", ...}, ...} + processed_solution = {} + features = {} + for house_i, row in enumerate(rows): + house_dict = {} + for fname, value in zip(header[1:], row[1:]): + fname_l = fname.lower() + val_l = value.lower() + house_dict[fname_l] = val_l + features.setdefault(fname_l, set()).add(val_l) + processed_solution[f"House {house_i + 1}"] = house_dict + features = {k: sorted(v) for k, v in features.items()} + + return { + "solution": processed_solution, + "n_houses": n_houses, + "n_features": n_features, + "features": features, + } + + +def evaluate_zebralogic_answer(answer, example): + """Evaluate ZebraLogic solution against ground truth using zebra_correctness.""" + candidate = extract_last_json(answer) + if candidate is None: + # Fallback: try the older extraction for non-standard formats + candidate = extract_solution_zebralogic(answer) + if candidate is None: + return False, None, "Could not extract valid JSON solution" + + # Lowercase candidate keys/values to match processed ground truth + normed_candidate = {} + for house_key, attrs in candidate.items(): + house_match = re.search(r"House\s*(\d+)", house_key, re.IGNORECASE) + hk = f"House {house_match.group(1)}" if house_match else house_key + if isinstance(attrs, dict): + normed_candidate[hk] = {k.lower(): v.lower() if isinstance(v, str) else v + for k, v in attrs.items()} + else: + normed_candidate[hk] = attrs + + problem = _raw_example_to_zebra_problem(example) + correct, skipped, missing, total = zebra_correctness(problem, normed_candidate) + is_correct = correct == total + msg = f"Correct={correct}/{total}, skipped={skipped}, missing={missing}" + return is_correct, normed_candidate, msg + + +def load_dataset_for_task(task): + if task == "game24": + return load_dataset("nlile/24-game", split="train") + if task == "maze": + return load_dataset("microsoft/VISION_LANGUAGE", "maze_text_only", split="val") + if task == "spatialmap": + return load_dataset("microsoft/VISION_LANGUAGE", "spatial_map_text_only", split="val") + if task == "zebralogic": + return load_dataset("WildEval/ZebraLogic", name="grid_mode", split="test") + # if task == "verina": + # return load_verina_dataset() + # if task == "verina_spec": + # return load_verina_spec_dataset() + raise ValueError(f"Unsupported task: {task}") + + +def resolve_indices(task, dataset_len, args): + if args.indices: + return [int(x.strip()) for x in args.indices.split(",")] + if args.xrange: + parts = args.xrange.split("-") + if len(parts) == 2: + try: + start = int(parts[0].strip()) + end = int(parts[1].strip()) + return list(range(start, end)) + except ValueError: + raise ValueError(f"Invalid xrange format: {args.xrange}. Use 'start-end'") + if args.num_examples: + return list(np.linspace(0, dataset_len - 1, args.num_examples, dtype=int)) + start = args.start if args.start is not None else 0 + end = args.end if args.end is not None else dataset_len + return list(range(start, end)) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Run Tree-of-Thought search on a subset of the supported tasks", + ) + parser.add_argument("--task", choices=["game24", "maze", "spatialmap", "zebralogic", "verina", "verina_spec"], required=True) + parser.add_argument("--k", type=int, default=1, help="Unused placeholder to mirror other baselines") + parser.add_argument("--num_examples", "-n", type=int, default=None) + parser.add_argument("--indices", type=str, default=None) + parser.add_argument("--xrange", type=str, default=None) + parser.add_argument("--start", type=int, default=None) + parser.add_argument("--end", type=int, default=None) + parser.add_argument("--model", default="Qwen/QwQ-32B") + parser.add_argument("--llm_url", default="http://localhost:{port}/v1/chat/completions") + parser.add_argument( + "--ports", + default="8000,8001,8002,8003", + help="Comma-separated list of vLLM ports to round-robin across", + ) + parser.add_argument("--temperature", type=float, default=0.6) + parser.add_argument("--top_p", type=float, default=0.95) + parser.add_argument("--top_k", type=int, default=20) + parser.add_argument("--max_tokens", type=int, default=8192) + # parser.add_argument("--max_context_length", type=int, default=40960) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--search_method", choices=["bfs", "dfs", "beam"], default="beam") + parser.add_argument("--branching_factor", type=int, default=4) + parser.add_argument("--max_depth", type=int, default=1) + parser.add_argument("--beam_width", type=int, default=2) + parser.add_argument("--sure_threshold", type=float, default=0.9) + parser.add_argument("--likely_threshold", type=float, default=0.5) + parser.add_argument("--impossible_threshold", type=float, default=0.2) + parser.add_argument("--max_candidates_per_level", type=int, default=3) + parser.add_argument("--early_termination", action="store_true") + parser.add_argument("--no_cache", action="store_true") + parser.add_argument( + "--concurrency", + type=int, + default=1, + help="Maximum number of ToT examples to run concurrently", + ) + parser.add_argument( + "--output_dir", + default="/workspace/vishak/ToT/interwhen/examples/TTSwithVerification/outputs/dummy", + help="Directory to store per-example JSON logs and summary", + ) + parser.add_argument("--log_level", default="INFO") + parser.add_argument("--summary_file",default="summary.json") + parser.add_argument("--log_file", default="tot_baseline.log") + return parser.parse_args() + + +def parse_port_list(port_str: str) -> List[int]: + return [int(p.strip()) for p in port_str.split(",") if p.strip()] + + +def build_llm_server(args: argparse.Namespace, port: int) -> Dict[str, Any]: + payload = { + "model": args.model, + "temperature": args.temperature, + "top_p": args.top_p, + "top_k": args.top_k, + "max_tokens": args.max_tokens, + "stream": False, + "seed": args.seed, + } + return { + "url": args.llm_url.format(port=port), + "headers": {"content-type": "application/json"}, + "payload": payload, + # "max_context_length": args.max_context_length, + } + + +def build_tot_config(args: argparse.Namespace) -> ToTSearchConfig: + method = SearchMethod[args.search_method.upper()] + return ToTSearchConfig( + branching_factor=args.branching_factor, + max_depth=args.max_depth, + search_method=method, + beam_width=args.beam_width, + sure_threshold=args.sure_threshold, + likely_threshold=args.likely_threshold, + impossible_threshold=args.impossible_threshold, + early_termination=args.early_termination, + cache_evaluations=not args.no_cache, + max_candidates_per_level=args.max_candidates_per_level, + ) + + +def ensure_output_dir(base_dir: str, task: str) -> Path: + path = Path(base_dir).expanduser().resolve() / task + path.mkdir(parents=True, exist_ok=True) + return path + + +def evaluate_verina_answer(output, example, idx): + """Evaluate Verina code generation output for ToT.""" + generated_code = extract_code_from_response(output) + if not generated_code.strip(): + return False, "", "No code extracted from response" + compiles, all_tests_pass, compile_output, test_results = evaluate_generated_code( + example, generated_code, idx, + ) + num_tests = len(example.tests) if example.tests else 0 + num_passed = sum(1 for v in test_results.values() if v == "pass") + if compiles and all_tests_pass: + return True, generated_code, f"Code compiles and all {num_tests} tests pass" + elif compiles: + return False, generated_code, f"Compilation succeeded but {num_tests - num_passed}/{num_tests} tests failed" + else: + error_preview = compile_output[:300] if compile_output else "Unknown error" + return False, generated_code, f"Compilation failed: {error_preview}" + + +def evaluate_verina_spec_answer(output, example, idx): + """Evaluate Verina spec generation output for ToT.""" + generated_spec = extract_spec_from_response(output) + if not generated_spec.get("precond") and not generated_spec.get("postcond"): + return False, generated_spec, "No spec extracted from response" + eval_result = evaluate_generated_spec(example, generated_spec, idx) + if eval_result["full_spec_correct"]: + return True, generated_spec, "Spec compiles and all soundness/completeness tests pass" + elif eval_result["compiles"]: + msg_parts = [] + if eval_result["precond_sound_total"] > 0: + msg_parts.append(f"precond_sound={eval_result['precond_sound_pass']}/{eval_result['precond_sound_total']}") + if eval_result["precond_complete_total"] > 0: + msg_parts.append(f"precond_complete={eval_result['precond_complete_pass']}/{eval_result['precond_complete_total']}") + if eval_result["postcond_sound_total"] > 0: + msg_parts.append(f"postcond_sound={eval_result['postcond_sound_pass']}/{eval_result['postcond_sound_total']}") + if eval_result["postcond_complete_total"] > 0: + msg_parts.append(f"postcond_complete={eval_result['postcond_complete_pass']}/{eval_result['postcond_complete_total']}") + return False, generated_spec, f"Compilation succeeded but tests: {', '.join(msg_parts)}" + else: + error_preview = eval_result.get("compile_error", "")[:300] + return False, generated_spec, f"Compilation failed: {error_preview}" + + +def prepare_eval(task: str, example: Dict[str, Any], idx: int = 0) -> Tuple: + if task == "game24": + nums = list(example.get("numbers", [])) + return (lambda output: evaluate_game24_answer(output, nums), {"numbers": nums}) + if task == "zebralogic": + # Pass raw example so zebra_correctness can evaluate against processed solution + gt_problem = _raw_example_to_zebra_problem(example) + meta = { + "ground_truth_sample": str(example.get("solution", {}))[:100], + "ground_truth_solution": gt_problem["solution"], + } + return (lambda output: evaluate_zebralogic_answer(output, example), meta) + # if task == "verina": + # meta = {"data_id": example.data_id} + # return (lambda output: evaluate_verina_answer(output, example, idx), meta) + # if task == "verina_spec": + # meta = {"data_id": example.data_id} + # return (lambda output: evaluate_verina_spec_answer(output, example, idx), meta) + gt = str(example.get("ground_truth", "")).strip() + target_options = ["A", "B"] if gt == "Q4" else ["A", "B", "C", "D"] + if task == "maze": + _, user_prompt = build_maze_prompt(example) + else: + _, user_prompt = build_spatialmap_prompt(example) + options = extract_options_from_prompt(user_prompt, target_options) + meta = {"options": options, "ground_truth": gt} + return (lambda output: evaluate_mcq_answer(output, options, gt), meta) + + +async def run_single_example( + idx: int, + task: str, + example: Dict[str, Any], + tot_config: ToTSearchConfig, + llm_server: Dict[str, Any], +) -> Dict[str, Any]: + eval_fn, eval_meta = prepare_eval(task, example, idx) + nums = example.get("numbers") if hasattr(example, "get") else None + problem = build_tot_problem(task, example, nums=nums) + tot = TreeOfThoughtSearch(tot_config) + search_result = await tot.search(task, problem, llm_server) + best_traj = search_result.get("best_trajectory", "") + best_value = search_result.get("best_value", 0.0) + + # Verina: if the trajectory already contains [CODE], evaluate directly + # (consistent with other tasks that embed answers in trajectories). + # Only fall back to a synthesis LLM call if no [CODE] block is present. + synthesized_code = None + synthesized_spec = None + if task == "verina": + a = 2 + # if task == "verina" and best_traj.strip(): + # # if re.search(r'\[CODE\]', best_traj, re.IGNORECASE): + # # # Code already in trajectory — evaluate directly like other tasks + # # is_correct, extracted, message = eval_fn(best_traj) + # # else: + # # Pure reasoning trajectory — synthesize code + # try: + # synthesis_prompt = build_verina_synthesis_prompt(problem, best_traj) + # synthesized_code = await tot._call_llm_streaming(llm_server, synthesis_prompt) + # is_correct, extracted, message = eval_fn(synthesized_code) + # except Exception as exc: + # LOGGER.warning("Verina synthesis failed for index %s: %s", idx, exc) + # is_correct, extracted, message = eval_fn(best_traj) + # elif task == "verina_spec" and best_traj.strip(): + # # has_precond = bool(re.search(r'\[PRECOND\]', best_traj, re.IGNORECASE)) + # # has_postcond = bool(re.search(r'\[POSTCOND\]', best_traj, re.IGNORECASE)) + # # if has_precond and has_postcond: + # # # Spec already in trajectory — evaluate directly + # # is_correct, extracted, message = eval_fn(best_traj) + # # else: + # # Pure reasoning trajectory — synthesize spec + # try: + # synthesis_prompt = build_verina_spec_synthesis_prompt(problem, best_traj) + # synthesized_spec = await tot._call_llm_streaming(llm_server, synthesis_prompt) + # is_correct, extracted, message = eval_fn(synthesized_spec) + # except Exception as exc: + # LOGGER.warning("Verina spec synthesis failed for index %s: %s", idx, exc) + # is_correct, extracted, message = eval_fn(best_traj) + else: + is_correct, extracted, message = eval_fn(best_traj) + + # ZebraLogic often ends with partial reasoning trajectories; add strict-JSON + # recovery passes before scoring. + finalized_answer = None + direct_answer = None + if task == "zebralogic" and (not is_correct): + try: + finalized_answer = await finalize_zebralogic_json(problem, best_traj, llm_server) + final_is_correct, final_extracted, final_message = eval_fn(finalized_answer) + if final_extracted is not None or final_is_correct: + is_correct = final_is_correct + extracted = final_extracted + message = final_message + best_traj = finalized_answer + except Exception as exc: # pragma: no cover + LOGGER.warning("ZebraLogic finalization failed for index %s: %s", idx, exc) + + if task == "zebralogic" and (not is_correct): + try: + direct_answer = await solve_zebralogic_json_direct(problem, llm_server) + direct_is_correct, direct_extracted, direct_message = eval_fn(direct_answer) + if direct_extracted is not None or direct_is_correct: + is_correct = direct_is_correct + extracted = direct_extracted + message = direct_message + best_traj = direct_answer + except Exception as exc: # pragma: no cover + LOGGER.warning("ZebraLogic direct solve failed for index %s: %s", idx, exc) + + return { + "index": int(idx), + "best_value": best_value, + "best_trajectory": best_traj, + "raw_best_trajectory": search_result.get("best_trajectory", ""), + "synthesized_code": synthesized_code, + "synthesized_spec": synthesized_spec, + "finalized_answer": finalized_answer, + "direct_answer": direct_answer, + "search_stats": search_result.get("search_stats", {}), + "decision_tree": search_result.get("decision_tree", []), + "correct": bool(is_correct), + "extracted": extracted, + "message": message, + "evaluation_meta": eval_meta, + } + + +async def run_tot_baseline(args: argparse.Namespace) -> None: + log_level = getattr(logging, args.log_level.upper(), logging.INFO) + root_logger = logging.getLogger() + root_logger.setLevel(log_level) + # Send all logs to a file instead of stdout/stderr (keeps tqdm clean) + log_file = Path(args.output_dir) / args.log_file + log_file.parent.mkdir(parents=True, exist_ok=True) + fh = logging.FileHandler(str(log_file), mode="a") + fh.setLevel(log_level) + fh.setFormatter(logging.Formatter("%(asctime)s %(name)s %(levelname)s %(message)s")) + root_logger.addHandler(fh) + # Remove default stderr handler so logs don't clobber the progress bar + for h in root_logger.handlers[:]: + if isinstance(h, logging.StreamHandler) and not isinstance(h, logging.FileHandler): + root_logger.removeHandler(h) + dataset = load_dataset_for_task(args.task) + indices = resolve_indices(args.task, len(dataset), args) + output_dir = ensure_output_dir(args.output_dir, args.task) + tot_config = build_tot_config(args) + ports = parse_port_list(args.ports) + if not ports: + raise ValueError("At least one port must be specified via --ports") + concurrency = max(1, args.concurrency) + port_lock = asyncio.Lock() + port_index = {"value": 0} + + async def next_port() -> int: + async with port_lock: + port = ports[port_index["value"] % len(ports)] + port_index["value"] += 1 + return port + + semaphore = asyncio.Semaphore(concurrency) + + async def process_index(idx: int) -> Dict[str, Any]: + async with semaphore: + example = dataset[int(idx)] + port = await next_port() + llm_server = build_llm_server(args, port) + LOGGER.info("Running ToT on example %s via port %s", idx, port) + try: + record = await run_single_example(idx, args.task, example, tot_config, llm_server) + except Exception as exc: # pragma: no cover + LOGGER.exception("Failed example %s", idx) + record = { + "index": int(idx), + "error": str(exc), + "best_trajectory": "", + "correct": False, + } + example_path = output_dir / f"example_{idx}.json" + with example_path.open("w", encoding="utf-8") as handle: + json.dump(record, handle, indent=2) + return record + + processed = await tqdm_asyncio.gather( + *[process_index(idx) for idx in indices], + desc="ToT examples", + ) + + total = len(processed) + correct = sum(1 for r in processed if r.get("correct")) + summary = { + "task": args.task, + "model": args.model, + "total_examples": total, + "correct": correct, + "accuracy": (correct / total) if total else 0.0, + "search_method": args.search_method, + "config": { + "branching_factor": args.branching_factor, + "max_depth": args.max_depth, + "beam_width": args.beam_width, + "sure_threshold": args.sure_threshold, + "likely_threshold": args.likely_threshold, + "impossible_threshold": args.impossible_threshold, + "max_candidates_per_level": args.max_candidates_per_level, + "early_termination": args.early_termination, + "cache_evaluations": not args.no_cache, + "ports": ports, + "concurrency": concurrency, + }, + } + summary_path = output_dir / args.summary_file + with summary_path.open("w", encoding="utf-8") as handle: + json.dump(summary, handle, indent=2) + LOGGER.info("Accuracy %.2f (%d/%d)", summary["accuracy"], correct, total) + + +if __name__ == "__main__": + import time + st = time.time() + asyncio.run(run_tot_baseline(parse_args())) + et = time.time() + print(f"Total execution time: {et - st:.2f} seconds") \ No newline at end of file diff --git a/interwhen/interject.py b/interwhen/interject.py index c8e4fa6..f424aca 100644 --- a/interwhen/interject.py +++ b/interwhen/interject.py @@ -35,7 +35,11 @@ async def stream_completion(prompt, prev_text = "", llm_server=None, monitors=[] break else: # Obtain the current token (text chunk) - chunk = json.loads(data)["choices"][0]["text"] + try: + chunk = json.loads(data)["choices"][0]["text"] + except (json.JSONDecodeError, KeyError, IndexError) as e: + logger.debug(f"Skipping malformed SSE data: {data!r} ({e})") + continue # If any event is already set, break immediately (we don't want more chunks) if stop_event.is_set(): logger.info(f'\n[Early stop already triggered, ignoring chunk: {chunk}]') @@ -71,6 +75,8 @@ async def stream_completion(prompt, prev_text = "", llm_server=None, monitors=[] corrected_text = await monitors[0].fix(generated_text, stop_info) if stop_info["feedback"] == "\nthe answer is \\boxed{no solution}": return corrected_text # No solution found, return no solution ie soundness is 100% is it doesnt pass the verifer + if stop_info.get("phase") == "final_answer_correct": + return corrected_text # Expression verified correct, stop generation return await stream_completion(prompt, prev_text=corrected_text, llm_server=llm_server, monitors=monitors, add_delay=add_delay, num_calls_index=num_calls_index+1, termination_requires_validation=termination_requires_validation, async_execution=async_execution) return generated_text \ No newline at end of file diff --git a/interwhen/monitors/__init__.py b/interwhen/monitors/__init__.py index ad3e095..054f4d7 100644 --- a/interwhen/monitors/__init__.py +++ b/interwhen/monitors/__init__.py @@ -2,4 +2,7 @@ from .k_stable import KstableAnswerMCQMonitor, KstableAnswerGame24Monitor from .stepVerifier import StepVerifierGame24Monitor, StepVerifierMazeMonitor, StepVerifierSpatialMapMonitor from .earlyStopping import EATMonitor, DEERMonitor +from .thinkingPhaseVerifierGame24 import ThinkingPhaseStepVerifierGame24Monitor +from .thinkingPhaseVerifierMaze import ThinkingPhaseStepVerifierMazeMonitor +from .thinkingPhaseVerifierSpatialMap import ThinkingPhaseStepVerifierSpatialMapMonitor from .zebralogic_monitor import ZebraLogicMonitor diff --git a/interwhen/monitors/_common.py b/interwhen/monitors/_common.py new file mode 100644 index 0000000..1c047fd --- /dev/null +++ b/interwhen/monitors/_common.py @@ -0,0 +1,51 @@ +""" +Shared utilities for thinking-phase verifier monitors. +""" + +import re +from typing import Optional + + +def find_complete_boxed(text: str) -> Optional[object]: + """Find a complete \\boxed{...} in text, handling nested braces. + + Unlike ``re.search(r'\\boxed\\{[^}]+\\}', text)`` this correctly + handles LaTeX like ``\\boxed{12\\frac{1}{2}}`` where the naive + ``[^}]+`` pattern would stop at the first ``}``. + + Returns a match-like object with ``.start()`` and ``.end()`` + spanning the full ``\\boxed{...}`` (including the outer braces), + or ``None`` if no complete boxed expression is found. + """ + idx = 0 + while idx < len(text): + pos = text.find(r'\boxed{', idx) + if pos == -1: + return None + # Start counting braces from after '\boxed{' + brace_start = pos + len(r'\boxed{') + depth = 1 + i = brace_start + while i < len(text) and depth > 0: + if text[i] == '{': + depth += 1 + elif text[i] == '}': + depth -= 1 + i += 1 + if depth == 0: + match_start = pos + match_end = i # i is right after the closing '}' + content = text[brace_start:i - 1].strip() + if content: + class _BoxedMatch: + def __init__(self, s, e): + self._start, self._end = s, e + def start(self): + return self._start + def end(self): + return self._end + def group(self, n=0): + return text[self._start:self._end] + return _BoxedMatch(match_start, match_end) + idx = pos + 1 + return None diff --git a/interwhen/monitors/stepVerifier.py b/interwhen/monitors/stepVerifier.py index ca5fc4b..a5f1262 100644 --- a/interwhen/monitors/stepVerifier.py +++ b/interwhen/monitors/stepVerifier.py @@ -1,18 +1,27 @@ import re +import logging from typing import List, Tuple, Optional, Set, Dict from .base import VerifyMonitor +from ._common import find_complete_boxed from ..utils.game24_verifier import parse_step, verify_step, format_feedback from ..utils.maze_verifier import ( Direction, parse_direction, parse_maze_step, verify_maze_step, verify_locate_section, format_maze_feedback, format_locate_feedback, - parse_maze_from_prompt + parse_maze_from_prompt, compute_relative_direction, ) from ..utils.spatialmap_verifier import ( SpatialMapZ3Solver, parse_directional_claims_from_text, - extract_step2_claims, verify_spatialmap_step, format_spatialmap_feedback + extract_step2_claims, verify_spatialmap_step, format_spatialmap_feedback, + parse_counting_question, parse_model_count_from_answer, + parse_direction_question, parse_object_question, + parse_model_boxed_answer, + get_possible_directions, get_consistent_object_options, + get_possible_count_range, ) - +logger = logging.getLogger(__name__) + + class StepVerifierGame24Monitor(VerifyMonitor): """ Step-by-step Game of 24 verifier monitor. @@ -295,7 +304,8 @@ def __init__( exit_pos: Tuple[int, int], max_corrections: int = 5, question_type: str = "right_turns", # "right_turns", "total_turns", "relative_position" - async_execution: bool = True + async_execution: bool = True, + prompt: str = None, ): super().__init__(name) self.async_execution = async_execution @@ -305,6 +315,7 @@ def __init__( self.exit_pos = exit_pos self.max_corrections = max_corrections self.question_type = question_type + self.prompt = prompt @staticmethod def detect_question_type(prompt: str) -> str: @@ -375,7 +386,8 @@ def from_prompt( exit_pos=exit_pos, max_corrections=max_corrections, question_type=question_type, - async_execution=async_execution + async_execution=async_execution, + prompt=prompt, ) def _count_feedback_blocks(self, text: str) -> int: @@ -603,18 +615,72 @@ async def verify(self, chunk: str, token_index: int, event, event_info: dict): return chunk, feedback + def _verify_relative_position_answer(self, boxed_answer: str) -> Tuple[bool, Optional[str]]: + """Verify a relative-position boxed answer (A=Yes / B=No). + + Parses the question from ``self.prompt`` to determine the asked + direction, computes the true relative direction of E from S, + and checks whether the model's Yes/No answer is correct. + + Returns ``(is_correct, feedback_or_None)``. + """ + if self.prompt is None: + return True, None + + answer_map = {"A": "Yes", "B": "No"} + model_yn = answer_map.get(boxed_answer.strip().upper()) + if model_yn is None: + return True, None + + m = re.search( + r'Is the exit \(E\)\s+(.*?)\s+(?:of\s+)?the starting point \(S\)', + self.prompt, re.IGNORECASE, + ) + if not m: + return True, None + + asked_raw = m.group(1).strip().lower() + asked_raw = re.sub(r',.*', '', asked_raw).strip() + + actual = compute_relative_direction(self.start_pos, self.exit_pos) + + direction_keywords = { + "directly to the left": {"west"}, + "directly to the right": {"east"}, + "directly above": {"north"}, + "directly below": {"south"}, + "to the top left": {"northwest"}, + "to the top right": {"northeast"}, + "to the bottom left": {"southwest"}, + "to the bottom right": {"southeast"}, + } + + expected_dirs = direction_keywords.get(asked_raw) + if expected_dirs is None: + return True, None + + expected_yn = "Yes" if actual in expected_dirs else "No" + + if model_yn == expected_yn: + return True, None + + feedback = ( + f"\n\n[VERIFIER FEEDBACK for relative position:\n" + f" ✗ Your answer {boxed_answer} ({model_yn}) is incorrect.\n" + f" IMPORTANT: In this task, \"{asked_raw}\" means the GENERAL " + f"COMPASS DIRECTION, NOT immediate adjacency. It asks whether E " + f"is in the {actual} direction from S, regardless of distance or " + f"walls between them.]\n\n" + ) + return False, feedback + async def _verify_relative_position(self, chunk: str, token_index: int, event, event_info: dict): """ Verify relative position answer. - For relative_position questions (Yes/No format), we only verify: + For relative_position questions (Yes/No format), we verify: 1. The LOCATE section (S and E positions are correctly identified) - - We do NOT verify the final Yes/No answer because: - - The question asks "Is E to the [direction] of S?" - - We don't have the question text here to know what direction was asked - - The comparison logic (row/col arithmetic) is straightforward - - If LOCATE is correct, the model should get the answer right + 2. The boxed answer (A=Yes / B=No) against the computed direction """ # Check LOCATE section for correct S and E positions locate_valid, locate_errors, locate_found = self._check_locate_section(chunk) @@ -629,8 +695,27 @@ async def _verify_relative_position(self, chunk: str, token_index: int, event, e event.set() return chunk, feedback - # For relative_position, we don't verify the final Yes/No answer - # Just let it complete once LOCATE is verified + # Check for boxed answer and verify it + if '' in chunk: + text_after_think = chunk.split("")[-1] + boxed_match = find_complete_boxed(text_after_think) + if boxed_match: + boxed_text = text_after_think[boxed_match.start():boxed_match.end()] + # Extract the letter from \boxed{X} + inner = re.search(r'\\boxed\{([^}]*)\}', boxed_text) + if inner: + boxed_answer = inner.group(1).strip() + is_correct, rp_feedback = self._verify_relative_position_answer(boxed_answer) + if not is_correct and rp_feedback: + if not event.is_set(): + event_info["generated_text"] = chunk + event_info["feedback"] = rp_feedback + event_info["correction_index"] = token_index + event_info["errors"] = [f"Wrong relative position answer: {boxed_answer}"] + event_info["failed_step"] = None + event.set() + return chunk, rp_feedback + return chunk, None async def fix(self, generated_text: str, event_info: dict, fix_method=None) -> str: @@ -739,7 +824,7 @@ def _step_extractor_relative_position( 3. LOCATE section is complete and analysis has started (verify LOCATE) """ # Check for boxed answer first (highest priority) - boxed_match = re.search(r'\\boxed\{[^}]+\}', text) + boxed_match = find_complete_boxed(text) if boxed_match: # Found answer, verify it (include full text up to boxed answer) end_pos = text_start_in_generated + boxed_match.end() @@ -811,6 +896,42 @@ def __init__( # Track verified claims to avoid re-checking self.verified_claims: Set[Tuple[str, str, str]] = set() + + # ---- question-type detection (consistent with ThinkingPhaseVerifier) ---- + self._counting_question = parse_counting_question(problem_text) + self._counting_options: Dict[str, str] = {} + _opts_text = re.split(r'\nFirst,', problem_text, maxsplit=1)[0] + if self._counting_question: + raw_opts = re.findall( + r'([A-D])\.\s*(.+?)\s*(?=[A-D]\.|$)', + _opts_text, flags=re.DOTALL, + ) + self._counting_options = { + k: v.strip().rstrip(".") for k, v in raw_opts + } + + self._direction_question = parse_direction_question(problem_text) + self._object_question = parse_object_question(problem_text) + + # Generic MCQ options (for direction & object Qs too) + if not self._counting_options: + raw_opts = re.findall( + r'([A-D])\.\s*(.+?)\s*(?=[A-D]\.|$)', + _opts_text, flags=re.DOTALL, + ) + self._mcq_options: Dict[str, str] = { + k: v.strip().rstrip(".") for k, v in raw_opts + } + else: + self._mcq_options = dict(self._counting_options) + + # Retry limits for final-answer verification + self._max_final_answer_retries = 3 + self._direction_feedback_count = 0 + self._object_feedback_count = 0 + self._diag_count_feedback_count = 0 + self._count_feedback_given = False + self._count_feedback_blocks_count = 0 @classmethod def from_prompt( @@ -861,8 +982,13 @@ def _extract_new_claims(self, chunk: str) -> List[Dict]: # Only look at text after the last feedback text_to_check = text_after_think[last_feedback_end:] + # Get full entity names from Z3 solver for abbreviation resolution + entity_names = list({ + k[:-2] for k in self.z3_solver.entities if k.endswith('_x') + }) + # Extract claims from STEP 2 in the latest attempt only - all_claims = extract_step2_claims(text_to_check) + all_claims = extract_step2_claims(text_to_check, entity_names=entity_names) # Filter to only new claims (not yet verified) new_claims = [] @@ -929,7 +1055,282 @@ async def verify(self, chunk: str, token_index: int, event, event_info: dict): return chunk, feedback - # All claims valid + # All claims valid — check for boxed answer (final answer verification) + if '' in chunk: + text_after_think = chunk.split("")[-1] + feedback_pattern = re.compile(r'\[VERIFIER FEEDBACK[^\]]*\]\s*', re.DOTALL) + last_feedback_end = 0 + for match in feedback_pattern.finditer(text_after_think): + last_feedback_end = match.end() + recent_text = text_after_think[last_feedback_end:] + + boxed_match = find_complete_boxed(recent_text) + if boxed_match: + # --- Direction-question verification --- + if ( + self._direction_question + and num_corrections < self.max_corrections + and self._direction_feedback_count < self._max_final_answer_retries + ): + model_dir_text = parse_model_boxed_answer( + recent_text, self._mcq_options + ) + if model_dir_text: + possible = get_possible_directions( + self.z3_solver, + self._direction_question["entity_a"], + self._direction_question["entity_b"], + ) + if model_dir_text not in possible: + self._direction_feedback_count += 1 + valid_options = [ + letter for letter, val in self._mcq_options.items() + if val.strip().lower().rstrip(".") in possible + ] + if len(valid_options) == 1: + feedback = ( + f"\n\n[VERIFIER FEEDBACK: Direction error!\n" + f" '{model_dir_text.title()}' is " + f"impossible for " + f"{self._direction_question['entity_a']} " + f"relative to " + f"{self._direction_question['entity_b']} " + f"based on the given constraints.]\n\n" + f">>> STEP 3: ANSWER\n" + ) + else: + feedback = ( + f"\n\n[VERIFIER FEEDBACK: Direction error!\n" + f" '{model_dir_text.title()}' is " + f"impossible for " + f"{self._direction_question['entity_a']} " + f"relative to " + f"{self._direction_question['entity_b']} " + f"based on the given constraints.\n" + f" Please reconsider and choose the " + f"correct option.]\n\n" + f">>> STEP 3: ANSWER\n" + ) + if not event.is_set(): + event_info["generated_text"] = chunk + event_info["feedback"] = feedback + event_info["correction_index"] = token_index + event_info["errors"] = [ + f"Direction '{model_dir_text}' impossible; " + f"possible: {possible}" + ] + event_info["failed_step"] = None + event.set() + return chunk, feedback + + # --- Object-question verification --- + if ( + self._object_question + and num_corrections < self.max_corrections + and self._object_feedback_count < self._max_final_answer_retries + ): + model_obj_text = parse_model_boxed_answer( + recent_text, self._mcq_options + ) + boxed_raw = re.findall( + r'\\boxed\{([^}]*)\}', recent_text + ) + model_letter = ( + boxed_raw[-1].strip().upper() if boxed_raw else None + ) + + if model_letter: + consistent = get_consistent_object_options( + self.z3_solver, + self._object_question["direction"], + self._object_question["reference"], + self._mcq_options, + ) + if model_letter not in consistent: + self._object_feedback_count += 1 + odir = self._object_question["direction"] + oref = self._object_question["reference"] + if len(consistent) == 1: + correct_name = self._mcq_options.get( + consistent[0], consistent[0] + ) + feedback = ( + f"\n\n[VERIFIER FEEDBACK: Object error!\n" + f" '{model_obj_text}' cannot be " + f"{odir} of {oref} based on the " + f"given constraints.\n" + f" The only consistent option is " + f"{consistent[0]}. {correct_name}.\n" + f" Please select option " + f"{consistent[0]}.]\n\n" + f">>> STEP 3: ANSWER\n" + ) + else: + valid_names = [ + f"{l}. {self._mcq_options.get(l, l)}" + for l in consistent + ] + feedback = ( + f"\n\n[VERIFIER FEEDBACK: Object error!\n" + f" '{model_obj_text}' cannot be " + f"{odir} of {oref} based on the " + f"given constraints.\n" + f" The consistent options are: " + f"{', '.join(valid_names)}.\n" + f" Please reconsider and choose the " + f"correct option.]\n\n" + f">>> STEP 3: ANSWER\n" + ) + if not event.is_set(): + event_info["generated_text"] = chunk + event_info["feedback"] = feedback + event_info["correction_index"] = token_index + event_info["errors"] = [ + f"Object '{model_obj_text}' impossible " + f"in {odir} of {oref}; " + f"consistent: {consistent}" + ] + event_info["failed_step"] = None + event.set() + return chunk, feedback + + # --- Counting-question verification --- + if ( + self._counting_question + and num_corrections < self.max_corrections + ): + direction = self._counting_question["direction"] + reference = self._counting_question["reference"] + is_cardinal = direction in ( + "north", "south", "east", "west" + ) + + if is_cardinal: + model_count = parse_model_count_from_answer( + recent_text, self._counting_options + ) + z3_count = 0 + + if ( + model_count is not None + and model_count != z3_count + ): + self._count_feedback_given = True + self._count_feedback_blocks_count += 1 + + if direction in ("north", "south"): + diag_examples = "northeast or northwest" + elif direction == "west": + diag_examples = "northwest or southwest" + else: + diag_examples = "northeast or southeast" + + feedback = ( + f"\n\n[VERIFIER FEEDBACK: Count mismatch!\n" + f" You answered {model_count} objects " + f"'{direction}' of {reference}, but this " + f"count is incorrect.\n" + f" IMPORTANT: '{direction}' is a strict " + f"cardinal direction — it means ONLY " + f"exactly {direction}, NOT {diag_examples}." + f"\n" + f" An object that is {diag_examples.split(' or ')[0]} of " + f"{reference} is NOT {direction} of " + f"{reference}.\n" + f" Re-examine each object: is it described " + f"as being strictly '{direction} of' " + f"{reference}, or is the relationship " + f"actually a diagonal direction like " + f"{diag_examples}? Only count objects that " + f"are strictly {direction}.]\n\n" + f">>> STEP 3: ANSWER\n" + ) + + if not event.is_set(): + event_info["generated_text"] = chunk + event_info["feedback"] = feedback + event_info["correction_index"] = token_index + event_info["errors"] = [ + f"Cardinal count mismatch: expected 0, " + f"got {model_count}" + ] + event_info["failed_step"] = None + event.set() + return chunk, feedback + + else: + if self._diag_count_feedback_count < self._max_final_answer_retries: + model_count = parse_model_count_from_answer( + recent_text, self._counting_options + ) + count_range = get_possible_count_range( + self.z3_solver, reference, direction + ) + + if ( + model_count is not None + and count_range is not None + ): + min_c, max_c = count_range + + if not (min_c <= model_count <= max_c): + self._diag_count_feedback_count += 1 + valid_opts = [] + for opt, val in ( + self._counting_options.items() + ): + try: + v = int(val) + if min_c <= v <= max_c: + valid_opts.append( + (opt, v) + ) + except (ValueError, TypeError): + pass + + if len(valid_opts) == 1: + feedback = ( + f"\n\n[VERIFIER FEEDBACK: " + f"Count error!\n" + f" {model_count} objects " + f"'{direction}' of {reference}" + f" is impossible.\n" + f" The valid count is " + f"{valid_opts[0][1]}.\n" + f" Please select option " + f"{valid_opts[0][0]}.]\n\n" + f">>> STEP 3: ANSWER\n" + ) + else: + feedback = ( + f"\n\n[VERIFIER FEEDBACK: " + f"Count error!\n" + f" {model_count} objects " + f"'{direction}' of {reference}" + f" is impossible.\n" + f" The possible count range " + f"is [{min_c}, {max_c}].\n" + f" Please reconsider and " + f"choose the correct " + f"option.]\n\n" + f">>> STEP 3: ANSWER\n" + ) + + if not event.is_set(): + event_info["generated_text"] = chunk + event_info["feedback"] = feedback + event_info["correction_index"] = ( + token_index + ) + event_info["errors"] = [ + f"Diagonal count " + f"{model_count} outside " + f"range [{min_c}, {max_c}]" + ] + event_info["failed_step"] = None + event.set() + return chunk, feedback + return chunk, None async def fix(self, generated_text: str, event_info: dict, fix_method=None) -> str: @@ -989,12 +1390,16 @@ def step_extractor(self, chunk: str, generated_text: str) -> Tuple[bool, Optiona return True, generated_text[:end_pos] # Check for boxed answer (trigger final verification) - boxed_match = re.search(r'\\boxed\{[^}]+\}', text) + boxed_match = find_complete_boxed(text) if boxed_match: # Verify any remaining claims before final answer new_claims = self._extract_new_claims(generated_text) if new_claims: end_pos = text_start_in_generated + boxed_match.end() return True, generated_text[:end_pos] + # Even if no new claims, boxed answer signals completion — + # trigger to allow final answer verification (direction/object/counting) + end_pos = text_start_in_generated + boxed_match.end() + return True, generated_text[:end_pos] return False, None \ No newline at end of file diff --git a/interwhen/monitors/thinkingPhaseVerifierGame24.py b/interwhen/monitors/thinkingPhaseVerifierGame24.py new file mode 100644 index 0000000..13b7af3 --- /dev/null +++ b/interwhen/monitors/thinkingPhaseVerifierGame24.py @@ -0,0 +1,608 @@ +""" +Thinking-phase verifier for Game of 24. + +Verifies expressions by forking a side-stream during the thinking phase +to ask the model about its current progress. + +Workflow +-------- +A) **DURING the thinking phase** (inside ``...``): + After a warmup period, every *N* newlines in the thinking trace: + 1. Inject `` The expression that I found till now is {`` and + stream ~20 tokens to extract the expression the model outputs. + 2. Verify the expression against Game-of-24 rules. + 3. If **wrong** -> inject error feedback into thinking trace. + 4. If **correct AND complete** -> inject early-stop message + ````. + 5. If **correct AND partial** -> no feedback, let model keep thinking. + +B) **AFTER a natural ````**: + Inject the expression extraction prompt so the model outputs its + answer expression, then verify in the same way. +""" + +import re +import json +import logging +from typing import List, Tuple, Optional +from copy import deepcopy + +import httpx + +from .base import VerifyMonitor +from ._common import find_complete_boxed +from ..utils.game24_verifier import ( + can_reach_24, is_close, format_number, safe_eval, +) + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Prompts injected to elicit an expression from the model. +# --------------------------------------------------------------------------- + +# Injected during the thinking phase (after ) +THINKING_PHASE_EXPRESSION_PROMPT = ( + "\nThe expression that I found till now is {" +) + +# Injected after a natural to force the model to emit \boxed{expr} +FINAL_EXPRESSION_PROMPT = ( + "\nThe final expression is \\boxed" +) + + +# --------------------------------------------------------------------------- +# Expression verification helpers +# --------------------------------------------------------------------------- + +def _extract_numbers_from_expr(expr: str) -> List[float]: + """Extract all numbers (integers and decimals) from an expression string.""" + numbers = re.findall(r'\d+\.?\d*', expr) + return [int(float(n)) if float(n) == int(float(n)) else float(n) for n in numbers] + + +def _normalize_number(n) -> float: + """Normalize a number for comparison.""" + return float(n) + + +def verify_expression(expr_str: str, original_numbers: List[float]) -> Tuple[str, bool, List[str], Optional[List[float]]]: + """ + Verify an expression against the Game of 24 rules. + + Returns: + (status, is_valid, errors, unused_numbers_or_None) + - status: "complete" | "partial" | "error" + - is_valid: True if the expression is valid (no errors) + - errors: List of error messages + - unused_numbers: Numbers from original not used in expr (None if errors) + """ + errors = [] + fmt = format_number + + used_numbers = _extract_numbers_from_expr(expr_str) + if not used_numbers: + errors.append(f"No numbers found in expression: {expr_str}") + return "error", False, errors, None + + original_copy = [_normalize_number(n) for n in original_numbers] + matched_indices = [] + for used_n in used_numbers: + used_norm = _normalize_number(used_n) + found = False + for i, orig_n in enumerate(original_copy): + if i not in matched_indices and is_close(used_norm, orig_n): + matched_indices.append(i) + found = True + break + if not found: + errors.append( + f"Number {fmt(used_norm)} in expression is not available in " + f"original numbers {[fmt(n) for n in original_numbers]} " + f"(or was already used)" + ) + + if errors: + return "error", False, errors, None + + unused = [original_copy[i] for i in range(len(original_copy)) if i not in matched_indices] + + try: + value = eval(expr_str, {"__builtins__": None}, {}) + value = float(value) + except Exception as e: + errors.append(f"Cannot evaluate expression '{expr_str}': {e}") + return "error", False, errors, None + + all_used = len(unused) == 0 + + if all_used: + if not is_close(value, 24): + errors.append( + f"Expression '{expr_str}' evaluates to {fmt(value)}, not 24." + ) + return "error", False, errors, None + return "complete", True, [], [] + else: + remaining = [value] + unused + can_reach, example = can_reach_24(remaining) + if not can_reach: + remaining_str = [fmt(n) for n in remaining] + errors.append( + f"Expression '{expr_str}' evaluates to {fmt(value)}. " + f"Remaining numbers (including result) are {remaining_str}. " + f"Cannot reach 24 from these numbers. This is a dead end." + ) + return "error", False, errors, None + return "partial", True, [], unused + + +# --------------------------------------------------------------------------- +# Monitor +# --------------------------------------------------------------------------- + +class ThinkingPhaseStepVerifierGame24Monitor(VerifyMonitor): + """ + Monitor that verifies Game-of-24 expressions during and after thinking. + + During thinking: every N newlines (after warmup) -> fork a + side-stream asking for the current expression, verify it, and + give appropriate feedback. + + After natural ````: inject expression prompt, verify the + final answer. + """ + + def __init__( + self, + name: str, + original_numbers: List[int], + llm_server: dict, + prompt: str, + newline_threshold: int = 15, + max_corrections: int = 5, + answer_start_token: str = "", + async_execution: bool = True, + warmup_newlines: int = 0, + ): + super().__init__(name) + self.original_numbers = [float(x) for x in original_numbers] + self.llm_server = llm_server + self.prompt = prompt + self.newline_threshold = newline_threshold + self.max_corrections = max_corrections + self.answer_start_token = answer_start_token + self.async_execution = async_execution + self.warmup_newlines = warmup_newlines + + # ---- state ---- + self._think_phase_corrections = 0 + self._verified_expression = None # set by Phase 1 early-stop + + # ------------------------------------------------------------------ + # helpers + # ------------------------------------------------------------------ + @staticmethod + def _fmt(n: float) -> str: + if abs(n - round(n)) < 1e-9: + return str(int(round(n))) + return f"{n:.4f}".rstrip('0').rstrip('.') + + def _count_feedback_blocks(self, text: str) -> int: + return len(re.findall(r'\[VERIFIER FEEDBACK[^\]]*\]', text)) + + def _is_in_thinking_phase(self, generated_text: str) -> bool: + return self.answer_start_token not in generated_text + + @staticmethod + def _extract_braced_expression(text: str) -> Optional[str]: + """Extract the first expression wrapped in { } from *text*. + + Handles nested braces so that e.g. ``{(3+5)*7}`` is extracted correctly. + """ + start = text.find('{') + if start == -1: + return None + brace_count = 0 + end = start + while end < len(text): + if text[end] == '{': + brace_count += 1 + elif text[end] == '}': + brace_count -= 1 + if brace_count == 0: + break + end += 1 + if brace_count != 0: + return None + expr = text[start + 1:end].strip() + if not expr: + return None + # Basic cleanup: remove LaTeX + expr = expr.replace(r'\times', '*').replace(r'\cdot', '*').replace(r'\div', '/') + expr = expr.replace(r'\,', '').replace(r'\ ', '') + expr = expr.replace(r'\left', '').replace(r'\right', '') + # Replace Unicode math operators (QwQ frequently uses these) + expr = expr.replace('\u00d7', '*').replace('\u00f7', '/').replace('\u2212', '-') + expr = expr.replace('\u2013', '-').replace('\u2014', '-') # en-dash, em-dash + frac_pattern = r"\\frac\{([^{}]+)\}\{([^{}]+)\}" + while re.search(frac_pattern, expr): + expr = re.sub(frac_pattern, r"(\1/\2)", expr) + # Handle implicit multiplication + expr = re.sub(r'\)\s*\(', ')*(', expr) + expr = re.sub(r'\)\s*(\d)', r')*\1', expr) + expr = re.sub(r'(\d)\s*\(', r'\1*(', expr) + return expr + + @staticmethod + def _extract_boxed_expression(text: str) -> Optional[str]: + """Extract expression from \\boxed{...} in text.""" + boxed_pattern = r"\\boxed\{" + matches = list(re.finditer(boxed_pattern, text)) + if not matches: + return None + last_match = matches[-1] + start = last_match.end() + brace_count = 1 + end = start + while end < len(text) and brace_count > 0: + if text[end] == '{': + brace_count += 1 + elif text[end] == '}': + brace_count -= 1 + end += 1 + expr = text[start:end - 1].strip() + expr = expr.replace(r'\times', '*').replace(r'\cdot', '*').replace(r'\div', '/') + expr = expr.replace(r'\,', '').replace(r'\ ', '') + expr = expr.replace(r'\left', '').replace(r'\right', '') + expr = expr.replace('\u00d7', '*').replace('\u00f7', '/').replace('\u2212', '-') + expr = expr.replace('\u2013', '-').replace('\u2014', '-') + frac_pattern = r"\\frac\{([^{}]+)\}\{([^{}]+)\}" + while re.search(frac_pattern, expr): + expr = re.sub(frac_pattern, r"(\1/\2)", expr) + expr = re.sub(r'\)\s*\(', ')*(', expr) + expr = re.sub(r'\)\s*(\d)', r')*\1', expr) + expr = re.sub(r'(\d)\s*\(', r'\1*(', expr) + return expr + + # ------------------------------------------------------------------ + # _side_stream_expression (streams ~20 tokens to get {expr}) + # ------------------------------------------------------------------ + async def _side_stream_expression(self, text_so_far: str, max_new_tokens: int = 20) -> Optional[str]: + """ + Send ``prompt + text_so_far`` to vLLM, stream at most + *max_new_tokens* tokens, and try to extract an expression from + the output that appears inside ``{ }``. + """ + fmt = self._fmt + nums_str = ", ".join(fmt(n) for n in self.original_numbers) + logger.info( + f"[Side-stream] Starting expression extraction\n" + f" Original numbers : [{nums_str}]\n" + f" Max new tokens : {max_new_tokens}" + ) + + payload = deepcopy(self.llm_server["payload"]) + payload["prompt"] = self.prompt + text_so_far + payload["max_tokens"] = max_new_tokens + payload.pop("logprobs", None) + + generated = "" + + async with httpx.AsyncClient(timeout=None) as client: + async with client.stream( + "POST", + self.llm_server["url"], + headers=self.llm_server["headers"], + json=payload, + ) as response: + async for line in response.aiter_lines(): + if line.startswith("data: "): + data = line[len("data: "):].strip() + if data == "[DONE]": + break + chunk = json.loads(data)["choices"][0]["text"] + generated += chunk + logger.debug(f"[Side-stream] chunk: {chunk!r}") + + if '}' in generated: + break + + full_text = "{" + generated + expr = self._extract_braced_expression(full_text) + if expr: + logger.info(f"[Side-stream] Extracted expression: {expr}") + else: + logger.info( + f"[Side-stream] No expression found in side-stream " + f"(generated {len(generated)} chars: {generated!r})" + ) + return expr + + # ------------------------------------------------------------------ + # step_extractor + # ------------------------------------------------------------------ + def step_extractor(self, chunk: str, generated_text: str): + # ===== PHASE 1: still inside ===== + if self._is_in_thinking_phase(generated_text): + if self._think_phase_corrections >= self.max_corrections: + return False, None + + total_newlines = generated_text.count('\n') + + if total_newlines < self.warmup_newlines: + return False, None + + past_warmup = total_newlines - self.warmup_newlines + if (generated_text.endswith('\n') + and past_warmup >= 0 + and past_warmup % self.newline_threshold == 0): + logger.info( + f"[step_extractor] Phase 1 trigger: \\n count={total_newlines} " + f"(warmup={self.warmup_newlines}, past_warmup={past_warmup}, " + f"threshold={self.newline_threshold})" + ) + return True, generated_text + + return False, None + + # ===== PHASE 2: after ===== + + # 2a: present but we haven't injected the expression prompt yet + if FINAL_EXPRESSION_PROMPT.strip() not in generated_text: + logger.info( + "[step_extractor] Phase 2a: detected, " + "expression prompt not yet injected." + ) + return True, generated_text + + # 2b: trigger once we see a complete \boxed{...} + think_end_pos = generated_text.find(self.answer_start_token) + len(self.answer_start_token) + text_after_think = generated_text[think_end_pos:] + + feedback_pattern = re.compile(r'\[VERIFIER FEEDBACK[^\]]*\]\s*', re.DOTALL) + last_feedback_end = 0 + for match in feedback_pattern.finditer(text_after_think): + last_feedback_end = match.end() + text = text_after_think[last_feedback_end:] + + has_boxed = find_complete_boxed(text) + if has_boxed: + return True, generated_text + + return False, None + + # ------------------------------------------------------------------ + # verify + # ------------------------------------------------------------------ + async def verify(self, step: str, token_index: int, event, event_info): + # ================================================================== + # CASE 1: Thinking phase -- side-stream expression verification + # ================================================================== + if self.answer_start_token not in step: + total_dn = step.count('\n') + logger.info( + f"[Phase 1] Thinking-phase verification triggered\n" + f" \\n count : {total_dn}\n" + f" Thinking len : {len(step)} chars" + ) + + text_with_prompt = step + "\n" + THINKING_PHASE_EXPRESSION_PROMPT + + expr_str = await self._side_stream_expression(text_with_prompt, max_new_tokens=20) + + if expr_str is None: + logger.info( + "[Phase 1] No expression extracted from side-stream. " + "Letting model continue thinking." + ) + return step, None + + status, is_valid, errors, unused = verify_expression( + expr_str, self.original_numbers + ) + + if not is_valid: + error_summary = "; ".join(errors) + self._think_phase_corrections += 1 + logger.info( + f"[Phase 1] INVALID expression '{expr_str}'\n" + f" Error(s) : {error_summary}\n" + f" Action : Inject feedback into thinking trace\n" + f" Corrections: {self._think_phase_corrections}/{self.max_corrections}" + ) + thinking_feedback = ( + f"\n\nWait, the expression {expr_str} does not work. " + f"{error_summary} " + f"I must NOT reuse {expr_str} or any expression I have already tried. " + f"Let me try a completely different combination of " + f"operations and grouping of numbers.\n" + ) + if not event.is_set(): + event_info["generated_text"] = step + event_info["feedback"] = thinking_feedback + event_info["correction_index"] = token_index + event_info["errors"] = errors + event_info["phase"] = "rollback_to_thinking" + event.set() + return step, thinking_feedback + + elif status == "complete": + self._verified_expression = expr_str + logger.info( + f"[Phase 1] VALID COMPLETE expression '{expr_str}' == 24\n" + f" Action: Inject early-stop message and transition to answer." + ) + early_stop_msg = ( + f"\n\nWait, the expression {expr_str} has been verified " + f"to equal 24 using all the given numbers. This will be " + f"my final answer.\n{self.answer_start_token}\n" + ) + if not event.is_set(): + event_info["generated_text"] = step + event_info["feedback"] = early_stop_msg + event_info["correction_index"] = token_index + event_info["phase"] = "early_stop_answer" + event_info["verified_expression"] = expr_str + event.set() + return step, early_stop_msg + + else: + unused_str = ( + "[" + ", ".join(self._fmt(n) for n in unused) + "]" + if unused else "[]" + ) + logger.info( + f"[Phase 1] VALID PARTIAL expression '{expr_str}'\n" + f" Unused numbers: {unused_str}\n" + f" Action: No error, let model keep thinking." + ) + return step, None + + # ================================================================== + # CASE 2a: present but expression prompt not yet injected + # ================================================================== + if FINAL_EXPRESSION_PROMPT.strip() not in step: + logger.info( + "[Phase 2a] Natural detected. " + "Injecting expression extraction prompt." + ) + prompt_text = FINAL_EXPRESSION_PROMPT + if not event.is_set(): + event_info["generated_text"] = step + event_info["feedback"] = prompt_text + event_info["correction_index"] = token_index + event_info["phase"] = "inject_expression_prompt" + event.set() + return step, prompt_text + + # ================================================================== + # CASE 2b: After + expression prompt -- verify final answer + # ================================================================== + + num_corrections = self._count_feedback_blocks(step) + if num_corrections >= self.max_corrections: + fb = "\nthe answer is \\boxed{no solution}" + if not event.is_set(): + event_info["generated_text"] = step + event_info["feedback"] = fb + event_info["correction_index"] = token_index + event_info["errors"] = ["Max corrections reached"] + event_info["phase"] = "standard_verify" + event.set() + return step, fb + + think_end_pos = step.find(self.answer_start_token) + len(self.answer_start_token) + text_after_think = step[think_end_pos:] + feedback_pattern = re.compile(r'\[VERIFIER FEEDBACK[^\]]*\]\s*', re.DOTALL) + last_feedback_end = 0 + for match in feedback_pattern.finditer(text_after_think): + last_feedback_end = match.end() + recent_text = text_after_think[last_feedback_end:] + + expr_str = self._extract_boxed_expression(recent_text) + if expr_str is not None: + logger.info(f"[Phase 2b] Extracted expression from \\boxed: '{expr_str}'") + + if expr_str is None: + return step, None + + status, is_valid, errors, unused = verify_expression( + expr_str, self.original_numbers + ) + + if is_valid and status == "complete": + logger.info(f"[Phase 2b] Final expression '{expr_str}' is correct (= 24)") + if not event.is_set(): + event_info["generated_text"] = step + event_info["feedback"] = "" + event_info["correction_index"] = token_index + event_info["phase"] = "final_answer_correct" + event_info["verified_expression"] = expr_str + event.set() + return step, None + + if is_valid and status == "partial": + used_numbers = _extract_numbers_from_expr(expr_str) + errors = [ + f"Expression '{expr_str}' only uses {len(used_numbers)} of " + f"{len(self.original_numbers)} numbers. After , " + f"a COMPLETE expression using ALL numbers is required." + ] + + if not errors: + errors = [f"Expression '{expr_str}' is not a valid solution."] + + error_summary = "; ".join(errors) + logger.info(f"[Phase 2b] Final expression FAILED: {error_summary}") + + orig_display = [int(n) if n == int(n) else n for n in self.original_numbers] + nums_str = ", ".join(str(n) for n in orig_display) + feedback = ( + f"\n[VERIFIER FEEDBACK:\n" + f" The expression {expr_str} is incorrect. {error_summary}\n" + f" Do NOT reuse {expr_str} or any previously tried expression.\n" + f" Try a completely different approach. Use ALL four numbers " + f"{nums_str} exactly once, " + f"evaluating to 24. Wrap in \\boxed{{}}. ]\n" + ) + if not event.is_set(): + event_info["generated_text"] = step + event_info["feedback"] = feedback + event_info["correction_index"] = token_index + event_info["errors"] = errors + event_info["phase"] = "standard_verify" + event.set() + return step, feedback + + # ------------------------------------------------------------------ + # fix + # ------------------------------------------------------------------ + async def fix(self, generated_text: str, event_info: dict, fix_method=None): + phase = event_info.get("phase", "standard_verify") + + if phase == "rollback_to_thinking": + base_text = event_info["generated_text"] + result = base_text.rstrip() + event_info["feedback"] + logger.info( + f"[fix] Phase: rollback_to_thinking\n" + f" -> Appended error feedback into trace.\n" + f" -> Think-phase corrections: {self._think_phase_corrections}/{self.max_corrections}" + ) + return result + + if phase == "early_stop_answer": + base_text = event_info["generated_text"] + result = base_text.rstrip() + event_info["feedback"] + logger.info( + f"[fix] Phase: early_stop_answer\n" + f" -> Verified expression passed. Injecting early-stop + .\n" + f" -> Model will now generate the final answer." + ) + return result + + if phase == "final_answer_correct": + expr = event_info.get("verified_expression", "?") + logger.info( + f"[fix] Phase: final_answer_correct\n" + f" -> Final expression '{expr}' verified correct. Stopping generation." + ) + return event_info["generated_text"] + + if phase == "inject_expression_prompt": + logger.info( + f"[fix] Phase: inject_expression_prompt\n" + f" -> Natural detected.\n" + f" -> Appending expression extraction prompt." + ) + return event_info["generated_text"] + event_info["feedback"] + + # standard_verify + errors = event_info.get("errors", []) + error_summary = "; ".join(errors) if errors else "unknown" + logger.info( + f"[fix] Phase: standard_verify\n" + f" -> Expression failed: {error_summary}\n" + f" -> Appending [VERIFIER FEEDBACK] so model retries." + ) + return event_info["generated_text"] + event_info["feedback"] diff --git a/interwhen/monitors/thinkingPhaseVerifierMaze.py b/interwhen/monitors/thinkingPhaseVerifierMaze.py new file mode 100644 index 0000000..1664a27 --- /dev/null +++ b/interwhen/monitors/thinkingPhaseVerifierMaze.py @@ -0,0 +1,878 @@ +""" +Thinking-phase verifier for Maze tasks. + +Verifies maze path-tracing by forking a side-stream during the +thinking phase to ask the model about its current traced path. + +Workflow +-------- +A) **DURING the thinking phase** (inside ``...``): + After a warmup period, every *N* newlines in the thinking trace: + 1. Inject a first-person prompt to extract the traced path steps. + 2. Parse and verify each step against the maze grid. + 3. If **errors** -> inject feedback into thinking trace. + 4. If **path reaches E** -> inject early-stop + ```` + + structured format. + 5. If **partial but correct** -> no feedback, keep thinking. + +B) **AFTER ````**: + Phase 2a: Inject structured step format template. + Phase 2b: Verify each step as the model fills in the template. + Once ``\\boxed{}`` appears, stop generation. +""" + +import re +import json +import logging +from typing import Tuple, Optional +from copy import deepcopy + +import httpx + +from .base import VerifyMonitor +from ._common import find_complete_boxed +from ..utils.maze_verifier import ( + Direction, parse_direction, get_expected_turn_type, + parse_maze_from_prompt, parse_maze_step, verify_maze_step, + verify_locate_section, format_maze_feedback, format_locate_feedback, + DIRECTION_DELTAS, compute_relative_direction, +) + +logger = logging.getLogger(__name__) + + +# ===================================================================== +# Maze Thinking-Phase Prompts +# ===================================================================== + + +def _build_maze_format_block(question_type: str) -> str: + """ + Build the ... block that describes the structured + output template. Re-used by both the side-stream (Phase 1) and + the post- injection (Phase 2a). + """ + if question_type == "relative_position": + return ( + "\n" + ">>> LOCATE START AND EXIT (0-indexed, top-left is (0,0)):\n" + " S position: (row, col)\n" + " E position: (row, col)\n" + "\n" + ">>> COMPARE POSITIONS:\n" + " Row comparison: E row (r) vs S row (r) → E is ABOVE/BELOW S\n" + " Col comparison: E col (c) vs S col (c) → E is LEFT/RIGHT of S\n" + "\n" + ">>> FINAL ANSWER:\n" + " \\boxed{LETTER}\n" + "" + ) + else: + count_line = " Running count: Right=0, Left=0" + if question_type == "total_turns": + count_line = " Running count: Right=0, Left=0, Total=0" + + return ( + "\n" + ">>> LOCATE START AND EXIT (0-indexed, top-left is (0,0)):\n" + " S position: (row, col)\n" + " E position: (row, col)\n" + "\n" + ">>> STEP 1: Move DOWN from (r1, c1) to (r2, c2)\n" + " Current position: (r2, c2)\n" + " Previous direction: —\n" + " Current direction: DOWN\n" + " Turn type: STRAIGHT\n" + f"{count_line}\n" + "\n" + "[... continue for all steps until reaching E ...]\n" + "\n" + ">>> FINAL ANSWER:\n" + " \\boxed{LETTER}\n" + "" + ) + + +def _build_maze_thinking_phase_prompt(question_type: str) -> str: + """ + Build the side-stream prompt injected during the thinking phase. + + Written in the LLM's own first-person thinking voice so it blends + naturally with the ```` trace. Includes the ```` + block and the starting marker so the model begins filling in. + """ + format_block = _build_maze_format_block(question_type) + return ( + "\n\nLet me output the current steps I have traced so far " + "through the maze in the following format:\n" + f"{format_block}\n" + ">>> LOCATE START AND EXIT (0-indexed, top-left is (0,0)):\n" + ) + + +def _build_maze_structured_prompt(question_type: str) -> str: + """ + Build the structured format prompt injected after . + + This is analogous to Game24's step format injection — it gives the + model a template to fill in so we can parse and verify each step. + Written in the LLM's own voice so it reads naturally. + """ + format_block = _build_maze_format_block(question_type) + return ( + "\nLet me trace the step by step solution through the maze " + "in the following format:\n" + f"{format_block}\n" + ">>> LOCATE START AND EXIT (0-indexed, top-left is (0,0)):\n" + ) + + +# ===================================================================== +# ThinkingPhaseStepVerifierMazeMonitor +# ===================================================================== + +class ThinkingPhaseStepVerifierMazeMonitor(VerifyMonitor): + """ + Monitor that verifies maze path-tracing during and after thinking. + + **No meta-prompt required** — works with a plain user prompt containing + just the maze and question. Structure is injected by this monitor + after ```` (natural or early-stop), exactly like Game24 + injects its step format. + + Phase 1 – During ``...``: + Every N newlines (after warmup), fork a side-stream that + injects ```` + a structured step prompt, stream ~300 + tokens, parse and verify each step against the maze grid. + + Phase 2a – ```` detected, structured prompt not yet injected: + Inject the structured step-by-step format template so the model + fills it in (LOCATE → STEPs → FINAL ANSWER → ``\\boxed{}``). + + Phase 2b – Structured prompt injected, model is generating: + Verify each completed step as it appears. Once ``\\boxed{}`` + appears, signal completion. + """ + + def __init__( + self, + name: str, + grid: list, + start_pos: tuple, + exit_pos: tuple, + llm_server: dict, + prompt: str, + question_type: str = "right_turns", + newline_threshold: int = 10, + max_corrections: int = 5, + answer_start_token: str = "", + async_execution: bool = True, + warmup_newlines: int = 0, + ): + super().__init__(name) + self.grid = grid + self.start_pos = start_pos + self.exit_pos = exit_pos + self.llm_server = llm_server + self.prompt = prompt + self.question_type = question_type + self.newline_threshold = newline_threshold + self.max_corrections = max_corrections + self.answer_start_token = answer_start_token + self.async_execution = async_execution + self.warmup_newlines = warmup_newlines + + # Build the structured prompt that will be injected after + self._structured_prompt = _build_maze_structured_prompt(question_type) + # Build the thinking-phase side-stream prompt (in LLM's own voice) + self._thinking_phase_prompt = _build_maze_thinking_phase_prompt(question_type) + # A unique marker to detect whether we already injected it + self._structured_marker = ">>> LOCATE START AND EXIT (0-indexed, top-left is (0,0)):" + + # ---- state ---- + self._think_phase_corrections = 0 + self._verified_path_complete = False # True if path reaches E + + # ------------------------------------------------------------------ + # helpers + # ------------------------------------------------------------------ + def _count_feedback_blocks(self, text: str) -> int: + return len(re.findall(r'\[VERIFIER FEEDBACK[^\]]*\]', text)) + + def _is_in_thinking_phase(self, generated_text: str) -> bool: + return self.answer_start_token not in generated_text + + def _structured_prompt_injected(self, generated_text: str) -> bool: + """Check if structured format was already injected after .""" + if self.answer_start_token not in generated_text: + return False + after_think = generated_text.split(self.answer_start_token, 1)[1] + return self._structured_marker in after_think + + @staticmethod + def detect_question_type(prompt: str) -> str: + """Auto-detect question type from prompt text.""" + prompt_lower = prompt.lower() + if "right turn" in prompt_lower or "right-turn" in prompt_lower: + return "right_turns" + if "left turn" in prompt_lower or "left-turn" in prompt_lower: + return "total_turns" + if "total" in prompt_lower and "turn" in prompt_lower: + return "total_turns" + if "turn" in prompt_lower: + return "right_turns" + return "relative_position" + + def _verify_relative_position_answer(self, boxed_answer: str) -> Tuple[bool, Optional[str]]: + """Verify a relative-position boxed answer (A=Yes / B=No). + + Parses the question from ``self.prompt`` to determine the asked + direction, computes the true relative direction of E from S, + and checks whether the model's Yes/No answer is correct. + + Returns ``(is_correct, feedback_or_None)``. + """ + answer_map = {"A": "Yes", "B": "No"} + model_yn = answer_map.get(boxed_answer.strip().upper()) + if model_yn is None: + return True, None + + m = re.search( + r'Is the exit \(E\)\s+(.*?)\s+(?:of\s+)?the starting point \(S\)', + self.prompt, re.IGNORECASE, + ) + if not m: + return True, None + + asked_raw = m.group(1).strip().lower() + asked_raw = re.sub(r',.*', '', asked_raw).strip() + + actual = compute_relative_direction(self.start_pos, self.exit_pos) + + direction_keywords = { + "directly to the left": {"west"}, + "directly to the right": {"east"}, + "directly above": {"north"}, + "directly below": {"south"}, + "to the top left": {"northwest"}, + "to the top right": {"northeast"}, + "to the bottom left": {"southwest"}, + "to the bottom right": {"southeast"}, + } + + expected_dirs = direction_keywords.get(asked_raw) + if expected_dirs is None: + return True, None + + expected_yn = "Yes" if actual in expected_dirs else "No" + + if model_yn == expected_yn: + return True, None + + sr, sc = self.start_pos + er, ec = self.exit_pos + correct_letter = 'A' if expected_yn == 'Yes' else 'B' + feedback = ( + f"\n\n[VERIFIER FEEDBACK for relative position:\n" + f" ✗ Your answer {boxed_answer} ({model_yn}) is incorrect.\n" + f" IMPORTANT: In this task, \"{asked_raw}\" means the GENERAL " + f"COMPASS DIRECTION, NOT immediate adjacency. It asks whether E " + f"is in the {actual} direction from S, regardless of distance or " + f"walls between them.]\n\n" + ) + return False, feedback + + # ------------------------------------------------------------------ + # _parse_steps_from_text – parse structured steps from side-stream + # ------------------------------------------------------------------ + def _parse_steps_from_text(self, text: str): + """ + Parse all structured maze steps from text. + + Returns list of parsed step dicts. + """ + steps = [] + + step_pattern = re.compile( + r'>>>\s*STEP\s+(\d+):\s*Move\s+\w+\s+from\s+\([^)]+\)\s+to\s+\([^)]+\).*?' + r'Running count:\s*Right\s*=\s*\d+\s*,\s*Left\s*=\s*\d+[^\n]*', + re.IGNORECASE | re.DOTALL + ) + + for match in step_pattern.finditer(text): + parsed = parse_maze_step(match.group(0)) + if parsed: + steps.append(parsed) + + return steps + + def _verify_all_steps(self, steps): + """ + Verify a sequence of parsed maze steps against the grid. + + Returns: + (all_valid, first_error_step_num, errors, final_pos, final_dir, + right_count, left_count, total_count) + """ + pos = self.start_pos + direction = Direction.NONE + right_count = 0 + left_count = 0 + total_count = 0 + + for step in steps: + is_valid, errors, state = verify_maze_step( + step=step, + grid=self.grid, + expected_from_pos=pos, + prev_direction=direction, + expected_right_count=right_count, + expected_left_count=left_count, + expected_total_count=total_count, + ) + + if not is_valid: + return (False, step.get('step_num', 0), errors, + pos, direction, right_count, left_count, total_count) + + pos = state['new_pos'] + direction = state['new_direction'] + right_count = state['new_right'] + left_count = state['new_left'] + total_count = state['new_total'] + + return (True, None, [], pos, direction, + right_count, left_count, total_count) + + # ------------------------------------------------------------------ + # _side_stream_maze_steps – streams tokens to get traced path + # ------------------------------------------------------------------ + async def _side_stream_maze_steps(self, text_so_far: str, max_new_tokens: int = 300) -> str: + """ + Send ``prompt + text_so_far`` to vLLM, stream at most + *max_new_tokens* tokens, and return the generated text. + """ + logger.info( + f"[Maze Side-stream] Starting path extraction\n" + f" Maze: S={self.start_pos}, E={self.exit_pos}\n" + f" Max new tokens: {max_new_tokens}" + ) + + payload = deepcopy(self.llm_server["payload"]) + payload["prompt"] = self.prompt + text_so_far + payload["max_tokens"] = max_new_tokens + payload.pop("logprobs", None) + + generated = "" + + async with httpx.AsyncClient(timeout=None) as client: + async with client.stream( + "POST", + self.llm_server["url"], + headers=self.llm_server["headers"], + json=payload, + ) as response: + async for line in response.aiter_lines(): + if line.startswith("data: "): + data = line[len("data: "):].strip() + if data == "[DONE]": + break + chunk = json.loads(data)["choices"][0]["text"] + generated += chunk + logger.debug(f"[Maze Side-stream] chunk: {chunk!r}") + + if '\\boxed' in generated or '>>> FINAL ANSWER' in generated: + break + + logger.info( + f"[Maze Side-stream] Generated {len(generated)} chars" + ) + return generated + + # ------------------------------------------------------------------ + # _extract_boxed_answer + # ------------------------------------------------------------------ + @staticmethod + def _extract_boxed_answer(text: str) -> Optional[str]: + """Extract the content of the last \\boxed{...} in text.""" + matches = list(re.finditer(r'\\boxed\{', text)) + if not matches: + return None + last_match = matches[-1] + start = last_match.end() + brace_count = 1 + end = start + while end < len(text) and brace_count > 0: + if text[end] == '{': + brace_count += 1 + elif text[end] == '}': + brace_count -= 1 + end += 1 + return text[start:end - 1].strip() + + # ------------------------------------------------------------------ + # step_extractor + # ------------------------------------------------------------------ + def step_extractor(self, chunk: str, generated_text: str): + # ===== PHASE 1: still inside ===== + if self._is_in_thinking_phase(generated_text): + if self._think_phase_corrections >= self.max_corrections: + return False, None + + total_newlines = generated_text.count('\n') + + if total_newlines < self.warmup_newlines: + return False, None + + past_warmup = total_newlines - self.warmup_newlines + if (generated_text.endswith('\n') + and past_warmup >= 0 + and past_warmup % self.newline_threshold == 0): + logger.info( + f"[Maze step_extractor] Phase 1 trigger: \\n count={total_newlines} " + f"(warmup={self.warmup_newlines}, past_warmup={past_warmup}, " + f"threshold={self.newline_threshold})" + ) + return True, generated_text + + return False, None + + # ===== PHASE 2: after ===== + + # 2a: structured prompt not yet injected → trigger immediately + if not self._structured_prompt_injected(generated_text): + logger.info( + "[Maze step_extractor] Phase 2a: detected, " + "structured prompt not yet injected." + ) + return True, generated_text + + # 2b: structured prompt injected — verify steps / boxed answer + think_end_pos = generated_text.find(self.answer_start_token) + len(self.answer_start_token) + text_after_think = generated_text[think_end_pos:] + + last_marker_pos = text_after_think.rfind(self._structured_marker) + if last_marker_pos >= 0: + model_output_start = last_marker_pos + len(self._structured_marker) + text_after_think = text_after_think[model_output_start:] + text_start_offset = think_end_pos + model_output_start + else: + text_start_offset = think_end_pos + + feedback_pattern = re.compile(r'\[VERIFIER FEEDBACK[^\]]*\]\s*', re.DOTALL) + last_feedback_end = 0 + for match in feedback_pattern.finditer(text_after_think): + last_feedback_end = match.end() + text = text_after_think[last_feedback_end:] + text_start = text_start_offset + last_feedback_end + + if self.question_type in ("right_turns", "total_turns"): + step_pattern = re.compile( + r'(>>>\s*STEP\s+(\d+):\s*Move\s+\w+\s+from\s+\([^)]+\)\s+to\s+\([^)]+\).*?' + r'Running count:\s*Right\s*=\s*\d+\s*,\s*Left\s*=\s*\d+[^\n]*)', + re.IGNORECASE | re.DOTALL + ) + all_steps = list(step_pattern.finditer(text)) + + if all_steps: + last_step = all_steps[-1] + text_after = text[last_step.end():] + next_step = re.search(r'>>>\s*STEP\s+\d+', text_after, re.IGNORECASE) + if not next_step: + end_pos = text_start + last_step.end() + return True, generated_text[:end_pos] + return False, None + + locate_pattern = re.compile( + r'(LOCATE START AND EXIT.*?E position:\s*\([^)]+\))', + re.IGNORECASE | re.DOTALL + ) + locate_match = locate_pattern.search(text) + if locate_match: + step1_start = re.search(r'>>>\s*STEP\s+1', text[locate_match.end():], re.IGNORECASE) + if step1_start: + end_pos = text_start + locate_match.end() + return True, generated_text[:end_pos] + + boxed = find_complete_boxed(text) + if boxed: + end_pos = text_start + boxed.end() + return True, generated_text[:end_pos] + + return False, None + + # ------------------------------------------------------------------ + # verify + # ------------------------------------------------------------------ + async def verify(self, step: str, token_index: int, event, event_info): + # ================================================================== + # CASE 1: Thinking phase – side-stream path verification + # ================================================================== + if self.answer_start_token not in step: + total_dn = step.count('\n') + logger.info( + f"[Maze Phase 1] Thinking-phase verification triggered\n" + f" \\n count : {total_dn}\n" + f" Thinking len : {len(step)} chars" + ) + + text_with_prompt = step + self._thinking_phase_prompt + + side_output = await self._side_stream_maze_steps( + text_with_prompt, max_new_tokens=300 + ) + + if not side_output or len(side_output.strip()) < 20: + logger.info( + "[Maze Phase 1] Insufficient output from side-stream. " + "Letting model continue thinking." + ) + return step, None + + full_side_text = ( + ">>> LOCATE START AND EXIT (0-indexed, top-left is (0,0)):\n" + side_output + ) + + locate_valid, locate_errors = verify_locate_section( + full_side_text, self.start_pos, self.exit_pos + ) + + if not locate_valid: + self._think_phase_corrections += 1 + error_summary = "; ".join(locate_errors) + logger.info( + f"[Maze Phase 1] LOCATE section errors: {error_summary}\n" + f" Action: Inject feedback into thinking trace\n" + f" Corrections: {self._think_phase_corrections}/{self.max_corrections}" + ) + thinking_feedback = ( + f"\n\nWait, I think I have the wrong positions. " + f"{error_summary}. " + f"Let me re-examine the maze grid carefully to find S and E.\n" + ) + if not event.is_set(): + event_info["generated_text"] = step + event_info["feedback"] = thinking_feedback + event_info["correction_index"] = token_index + event_info["errors"] = locate_errors + event_info["phase"] = "rollback_to_thinking" + event.set() + return step, thinking_feedback + + steps = self._parse_steps_from_text(full_side_text) + + if not steps: + logger.info( + "[Maze Phase 1] No structured steps found in side-stream. " + "Letting model continue thinking." + ) + return step, None + + (all_valid, err_step_num, errors, final_pos, + final_dir, r_count, l_count, t_count) = self._verify_all_steps(steps) + + if not all_valid: + error_summary = "; ".join(errors) + self._think_phase_corrections += 1 + logger.info( + f"[Maze Phase 1] INVALID step {err_step_num}\n" + f" Error(s) : {error_summary}\n" + f" Action : Inject feedback into thinking trace\n" + f" Corrections: {self._think_phase_corrections}/{self.max_corrections}" + ) + thinking_feedback = ( + f"\n\nWait, I made an error at Step {err_step_num}. " + f"{error_summary}. " + f"Let me re-trace the path more carefully from the correct position.\n" + ) + if not event.is_set(): + event_info["generated_text"] = step + event_info["feedback"] = thinking_feedback + event_info["correction_index"] = token_index + event_info["errors"] = errors + event_info["phase"] = "rollback_to_thinking" + event.set() + return step, thinking_feedback + + # All steps valid — check if path is complete (reached E) + if final_pos == self.exit_pos: + self._verified_path_complete = True + logger.info( + f"[Maze Phase 1] VALID COMPLETE path to E={self.exit_pos}\n" + f" Steps: {len(steps)}, Right={r_count}, Left={l_count}, Total={t_count}\n" + f" Action: Inject early-stop + + structured format." + ) + early_stop_msg = ( + f"\n\nWait, I have successfully traced the path from " + f"S={self.start_pos} to E={self.exit_pos} with " + f"{len(steps)} steps. " + f"Right turns={r_count}, Left turns={l_count}, " + f"Total turns={t_count}. " + f"This path has been verified as correct. " + f"Let me give the final answer.\n" + f"{self.answer_start_token}" + f"{self._structured_prompt}" + ) + if not event.is_set(): + event_info["generated_text"] = step + event_info["feedback"] = early_stop_msg + event_info["correction_index"] = token_index + event_info["phase"] = "early_stop_answer" + event_info["verified_counts"] = { + "right": r_count, + "left": l_count, + "total": t_count, + "steps": len(steps), + } + event.set() + return step, early_stop_msg + + else: + logger.info( + f"[Maze Phase 1] VALID PARTIAL path\n" + f" Current pos: {final_pos}, Target: {self.exit_pos}\n" + f" Steps so far: {len(steps)}\n" + f" Action: No error, let model keep thinking." + ) + return step, None + + # ================================================================== + # CASE 2a: present but structured prompt not yet injected + # ================================================================== + if not self._structured_prompt_injected(step): + logger.info( + "[Maze Phase 2a] detected. " + "Injecting structured step format." + ) + if not event.is_set(): + event_info["generated_text"] = step + event_info["feedback"] = self._structured_prompt + event_info["correction_index"] = token_index + event_info["phase"] = "inject_structured_prompt" + event.set() + return step, self._structured_prompt + + # ================================================================== + # CASE 2b: Structured prompt injected — verify output + # ================================================================== + + num_corrections = self._count_feedback_blocks(step) + if num_corrections >= self.max_corrections: + fb = "\nthe answer is \\boxed{no solution}" + if not event.is_set(): + event_info["generated_text"] = step + event_info["feedback"] = fb + event_info["correction_index"] = token_index + event_info["errors"] = ["Max corrections reached"] + event_info["phase"] = "standard_verify" + event.set() + return step, fb + + think_end_pos = step.find(self.answer_start_token) + len(self.answer_start_token) + text_after_think = step[think_end_pos:] + + last_marker_pos = text_after_think.rfind(self._structured_marker) + if last_marker_pos >= 0: + text_after_think = text_after_think[last_marker_pos:] + + feedback_pattern = re.compile(r'\[VERIFIER FEEDBACK[^\]]*\]\s*', re.DOTALL) + last_feedback_end = 0 + for match in feedback_pattern.finditer(text_after_think): + last_feedback_end = match.end() + recent_text = text_after_think[last_feedback_end:] + + # --- Verify LOCATE section --- + locate_match = re.search(r'LOCATE START AND EXIT', recent_text, re.IGNORECASE) + if locate_match: + step1_start = re.search(r'>>>\s*STEP\s+1', recent_text, re.IGNORECASE) + if step1_start or '\\boxed' in recent_text: + if step1_start: + locate_text = recent_text[locate_match.start():step1_start.start()] + else: + locate_text = recent_text[locate_match.start():] + is_valid, loc_errors = verify_locate_section( + locate_text, self.start_pos, self.exit_pos + ) + if not is_valid: + feedback = format_locate_feedback(loc_errors) + if not event.is_set(): + event_info["generated_text"] = step + event_info["feedback"] = feedback + event_info["correction_index"] = token_index + event_info["errors"] = loc_errors + event_info["phase"] = "standard_verify" + event.set() + return step, feedback + + # --- Verify structured steps --- + if self.question_type in ("right_turns", "total_turns"): + step_pattern = re.compile( + r'(>>>\s*STEP\s+(\d+):\s*Move\s+\w+\s+from\s+\([^)]+\)\s+to\s+\([^)]+\).*?' + r'Running count:[^\n]+)', + re.IGNORECASE | re.DOTALL + ) + recent_step_matches = list(step_pattern.finditer(recent_text)) + + if recent_step_matches: + last_match = recent_step_matches[-1] + last_step_text = last_match.group(0) + last_step_num = int(last_match.group(2)) + parsed = parse_maze_step(last_step_text) + + if parsed: + all_full_matches = list(step_pattern.finditer(text_after_think)) + state = self._get_state_before_step_phase2( + text_after_think, last_step_num, all_full_matches + ) + + is_valid, errors, new_state = verify_maze_step( + step=parsed, + grid=self.grid, + expected_from_pos=state['position'], + prev_direction=state['direction'], + expected_right_count=state['right_count'], + expected_left_count=state['left_count'], + expected_total_count=state['total_count'], + ) + + if not is_valid: + feedback = format_maze_feedback(errors, last_step_num) + if not event.is_set(): + event_info["generated_text"] = step + event_info["feedback"] = feedback + event_info["correction_index"] = token_index + event_info["errors"] = errors + event_info["phase"] = "standard_verify" + event.set() + return step, feedback + + # --- Check for boxed answer --- + boxed_answer = self._extract_boxed_answer(recent_text) + if boxed_answer is not None: + logger.info(f"[Maze Phase 2b] Extracted boxed answer: {boxed_answer}") + + if self.question_type == "relative_position": + is_correct, rp_feedback = self._verify_relative_position_answer(boxed_answer) + if not is_correct and rp_feedback: + logger.info( + f"[Maze Phase 2b] Relative position answer '{boxed_answer}' is INCORRECT." + ) + if not event.is_set(): + event_info["generated_text"] = step + event_info["feedback"] = rp_feedback + event_info["correction_index"] = token_index + event_info["errors"] = [f"Wrong relative position answer: {boxed_answer}"] + event_info["phase"] = "standard_verify" + event.set() + return step, rp_feedback + + if not event.is_set(): + event_info["generated_text"] = step + event_info["feedback"] = "" + event_info["correction_index"] = token_index + event_info["phase"] = "final_answer_correct" + event.set() + return step, None + + return step, None + + # ------------------------------------------------------------------ + # _get_state_before_step_phase2 – reconstruct state for Phase 2 + # ------------------------------------------------------------------ + def _get_state_before_step_phase2(self, text: str, target_step_num: int, + all_step_matches: list) -> dict: + """Reconstruct state before a given step from Phase 2 structured output. + + When a step number appears multiple times (original + corrections after + verifier feedback), only the LAST occurrence of each step number is used, + so that corrected steps override earlier invalid ones. + """ + state = { + 'position': self.start_pos, + 'direction': Direction.NONE, + 'right_count': 0, + 'left_count': 0, + 'total_count': 0, + } + + last_by_num = {} + for match in all_step_matches: + step_num = int(match.group(2)) + if step_num >= target_step_num: + continue + last_by_num[step_num] = match + + for step_num in sorted(last_by_num.keys()): + parsed = parse_maze_step(last_by_num[step_num].group(0)) + if not parsed: + continue + + direction = parsed['direction'] + to_pos = parsed['to_pos'] + + turn_type = get_expected_turn_type(state['direction'], direction) + if turn_type == 'RIGHT_TURN': + state['right_count'] += 1 + state['total_count'] += 1 + elif turn_type == 'LEFT_TURN': + state['left_count'] += 1 + state['total_count'] += 1 + + state['position'] = to_pos + state['direction'] = direction + + return state + + # ------------------------------------------------------------------ + # fix + # ------------------------------------------------------------------ + async def fix(self, generated_text: str, event_info: dict, fix_method=None): + """Apply the appropriate fix depending on the phase.""" + phase = event_info.get("phase", "standard_verify") + + if phase == "rollback_to_thinking": + base_text = event_info["generated_text"] + result = base_text.rstrip() + event_info["feedback"] + logger.info( + f"[Maze fix] Phase: rollback_to_thinking\n" + f" -> Appended error feedback into trace.\n" + f" -> Think-phase corrections: {self._think_phase_corrections}/{self.max_corrections}" + ) + return result + + if phase == "early_stop_answer": + base_text = event_info["generated_text"] + result = base_text.rstrip() + event_info["feedback"] + counts = event_info.get("verified_counts", {}) + logger.info( + f"[Maze fix] Phase: early_stop_answer\n" + f" -> Path verified: {counts.get('steps', '?')} steps, " + f"R={counts.get('right', '?')}, L={counts.get('left', '?')}, " + f"T={counts.get('total', '?')}\n" + f" -> Injecting early-stop + + structured format." + ) + return result + + if phase == "inject_structured_prompt": + logger.info( + "[Maze fix] Phase: inject_structured_prompt\n" + " -> Appending structured step format after ." + ) + return event_info["generated_text"] + event_info["feedback"] + + if phase == "final_answer_correct": + logger.info( + f"[Maze fix] Phase: final_answer_correct\n" + f" -> Stopping generation." + ) + return event_info["generated_text"] + + # standard_verify + errors = event_info.get("errors", []) + error_summary = "; ".join(errors) if errors else "unknown" + logger.info( + f"[Maze fix] Phase: standard_verify\n" + f" -> Error: {error_summary}\n" + f" -> Appending [VERIFIER FEEDBACK] so model retries." + ) + return event_info["generated_text"] + event_info["feedback"] diff --git a/interwhen/monitors/thinkingPhaseVerifierSpatialMap.py b/interwhen/monitors/thinkingPhaseVerifierSpatialMap.py new file mode 100644 index 0000000..2c4d0e7 --- /dev/null +++ b/interwhen/monitors/thinkingPhaseVerifierSpatialMap.py @@ -0,0 +1,1023 @@ +""" +Thinking-phase verifier for SpatialMap tasks. + +Verifies spatial-map directional claims by forking a side-stream during +the thinking phase. Uses Z3 constraint solving to check whether +directional claims (e.g. "A is northeast of B") are consistent with +the stated problem constraints. + +Workflow +-------- +A) **DURING the thinking phase** (inside ``...``): + After a warmup period, every *N* newlines in the thinking trace: + 1. Inject a first-person prompt to extract parsed and derived + spatial relationships (STEP 1 pre-filled, STEP 2 generated). + 2. Parse directional claims from STEP 2 output. + 3. Verify each claim using a Z3 solver. + 4. If **errors** -> inject feedback into thinking trace. + 5. If **all valid** -> no feedback, keep thinking. + +B) **AFTER ````**: + Phase 2a: Inject structured step format template. + Phase 2b: Verify directional claims and final answer + (direction / object / counting questions) as model fills template. + Once ``\\boxed{}`` appears, stop generation. +""" + +import re +import json +import logging +from typing import Dict, List, Set, Tuple, Optional +from copy import deepcopy + +import httpx + +from .base import VerifyMonitor +from ._common import find_complete_boxed +from ..utils.spatialmap_verifier import ( + SpatialMapZ3Solver, extract_step2_claims, + parse_directional_claims_from_text, + parse_counting_question, parse_model_count_from_answer, + parse_direction_question, parse_object_question, + parse_model_boxed_answer, + get_possible_directions, get_consistent_object_options, + get_possible_count_range, + verify_spatialmap_step, format_spatialmap_feedback, +) + +logger = logging.getLogger(__name__) + + +# ===================================================================== +# SpatialMap Thinking-Phase Prompts +# ===================================================================== + + +def _build_spatialmap_format_block() -> str: + """ + Build the ``...`` block that describes the structured + output template for SpatialMap tasks. + + Re-used by both the side-stream (Phase 1) and the post-```` + injection (Phase 2a). + """ + return ( + "\n" + ">>> STEP 1: PARSE RELATIONSHIPS\n" + " - [Full Name A] is to the [direction] of [Full Name B]\n" + " - [Full Name C] is to the [direction] of [Full Name D]\n" + " [... list ALL given relationships using FULL names exactly as in the question ...]\n" + " (NO abbreviations, NO short forms, NO parenthetical aliases like 'Police Supply Store (PSS)')\n" + "\n" + ">>> STEP 2: ANALYZE SPATIAL RELATIONSHIPS\n" + " - Looking for: [target relationship / direction / count]\n" + " - [Full Name A] is to the [direction] of [Full Name B]\n" + " - [Full Name C] is to the [direction] of [Full Name D]\n" + " [... list each derived relationship as a structured claim using FULL names ...]\n" + " (Each claim MUST be in the form: '[Full Name] is to the [direction] of [Full Name]')\n" + "\n" + ">>> STEP 3: ANSWER\n" + " - [state conclusion]\n" + "\n" + ">>> FINAL ANSWER: [answer text]\n" + " \\boxed{LETTER}\n" + "" + ) + + +def _build_spatialmap_thinking_phase_prompt( + parsed_relations: List[Dict], +) -> str: + """ + Build the side-stream prompt injected during the thinking phase. + + Pre-fills STEP 1 with the known parsed relations (from the Z3 solver) + so the model jumps directly to STEP 2 analysis, maximising the chance + of producing verifiable directional claims within the token budget. + + Written in the LLM's own first-person thinking voice so it blends + naturally with the ```` trace. + """ + # Pre-fill STEP 1 from the ground-truth parsed relations + step1_lines = [] + for rel in parsed_relations: + step1_lines.append( + f" - {rel['A']} is to the {rel['direction']} of {rel['B']}" + ) + step1_body = "\n".join(step1_lines) if step1_lines else " (none)" + + return ( + "\n\nLet me organize what I have so far. I will list the given " + "relationships in STEP 1, then in STEP 2 I will state every " + "spatial claim I have derived.\n" + "IMPORTANT: I must use the FULL object names exactly as given in the question " + "(no abbreviations, no short forms, no aliases, no partial names, no parenthetical aliases like 'Store (S)').\n" + "Every claim must be in the form: '[Full Name] is to the [direction] of [Full Name]'\n" + "For direction I will use the full word: northeast, northwest, southeast, southwest, north, south, east, or west.\n\n" + ">>> STEP 1: PARSE RELATIONSHIPS (given)\n" + f"{step1_body}\n\n" + ">>> STEP 2: ANALYZE SPATIAL RELATIONSHIPS (derived)\n" + "Based on my analysis so far, the derived relationships are:\n" + ) + + +def _build_spatialmap_structured_prompt() -> str: + """ + Build the structured format prompt injected after ````. + + Analogous to the maze's structured format injection — gives the + model a template to fill in so we can parse and verify each step. + """ + format_block = _build_spatialmap_format_block() + return ( + "\nLet me solve this step by step using the structured format.\n" + "IMPORTANT: I must use the FULL names of all objects exactly as they appear in the question. " + "NO abbreviations, NO short forms, NO parenthetical aliases.\n\n" + f"{format_block}\n" + ">>> STEP 1: PARSE RELATIONSHIPS\n" + ) + + +# ===================================================================== +# ThinkingPhaseStepVerifierSpatialMapMonitor +# ===================================================================== + + +class ThinkingPhaseStepVerifierSpatialMapMonitor(VerifyMonitor): + """ + Monitor that verifies spatial-map directional claims during and after + thinking. + + **No meta-prompt required** — works with a plain user prompt containing + just the map description and question. Structure is injected by this + monitor after ```` (natural or early-stop), exactly like the + Maze monitor injects its step format. + + Phase 1 – During ``...``: + Every N newlines (after warmup), fork a side-stream that + injects a structured step prompt, stream tokens, parse directional + claims from STEP 2, and verify them against Z3. + + Phase 2a – ```` detected, structured prompt not yet injected: + Inject the structured step-by-step format template so the model + fills it in (STEP 1 → STEP 2 → STEP 3 → FINAL ANSWER → ``\\boxed{}``). + + Phase 2b – Structured prompt injected, model is generating: + Verify directional claims in STEP 2 as they appear. Once + ``\\boxed{}`` appears, signal completion. + """ + + def __init__( + self, + name: str, + problem_text: str, + llm_server: dict, + prompt: str, + newline_threshold: int = 15, + max_corrections: int = 5, + answer_start_token: str = "", + async_execution: bool = True, + warmup_newlines: int = 0, + ): + super().__init__(name) + self.problem_text = problem_text + self.llm_server = llm_server + self.prompt = prompt + self.newline_threshold = newline_threshold + self.max_corrections = max_corrections + self.answer_start_token = answer_start_token + self.async_execution = async_execution + self.warmup_newlines = warmup_newlines + + # Initialize Z3 solver with problem constraints + self.z3_solver = SpatialMapZ3Solver(problem_text) + + # Build prompts for injection + self._structured_prompt = _build_spatialmap_structured_prompt() + self._thinking_phase_prompt = _build_spatialmap_thinking_phase_prompt( + self.z3_solver.parsed_relations, + ) + # Marker to detect if structured prompt was already injected + self._structured_marker = ">>> STEP 1: PARSE RELATIONSHIPS" + + # ---- state ---- + self._think_phase_corrections = 0 + self.verified_claims: Set[Tuple[str, str, str]] = set() + + # ---- counting-question verification ---- + self._counting_question = parse_counting_question(problem_text) + self._counting_options: Dict[str, str] = {} + # Strip trailing instruction paragraph for clean option parsing + _opts_text = re.split(r'\nFirst,', problem_text, maxsplit=1)[0] + if self._counting_question: + # Parse MCQ options from problem text (e.g., "A. 5\nB. 3\nC. 0\nD. 1") + raw_opts = re.findall( + r'([A-D])\.\s*(.+?)\s*(?=[A-D]\.|$)', + _opts_text, flags=re.DOTALL, + ) + self._counting_options = { + k: v.strip().rstrip(".") for k, v in raw_opts + } + logger.info( + f"[SpatialMap] Counting question detected: " + f"direction={self._counting_question['direction']}, " + f"reference={self._counting_question['reference']}, " + f"options={self._counting_options}" + ) + self._count_feedback_given = False + self._count_feedback_blocks_count = 0 # tracks cardinal count retry attempts + + # ---- direction-question verification ---- + self._direction_question = parse_direction_question(problem_text) + if self._direction_question: + logger.info( + f"[SpatialMap] Direction question detected: " + f"entity_a={self._direction_question['entity_a']}, " + f"entity_b={self._direction_question['entity_b']}" + ) + + # ---- object-question verification ---- + self._object_question = parse_object_question(problem_text) + if self._object_question: + logger.info( + f"[SpatialMap] Object question detected: " + f"direction={self._object_question['direction']}, " + f"reference={self._object_question['reference']}" + ) + + # ---- Generic MCQ options (for direction & object Qs too) ---- + if not self._counting_options: + raw_opts = re.findall( + r'([A-D])\.\s*(.+?)\s*(?=[A-D]\.|$)', + _opts_text, flags=re.DOTALL, + ) + self._mcq_options: Dict[str, str] = { + k: v.strip().rstrip(".") for k, v in raw_opts + } + else: + self._mcq_options = dict(self._counting_options) + + # Allow multiple retries for final-answer verification + self._max_final_answer_retries = 3 + self._direction_feedback_count = 0 + self._object_feedback_count = 0 + self._diag_count_feedback_count = 0 + + @classmethod + def from_prompt( + cls, + problem_text: str, + llm_server: dict, + prompt: str, + newline_threshold: int = 15, + max_corrections: int = 5, + warmup_newlines: int = 0, + name: str = "spatialmap_thinking_verifier", + ) -> "ThinkingPhaseStepVerifierSpatialMapMonitor": + """ + Convenience factory method. + """ + return cls( + name=name, + problem_text=problem_text, + llm_server=llm_server, + prompt=prompt, + newline_threshold=newline_threshold, + max_corrections=max_corrections, + warmup_newlines=warmup_newlines, + ) + + # ------------------------------------------------------------------ + # helpers + # ------------------------------------------------------------------ + def _count_feedback_blocks(self, text: str) -> int: + return len(re.findall(r'\[VERIFIER FEEDBACK[^\]]*\]', text)) + + def _is_in_thinking_phase(self, generated_text: str) -> bool: + return self.answer_start_token not in generated_text + + def _structured_prompt_injected(self, generated_text: str) -> bool: + """Check if structured format was already injected after .""" + if self.answer_start_token not in generated_text: + return False + after_think = generated_text.split(self.answer_start_token, 1)[1] + return self._structured_marker in after_think + + def _extract_new_claims(self, text: str) -> List[Dict]: + """ + Extract new (not yet verified) directional claims from STEP 2 of + the most recent attempt (after last feedback block). + """ + feedback_pattern = re.compile(r'\[VERIFIER FEEDBACK[^\]]*\]', re.DOTALL) + last_feedback_end = 0 + for match in feedback_pattern.finditer(text): + last_feedback_end = match.end() + + text_to_check = text[last_feedback_end:] + + # Get full entity names from Z3 solver for abbreviation resolution + entity_names = list({ + k[:-2] for k in self.z3_solver.entities if k.endswith('_x') + }) + + all_claims = extract_step2_claims(text_to_check, entity_names=entity_names) + + new_claims = [] + for claim in all_claims: + claim_key = (claim['A'], claim['direction'], claim['B']) + if claim_key not in self.verified_claims: + new_claims.append(claim) + + return new_claims + + # ------------------------------------------------------------------ + # _side_stream_spatialmap – streams tokens to get analysis + # ------------------------------------------------------------------ + async def _side_stream_spatialmap(self, text_so_far: str, max_new_tokens: int = 400) -> str: + """ + Send ``prompt + text_so_far`` to vLLM, stream at most + *max_new_tokens* tokens, and return the generated text. + + ``text_so_far`` is expected to end with the structured spatial map + prompt so the model outputs its analysis steps. + """ + logger.info( + f"[SpatialMap Side-stream] Starting analysis extraction\n" + f" Relations: {len(self.z3_solver.parsed_relations)}\n" + f" Max new tokens: {max_new_tokens}" + ) + + payload = deepcopy(self.llm_server["payload"]) + payload["prompt"] = self.prompt + text_so_far + payload["max_tokens"] = max_new_tokens + payload.pop("logprobs", None) + + generated = "" + + async with httpx.AsyncClient(timeout=None) as client: + async with client.stream( + "POST", + self.llm_server["url"], + headers=self.llm_server["headers"], + json=payload, + ) as response: + async for line in response.aiter_lines(): + if line.startswith("data: "): + data = line[len("data: "):].strip() + if data == "[DONE]": + break + chunk = json.loads(data)["choices"][0]["text"] + generated += chunk + logger.debug(f"[SpatialMap Side-stream] chunk: {chunk!r}") + + # Stop if we see FINAL ANSWER or \boxed + if '\\boxed' in generated or '>>> FINAL ANSWER' in generated: + break + + logger.info( + f"[SpatialMap Side-stream] Generated {len(generated)} chars" + ) + return generated + + # ------------------------------------------------------------------ + # step_extractor + # ------------------------------------------------------------------ + def step_extractor(self, chunk: str, generated_text: str): + """ + Phase 1 (thinking): trigger at every newline_threshold multiple + (after warmup). + Phase 2 (after ): trigger on structured steps or boxed + answer. + """ + # ===== PHASE 1: still inside ===== + if self._is_in_thinking_phase(generated_text): + if self._think_phase_corrections >= self.max_corrections: + return False, None + + total_newlines = generated_text.count('\n') + + if total_newlines < self.warmup_newlines: + return False, None + + past_warmup = total_newlines - self.warmup_newlines + if (generated_text.endswith('\n') + and past_warmup >= 0 + and past_warmup % self.newline_threshold == 0): + logger.info( + f"[SpatialMap step_extractor] Phase 1 trigger: \\n count={total_newlines} " + f"(warmup={self.warmup_newlines}, past_warmup={past_warmup}, " + f"threshold={self.newline_threshold})" + ) + return True, generated_text + + return False, None + + # ===== PHASE 2: after ===== + + # 2a: structured prompt not yet injected → trigger immediately + if not self._structured_prompt_injected(generated_text): + logger.info( + "[SpatialMap step_extractor] Phase 2a: detected, " + "structured prompt not yet injected." + ) + return True, generated_text + + # 2b: structured prompt injected — verify STEP 2 claims / boxed answer + think_end_pos = generated_text.find(self.answer_start_token) + len(self.answer_start_token) + text_after_think = generated_text[think_end_pos:] + + # Strip out the injected ... template so we only + # look at actual model output (which starts after the last marker). + last_marker_pos = text_after_think.rfind(self._structured_marker) + if last_marker_pos >= 0: + model_output_start = last_marker_pos + len(self._structured_marker) + text_after_think = text_after_think[model_output_start:] + text_start_offset = think_end_pos + model_output_start + else: + text_start_offset = think_end_pos + + # Skip past feedback blocks + feedback_pattern = re.compile(r'\[VERIFIER FEEDBACK[^\]]*\]\s*', re.DOTALL) + last_feedback_end = 0 + for match in feedback_pattern.finditer(text_after_think): + last_feedback_end = match.end() + text = text_after_think[last_feedback_end:] + text_start = text_start_offset + last_feedback_end + + # Check for STEP 2 section with claims + step2_pattern = re.compile( + r'>>>\s*STEP\s*2[:\s].*?(?=>>>\s*STEP\s*3|>>>\s*FINAL|\\boxed|$)', + re.DOTALL | re.IGNORECASE + ) + step2_match = step2_pattern.search(text) + + if step2_match: + # Check if STEP 3 or FINAL has started (STEP 2 is complete) + text_after_step2 = text[step2_match.end():] + step3_or_final = re.search( + r'>>>\s*(STEP\s*3|FINAL)', + text_after_step2, + re.IGNORECASE + ) + + if step3_or_final: + new_claims = self._extract_new_claims(text) + if new_claims: + end_pos = text_start + step2_match.end() + return True, generated_text[:end_pos] + + # Check for boxed answer (trigger final verification) + boxed_match = find_complete_boxed(text) + if boxed_match: + new_claims = self._extract_new_claims(text) + if new_claims: + end_pos = text_start + boxed_match.end() + return True, generated_text[:end_pos] + # Even if no new claims, boxed answer signals completion + end_pos = text_start + boxed_match.end() + return True, generated_text[:end_pos] + + return False, None + + # ------------------------------------------------------------------ + # verify + # ------------------------------------------------------------------ + async def verify(self, step: str, token_index: int, event, event_info): + """ + Case 1 -- still in thinking (no ): + Fork side-stream, parse claims, verify with Z3. + Case 2 -- after : + 2a: Inject structured prompt. + 2b: Verify STEP 2 claims and/or final answer. + """ + + # ================================================================== + # CASE 1: Thinking phase – side-stream verification + # ================================================================== + if self.answer_start_token not in step: + total_dn = step.count('\n') + logger.info( + f"[SpatialMap Phase 1] Thinking-phase verification triggered\n" + f" \\n count : {total_dn}\n" + f" Thinking len : {len(step)} chars" + ) + + # Build text with injected prompt for analysis extraction + text_with_prompt = step + self._thinking_phase_prompt + + # Side-stream: get analysis from the model + side_output = await self._side_stream_spatialmap( + text_with_prompt, max_new_tokens=800 + ) + + if not side_output or len(side_output.strip()) < 20: + logger.info( + "[SpatialMap Phase 1] Insufficient output from side-stream. " + "Letting model continue thinking." + ) + return step, None + + # Parse directional claims directly from the side-stream output. + # The prompt pre-fills STEP 1 and ends at ">>> STEP 2:", so the + # model's output is already STEP 2 content — no header to search for. + entity_names = list({ + k[:-2] for k in self.z3_solver.entities if k.endswith('_x') + }) + claims = parse_directional_claims_from_text( + side_output, entity_names=entity_names + ) + + logger.info( + f"[SpatialMap Phase 1] Parsed {len(claims)} claims from side-stream.\n" + f" Side-stream output (first 500 chars): {side_output[:500]!r}" + ) + + if not claims: + logger.info( + "[SpatialMap Phase 1] No directional claims found in side-stream. " + "Letting model continue thinking." + ) + return step, None + + # Verify each claim against Z3 + for claim in claims: + claim_key = (claim['A'], claim['direction'], claim['B']) + if claim_key in self.verified_claims: + continue + + is_valid, errors = verify_spatialmap_step( + claim=claim, + z3_solver=self.z3_solver, + add_if_valid=True, + ) + self.verified_claims.add(claim_key) + + if not is_valid: + self._think_phase_corrections += 1 + error_summary = "; ".join(errors) + logger.info( + f"[SpatialMap Phase 1] INVALID claim: " + f"{claim['A']} is {claim['direction']} of {claim['B']}\n" + f" Error(s) : {error_summary}\n" + f" Corrections: {self._think_phase_corrections}/{self.max_corrections}" + ) + thinking_feedback = ( + f"\n\nWait, I think I made an error in my spatial reasoning. " + f"{error_summary}. " + f"Let me re-examine the relationships more carefully.\n" + ) + if not event.is_set(): + event_info["generated_text"] = step + event_info["feedback"] = thinking_feedback + event_info["correction_index"] = token_index + event_info["errors"] = errors + event_info["phase"] = "rollback_to_thinking" + event.set() + return step, thinking_feedback + + # All claims valid + logger.info( + f"[SpatialMap Phase 1] All {len(claims)} claims valid. " + f"Letting model continue thinking." + ) + return step, None + + # ================================================================== + # CASE 2a: present but structured prompt not yet injected + # ================================================================== + if not self._structured_prompt_injected(step): + logger.info( + "[SpatialMap Phase 2a] detected. " + "Injecting structured step format." + ) + if not event.is_set(): + event_info["generated_text"] = step + event_info["feedback"] = self._structured_prompt + event_info["correction_index"] = token_index + event_info["phase"] = "inject_structured_prompt" + event.set() + return step, self._structured_prompt + + # ================================================================== + # CASE 2b: Structured prompt injected — verify output + # ================================================================== + num_corrections = self._count_feedback_blocks(step) + if num_corrections >= self.max_corrections: + fb = "\nthe answer is \\boxed{no solution}" + if not event.is_set(): + event_info["generated_text"] = step + event_info["feedback"] = fb + event_info["correction_index"] = token_index + event_info["errors"] = ["Max corrections reached"] + event_info["phase"] = "standard_verify" + event.set() + return step, fb + + think_end_pos = step.find(self.answer_start_token) + len(self.answer_start_token) + text_after_think = step[think_end_pos:] + + # Strip the injected template — only look at model output after marker + last_marker_pos = text_after_think.rfind(self._structured_marker) + if last_marker_pos >= 0: + text_after_think = text_after_think[last_marker_pos:] + + feedback_pattern = re.compile(r'\[VERIFIER FEEDBACK[^\]]*\]\s*', re.DOTALL) + last_feedback_end = 0 + for match in feedback_pattern.finditer(text_after_think): + last_feedback_end = match.end() + recent_text = text_after_think[last_feedback_end:] + + # --- Verify STEP 2 claims --- + new_claims = self._extract_new_claims(recent_text) + + for claim in new_claims: + claim_key = (claim['A'], claim['direction'], claim['B']) + + is_valid, errors = verify_spatialmap_step( + claim=claim, + z3_solver=self.z3_solver, + add_if_valid=True, + ) + self.verified_claims.add(claim_key) + + if not is_valid: + feedback = format_spatialmap_feedback(errors, claim) + if not event.is_set(): + event_info["generated_text"] = step + event_info["feedback"] = feedback + event_info["correction_index"] = token_index + event_info["errors"] = errors + event_info["failed_step"] = claim + event_info["phase"] = "standard_verify" + event.set() + return step, feedback + + # --- Check for boxed answer --- + boxed_match = find_complete_boxed(recent_text) + if boxed_match: + + # ========================================================== + # Direction-question verification + # ========================================================== + if ( + self._direction_question + and num_corrections < self.max_corrections + and self._direction_feedback_count < self._max_final_answer_retries + ): + model_dir_text = parse_model_boxed_answer( + recent_text, self._mcq_options + ) + if model_dir_text: + possible = get_possible_directions( + self.z3_solver, + self._direction_question["entity_a"], + self._direction_question["entity_b"], + ) + logger.info( + f"[SpatialMap Phase 2b] Direction check: " + f"model={model_dir_text}, possible={possible}" + ) + if model_dir_text not in possible: + self._direction_feedback_count += 1 + # Find which MCQ options are consistent + valid_options = [ + letter for letter, val in self._mcq_options.items() + if val.strip().lower().rstrip(".") in possible + ] + if len(valid_options) == 1: + feedback = ( + f"\n\n[VERIFIER FEEDBACK: Direction error!\n" + f" '{model_dir_text.title()}' is " + f"impossible for " + f"{self._direction_question['entity_a']} " + f"relative to " + f"{self._direction_question['entity_b']} " + f"based on the given constraints.]\n\n" + f">>> STEP 3: ANSWER\n" + ) + else: + feedback = ( + f"\n\n[VERIFIER FEEDBACK: Direction error!\n" + f" '{model_dir_text.title()}' is " + f"impossible for " + f"{self._direction_question['entity_a']} " + f"relative to " + f"{self._direction_question['entity_b']} " + f"based on the given constraints.\n" + f" Please reconsider and choose the " + f"correct option.]\n\n" + f">>> STEP 3: ANSWER\n" + ) + if not event.is_set(): + event_info["generated_text"] = step + event_info["feedback"] = feedback + event_info["correction_index"] = token_index + event_info["errors"] = [ + f"Direction '{model_dir_text}' impossible; " + f"possible: {possible}" + ] + event_info["phase"] = "standard_verify" + event.set() + return step, feedback + + # ========================================================== + # Object-question verification + # ========================================================== + if ( + self._object_question + and num_corrections < self.max_corrections + and self._object_feedback_count < self._max_final_answer_retries + ): + model_obj_text = parse_model_boxed_answer( + recent_text, self._mcq_options + ) + boxed_raw = re.findall( + r'\\boxed\{([^}]*)\}', recent_text + ) + model_letter = ( + boxed_raw[-1].strip().upper() if boxed_raw else None + ) + + if model_letter: + consistent = get_consistent_object_options( + self.z3_solver, + self._object_question["direction"], + self._object_question["reference"], + self._mcq_options, + ) + logger.info( + f"[SpatialMap Phase 2b] Object check: " + f"model={model_letter}, " + f"consistent_options={consistent}" + ) + if model_letter not in consistent: + self._object_feedback_count += 1 + odir = self._object_question["direction"] + oref = self._object_question["reference"] + if len(consistent) == 1: + correct_name = self._mcq_options.get( + consistent[0], consistent[0] + ) + feedback = ( + f"\n\n[VERIFIER FEEDBACK: Object error!\n" + f" '{model_obj_text}' cannot be " + f"{odir} of {oref} based on the " + f"given constraints.\n" + f" The only consistent option is " + f"{consistent[0]}. {correct_name}.\n" + f" Please select option " + f"{consistent[0]}.]\n\n" + f">>> STEP 3: ANSWER\n" + ) + else: + valid_names = [ + f"{l}. {self._mcq_options.get(l, l)}" + for l in consistent + ] + feedback = ( + f"\n\n[VERIFIER FEEDBACK: Object error!\n" + f" '{model_obj_text}' cannot be " + f"{odir} of {oref} based on the " + f"given constraints.\n" + f" The consistent options are: " + f"{', '.join(valid_names)}.\n" + f" Please reconsider and choose the " + f"correct option.]\n\n" + f">>> STEP 3: ANSWER\n" + ) + if not event.is_set(): + event_info["generated_text"] = step + event_info["feedback"] = feedback + event_info["correction_index"] = token_index + event_info["errors"] = [ + f"Object '{model_obj_text}' impossible " + f"in {odir} of {oref}; " + f"consistent: {consistent}" + ] + event_info["phase"] = "standard_verify" + event.set() + return step, feedback + + # ========================================================== + # Counting-question verification (cardinal + diagonal) + # ========================================================== + if ( + self._counting_question + and num_corrections < self.max_corrections + ): + direction = self._counting_question["direction"] + reference = self._counting_question["reference"] + is_cardinal = direction in ( + "north", "south", "east", "west" + ) + + if is_cardinal: + # --- Cardinal: GT is always 0 --- + # All spatial constraints in this dataset are diagonal + # (NE, NW, SE, SW), so no object can be strictly + # north/south/east/west of another. The answer is + # always 0. + model_count = parse_model_count_from_answer( + recent_text, self._counting_options + ) + z3_count = 0 + + logger.info( + f"[SpatialMap Phase 2b] Cardinal count check: " + f"model={model_count}, expected={z3_count}, " + f"direction={direction}, reference={reference}" + ) + + if ( + model_count is not None + and model_count != z3_count + ): + self._count_feedback_given = True + count_corrections = self._count_feedback_blocks_count + self._count_feedback_blocks_count = count_corrections + 1 + + # Build direction-specific examples of what does NOT count + if direction in ("north", "south"): + diag_examples = "northeast or northwest" + elif direction == "west": + diag_examples = "northwest or southwest" + else: # east + diag_examples = "northeast or southeast" + + feedback = ( + f"\n\n[VERIFIER FEEDBACK: Count mismatch!\n" + f" You answered {model_count} objects " + f"'{direction}' of {reference}, but this " + f"count is incorrect.\n" + f" IMPORTANT: '{direction}' is a strict " + f"cardinal direction — it means ONLY " + f"exactly {direction}, NOT {diag_examples}." + f"\n" + f" An object that is {diag_examples.split(' or ')[0]} of " + f"{reference} is NOT {direction} of " + f"{reference}.\n" + f" Re-examine each object: is it described " + f"as being strictly '{direction} of' " + f"{reference}, or is the relationship " + f"actually a diagonal direction like " + f"{diag_examples}? Only count objects that " + f"are strictly {direction}.]\n\n" + f">>> STEP 3: ANSWER\n" + ) + + logger.info( + f"[SpatialMap Phase 2b] Cardinal count " + f"mismatch: model={model_count}, " + f"expected=0. Injecting feedback " + f"(attempt={'1st' if not self._count_feedback_given else '2nd'})." + ) + if not event.is_set(): + event_info["generated_text"] = step + event_info["feedback"] = feedback + event_info["correction_index"] = token_index + event_info["errors"] = [ + f"Cardinal count mismatch: expected 0, " + f"got {model_count}" + ] + event_info["phase"] = "standard_verify" + event.set() + return step, feedback + + else: + # --- Diagonal: use Z3 range check --- + if self._diag_count_feedback_count < self._max_final_answer_retries: + model_count = parse_model_count_from_answer( + recent_text, self._counting_options + ) + count_range = get_possible_count_range( + self.z3_solver, reference, direction + ) + + if ( + model_count is not None + and count_range is not None + ): + min_c, max_c = count_range + logger.info( + f"[SpatialMap Phase 2b] Diagonal count " + f"check: model={model_count}, " + f"range=[{min_c}, {max_c}], " + f"direction={direction}, " + f"reference={reference}" + ) + + if not (min_c <= model_count <= max_c): + self._diag_count_feedback_count += 1 + # Find valid MCQ options + valid_opts = [] + for opt, val in ( + self._counting_options.items() + ): + try: + v = int(val) + if min_c <= v <= max_c: + valid_opts.append( + (opt, v) + ) + except (ValueError, TypeError): + pass + + if len(valid_opts) == 1: + feedback = ( + f"\n\n[VERIFIER FEEDBACK: " + f"Count error!\n" + f" {model_count} objects " + f"'{direction}' of {reference}" + f" is impossible.\n" + f" The valid count is " + f"{valid_opts[0][1]}.\n" + f" Please select option " + f"{valid_opts[0][0]}.]\n\n" + f">>> STEP 3: ANSWER\n" + ) + else: + feedback = ( + f"\n\n[VERIFIER FEEDBACK: " + f"Count error!\n" + f" {model_count} objects " + f"'{direction}' of {reference}" + f" is impossible.\n" + f" The possible count range " + f"is [{min_c}, {max_c}].\n" + f" Please reconsider and " + f"choose the correct " + f"option.]\n\n" + f">>> STEP 3: ANSWER\n" + ) + + if not event.is_set(): + event_info["generated_text"] = step + event_info["feedback"] = feedback + event_info["correction_index"] = ( + token_index + ) + event_info["errors"] = [ + f"Diagonal count " + f"{model_count} outside " + f"range [{min_c}, {max_c}]" + ] + event_info["phase"] = ( + "standard_verify" + ) + event.set() + return step, feedback + + logger.info( + f"[SpatialMap Phase 2b] Boxed answer found. Stopping." + ) + if not event.is_set(): + event_info["generated_text"] = step + event_info["feedback"] = "" + event_info["correction_index"] = token_index + event_info["phase"] = "final_answer_correct" + event.set() + return step, None + + # All claims valid, no boxed yet + return step, None + + # ------------------------------------------------------------------ + # fix + # ------------------------------------------------------------------ + async def fix(self, generated_text: str, event_info: dict, fix_method=None): + """Apply the appropriate fix depending on the phase.""" + phase = event_info.get("phase", "standard_verify") + + if phase == "rollback_to_thinking": + base_text = event_info["generated_text"] + result = base_text.rstrip() + event_info["feedback"] + logger.info( + f"[SpatialMap fix] Phase: rollback_to_thinking\n" + f" -> Appended error feedback into trace.\n" + f" -> Think-phase corrections: {self._think_phase_corrections}/{self.max_corrections}" + ) + return result + + if phase == "inject_structured_prompt": + logger.info( + "[SpatialMap fix] Phase: inject_structured_prompt\n" + " -> Appending structured step format after ." + ) + return event_info["generated_text"] + event_info["feedback"] + + if phase == "final_answer_correct": + logger.info( + "[SpatialMap fix] Phase: final_answer_correct\n" + " -> Stopping generation." + ) + return event_info["generated_text"] + + # standard_verify + errors = event_info.get("errors", []) + error_summary = "; ".join(errors) if errors else "unknown" + logger.info( + f"[SpatialMap fix] Phase: standard_verify\n" + f" -> Error: {error_summary}\n" + f" -> Appending [VERIFIER FEEDBACK] so model retries." + ) + return event_info["generated_text"] + event_info["feedback"] diff --git a/interwhen/utils/maze_verifier.py b/interwhen/utils/maze_verifier.py index 19226c9..1fea92d 100644 --- a/interwhen/utils/maze_verifier.py +++ b/interwhen/utils/maze_verifier.py @@ -46,11 +46,44 @@ class Direction(Enum): } +# Map alternative direction names (cardinal) to the canonical enum names +_DIRECTION_ALIASES = { + 'NORTH': 'UP', + 'SOUTH': 'DOWN', + 'EAST': 'RIGHT', + 'WEST': 'LEFT', +} + +# Reverse mapping: enum name -> cardinal name (for feedback messages) +_CARDINAL_NAMES = { + Direction.UP: 'NORTH', + Direction.DOWN: 'SOUTH', + Direction.LEFT: 'WEST', + Direction.RIGHT: 'EAST', + Direction.NONE: 'NONE', +} + + +def cardinal_name(d: Direction) -> str: + """Return the cardinal compass name for a Direction enum value. + + Used in feedback messages so that the model (which often thinks in + NORTH/SOUTH/EAST/WEST terms) can understand corrections. + """ + return _CARDINAL_NAMES.get(d, d.name) + + def parse_direction(dir_str: str) -> Direction: - """Parse direction string to Direction enum.""" + """Parse direction string to Direction enum. + + Accepts canonical names (UP/DOWN/LEFT/RIGHT) **and** cardinal names + (NORTH/SOUTH/EAST/WEST). + """ dir_str = dir_str.strip().upper() if dir_str in ['—', '-', 'NONE', '']: return Direction.NONE + # Resolve cardinal aliases + dir_str = _DIRECTION_ALIASES.get(dir_str, dir_str) try: return Direction[dir_str] except KeyError: @@ -68,6 +101,35 @@ def get_expected_turn_type(prev_dir: Direction, curr_dir: Direction) -> str: return 'UNKNOWN' +# Map common turn type variations to canonical names +_TURN_TYPE_ALIASES = { + 'RIGHT': 'RIGHT_TURN', + 'RIGHT TURN': 'RIGHT_TURN', + 'RIGHT_TURN': 'RIGHT_TURN', + 'RIGHTTURN': 'RIGHT_TURN', + 'LEFT': 'LEFT_TURN', + 'LEFT TURN': 'LEFT_TURN', + 'LEFT_TURN': 'LEFT_TURN', + 'LEFTTURN': 'LEFT_TURN', + 'STRAIGHT': 'STRAIGHT', + 'NONE': 'STRAIGHT', + 'NO TURN': 'STRAIGHT', + 'NO_TURN': 'STRAIGHT', + 'NOTURN': 'STRAIGHT', +} + + +def normalize_turn_type(turn_str: str) -> str: + """Normalize a claimed turn type string to canonical form. + + Accepts common variations such as ``RIGHT``, ``RIGHT TURN``, + ``RIGHT_TURN``, ``RIGHTTURN`` (case-insensitive) and maps them to + the canonical ``RIGHT_TURN`` / ``LEFT_TURN`` / ``STRAIGHT``. + """ + turn_str = turn_str.strip().upper() + return _TURN_TYPE_ALIASES.get(turn_str, turn_str) + + def parse_maze_from_prompt(prompt: str) -> Tuple[List[List[str]], Optional[Tuple[int, int]], Optional[Tuple[int, int]]]: """ Parse maze from prompt. Returns (grid, start_pos, exit_pos). @@ -80,10 +142,22 @@ def parse_maze_from_prompt(prompt: str) -> Tuple[List[List[str]], Optional[Tuple for line in lines: stripped = line.strip() - if stripped.startswith('#') and all(c in '#XSEX ' for c in stripped): - in_maze = True - current_maze.append(stripped) - elif in_maze: + # Some dataset entries glue the last maze row to description text, + # e.g. "#######, where the symbols ...". Strip everything from the + # first character that isn't a valid maze cell. + if stripped.startswith('#'): + maze_part = "" + for ch in stripped: + if ch in '# XSEX': + maze_part += ch + else: + break + maze_part = maze_part.rstrip() + if maze_part and all(c in '#XSEX ' for c in maze_part): + in_maze = True + current_maze.append(maze_part) + continue + if in_maze: if current_maze: all_mazes.append(current_maze) current_maze = [] @@ -162,10 +236,16 @@ def parse_maze_step(step_text: str) -> Optional[Dict[str, Any]]: else: result['claimed_curr_dir'] = None - # Extract turn type - turn_match = re.search(r'Turn type:\s*(\S+)', step_text) + # Extract turn type (handle multi-word like 'RIGHT TURN', 'LEFT_TURN', etc.) + # Strip parenthetical comments like 'RIGHT (DOWN → LEFT is a RIGHT turn)' + turn_match = re.search(r'Turn type:\s*(.+)', step_text) if turn_match: - result['claimed_turn'] = turn_match.group(1).upper() + turn_raw = turn_match.group(1).strip() + # Remove parenthetical comments: "RIGHT (DOWN → LEFT ...)" → "RIGHT" + turn_raw = re.sub(r'\s*\(.*', '', turn_raw) + # Also strip trailing punctuation/whitespace + turn_raw = turn_raw.strip().rstrip(':') + result['claimed_turn'] = normalize_turn_type(turn_raw) else: result['claimed_turn'] = None @@ -239,7 +319,10 @@ def verify_maze_step( if expected_delta: actual_delta = (to_pos[0] - from_pos[0], to_pos[1] - from_pos[1]) if actual_delta != expected_delta: - errors.append(f"Move {direction.name} doesn't match delta {actual_delta}, expected {expected_delta}") + errors.append( + f"Move {cardinal_name(direction)} from {from_pos} to {to_pos} has delta {actual_delta}, " + f"but {cardinal_name(direction)} should have delta {expected_delta} (row_change, col_change)" + ) # 3. Verify to_pos is walkable (not a wall) if 0 <= to_pos[0] < len(grid) and 0 <= to_pos[1] < len(grid[0]): @@ -256,7 +339,18 @@ def verify_maze_step( # 5. Verify turn type expected_turn = get_expected_turn_type(prev_direction, direction) if claimed_turn is not None and claimed_turn != expected_turn: - errors.append(f"Turn type {claimed_turn} should be {expected_turn} (prev={prev_direction.name}, curr={direction.name})") + prev_card = cardinal_name(prev_direction) + curr_card = cardinal_name(direction) + if expected_turn == 'RIGHT_TURN': + clock_desc = "clockwise (RIGHT turn)" + elif expected_turn == 'LEFT_TURN': + clock_desc = "counterclockwise (LEFT turn)" + else: + clock_desc = "no turn (STRAIGHT)" + errors.append( + f"Turn type {claimed_turn} should be {expected_turn}. " + f"Going from {prev_card} to {curr_card} is a {clock_desc} rotation." + ) # 6. Calculate expected counts after this step new_right = expected_right_count @@ -272,11 +366,11 @@ def verify_maze_step( # 7. Verify running counts if claimed_right is not None and claimed_right != new_right: - errors.append(f"Right count {claimed_right} should be {new_right}") + errors.append(f"Right turn count {claimed_right} should be {new_right}") if claimed_left is not None and claimed_left != new_left: - errors.append(f"Left count {claimed_left} should be {new_left}") + errors.append(f"Left turn count {claimed_left} should be {new_left}") if claimed_total is not None and claimed_total != new_total: - errors.append(f"Total count {claimed_total} should be {new_total}") + errors.append(f"Total turn count {claimed_total} should be {new_total}") # Update state for next step state['new_pos'] = to_pos @@ -321,7 +415,12 @@ def format_maze_feedback(errors: List[str], step_num: int) -> str: feedback = f"\n\n[VERIFIER FEEDBACK for Step {step_num}:\n" for err in errors: feedback += f" ✗ {err}\n" - feedback += "Please correct this step and continue.]\n\n" + feedback += ( + "IMPORTANT: Clockwise on a compass is NORTH→EAST→SOUTH→WEST→NORTH. " + "A RIGHT turn = 90° clockwise; a LEFT turn = 90° counterclockwise. " + "For example: SOUTH→WEST is RIGHT (clockwise), SOUTH→EAST is LEFT (counterclockwise). " + "Please correct this step and continue.]\n\n" + ) return feedback @@ -333,7 +432,13 @@ def format_locate_feedback(errors: List[str]) -> str: feedback = "\n\n[VERIFIER FEEDBACK for LOCATE section:\n" for err in errors: feedback += f" ✗ {err}\n" - feedback += "Please correct the start/exit positions and continue.]\n\n" + feedback += ( + "IMPORTANT: Coordinates are 0-indexed (row 0, col 0 is the top-left corner). " + "Do NOT use 1-indexed coordinates. " + "For example, if S is in the first row and first open column, " + "that is (0, 1) not (1, 1) or (1, 2).\n" + "Please correct the start/exit positions and continue.]\n\n" + ) return feedback diff --git a/interwhen/utils/spatialmap_verifier.py b/interwhen/utils/spatialmap_verifier.py index 2ea94e4..23e0c8f 100644 --- a/interwhen/utils/spatialmap_verifier.py +++ b/interwhen/utils/spatialmap_verifier.py @@ -13,7 +13,7 @@ import re from typing import Dict, List, Tuple, Optional, Set -from z3 import Solver, Real, And, sat +from z3 import Solver, Real, And, Not, sat, unsat class SpatialMapZ3Solver: @@ -176,8 +176,149 @@ def check_with_new_constraint(self, ir: Dict) -> bool: def is_satisfiable(self) -> bool: return self.solver.check() == sat + def count_objects_in_direction( + self, reference: str, direction: str + ) -> Optional[int]: + """ + Count how many entities are in a **strict** direction from *reference*. + + For cardinal directions the semantics are strict: + - "north" → same x, higher y (but see note below) + - "south" → same x, lower y + - "east" → higher x, same y + - "west" → lower x, same y + + However, since every constraint in the SpatialMap dataset is diagonal + (NE/NW/SE/SW), no two objects can share an x- or y-coordinate. + Therefore the strict-cardinal count is always **0** whenever the + problem only has diagonal constraints — which is exactly the + ground-truth expectation. + + For diagonal directions: + - "northeast" → higher x AND higher y + - "northwest" → lower x AND higher y + - "southeast" → higher x AND lower y + - "southwest" → lower x AND lower y + + Returns the count, or ``None`` if the solver cannot determine it + (e.g. reference entity not found). + """ + direction = direction.lower().strip() + + # Resolve the reference entity's variable names + ref_x_key = f"{reference}_x" + ref_y_key = f"{reference}_y" + if ref_x_key not in self.entities: + # Try fuzzy match — dataset names may differ in whitespace + for key in self.entities: + if key.endswith("_x") and reference.lower() in key.lower(): + ref_x_key = key + ref_y_key = key.replace("_x", "_y") + reference = key[:-2] + break + else: + return None + + ref_x = self.entities[ref_x_key] + ref_y = self.entities[ref_y_key] + + # Collect all other entity names (unique base names) + all_entities = set() + for key in self.entities: + if key.endswith("_x"): + ename = key[:-2] + if ename != reference: + all_entities.add(ename) + + # Determine x/y constraints for the direction + is_cardinal = direction in ("north", "south", "east", "west") + + # Since all given constraints are strictly diagonal, any pair of + # objects cannot share the same x- or y-coordinate. Cardinal + # directions require an exact match on one axis, which is impossible. + if is_cardinal: + return 0 + + # For diagonal directions, check each entity with Z3 + count = 0 + for ename in all_entities: + e_x = self.entities[f"{ename}_x"] + e_y = self.entities[f"{ename}_y"] + + if direction == "northeast": + constraint = And(e_x > ref_x, e_y > ref_y) + elif direction == "northwest": + constraint = And(e_x < ref_x, e_y > ref_y) + elif direction == "southeast": + constraint = And(e_x > ref_x, e_y < ref_y) + elif direction == "southwest": + constraint = And(e_x < ref_x, e_y < ref_y) + else: + continue + + # Check if this entity MUST be in that direction + # (i.e. the negation is unsatisfiable) + self.solver.push() + self.solver.add(Not(constraint)) + must_be = self.solver.check() == unsat + self.solver.pop() + + if must_be: + count += 1 -def parse_directional_claims_from_text(text: str) -> List[Dict]: + return count + + +def parse_counting_question(problem_text: str) -> Optional[Dict]: + """ + If the problem asks a *counting* question ("How many objects are in + the X of Y?"), return a dict with the direction and reference entity. + + Returns ``None`` for non-counting questions. + """ + m = re.search( + r'How many objects are in the (\w+) of ([^?]+?)\?', + problem_text, + re.IGNORECASE, + ) + if not m: + return None + return { + "direction": m.group(1).strip().lower(), + "reference": m.group(2).strip().rstrip("."), + } + + +def parse_model_count_from_answer(text_after_think: str, options: dict = None) -> Optional[int]: + """ + Extract the numeric count the model chose from its ``\\boxed{}`` answer. + + Looks for ``\\boxed{LETTER}`` then maps through *options* to get the + numeric value. Falls back to extracting a number directly. + """ + boxed = re.findall(r'\\boxed\{([^}]*)\}', text_after_think) + if not boxed: + return None + answer = boxed[-1].strip() + + # If options mapping is provided, resolve letter → value + if options and answer in options: + try: + return int(options[answer]) + except (ValueError, TypeError): + return None + + # Try direct numeric + try: + return int(answer) + except (ValueError, TypeError): + return None + + +def parse_directional_claims_from_text( + text: str, + entity_names: Optional[List[str]] = None, +) -> List[Dict]: """ Parse directional claims from model output text. @@ -185,15 +326,85 @@ def parse_directional_claims_from_text(text: str) -> List[Dict]: - "X is to the northwest of Y" - "X is NORTHWEST of Y" - "X is northwest of Y" (affirmative claims) + - "X is NW of Y" (abbreviated directions) + - "[X] is to the northwest of [Y]" (bracket-wrapped names) + + If *entity_names* is provided, single-letter or short abbreviations + in parsed claims will be resolved to the closest full entity name. + Parenthetical aliases like '(L)', '(Mo)' are stripped before parsing. Returns list of IR dicts: [{"A": ..., "direction": ..., "B": ...}, ...] """ + # Build abbreviation → full-name map from entity_names. + # When multiple entities share the same abbreviation, mark it as + # ambiguous (map to None) so we don't silently pick the wrong one. + abbrev_to_full: Dict[str, Optional[str]] = {} + if entity_names: + for name in entity_names: + words = re.split(r"[\s']+", name) + capitals = [w[0] for w in words if w and w[0].isupper()] + candidates: List[str] = [] + if capitals: + candidates.append(capitals[0]) # e.g. "M" + if len(capitals) >= 2: + candidates.append(''.join(capitals[:2])) # e.g. "MG" + candidates.append(''.join(capitals)) # e.g. "MGM" + first_word = words[0] if words else '' + if len(first_word) >= 2: + candidates.append(first_word[:2]) # e.g. "Mi" + if first_word: + candidates.append(first_word) # e.g. "Miniature" + + for abbr in candidates: + if abbr in abbrev_to_full: + if abbrev_to_full[abbr] != name: + # Ambiguous — mark as None so we skip it + abbrev_to_full[abbr] = None + else: + abbrev_to_full[abbr] = name + + # Remove ambiguous entries + abbrev_to_full = {k: v for k, v in abbrev_to_full.items() if v is not None} + + # Expand abbreviated directions before parsing + abbrev_map = { + 'NW': 'northwest', 'NE': 'northeast', + 'SW': 'southwest', 'SE': 'southeast', + } + expanded_text = text + for abbr, full in abbrev_map.items(): + # Replace standalone abbreviations like "is NE of" → "is northeast of" + expanded_text = re.sub( + rf'\b{abbr}\b(?=\s+of\b)', full, expanded_text + ) + + # Strip square brackets around entity names: [Foo Bar] → Foo Bar + expanded_text = re.sub(r'\[([A-Z][A-Za-z\'\s]*?)\]', r'\1', expanded_text) + + # Strip parenthetical aliases like (L), (M), (Mo), (IQC) — but not + # coordinate tuples like (0,0) or (a, b) + expanded_text = re.sub(r'\s*\([A-Z][A-Za-z]{0,3}\)', '', expanded_text) + claims = [] # Pattern: "X is (to the) DIRECTION of Y" - pattern = r"([A-Z][A-Za-z'][A-Za-z'\s]*?)\s+is\s+(?:to\s+the\s+)?(northwest|northeast|southwest|southeast|north|south|east|west)\s+of\s+([A-Z][A-Za-z'][A-Za-z'\s]*?)(?:\.|,|\s*[→✓✗]|\s*$|\s+(?:and|so|which|therefore|thus|but|\())" + # Terminators include ⇒ for arrow-style claims. + # Entity capture allows single uppercase letters (resolved via abbrev map) + # or multi-word names starting with uppercase. + entity_pat = r"([A-Z][A-Za-z'][A-Za-z'\s]*?|[A-Z][a-z]?)" + pattern = ( + entity_pat + + r"\s+is\s+(?:to\s+the\s+)?" + r"(northwest|northeast|southwest|southeast|north|south|east|west)" + r"\s+of\s+" + + entity_pat + + r"(?:\.|,|;|:|\s*[→⇒✓✗]|\s*\n|\s*$" + r"|\s+(?:and|so|which|therefore|thus|but|since|because|while|whereas" + r"|however|hence|then|for|as|meaning|indicating|implying|suggesting" + r"|confirming|\())" + ) - matches = re.finditer(pattern, text, re.IGNORECASE) + matches = re.finditer(pattern, expanded_text, re.IGNORECASE) for match in matches: entity_a = match.group(1).strip() @@ -204,12 +415,19 @@ def parse_directional_claims_from_text(text: str) -> List[Dict]: entity_a = re.sub(r'[,\.\!\?]+$', '', entity_a).strip() entity_b = re.sub(r'[,\.\!\?]+$', '', entity_b).strip() + # Resolve abbreviations to full names if entity_names provided + if abbrev_to_full: + if entity_a in abbrev_to_full: + entity_a = abbrev_to_full[entity_a] + if entity_b in abbrev_to_full: + entity_b = abbrev_to_full[entity_b] + # Skip if entities look like fragments, pronouns, or are too short skip_words = {'then', 'if', 'so', 'thus', 'therefore', 'it', 'this', 'that', 'which', 'what', 'where', 'when', 'also', 'not', 'the', 'a', 'an'} if entity_a.lower() in skip_words or entity_b.lower() in skip_words: continue - if len(entity_a) < 3 or len(entity_b) < 3: + if len(entity_a) < 2 or len(entity_b) < 2: continue if not entity_a[0].isupper(): continue @@ -223,7 +441,10 @@ def parse_directional_claims_from_text(text: str) -> List[Dict]: return claims -def extract_step2_claims(answer_text: str) -> List[Dict]: +def extract_step2_claims( + answer_text: str, + entity_names: Optional[List[str]] = None, +) -> List[Dict]: """ Extract directional claims specifically from STEP 2 of the answer. @@ -243,7 +464,7 @@ def extract_step2_claims(answer_text: str) -> List[Dict]: return [] step2_text = match.group(0) - return parse_directional_claims_from_text(step2_text) + return parse_directional_claims_from_text(step2_text, entity_names=entity_names) def verify_spatialmap_step( @@ -257,11 +478,16 @@ def verify_spatialmap_step( Args: claim: {"A": entity1, "direction": direction, "B": entity2} z3_solver: The Z3 solver with known constraints - add_if_valid: If True, add the claim to the solver if it's valid + add_if_valid: If True, add the claim to the solver **only if it + is entailed** (i.e. its negation is UNSAT). Merely + satisfiable claims are accepted but NOT committed to the + solver so they cannot over-constrain future checks. Returns: (is_valid, errors) """ + from z3 import Not as Z3Not, unsat as z3unsat + errors = [] is_consistent = z3_solver.check_with_new_constraint(claim) @@ -274,7 +500,17 @@ def verify_spatialmap_step( return False, errors if add_if_valid: - z3_solver.apply_ir(claim) + # Only commit the claim if it is *entailed* (negation is UNSAT). + # This prevents merely-satisfiable-but-unproven claims from + # over-constraining the solver and blocking valid solutions later. + compiled = z3_solver.compile_constraint(claim) + if compiled is not None: + z3_solver.solver.push() + z3_solver.solver.add(Z3Not(compiled)) + is_entailed = z3_solver.solver.check() == z3unsat + z3_solver.solver.pop() + if is_entailed: + z3_solver.apply_ir(claim) return True, [] @@ -294,10 +530,280 @@ def format_spatialmap_feedback(errors: List[str], claim: Optional[Dict] = None) return feedback +# --------------------------------------------------------------------------- +# Direction-question helpers +# --------------------------------------------------------------------------- + +def parse_direction_question(problem_text: str) -> Optional[Dict]: + """ + If the problem asks a *direction* question + ("In which direction is X relative to Y?"), + return ``{"entity_a": X, "entity_b": Y}``. + + Returns ``None`` for non-direction questions. + """ + m = re.search( + r'In which direction is (.+?) relative to (.+?)\?', + problem_text, + re.IGNORECASE, + ) + if not m: + return None + return { + "entity_a": m.group(1).strip(), + "entity_b": m.group(2).strip(), + } + + +def parse_object_question(problem_text: str) -> Optional[Dict]: + """ + If the problem asks an *object* question + ("Which object is in the [direction] of [entity]?"), + return ``{"direction": ..., "reference": ...}``. + + Returns ``None`` for non-object questions. + """ + m = re.search( + r'Which object is (?:located )?(?:to the |in the )' + r'(northeast|northwest|southeast|southwest|north|south|east|west)' + r' of (.+?)\?', + problem_text, + re.IGNORECASE, + ) + if not m: + return None + return { + "direction": m.group(1).strip().lower(), + "reference": m.group(2).strip().rstrip("."), + } + + +def parse_model_boxed_answer( + text_after_think: str, options: Dict[str, str] +) -> Optional[str]: + """ + Extract the text value the model chose from its ``\\boxed{}`` answer. + Maps letter → option text using *options* dict. + Returns the raw option text (lowercase stripped) or None. + """ + boxed = re.findall(r'\\boxed\{([^}]*)\}', text_after_think) + if not boxed: + return None + answer = boxed[-1].strip().upper() + if answer in options: + return options[answer].strip().lower().rstrip(".") + # Try the raw value + return answer.lower() + + +def get_possible_directions( + solver: SpatialMapZ3Solver, + entity_a: str, + entity_b: str, +) -> List[str]: + """ + Return the list of diagonal directions (NE/NW/SE/SW) that are + *satisfiable* for entity_a relative to entity_b under the current + constraints. + + ``entity_a`` and ``entity_b`` are matched fuzzily against solver + entity names. + """ + from z3 import And as Z3And, sat as z3sat + + def _find(name): + nl = name.lower() + for k in solver.entities: + if k.endswith('_x') and k[:-2].lower() == nl: + return k[:-2] + for k in solver.entities: + if k.endswith('_x') and (nl in k[:-2].lower() or k[:-2].lower() in nl): + return k[:-2] + return None + + ba = _find(entity_a) + bb = _find(entity_b) + if not ba or not bb: + return ['northeast', 'northwest', 'southeast', 'southwest'] + + ax = solver.entities[f'{ba}_x'] + ay = solver.entities[f'{ba}_y'] + bx = solver.entities[f'{bb}_x'] + by = solver.entities[f'{bb}_y'] + + dir_constraints = { + 'northeast': Z3And(ax > bx, ay > by), + 'northwest': Z3And(ax < bx, ay > by), + 'southeast': Z3And(ax > bx, ay < by), + 'southwest': Z3And(ax < bx, ay < by), + } + + possible = [] + for dname, dc in dir_constraints.items(): + solver.solver.push() + solver.solver.add(dc) + if solver.solver.check() == z3sat: + possible.append(dname) + solver.solver.pop() + + return possible if possible else ['northeast', 'northwest', 'southeast', 'southwest'] + + +def get_consistent_object_options( + solver: SpatialMapZ3Solver, + direction: str, + reference: str, + options: Dict[str, str], +) -> List[str]: + """ + For an *object* question, return the list of MCQ letters whose entity + *could* be in ``direction`` of ``reference`` (Z3-satisfiable). + + Letters whose entities cannot be found in the solver are kept as + "possible" (benefit of the doubt). + """ + from z3 import And as Z3And, sat as z3sat + + def _find(name): + nl = name.lower() + for k in solver.entities: + if k.endswith('_x') and k[:-2].lower() == nl: + return k[:-2] + for k in solver.entities: + if k.endswith('_x') and (nl in k[:-2].lower() or k[:-2].lower() in nl): + return k[:-2] + return None + + ref_base = _find(reference) + if not ref_base: + return list(options.keys()) # can't check, keep all + + rx = solver.entities[f'{ref_base}_x'] + ry = solver.entities[f'{ref_base}_y'] + + dfunc = { + 'northeast': lambda ox, oy: Z3And(ox > rx, oy > ry), + 'northwest': lambda ox, oy: Z3And(ox < rx, oy > ry), + 'southeast': lambda ox, oy: Z3And(ox > rx, oy < ry), + 'southwest': lambda ox, oy: Z3And(ox < rx, oy < ry), + }.get(direction.lower()) + if not dfunc: + return list(options.keys()) + + consistent = [] + for letter, opt_name in options.items(): + opt_base = _find(opt_name.strip().rstrip('.')) + if not opt_base: + consistent.append(letter) # can't verify, assume possible + continue + ox = solver.entities[f'{opt_base}_x'] + oy = solver.entities[f'{opt_base}_y'] + solver.solver.push() + solver.solver.add(dfunc(ox, oy)) + if solver.solver.check() == z3sat: + consistent.append(letter) + solver.solver.pop() + + return consistent + + +def get_possible_count_range( + solver: SpatialMapZ3Solver, + reference: str, + direction: str, +) -> Optional[Tuple[int, int]]: + """ + Compute the *[min, max]* range of how many entities could be in + ``direction`` of ``reference`` across all satisfying assignments. + + Uses Z3 must-be / can-be checks per entity: + - *must_be*: negation is UNSAT → entity is ALWAYS in that direction + - *can_be*: adding constraint is SAT → entity COULD be there + + min = count(must_be), max = count(must_be) + count(maybe) + + Returns ``None`` if the reference entity cannot be found. + """ + from z3 import And as Z3And, Not as Z3Not, sat as z3sat, unsat as z3unsat + + direction = direction.lower().strip() + if direction in ('north', 'south', 'east', 'west'): + return (0, 0) # cardinal → always 0 with diagonal-only constraints + + def _find(name): + nl = name.lower() + for k in solver.entities: + if k.endswith('_x') and k[:-2].lower() == nl: + return k[:-2] + for k in solver.entities: + if k.endswith('_x') and (nl in k[:-2].lower() or k[:-2].lower() in nl): + return k[:-2] + return None + + ref_base = _find(reference) + if not ref_base: + return None + + rx = solver.entities[f'{ref_base}_x'] + ry = solver.entities[f'{ref_base}_y'] + + others = [ + k[:-2] for k in solver.entities + if k.endswith('_x') and k[:-2] != ref_base + ] + + dfunc = { + 'northeast': lambda ox, oy: Z3And(ox > rx, oy > ry), + 'northwest': lambda ox, oy: Z3And(ox < rx, oy > ry), + 'southeast': lambda ox, oy: Z3And(ox > rx, oy < ry), + 'southwest': lambda ox, oy: Z3And(ox < rx, oy < ry), + }.get(direction) + if not dfunc: + return None + + must_count = 0 + maybe_count = 0 + + for ename in others: + ex = solver.entities[f'{ename}_x'] + ey = solver.entities[f'{ename}_y'] + c = dfunc(ex, ey) + + # Can it be in that direction? + solver.solver.push() + solver.solver.add(c) + can_be = solver.solver.check() == z3sat + solver.solver.pop() + + if not can_be: + continue + + # Must it be? + solver.solver.push() + solver.solver.add(Z3Not(c)) + must_be = solver.solver.check() == z3unsat + solver.solver.pop() + + if must_be: + must_count += 1 + else: + maybe_count += 1 + + return (must_count, must_count + maybe_count) + + # Export __all__ = [ 'SpatialMapZ3Solver', 'parse_directional_claims_from_text', + 'parse_counting_question', + 'parse_model_count_from_answer', + 'parse_direction_question', + 'parse_object_question', + 'parse_model_boxed_answer', + 'get_possible_directions', + 'get_consistent_object_options', + 'get_possible_count_range', 'extract_step2_claims', 'verify_spatialmap_step', 'format_spatialmap_feedback',