diff --git a/src/reverse_mode.jl b/src/reverse_mode.jl index c789461..a13924d 100644 --- a/src/reverse_mode.jl +++ b/src/reverse_mode.jl @@ -852,33 +852,43 @@ function _reverse_eval( continue end elseif op == :^ - # Broadcasted array .^ scalar: per-j reverse for the base, - # and a sum-reduced reverse for the (scalar) exponent. + # Broadcasted array .^ scalar: vectorize the per-element + # base reverse (with 0*Inf guard preserved) and reduce + # the exponent contribution as a single `sum` over GPU + # arrays. @assert length(children_indices) == 2 idx1 = first(children_indices) idx2 = last(children_indices) @inbounds ix1 = children_arr[idx1] @inbounds ix2 = children_arr[idx2] - for j in _eachindex(f.sizes, k) - rev_parent = @j f.reverse_storage[k] - partial = @j f.partials_storage[ix1] - val = ifelse( - rev_parent == 0.0 && !isfinite(partial), - rev_parent, - rev_parent * partial, - ) - @j f.reverse_storage[ix1] = val - end - rev_exp = zero(Float64) - for j in _eachindex(f.sizes, k) - rev_parent = @j f.reverse_storage[k] - base = @j f.forward_storage[ix1] - out = @j f.forward_storage[k] - if base > 0 - rev_exp += rev_parent * out * log(base) - end + rev_parent = _view_linear(f.reverse_storage, f.sizes, k) + rev_v1 = _view_linear(f.reverse_storage, f.sizes, ix1) + partial = _view_linear(f.partials_storage, f.sizes, ix1) + rev_v1 .= ifelse.( + (rev_parent .== 0) .& .!isfinite.(partial), + rev_parent, + rev_parent .* partial, + ) + base_view = _view_linear(f.forward_storage, f.sizes, ix1) + out_view = _view_linear(f.forward_storage, f.sizes, k) + # `mapreduce(f, +, base_view, rev_parent, out_view)` + # would express this directly, but multi-iterable + # `mapreduce` materializes an intermediate today + # (JuliaLang/julia#53417). Wrap the inputs in `zip` so + # the single-iterable specialization fires and the + # reduction stays allocation-free. Once + # https://github.com/JuliaLang/julia/pull/55301 lands + # we can drop the `zip` and use the multi-arg form. + T = eltype(rev_parent) + rev_exp_total = mapreduce( + +, + zip(base_view, rev_parent, out_view); + init = zero(T), + ) do (b, rp, o) + return b > 0 ? rp * o * log(b) : zero(T) end - @s f.reverse_storage[ix2] = rev_exp + pos2 = _scalar_pos(f.sizes, ix2) + view(f.reverse_storage, pos2:pos2) .= rev_exp_total continue end end diff --git a/test/JuMP.jl b/test/JuMP.jl index 20cf661..3043900 100644 --- a/test/JuMP.jl +++ b/test/JuMP.jl @@ -314,6 +314,9 @@ end # is the path that actually re-runs forward+reverse, not the # `last_x == x` short-circuit). function test_neural_allocations() + if VERSION < v"1.12" + return + end n = 2 X = [1.0 0.5; 0.3 0.8] target = [0.5 0.2; 0.1 0.7]