66import queue
77import sys
88import time
9+ from collections import deque
910from contextlib import ExitStack
1011from multiprocessing import Manager
1112from typing import TYPE_CHECKING
@@ -61,6 +62,9 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091
6162 __tracebackhide__ = True
6263 reports = session .execution_reports
6364 running_tasks : dict [str , Future [Any ]] = {}
65+ running_try_last : set [str ] = set ()
66+ queued_tasks : deque [str ] = deque ()
67+ queued_try_last_tasks : deque [str ] = deque ()
6468 sleeper = _Sleeper ()
6569 debug_status = _is_debug_status_enabled ()
6670
@@ -97,10 +101,26 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091
97101 elif status_queue_factory == "simple" :
98102 session .config ["_status_queue" ] = queue .SimpleQueue ()
99103
104+ if live_execution :
105+ live_execution .initial_status = start_execution_state
106+
100107 i = 0
108+ prefetch_factor = (
109+ 2
110+ if session .config ["parallel_backend" ]
111+ in (
112+ ParallelBackend .PROCESSES ,
113+ ParallelBackend .LOKY ,
114+ ParallelBackend .THREADS ,
115+ )
116+ else 1
117+ )
118+ use_prefetch_queue = prefetch_factor > 1
101119 while session .scheduler .is_active ():
102120 try :
103121 newly_collected_reports = []
122+ did_enqueue = False
123+ did_submit = False
104124
105125 # If there is any coiled function, the user probably wants to exploit
106126 # adaptive scaling. Thus, we need to submit all ready tasks.
@@ -110,42 +130,104 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091
110130 if any_coiled_task :
111131 n_new_tasks = 10_000
112132 else :
113- n_new_tasks = session .config ["n_workers" ] - len (running_tasks )
133+ if use_prefetch_queue :
134+ n_new_tasks = (
135+ session .config ["n_workers" ] * prefetch_factor
136+ ) - (
137+ len (running_tasks )
138+ + len (queued_tasks )
139+ + len (queued_try_last_tasks )
140+ )
141+ else :
142+ n_new_tasks = session .config ["n_workers" ] - len (running_tasks )
114143
115144 ready_tasks = (
116145 list (session .scheduler .get_ready (n_new_tasks ))
117146 if n_new_tasks >= 1
118147 else []
119148 )
120149
121- for task_signature in ready_tasks :
122- task = session .dag .nodes [task_signature ]["task" ]
123- if debug_status :
124- _log_status (
125- "PENDING"
126- if start_execution_state == TaskExecutionStatus .PENDING
127- else "RUNNING" ,
128- task_signature ,
150+ if use_prefetch_queue :
151+ for task_signature in ready_tasks :
152+ task = session .dag .nodes [task_signature ]["task" ]
153+ if debug_status :
154+ _log_status ("PENDING" , task_signature )
155+ session .hook .pytask_execute_task_log_start (
156+ session = session ,
157+ task = task ,
158+ status = start_execution_state ,
129159 )
130- session .hook .pytask_execute_task_log_start (
131- session = session , task = task , status = start_execution_state
132- )
133- try :
134- session .hook .pytask_execute_task_setup (
135- session = session , task = task
136- )
137- running_tasks [task_signature ] = (
138- session .hook .pytask_execute_task (session = session , task = task )
160+ if get_marks (task , "try_last" ):
161+ queued_try_last_tasks .append (task_signature )
162+ else :
163+ queued_tasks .append (task_signature )
164+ did_enqueue = True
165+
166+ def _can_run_try_last () -> bool :
167+ return not (
168+ queued_tasks
169+ or (len (running_tasks ) > len (running_try_last ))
139170 )
140- sleeper .reset ()
141- except Exception : # noqa: BLE001
142- report = ExecutionReport .from_task_and_exception (
143- task , sys .exc_info ()
171+
172+ while len (running_tasks ) < session .config ["n_workers" ]:
173+ if queued_tasks :
174+ task_signature = queued_tasks .popleft ()
175+ elif queued_try_last_tasks and _can_run_try_last ():
176+ task_signature = queued_try_last_tasks .popleft ()
177+ else :
178+ break
179+ task = session .dag .nodes [task_signature ]["task" ]
180+ try :
181+ session .hook .pytask_execute_task_setup (
182+ session = session , task = task
183+ )
184+ running_tasks [task_signature ] = (
185+ session .hook .pytask_execute_task (
186+ session = session , task = task
187+ )
188+ )
189+ if get_marks (task , "try_last" ):
190+ running_try_last .add (task_signature )
191+ sleeper .reset ()
192+ did_submit = True
193+ except Exception : # noqa: BLE001
194+ report = ExecutionReport .from_task_and_exception (
195+ task , sys .exc_info ()
196+ )
197+ newly_collected_reports .append (report )
198+ session .scheduler .done (task_signature )
199+ else :
200+ for task_signature in ready_tasks :
201+ task = session .dag .nodes [task_signature ]["task" ]
202+ if debug_status :
203+ _log_status (
204+ "PENDING"
205+ if start_execution_state == TaskExecutionStatus .PENDING
206+ else "RUNNING" ,
207+ task_signature ,
208+ )
209+ session .hook .pytask_execute_task_log_start (
210+ session = session , task = task , status = start_execution_state
144211 )
145- newly_collected_reports .append (report )
146- session .scheduler .done (task_signature )
212+ try :
213+ session .hook .pytask_execute_task_setup (
214+ session = session , task = task
215+ )
216+ running_tasks [task_signature ] = (
217+ session .hook .pytask_execute_task (
218+ session = session , task = task
219+ )
220+ )
221+ sleeper .reset ()
222+ did_submit = True
223+ except Exception : # noqa: BLE001
224+ report = ExecutionReport .from_task_and_exception (
225+ task , sys .exc_info ()
226+ )
227+ newly_collected_reports .append (report )
228+ session .scheduler .done (task_signature )
147229
148- if not ready_tasks :
230+ if not ready_tasks and not did_enqueue and not did_submit :
149231 sleeper .increment ()
150232
151233 for task_signature in list (running_tasks ):
@@ -173,6 +255,7 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091
173255 )
174256 )
175257 running_tasks .pop (task_signature )
258+ running_try_last .discard (task_signature )
176259 session .scheduler .done (task_signature )
177260 else :
178261 task = session .dag .nodes [task_signature ]["task" ]
@@ -192,6 +275,7 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091
192275 report = ExecutionReport .from_task (task )
193276
194277 running_tasks .pop (task_signature )
278+ running_try_last .discard (task_signature )
195279 newly_collected_reports .append (report )
196280 session .scheduler .done (task_signature )
197281
0 commit comments