Skip to content

Commit 67a7940

Browse files
committed
retry faileds, remove jobs and check completion
1 parent 4116e71 commit 67a7940

26 files changed

Lines changed: 1627 additions & 14 deletions

.gitmodules

Lines changed: 0 additions & 3 deletions
This file was deleted.

src/omniq/_ops.py

Lines changed: 192 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22
import redis
33

44
from dataclasses import dataclass
5-
from typing import Optional, Any
5+
from typing import Optional, Any, List
66
from threading import Lock
77

88
from .clock import now_ms
99
from .ids import new_ulid
10-
from .types import ReservePaused, ReserveJob, ReserveResult, AckFailResult
10+
from .types import ReservePaused, ReserveJob, ReserveResult, AckFailResult, BatchRemoveResult, BatchRetryFailedResult
1111
from .transport import RedisLike
1212
from .scripts import OmniqScripts
13-
from .helper import queue_base, queue_anchor
13+
from .helper import queue_base, queue_anchor, check_completion_anchor
1414

1515
@dataclass
1616
class OmniqOps:
@@ -300,6 +300,195 @@ def job_timeout_ms(self, *, queue: str, job_id: str, default_ms: int = 60_000) -
300300
except Exception:
301301
n = 0
302302
return n if n > 0 else int(default_ms)
303+
304+
def retry_failed(self, *, queue: str, job_id: str, now_ms_override: int = 0) -> None:
305+
anchor = queue_anchor(queue)
306+
nms = now_ms_override or now_ms()
307+
308+
res = self._evalsha_with_noscript_fallback(
309+
self.scripts.retry_failed.sha,
310+
self.scripts.retry_failed.src,
311+
1,
312+
anchor,
313+
job_id,
314+
str(int(nms)),
315+
)
316+
317+
if not isinstance(res, list) or len(res) < 1:
318+
raise RuntimeError(f"Unexpected RETRY_FAILED response: {res}")
319+
320+
if res[0] == "OK":
321+
return
322+
323+
if res[0] == "ERR":
324+
reason = str(res[1]) if len(res) > 1 else "UNKNOWN"
325+
raise RuntimeError(f"RETRY_FAILED failed: {reason}")
326+
327+
raise RuntimeError(f"Unexpected RETRY_FAILED response: {res}")
328+
329+
def retry_failed_batch(
330+
self,
331+
*,
332+
queue: str,
333+
job_ids: List[str],
334+
now_ms_override: int = 0,
335+
) -> BatchRetryFailedResult:
336+
if len(job_ids) > 100:
337+
raise ValueError("retry_failed_batch max is 100 job_ids per call")
338+
339+
anchor = queue_anchor(queue)
340+
nms = now_ms_override or now_ms()
341+
342+
argv: list[str] = [str(int(nms)), str(len(job_ids))]
343+
argv.extend([str(j) for j in job_ids])
344+
345+
res = self._evalsha_with_noscript_fallback(
346+
self.scripts.retry_failed_batch.sha,
347+
self.scripts.retry_failed_batch.src,
348+
1,
349+
anchor,
350+
*argv,
351+
)
352+
353+
if not isinstance(res, list):
354+
raise RuntimeError(f"Unexpected RETRY_FAILED_BATCH response: {res}")
355+
356+
if len(res) >= 2 and str(res[0]) == "ERR":
357+
reason = str(res[1])
358+
extra = str(res[2]) if len(res) > 2 else ""
359+
raise RuntimeError(f"RETRY_FAILED_BATCH failed: {reason} {extra}".strip())
360+
361+
out: BatchRetryFailedResult = []
362+
i = 0
363+
while i < len(res):
364+
job_id = str(res[i] or "")
365+
status = str(res[i + 1] or "")
366+
reason: Optional[str] = None
367+
if status == "ERR":
368+
reason = str(res[i + 2] or "UNKNOWN")
369+
i += 3
370+
else:
371+
i += 2
372+
out.append((job_id, status, reason))
373+
return out
374+
375+
def remove_job(self, *, queue: str, job_id: str, lane: str) -> str:
376+
anchor = queue_anchor(queue)
377+
378+
res = self._evalsha_with_noscript_fallback(
379+
self.scripts.remove_job.sha,
380+
self.scripts.remove_job.src,
381+
1,
382+
anchor,
383+
job_id,
384+
lane,
385+
)
386+
387+
if not isinstance(res, list) or len(res) < 1:
388+
raise RuntimeError(f"Unexpected REMOVE_JOB response: {res}")
389+
390+
if res[0] == "OK":
391+
return str(res[0] or "")
392+
393+
if res[0] == "ERR":
394+
reason = str(res[1]) if len(res) > 1 else "UNKNOWN"
395+
raise RuntimeError(f"REMOVE_JOB failed: {reason}")
396+
397+
raise RuntimeError(f"Unexpected REMOVE_JOB response: {res}")
398+
399+
def remove_jobs_batch(
400+
self,
401+
*,
402+
queue: str,
403+
lane: str,
404+
job_ids: List[str],
405+
) -> BatchRemoveResult:
406+
if len(job_ids) > 100:
407+
raise ValueError("remove_jobs_batch max is 100 job_ids per call")
408+
409+
anchor = queue_anchor(queue)
410+
411+
argv: list[str] = [str(lane), str(len(job_ids))]
412+
argv.extend([str(j) for j in job_ids])
413+
414+
res = self._evalsha_with_noscript_fallback(
415+
self.scripts.remove_jobs_batch.sha,
416+
self.scripts.remove_jobs_batch.src,
417+
1,
418+
anchor,
419+
*argv,
420+
)
421+
422+
if not isinstance(res, list):
423+
raise RuntimeError(f"Unexpected REMOVE_JOBS_BATCH response: {res}")
424+
425+
if len(res) >= 2 and str(res[0]) == "ERR":
426+
reason = str(res[1])
427+
extra = str(res[2]) if len(res) > 2 else ""
428+
raise RuntimeError(f"REMOVE_JOBS_BATCH failed: {reason} {extra}".strip())
429+
430+
out: BatchRemoveResult = []
431+
i = 0
432+
while i < len(res):
433+
job_id = str(res[i] or "")
434+
status = str(res[i + 1] or "")
435+
reason: Optional[str] = None
436+
if status == "ERR":
437+
reason = str(res[i + 2] or "UNKNOWN")
438+
i += 3
439+
else:
440+
i += 2
441+
out.append((job_id, status, reason))
442+
return out
443+
444+
def check_completion_init_job_counter(self, *, key: str, expected: int) -> None:
445+
anchor = check_completion_anchor(key)
446+
447+
res = self._evalsha_with_noscript_fallback(
448+
self.scripts.check_completion_init.sha,
449+
self.scripts.check_completion_init.src,
450+
1,
451+
anchor,
452+
str(int(expected)),
453+
)
454+
455+
if not isinstance(res, list) or len(res) < 1:
456+
raise RuntimeError(f"Unexpected CHECK_COMPLETION_INIT response: {res}")
457+
458+
if res[0] == "OK":
459+
return
460+
461+
if res[0] == "ERR":
462+
reason = str(res[1]) if len(res) > 1 else "UNKNOWN"
463+
raise RuntimeError(f"CHECK_COMPLETION_INIT failed: {reason}")
464+
465+
raise RuntimeError(f"Unexpected CHECK_COMPLETION_INIT response: {res}")
466+
467+
def check_completion_job_decrement(self, *, key: str, child_id: str) -> int:
468+
anchor = check_completion_anchor(key)
469+
470+
res = self._evalsha_with_noscript_fallback(
471+
self.scripts.check_completion_decrement.sha,
472+
self.scripts.check_completion_decrement.src,
473+
1,
474+
anchor,
475+
str(child_id),
476+
)
477+
478+
if not isinstance(res, list) or len(res) < 2:
479+
raise RuntimeError(f"Unexpected CHECK_COMPLETION_DECREMENT response: {res}")
480+
481+
if res[0] == "OK":
482+
try:
483+
return int(res[1])
484+
except Exception:
485+
raise RuntimeError(f"Unexpected CHECK_COMPLETION_DECREMENT remaining: {res}")
486+
487+
if res[0] == "ERR":
488+
reason = str(res[1]) if len(res) > 1 else "UNKNOWN"
489+
raise RuntimeError(f"CHECK_COMPLETION_DECREMENT failed: {reason}")
490+
491+
raise RuntimeError(f"Unexpected CHECK_COMPLETION_DECREMENT response: {res}")
303492

