Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions tests/workflows/test_workflows_fan_out.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from taskito.workflows import NodeStatus, Workflow, WorkflowState

WorkflowWorkerFactory = Callable[[], AbstractContextManager[threading.Thread]]
PollUntil = Any # the conftest fixture's runtime type


def test_fan_out_each(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
Expand Down Expand Up @@ -214,7 +215,9 @@ def aggregate(results: list[int]) -> str:
assert final.nodes["collect"].status == NodeStatus.SKIPPED


def test_fan_out_cancellation(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
def test_fan_out_cancellation(
queue: Queue, workflow_worker: WorkflowWorkerFactory, poll_until: PollUntil
) -> None:
"""Cancelling a workflow mid-fan-out skips pending children."""

@queue.task()
Expand All @@ -223,7 +226,9 @@ def source() -> list[int]:

@queue.task()
def slow_process(x: int) -> int:
time.sleep(10) # will be cancelled
# Intentional pacing — the test needs the children to be running so the
# cancel happens mid-fan-out.
time.sleep(10)
return x

@queue.task()
Expand All @@ -237,8 +242,11 @@ def aggregate(results: list[int]) -> str:

with workflow_worker():
run = queue.submit_workflow(wf)
# Wait for source to complete and fan-out to expand
time.sleep(2)
poll_until(
lambda: run.node_status("fetch") == NodeStatus.COMPLETED,
timeout=10,
message="source did not complete; fan-out never expanded",
)
run.cancel()
snapshot = run.status()

Expand Down
33 changes: 26 additions & 7 deletions tests/workflows/test_workflows_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,20 @@
from __future__ import annotations

import threading
import time
from collections.abc import Callable
from contextlib import AbstractContextManager
from typing import Any

from taskito import Queue
from taskito.workflows import NodeStatus, Workflow, WorkflowState

WorkflowWorkerFactory = Callable[[], AbstractContextManager[threading.Thread]]
PollUntil = Any # the conftest fixture's runtime type


def test_gate_pauses_workflow(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
def test_gate_pauses_workflow(
queue: Queue, workflow_worker: WorkflowWorkerFactory, poll_until: PollUntil
) -> None:
"""A gate node enters WAITING_APPROVAL and blocks downstream."""

@queue.task()
Expand All @@ -27,7 +30,11 @@ def ok_task() -> str:

with workflow_worker():
run = queue.submit_workflow(wf)
time.sleep(3) # Let "a" complete
poll_until(
lambda: run.node_status("approve") == NodeStatus.WAITING_APPROVAL,
timeout=10,
message="gate did not reach WAITING_APPROVAL",
)
snapshot = run.status()

assert snapshot.state == WorkflowState.RUNNING
Expand All @@ -36,7 +43,9 @@ def ok_task() -> str:
assert snapshot.nodes["b"].status == NodeStatus.PENDING


def test_approve_gate_resumes(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
def test_approve_gate_resumes(
queue: Queue, workflow_worker: WorkflowWorkerFactory, poll_until: PollUntil
) -> None:
"""Approving a gate lets downstream steps run to completion."""

collected: list[str] = []
Expand All @@ -53,7 +62,11 @@ def ok_task() -> str:

with workflow_worker():
run = queue.submit_workflow(wf)
time.sleep(2) # Let "a" complete and gate enter WAITING_APPROVAL
poll_until(
lambda: run.node_status("approve") == NodeStatus.WAITING_APPROVAL,
timeout=10,
message="gate did not reach WAITING_APPROVAL",
)
queue.approve_gate(run.id, "approve")
final = run.wait(timeout=15)

Expand All @@ -63,7 +76,9 @@ def ok_task() -> str:
assert len(collected) == 2 # "a" and "b"


def test_reject_gate_fails(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
def test_reject_gate_fails(
queue: Queue, workflow_worker: WorkflowWorkerFactory, poll_until: PollUntil
) -> None:
"""Rejecting a gate fails it and skips downstream."""

@queue.task()
Expand All @@ -77,7 +92,11 @@ def ok_task() -> str:

with workflow_worker():
run = queue.submit_workflow(wf)
time.sleep(2)
poll_until(
lambda: run.node_status("approve") == NodeStatus.WAITING_APPROVAL,
timeout=10,
message="gate did not reach WAITING_APPROVAL",
)
queue.reject_gate(run.id, "approve", error="not approved")
final = run.wait(timeout=15)

Expand Down
14 changes: 12 additions & 2 deletions tests/workflows/test_workflows_subworkflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import threading
from collections.abc import Callable
from contextlib import AbstractContextManager
from typing import Any

from taskito import Queue
from taskito.workflows import NodeStatus, Workflow, WorkflowState

WorkflowWorkerFactory = Callable[[], AbstractContextManager[threading.Thread]]
PollUntil = Any # the conftest fixture's runtime type


def test_sub_workflow_executes(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
Expand Down Expand Up @@ -119,13 +121,17 @@ def broken_sub() -> Workflow:
assert final.nodes["after"].status == NodeStatus.SKIPPED


def test_cancel_parent_cascades(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
def test_cancel_parent_cascades(
queue: Queue, workflow_worker: WorkflowWorkerFactory, poll_until: PollUntil
) -> None:
"""Cancelling a parent workflow cancels the child sub-workflow too."""

import time

@queue.task()
def slow_task() -> str:
# Intentional pacing — keeps the child running so the cancel cascades
# mid-flight rather than after natural completion.
time.sleep(30)
return "slow"

Expand All @@ -140,7 +146,11 @@ def slow_sub() -> Workflow:

with workflow_worker():
run = queue.submit_workflow(wf)
time.sleep(2) # Let sub-workflow submit
poll_until(
lambda: run.node_status("sub") == NodeStatus.RUNNING,
timeout=10,
message="sub-workflow did not reach RUNNING",
)
run.cancel()
snapshot = run.status()

Expand Down
Loading