diff --git a/Makefile b/Makefile index 4300d45b..4e8619db 100644 --- a/Makefile +++ b/Makefile @@ -59,6 +59,10 @@ tests-ind: tests-timing: @make tests-ind 2>&1 | ./scripts/test_times.py +.PHONY : waiting-task-stress +waiting-task-stress: + $(RUN) python -m unittest -v tests.SpiffWorkflow.bpmn.WaitingTaskStressBenchmark + wheel: clean $(RUN) python -m build --sdist --wheel --outdir dist/ diff --git a/SpiffWorkflow/bpmn/serializer/default/workflow.py b/SpiffWorkflow/bpmn/serializer/default/workflow.py index d2c34c16..9284c1d1 100644 --- a/SpiffWorkflow/bpmn/serializer/default/workflow.py +++ b/SpiffWorkflow/bpmn/serializer/default/workflow.py @@ -201,6 +201,7 @@ def from_dict(self, dct): # Handle the remaining top workflow attributes self.subprocesses_from_dict(dct['subprocesses'], workflow) workflow.bpmn_events = self.registry.restore(dct.pop('bpmn_events', [])) + workflow._rebuild_waiting_task_index() return workflow diff --git a/SpiffWorkflow/bpmn/util/subworkflow.py b/SpiffWorkflow/bpmn/util/subworkflow.py index 6968a1f1..1f60e7dc 100644 --- a/SpiffWorkflow/bpmn/util/subworkflow.py +++ b/SpiffWorkflow/bpmn/util/subworkflow.py @@ -35,6 +35,12 @@ def data_objects(self): def get_tasks_iterator(self, first_task=None, **kwargs): return BpmnTaskIterator(first_task or self.task_tree, **kwargs) + def update_waiting_tasks(self): + self.top_workflow._refresh_internal_waiting_tasks() + + def _task_state_changed_notify(self, task, old_state, new_state): + self.top_workflow._waiting_task_state_changed(task, old_state, new_state) + class BpmnSubWorkflow(BpmnBaseWorkflow): @@ -68,4 +74,3 @@ def collect_log_extras(self, dct=None): dct = super().collect_log_extras() dct.update({'parent_task_id': self.parent_task_id}) return dct - diff --git a/SpiffWorkflow/bpmn/workflow.py b/SpiffWorkflow/bpmn/workflow.py index 96015265..476b7819 100644 --- a/SpiffWorkflow/bpmn/workflow.py +++ b/SpiffWorkflow/bpmn/workflow.py @@ -17,6 +17,9 @@ # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA # 02110-1301 USA +import heapq +from datetime import datetime, timezone + from SpiffWorkflow.task import Task from SpiffWorkflow.util.task import TaskState from SpiffWorkflow.exceptions import WorkflowException @@ -24,6 +27,8 @@ from SpiffWorkflow.bpmn.specs.mixins.events.event_types import CatchingEvent from SpiffWorkflow.bpmn.specs.mixins.events.start_event import StartEvent from SpiffWorkflow.bpmn.specs.mixins.subworkflow_task import CallActivity +from SpiffWorkflow.bpmn.specs.event_definitions.multiple import MultipleEventDefinition +from SpiffWorkflow.bpmn.specs.event_definitions.timer import TimerEventDefinition from SpiffWorkflow.bpmn.specs.event_definitions.item_aware_event import CodeEventDefinition from SpiffWorkflow.bpmn.specs.control import BoundaryEventSplit @@ -33,6 +38,98 @@ from .script_engine.python_engine import PythonScriptEngine +class _WaitingTaskIndex: + + def __init__(self): + self.waiting_tasks = {} + self.timer_tasks = {} + self.timer_due_at = {} + self.timer_heap = [] + self._sequence = 0 + + def task_state_changed(self, task, old_state, new_state): + if old_state == TaskState.WAITING: + self._remove(task) + if new_state == TaskState.WAITING: + self._add(task) + + def refresh_internal_tasks(self, refresh_task): + for task in list(self.waiting_tasks.values()): + if task.id not in self.timer_tasks and task.state == TaskState.WAITING: + refresh_task(task) + self.refresh_due_timers(refresh_task) + + def refresh_due_timers(self, refresh_task): + self._schedule_missing_timer_due_times(refresh_task) + now = datetime.now(timezone.utc).timestamp() + while self.timer_heap and self.timer_heap[0][0] <= now: + due_at, _sequence, task_id = heapq.heappop(self.timer_heap) + if self.timer_due_at.get(task_id) != due_at: + continue + task = self.timer_tasks.get(task_id) + if task is None or task.state != TaskState.WAITING: + continue + refresh_task(task) + + def refresh_tasks(self, tasks, refresh_task): + for task in tasks: + refresh_task(task) + + def reschedule_timer(self, task): + if task.id in self.timer_tasks: + self._schedule_timer(task) + + def _add(self, task): + self.waiting_tasks[task.id] = task + if self._is_timer_task(task): + self.timer_tasks[task.id] = task + self._schedule_timer(task) + + def _remove(self, task): + self.waiting_tasks.pop(task.id, None) + self.timer_tasks.pop(task.id, None) + self.timer_due_at.pop(task.id, None) + + def _schedule_missing_timer_due_times(self, refresh_task): + for task in list(self.timer_tasks.values()): + if task.state != TaskState.WAITING: + continue + if task.id not in self.timer_due_at: + refresh_task(task) + + def _schedule_timer(self, task): + due_at = self._get_timer_due_at(task) + if due_at is None: + self.timer_due_at.pop(task.id, None) + return + self.timer_due_at[task.id] = due_at + self._sequence += 1 + heapq.heappush(self.timer_heap, (due_at, self._sequence, task.id)) + + def _is_timer_task(self, task): + return isinstance(task.task_spec, CatchingEvent) and self._has_timer_definition(task.task_spec.event_definition) + + def _has_timer_definition(self, event_definition): + if isinstance(event_definition, TimerEventDefinition): + return True + if isinstance(event_definition, MultipleEventDefinition): + return any(self._has_timer_definition(definition) for definition in event_definition.event_definitions) + return False + + def _get_timer_due_at(self, task): + event_value = task._get_internal_data('event_value') + if event_value is None: + return None + if isinstance(event_value, dict): + if event_value.get('cycles') == 0: + return 0 + next_event = event_value.get('next') + if next_event is None: + return None + return TimerEventDefinition.get_datetime(next_event).timestamp() + return TimerEventDefinition.get_datetime(event_value).timestamp() + + class BpmnWorkflow(BpmnBaseWorkflow): """ The engine that executes a BPMN workflow. This specialises the standard @@ -51,6 +148,8 @@ def __init__(self, spec, subprocess_specs=None, script_engine=None, **kwargs): self.subprocesses = {} self.bpmn_events = [] self.correlations = {} + self._waiting_task_index = _WaitingTaskIndex() + self._refreshing_waiting_tasks = False super(BpmnWorkflow, self).__init__(spec, **kwargs) for obj in self.spec.data_objects: @@ -129,7 +228,7 @@ def catch(self, event): for task in tasks: task.task_spec.catch(task, event) if len(tasks) > 0: - self.refresh_waiting_tasks() + self._refresh_caught_tasks(tasks) def send_event(self, event): """Allows this workflow to catch an externally generated event.""" @@ -142,7 +241,7 @@ def send_event(self, event): raise WorkflowException(f"This process is not waiting for {event.event_definition.name}") for task in tasks: task.task_spec.catch(task, event) - self.refresh_waiting_tasks() + self._refresh_caught_tasks(tasks) def get_events(self): """Returns the list of events that cannot be handled from within this workflow.""" @@ -164,8 +263,10 @@ def do_engine_steps(self, will_complete_task=None, did_complete_task=None): :param will_complete_task: Callback that will be called prior to completing a task :param did_complete_task: Callback that will be called after completing a task """ + self._refresh_due_waiting_tasks() count = self._do_engine_steps(will_complete_task, did_complete_task) while count > 0: + self._refresh_due_waiting_tasks() count = self._do_engine_steps(will_complete_task, did_complete_task) def _do_engine_steps(self, will_complete_task=None, did_complete_task=None): @@ -197,25 +298,58 @@ def update_workflow(wf): def refresh_waiting_tasks(self, will_refresh_task=None, did_refresh_task=None): """ - Refresh the state of all WAITING tasks. This will, for example, update - Catching Timer Events whose waiting time has passed. + Compatibility no-op. + + BPMN workflows now refresh WAITING task internals through engine steps, + targeted event catches, and task completion notifications. :param will_refresh_task: Callback that will be called prior to refreshing a task :param did_refresh_task: Callback that will be called after refreshing a task """ - def update_task(task): - if will_refresh_task is not None: - will_refresh_task(task) + pass + + def _waiting_task_state_changed(self, task, old_state, new_state): + self._waiting_task_index.task_state_changed(task, old_state, new_state) + + def _rebuild_waiting_task_index(self): + self._waiting_task_index = _WaitingTaskIndex() + workflows = [self] + list(self.subprocesses.values()) + for workflow in workflows: + for task in workflow.tasks.values(): + if task.state == TaskState.WAITING: + self._waiting_task_index.task_state_changed(task, None, TaskState.WAITING) + + def _refresh_internal_waiting_tasks(self): + if self._refreshing_waiting_tasks: + return + self._refreshing_waiting_tasks = True + try: + self._waiting_task_index.refresh_internal_tasks(self._refresh_waiting_task) + finally: + self._refreshing_waiting_tasks = False + + def _refresh_due_waiting_tasks(self): + if self._refreshing_waiting_tasks: + return + self._refreshing_waiting_tasks = True + try: + self._waiting_task_index.refresh_due_timers(self._refresh_waiting_task) + finally: + self._refreshing_waiting_tasks = False + + def _refresh_caught_tasks(self, tasks): + if self._refreshing_waiting_tasks: + return + self._refreshing_waiting_tasks = True + try: + self._waiting_task_index.refresh_tasks(tasks, self._refresh_waiting_task) + finally: + self._refreshing_waiting_tasks = False + + def _refresh_waiting_task(self, task): + if task.state == TaskState.WAITING: task.task_spec._update(task) - if did_refresh_task is not None: - did_refresh_task(task) - - for subprocess in sorted(self.get_active_subprocesses(), key=lambda v: v.depth, reverse=True): - for task in subprocess.get_tasks_iterator(skip_subprocesses=True, state=TaskState.WAITING): - update_task(task) - - for task in self.get_tasks_iterator(skip_subprocesses=True, state=TaskState.WAITING): - update_task(task) + self._waiting_task_index.reschedule_timer(task) def get_task_from_id(self, task_id): if task_id not in self.tasks: diff --git a/SpiffWorkflow/task.py b/SpiffWorkflow/task.py index 5a1ff2ba..0151adb8 100644 --- a/SpiffWorkflow/task.py +++ b/SpiffWorkflow/task.py @@ -292,9 +292,12 @@ def _set_state(self, value): """Force set the state on a task""" if value != self.state: + old_state = self._state elapsed = time.time() - self.last_state_change self.last_state_change = time.time() self._state = value + if hasattr(self.workflow, '_task_state_changed_notify'): + self.workflow._task_state_changed_notify(self, old_state, value) logger.info( f'State changed to {TaskState.get_name(value)}', extra=self.collect_log_extras({'elapsed': elapsed}) diff --git a/SpiffWorkflow/workflow.py b/SpiffWorkflow/workflow.py index ceaa8d3a..14ce896b 100644 --- a/SpiffWorkflow/workflow.py +++ b/SpiffWorkflow/workflow.py @@ -285,6 +285,8 @@ def _remove_task(self, task_id): task = self.tasks[task_id] for child in task.children: self._remove_task(child.id) + if hasattr(self, '_task_state_changed_notify'): + self._task_state_changed_notify(task, task.state, None) task.parent._children.remove(task.id) self.tasks.pop(task_id) diff --git a/tests/SpiffWorkflow/bpmn/WaitingTaskStressBenchmark.py b/tests/SpiffWorkflow/bpmn/WaitingTaskStressBenchmark.py new file mode 100644 index 00000000..fc18e811 --- /dev/null +++ b/tests/SpiffWorkflow/bpmn/WaitingTaskStressBenchmark.py @@ -0,0 +1,141 @@ +""" +Run with: + make RUN='uv run' waiting-task-stress + +Useful scale knobs: + SPIFF_WAITING_STRESS_TIMERS=500 + SPIFF_WAITING_STRESS_READY_STEPS=500 + SPIFF_WAITING_STRESS_DUE_TIMERS=50 + SPIFF_WAITING_STRESS_FUTURE_TIMERS=450 + +Optional guard for optimized branches: + SPIFF_WAITING_STRESS_MAX_TIMER_CHECKS=500 +""" + +import os +import time +from unittest.mock import patch + +from SpiffWorkflow import TaskState +from SpiffWorkflow.bpmn.specs.event_definitions.timer import DurationTimerEventDefinition + +from .BpmnWorkflowTestCase import BpmnWorkflowTestCase +from .waiting_task_stress import StressBpmnKind, WaitingTaskStressConfig, load_stress_workflow + + +class WaitingTaskStressBenchmark(BpmnWorkflowTestCase): + + def test_ready_hot_path_with_many_dormant_timers(self): + config = WaitingTaskStressConfig( + waiting_timers=_env_int("SPIFF_WAITING_STRESS_TIMERS", 100), + ready_steps=_env_int("SPIFF_WAITING_STRESS_READY_STEPS", 100), + ) + workflow = load_stress_workflow(self, StressBpmnKind.READY_HOT_PATH, config) + + with _count_duration_timer_checks() as counter: + started_at = time.perf_counter() + workflow.do_engine_steps() + elapsed = time.perf_counter() - started_at + + waiting_timers = _tasks_with_bpmn_id_prefix(workflow, TaskState.WAITING, "timer_wait_") + completed_hot_steps = _tasks_with_bpmn_id_prefix(workflow, TaskState.COMPLETED, "hot_step_") + + self.assertEqual(config.waiting_timers, len(waiting_timers)) + self.assertEqual(config.ready_steps, len(completed_hot_steps)) + _print_metrics( + "READY HOT PATH WITH DORMANT TIMERS", + { + "waiting_timers": config.waiting_timers, + "ready_steps": config.ready_steps, + "timer_has_fired_calls": counter.calls, + "elapsed_seconds": f"{elapsed:.6f}", + }, + ) + _assert_optional_max("SPIFF_WAITING_STRESS_MAX_TIMER_CHECKS", counter.calls, self) + + def test_staggered_timers_refresh_cost(self): + due_timers = _env_int("SPIFF_WAITING_STRESS_DUE_TIMERS", 10) + waiting_timers = _env_int("SPIFF_WAITING_STRESS_FUTURE_TIMERS", 90) + config = WaitingTaskStressConfig( + waiting_timers=waiting_timers, + due_timers=due_timers, + due_duration="PT0.01S", + ) + workflow = load_stress_workflow(self, StressBpmnKind.STAGGERED_TIMERS, config) + workflow.do_engine_steps() + + time.sleep(0.02) + with _count_duration_timer_checks() as counter: + started_at = time.perf_counter() + workflow.refresh_waiting_tasks() + workflow.do_engine_steps() + elapsed = time.perf_counter() - started_at + + waiting_timer_tasks = _tasks_with_bpmn_id_prefix(workflow, TaskState.WAITING, "timer_wait_") + completed_timer_tasks = _tasks_with_bpmn_id_prefix(workflow, TaskState.COMPLETED, "timer_wait_") + + self.assertEqual(waiting_timers, len(waiting_timer_tasks)) + self.assertEqual(due_timers, len(completed_timer_tasks)) + _print_metrics( + "STAGGERED TIMER REFRESH", + { + "due_timers": due_timers, + "future_timers": waiting_timers, + "timer_has_fired_calls": counter.calls, + "elapsed_seconds": f"{elapsed:.6f}", + }, + ) + _assert_optional_max("SPIFF_WAITING_STRESS_MAX_TIMER_CHECKS", counter.calls, self) + + +class _TimerCheckCounter: + def __init__(self): + self.calls = 0 + + +def _count_duration_timer_checks(): + counter = _TimerCheckCounter() + original = DurationTimerEventDefinition.has_fired + + def counted_has_fired(event_definition, task): + counter.calls += 1 + return original(event_definition, task) + + patcher = patch.object(DurationTimerEventDefinition, "has_fired", counted_has_fired) + + class TimerCheckContext: + def __enter__(self): + patcher.start() + return counter + + def __exit__(self, exc_type, exc_value, traceback): + patcher.stop() + + return TimerCheckContext() + + +def _tasks_with_bpmn_id_prefix(workflow, state, prefix): + return [ + task for task in workflow.get_tasks(state=state) + if task.task_spec.bpmn_id is not None and task.task_spec.bpmn_id.startswith(prefix) + ] + + +def _env_int(name, default): + value = os.environ.get(name) + return default if value is None else int(value) + + +def _assert_optional_max(env_name, actual, test_case): + expected = os.environ.get(env_name) + if expected is not None: + test_case.assertLessEqual(actual, int(expected)) + + +def _print_metrics(title, metrics): + print("\n" + "=" * 80) + print(f"WAITING TASK STRESS: {title}") + print("=" * 80) + for key, value in metrics.items(): + print(f" {key}: {value}") + print("=" * 80) diff --git a/tests/SpiffWorkflow/bpmn/WaitingTaskStressTest.py b/tests/SpiffWorkflow/bpmn/WaitingTaskStressTest.py new file mode 100644 index 00000000..6fe25ba4 --- /dev/null +++ b/tests/SpiffWorkflow/bpmn/WaitingTaskStressTest.py @@ -0,0 +1,110 @@ +import time +from unittest.mock import patch + +from SpiffWorkflow import TaskState +from SpiffWorkflow.bpmn.specs.event_definitions.timer import DurationTimerEventDefinition + +from .BpmnWorkflowTestCase import BpmnWorkflowTestCase +from .waiting_task_stress import StressBpmnKind, WaitingTaskStressConfig, load_stress_workflow + + +class WaitingTaskStressTest(BpmnWorkflowTestCase): + + def test_ready_hot_path_stress_fixture_keeps_many_dormant_waiting_timers(self): + config = WaitingTaskStressConfig(waiting_timers=6, ready_steps=5) + workflow = load_stress_workflow(self, StressBpmnKind.READY_HOT_PATH, config) + + workflow.do_engine_steps() + + waiting_timer_tasks = [ + task for task in workflow.get_tasks(state=TaskState.WAITING) + if task.task_spec.bpmn_id is not None and task.task_spec.bpmn_id.startswith("timer_wait_") + ] + completed_hot_steps = [ + task for task in workflow.get_tasks(state=TaskState.COMPLETED) + if task.task_spec.bpmn_id is not None and task.task_spec.bpmn_id.startswith("hot_step_") + ] + + self.assertEqual(config.waiting_timers, len(waiting_timer_tasks)) + self.assertEqual(config.ready_steps, len(completed_hot_steps)) + self.assertFalse(workflow.completed) + + def test_ready_hot_path_does_not_poll_dormant_timers_per_step(self): + config = WaitingTaskStressConfig(waiting_timers=8, ready_steps=6) + workflow = load_stress_workflow(self, StressBpmnKind.READY_HOT_PATH, config) + + with _count_duration_timer_checks() as counter: + workflow.do_engine_steps() + + self.assertLessEqual(counter.calls, config.waiting_timers) + + def test_refresh_waiting_tasks_is_noop_and_engine_steps_refresh_due_timers(self): + config = WaitingTaskStressConfig(waiting_timers=0, due_timers=1, due_duration="PT0.01S") + workflow = load_stress_workflow(self, StressBpmnKind.STAGGERED_TIMERS, config) + + workflow.do_engine_steps() + timer_task = workflow.get_tasks(state=TaskState.WAITING, spec_name="timer_wait_0")[0] + callbacks = [] + time.sleep(0.02) + + workflow.refresh_waiting_tasks(callbacks.append, callbacks.append) + + self.assertEqual(TaskState.WAITING, timer_task.state) + self.assertEqual([], callbacks) + + workflow.do_engine_steps() + + self.assertEqual(TaskState.COMPLETED, timer_task.state) + + def test_get_tasks_does_not_refresh_due_timers_by_inspection(self): + config = WaitingTaskStressConfig(waiting_timers=0, due_timers=1, due_duration="PT0.01S") + workflow = load_stress_workflow(self, StressBpmnKind.STAGGERED_TIMERS, config) + + workflow.do_engine_steps() + time.sleep(0.02) + + waiting_tasks = workflow.get_tasks(state=TaskState.WAITING, spec_name="timer_wait_0") + ready_tasks = workflow.get_tasks(state=TaskState.READY, spec_name="timer_wait_0") + + self.assertEqual(1, len(waiting_tasks)) + self.assertEqual(0, len(ready_tasks)) + + def test_due_timer_survives_save_restore_without_public_refresh(self): + config = WaitingTaskStressConfig(waiting_timers=0, due_timers=1, due_duration="PT0.01S") + self.workflow = load_stress_workflow(self, StressBpmnKind.STAGGERED_TIMERS, config) + + self.workflow.do_engine_steps() + timer_task = self.workflow.get_tasks(state=TaskState.WAITING, spec_name="timer_wait_0")[0] + self.save_restore() + time.sleep(0.02) + + self.workflow.do_engine_steps() + + timer_task = self.workflow.get_task_from_id(timer_task.id) + self.assertEqual(TaskState.COMPLETED, timer_task.state) + + +class _TimerCheckCounter: + def __init__(self): + self.calls = 0 + + +def _count_duration_timer_checks(): + counter = _TimerCheckCounter() + original = DurationTimerEventDefinition.has_fired + + def counted_has_fired(event_definition, task): + counter.calls += 1 + return original(event_definition, task) + + patcher = patch.object(DurationTimerEventDefinition, "has_fired", counted_has_fired) + + class TimerCheckContext: + def __enter__(self): + patcher.start() + return counter + + def __exit__(self, exc_type, exc_value, traceback): + patcher.stop() + + return TimerCheckContext() diff --git a/tests/SpiffWorkflow/bpmn/events/TimerCycleTest.py b/tests/SpiffWorkflow/bpmn/events/TimerCycleTest.py index 75a8236b..1d287753 100644 --- a/tests/SpiffWorkflow/bpmn/events/TimerCycleTest.py +++ b/tests/SpiffWorkflow/bpmn/events/TimerCycleTest.py @@ -48,7 +48,7 @@ def actual_test(self,save_restore = False): self.workflow.do_engine_steps() if save_restore: self.save_restore() - self.workflow.refresh_waiting_tasks() + self.workflow.do_engine_steps() events = self.workflow.waiting_events() refill = self.workflow.get_tasks(spec_name='Refill_Coffee') # Wait time is 0.1s, with a limit of 2 children, so by the 3rd iteration, the event should be complete diff --git a/tests/SpiffWorkflow/bpmn/events/TimerDurationTest.py b/tests/SpiffWorkflow/bpmn/events/TimerDurationTest.py index b29a787f..77886aba 100644 --- a/tests/SpiffWorkflow/bpmn/events/TimerDurationTest.py +++ b/tests/SpiffWorkflow/bpmn/events/TimerDurationTest.py @@ -35,7 +35,7 @@ def actual_test(self,save_restore = False): self.save_restore() self.workflow.script_engine = self.script_engine time.sleep(0.1) - self.workflow.refresh_waiting_tasks() + self.workflow.do_engine_steps() loopcount += 1 endtime = datetime.now() duration = endtime - starttime diff --git a/tests/SpiffWorkflow/bpmn/events/TimerIntermediateTest.py b/tests/SpiffWorkflow/bpmn/events/TimerIntermediateTest.py index 91326e0b..3e78088d 100644 --- a/tests/SpiffWorkflow/bpmn/events/TimerIntermediateTest.py +++ b/tests/SpiffWorkflow/bpmn/events/TimerIntermediateTest.py @@ -31,9 +31,6 @@ def testRunThroughHappy(self): time.sleep(0.02) self.assertEqual(1, len(self.workflow.get_tasks(state=TaskState.WAITING))) - self.workflow.refresh_waiting_tasks() - self.assertEqual(0, len(self.workflow.get_tasks(state=TaskState.WAITING))) - self.assertEqual(1, len(self.workflow.get_tasks(state=TaskState.READY))) - self.workflow.do_engine_steps() + self.assertEqual(0, len(self.workflow.get_tasks(state=TaskState.WAITING))) self.assertEqual(0, len(self.workflow.get_tasks(state=TaskState.READY|TaskState.WAITING))) diff --git a/tests/SpiffWorkflow/bpmn/waiting_task_stress.py b/tests/SpiffWorkflow/bpmn/waiting_task_stress.py new file mode 100644 index 00000000..a5379d85 --- /dev/null +++ b/tests/SpiffWorkflow/bpmn/waiting_task_stress.py @@ -0,0 +1,152 @@ +import os +from dataclasses import dataclass +from enum import Enum +from uuid import uuid4 + +from SpiffWorkflow.bpmn.workflow import BpmnWorkflow + + +class StressBpmnKind(Enum): + READY_HOT_PATH = "ready_hot_path" + STAGGERED_TIMERS = "staggered_timers" + + +@dataclass(frozen=True) +class WaitingTaskStressConfig: + waiting_timers: int = 100 + ready_steps: int = 100 + due_timers: int = 0 + future_duration: str = "PT24H" + due_duration: str = "PT0S" + + +def build_stress_bpmn(kind, config): + if kind == StressBpmnKind.READY_HOT_PATH: + return _build_ready_hot_path_bpmn(config) + if kind == StressBpmnKind.STAGGERED_TIMERS: + return _build_staggered_timers_bpmn(config) + raise ValueError(f"Unsupported stress BPMN kind: {kind}") + + +def load_stress_workflow(test_case, kind, config): + filename = write_stress_bpmn(test_case, kind, config) + try: + spec, subprocesses = test_case.load_workflow_spec(filename, "waiting_task_stress", validate=False) + return BpmnWorkflow(spec, subprocesses) + finally: + path = _data_path(filename) + if os.path.exists(path): + os.unlink(path) + + +def write_stress_bpmn(test_case, kind, config): + filename = f"_generated_waiting_task_stress_{kind.value}_{uuid4().hex}.bpmn" + path = _data_path(filename) + with open(path, "w") as bpmn_file: + bpmn_file.write(build_stress_bpmn(kind, config)) + return filename + + +def _data_path(filename): + return os.path.join(os.path.dirname(__file__), "data", filename) + + +def _build_ready_hot_path_bpmn(config): + timer_branches = [ + _timer_branch(idx, config.future_duration) + for idx in range(config.waiting_timers) + ] + timer_flows = [ + f'flow_split_timer_{idx}' + for idx in range(config.waiting_timers) + ] + return _definitions( + "\n".join([ + _start_and_split(timer_flows + ["flow_split_hot_0"]), + "\n".join(timer_branches), + _hot_path(config.ready_steps), + ]) + ) + + +def _build_staggered_timers_bpmn(config): + timer_count = config.waiting_timers + config.due_timers + timer_flows = [ + f'flow_split_timer_{idx}' + for idx in range(timer_count) + ] + branches = [] + for idx in range(timer_count): + duration = config.due_duration if idx < config.due_timers else config.future_duration + branches.append(_timer_branch(idx, duration)) + return _definitions("\n".join([ + _start_and_split(timer_flows), + "\n".join(branches), + ])) + + +def _definitions(process_body): + return f""" + + +{_indent(process_body, 4)} + + +""" + + +def _start_and_split(split_outgoing): + outgoing = "\n".join(split_outgoing) + return f""" + flow_start_split + + + flow_start_split +{_indent(outgoing, 2)} + +""" + + +def _timer_branch(idx, duration): + return f""" + flow_split_timer_{idx} + flow_timer_{idx}_end + + "{duration}" + + + + flow_timer_{idx}_end + + +""" + + +def _hot_path(ready_steps): + if ready_steps < 1: + raise ValueError("ready_steps must be at least 1") + + tasks = [] + flows = [''] + for idx in range(ready_steps): + incoming = "flow_split_hot_0" if idx == 0 else f"flow_hot_{idx - 1}_{idx}" + outgoing = "flow_hot_last_end" if idx == ready_steps - 1 else f"flow_hot_{idx}_{idx + 1}" + tasks.append(f""" + {incoming} + {outgoing} + hot_path_steps = hot_path_steps + 1 if 'hot_path_steps' in locals() else 1 +""") + if idx < ready_steps - 1: + flows.append( + f'' + ) + flows.append('flow_hot_last_end') + flows.append( + f'' + ) + return "\n".join(tasks + flows) + + +def _indent(text, spaces): + prefix = " " * spaces + return "\n".join(f"{prefix}{line}" if line else line for line in text.splitlines())