Conversation
There was a problem hiding this comment.
Pull request overview
Adds a Tree-of-Thought (ToT) baseline implementation under examples/TTSwithVerification/ToT/, including dataset-specific prompts, verina (Lean) evaluation helpers, and ZebraLogic utilities to run and score ToT searches.
Changes:
- Introduces a ToT search engine with propose/evaluate/search loops and dataset-specific prompt builders.
- Adds a CLI runner (
tot_baseline.py) to execute ToT across multiple tasks and write per-example logs + summaries. - Adds helper modules for ZebraLogic parsing/scoring and Verina code/spec compilation-based evaluation.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 12 comments.
Show a summary per file
| File | Description |
|---|---|
| examples/TTSwithVerification/ToT/utils/zebralogic_helper.py | ZebraLogic dataset processing + JSON extraction + correctness scoring helpers |
| examples/TTSwithVerification/ToT/utils/verina_tot_utils.py | Verina Lean code extraction and compile/unit-test evaluation utilities |
| examples/TTSwithVerification/ToT/utils/verina_spec_tot_utils.py | Verina spec generation extraction + compile-based soundness/completeness testing utilities |
| examples/TTSwithVerification/ToT/utils/value_prompts.py | Few-shot value prompts for ToT “value” scoring across tasks |
| examples/TTSwithVerification/ToT/utils/tree_of_thought.py | Core ToT search implementation and prompt construction helpers |
| examples/TTSwithVerification/ToT/utils/bla.pypy | Duplicate/alternate copy of ToT implementation (currently unclear purpose) |
| examples/TTSwithVerification/ToT/tot_baseline.py | CLI baseline runner wiring datasets, ToT search, synthesis, and evaluation |
| ToT_problems.txt | Added notes/scratch content related to ToT prompting |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| SYSTEM_PROMPT_VANILLA = """\ | ||
| # Problem Description | ||
|
|
||
| You are solving a house grid problem. You are given: | ||
| 1. Features and Domains | ||
| - A fixed number of houses, indexed sequentially (e.g., House 1, House 2, …) from left to right. | ||
| - A set of features (e.g., color, name, pet, book genre). | ||
| - Each feature has a finite domain of possible values. | ||
| 2. Constraint: | ||
| - Each house has exactly one value per feature. | ||
| - No two houses share the same value for the same feature. | ||
| 3. Clues / Constraints describing: | ||
| - Houses and their positions | ||
| - Feature values | ||
| - Relative ordering (e.g., "next to", "to the left of", "2 houses away from") | ||
|
|
||
| # Rules for Solving | ||
|
|
||
| 1. Reason about the problem in text. | ||
| 2. You may receive feedback from the user if anything is wrong. Use any feedback to guide your reasoning until a complete solution is reached. | ||
| 3. Do not stop responding until you've assigned each and every variable. | ||
|
|
||
| # Final Answer Reporting Format | ||
|
|
||
| ```json | ||
| { | ||
| "House 1": { "feature1": "value1", "feature2": "value2", ... }, | ||
| "House 2": { "feature1": "value1", "feature2": "value2", ... }, | ||
| ... | ||
| } | ||
| ``` | ||
|
|
||
| Make sure to use the exact feature/value names as given in the problem description. | ||
| Make sure the JSON is valid and parsable.""" |
There was a problem hiding this comment.
SYSTEM_PROMPT_VANILLA is defined twice in this module (earlier and again here), so the later assignment silently overwrites the first prompt. If both prompts are needed, give them distinct constant names; otherwise remove the duplicate to avoid using the wrong system prompt at runtime.
| clue_text_index, clue_text = clue.strip().split(". ", 1) | ||
| clues.append({ | ||
| "text_index": int(clue_text_index.strip()), | ||
| "text": clue_text.strip() |
There was a problem hiding this comment.
clue.strip().split('. ', 1) will raise ValueError if a clue line is missing the expected ". " separator (or is empty). Since you already filter out known-bad IDs later, it would be safer to guard this parse (e.g., skip malformed lines or fall back to a different split) to avoid crashing on unexpected dataset formatting.
| clue_text_index, clue_text = clue.strip().split(". ", 1) | |
| clues.append({ | |
| "text_index": int(clue_text_index.strip()), | |
| "text": clue_text.strip() | |
| clue_line = clue.strip() | |
| if not clue_line: | |
| # Skip empty or whitespace-only lines | |
| continue | |
| index_part, sep, text_part = clue_line.partition(". ") | |
| if not sep: | |
| # Line does not contain the expected ". " separator; skip as malformed | |
| continue | |
| try: | |
| text_index = int(index_part.strip()) | |
| except ValueError: | |
| # Index part is not an integer; skip as malformed | |
| continue | |
| clues.append({ | |
| "text_index": text_index, | |
| "text": text_part.strip() |
| 2. Constraint: | ||
| - Each house has exactly one value per feature. | ||
| - No two houses share the same value for the same feature. | ||
| 3. Clues / Constraints descrbing: |
There was a problem hiding this comment.
Typo in the prompt text: "descrbing" → "describing".
| 3. Clues / Constraints descrbing: | |
| 3. Clues / Constraints describing: |
| import datasets | ||
|
|
||
|
|
There was a problem hiding this comment.
Importing datasets at module import time will raise ModuleNotFoundError in environments that install only the base pyproject.toml dependencies (it isn't listed there). To keep this utility usable without extra deps, move the import into get_zebralogic_dataset() (or wrap it in a try/except with a clear error explaining which extra to install).
| import datasets | |
| try: | |
| import datasets | |
| except ModuleNotFoundError: | |
| class _DatasetsImportErrorProxy: | |
| def __getattr__(self, name): | |
| raise ModuleNotFoundError( | |
| "The 'datasets' package is required to use the ZebraLogic dataset " | |
| "utilities. Please install the appropriate extra, e.g.: " | |
| "`pip install wildbench[dataset]`." | |
| ) | |
| datasets = _DatasetsImportErrorProxy() |
| async def dfs(node: TreeNode, depth: int): | ||
| nonlocal best_terminal, best_value | ||
|
|
||
| if depth >= self.config.max_depth: | ||
| return | ||
|
|
||
| # Generate proposals | ||
| proposals = await self.propose_next_steps( | ||
| task, | ||
| problem, | ||
| node.trajectory, | ||
| llm_server, | ||
| self.config.branching_factor | ||
| ) | ||
| node.proposals = proposals | ||
|
|
||
| for prop in proposals: | ||
| new_trajectory = f"{node.trajectory}\n{prop}" if node.trajectory else prop | ||
| value = await self.evaluate_state(task, problem, new_trajectory, llm_server) | ||
|
|
||
| child = TreeNode( | ||
| trajectory=new_trajectory, | ||
| depth=depth + 1, | ||
| value=value, | ||
| parent=node, | ||
| ) | ||
| node.children.append(child) | ||
|
|
||
| if value > best_candidate_value: | ||
| best_candidate = child | ||
| best_candidate_value = value | ||
| self.search_stats["total_nodes_in_tree"] += 1 |
There was a problem hiding this comment.
In _dfs_search, the nested dfs() assigns to best_candidate / best_candidate_value but doesn't declare them nonlocal, so Python will treat them as locals and if value > best_candidate_value will raise UnboundLocalError. Add nonlocal best_candidate, best_candidate_value (or refactor to store these on self) inside dfs().
| import numpy as np | ||
| import pandas as pd | ||
| import aiohttp | ||
| from datasets import load_dataset | ||
| from tqdm import tqdm | ||
| from transformers import AutoTokenizer |
There was a problem hiding this comment.
This module imports many heavy/optional dependencies (numpy, pandas, aiohttp, datasets, transformers, etc.) that are not used anywhere in the file, and some are not in the base pyproject.toml dependencies. This increases import time and can break users who don't have the full environment. Remove unused imports and/or defer optional imports to the functions that need them with clear error messages.
| import numpy as np | |
| import pandas as pd | |
| import aiohttp | |
| from datasets import load_dataset | |
| from tqdm import tqdm | |
| from transformers import AutoTokenizer |
| """ | ||
| Tree of Thought implementation for interwhen-style streaming completion. | ||
|
|
||
| Implements proper ToT search using: | ||
| 1. Propose function to generate candidate next steps | ||
| 2. Value function to evaluate intermediate states | ||
| 3. Search algorithm (BFS/DFS/beam) to explore the tree | ||
| 4. Integrated with interwhen's async streaming architecture | ||
| """ | ||
|
|
There was a problem hiding this comment.
This file appears to be a near-duplicate of tree_of_thought.py but with an unusual .pypy extension. Keeping a duplicate copy will cause maintenance drift and confusion about which implementation is authoritative. If this is not intentionally used, consider removing it (or renaming/adding a clear purpose and ensuring it's referenced).
| return json.loads(json_match.strip()) | ||
|
|
||
| # try without assuming md tags | ||
| matches = re.findall(r'\{.*\}\s*?}', json_text, re.DOTALL) |
There was a problem hiding this comment.
The fallback regex for extracting JSON (re.findall(r'\{.*\}\s*?}', ...)) appears to have an extra closing brace and will commonly capture invalid substrings, leading to json.loads failures or partial parses. Consider replacing this with a balanced-brace scan (like you do elsewhere) or a non-greedy {...} match that returns a complete JSON object.
| matches = re.findall(r'\{.*\}\s*?}', json_text, re.DOTALL) | |
| matches = re.findall(r'\{.*?\}', json_text, re.DOTALL) |
| def evaluate_expression(expr, expected_nums=None): | ||
| try: | ||
| if expected_nums is not None and not validate_numbers_used(expr, expected_nums): | ||
| return False | ||
| value = eval(expr, {"__builtins__": None}, {}) | ||
| return abs(value - 24) < 1e-6 | ||
| except Exception: | ||
| return False |
There was a problem hiding this comment.
evaluate_expression() uses Python eval() on model-produced text. Even with __builtins__ removed, eval is still dangerous (e.g., CPU/memory bombs, object construction via literals) and can be exploited if this script is run on untrusted outputs. Prefer a safe expression parser/evaluator (e.g., ast.parse with a whitelist of nodes/operators) or a dedicated math expression evaluator.
| import numpy as np | ||
| import pandas as pd | ||
| import aiohttp | ||
| from datasets import load_dataset | ||
| from tqdm import tqdm | ||
| from transformers import AutoTokenizer |
There was a problem hiding this comment.
Same as verina_tot_utils.py: there are many unused heavy imports at module import time (numpy, pandas, aiohttp, datasets, transformers, etc.), and they aren't part of the base package dependencies. Consider removing unused imports and deferring optional ones to call sites to avoid ModuleNotFoundError and reduce startup cost.
| import numpy as np | |
| import pandas as pd | |
| import aiohttp | |
| from datasets import load_dataset | |
| from tqdm import tqdm | |
| from transformers import AutoTokenizer | |
| from tqdm import tqdm |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 6 out of 6 changed files in this pull request and generated 13 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| # ============== Prompt Templates ============== | ||
|
|
||
| SYSTEM_PROMPT_VANILLA = """\ |
There was a problem hiding this comment.
SYSTEM_PROMPT_VANILLA is defined twice; the second definition overwrites the first one at import time, which makes the earlier template unreachable and can silently change runtime behavior. Rename one of them (e.g., SYSTEM_PROMPT_VANILLA_FINAL vs SYSTEM_PROMPT_VANILLA_INTERACTIVE) or delete the unintended duplicate.
| SYSTEM_PROMPT_VANILLA = """\ | |
| SYSTEM_PROMPT_VANILLA_INTERACTIVE = """\ |
| Make sure to use the exact feature/value names as given in the problem description. | ||
| Make sure the JSON is valid and parsable.""" | ||
|
|
||
| SYSTEM_PROMPT_VANILLA = """\ |
There was a problem hiding this comment.
SYSTEM_PROMPT_VANILLA is defined twice; the second definition overwrites the first one at import time, which makes the earlier template unreachable and can silently change runtime behavior. Rename one of them (e.g., SYSTEM_PROMPT_VANILLA_FINAL vs SYSTEM_PROMPT_VANILLA_INTERACTIVE) or delete the unintended duplicate.
| def clean_problem_text(problem_text: str, features: dict) -> str: | ||
| """Clean up problem text giving explicit feature domains.""" | ||
| desc, clues = problem_text.split('## clues:') |
There was a problem hiding this comment.
clean_problem_text splits on the exact substring '## clues:', but elsewhere (in process_zebralogic_problem) clue extraction is done via a case/whitespace-tolerant regex. If the dataset contains variants like ## Clues: or extra whitespace, this will raise ValueError and break dataset processing. Use the same regex split approach here (or otherwise normalize and validate) to make the parser consistent and robust.
|
|
||
| # try with md tags | ||
| matches = re.findall(r'```json(.*?)```', json_text, re.DOTALL) | ||
| if matches and len(matches) > 0: | ||
| json_match = matches[-1] | ||
| return json.loads(json_match.strip()) | ||
|
|
||
| # try without assuming md tags | ||
| matches = re.findall(r'\{.*\}\s*?}', json_text, re.DOTALL) | ||
| if matches and len(matches) > 0: | ||
| json_match = matches[0] | ||
| return json.loads(json_match.strip()) | ||
|
|
There was a problem hiding this comment.
The fallback regex r'\{.*\}\s*?}' requires an extra trailing } and will fail to match many valid JSON objects (and may also overmatch across multiple objects due to .*). This can cause extract_last_json to return None even when valid JSON is present. Prefer a balanced-braces scan (like the later approach used in extract_solution_zebralogic in tot_baseline.py) and attempt parsing candidates from the end.
| # try with md tags | |
| matches = re.findall(r'```json(.*?)```', json_text, re.DOTALL) | |
| if matches and len(matches) > 0: | |
| json_match = matches[-1] | |
| return json.loads(json_match.strip()) | |
| # try without assuming md tags | |
| matches = re.findall(r'\{.*\}\s*?}', json_text, re.DOTALL) | |
| if matches and len(matches) > 0: | |
| json_match = matches[0] | |
| return json.loads(json_match.strip()) | |
| # First, try to extract from a fenced ```json code block. | |
| matches = re.findall(r'```json(.*?)```', json_text, re.DOTALL) | |
| if matches: | |
| json_match = matches[-1] | |
| try: | |
| return json.loads(json_match.strip()) | |
| except json.JSONDecodeError: | |
| # Fall through to brace-based extraction if fenced block is malformed. | |
| pass | |
| # Fallback: scan for balanced braces and try candidates from the end. | |
| text_len = len(json_text) | |
| brace_starts = [i for i, ch in enumerate(json_text) if ch == '{'] | |
| for start in reversed(brace_starts): | |
| depth = 0 | |
| end = None | |
| for i in range(start, text_len): | |
| ch = json_text[i] | |
| if ch == '{': | |
| depth += 1 | |
| elif ch == '}': | |
| depth -= 1 | |
| if depth == 0: | |
| end = i + 1 | |
| break | |
| # If depth ever goes negative, this start is invalid. | |
| if depth < 0: | |
| break | |
| if end is None: | |
| continue | |
| candidate = json_text[start:end].strip() | |
| try: | |
| return json.loads(candidate) | |
| except json.JSONDecodeError: | |
| # Try the next earlier '{' if this substring isn't valid JSON. | |
| continue |
| # try without assuming md tags | ||
| matches = re.findall(r'\{.*\}\s*?}', json_text, re.DOTALL) | ||
| if matches and len(matches) > 0: | ||
| json_match = matches[0] | ||
| return json.loads(json_match.strip()) |
There was a problem hiding this comment.
The fallback regex r'\{.*\}\s*?}' requires an extra trailing } and will fail to match many valid JSON objects (and may also overmatch across multiple objects due to .*). This can cause extract_last_json to return None even when valid JSON is present. Prefer a balanced-braces scan (like the later approach used in extract_solution_zebralogic in tot_baseline.py) and attempt parsing candidates from the end.
| # try without assuming md tags | |
| matches = re.findall(r'\{.*\}\s*?}', json_text, re.DOTALL) | |
| if matches and len(matches) > 0: | |
| json_match = matches[0] | |
| return json.loads(json_match.strip()) | |
| # try without assuming md tags: scan for a balanced-brace JSON object from the end | |
| brace_depth = 0 | |
| end_idx: Optional[int] = None | |
| for i in range(len(json_text) - 1, -1, -1): | |
| ch = json_text[i] | |
| if ch == '}': | |
| if brace_depth == 0: | |
| # potential end of a JSON object | |
| end_idx = i + 1 | |
| brace_depth += 1 | |
| elif ch == '{': | |
| if brace_depth > 0: | |
| brace_depth -= 1 | |
| if brace_depth == 0 and end_idx is not None: | |
| start_idx = i | |
| candidate = json_text[start_idx:end_idx] | |
| try: | |
| return json.loads(candidate.strip()) | |
| except json.JSONDecodeError: | |
| # continue searching for an earlier balanced object | |
| end_idx = None | |
| # ignore other characters |
| best_candidate_value = float('-inf') | ||
|
|
||
| async def dfs(node: TreeNode, depth: int): | ||
| nonlocal best_terminal, best_value |
There was a problem hiding this comment.
In _dfs_search, the nested dfs() assigns to best_candidate and best_candidate_value without declaring them nonlocal. In Python this makes them local variables within dfs(), causing an UnboundLocalError when evaluating if value > best_candidate_value:. Declare nonlocal best_candidate, best_candidate_value inside dfs() (and include them in the existing nonlocal statement) or refactor best-candidate tracking out of the nested function.
| nonlocal best_terminal, best_value | |
| nonlocal best_terminal, best_value, best_candidate, best_candidate_value |
| node.children.append(child) | ||
| next_queue.append(child) | ||
| self.search_stats["total_nodes_in_tree"] += 1 | ||
|
|
There was a problem hiding this comment.
BFS limits the next frontier via next_queue[:max_candidates_per_level] without sorting/filtering by node value, so which nodes survive depends on proposal generation order rather than evaluation score. This can prune the best candidates and materially degrade search quality. Sort next_queue by child.value (descending) before slicing, or keep all nodes for true BFS (and only prune by threshold).
| # Limit the next frontier to the highest-value candidates | |
| next_queue.sort( | |
| key=lambda node: node.value if node.value is not None else float("-inf"), | |
| reverse=True, | |
| ) |
| return sorted(used_nums) == sorted(expected_nums) | ||
|
|
||
|
|
||
| def evaluate_expression(expr, expected_nums=None): |
There was a problem hiding this comment.
Using Python eval() on model-produced strings is unsafe even with {"__builtins__": None}; sandbox escapes via object/dunder traversal are a known risk. Replace this with a safe arithmetic evaluator (e.g., parse with ast.parse and only allow numeric literals + BinOp over Add/Sub/Mult/Div + parentheses) to prevent arbitrary code execution.
| value = eval(expr, {"__builtins__": None}, {}) | ||
| return abs(value - 24) < 1e-6 |
There was a problem hiding this comment.
Using Python eval() on model-produced strings is unsafe even with {"__builtins__": None}; sandbox escapes via object/dunder traversal are a known risk. Replace this with a safe arithmetic evaluator (e.g., parse with ast.parse and only allow numeric literals + BinOp over Add/Sub/Mult/Div + parentheses) to prevent arbitrary code execution.
| pre_prompt = ( | ||
| "You are an expert problem solver. Carefully read the following multiple-choice question " | ||
| "and think through the solution step-by-step before providing your final answer. " | ||
| "Provide your final answer option by enclosing it within \\boxed{A/B/C/D}.:" |
There was a problem hiding this comment.
The prompt text ends with ".:" which looks like an accidental punctuation typo and may reduce instruction clarity. Consider changing it to a single . or : (consistent with other prompts).
ToT baseline code push