Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions src/NUTS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Random boolean which is `true` with the given probability `exp(logprob)`, which
in which case no random value is drawn.
"""
function rand_bool_logprob(rng::AbstractRNG, logprob)
logprob ≥ 0 || (randexp(rng, Float64) > -logprob)
logprob ≥ 0 || (randexp(rng, typeof(logprob)) > -logprob)
end

function calculate_logprob2(::TrajectoryNUTS, is_doubling, ω₁, ω₂, ω)
Expand Down Expand Up @@ -175,22 +175,22 @@ developments, as described in Betancourt (2017).

$(FIELDS)
"""
struct NUTS{S}
struct NUTS{T<:Real,S}
"Maximum tree depth."
max_depth::Int
"Threshold for negative energy relative to starting point that indicates divergence."
min_Δ::Float64
min_Δ::T
"""
Turn statistic configuration. Currently only `Val(:generalized)` (the default) is
supported.
"""
turn_statistic_configuration::S
function NUTS(; max_depth = DEFAULT_MAX_TREE_DEPTH, min_Δ = -1000.0,
turn_statistic_configuration = Val{:generalized}())
function NUTS(; max_depth = DEFAULT_MAX_TREE_DEPTH, min_Δ::T = -1000.0,
turn_statistic_configuration = Val{:generalized}()) where {T<:Real}
@argcheck 0 < max_depth ≤ MAX_DIRECTIONS_DEPTH
@argcheck min_Δ < 0
S = typeof(turn_statistic_configuration)
new{S}(Int(max_depth), Float64(min_Δ), turn_statistic_configuration)
new{T,S}(Int(max_depth), min_Δ, turn_statistic_configuration)
end
end

Expand All @@ -205,15 +205,15 @@ Accessing fields directly is part of the API.

$(FIELDS)
"""
struct TreeStatisticsNUTS
struct TreeStatisticsNUTS{T<:Real}
"Log density of the Hamiltonian (negative energy)."
π::Float64
π::T
"Depth of the tree."
depth::Int
"Reason for termination. See [`InvalidTree`](@ref) and [`REACHED_MAX_DEPTH`](@ref)."
termination::InvalidTree
"Acceptance rate statistic."
acceptance_rate::Float64
acceptance_rate::T
"Number of leapfrog steps evaluated."
steps::Int
"Directions for tree doubling (useful for debugging)."
Expand All @@ -233,7 +233,8 @@ function sample_tree(rng, algorithm::NUTS, H::Hamiltonian, Q::EvaluatedLogDensit
p = rand_p(rng, H.κ), directions = rand(rng, Directions))
(; max_depth, min_Δ, turn_statistic_configuration) = algorithm
z = PhasePoint(Q, p)
trajectory = TrajectoryNUTS(H, logdensity(H, z), ϵ, min_Δ, turn_statistic_configuration)
π₀ = logdensity(H, z)
trajectory = TrajectoryNUTS(H, π₀, oftype(π₀, ϵ), oftype(π₀, min_Δ), turn_statistic_configuration)
ζ, v, termination, depth = sample_trajectory(rng, trajectory, z, max_depth, directions)
tree_statistics = TreeStatisticsNUTS(logdensity(H, ζ), depth, termination,
acceptance_rate(v), v.steps, directions)
Expand Down
4 changes: 2 additions & 2 deletions src/hamiltonian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ $(SIGNATURES)

Gaussian kinetic energy with a diagonal inverse covariance matrix `M⁻¹=m⁻¹*I`.
"""
GaussianKineticEnergy(N::Integer, m⁻¹ = 1.0) = GaussianKineticEnergy(Diagonal(Fill(m⁻¹, N)))
GaussianKineticEnergy(N::Integer, m⁻¹ = 1.0) = GaussianKineticEnergy(Diagonal(Fill(float(m⁻¹), N)))

function Base.show(io::IO, κ::GaussianKineticEnergy{T}) where {T}
print(io::IO, "Gaussian kinetic energy ($(nameof(T))), √diag(M⁻¹): $(.√(diag(κ.M⁻¹)))")
Expand Down Expand Up @@ -121,7 +121,7 @@ $(SIGNATURES)

