Skip to content

Commit f59afeb

Browse files
committed
Fix: finalize and catch errors
1 parent 51acd18 commit f59afeb

2 files changed

Lines changed: 100 additions & 17 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ParallelOperations"
22
uuid = "09c1cff2-b94d-4c31-85a1-721512e21c63"
33
authors = ["islent <leoislent@gmail.com>"]
4-
version = "0.1.5"
4+
version = "0.1.6"
55

66
[deps]
77
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"

src/ParallelOperations.jl

Lines changed: 99 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,12 @@ end
114114

115115
function transfer(src::Int, target::Int, from_expr, to_expr, to_mod::Module = Main, from_mod::Module = Main)
116116
r = RemoteChannel(src)
117-
@spawnat(src, put!(r, Core.eval(from_mod, from_expr)))
118-
@sync @spawnat(target, Core.eval(to_mod, Expr(:(=), to_expr, fetch(r))))
117+
try
118+
@spawnat(src, put!(r, Core.eval(from_mod, from_expr)))
119+
@sync @spawnat(target, Core.eval(to_mod, Expr(:(=), to_expr, fetch(r))))
120+
finally
121+
close(r)
122+
end
119123
end
120124

121125
# broadcast
@@ -296,11 +300,37 @@ function Base.reduce(f::Function, pids::Array, expr, mod::Module = Main; timeout
296300
futures = reduce_async(f, pids, expr, mod)
297301

298302
local_results = Vector{Any}(undef, length(futures))
299-
@sync for (i, future) in enumerate(futures)
300-
@async local_results[i] = fetch_with_timeout(future, timeout)
301-
end
302303

303-
return reduce(f, local_results)
304+
try
305+
@sync for (i, future) in enumerate(futures)
306+
@async begin
307+
try
308+
local_results[i] = fetch_with_timeout(future, timeout)
309+
catch e
310+
local_results[i] = e
311+
end
312+
end
313+
end
314+
315+
for (i, result) in enumerate(local_results)
316+
if result isa Exception
317+
throw(result)
318+
end
319+
end
320+
321+
return reduce(f, local_results)
322+
catch e
323+
for future in futures
324+
if isopen(future)
325+
try
326+
fetch(future)
327+
catch
328+
# ignore error
329+
end
330+
end
331+
end
332+
rethrow(e)
333+
end
304334
end
305335

306336
function gather_async(pids::Array, expr, mod::Module = Main)
@@ -312,11 +342,36 @@ function gather(pids::Array, expr, mod::Module = Main; timeout::Float64 = 5.0)
312342
futures = gather_async(pids, expr, mod)
313343
results = Vector{Any}(undef, length(futures))
314344

315-
@sync for (i, future) in enumerate(futures)
316-
@async results[i] = fetch_with_timeout(future, timeout)
345+
try
346+
@sync for (i, future) in enumerate(futures)
347+
@async begin
348+
try
349+
results[i] = fetch_with_timeout(future, timeout)
350+
catch e
351+
results[i] = e
352+
end
353+
end
354+
end
355+
356+
for (i, result) in enumerate(results)
357+
if result isa Exception
358+
throw(result)
359+
end
360+
end
361+
362+
return results
363+
catch e
364+
for future in futures
365+
if isopen(future)
366+
try
367+
fetch(future)
368+
catch
369+
# ignore error
370+
end
371+
end
372+
end
373+
rethrow(e)
317374
end
318-
319-
return results
320375
end
321376

322377
macro gather(pids, expr, mod::Symbol = :Main)
@@ -366,11 +421,25 @@ function gather(f::Function, pids::Array, expr, mod::Module = Main; timeout::Flo
366421
end
367422

368423
function allgather_async(pids::Array, src_expr, target_expr = src_expr, mod::Module = Main)
424+
#TODO fully async
369425
gather_futures = gather_async(pids, src_expr, mod)
370-
gather_result = fetch.(gather_futures)
371-
bcast_futures = bcast_async(pids, target_expr, gather_result, mod)
372426

373-
return bcast_futures
427+
try
428+
gather_result = fetch.(gather_futures)
429+
bcast_futures = bcast_async(pids, target_expr, gather_result, mod)
430+
return bcast_futures
431+
catch e
432+
for future in gather_futures
433+
if isopen(future)
434+
try
435+
fetch(future)
436+
catch
437+
# ignore error
438+
end
439+
end
440+
end
441+
rethrow(e)
442+
end
374443
end
375444

376445
function allgather(pids::Array, src_expr, target_expr = src_expr, mod::Module = Main; timeout::Float64 = 5.0)
@@ -379,11 +448,25 @@ function allgather(pids::Array, src_expr, target_expr = src_expr, mod::Module =
379448
end
380449

381450
function allreduce_async(f::Function, pids::Array, src_expr, target_expr = src_expr, mod::Module = Main)
451+
#TODO fully async
382452
reduce_futures = reduce_async(f, pids, src_expr, mod)
383-
reduce_result = reduce(f, fetch.(reduce_futures))
384-
bcast_futures = bcast_async(pids, target_expr, reduce_result, mod)
385453

386-
return bcast_futures
454+
try
455+
reduce_result = reduce(f, fetch.(reduce_futures))
456+
bcast_futures = bcast_async(pids, target_expr, reduce_result, mod)
457+
return bcast_futures
458+
catch e
459+
for future in reduce_futures
460+
if isopen(future)
461+
try
462+
fetch(future)
463+
catch
464+
# ignore error
465+
end
466+
end
467+
end
468+
rethrow(e)
469+
end
387470
end
388471

389472
function allreduce(f::Function, pids::Array, src_expr, target_expr = src_expr, mod::Module = Main; timeout::Float64 = 5.0)

0 commit comments

Comments
 (0)