Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 24 additions & 24 deletions src/reverse_mode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,9 @@ function _forward_eval(
idx2 = last(children_indices)
@inbounds ix1 = children_arr[idx1]
@inbounds ix2 = children_arr[idx2]
v1 = _view_array(f.forward_storage, f.sizes, ix1)
v2 = _view_array(f.forward_storage, f.sizes, ix2)
out = _view_array(f.forward_storage, f.sizes, k)
v1 = _view_matrix(f.forward_storage, f.sizes, ix1)
v2 = _view_matrix(f.forward_storage, f.sizes, ix2)
out = _view_matrix(f.forward_storage, f.sizes, k)
LinearAlgebra.mul!(out, v1, v2)
# We deliberately don't write v1/v2 into partials_storage
# here: the matmul reverse branch reads forward_storage
Expand Down Expand Up @@ -343,8 +343,8 @@ function _forward_eval(
elseif node.index == 15 # sum
@assert N == 1
ix = children_arr[first(children_indices)]
inp = _view_array(f.forward_storage, f.sizes, ix)
fill!(_view_array(f.partials_storage, f.sizes, ix), one(T))
inp = _view_linear(f.forward_storage, f.sizes, ix)
fill!(_view_linear(f.partials_storage, f.sizes, ix), one(T))
@s f.forward_storage[k] = sum(inp)
elseif node.index == 16 # row
for j in _eachindex(f.sizes, k)
Expand Down Expand Up @@ -393,12 +393,12 @@ function _forward_eval(
child1 = first(children_indices)
@inbounds ix1 = children_arr[child1]
@inbounds ix2 = children_arr[child1+1]
out = _view_array(f.forward_storage, f.sizes, k)
v1 = _view_array(f.forward_storage, f.sizes, ix1)
v2 = _view_array(f.forward_storage, f.sizes, ix2)
out = _view_linear(f.forward_storage, f.sizes, k)
v1 = _view_linear(f.forward_storage, f.sizes, ix1)
v2 = _view_linear(f.forward_storage, f.sizes, ix2)
out .= v1 .- v2
fill!(_view_array(f.partials_storage, f.sizes, ix1), one(T))
fill!(_view_array(f.partials_storage, f.sizes, ix2), -one(T))
fill!(_view_linear(f.partials_storage, f.sizes, ix1), one(T))
fill!(_view_linear(f.partials_storage, f.sizes, ix2), -one(T))
elseif node.index == 3 # :* (broadcasted)
# Node `k` is not scalar, so we do matrix multiplication
if f.sizes.ndims[k] != 0
Expand Down Expand Up @@ -466,9 +466,9 @@ function _forward_eval(
f.forward_storage,
f.sizes.storage_offset[ix2]+1,
)
out = _view_array(f.forward_storage, f.sizes, k)
inp = _view_array(f.forward_storage, f.sizes, ix1)
partials = _view_array(f.partials_storage, f.sizes, ix1)
out = _view_linear(f.forward_storage, f.sizes, k)
inp = _view_linear(f.forward_storage, f.sizes, ix1)
partials = _view_linear(f.partials_storage, f.sizes, ix1)
if exponent == 2
out .= inp .* inp
partials .= 2 .* inp
Expand Down Expand Up @@ -518,9 +518,9 @@ function _forward_eval(
@j f.forward_storage[k] = -val
end
elseif operators.univariate_operators[node.index] === :tanh
out = _view_array(f.forward_storage, f.sizes, k)
inp = _view_array(f.forward_storage, f.sizes, child_idx)
partials = _view_array(f.partials_storage, f.sizes, child_idx)
out = _view_linear(f.forward_storage, f.sizes, k)
inp = _view_linear(f.forward_storage, f.sizes, child_idx)
partials = _view_linear(f.partials_storage, f.sizes, child_idx)
out .= tanh.(inp)
partials .= one(T) .- out .* out
else
Expand Down Expand Up @@ -618,11 +618,11 @@ function _reverse_eval(
idx2 = last(children_indices)
ix1 = children_arr[idx1]
ix2 = children_arr[idx2]
v1 = _view_array(f.forward_storage, f.sizes, ix1)
v2 = _view_array(f.forward_storage, f.sizes, ix2)
rev_parent = _view_array(f.reverse_storage, f.sizes, k)
rev_v1 = _view_array(f.reverse_storage, f.sizes, ix1)
rev_v2 = _view_array(f.reverse_storage, f.sizes, ix2)
v1 = _view_matrix(f.forward_storage, f.sizes, ix1)
v2 = _view_matrix(f.forward_storage, f.sizes, ix2)
rev_parent = _view_matrix(f.reverse_storage, f.sizes, k)
rev_v1 = _view_matrix(f.reverse_storage, f.sizes, ix1)
rev_v2 = _view_matrix(f.reverse_storage, f.sizes, ix2)
LinearAlgebra.mul!(rev_v1, rev_parent, v2')
LinearAlgebra.mul!(rev_v2, v1', rev_parent)
continue
Expand Down Expand Up @@ -881,12 +881,12 @@ function _reverse_eval(
# diagonal entries are stored in `f.partials_storage`. We broadcast
# `rev_child .= rev_parent .* partial` over the whole array (with the
# 0 * Inf guard preserved).
rev_parent = _view_array(f.reverse_storage, f.sizes, k)
rev_parent = _view_linear(f.reverse_storage, f.sizes, k)
for child_idx in children_indices
ix = children_arr[child_idx]
@assert _size(f.sizes, k) == _size(f.sizes, ix)
rev_child = _view_array(f.reverse_storage, f.sizes, ix)
partial = _view_array(f.partials_storage, f.sizes, ix)
rev_child = _view_linear(f.reverse_storage, f.sizes, ix)
partial = _view_linear(f.partials_storage, f.sizes, ix)
rev_child .= ifelse.(
(rev_parent .== 0) .& .!isfinite.(partial),
rev_parent,
Expand Down
50 changes: 32 additions & 18 deletions src/sizes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,27 +68,41 @@ implementation just calls `getindex`; this is a hook for storage backends
_scalar_load(storage::AbstractVector, idx::Int) = @inbounds storage[idx]

"""
_view_array(storage, sizes, k) -> AbstractArray
_view_linear(storage, sizes, k) -> SubArray

Return a view of the slice of `storage` that holds node `k`'s array value,
reshaped to that node's natural shape. The view aliases the underlying
`storage` (no copy), so mutating the returned array writes back into the tape.
For a scalar (`ndims[k] == 0`) node this returns a length-1 vector view.
Return a flat 1-D view of the slice of `storage` that holds node `k`'s array
value. The view aliases the underlying `storage` (no copy), so mutating it
writes back into the tape. For a scalar (`ndims[k] == 0`) node this returns
a length-1 vector view.

Use this for elementwise (broadcasted) operations and reductions that don't
need the array's natural shape — keeping the return type-stable
(`SubArray{T,1,...}`) avoids the heap-boxing that a multi-shape return type
would force.
"""
function _view_array(storage::AbstractVector, sizes::Sizes, k::Int)
nd = sizes.ndims[k]
function _view_linear(storage::AbstractVector, sizes::Sizes, k::Int)
offset = sizes.storage_offset[k]
if nd == 0
return view(storage, (offset+1):(offset+1))
elseif nd == 1
n = sizes.size[sizes.size_offset[k]+1]
return view(storage, (offset+1):(offset+n))
else
N = _length(sizes, k)
v = view(storage, (offset+1):(offset+N))
szs = ntuple(d -> sizes.size[sizes.size_offset[k]+d], nd)
return reshape(v, szs)
end
N = _length(sizes, k)
return view(storage, (offset+1):(offset+N))
end

"""
_view_matrix(storage, sizes, k) -> ReshapedArray

Return a 2-D view of the slice of `storage` that holds node `k`'s array
value. A 1-D node is treated as a column vector `(n, 1)` and a 0-D node as
`(1, 1)`. Always returns a 2-D `Base.ReshapedArray`, which is what callers
like `LinearAlgebra.mul!` need; keeping the return type-stable avoids
heap-boxing.
"""
function _view_matrix(storage::AbstractVector, sizes::Sizes, k::Int)
@assert sizes.ndims[k] == 2
offset = sizes.storage_offset[k]
size_off = sizes.size_offset[k]
m = sizes.size[size_off+1]
n = sizes.size[size_off+2]
v = view(storage, (offset+1):(offset+m*n))
return reshape(v, (m, n))
end

"""
Expand Down
52 changes: 52 additions & 0 deletions test/JuMP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,58 @@ function test_neural()
end
end

# Builds the same `sum((W2*tanh.(W1*X) - target)^2)` MLP that `test_neural`
# exercises and checks that, after warmup, both `eval_objective` and
# `eval_objective_gradient` are allocation-free on the CPU `Vector{Float64}`
# tape — including when the input `x` has changed since the last call (which
# is the path that actually re-runs forward+reverse, not the
# `last_x == x` short-circuit).
function test_neural_allocations()
n = 2
X = [1.0 0.5; 0.3 0.8]
target = [0.5 0.2; 0.1 0.7]
model = Model()
@variable(model, W1[1:n, 1:n], container = ArrayDiff.ArrayOfVariables)
@variable(model, W2[1:n, 1:n], container = ArrayDiff.ArrayOfVariables)
Y = W2 * tanh.(W1 * X)
loss = sum((Y .- target) .^ 2)
mode = ArrayDiff.Mode()
ad = ArrayDiff.model(mode)
MOI.Nonlinear.set_objective(ad, JuMP.moi_function(loss))
evaluator = MOI.Nonlinear.Evaluator(
ad,
mode,
JuMP.index.(JuMP.all_variables(model)),
)
MOI.initialize(evaluator, [:Grad])
x1 = Float64.(collect(1:8))
x2 = Float64.(collect(2:9))
g = zeros(8)
# Wrapped in typed functions so `@allocated` doesn't capture the
# return-value boxing that happens when calling `eval_objective`
# directly from the macro's untyped scope (each `MOI.eval_objective`
# returns a `Float64` which then escapes into `Any`-typed scope).
_obj(ev, x) = MOI.eval_objective(ev, x)
function _grad!(ev, g, x)
MOI.eval_objective_gradient(ev, g, x)
return nothing
end
# Warmup: trigger JIT compilation for both `eval_objective` and
# `eval_objective_gradient`. Two distinct inputs so `_reverse_mode`'s
# `last_x == x` short-circuit doesn't elide the work on the second call.
_obj(evaluator, x1)
_obj(evaluator, x2)
_grad!(evaluator, g, x1)
_grad!(evaluator, g, x2)
# Now alternate: each measured call sees `last_x ≠ x`, so it actually
# runs the full forward + reverse passes through the block tape.
@test 0 == @allocated _obj(evaluator, x1)
@test 0 == @allocated _obj(evaluator, x2)
@test 0 == @allocated _grad!(evaluator, g, x1)
@test 0 == @allocated _grad!(evaluator, g, x2)
return
end

function test_moi_function()
model = Model()
@variable(model, W[1:2, 1:2], container = ArrayDiff.ArrayOfVariables)
Expand Down
Loading