From 5c49cf765ccfd06d64a1b6e42990310ae78e1c7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Wed, 6 May 2026 12:28:17 +0200 Subject: [PATCH 1/4] Vectorized sum --- src/reverse_mode.jl | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/reverse_mode.jl b/src/reverse_mode.jl index 46050f9..7e4f80e 100644 --- a/src/reverse_mode.jl +++ b/src/reverse_mode.jl @@ -791,11 +791,21 @@ function _reverse_eval( end continue elseif op == :sum - rev_parent = @s f.reverse_storage[k] + # `sum` is rank-reducing (1 → 0): reverse-mode broadcasts + # the parent's scalar adjoint to every child slot. Avoid + # the scalar read of `reverse_storage[k]` (which fails on + # GPU storage) by indexing with a 0-dim index — the view + # is then 0-dim at the outermost type, which the + # broadcast machinery specializes as a scalar source. On + # GPU this lowers to a single fill-kernel; no D2H + # round-trip. ix = children_arr[children_indices[1]] - for j in _eachindex(f.sizes, ix) - @j f.reverse_storage[ix] = rev_parent - end + pos = _scalar_pos(f.sizes, k) + rev_parent_view = + view(f.reverse_storage, reshape(pos:pos, ())) + rev_children_view = + _view_array(f.reverse_storage, f.sizes, ix) + rev_children_view .= rev_parent_view continue elseif op == :row for j in _eachindex(f.sizes, k) From 5fda65c297edb89acd49441aa54079573fe176bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Wed, 6 May 2026 12:54:23 +0200 Subject: [PATCH 2/4] Add _scalar_pos --- src/sizes.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/sizes.jl b/src/sizes.jl index 2099dea..1c76881 100644 --- a/src/sizes.jl +++ b/src/sizes.jl @@ -49,6 +49,15 @@ function _setscalar!(x, value, sizes::Sizes, k::Int) return x[sizes.storage_offset[k]+1] = value end +""" + _scalar_pos(sizes, k) -> Int + +Tape index of node `k`'s single scalar slot. Useful when callers want to build +a 1-element view onto `forward_storage`/`reverse_storage` to do a +broadcast-style read or write that's safe on a GPU array. +""" +@inline _scalar_pos(sizes::Sizes, k::Int) = sizes.storage_offset[k] + 1 + function _getindex(x, sizes::Sizes, k::Int, j) return x[sizes.storage_offset[k]+j] end From 9784e57f225012a60a1552c448ef01563db9df3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Wed, 6 May 2026 13:02:25 +0200 Subject: [PATCH 3/4] Fix --- src/reverse_mode.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/reverse_mode.jl b/src/reverse_mode.jl index 7e4f80e..0638718 100644 --- a/src/reverse_mode.jl +++ b/src/reverse_mode.jl @@ -804,7 +804,7 @@ function _reverse_eval( rev_parent_view = view(f.reverse_storage, reshape(pos:pos, ())) rev_children_view = - _view_array(f.reverse_storage, f.sizes, ix) + _view_linear(f.reverse_storage, f.sizes, ix) rev_children_view .= rev_parent_view continue elseif op == :row From 482e8dc8cc8837bbed8648dd2d43518641fbd8e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Wed, 6 May 2026 13:04:38 +0200 Subject: [PATCH 4/4] Improve comment --- src/reverse_mode.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/reverse_mode.jl b/src/reverse_mode.jl index 0638718..c789461 100644 --- a/src/reverse_mode.jl +++ b/src/reverse_mode.jl @@ -792,19 +792,19 @@ function _reverse_eval( continue elseif op == :sum # `sum` is rank-reducing (1 → 0): reverse-mode broadcasts - # the parent's scalar adjoint to every child slot. Avoid - # the scalar read of `reverse_storage[k]` (which fails on - # GPU storage) by indexing with a 0-dim index — the view - # is then 0-dim at the outermost type, which the - # broadcast machinery specializes as a scalar source. On - # GPU this lowers to a single fill-kernel; no D2H - # round-trip. + # the parent's scalar adjoint to every child slot. ix = children_arr[children_indices[1]] pos = _scalar_pos(f.sizes, k) + # Avoid the scalar read of `reverse_storage[k]` (which fails + # on # GPU storage) by indexing with a 0-dim index, the view + # is then 0-dim at the outermost type, which the + # broadcast machinery specializes as a scalar source. rev_parent_view = view(f.reverse_storage, reshape(pos:pos, ())) rev_children_view = _view_linear(f.reverse_storage, f.sizes, ix) + # On GPU this lowers to a single fill-kernel; no + # Device-to-Host round-trip. rev_children_view .= rev_parent_view continue elseif op == :row