diff --git a/docs/src/index.md b/docs/src/index.md index 85e7e55..867a691 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -123,6 +123,15 @@ ADTypes.SymbolicMode ADTypes.Auto ``` +## Gradient API + +```@docs +ADTypes.GradientOrder +ADTypes.gradient_order +ADTypes.value_and_gradient!! +ADTypes.value_and_jacobian!! +``` + ## Deprecated ```@docs diff --git a/src/ADTypes.jl b/src/ADTypes.jl index de5f756..8f81b1f 100644 --- a/src/ADTypes.jl +++ b/src/ADTypes.jl @@ -25,6 +25,7 @@ include("dense.jl") include("sparse.jl") include("legacy.jl") include("symbols.jl") +include("gradient_api.jl") # Automatic Differentiation export AbstractADType @@ -55,6 +56,12 @@ export AutoChainRules, @public mode @public Auto +# Gradient API (minimal interface for backends to implement) +@public GradientOrder +@public gradient_order +@public value_and_gradient!! +@public value_and_jacobian!! + # Sparse Automatic Differentiation export AutoSparse @public dense_ad diff --git a/src/gradient_api.jl b/src/gradient_api.jl new file mode 100644 index 0000000..5b388b0 --- /dev/null +++ b/src/gradient_api.jl @@ -0,0 +1,103 @@ +## Capability trait + +""" + GradientOrder{K} + +Trait indicating that an AD backend supports computing derivatives up to order `K`: + + - `GradientOrder{0}()`: primal evaluation only + - `GradientOrder{1}()`: value + gradient / Jacobian + - `GradientOrder{2}()`: value + gradient + Hessian + +Backends declare their capability by implementing [`gradient_order`](@ref). +Consumers can compare orders: `GradientOrder{1}() ≤ GradientOrder{2}()`. +""" +struct GradientOrder{K} + function GradientOrder{K}() where {K} + _K = Int(K) + _K ≥ 0 || throw(ArgumentError("GradientOrder requires K ≥ 0, got $_K")) + new{_K}() + end +end + +GradientOrder(K::Integer) = GradientOrder{Int(K)}() + +Base.isless(::GradientOrder{J}, ::GradientOrder{K}) where {J, K} = J < K + +""" + gradient_order(backend::AbstractADType) -> GradientOrder{K} or nothing + +Return the [`GradientOrder`](@ref) supported by `backend`, or `nothing` if the backend +does not implement the ADTypes gradient API. + +Backends declare support by adding a method: + + ADTypes.gradient_order(::MyBackend) = GradientOrder{1}() +""" +gradient_order(::AbstractADType) = nothing + +## Interface functions + +""" + value_and_gradient!!(f, backend::AbstractADType, x) + +Compute the primal value `y = f(x)` and gradient `∇f(x)` for a scalar-valued function `f`. + +Returns `(y, g)` where `g` has the same structure as `x`. + +The `!!` signals that the backend may mutate internal cache state. The caller owns the +returned values: mutable components (e.g. gradient arrays) may be overwritten on the next +call with the same backend, so copy if you need to retain them. + +# Interface + +Backends supporting first-order derivatives implement: + + ADTypes.value_and_gradient!!(f, ::MyBackend, x) = ... + +and declare: + + ADTypes.gradient_order(::MyBackend) = GradientOrder{1}() + +See also: [`value_and_jacobian!!`](@ref), [`gradient_order`](@ref). +""" +function value_and_gradient!! end + +""" + value_and_jacobian!!(f, backend::AbstractADType, x) + +Compute the primal value `y = f(x)` and the Jacobian `∂f(x)` for a general function `f`. + + - If `f` is scalar-valued, this is equivalent to [`value_and_gradient!!`](@ref). + - If `f` is vector-valued (`f : ℝⁿ → ℝᵐ`), returns the full `m × n` Jacobian matrix. + +The `!!` signals that the backend may mutate internal cache state. The caller owns the +returned values. + +# Interface + +Backends implement: + + ADTypes.value_and_jacobian!!(f, ::MyBackend, x) = ... + +See also: [`value_and_gradient!!`](@ref), [`gradient_order`](@ref). +""" +function value_and_jacobian!! end + +## Error fallbacks + +function value_and_gradient!!(f::F, ::T, x) where {F, T<:AbstractADType} + throw(ArgumentError( + "`ADTypes.value_and_gradient!!` is not implemented for backend `$T`. " * + "Add a method:\n ADTypes.value_and_gradient!!(f, ::$T, x) = ...\n" * + "and declare:\n ADTypes.gradient_order(::$T) = GradientOrder{1}()" + )) +end + +function value_and_jacobian!!(f::F, ::T, x) where {F, T<:AbstractADType} + throw(ArgumentError( + "`ADTypes.value_and_jacobian!!` is not implemented for backend `$T`. " * + "Add a method:\n ADTypes.value_and_jacobian!!(f, ::$T, x) = ...\n" * + "and declare:\n ADTypes.gradient_order(::$T) = GradientOrder{1}()" + )) +end diff --git a/test/gradient_api.jl b/test/gradient_api.jl new file mode 100644 index 0000000..4ef6fe6 --- /dev/null +++ b/test/gradient_api.jl @@ -0,0 +1,23 @@ +using ADTypes: GradientOrder, gradient_order + +struct UnimplementedBackend <: AbstractADType end + +@testset "GradientOrder trait" begin + @test GradientOrder{0}() isa GradientOrder + @test GradientOrder{1}() isa GradientOrder + @test GradientOrder{0}() < GradientOrder{1}() + @test GradientOrder{1}() < GradientOrder{2}() + @test !(GradientOrder{1}() < GradientOrder{1}()) + @test_throws ArgumentError GradientOrder{-1}() +end + +@testset "gradient_order" begin + @test gradient_order(UnimplementedBackend()) === nothing +end + +@testset "Error fallbacks" begin + f = x -> x^2 + backend = UnimplementedBackend() + @test_throws ArgumentError ADTypes.value_and_gradient!!(f, backend, 1.0) + @test_throws ArgumentError ADTypes.value_and_jacobian!!(f, backend, 1.0) +end diff --git a/test/public.jl b/test/public.jl index b8ed0c6..d890cde 100644 --- a/test/public.jl +++ b/test/public.jl @@ -19,5 +19,10 @@ public_symbols = ( # Matrix coloring :coloring_algorithm, :NoColoringAlgorithm, + # Gradient API + :GradientOrder, + :gradient_order, + :value_and_gradient!!, + :value_and_jacobian!!, ) @test public_symbols ⊆ names(ADTypes) diff --git a/test/runtests.jl b/test/runtests.jl index c77d322..33171e1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -115,6 +115,9 @@ end @testset "Miscellaneous" begin include("misc.jl") end + @testset "Gradient API" begin + include("gradient_api.jl") + end if VERSION >= v"1.11.0-DEV.469" @testset "Public" begin include("public.jl")