From 645f59c8acae1a463e3ac4aae51b1ce775676b07 Mon Sep 17 00:00:00 2001 From: jamesquinlan Date: Sun, 24 May 2026 18:41:58 -0400 Subject: [PATCH] Add batched and multi-dimensional rfft/irfft support Signed-off-by: jamesquinlan --- src/fft.jl | 72 +++++++++++++++++++++++++++++++++++++++-------- test/fft_tests.jl | 59 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 118 insertions(+), 13 deletions(-) diff --git a/src/fft.jl b/src/fft.jl index 26910bc..b9acce2 100644 --- a/src/fft.jl +++ b/src/fft.jl @@ -81,9 +81,9 @@ function generic_fft!(x) end -generic_fft(x, region) = generic_fft!(copy(x), region) +generic_fft(x, region) = generic_fft!(copy(complex(x)), region) -generic_fft(x) = generic_fft!(copy(x)) +generic_fft(x) = generic_fft!(copy(complex(x))) function generic_fft(x::AbstractVector{T}) where T<:AbstractFloats n = length(x) @@ -105,14 +105,64 @@ generic_ifft(x::AbstractArray{T, N}, region) where {T<:AbstractFloats, N} = ldiv generic_ifft!(x::AbstractArray{T, N}, region) where {T<:AbstractFloats, N} = ldiv!(T(_regionscale(x, region)), conj!(generic_fft!(conj!(x), region))) generic_rfft(v::AbstractVector{T}, region) where T<:AbstractFloats = generic_fft(v, region)[1:div(length(v),2)+1] + +function generic_rfft(x::AbstractArray{T, N}, region) where {T<:AbstractFloats, N} + d = first(region) + if length(region) > 1 + return generic_fft(generic_rfft(x, d), region[2:end]) + end + + # Batched 1D RFFT along dimension d + nout = size(x, d) ÷ 2 + 1 + sz = collect(size(x)) + sz[d] = nout + out = similar(x, Complex{real(T)}, tuple(sz...)) + + Rpre = CartesianIndices(size(x)[1:d-1]) + Rpost = CartesianIndices(size(x)[d+1:end]) + + Threads.@threads for Ipost in Rpost + for Ipre in Rpre + out[Ipre, :, Ipost] .= generic_rfft(view(x, Ipre, :, Ipost), 1) + end + end + return out +end + function generic_irfft(v::AbstractVector{T}, n::Integer, region) where T<:ComplexFloats @assert length(v) == n>>1 + 1 r = Vector{T}(undef, n) r[1:length(v)]=v r[length(v)+1:n]=reverse(conj(v[2:end])[1:n-length(v)]) - real(generic_ifft(r, region)) + return real(generic_ifft(r, region)) +end + +function generic_irfft(x::AbstractArray{T, N}, n::Integer, region) where {T<:ComplexFloats, N} + d = first(region) + if length(region) > 1 + return generic_irfft(generic_ifft(x, region[2:end]), n, d) + end + + # Batched 1D IRFFT along dimension d + sz = collect(size(x)) + sz[d] = n + out = similar(x, real(T), tuple(sz...)) + + Rpre = CartesianIndices(size(x)[1:d-1]) + Rpost = CartesianIndices(size(x)[d+1:end]) + + Threads.@threads for Ipost in Rpost + for Ipre in Rpre + out[Ipre, :, Ipost] .= generic_irfft(view(x, Ipre, :, Ipost), n, 1) + end + end + return out +end + +function generic_brfft(v::AbstractArray, n::Integer, region) + scale = n * _regionscale(v, region isa Integer ? () : region[2:end]) + return generic_irfft(v, n, region) * scale end -generic_brfft(v::AbstractArray, n::Integer, region) = generic_irfft(v, n, region)*n function _conv!(u::AbstractVector{T}, v::AbstractVector{T}) where T<:AbstractFloats nu = length(u) @@ -262,7 +312,7 @@ for P in (:DummyFFTPlan, :DummyiFFTPlan, :DummybFFTPlan, :DummyDCTPlan, :DummyiD @eval begin mutable struct $P{T,inplace,G} <: DummyPlan{T} region::G # region (iterable) of dims that are transformed - pinv::DummyPlan{T} + pinv::Plan $P{T,inplace,G}(region::G) where {T<:AbstractFloats, inplace, G} = new(region) end end @@ -272,7 +322,7 @@ for P in (:DummyrFFTPlan, :DummyirFFTPlan, :DummybrFFTPlan) mutable struct $P{T,inplace,G} <: DummyPlan{T} n::Integer region::G # region (iterable) of dims that are transformed - pinv::DummyPlan{T} + pinv::Plan $P{T,inplace,G}(n::Integer, region::G) where {T<:AbstractFloats, inplace, G} = new(n, region) end end @@ -287,8 +337,8 @@ for (Plan,iPlan) in ((:DummyFFTPlan,:DummyiFFTPlan), end # Specific for rfft, irfft and brfft: -plan_inv(p::DummyirFFTPlan{T,inplace,G}) where {T,inplace,G} = DummyrFFTPlan{T,inplace,G}(p.n, p.region) -plan_inv(p::DummyrFFTPlan{T,inplace,G}) where {T,inplace,G} = DummyirFFTPlan{T,inplace,G}(p.n, p.region) +plan_inv(p::DummyirFFTPlan{T,inplace,G}) where {T,inplace,G} = DummyrFFTPlan{real(T),inplace,G}(p.n, p.region) +plan_inv(p::DummyrFFTPlan{T,inplace,G}) where {T,inplace,G} = DummyirFFTPlan{Complex{T},inplace,G}(p.n, p.region) @@ -345,11 +395,11 @@ plan_dct!(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyDCTPlan plan_idct(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyiDCTPlan{T,false,typeof(region)}(region) plan_idct!(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyiDCTPlan{T,true,typeof(region)}(region) -plan_rfft(x::StridedArray{T}, region) where {T <: RealFloats} = DummyrFFTPlan{T,false,typeof(region)}(length(x), region) +plan_rfft(x::StridedArray{T}, region) where {T <: RealFloats} = DummyrFFTPlan{T,false,typeof(region)}(size(x, first(region)), region) plan_brfft(x::StridedArray{T}, n::Integer, region) where {T <: ComplexFloats} = DummybrFFTPlan{T,false,typeof(region)}(n, region) -# A plan for irfft is created in terms of a plan for brfft. -# plan_irfft(x::StridedArray{T}, n::Integer, region) where {T <: ComplexFloats} = DummyirFFTPlan{Complex{real(T)},false,typeof(region)}(n, region) +# Explicitly define plan_irfft to ensure correct scaling +plan_irfft(x::StridedArray{T}, n::Integer, region) where {T <: ComplexFloats} = DummyirFFTPlan{T,false,typeof(region)}(n, region) # These don't exist for now: # plan_rfft!(x::StridedArray{T}) where {T <: RealFloats} = DummyrFFTPlan{Complex{real(T)},true}() diff --git a/test/fft_tests.jl b/test/fft_tests.jl index 266451e..a345b5d 100644 --- a/test/fft_tests.jl +++ b/test/fft_tests.jl @@ -44,8 +44,8 @@ function test_fft_dct(T) @test norm(dct(idct(c))-c,Inf) < 1000eps(T) @test_throws AssertionError irfft(c, 197) - @test norm(irfft(c, 198) - irfft(map(ComplexF64, c), 198), Inf) < 10eps(Float64) - @test norm(irfft(c, 199) - irfft(map(ComplexF64, c), 199), Inf) < 10eps(Float64) + @test norm(irfft(c, 198) - irfft(map(ComplexF64, c), 198), Inf) < 100eps(Float64) + @test norm(irfft(c, 199) - irfft(map(ComplexF64, c), 199), Inf) < 100eps(Float64) @test_throws AssertionError irfft(c, 200) end @@ -202,3 +202,58 @@ end @allocations generic_fft!(A2) # compile @test N+150 > @allocations generic_fft!(A2) # a few allocations is OK end + +@testset "Batched rfft/irfft" begin + for T in (Float64, BigFloat) + X = randn(T, 10, 6) + + Y1 = rfft(X, 1) # Dimension 1 + @test size(Y1) == (10÷2+1, 6) + for j in 1:6 + @test Y1[:, j] ≈ rfft(X[:, j]) + end + @test irfft(Y1, 10, 1) ≈ X + + Y2 = rfft(X, 2) # Dimension 2 + @test size(Y2) == (10, 6÷2+1) + for i in 1:10 + @test Y2[i, :] ≈ rfft(X[i, :]) + end + @test irfft(Y2, 6, 2) ≈ X + + Y12 = rfft(X, (1, 2)) # 2D RFFT + @test size(Y12) == (10÷2+1, 6) + @test Y12 ≈ fft(rfft(X, 1), 2) + @test irfft(Y12, 10, (1, 2)) ≈ X + + p1 = plan_rfft(X, 1) # Plans + @test p1 * X ≈ rfft(X, 1) + @test inv(p1) * (p1 * X) ≈ X + + p2 = plan_rfft(X, 2) + @test p2 * X ≈ rfft(X, 2) + @test inv(p2) * (p2 * X) ≈ X + end + + + for n in (7, 11) # Test a few odd lengths + X = randn(BigFloat, n, 4) + Y = rfft(X, 1) + @test size(Y) == (n÷2+1, 4) + @test irfft(Y, n, 1) ≈ X + end + + data = randn(BigFloat, 10, 10) + v = view(data, 1:8, 1:6) + @test rfft(v, 1) ≈ rfft(collect(v), 1) + @test irfft(rfft(v, 1), 8, 1) ≈ v + + X3 = randn(BigFloat, 4, 10, 4) # Test 3D Batched + + Y3 = rfft(X3, 2) # Transform along dimension 2 + @test size(Y3) == (4, 10÷2+1, 4) + for i in 1:4, k in 1:4 + @test Y3[i, :, k] ≈ rfft(X3[i, :, k]) + end + @test irfft(Y3, 10, 2) ≈ X3 +end