Skip to content

Commit fcd0f88

Browse files
committed
add initial implementation
1 parent 04d57fc commit fcd0f88

2 files changed

Lines changed: 139 additions & 6 deletions

File tree

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
"""
2+
Regression tests for issue #43 — thread-safe _workspace_path_override.
3+
4+
Run:
5+
python -m unittest tests.test_workspace_path_thread_safety -v
6+
"""
7+
8+
from __future__ import annotations
9+
10+
import os
11+
import shutil
12+
import sys
13+
import tempfile
14+
import threading
15+
import unittest
16+
from concurrent.futures import ThreadPoolExecutor, as_completed
17+
18+
REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
19+
sys.path.insert(0, REPO_ROOT)
20+
21+
from utils.workspace_path import (
22+
get_workspace_path_override,
23+
resolve_workspace_path,
24+
set_workspace_path_override,
25+
)
26+
27+
28+
class TestWorkspacePathThreadSafety(unittest.TestCase):
29+
"""Concurrent set-workspace + resolve must not observe torn global state."""
30+
31+
def setUp(self):
32+
self.tmp = tempfile.mkdtemp(prefix="cursor-ws-thread-test-")
33+
self.addCleanup(shutil.rmtree, self.tmp, ignore_errors=True)
34+
self.path_a = os.path.join(self.tmp, "storage-a")
35+
self.path_b = os.path.join(self.tmp, "storage-b")
36+
os.makedirs(self.path_a)
37+
os.makedirs(self.path_b)
38+
self._prior_workspace_env = os.environ.pop("WORKSPACE_PATH", None)
39+
self.addCleanup(self._restore_workspace_env)
40+
self.addCleanup(set_workspace_path_override, None)
41+
42+
def _restore_workspace_env(self):
43+
if self._prior_workspace_env is None:
44+
os.environ.pop("WORKSPACE_PATH", None)
45+
else:
46+
os.environ["WORKSPACE_PATH"] = self._prior_workspace_env
47+
48+
def test_concurrent_set_and_resolve_never_returns_mixed_paths(self):
49+
iterations = 500
50+
errors: list[str] = []
51+
start = threading.Barrier(9) # 1 writer + 8 readers
52+
# Seed before workers start so readers never observe the unset default path.
53+
set_workspace_path_override(self.path_a)
54+
55+
def writer() -> None:
56+
start.wait()
57+
for i in range(iterations):
58+
set_workspace_path_override(self.path_a if i % 2 == 0 else self.path_b)
59+
60+
def reader() -> None:
61+
start.wait()
62+
for _ in range(iterations):
63+
override = get_workspace_path_override()
64+
if override is None:
65+
errors.append("override was unexpectedly cleared during run")
66+
continue
67+
if override not in (self.path_a, self.path_b):
68+
errors.append(f"override returned unexpected value: {override!r}")
69+
continue
70+
resolved = resolve_workspace_path()
71+
expected = os.path.realpath(override)
72+
if resolved != expected:
73+
errors.append(
74+
f"resolve {resolved!r} != realpath(override) {expected!r}"
75+
)
76+
77+
with ThreadPoolExecutor(max_workers=9) as pool:
78+
futures = [pool.submit(writer)]
79+
futures.extend(pool.submit(reader) for _ in range(8))
80+
for fut in as_completed(futures):
81+
fut.result()
82+
83+
self.assertEqual(errors, [], "\n".join(errors[:20]))
84+
85+
def test_concurrent_clear_and_set_stays_consistent(self):
86+
iterations = 200
87+
errors: list[str] = []
88+
start = threading.Barrier(5)
89+
90+
def toggler() -> None:
91+
start.wait()
92+
for i in range(iterations):
93+
if i % 3 == 0:
94+
set_workspace_path_override(None)
95+
else:
96+
set_workspace_path_override(
97+
self.path_a if i % 2 == 0 else self.path_b
98+
)
99+
100+
def reader() -> None:
101+
start.wait()
102+
for _ in range(iterations):
103+
override = get_workspace_path_override()
104+
resolved = resolve_workspace_path()
105+
if override is None:
106+
continue
107+
if override not in (self.path_a, self.path_b):
108+
errors.append(f"unexpected override: {override!r}")
109+
elif resolved != os.path.realpath(override):
110+
errors.append(
111+
f"resolve {resolved!r} != realpath({override!r})"
112+
)
113+
114+
with ThreadPoolExecutor(max_workers=5) as pool:
115+
futures = [pool.submit(toggler)]
116+
futures.extend(pool.submit(reader) for _ in range(4))
117+
for fut in as_completed(futures):
118+
fut.result()
119+
120+
self.assertEqual(errors, [], "\n".join(errors[:20]))
121+
122+
123+
if __name__ == "__main__":
124+
unittest.main()

utils/workspace_path.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,27 @@
55
import os
66
import sys
77
import subprocess
8+
import threading
89

910
from .path_helpers import expand_tilde_path
1011

11-
# Module-level override set via the /api/set-workspace endpoint
12+
# Module-level override set via POST /api/set-workspace (or --base-dir).
13+
# All reads and writes are serialized by _workspace_path_lock so threaded
14+
# WSGI workers (gunicorn --threads, waitress, etc.) cannot observe torn
15+
# state between set_workspace_path_override and resolve_workspace_path.
16+
_workspace_path_lock = threading.Lock()
1217
_workspace_path_override: str | None = None
1318

1419

15-
def set_workspace_path_override(path: str):
20+
def set_workspace_path_override(path: str | None) -> None:
1621
global _workspace_path_override
17-
_workspace_path_override = path
22+
with _workspace_path_lock:
23+
_workspace_path_override = path
1824

1925

2026
def get_workspace_path_override() -> str | None:
21-
return _workspace_path_override
27+
with _workspace_path_lock:
28+
return _workspace_path_override
2229

2330

2431
def get_default_workspace_path() -> str:
@@ -64,8 +71,10 @@ def resolve_workspace_path() -> str:
6471
is only tilde-expanded — trusted-operator escape hatch, not the same checks
6572
as the API (issue #15).
6673
"""
67-
if _workspace_path_override:
68-
return expand_tilde_path(_workspace_path_override)
74+
with _workspace_path_lock:
75+
override = _workspace_path_override
76+
if override:
77+
return expand_tilde_path(override)
6978
env_path = os.environ.get("WORKSPACE_PATH", "").strip()
7079
if env_path:
7180
return expand_tilde_path(env_path)

0 commit comments

Comments
 (0)