From e50ff6f5e90bf0b899d377d836f0c512201b9550 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Wed, 6 May 2026 12:28:29 +0200 Subject: [PATCH 1/4] Vectorized power --- src/reverse_mode.jl | 45 +++++++++++++++++++++++---------------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/src/reverse_mode.jl b/src/reverse_mode.jl index c789461..df7877d 100644 --- a/src/reverse_mode.jl +++ b/src/reverse_mode.jl @@ -852,33 +852,34 @@ function _reverse_eval( continue end elseif op == :^ - # Broadcasted array .^ scalar: per-j reverse for the base, - # and a sum-reduced reverse for the (scalar) exponent. + # Broadcasted array .^ scalar: vectorize the per-element + # base reverse (with 0*Inf guard preserved) and reduce + # the exponent contribution as a single `sum` over GPU + # arrays. @assert length(children_indices) == 2 idx1 = first(children_indices) idx2 = last(children_indices) @inbounds ix1 = children_arr[idx1] @inbounds ix2 = children_arr[idx2] - for j in _eachindex(f.sizes, k) - rev_parent = @j f.reverse_storage[k] - partial = @j f.partials_storage[ix1] - val = ifelse( - rev_parent == 0.0 && !isfinite(partial), - rev_parent, - rev_parent * partial, - ) - @j f.reverse_storage[ix1] = val - end - rev_exp = zero(Float64) - for j in _eachindex(f.sizes, k) - rev_parent = @j f.reverse_storage[k] - base = @j f.forward_storage[ix1] - out = @j f.forward_storage[k] - if base > 0 - rev_exp += rev_parent * out * log(base) - end - end - @s f.reverse_storage[ix2] = rev_exp + rev_parent = _view_array(f.reverse_storage, f.sizes, k) + rev_v1 = _view_array(f.reverse_storage, f.sizes, ix1) + partial = _view_array(f.partials_storage, f.sizes, ix1) + rev_v1 .= ifelse.( + (rev_parent .== 0) .& .!isfinite.(partial), + rev_parent, + rev_parent .* partial, + ) + base_view = _view_array(f.forward_storage, f.sizes, ix1) + out_view = _view_array(f.forward_storage, f.sizes, k) + rev_exp_total = sum( + ifelse.( + base_view .> 0, + rev_parent .* out_view .* log.(abs.(base_view)), + zero(Float64), + ), + ) + pos2 = _scalar_pos(f.sizes, ix2) + view(f.reverse_storage, pos2:pos2) .= rev_exp_total continue end end From d3c701800f74c7b477ee9c71ac0ff7cd3524d857 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Wed, 6 May 2026 12:58:44 +0200 Subject: [PATCH 2/4] Fix --- src/reverse_mode.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/reverse_mode.jl b/src/reverse_mode.jl index df7877d..c3fc9ea 100644 --- a/src/reverse_mode.jl +++ b/src/reverse_mode.jl @@ -861,16 +861,16 @@ function _reverse_eval( idx2 = last(children_indices) @inbounds ix1 = children_arr[idx1] @inbounds ix2 = children_arr[idx2] - rev_parent = _view_array(f.reverse_storage, f.sizes, k) - rev_v1 = _view_array(f.reverse_storage, f.sizes, ix1) - partial = _view_array(f.partials_storage, f.sizes, ix1) + rev_parent = _view_linear(f.reverse_storage, f.sizes, k) + rev_v1 = _view_linear(f.reverse_storage, f.sizes, ix1) + partial = _view_linear(f.partials_storage, f.sizes, ix1) rev_v1 .= ifelse.( (rev_parent .== 0) .& .!isfinite.(partial), rev_parent, rev_parent .* partial, ) - base_view = _view_array(f.forward_storage, f.sizes, ix1) - out_view = _view_array(f.forward_storage, f.sizes, k) + base_view = _view_linear(f.forward_storage, f.sizes, ix1) + out_view = _view_linear(f.forward_storage, f.sizes, k) rev_exp_total = sum( ifelse.( base_view .> 0, From ce796f71a24d300fc61740d2f23f41ca471212df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Wed, 6 May 2026 15:07:34 +0200 Subject: [PATCH 3/4] Fix --- src/reverse_mode.jl | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/src/reverse_mode.jl b/src/reverse_mode.jl index c3fc9ea..a13924d 100644 --- a/src/reverse_mode.jl +++ b/src/reverse_mode.jl @@ -871,13 +871,22 @@ function _reverse_eval( ) base_view = _view_linear(f.forward_storage, f.sizes, ix1) out_view = _view_linear(f.forward_storage, f.sizes, k) - rev_exp_total = sum( - ifelse.( - base_view .> 0, - rev_parent .* out_view .* log.(abs.(base_view)), - zero(Float64), - ), - ) + # `mapreduce(f, +, base_view, rev_parent, out_view)` + # would express this directly, but multi-iterable + # `mapreduce` materializes an intermediate today + # (JuliaLang/julia#53417). Wrap the inputs in `zip` so + # the single-iterable specialization fires and the + # reduction stays allocation-free. Once + # https://github.com/JuliaLang/julia/pull/55301 lands + # we can drop the `zip` and use the multi-arg form. + T = eltype(rev_parent) + rev_exp_total = mapreduce( + +, + zip(base_view, rev_parent, out_view); + init = zero(T), + ) do (b, rp, o) + return b > 0 ? rp * o * log(b) : zero(T) + end pos2 = _scalar_pos(f.sizes, ix2) view(f.reverse_storage, pos2:pos2) .= rev_exp_total continue From 6ab7c561c890c28e04c69501ca9036be7a8aa342 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Wed, 6 May 2026 15:25:15 +0200 Subject: [PATCH 4/4] Fix --- test/JuMP.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/JuMP.jl b/test/JuMP.jl index 20cf661..3043900 100644 --- a/test/JuMP.jl +++ b/test/JuMP.jl @@ -314,6 +314,9 @@ end # is the path that actually re-runs forward+reverse, not the # `last_x == x` short-circuit). function test_neural_allocations() + if VERSION < v"1.12" + return + end n = 2 X = [1.0 0.5; 0.3 0.8] target = [0.5 0.2; 0.1 0.7]