Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,15 @@ ADTypes.SymbolicMode
ADTypes.Auto
```

## Gradient API

```@docs
ADTypes.GradientOrder
ADTypes.gradient_order
ADTypes.value_and_gradient!!
ADTypes.value_and_jacobian!!
```

## Deprecated

```@docs
Expand Down
7 changes: 7 additions & 0 deletions src/ADTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ include("dense.jl")
include("sparse.jl")
include("legacy.jl")
include("symbols.jl")
include("gradient_api.jl")

# Automatic Differentiation
export AbstractADType
Expand Down Expand Up @@ -55,6 +56,12 @@ export AutoChainRules,
@public mode
@public Auto

# Gradient API (minimal interface for backends to implement)
@public GradientOrder
@public gradient_order
@public value_and_gradient!!
@public value_and_jacobian!!

# Sparse Automatic Differentiation
export AutoSparse
@public dense_ad
Expand Down
103 changes: 103 additions & 0 deletions src/gradient_api.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
## Capability trait

"""
GradientOrder{K}

Trait indicating that an AD backend supports computing derivatives up to order `K`:

- `GradientOrder{0}()`: primal evaluation only
- `GradientOrder{1}()`: value + gradient / Jacobian
- `GradientOrder{2}()`: value + gradient + Hessian

Backends declare their capability by implementing [`gradient_order`](@ref).
Consumers can compare orders: `GradientOrder{1}() ≤ GradientOrder{2}()`.
"""
struct GradientOrder{K}
function GradientOrder{K}() where {K}
_K = Int(K)
_K ≥ 0 || throw(ArgumentError("GradientOrder requires K ≥ 0, got $_K"))
new{_K}()
end
end

GradientOrder(K::Integer) = GradientOrder{Int(K)}()

Base.isless(::GradientOrder{J}, ::GradientOrder{K}) where {J, K} = J < K

"""
gradient_order(backend::AbstractADType) -> GradientOrder{K} or nothing

Return the [`GradientOrder`](@ref) supported by `backend`, or `nothing` if the backend
does not implement the ADTypes gradient API.

Backends declare support by adding a method:

ADTypes.gradient_order(::MyBackend) = GradientOrder{1}()
"""
gradient_order(::AbstractADType) = nothing

## Interface functions

"""
value_and_gradient!!(f, backend::AbstractADType, x)

Compute the primal value `y = f(x)` and gradient `∇f(x)` for a scalar-valued function `f`.

Returns `(y, g)` where `g` has the same structure as `x`.

The `!!` signals that the backend may mutate internal cache state. The caller owns the
returned values: mutable components (e.g. gradient arrays) may be overwritten on the next
call with the same backend, so copy if you need to retain them.

# Interface

Backends supporting first-order derivatives implement:

ADTypes.value_and_gradient!!(f, ::MyBackend, x) = ...

and declare:

ADTypes.gradient_order(::MyBackend) = GradientOrder{1}()

See also: [`value_and_jacobian!!`](@ref), [`gradient_order`](@ref).
"""
function value_and_gradient!! end

"""
value_and_jacobian!!(f, backend::AbstractADType, x)

Compute the primal value `y = f(x)` and the Jacobian `∂f(x)` for a general function `f`.

- If `f` is scalar-valued, this is equivalent to [`value_and_gradient!!`](@ref).
- If `f` is vector-valued (`f : ℝⁿ → ℝᵐ`), returns the full `m × n` Jacobian matrix.

The `!!` signals that the backend may mutate internal cache state. The caller owns the
returned values.

# Interface

Backends implement:

ADTypes.value_and_jacobian!!(f, ::MyBackend, x) = ...

See also: [`value_and_gradient!!`](@ref), [`gradient_order`](@ref).
"""
function value_and_jacobian!! end

## Error fallbacks

function value_and_gradient!!(f::F, ::T, x) where {F, T<:AbstractADType}
throw(ArgumentError(
"`ADTypes.value_and_gradient!!` is not implemented for backend `$T`. " *
"Add a method:\n ADTypes.value_and_gradient!!(f, ::$T, x) = ...\n" *
"and declare:\n ADTypes.gradient_order(::$T) = GradientOrder{1}()"
))
end

function value_and_jacobian!!(f::F, ::T, x) where {F, T<:AbstractADType}
throw(ArgumentError(
"`ADTypes.value_and_jacobian!!` is not implemented for backend `$T`. " *
"Add a method:\n ADTypes.value_and_jacobian!!(f, ::$T, x) = ...\n" *
"and declare:\n ADTypes.gradient_order(::$T) = GradientOrder{1}()"
))
end
23 changes: 23 additions & 0 deletions test/gradient_api.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
using ADTypes: GradientOrder, gradient_order

struct UnimplementedBackend <: AbstractADType end

@testset "GradientOrder trait" begin
@test GradientOrder{0}() isa GradientOrder
@test GradientOrder{1}() isa GradientOrder
@test GradientOrder{0}() < GradientOrder{1}()
@test GradientOrder{1}() < GradientOrder{2}()
@test !(GradientOrder{1}() < GradientOrder{1}())
@test_throws ArgumentError GradientOrder{-1}()
end

@testset "gradient_order" begin
@test gradient_order(UnimplementedBackend()) === nothing
end

@testset "Error fallbacks" begin
f = x -> x^2
backend = UnimplementedBackend()
@test_throws ArgumentError ADTypes.value_and_gradient!!(f, backend, 1.0)
@test_throws ArgumentError ADTypes.value_and_jacobian!!(f, backend, 1.0)
end
5 changes: 5 additions & 0 deletions test/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,10 @@ public_symbols = (
# Matrix coloring
:coloring_algorithm,
:NoColoringAlgorithm,
# Gradient API
:GradientOrder,
:gradient_order,
:value_and_gradient!!,
:value_and_jacobian!!,
)
@test public_symbols ⊆ names(ADTypes)
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ end
@testset "Miscellaneous" begin
include("misc.jl")
end
@testset "Gradient API" begin
include("gradient_api.jl")
end
if VERSION >= v"1.11.0-DEV.469"
@testset "Public" begin
include("public.jl")
Expand Down
Loading