Skip to content

Commit a7cfda4

Browse files
Add connection hardening with retry policy (#169)
Implement automatic retries with exponential backoff and jitter for transient server errors, improving SDK reliability for production use. New module: src/durable_workflow/retry_policy.py - RetryPolicy class with configurable retry behavior - Retries connection errors, timeouts, 5xx server errors, 429 rate limit - Does not retry 4xx client errors (except 429) - Exponential backoff with jitter (default: 0.1s initial, 5s max, 2x multiplier) - Default: 3 attempts per request Client changes: - Added retry_policy parameter to Client.__init__ (optional, defaults to RetryPolicy()) - Updated _request() to wrap httpx calls with retry logic - Updated health() to use _request() for consistency (gets retries too) Tests: - 15 new unit tests for RetryPolicy covering all retry scenarios - Updated test_sync.py mock (health now uses _request not _http.get) - All 219 tests pass, mypy --strict clean, ruff clean Example usage: Closes a Phase 4 deliverable (connection hardening). Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
1 parent 559409d commit a7cfda4

4 files changed

Lines changed: 263 additions & 15 deletions

File tree

src/durable_workflow/client.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
WorkflowTerminated,
1414
_raise_for_status,
1515
)
16+
from .retry_policy import RetryPolicy
1617

1718
PROTOCOL_VERSION = "1.0"
1819
CONTROL_PLANE_VERSION = "2"
@@ -233,10 +234,12 @@ def __init__(
233234
token: str | None = None,
234235
namespace: str = "default",
235236
timeout: float = 60.0,
237+
retry_policy: RetryPolicy | None = None,
236238
) -> None:
237239
self.base_url = base_url.rstrip("/")
238240
self.token = token
239241
self.namespace = namespace
242+
self.retry_policy = retry_policy or RetryPolicy()
240243
self._http = httpx.AsyncClient(base_url=self.base_url, timeout=timeout)
241244

