Skip to content

Commit 3894ee6

Browse files
authored
Fix thread-unsafe _workspace_path_override race (#43) (#54)
* add initial implementation * fix: avoid cross-call snapshot assertions. * fix review comments from reviewer
1 parent 5d17c08 commit 3894ee6

2 files changed

Lines changed: 132 additions & 6 deletions

File tree

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
"""
2+
Regression tests for issue #43 — thread-safe _workspace_path_override.
3+
4+
Run:
5+
python -m unittest tests.test_workspace_path_thread_safety -v
6+
"""
7+
8+
from __future__ import annotations
9+
10+
import os
11+
import shutil
12+
import sys
13+
import tempfile
14+
import threading
15+
import unittest
16+
from concurrent.futures import ThreadPoolExecutor, as_completed
17+
18+
REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
19+
sys.path.insert(0, REPO_ROOT)
20+
21+
from utils.workspace_path import (
22+
resolve_workspace_path,
23+
set_workspace_path_override,
24+
)
25+
26+
27+
class TestWorkspacePathThreadSafety(unittest.TestCase):
28+
"""Concurrent set-workspace + resolve must not observe torn global state."""
29+
30+
def setUp(self):
31+
self.tmp = tempfile.mkdtemp(prefix="cursor-ws-thread-test-")
32+
self.addCleanup(shutil.rmtree, self.tmp, ignore_errors=True)
33+
self.path_a = os.path.join(self.tmp, "storage-a")
34+
self.path_b = os.path.join(self.tmp, "storage-b")
35+
os.makedirs(self.path_a)
36+
os.makedirs(self.path_b)
37+
# Match resolve_workspace_path() (expand_tilde only — no realpath).
38+
self.allowed_resolved = {self.path_a, self.path_b}
39+
self._prior_workspace_env = os.environ.pop("WORKSPACE_PATH", None)
40+
self.addCleanup(self._restore_workspace_env)
41+
self.addCleanup(set_workspace_path_override, None)
42+
# With WORKSPACE_PATH popped and override None, this is resolve()'s
43+
# "override cleared" path — used by test_concurrent_clear_and_set.
44+
self.fallback_resolved = resolve_workspace_path()
45+
46+
def _restore_workspace_env(self):
47+
if self._prior_workspace_env is None:
48+
os.environ.pop("WORKSPACE_PATH", None)
49+
else:
50+
os.environ["WORKSPACE_PATH"] = self._prior_workspace_env
51+
52+
def test_concurrent_set_and_resolve_never_returns_mixed_paths(self):
53+
iterations = 500
54+
errors: list[str] = []
55+
start = threading.Barrier(9) # 1 writer + 8 readers
56+
# Seed before workers start so readers never observe the unset default path.
57+
set_workspace_path_override(self.path_a)
58+
59+
def writer() -> None:
60+
start.wait()
61+
for i in range(iterations):
62+
set_workspace_path_override(self.path_a if i % 2 == 0 else self.path_b)
63+
64+
def reader() -> None:
65+
start.wait()
66+
for _ in range(iterations):
67+
resolved = resolve_workspace_path()
68+
if resolved not in self.allowed_resolved:
69+
errors.append(
70+
f"resolve returned unexpected path: {resolved!r}"
71+
)
72+
73+
with ThreadPoolExecutor(max_workers=9) as pool:
74+
futures = [pool.submit(writer)]
75+
futures.extend(pool.submit(reader) for _ in range(8))
76+
for fut in as_completed(futures):
77+
fut.result()
78+
79+
self.assertEqual(errors, [], "\n".join(errors[:20]))
80+
81+
def test_concurrent_clear_and_set_stays_consistent(self):
82+
iterations = 200
83+
errors: list[str] = []
84+
start = threading.Barrier(5)
85+
86+
def toggler() -> None:
87+
start.wait()
88+
for i in range(iterations):
89+
if i % 3 == 0:
90+
set_workspace_path_override(None)
91+
else:
92+
set_workspace_path_override(
93+
self.path_a if i % 2 == 0 else self.path_b
94+
)
95+
96+
def reader() -> None:
97+
start.wait()
98+
for _ in range(iterations):
99+
resolved = resolve_workspace_path()
100+
if (
101+
resolved in self.allowed_resolved
102+
or resolved == self.fallback_resolved
103+
):
104+
continue
105+
errors.append(f"resolve returned unexpected path: {resolved!r}")
106+
107+
with ThreadPoolExecutor(max_workers=5) as pool:
108+
futures = [pool.submit(toggler)]
109+
futures.extend(pool.submit(reader) for _ in range(4))
110+
for fut in as_completed(futures):
111+
fut.result()
112+
113+
self.assertEqual(errors, [], "\n".join(errors[:20]))
114+
115+
116+
if __name__ == "__main__":
117+
unittest.main()

utils/workspace_path.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,27 @@
55
import os
66
import sys
77
import subprocess
8+
import threading
89

910
from .path_helpers import expand_tilde_path
1011

11-
# Module-level override set via the /api/set-workspace endpoint
12+
# Module-level override set via POST /api/set-workspace (or --base-dir).
13+
# Reads and writes are serialized by _workspace_path_lock so threaded WSGI
14+
# workers (gunicorn --threads, waitress, etc.) always see the latest override
15+
# from another thread and resolve_workspace_path's snapshot+expand stays consistent.
16+
_workspace_path_lock = threading.Lock()
1217
_workspace_path_override: str | None = None
1318

1419

15-
def set_workspace_path_override(path: str):
20+
def set_workspace_path_override(path: str | None) -> None:
1621
global _workspace_path_override
17-
_workspace_path_override = path
22+
with _workspace_path_lock:
23+
_workspace_path_override = path
1824

1925

2026
def get_workspace_path_override() -> str | None:
21-
return _workspace_path_override
27+
with _workspace_path_lock:
28+
return _workspace_path_override
2229

2330

2431
def get_default_workspace_path() -> str:
@@ -64,8 +71,10 @@ def resolve_workspace_path() -> str:
6471
is only tilde-expanded — trusted-operator escape hatch, not the same checks
6572
as the API (issue #15).
6673
"""
67-
if _workspace_path_override:
68-
return expand_tilde_path(_workspace_path_override)
74+
with _workspace_path_lock:
75+
override = _workspace_path_override
76+
if override:
77+
return expand_tilde_path(override)
6978
env_path = os.environ.get("WORKSPACE_PATH", "").strip()
7079
if env_path:
7180
return expand_tilde_path(env_path)

0 commit comments

Comments
 (0)