From 28ea74b4eca44b15796546715ae9044af70b9213 Mon Sep 17 00:00:00 2001 From: Samuel Brand <48288458+SamuelBrand1@users.noreply.github.com> Date: Thu, 5 Mar 2026 11:58:07 -0500 Subject: [PATCH] Support Float32 and type-stable warmup/NUTS --- src/NUTS.jl | 21 ++++++----- src/hamiltonian.jl | 4 +- src/mcmc.jl | 16 +++++--- src/stepsize.jl | 36 +++++++++++++----- test/test_mcmc.jl | 91 ++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 141 insertions(+), 27 deletions(-) diff --git a/src/NUTS.jl b/src/NUTS.jl index 2a50103..2a2a158 100644 --- a/src/NUTS.jl +++ b/src/NUTS.jl @@ -41,7 +41,7 @@ Random boolean which is `true` with the given probability `exp(logprob)`, which in which case no random value is drawn. """ function rand_bool_logprob(rng::AbstractRNG, logprob) - logprob ≥ 0 || (randexp(rng, Float64) > -logprob) + logprob ≥ 0 || (randexp(rng, typeof(logprob)) > -logprob) end function calculate_logprob2(::TrajectoryNUTS, is_doubling, ω₁, ω₂, ω) @@ -175,22 +175,22 @@ developments, as described in Betancourt (2017). $(FIELDS) """ -struct NUTS{S} +struct NUTS{T<:Real,S} "Maximum tree depth." max_depth::Int "Threshold for negative energy relative to starting point that indicates divergence." - min_Δ::Float64 + min_Δ::T """ Turn statistic configuration. Currently only `Val(:generalized)` (the default) is supported. """ turn_statistic_configuration::S - function NUTS(; max_depth = DEFAULT_MAX_TREE_DEPTH, min_Δ = -1000.0, - turn_statistic_configuration = Val{:generalized}()) + function NUTS(; max_depth = DEFAULT_MAX_TREE_DEPTH, min_Δ::T = -1000.0, + turn_statistic_configuration = Val{:generalized}()) where {T<:Real} @argcheck 0 < max_depth ≤ MAX_DIRECTIONS_DEPTH @argcheck min_Δ < 0 S = typeof(turn_statistic_configuration) - new{S}(Int(max_depth), Float64(min_Δ), turn_statistic_configuration) + new{T,S}(Int(max_depth), min_Δ, turn_statistic_configuration) end end @@ -205,15 +205,15 @@ Accessing fields directly is part of the API. $(FIELDS) """ -struct TreeStatisticsNUTS +struct TreeStatisticsNUTS{T<:Real} "Log density of the Hamiltonian (negative energy)." - π::Float64 + π::T "Depth of the tree." depth::Int "Reason for termination. See [`InvalidTree`](@ref) and [`REACHED_MAX_DEPTH`](@ref)." termination::InvalidTree "Acceptance rate statistic." - acceptance_rate::Float64 + acceptance_rate::T "Number of leapfrog steps evaluated." steps::Int "Directions for tree doubling (useful for debugging)." @@ -233,7 +233,8 @@ function sample_tree(rng, algorithm::NUTS, H::Hamiltonian, Q::EvaluatedLogDensit p = rand_p(rng, H.κ), directions = rand(rng, Directions)) (; max_depth, min_Δ, turn_statistic_configuration) = algorithm z = PhasePoint(Q, p) - trajectory = TrajectoryNUTS(H, logdensity(H, z), ϵ, min_Δ, turn_statistic_configuration) + π₀ = logdensity(H, z) + trajectory = TrajectoryNUTS(H, π₀, oftype(π₀, ϵ), oftype(π₀, min_Δ), turn_statistic_configuration) ζ, v, termination, depth = sample_trajectory(rng, trajectory, z, max_depth, directions) tree_statistics = TreeStatisticsNUTS(logdensity(H, ζ), depth, termination, acceptance_rate(v), v.steps, directions) diff --git a/src/hamiltonian.jl b/src/hamiltonian.jl index a7b3305..2ce6b5e 100644 --- a/src/hamiltonian.jl +++ b/src/hamiltonian.jl @@ -84,7 +84,7 @@ $(SIGNATURES) Gaussian kinetic energy with a diagonal inverse covariance matrix `M⁻¹=m⁻¹*I`. """ -GaussianKineticEnergy(N::Integer, m⁻¹ = 1.0) = GaussianKineticEnergy(Diagonal(Fill(m⁻¹, N))) +GaussianKineticEnergy(N::Integer, m⁻¹ = 1.0) = GaussianKineticEnergy(Diagonal(Fill(float(m⁻¹), N))) function Base.show(io::IO, κ::GaussianKineticEnergy{T}) where {T} print(io::IO, "Gaussian kinetic energy ($(nameof(T))), √diag(M⁻¹): $(.√(diag(κ.M⁻¹)))") @@ -121,7 +121,7 @@ $(SIGNATURES) Generate a random momentum from a kinetic energy at position `q`. """ -rand_p(rng::AbstractRNG, κ::GaussianKineticEnergy, q = nothing) = κ.W * randn(rng, size(κ.W, 1)) +rand_p(rng::AbstractRNG, κ::GaussianKineticEnergy, q = nothing) = κ.W * randn(rng, eltype(κ.W), size(κ.W, 1)) #### #### Hamiltonian diff --git a/src/mcmc.jl b/src/mcmc.jl index e253623..9924aef 100644 --- a/src/mcmc.jl +++ b/src/mcmc.jl @@ -127,7 +127,7 @@ Create an initial warmup state from a random position. $(DOC_INITIAL_WARMUP_ARGS) """ function initialize_warmup_state(rng, ℓ; q = random_position(rng, dimension(ℓ)), - κ = GaussianKineticEnergy(dimension(ℓ)), ϵ = nothing) + κ = GaussianKineticEnergy(dimension(ℓ), one(eltype(q))), ϵ = nothing) WarmupState(evaluate_ℓ(ℓ, q; strict = true), κ, ϵ) end @@ -135,6 +135,8 @@ function warmup(sampling_logdensity, stepsize_search::InitialStepsizeSearch, war (; rng, ℓ, reporter) = sampling_logdensity (; Q, κ, ϵ) = warmup_state @argcheck ϵ ≡ nothing "stepsize ϵ manually specified, won't perform initial search" + T = typeof(one(eltype(Q.q))) + stepsize_search = _oftype(stepsize_search, T) z = PhasePoint(Q, rand_p(rng, κ)) try ϵ = find_initial_stepsize(stepsize_search, local_log_acceptance_ratio(Hamiltonian(κ, ℓ), z)) @@ -175,7 +177,7 @@ A `NamedTuple` with the following fields: $(FIELDS) """ -struct TuningNUTS{M,D} +struct TuningNUTS{M,D,T<:Real} "Number of samples." N::Int "Dual averaging parameters." @@ -185,12 +187,12 @@ struct TuningNUTS{M,D} rescaled by `λ` towards ``σ²I``, where ``σ²`` is the median of the diagonal. The constructor has a reasonable default. """ - λ::Float64 + λ::T function TuningNUTS{M}(N::Integer, stepsize_adaptation::D, - λ = 5.0/N) where {M <: Union{Nothing,Diagonal,Symmetric},D} + λ::T = 5.0/N) where {M <: Union{Nothing,Diagonal,Symmetric},D,T<:Real} @argcheck N ≥ 20 # variance estimator is kind of meaningless for few samples @argcheck λ ≥ 0 - new{M,D}(N, stepsize_adaptation, λ) + new{M,D,T}(N, stepsize_adaptation, λ) end end @@ -259,12 +261,14 @@ function warmup(sampling_logdensity, tuning::TuningNUTS{M}, warmup_state) where (; rng, ℓ, algorithm, reporter) = sampling_logdensity (; Q, κ, ϵ) = warmup_state (; N, stepsize_adaptation, λ) = tuning + T = typeof(one(eltype(Q.q))) + stepsize_adaptation = _oftype(stepsize_adaptation, T) posterior_matrix = _empty_posterior_matrix(Q, N) logdensities = _empty_logdensity_vector(Q, N) tree_statistics = Vector{TreeStatisticsNUTS}(undef, N) H = Hamiltonian(κ, ℓ) ϵ_state = initial_adaptation_state(stepsize_adaptation, ϵ) - ϵs = Vector{Float64}(undef, N) + ϵs = Vector{typeof(float(ϵ))}(undef, N) mcmc_reporter = make_mcmc_reporter(reporter, N; currently_warmup = true, tuning = M ≡ Nothing ? "stepsize" : "stepsize and $(M) metric") diff --git a/src/stepsize.jl b/src/stepsize.jl index a2beacd..cb0f2d0 100644 --- a/src/stepsize.jl +++ b/src/stepsize.jl @@ -20,18 +20,19 @@ $FIELDS The algorithm is from Hoffman and Gelman (2014), default threshold modified to `0.8` following later practice in Stan. """ -struct InitialStepsizeSearch +struct InitialStepsizeSearch{T<:Real} "The stepsize where the search is started." - initial_ϵ::Float64 + initial_ϵ::T "Log of the threshold that needs to be crossed." - log_threshold::Float64 + log_threshold::T "Maximum number of iterations for crossing the threshold." maxiter_crossing::Int - function InitialStepsizeSearch(; log_threshold::Float64 = log(0.8), initial_ϵ = 0.1, maxiter_crossing = 400) + function InitialStepsizeSearch(; log_threshold = log(0.8), initial_ϵ = 0.1, maxiter_crossing = 400) + T = promote_type(typeof(log_threshold), typeof(initial_ϵ)) @argcheck isfinite(log_threshold) && log_threshold < 0 @argcheck isfinite(initial_ϵ) && 0 < initial_ϵ @argcheck maxiter_crossing ≥ 50 - new(initial_ϵ, log_threshold, maxiter_crossing) + new{T}(T(initial_ϵ), T(log_threshold), maxiter_crossing) end end @@ -131,10 +132,10 @@ $(SIGNATURES) Return an initial adaptation state for the adaptation method and a stepsize `ϵ`. """ -function initial_adaptation_state(::DualAveraging, ϵ) +function initial_adaptation_state(::DualAveraging{T}, ϵ) where T <: AbstractFloat @argcheck ϵ > 0 logϵ = log(ϵ) - DualAveragingState(; μ = log(10) + logϵ, m = 1, H̄ = zero(logϵ), logϵ, logϵ̄ = zero(logϵ)) + DualAveragingState{T}(; μ = log(T(10)) + logϵ, m = 1, H̄ = zero(logϵ), logϵ, logϵ̄ = zero(logϵ)) end """ @@ -150,8 +151,9 @@ function adapt_stepsize(parameters::DualAveraging, A::DualAveragingState, a) (; μ, m, H̄, logϵ, logϵ̄) = A m += 1 H̄ += (δ - a - H̄) / (m + t₀) - logϵ = μ - √m/γ * H̄ - logϵ̄ += m^(-κ)*(logϵ - logϵ̄) + T_m = oftype(μ, m) + logϵ = μ - sqrt(T_m)/γ * H̄ + logϵ̄ += T_m^(-κ)*(logϵ - logϵ̄) DualAveragingState(; μ, m, H̄, logϵ, logϵ̄) end @@ -187,3 +189,19 @@ adapt_stepsize(::FixedStepsize, ϵ, a) = ϵ current_ϵ(ϵ::Real) = ϵ final_ϵ(ϵ::Real) = ϵ + +### +### type conversion helpers for warmup pipeline +### + +_oftype(da::DualAveraging{T}, ::Type{T}) where {T} = da +_oftype(da::DualAveraging, ::Type{T}) where {T<:AbstractFloat} = + DualAveraging(T(da.δ), T(da.γ), T(da.κ), da.t₀) + +_oftype(iss::InitialStepsizeSearch{T}, ::Type{T}) where {T} = iss +_oftype(iss::InitialStepsizeSearch, ::Type{T}) where {T<:Real} = + InitialStepsizeSearch(; log_threshold = T(iss.log_threshold), + initial_ϵ = T(iss.initial_ϵ), + maxiter_crossing = iss.maxiter_crossing) + +_oftype(fs::FixedStepsize, ::Type) = fs diff --git a/test/test_mcmc.jl b/test/test_mcmc.jl index 6355246..b2c87f2 100644 --- a/test/test_mcmc.jl +++ b/test/test_mcmc.jl @@ -71,6 +71,97 @@ end @test M == 0 end +@testset "Float32 support" begin + # Float32 multivariate normal: ℓ(q) = -½ (q - μ)ᵀ Σ⁻¹ (q - μ) with Σ = I + struct Float32Normal{V <: AbstractVector} + μ::V + end + LogDensityProblems.capabilities(::Type{<:Float32Normal}) = LogDensityProblems.LogDensityOrder{1}() + LogDensityProblems.dimension(ℓ::Float32Normal) = length(ℓ.μ) + function LogDensityProblems.logdensity_and_gradient(ℓ::Float32Normal, q::AbstractVector) + r = q - ℓ.μ + T = eltype(q) + T(-dot(r, r) / 2), -r + end + + @testset "type propagation" begin + ℓ32 = Float32Normal(zeros(Float32, 3)) + q0 = randn(Float32, 3) + results = mcmc_with_warmup(RNG, ℓ32, 100; + initialization = (q = q0,), + reporter = NoProgressReport()) + @test eltype(results.posterior_matrix) == Float32 + @test eltype(results.logdensities) == Float32 + @test results.tree_statistics[1].π isa Float32 + @test results.tree_statistics[1].acceptance_rate isa Float32 + @test results.ϵ isa Float32 + end + + @testset "no type promotion in compute" begin + # A log density that errors if position is not Float32, + # catching any accidental promotion in leapfrog/adaptation + struct StrictFloat32Normal{V <: AbstractVector{Float32}} + μ::V + end + LogDensityProblems.capabilities(::Type{<:StrictFloat32Normal}) = LogDensityProblems.LogDensityOrder{1}() + LogDensityProblems.dimension(ℓ::StrictFloat32Normal) = length(ℓ.μ) + function LogDensityProblems.logdensity_and_gradient(ℓ::StrictFloat32Normal, q::AbstractVector) + @assert eltype(q) === Float32 "position promoted to $(eltype(q)), expected Float32" + r = q - ℓ.μ + Float32(-dot(r, r) / 2), -r + end + ℓ_strict = StrictFloat32Normal(zeros(Float32, 3)) + q0 = randn(Float32, 3) + # runs full warmup (stepsize search + dual averaging + metric adaptation) + # and inference — any Float64 promotion in leapfrog would trigger the assertion + results = mcmc_with_warmup(RNG, ℓ_strict, 100; + initialization = (q = q0,), + reporter = NoProgressReport()) + @test eltype(results.posterior_matrix) == Float32 + @test results.ϵ isa Float32 + end + + @testset "sample correctness" begin + μ = Float32[1.0, -0.5, 2.0, 0.0, -1.5] + ℓ32 = Float32Normal(μ) + q0 = randn(Float32, 5) + results = mcmc_with_warmup(RNG, ℓ32, 10000; + initialization = (q = q0,), + reporter = NoProgressReport()) + Z = results.posterior_matrix + @test eltype(Z) == Float32 + @test norm(mean(Z; dims = 2) .- μ, Inf) < 0.06 + @test norm(std(Z; dims = 2) .- ones(Float32, 5), Inf) < 0.06 + @test mean(x -> x.acceptance_rate, results.tree_statistics) ≥ 0.7 + end + + @testset "fixed stepsize" begin + ℓ32 = Float32Normal(ones(Float32, 3)) + q0 = randn(Float32, 3) + results = mcmc_with_warmup(RNG, ℓ32, 5000; + initialization = (q = q0, ϵ = Float32(1.0)), + warmup_stages = fixed_stepsize_warmup_stages(), + reporter = NoProgressReport()) + Z = results.posterior_matrix + @test eltype(Z) == Float32 + @test norm(mean(Z; dims = 2) .- ones(Float32, 3), Inf) < 0.1 + end + + @testset "stepwise" begin + ℓ32 = Float32Normal(zeros(Float32, 3)) + q0 = randn(Float32, 3) + results = mcmc_keep_warmup(RNG, ℓ32, 0; + initialization = (q = q0,), + reporter = NoProgressReport()) + steps = mcmc_steps(results.sampling_logdensity, results.final_warmup_state) + Q = results.final_warmup_state.Q + @test eltype(Q.q) == Float32 + qs = [(Q = first(mcmc_next_step(steps, Q)); Q.q) for _ in 1:1000] + @test eltype(qs[1]) == Float32 + @test norm(mean(reduce(hcat, qs); dims = 2), Inf) ≤ 0.15 + end +end + @testset "posterior accessors sanity checks" begin D, N, K = 5, 100, 7 ℓ = multivariate_normal(ones(5))