Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions benchmarks/comprehensive_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Based on ai-sandbox-benchmark (Apache 2.0 License)
https://github.com/nibzard/ai-sandbox-benchmark
"""

import asyncio
import os
import sys
Expand Down Expand Up @@ -51,7 +52,8 @@
},
"prime_calculation": {
"name": "Prime Calculation",
"command": """python3 -c "
"command": (
"""python3 -c "
def is_prime(n):
if n < 2: return False
for i in range(2, int(n**0.5) + 1):
Expand All @@ -61,13 +63,15 @@ def is_prime(n):
primes = [n for n in range(2, 1000) if is_prime(n)]
print(f'Found {len(primes)} primes')
"
""",
"""
),
"runs": 5,
"description": "CPU-bound computation",
},
"file_io": {
"name": "File I/O (1000 files)",
"command": """python3 -c "
"command": (
"""python3 -c "
import os
# Write 1000 small files
for i in range(1000):
Expand All @@ -82,25 +86,30 @@ def is_prime(n):

print(f'Processed {total} bytes')
"
""",
"""
),
"runs": 3,
"description": "I/O performance test",
},
"package_install": {
"name": "pip install requests",
"command": "pip install -q requests && python3 -c 'import requests; print(f\"requests {requests.__version__}\")'",
"command": (
"pip install -q requests && python3 -c 'import requests; print(f\"requests {requests.__version__}\")'"
),
"runs": 2,
"description": "Package installation speed (requests already installed in standard image)",
},
"numpy_fft": {
"name": "NumPy FFT",
"command": """python3 -c "
"command": (
"""python3 -c "
import numpy as np
x = np.random.random(10000)
result = np.fft.fft(x)
print(f'FFT: {len(result)} points')
"
""",
"""
),
"runs": 3,
"description": "Numerical computation with pre-installed packages",
},
Expand Down
1 change: 1 addition & 0 deletions benchmarks/run_all_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

Outputs comprehensive results to benchmarks/results.txt
"""

import subprocess
import sys
import time
Expand Down
24 changes: 12 additions & 12 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,25 +26,25 @@ dependencies = [
"typing-extensions>=4.0.0",
"click>=8.0.0",
"tabulate>=0.9.0",
"modal>=1.1.4",
"e2b>=2.0.0",
"daytona>=0.103.0",
"hopx-ai>=0.3.0",
"modal>=1.3.3",
"e2b>=2.13.2",
"daytona>=0.143.0",
"hopx-ai>=0.5.0",
"httpx>=0.27.0",
]

[project.optional-dependencies]
daytona = [
"daytona==0.103.0", # Official Daytona SDK - latest stable version
"daytona==0.143.0", # Official Daytona SDK - latest stable version
]
e2b = [
"e2b>=2.0.0", # Regular E2B SDK for standard Linux sandboxes
"e2b>=2.13.2", # Regular E2B SDK for standard Linux sandboxes
]
modal = [
"modal==1.1.4", # Latest stable version
"modal==1.3.3", # Latest stable version
]
hopx = [
"hopx-ai>=0.3.0", # Official Hopx SDK for secure cloud sandboxes
"hopx-ai>=0.5.0", # Official Hopx SDK for secure cloud sandboxes
]
# vercel = [
# "vercel-sdk>=0.1.0", # When available
Expand All @@ -53,10 +53,10 @@ hopx = [
# "cloudflare-workers-sdk>=0.1.0", # When available
# ]
all = [
"daytona==0.103.0",
"e2b>=2.0.0",
"modal==1.1.4",
"hopx-ai>=0.3.0",
"daytona==0.143.0",
"e2b>=2.13.2",
"modal==1.3.3",
"hopx-ai>=0.5.0",
]
dev = [
"pytest>=7.4.0",
Expand Down
4 changes: 3 additions & 1 deletion sandboxes/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

import click

from . import __version__


def get_provider(name: str):
"""Get a provider instance by name."""
Expand Down Expand Up @@ -38,7 +40,7 @@ def get_provider(name: str):


@click.group()
@click.version_option(version="0.2.3", prog_name="cased-sandboxes")
@click.version_option(version=__version__, prog_name="cased-sandboxes")
def cli():
"""Universal AI code execution sandboxes."""
pass
Expand Down
57 changes: 52 additions & 5 deletions sandboxes/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ def __init__(self, pool_config: PoolConfig | None = None):
# Locks for thread-safe operations
self._lock = asyncio.Lock()
self._condition = asyncio.Condition(self._lock)
self._ensure_min_idle_lock = asyncio.Lock()

# Template used for eager prewarming
self._warm_provider: Any | None = None
self._warm_config: SandboxConfig | None = None

# Cleanup task
self._cleanup_task: asyncio.Task | None = None
Expand All @@ -106,11 +111,15 @@ def __init__(self, pool_config: PoolConfig | None = None):
"errors": 0,
}

async def start(self):
async def start(self, provider: Any | None = None, config: SandboxConfig | None = None):
"""Start the pool and background tasks."""
if self.config.auto_cleanup:
self._cleanup_task = asyncio.create_task(self._cleanup_loop())

if provider and config:
self._warm_provider = provider
self._warm_config = config

# Pre-create sandboxes if using eager strategy
if self.config.strategy == PoolStrategy.EAGER:
await self._ensure_min_idle()
Expand Down Expand Up @@ -149,6 +158,12 @@ async def acquire(
if self.config.max_total <= 0:
raise SandboxQuotaError("Pool limit reached: 0")

self._warm_provider = provider
self._warm_config = config

if self.config.strategy == PoolStrategy.EAGER:
await self._ensure_min_idle(provider, config)

eviction_entry: SandboxPoolEntry | None = None

try:
Expand Down Expand Up @@ -223,6 +238,9 @@ async def release(self, sandbox_id: str):
for entry in evictions:
await self._finalize_eviction(entry)

if self.config.strategy == PoolStrategy.EAGER:
await self._ensure_min_idle()

async def destroy(self, sandbox_id: str):
"""
Destroy a sandbox and remove from pool.
Expand Down Expand Up @@ -266,6 +284,7 @@ async def _create_sandbox(self, provider: Any, config: SandboxConfig) -> Sandbox

# Add to pool
self._pool[sandbox.id] = entry
self._idle_sandboxes.add(sandbox.id)

# Update label index
for key, value in entry.labels.items():
Expand Down Expand Up @@ -354,11 +373,36 @@ async def _evict_idle_sandbox(self) -> bool:
await self._finalize_eviction(entry)
return True

async def _ensure_min_idle(self):
async def _ensure_min_idle(
self, provider: Any | None = None, config: SandboxConfig | None = None
) -> None:
"""Ensure minimum idle sandboxes (for eager strategy)."""
# This would need provider and config information
# Implement based on specific requirements
pass
async with self._ensure_min_idle_lock:
if provider and config:
self._warm_provider = provider
self._warm_config = config

provider_to_use = provider or self._warm_provider
config_to_use = config or self._warm_config
if provider_to_use is None or config_to_use is None:
return

target_idle = min(self.config.min_idle, self.config.max_idle, self.config.max_total)
if target_idle <= 0:
return

while True:
async with self._lock:
idle_count = len(self._idle_sandboxes)
total_count = len(self._pool)
if idle_count >= target_idle or total_count >= self.config.max_total:
return
try:
await self._create_sandbox(provider_to_use, config_to_use)
except Exception as e:
self._stats["errors"] += 1
logger.error(f"Failed to pre-create idle sandbox: {e}")
return

async def _remove_from_pool(self, sandbox_id: str):
"""Remove sandbox from pool and indexes."""
Expand Down Expand Up @@ -450,6 +494,9 @@ async def _cleanup_expired(self):
logger.info(f"Cleaning up expired sandbox {sandbox_id}")
await self.destroy(sandbox_id)

if self.config.strategy == PoolStrategy.EAGER:
await self._ensure_min_idle()

async def _call_hook(self, hook: Callable, *args, **kwargs):
"""Call a lifecycle hook safely."""
try:
Expand Down
42 changes: 35 additions & 7 deletions sandboxes/providers/cloudflare.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import asyncio
import base64
import json
import re
import shlex
import time
import uuid
from collections.abc import AsyncIterator
from contextlib import suppress
Expand All @@ -17,6 +20,7 @@
from ..security import validate_download_path, validate_upload_path

_DEFAULT_TIMEOUT = 30.0
_ENV_VAR_NAME_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")


class CloudflareProvider(SandboxProvider):
Expand Down Expand Up @@ -47,6 +51,7 @@ def __init__(
self.account_id = account_id
self._transport = transport
self._user_agent = "cased-sandboxes/0.4.2"
self._last_accessed: dict[str, float] = {}

@property
def name(self) -> str:
Expand All @@ -73,6 +78,7 @@ async def create_sandbox(self, config: SandboxConfig) -> Sandbox:
"created_via": "cloudflare",
},
)
self._touch_session(session_id)
return sandbox

async def get_sandbox(self, sandbox_id: str) -> Sandbox | None:
Expand Down Expand Up @@ -133,6 +139,7 @@ async def destroy_sandbox(self, sandbox_id: str) -> bool:
)
except SandboxNotFoundError:
return False
self._last_accessed.pop(sandbox_id, None)
return True

async def stream_execution(
Expand Down Expand Up @@ -264,10 +271,11 @@ async def upload_file(
# Create directory if needed
dir_path = "/".join(remote_path.split("/")[:-1])
if dir_path:
await self.execute_command(sandbox_id, f"mkdir -p {dir_path}")
await self.execute_command(sandbox_id, f"mkdir -p {shlex.quote(dir_path)}")
# Write file using base64 decode
result = await self.execute_command(
sandbox_id, f"echo '{encoded}' | base64 -d > {remote_path}"
sandbox_id,
f"echo {shlex.quote(encoded)} | base64 -d > {shlex.quote(remote_path)}",
)
return result.success

Expand Down Expand Up @@ -295,7 +303,9 @@ async def download_file(
return True
except (SandboxError, SandboxNotFoundError):
# Fallback: use cat and base64 encoding to read file
result = await self.execute_command(sandbox_id, f"cat {remote_path} | base64")
result = await self.execute_command(
sandbox_id, f"cat {shlex.quote(remote_path)} | base64"
)
if not result.success:
return False

Expand All @@ -315,11 +325,14 @@ async def cleanup_idle_sandboxes(self, idle_timeout: int = 600) -> None:
cleans up our tracking. Actual sandbox cleanup happens automatically.
"""
sandboxes = await self.list_sandboxes()
asyncio.get_event_loop().time()
now = time.time()

for sandbox in sandboxes:
# Since we don't track last access time in the Worker,
# we'll clean up all sandboxes as a precaution
last_accessed = self._last_accessed.get(sandbox.id)
if last_accessed is None:
continue
if now - last_accessed <= idle_timeout:
continue
with suppress(SandboxNotFoundError):
await self.destroy_sandbox(sandbox.id)

Expand Down Expand Up @@ -360,13 +373,28 @@ def _apply_env_vars_to_command(
) -> str:
if not env_vars:
return command
exports = " && ".join([f"export {key}='{value}'" for key, value in env_vars.items()])
exports = " && ".join(
[
f"export {CloudflareProvider._validate_env_var_name(key)}={shlex.quote(str(value))}"
for key, value in env_vars.items()
]
)
return f"{exports} && {command}"

@staticmethod
def _validate_env_var_name(key: str) -> str:
if not _ENV_VAR_NAME_RE.match(key):
raise SandboxError(f"Invalid environment variable name: {key}")
return key

async def _ensure_session_exists(self, sandbox_id: str) -> None:
sandbox = await self.get_sandbox(sandbox_id)
if not sandbox:
raise SandboxNotFoundError(f"Session {sandbox_id} not found")
self._touch_session(sandbox_id)

def _touch_session(self, sandbox_id: str) -> None:
self._last_accessed[sandbox_id] = time.time()

async def _request(
self,
Expand Down
Loading