Generate a random momentum from a kinetic energy at position `q`.
"""
rand_p(rng::AbstractRNG, κ::GaussianKineticEnergy, q = nothing) = κ.W * randn(rng, size(κ.W, 1))
rand_p(rng::AbstractRNG, κ::GaussianKineticEnergy, q = nothing) = κ.W * randn(rng, eltype(κ.W), size(κ.W, 1))

####
#### Hamiltonian
Expand Down
16 changes: 10 additions & 6 deletions src/mcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,16 @@ Create an initial warmup state from a random position.
$(DOC_INITIAL_WARMUP_ARGS)
"""
function initialize_warmup_state(rng, ℓ; q = random_position(rng, dimension(ℓ)),
κ = GaussianKineticEnergy(dimension(ℓ)), ϵ = nothing)
κ = GaussianKineticEnergy(dimension(ℓ), one(eltype(q))), ϵ = nothing)
WarmupState(evaluate_ℓ(ℓ, q; strict = true), κ, ϵ)
end

function warmup(sampling_logdensity, stepsize_search::InitialStepsizeSearch, warmup_state)
(; rng, ℓ, reporter) = sampling_logdensity
(; Q, κ, ϵ) = warmup_state
@argcheck ϵ ≡ nothing "stepsize ϵ manually specified, won't perform initial search"
T = typeof(one(eltype(Q.q)))
stepsize_search = _oftype(stepsize_search, T)
z = PhasePoint(Q, rand_p(rng, κ))
try
ϵ = find_initial_stepsize(stepsize_search, local_log_acceptance_ratio(Hamiltonian(κ, ℓ), z))
Expand Down Expand Up @@ -175,7 +177,7 @@ A `NamedTuple` with the following fields:

$(FIELDS)
"""
struct TuningNUTS{M,D}
struct TuningNUTS{M,D,T<:Real}
"Number of samples."
N::Int
"Dual averaging parameters."
Expand All @@ -185,12 +187,12 @@ struct TuningNUTS{M,D}
rescaled by `λ` towards ``σ²I``, where ``σ²`` is the median of the diagonal. The
constructor has a reasonable default.
"""
λ::Float64
λ::T
function TuningNUTS{M}(N::Integer, stepsize_adaptation::D,
λ = 5.0/N) where {M <: Union{Nothing,Diagonal,Symmetric},D}
λ::T = 5.0/N) where {M <: Union{Nothing,Diagonal,Symmetric},D,T<:Real}
@argcheck N ≥ 20 # variance estimator is kind of meaningless for few samples
@argcheck λ ≥ 0
new{M,D}(N, stepsize_adaptation, λ)
new{M,D,T}(N, stepsize_adaptation, λ)
end
end

Expand Down Expand Up @@ -259,12 +261,14 @@ function warmup(sampling_logdensity, tuning::TuningNUTS{M}, warmup_state) where
(; rng, ℓ, algorithm, reporter) = sampling_logdensity
(; Q, κ, ϵ) = warmup_state
(; N, stepsize_adaptation, λ) = tuning
T = typeof(one(eltype(Q.q)))
stepsize_adaptation = _oftype(stepsize_adaptation, T)
posterior_matrix = _empty_posterior_matrix(Q, N)
logdensities = _empty_logdensity_vector(Q, N)
tree_statistics = Vector{TreeStatisticsNUTS}(undef, N)
H = Hamiltonian(κ, ℓ)
ϵ_state = initial_adaptation_state(stepsize_adaptation, ϵ)
ϵs = Vector{Float64}(undef, N)
ϵs = Vector{typeof(float(ϵ))}(undef, N)
mcmc_reporter = make_mcmc_reporter(reporter, N;
currently_warmup = true,
tuning = M ≡ Nothing ? "stepsize" : "stepsize and $(M) metric")
Expand Down
36 changes: 27 additions & 9 deletions src/stepsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,19 @@ $FIELDS