242245
async def aclose(self) -> None:
@@ -269,19 +272,29 @@ async def _request(
269272
timeout: float | None = None,
270273
context: str = "",
271274
) -> Any:
272-
resp = await self._http.request(
273-
method,
274-
f"/api{path}",
275-
headers=self._headers(worker=worker),
276-
json=json,
277-
timeout=timeout,
278-
)
279-
if resp.status_code >= 400:
275+
async def _do_request() -> httpx.Response:
276+
resp = await self._http.request(
277+
method,
278+
f"/api{path}",
279+
headers=self._headers(worker=worker),
280+
json=json,
281+
timeout=timeout,
282+
)
283+
# Raise HTTPStatusError for 4xx/5xx so retry policy can catch it
284+
resp.raise_for_status()
285+
return resp
286+
287+
try:
288+
resp = await self.retry_policy.execute(_do_request)
289+
except httpx.HTTPStatusError as exc:
290+
# Convert to our custom exception types
280291
try:
281-
body = resp.json()
292+
body = exc.response.json()
282293
except ValueError:
283-
body = resp.text
284-
_raise_for_status(resp.status_code, body, context=context)
294+
body = exc.response.text
295+
_raise_for_status(exc.response.status_code, body, context=context)
296+
raise # unreachable, but keeps type checker happy
297+
285298
if resp.status_code == 204 or not resp.content:
286299
return None
287300
return resp.json()
@@ -293,9 +306,8 @@ def get_workflow_handle(
293306

294307
# ── Health ─────────────────────────────────────────────────────────
295308
async def health(self) -> dict[str, Any]:
296-
resp = await self._http.get("/api/health")
297-
resp.raise_for_status()
298-
result: dict[str, Any] = resp.json()
309+
result = await self._request("GET", "/health")
310+
assert isinstance(result, dict)
299311
return result
300312

301313
# ── Workflows ──────────────────────────────────────────────────────
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
import random
5+
from collections.abc import Awaitable, Callable
6+
from dataclasses import dataclass
7+
from typing import TypeVar
8+
9+
import httpx
10+
11+
T = TypeVar("T")
12+
13+
14+
@dataclass
15+
class RetryPolicy:
16+
"""
17+
Retry policy for transient server errors.
18+
19+
Retries requests that fail with transient errors (connection errors,
20+
timeouts, 5xx server errors, 429 rate limit). Does not retry client
21+
errors (4xx except 429).
22+
23+
Uses exponential backoff with jitter to avoid thundering herd.
24+
"""
25+
26+
max_attempts: int = 3
27+
initial_backoff_seconds: float = 0.1
28+
max_backoff_seconds: float = 5.0
29+
backoff_multiplier: float = 2.0
30+
jitter: bool = True
31+
32+
def should_retry(self, exc: Exception, attempt: int) -> bool:
33+
"""Check if the error is retryable and we haven't exceeded max attempts."""
34+
if attempt >= self.max_attempts:
35+
return False
36+
37+
# Retry connection errors and timeouts
38+
if isinstance(exc, (httpx.ConnectError, httpx.TimeoutException, httpx.NetworkError)):
39+
return True
40+
41+
# Retry 5xx server errors and 429 rate limit
42+
if isinstance(exc, httpx.HTTPStatusError):
43+
return exc.response.status_code >= 500 or exc.response.status_code == 429
44+
45+
return False
46+
47+
def backoff_seconds(self, attempt: int) -> float:
48+
"""Calculate backoff duration for the given attempt number (0-indexed)."""
49+
backoff = min(
50+
self.initial_backoff_seconds * (self.backoff_multiplier**attempt),
51+
self.max_backoff_seconds,
52+
)
53+
if self.jitter:
54+
# Add ±25% jitter
55+
backoff *= random.uniform(0.75, 1.25)
56+
return backoff
57+
58+
async def execute(self, fn: Callable[[], Awaitable[T]]) -> T:
59+
"""
60+
Execute the given async function with retries.
61+
62+
Raises the last exception if all retries are exhausted.
63+
"""
64+
attempt = 0
65+
last_exc: Exception | None = None
66+
67+
while attempt < self.max_attempts:
68+
try:
69+
result = await fn()
70+
return result
71+
except Exception as exc:
72+
last_exc = exc
73+
if not self.should_retry(exc, attempt):
74+
raise
75+
76+
if attempt + 1 < self.max_attempts:
77+
backoff = self.backoff_seconds(attempt)
78+
await asyncio.sleep(backoff)
79+
80+
attempt += 1
81+
82+
# All retries exhausted
83+
if last_exc:
84+
raise last_exc
85+
raise RuntimeError("retry loop exhausted with no exception")

tests/test_retry_policy.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
from __future__ import annotations
2+
3+
import httpx
4+
import pytest
5+
6+
from durable_workflow.retry_policy import RetryPolicy
7+
8+
9+
class TestRetryPolicy:
10+
def test_should_retry_connection_error(self) -> None:
11+
policy = RetryPolicy(max_attempts=3)
12+
exc = httpx.ConnectError("connection failed")
13+
assert policy.should_retry(exc, attempt=0) is True
14+
assert policy.should_retry(exc, attempt=1) is True
15+
assert policy.should_retry(exc, attempt=2) is True
16+
assert policy.should_retry(exc, attempt=3) is False # max_attempts reached
17+
18+
def test_should_retry_timeout(self) -> None:
19+
policy = RetryPolicy(max_attempts=3)
20+
exc = httpx.TimeoutException("timeout")
21+
assert policy.should_retry(exc, attempt=0) is True
22+
23+
def test_should_retry_network_error(self) -> None:
24+
policy = RetryPolicy(max_attempts=3)
25+
exc = httpx.NetworkError("network error")
26+
assert policy.should_retry(exc, attempt=0) is True
27+
28+
def test_should_retry_5xx_server_error(self) -> None:
29+
policy = RetryPolicy(max_attempts=3)
30+
response = httpx.Response(status_code=500, request=httpx.Request("GET", "http://test"))
31+
exc = httpx.HTTPStatusError("server error", request=response.request, response=response)
32+
assert policy.should_retry(exc, attempt=0) is True
33+
34+
def test_should_retry_429_rate_limit(self) -> None:
35+
policy = RetryPolicy(max_attempts=3)
36+
response = httpx.Response(status_code=429, request=httpx.Request("GET", "http://test"))
37+
exc = httpx.HTTPStatusError("rate limited", request=response.request, response=response)
38+
assert policy.should_retry(exc, attempt=0) is True
39+
40+
def test_should_not_retry_4xx_client_error(self) -> None:
41+
policy = RetryPolicy(max_attempts=3)
42+
response = httpx.Response(status_code=404, request=httpx.Request("GET", "http://test"))
43+
exc = httpx.HTTPStatusError("not found", request=response.request, response=response)
44+
assert policy.should_retry(exc, attempt=0) is False
45+
46+
def test_should_not_retry_400_bad_request(self) -> None:
47+
policy = RetryPolicy(max_attempts=3)
48+
response = httpx.Response(status_code=400, request=httpx.Request("GET", "http://test"))
49+
exc = httpx.HTTPStatusError("bad request", request=response.request, response=response)
50+
assert policy.should_retry(exc, attempt=0) is False
51+
52+
def test_should_not_retry_other_exceptions(self) -> None:
53+
policy = RetryPolicy(max_attempts=3)
54+
exc = ValueError("not a network error")
55+
assert policy.should_retry(exc, attempt=0) is False
56+
57+
def test_backoff_calculation(self) -> None:
58+
policy = RetryPolicy(
59+
initial_backoff_seconds=0.1,
60+
max_backoff_seconds=5.0,
61+
backoff_multiplier=2.0,
62+
jitter=False,
63+
)
64+
assert policy.backoff_seconds(0) == 0.1
65+
assert policy.backoff_seconds(1) == 0.2
66+
assert policy.backoff_seconds(2) == 0.4
67+
assert policy.backoff_seconds(10) == 5.0 # capped at max
68+
69+
def test_backoff_with_jitter(self) -> None:
70+
policy = RetryPolicy(initial_backoff_seconds=1.0, jitter=True)
71+
# Jitter should give us ±25%
72+
backoff = policy.backoff_seconds(0)
73+
assert 0.75 <= backoff <= 1.25
74+
75+
@pytest.mark.asyncio
76+
async def test_execute_success_on_first_try(self) -> None:
77+
policy = RetryPolicy(max_attempts=3)
78+
call_count = 0
79+
80+
async def fn() -> str:
81+
nonlocal call_count
82+
call_count += 1
83+
return "success"
84+
85+
result = await policy.execute(fn)
86+
assert result == "success"
87+
assert call_count == 1
88+
89+
@pytest.mark.asyncio
90+
async def test_execute_success_after_retry(self) -> None:
91+
policy = RetryPolicy(max_attempts=3, initial_backoff_seconds=0.01, jitter=False)
92+
call_count = 0
93+
94+
async def fn() -> str:
95+
nonlocal call_count
96+
call_count += 1
97+
if call_count < 3:
98+
raise httpx.ConnectError("connection failed")
99+
return "success"
100+
101+
result = await policy.execute(fn)
102+
assert result == "success"
103+
assert call_count == 3
104+
105+
@pytest.mark.asyncio
106+
async def test_execute_exhausted_retries(self) -> None:
107+
policy = RetryPolicy(max_attempts=3, initial_backoff_seconds=0.01, jitter=False)
108+
call_count = 0
109+
110+
async def fn() -> str:
111+
nonlocal call_count
112+
call_count += 1
113+
raise httpx.ConnectError("connection failed")
114+
115+
with pytest.raises(httpx.ConnectError):
116+
await policy.execute(fn)
117+
118+
assert call_count == 3
119+
120+
@pytest.mark.asyncio
121+
async def test_execute_non_retryable_error(self) -> None:
122+
policy = RetryPolicy(max_attempts=3)
123+
call_count = 0
124+
125+
async def fn() -> str:
126+
nonlocal call_count
127+
call_count += 1
128+
response = httpx.Response(status_code=400, request=httpx.Request("GET", "http://test"))
129+
raise httpx.HTTPStatusError("bad request", request=response.request, response=response)
130+
131+
with pytest.raises(httpx.HTTPStatusError):
132+
await policy.execute(fn)
133+
134+
assert call_count == 1 # Should not retry 400 errors
135+
136+
@pytest.mark.asyncio
137+
async def test_execute_retries_500_but_not_400(self) -> None:
138+
policy = RetryPolicy(max_attempts=3, initial_backoff_seconds=0.01, jitter=False)
139+
call_count = 0
140+
141+
async def fn() -> str:
142+
nonlocal call_count
143+
call_count += 1
144+
if call_count == 1:
145+
response = httpx.Response(status_code=500, request=httpx.Request("GET", "http://test"))
146+
raise httpx.HTTPStatusError("server error", request=response.request, response=response)
147+
return "success"
148+
149+
result = await policy.execute(fn)
150+
assert result == "success"
151+
assert call_count == 2 # First attempt 500, second attempt success

tests/test_sync.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class TestSyncClientHealth:
2323
def test_health(self) -> None:
2424
client = Client("http://localhost:8080")
2525
resp = _mock_response(200, {"status": "ok"})
26-
with patch.object(client._async._http, "get", new_callable=AsyncMock, return_value=resp):
26+
with patch.object(client._async._http, "request", new_callable=AsyncMock, return_value=resp):
2727
result = client.health()
2828
assert result["status"] == "ok"
2929

0 commit comments

Comments
 (0)