304493
@staticmethod
305494
def paused_backoff_s(poll_interval_s: float) -> float:

src/omniq/check_completion.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from dataclasses import dataclass
2+
from typing import Optional
3+
4+
from ._ops import OmniqOps
5+
6+
7+
@dataclass(frozen=True)
8+
class CheckCompletion:
9+
ops: OmniqOps
10+
default_child_id: str
11+
12+
def InitJobCounter(self, key: str, expected: int) -> None:
13+
self.ops.check_completion_init_job_counter(key=key, expected=int(expected))
14+
15+
def JobDecrement(self, key: str, child_id: Optional[str] = None) -> int:
16+
cid = (child_id or self.default_child_id or "").strip()
17+
if not cid:
18+
raise ValueError("child_id is required (or provide default_child_id)")
19+
return int(self.ops.check_completion_job_decrement(key=key, child_id=cid))

src/omniq/client.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,18 @@ def resume(self, *, queue: str) -> int:
102102
def is_paused(self, *, queue: str) -> bool:
103103
return self._ops.is_paused(queue=queue)
104104

105+
def retry_failed(self, *, queue: str, job_id: str, now_ms_override: int = 0) -> None:
106+
return self._ops.retry_failed(queue=queue, job_id=job_id, now_ms_override=now_ms_override)
107+
108+
def retry_failed_batch(self, *, queue: str, job_ids: list[str], now_ms_override: int = 0) -> None:
109+
return self._ops.retry_failed_batch(queue=queue, job_ids=job_ids, now_ms_override=now_ms_override)
110+
111+
def remove_job(self, *, queue: str, job_id: str, lane: str) -> None:
112+
return self._ops.remove_job(queue=queue, job_id=job_id, lane=lane)
113+
114+
def remove_jobs_batch(self, *, queue: str, lane: str, job_ids: list[str]):
115+
return self._ops.remove_jobs_batch(queue=queue, lane=lane, job_ids=job_ids)
116+
105117
def consume(
106118
self,
107119
*,

src/omniq/consumer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from ._ops import OmniqOps
99
from .types import JobCtx, ReserveJob
10+
from .check_completion import CheckCompletion
1011

1112
@dataclass
1213
class StopController:
@@ -176,6 +177,7 @@ def on_sigint(signum, frame):
176177
except Exception:
177178
payload_obj = res.payload
178179

180+
cc = CheckCompletion(ops=ops, default_child_id=res.job_id)
179181
ctx = JobCtx(
180182
queue=queue,
181183
job_id=res.job_id,
@@ -185,6 +187,7 @@ def on_sigint(signum, frame):
185187
lock_until_ms=res.lock_until_ms,
186188
lease_token=res.lease_token,
187189
gid=res.gid,
190+
check_completion=cc,
188191
)
189192

190193
if verbose:

src/omniq/core

Lines changed: 0 additions & 1 deletion
This file was deleted.
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
local anchor = KEYS[1]
2+
local job_id = ARGV[1]
3+
local now_ms = tonumber(ARGV[2] or "0")
4+
local lease_token = ARGV[3]
5+
local err_msg = ARGV[4]
6+
7+
local DEFAULT_GROUP_LIMIT = 1
8+
local MAX_ERR_BYTES = 4096
9+
10+
local function derive_base(a)
11+
if a == nil or a == "" then return "" end
12+
if string.sub(a, -5) == ":meta" then
13+
return string.sub(a, 1, -6)
14+
end
15+
return a
16+
end
17+
18+
local base = derive_base(anchor)
19+
20+
local k_job = base .. ":job:" .. job_id
21+
local k_active = base .. ":active"
22+
local k_delayed = base .. ":delayed"
23+
local k_failed = base .. ":failed"
24+
local k_gready = base .. ":groups:ready"
25+
26+
local function to_i(v)
27+
if v == false or v == nil or v == '' then return 0 end
28+
local n = tonumber(v)
29+
if n == nil then return 0 end
30+
return math.floor(n)
31+
end
32+
33+
local function dec_floor0(key)
34+
local v = to_i(redis.call("DECR", key))
35+
if v < 0 then
36+
redis.call("SET", key, "0")
37+
return 0
38+
end
39+
return v
40+
end
41+
42+
local function group_limit_for(gid)
43+
local k_glimit = base .. ":g:" .. gid .. ":limit"
44+
local lim = to_i(redis.call("GET", k_glimit))
45+
if lim <= 0 then return DEFAULT_GROUP_LIMIT end
46+
return lim
47+
end
48+
49+
local function maybe_store_last_error()
50+
if err_msg == nil or err_msg == "" then return end
51+
if string.len(err_msg) > MAX_ERR_BYTES then
52+
err_msg = string.sub(err_msg, 1, MAX_ERR_BYTES)
53+
end
54+
redis.call("HSET", k_job,
55+
"last_error", err_msg,
56+
"last_error_ms", tostring(now_ms)
57+
)
58+
end
59+
60+
if lease_token == nil or lease_token == "" then
61+
return {"ERR", "TOKEN_REQUIRED"}
62+
end
63+
64+
local cur_token = redis.call("HGET", k_job, "lease_token") or ""
65+
if cur_token ~= lease_token then
66+
return {"ERR", "TOKEN_MISMATCH"}
67+
end
68+
69+
if redis.call("ZREM", k_active, job_id) ~= 1 then
70+
return {"ERR", "NOT_ACTIVE"}
71+
end
72+
73+
maybe_store_last_error()
74+
75+
local gid = redis.call("HGET", k_job, "gid")
76+
if gid and gid ~= "" then
77+
local k_ginflight = base .. ":g:" .. gid .. ":inflight"
78+
local inflight = dec_floor0(k_ginflight)
79+
local limit = group_limit_for(gid)
80+
local k_gwait = base .. ":g:" .. gid .. ":wait"
81+
if inflight < limit and to_i(redis.call("LLEN", k_gwait)) > 0 then
82+
redis.call("ZADD", k_gready, now_ms, gid)
83+
end
84+
end
85+
86+
local attempt = to_i(redis.call("HGET", k_job, "attempt"))
87+
local max_attempts = to_i(redis.call("HGET", k_job, "max_attempts"))
88+
if max_attempts <= 0 then max_attempts = 1 end
89+
local backoff_ms = to_i(redis.call("HGET", k_job, "backoff_ms"))
90+
91+
if attempt >= max_attempts then
92+
redis.call("HSET", k_job,
93+
"state", "failed",
94+
"updated_ms", tostring(now_ms),
95+
"lease_token", "",
96+
"lock_until_ms", ""
97+
)
98+
redis.call("LPUSH", k_failed, job_id)
99+
return {"FAILED"}
100+
end
101+
102+
local due_ms = now_ms + backoff_ms
103+
redis.call("HSET", k_job,
104+
"state", "delayed",
105+
"due_ms", tostring(due_ms),
106+
"updated_ms", tostring(now_ms),
107+
"lease_token", "",
108+
"lock_until_ms", ""
109+
)
110+
redis.call("ZADD", k_delayed, due_ms, job_id)
111+
112+
return {"RETRY", tostring(due_ms)}

0 commit comments

Comments
 (0)