diff --git a/src/reverse_mode.jl b/src/reverse_mode.jl index 6aaa6c2..46050f9 100644 --- a/src/reverse_mode.jl +++ b/src/reverse_mode.jl @@ -165,9 +165,9 @@ function _forward_eval( idx2 = last(children_indices) @inbounds ix1 = children_arr[idx1] @inbounds ix2 = children_arr[idx2] - v1 = _view_array(f.forward_storage, f.sizes, ix1) - v2 = _view_array(f.forward_storage, f.sizes, ix2) - out = _view_array(f.forward_storage, f.sizes, k) + v1 = _view_matrix(f.forward_storage, f.sizes, ix1) + v2 = _view_matrix(f.forward_storage, f.sizes, ix2) + out = _view_matrix(f.forward_storage, f.sizes, k) LinearAlgebra.mul!(out, v1, v2) # We deliberately don't write v1/v2 into partials_storage # here: the matmul reverse branch reads forward_storage @@ -343,8 +343,8 @@ function _forward_eval( elseif node.index == 15 # sum @assert N == 1 ix = children_arr[first(children_indices)] - inp = _view_array(f.forward_storage, f.sizes, ix) - fill!(_view_array(f.partials_storage, f.sizes, ix), one(T)) + inp = _view_linear(f.forward_storage, f.sizes, ix) + fill!(_view_linear(f.partials_storage, f.sizes, ix), one(T)) @s f.forward_storage[k] = sum(inp) elseif node.index == 16 # row for j in _eachindex(f.sizes, k) @@ -393,12 +393,12 @@ function _forward_eval( child1 = first(children_indices) @inbounds ix1 = children_arr[child1] @inbounds ix2 = children_arr[child1+1] - out = _view_array(f.forward_storage, f.sizes, k) - v1 = _view_array(f.forward_storage, f.sizes, ix1) - v2 = _view_array(f.forward_storage, f.sizes, ix2) + out = _view_linear(f.forward_storage, f.sizes, k) + v1 = _view_linear(f.forward_storage, f.sizes, ix1) + v2 = _view_linear(f.forward_storage, f.sizes, ix2) out .= v1 .- v2 - fill!(_view_array(f.partials_storage, f.sizes, ix1), one(T)) - fill!(_view_array(f.partials_storage, f.sizes, ix2), -one(T)) + fill!(_view_linear(f.partials_storage, f.sizes, ix1), one(T)) + fill!(_view_linear(f.partials_storage, f.sizes, ix2), -one(T)) elseif node.index == 3 # :* (broadcasted) # Node `k` is not scalar, so we do matrix multiplication if f.sizes.ndims[k] != 0 @@ -466,9 +466,9 @@ function _forward_eval( f.forward_storage, f.sizes.storage_offset[ix2]+1, ) - out = _view_array(f.forward_storage, f.sizes, k) - inp = _view_array(f.forward_storage, f.sizes, ix1) - partials = _view_array(f.partials_storage, f.sizes, ix1) + out = _view_linear(f.forward_storage, f.sizes, k) + inp = _view_linear(f.forward_storage, f.sizes, ix1) + partials = _view_linear(f.partials_storage, f.sizes, ix1) if exponent == 2 out .= inp .* inp partials .= 2 .* inp @@ -518,9 +518,9 @@ function _forward_eval( @j f.forward_storage[k] = -val end elseif operators.univariate_operators[node.index] === :tanh - out = _view_array(f.forward_storage, f.sizes, k) - inp = _view_array(f.forward_storage, f.sizes, child_idx) - partials = _view_array(f.partials_storage, f.sizes, child_idx) + out = _view_linear(f.forward_storage, f.sizes, k) + inp = _view_linear(f.forward_storage, f.sizes, child_idx) + partials = _view_linear(f.partials_storage, f.sizes, child_idx) out .= tanh.(inp) partials .= one(T) .- out .* out else @@ -618,11 +618,11 @@ function _reverse_eval( idx2 = last(children_indices) ix1 = children_arr[idx1] ix2 = children_arr[idx2] - v1 = _view_array(f.forward_storage, f.sizes, ix1) - v2 = _view_array(f.forward_storage, f.sizes, ix2) - rev_parent = _view_array(f.reverse_storage, f.sizes, k) - rev_v1 = _view_array(f.reverse_storage, f.sizes, ix1) - rev_v2 = _view_array(f.reverse_storage, f.sizes, ix2) + v1 = _view_matrix(f.forward_storage, f.sizes, ix1) + v2 = _view_matrix(f.forward_storage, f.sizes, ix2) + rev_parent = _view_matrix(f.reverse_storage, f.sizes, k) + rev_v1 = _view_matrix(f.reverse_storage, f.sizes, ix1) + rev_v2 = _view_matrix(f.reverse_storage, f.sizes, ix2) LinearAlgebra.mul!(rev_v1, rev_parent, v2') LinearAlgebra.mul!(rev_v2, v1', rev_parent) continue @@ -881,12 +881,12 @@ function _reverse_eval( # diagonal entries are stored in `f.partials_storage`. We broadcast # `rev_child .= rev_parent .* partial` over the whole array (with the # 0 * Inf guard preserved). - rev_parent = _view_array(f.reverse_storage, f.sizes, k) + rev_parent = _view_linear(f.reverse_storage, f.sizes, k) for child_idx in children_indices ix = children_arr[child_idx] @assert _size(f.sizes, k) == _size(f.sizes, ix) - rev_child = _view_array(f.reverse_storage, f.sizes, ix) - partial = _view_array(f.partials_storage, f.sizes, ix) + rev_child = _view_linear(f.reverse_storage, f.sizes, ix) + partial = _view_linear(f.partials_storage, f.sizes, ix) rev_child .= ifelse.( (rev_parent .== 0) .& .!isfinite.(partial), rev_parent, diff --git a/src/sizes.jl b/src/sizes.jl index 04c45d5..2099dea 100644 --- a/src/sizes.jl +++ b/src/sizes.jl @@ -68,27 +68,41 @@ implementation just calls `getindex`; this is a hook for storage backends _scalar_load(storage::AbstractVector, idx::Int) = @inbounds storage[idx] """ - _view_array(storage, sizes, k) -> AbstractArray + _view_linear(storage, sizes, k) -> SubArray -Return a view of the slice of `storage` that holds node `k`'s array value, -reshaped to that node's natural shape. The view aliases the underlying -`storage` (no copy), so mutating the returned array writes back into the tape. -For a scalar (`ndims[k] == 0`) node this returns a length-1 vector view. +Return a flat 1-D view of the slice of `storage` that holds node `k`'s array +value. The view aliases the underlying `storage` (no copy), so mutating it +writes back into the tape. For a scalar (`ndims[k] == 0`) node this returns +a length-1 vector view. + +Use this for elementwise (broadcasted) operations and reductions that don't +need the array's natural shape — keeping the return type-stable +(`SubArray{T,1,...}`) avoids the heap-boxing that a multi-shape return type +would force. """ -function _view_array(storage::AbstractVector, sizes::Sizes, k::Int) - nd = sizes.ndims[k] +function _view_linear(storage::AbstractVector, sizes::Sizes, k::Int) offset = sizes.storage_offset[k] - if nd == 0 - return view(storage, (offset+1):(offset+1)) - elseif nd == 1 - n = sizes.size[sizes.size_offset[k]+1] - return view(storage, (offset+1):(offset+n)) - else - N = _length(sizes, k) - v = view(storage, (offset+1):(offset+N)) - szs = ntuple(d -> sizes.size[sizes.size_offset[k]+d], nd) - return reshape(v, szs) - end + N = _length(sizes, k) + return view(storage, (offset+1):(offset+N)) +end + +""" + _view_matrix(storage, sizes, k) -> ReshapedArray + +Return a 2-D view of the slice of `storage` that holds node `k`'s array +value. A 1-D node is treated as a column vector `(n, 1)` and a 0-D node as +`(1, 1)`. Always returns a 2-D `Base.ReshapedArray`, which is what callers +like `LinearAlgebra.mul!` need; keeping the return type-stable avoids +heap-boxing. +""" +function _view_matrix(storage::AbstractVector, sizes::Sizes, k::Int) + @assert sizes.ndims[k] == 2 + offset = sizes.storage_offset[k] + size_off = sizes.size_offset[k] + m = sizes.size[size_off+1] + n = sizes.size[size_off+2] + v = view(storage, (offset+1):(offset+m*n)) + return reshape(v, (m, n)) end """ diff --git a/test/JuMP.jl b/test/JuMP.jl index 2774d81..912a97c 100644 --- a/test/JuMP.jl +++ b/test/JuMP.jl @@ -279,6 +279,58 @@ function test_neural() end end +# Builds the same `sum((W2*tanh.(W1*X) - target)^2)` MLP that `test_neural` +# exercises and checks that, after warmup, both `eval_objective` and +# `eval_objective_gradient` are allocation-free on the CPU `Vector{Float64}` +# tape — including when the input `x` has changed since the last call (which +# is the path that actually re-runs forward+reverse, not the +# `last_x == x` short-circuit). +function test_neural_allocations() + n = 2 + X = [1.0 0.5; 0.3 0.8] + target = [0.5 0.2; 0.1 0.7] + model = Model() + @variable(model, W1[1:n, 1:n], container = ArrayDiff.ArrayOfVariables) + @variable(model, W2[1:n, 1:n], container = ArrayDiff.ArrayOfVariables) + Y = W2 * tanh.(W1 * X) + loss = sum((Y .- target) .^ 2) + mode = ArrayDiff.Mode() + ad = ArrayDiff.model(mode) + MOI.Nonlinear.set_objective(ad, JuMP.moi_function(loss)) + evaluator = MOI.Nonlinear.Evaluator( + ad, + mode, + JuMP.index.(JuMP.all_variables(model)), + ) + MOI.initialize(evaluator, [:Grad]) + x1 = Float64.(collect(1:8)) + x2 = Float64.(collect(2:9)) + g = zeros(8) + # Wrapped in typed functions so `@allocated` doesn't capture the + # return-value boxing that happens when calling `eval_objective` + # directly from the macro's untyped scope (each `MOI.eval_objective` + # returns a `Float64` which then escapes into `Any`-typed scope). + _obj(ev, x) = MOI.eval_objective(ev, x) + function _grad!(ev, g, x) + MOI.eval_objective_gradient(ev, g, x) + return nothing + end + # Warmup: trigger JIT compilation for both `eval_objective` and + # `eval_objective_gradient`. Two distinct inputs so `_reverse_mode`'s + # `last_x == x` short-circuit doesn't elide the work on the second call. + _obj(evaluator, x1) + _obj(evaluator, x2) + _grad!(evaluator, g, x1) + _grad!(evaluator, g, x2) + # Now alternate: each measured call sees `last_x ≠ x`, so it actually + # runs the full forward + reverse passes through the block tape. + @test 0 == @allocated _obj(evaluator, x1) + @test 0 == @allocated _obj(evaluator, x2) + @test 0 == @allocated _grad!(evaluator, g, x1) + @test 0 == @allocated _grad!(evaluator, g, x2) + return +end + function test_moi_function() model = Model() @variable(model, W[1:2, 1:2], container = ArrayDiff.ArrayOfVariables)