The algorithm is from Hoffman and Gelman (2014), default threshold modified to `0.8` following later practice in Stan.
"""
struct InitialStepsizeSearch
struct InitialStepsizeSearch{T<:Real}
"The stepsize where the search is started."
initial_ϵ::Float64
initial_ϵ::T
"Log of the threshold that needs to be crossed."
log_threshold::Float64
log_threshold::T
"Maximum number of iterations for crossing the threshold."
maxiter_crossing::Int
function InitialStepsizeSearch(; log_threshold::Float64 = log(0.8), initial_ϵ = 0.1, maxiter_crossing = 400)
function InitialStepsizeSearch(; log_threshold = log(0.8), initial_ϵ = 0.1, maxiter_crossing = 400)
T = promote_type(typeof(log_threshold), typeof(initial_ϵ))
@argcheck isfinite(log_threshold) && log_threshold < 0
@argcheck isfinite(initial_ϵ) && 0 < initial_ϵ
@argcheck maxiter_crossing ≥ 50
new(initial_ϵ, log_threshold, maxiter_crossing)
new{T}(T(initial_ϵ), T(log_threshold), maxiter_crossing)
end
end

Expand Down Expand Up @@ -131,10 +132,10 @@ $(SIGNATURES)

Return an initial adaptation state for the adaptation method and a stepsize `ϵ`.
"""
function initial_adaptation_state(::DualAveraging, ϵ)
function initial_adaptation_state(::DualAveraging{T}, ϵ) where T <: AbstractFloat
@argcheck ϵ > 0
logϵ = log(ϵ)
DualAveragingState(; μ = log(10) + logϵ, m = 1, H̄ = zero(logϵ), logϵ, logϵ̄ = zero(logϵ))
DualAveragingState{T}(; μ = log(T(10)) + logϵ, m = 1, H̄ = zero(logϵ), logϵ, logϵ̄ = zero(logϵ))
end

