Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,15 @@ __pycache__/
# results
results/

# agent run logs (data_agent traces, llm_calls, exec_tool_work_dir, etc.)
query_*/query*/logs/
# executor temp scripts (also removed after each run in ExecTool)
query_*/query*/logs/**/tmp_code_*.py

# dependencies
requirements.txt
Miniconda3-*.sh
Anaconda3-*.sh

# scripts
python_script/
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -226,12 +226,14 @@ AZURE_API_KEY=
AZURE_API_VERSION=
GEMINI_API_KEY=
TOGETHER_API_KEY=
OPENROUTER_API_KEY=
```

Currently, we support
- Microsoft Azure API (for GPT models)
- Google Gemini API (for Gemini models)
- Together.AI API (for Kimi and Qwen models)
- OpenRouter API (for OpenRouter model IDs, e.g. `openrouter/google/gemini-2.5-pro`)

If you want to use a model not yet supported by default, you may register it in [DataAgent.py](./common_scaffold/DataAgent.py):
```python
Expand Down
37 changes: 29 additions & 8 deletions common_scaffold/DataAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import time
import logging
from openai import AzureOpenAI, OpenAI
from openai import BadRequestError
from openai.types.chat import ChatCompletionMessage, ChatCompletionMessageToolCall
from dotenv import load_dotenv
from common_scaffold.prompts import prompt_builder
Expand Down Expand Up @@ -73,12 +74,18 @@ def __init__(
self.logger.info(f"\tmax_iterations: {self.max_iterations}")
self.llm_call_count = 0
load_dotenv()
if "gpt" in deployment_name.lower():
use_openrouter = bool(os.getenv("OPENROUTER_API_KEY")) and "/" in deployment_name and not deployment_name.lower().startswith("gemini")
if "gpt" in deployment_name.lower() and not use_openrouter:
self.client = AzureOpenAI(
api_key=os.getenv("AZURE_API_KEY"),
api_version=os.getenv("AZURE_API_VERSION"),
azure_endpoint=os.getenv("AZURE_API_BASE")
)
elif deployment_name.lower().startswith("openrouter/") or "openrouter" in deployment_name.lower() or use_openrouter:
self.client = OpenAI(
api_key=os.getenv("OPENROUTER_API_KEY"),
base_url="https://openrouter.ai/api/v1",
)
elif "gemini" in deployment_name.lower():
self.client = OpenAI(
api_key=os.getenv("GEMINI_API_KEY"),
Expand Down Expand Up @@ -191,13 +198,27 @@ def call_llm(self):
start = time.time()
response = None
for attempt in range(3):
try:
response = self.client.chat.completions.create(
model=self.deployment_name,
messages=self.messages,
tools=[tool.get_spec() for tool in self.tools.values()],
timeout=600,
)
try:
tool_specs = [tool.get_spec() for tool in self.tools.values()]
try:
response = self.client.chat.completions.create(
model=self.deployment_name,
messages=self.messages,
tools=tool_specs,
tool_choice="required",
timeout=600,
)
except BadRequestError as e:
# Some providers reject tool_choice="required"; retry without it.
if "tool_choice" in str(e).lower():
response = self.client.chat.completions.create(
model=self.deployment_name,
messages=self.messages,
tools=tool_specs,
timeout=600,
)
else:
raise
break
except Exception as e:
response = None
Expand Down
9 changes: 8 additions & 1 deletion common_scaffold/prompts/prompt_builder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

PREVIEW_LENGTH = 10000

GPT_TOOL_CALL_INSTRUCTIONS = """2. Inside execute_python code you may read storage entries directly as variables using the provided key names. You should directly use the key names as variable names in your code, e.g., if the tool call id is "call_1", you can access its result via the variable `var_call_1` in your code, without quotes or other modifications."""
Expand Down Expand Up @@ -76,7 +78,12 @@

def init_messages(user_query: str, db_description: str, deployment_name: str, system_prompt: str=SYSTEM_PROMPT) -> list[dict]:
system_prompt_suffix = ""
if "gemini" in deployment_name.lower():
use_openrouter = bool(os.getenv("OPENROUTER_API_KEY")) and "/" in deployment_name and not deployment_name.lower().startswith("gemini")
if "openrouter" in deployment_name.lower() or use_openrouter:
# OpenRouter model IDs vary widely (and may include hyphens/colons), so
# use the safest access pattern for tool-call result variables.
tool_call_instructions = GEMINI_TOOL_CALL_INSTRUCTIONS
elif "gemini" in deployment_name.lower():
tool_call_instructions = GEMINI_TOOL_CALL_INSTRUCTIONS
if deployment_name.lower() == "gemini-2.5-flash":
tool_call_instructions = GEMINI_25FLASH_TOOL_CALL_INSTRUCTIONS
Expand Down
97 changes: 53 additions & 44 deletions common_scaffold/tools/ExecTool.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,20 @@
import json


def _unlink_tmp_code_scripts(work_dir: Path, logger: logging.Logger) -> None:
"""Remove DockerCommandLineCodeExecutor temp scripts (tmp_code_<hash>.py) from work_dir."""
try:
if not work_dir.is_dir():
return
for path in work_dir.glob("tmp_code_*.py"):
try:
path.unlink(missing_ok=True)
except OSError as e:
logger.debug("Could not remove %s: %s", path, e)
except OSError as e:
logger.debug("tmp_code cleanup skipped: %s", e)


class ExecTool(BaseTool):
"""
A robust, synchronous interface around AutoGen's DockerCommandLineCodeExecutor.
Expand Down Expand Up @@ -151,6 +165,7 @@ def _run_with_timeout(self, blocks):
def close(self):
"""Explicit shutdown method."""
self._stop_executor()
_unlink_tmp_code_scripts(self.work_dir, self.logger)
self._loop.stop()
self._loop.close()

Expand Down Expand Up @@ -185,53 +200,47 @@ def _check_args(self, args):

def _exec(self, args):
super()._exec(args)
if "code" in args:
env_args = args["env"]
exec_str = f'''code = """{args["code"]}"""\n\nenv_args = {env_args}\n\nexec(code, env_args)\n'''
result = self.run_python(exec_str)
elif "command" in args:
result = self.run_shell(args["command"])
else:
raise FatalError("Invalid argument")


self.logger.debug(f"ExecTool execution result: {result}")
# Log artifact
artifact_entry = {"val_args": args}
try:
artifact_entry['exit_code'] = result.exit_code
except:
# artifact_entry['exit_code'] = None
raise FatalError("Execution did not return an exit code")
try:
artifact_entry['output'] = result.output
except:
# artifact_entry['output'] = None
raise FatalError("Execution did not return output")
try:
artifact_entry['code_file'] = str(result.code_file)
except:
# artifact_entry['code_file'] = None
raise FatalError("Execution did not return code file")
with open(self.artifact_log_path, "a", encoding="utf-8") as f:
f.write(json.dumps(artifact_entry) + "\n")


if result.exit_code != 0:
# Handle timeout case separately
if "code execution was cancelled" in result.output.lower():
raise TimeoutError(f"Execution timed out after {self.timeout} seconds")
if "code" in args:
env_args = args["env"]
exec_str = f'''code = """{args["code"]}"""\n\nenv_args = {env_args}\n\nexec(code, env_args)\n'''
result = self.run_python(exec_str)
elif "command" in args:
result = self.run_shell(args["command"])
else:
raise FatalError("Invalid argument")

self.logger.debug(f"ExecTool execution result: {result}")
# Log artifact
artifact_entry = {"val_args": args}
try:
clean_err = result.output.strip().splitlines()[-1]
except: # fallback
clean_err = result.output
raise ValueError(f"Execution failed with exit code {result.exit_code}\n{clean_err}")
else:
artifact_entry['exit_code'] = result.exit_code
except Exception:
raise FatalError("Execution did not return an exit code")
try:
artifact_entry['output'] = result.output
except Exception:
raise FatalError("Execution did not return output")
try:
artifact_entry['code_file'] = str(result.code_file)
except Exception:
raise FatalError("Execution did not return code file")
with open(self.artifact_log_path, "a", encoding="utf-8") as f:
f.write(json.dumps(artifact_entry) + "\n")

if result.exit_code != 0:
# Handle timeout case separately
if "code execution was cancelled" in result.output.lower():
raise TimeoutError(f"Execution timed out after {self.timeout} seconds")
try:
clean_err = result.output.strip().splitlines()[-1]
except Exception:
clean_err = result.output
raise ValueError(f"Execution failed with exit code {result.exit_code}\n{clean_err}")
if "code" in args:
# Parse output for PRINT FORMAT
parsed_output = parse_result_python(result.output)
self.logger.debug(f"Parsed ExecTool output: {parsed_output}")
return parsed_output
else:
return result.output

return result.output
finally:
_unlink_tmp_code_scripts(self.work_dir, self.logger)
2 changes: 2 additions & 0 deletions common_scaffold/tools/db_utils/db_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

# MongoDB
MONGO_URI = os.getenv("MONGO_URI", "mongodb://localhost:27017/")
# Path to mongorestore (MongoDB Database Tools); use full path if not on PATH
MONGORESTORE = os.getenv("MONGORESTORE", "mongorestore")

# SQLite
SQLITE_PATH = os.getenv("SQLITE_PATH", "data/mydb.sqlite")
Expand Down
21 changes: 14 additions & 7 deletions common_scaffold/tools/db_utils/mongo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,20 @@ def load_db(dump_folder: str, db_name: str):
# ["mongorestore", f"--nsInclude={db_name}.*", dump_path],
# check=True
# )
result = subprocess.run(
["mongorestore", f"--nsInclude={db_name}.*", dump_path],
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True
)
cmd = [db_config.MONGORESTORE, f"--uri={db_config.MONGO_URI}", f"--nsInclude={db_name}.*", str(dump_path)]
try:
result = subprocess.run(
cmd,
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
except FileNotFoundError as e:
raise FatalError(
f"mongorestore not found ({db_config.MONGORESTORE!r}). Install MongoDB Database Tools "
f"(package is often `mongodb-database-tools`) or set MONGORESTORE in .env to the full path to mongorestore."
) from e
if result.stdout:
logging.getLogger(__name__).debug(f"MongoDB load stdout: {result.stdout}")
if result.stderr:
Expand Down
6 changes: 3 additions & 3 deletions download.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ echo "Downloading database (~5GB)..."
# Create directory if needed
mkdir -p "$(dirname "$OUTPUT_PATH")"

# Download using gdown
if ! command -v gdown &> /dev/null; then
# Download using gdown (python -m avoids ~/.local/bin not being on PATH)
if ! python3 -c "import gdown" 2>/dev/null; then
echo "gdown not found. Installing..."
pip install gdown
fi

gdown --id "$FILE_ID" -O "$OUTPUT_PATH"
python3 -m gdown "https://drive.google.com/uc?id=${FILE_ID}" -O "$OUTPUT_PATH"

echo "Download complete."

Expand Down
2 changes: 1 addition & 1 deletion environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ dependencies:
- pillow==12.0.0
- propcache==0.4.1
- protobuf==5.29.5
- psycopg2==2.9.9
- psycopg2-binary==2.9.9
- pyarrow==22.0.0
- pydantic==2.11.7
- pydantic-core==2.33.2
Expand Down
Loading