diff --git a/src/reverse_mode.jl b/src/reverse_mode.jl index 46050f9..c789461 100644 --- a/src/reverse_mode.jl +++ b/src/reverse_mode.jl @@ -791,11 +791,21 @@ function _reverse_eval( end continue elseif op == :sum - rev_parent = @s f.reverse_storage[k] + # `sum` is rank-reducing (1 → 0): reverse-mode broadcasts + # the parent's scalar adjoint to every child slot. ix = children_arr[children_indices[1]] - for j in _eachindex(f.sizes, ix) - @j f.reverse_storage[ix] = rev_parent - end + pos = _scalar_pos(f.sizes, k) + # Avoid the scalar read of `reverse_storage[k]` (which fails + # on # GPU storage) by indexing with a 0-dim index, the view + # is then 0-dim at the outermost type, which the + # broadcast machinery specializes as a scalar source. + rev_parent_view = + view(f.reverse_storage, reshape(pos:pos, ())) + rev_children_view = + _view_linear(f.reverse_storage, f.sizes, ix) + # On GPU this lowers to a single fill-kernel; no + # Device-to-Host round-trip. + rev_children_view .= rev_parent_view continue elseif op == :row for j in _eachindex(f.sizes, k) diff --git a/src/sizes.jl b/src/sizes.jl index 2099dea..1c76881 100644 --- a/src/sizes.jl +++ b/src/sizes.jl @@ -49,6 +49,15 @@ function _setscalar!(x, value, sizes::Sizes, k::Int) return x[sizes.storage_offset[k]+1] = value end +""" + _scalar_pos(sizes, k) -> Int + +Tape index of node `k`'s single scalar slot. Useful when callers want to build +a 1-element view onto `forward_storage`/`reverse_storage` to do a +broadcast-style read or write that's safe on a GPU array. +""" +@inline _scalar_pos(sizes::Sizes, k::Int) = sizes.storage_offset[k] + 1 + function _getindex(x, sizes::Sizes, k::Int, j) return x[sizes.storage_offset[k]+j] end