"""
Expand All @@ -150,8 +151,9 @@ function adapt_stepsize(parameters::DualAveraging, A::DualAveragingState, a)
(; μ, m, H̄, logϵ, logϵ̄) = A
m += 1
H̄ += (δ - a - H̄) / (m + t₀)
logϵ = μ - √m/γ * H̄
logϵ̄ += m^(-κ)*(logϵ - logϵ̄)
T_m = oftype(μ, m)
logϵ = μ - sqrt(T_m)/γ * H̄
logϵ̄ += T_m^(-κ)*(logϵ - logϵ̄)
DualAveragingState(; μ, m, H̄, logϵ, logϵ̄)
end

Expand Down Expand Up @@ -187,3 +189,19 @@ adapt_stepsize(::FixedStepsize, ϵ, a) = ϵ
current_ϵ(ϵ::Real) = ϵ

final_ϵ(ϵ::Real) = ϵ

###
### type conversion helpers for warmup pipeline
###

_oftype(da::DualAveraging{T}, ::Type{T}) where {T} = da
_oftype(da::DualAveraging, ::Type{T}) where {T<:AbstractFloat} =
DualAveraging(T(da.δ), T(da.γ), T(da.κ), da.t₀)

_oftype(iss::InitialStepsizeSearch{T}, ::Type{T}) where {T} = iss
_oftype(iss::InitialStepsizeSearch, ::Type{T}) where {T<:Real} =
InitialStepsizeSearch(; log_threshold = T(iss.log_threshold),
initial_ϵ = T(iss.initial_ϵ),
maxiter_crossing = iss.maxiter_crossing)

_oftype(fs::FixedStepsize, ::Type) = fs
91 changes: 91 additions & 0 deletions test/test_mcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,97 @@ end
@test M == 0
end

@testset "Float32 support" begin
# Float32 multivariate normal: ℓ(q) = -½ (q - μ)ᵀ Σ⁻¹ (q - μ) with Σ = I
struct Float32Normal{V <: AbstractVector}
μ::V
end
LogDensityProblems.capabilities(::Type{<:Float32Normal}) = LogDensityProblems.LogDensityOrder{1}()
LogDensityProblems.dimension(ℓ::Float32Normal) = length(ℓ.μ)
function LogDensityProblems.logdensity_and_gradient(ℓ::Float32Normal, q::AbstractVector)
r = q - ℓ.μ
T = eltype(q)
T(-dot(r, r) / 2), -r
end

@testset "type propagation" begin
ℓ32 = Float32Normal(zeros(Float32, 3))
q0 = randn(Float32, 3)
results = mcmc_with_warmup(RNG, ℓ32, 100;
initialization = (q = q0,),
reporter = NoProgressReport())
@test eltype(results.posterior_matrix) == Float32
@test eltype(results.logdensities) == Float32
@test results.tree_statistics[1].π isa Float32
@test results.tree_statistics[1].acceptance_rate isa Float32
@test results.ϵ isa Float32
end

@testset "no type promotion in compute" begin
# A log density that errors if position is not Float32,
# catching any accidental promotion in leapfrog/adaptation
struct StrictFloat32Normal{V <: AbstractVector{Float32}}
μ::V
end
LogDensityProblems.capabilities(::Type{<:StrictFloat32Normal}) = LogDensityProblems.LogDensityOrder{1}()
LogDensityProblems.dimension(ℓ::StrictFloat32Normal) = length(ℓ.μ)
function LogDensityProblems.logdensity_and_gradient(ℓ::StrictFloat32Normal, q::AbstractVector)
@assert eltype(q) === Float32 "position promoted to $(eltype(q)), expected Float32"
r = q - ℓ.μ
Float32(-dot(r, r) / 2), -r
end
ℓ_strict = StrictFloat32Normal(zeros(Float32, 3))
q0 = randn(Float32, 3)
# runs full warmup (stepsize search + dual averaging + metric adaptation)
# and inference — any Float64 promotion in leapfrog would trigger the assertion
results = mcmc_with_warmup(RNG, ℓ_strict, 100;
initialization = (q = q0,),
reporter = NoProgressReport())
@test eltype(results.posterior_matrix) == Float32
@test results.ϵ isa Float32
end

@testset "sample correctness" begin
μ = Float32[1.0, -0.5, 2.0, 0.0, -1.5]
ℓ32 = Float32Normal(μ)
q0 = randn(Float32, 5)
results = mcmc_with_warmup(RNG, ℓ32, 10000;
initialization = (q = q0,),
reporter = NoProgressReport())
Z = results.posterior_matrix
@test eltype(Z) == Float32
@test norm(mean(Z; dims = 2) .- μ, Inf) < 0.06
@test norm(std(Z; dims = 2) .- ones(Float32, 5), Inf) < 0.06
@test mean(x -> x.acceptance_rate, results.tree_statistics) ≥ 0.7
end

@testset "fixed stepsize" begin
ℓ32 = Float32Normal(ones(Float32, 3))
q0 = randn(Float32, 3)
results = mcmc_with_warmup(RNG, ℓ32, 5000;
initialization = (q = q0, ϵ = Float32(1.0)),
warmup_stages = fixed_stepsize_warmup_stages(),
reporter = NoProgressReport())
Z = results.posterior_matrix
@test eltype(Z) == Float32
@test norm(mean(Z; dims = 2) .- ones(Float32, 3), Inf) < 0.1
end

@testset "stepwise" begin
ℓ32 = Float32Normal(zeros(Float32, 3))
q0 = randn(Float32, 3)
results = mcmc_keep_warmup(RNG, ℓ32, 0;
initialization = (q = q0,),
reporter = NoProgressReport())
steps = mcmc_steps(results.sampling_logdensity, results.final_warmup_state)
Q = results.final_warmup_state.Q
@test eltype(Q.q) == Float32
qs = [(Q = first(mcmc_next_step(steps, Q)); Q.q) for _ in 1:1000]
@test eltype(qs[1]) == Float32
@test norm(mean(reduce(hcat, qs); dims = 2), Inf) ≤ 0.15
end
end

@testset "posterior accessors sanity checks" begin
D, N, K = 5, 100, 7
ℓ = multivariate_normal(ones(5))
Expand Down
Loading