Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions src/reverse_mode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -791,11 +791,21 @@ function _reverse_eval(
end
continue
elseif op == :sum
rev_parent = @s f.reverse_storage[k]
# `sum` is rank-reducing (1 → 0): reverse-mode broadcasts
# the parent's scalar adjoint to every child slot.
ix = children_arr[children_indices[1]]
for j in _eachindex(f.sizes, ix)
@j f.reverse_storage[ix] = rev_parent
end
pos = _scalar_pos(f.sizes, k)
# Avoid the scalar read of `reverse_storage[k]` (which fails
# on # GPU storage) by indexing with a 0-dim index, the view
# is then 0-dim at the outermost type, which the
# broadcast machinery specializes as a scalar source.
rev_parent_view =
view(f.reverse_storage, reshape(pos:pos, ()))
rev_children_view =
_view_linear(f.reverse_storage, f.sizes, ix)
# On GPU this lowers to a single fill-kernel; no
# Device-to-Host round-trip.
rev_children_view .= rev_parent_view
continue
elseif op == :row
for j in _eachindex(f.sizes, k)
Expand Down
9 changes: 9 additions & 0 deletions src/sizes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,15 @@ function _setscalar!(x, value, sizes::Sizes, k::Int)
return x[sizes.storage_offset[k]+1] = value
end

"""
_scalar_pos(sizes, k) -> Int

Tape index of node `k`'s single scalar slot. Useful when callers want to build
a 1-element view onto `forward_storage`/`reverse_storage` to do a
broadcast-style read or write that's safe on a GPU array.
"""
@inline _scalar_pos(sizes::Sizes, k::Int) = sizes.storage_offset[k] + 1

function _getindex(x, sizes::Sizes, k::Int, j)
return x[sizes.storage_offset[k]+j]
end
Expand Down
Loading