Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 61 additions & 11 deletions src/fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ function generic_fft!(x)
end


generic_fft(x, region) = generic_fft!(copy(x), region)
generic_fft(x, region) = generic_fft!(copy(complex(x)), region)

generic_fft(x) = generic_fft!(copy(x))
generic_fft(x) = generic_fft!(copy(complex(x)))

function generic_fft(x::AbstractVector{T}) where T<:AbstractFloats
n = length(x)
Expand All @@ -105,14 +105,64 @@ generic_ifft(x::AbstractArray{T, N}, region) where {T<:AbstractFloats, N} = ldiv
generic_ifft!(x::AbstractArray{T, N}, region) where {T<:AbstractFloats, N} = ldiv!(T(_regionscale(x, region)), conj!(generic_fft!(conj!(x), region)))

generic_rfft(v::AbstractVector{T}, region) where T<:AbstractFloats = generic_fft(v, region)[1:div(length(v),2)+1]

function generic_rfft(x::AbstractArray{T, N}, region) where {T<:AbstractFloats, N}
d = first(region)
if length(region) > 1
return generic_fft(generic_rfft(x, d), region[2:end])
end

# Batched 1D RFFT along dimension d
nout = size(x, d) ÷ 2 + 1
sz = collect(size(x))
sz[d] = nout
out = similar(x, Complex{real(T)}, tuple(sz...))

Rpre = CartesianIndices(size(x)[1:d-1])
Rpost = CartesianIndices(size(x)[d+1:end])

Threads.@threads for Ipost in Rpost
for Ipre in Rpre
out[Ipre, :, Ipost] .= generic_rfft(view(x, Ipre, :, Ipost), 1)
end
end
return out
end

function generic_irfft(v::AbstractVector{T}, n::Integer, region) where T<:ComplexFloats
@assert length(v) == n>>1 + 1
r = Vector{T}(undef, n)
r[1:length(v)]=v
r[length(v)+1:n]=reverse(conj(v[2:end])[1:n-length(v)])
real(generic_ifft(r, region))
return real(generic_ifft(r, region))
end

function generic_irfft(x::AbstractArray{T, N}, n::Integer, region) where {T<:ComplexFloats, N}
d = first(region)
if length(region) > 1
return generic_irfft(generic_ifft(x, region[2:end]), n, d)
end

# Batched 1D IRFFT along dimension d
sz = collect(size(x))
sz[d] = n
out = similar(x, real(T), tuple(sz...))

Rpre = CartesianIndices(size(x)[1:d-1])
Rpost = CartesianIndices(size(x)[d+1:end])

Threads.@threads for Ipost in Rpost
for Ipre in Rpre
out[Ipre, :, Ipost] .= generic_irfft(view(x, Ipre, :, Ipost), n, 1)
end
end
return out
end

function generic_brfft(v::AbstractArray, n::Integer, region)
scale = n * _regionscale(v, region isa Integer ? () : region[2:end])
return generic_irfft(v, n, region) * scale
end
generic_brfft(v::AbstractArray, n::Integer, region) = generic_irfft(v, n, region)*n

function _conv!(u::AbstractVector{T}, v::AbstractVector{T}) where T<:AbstractFloats
nu = length(u)
Expand Down Expand Up @@ -262,7 +312,7 @@ for P in (:DummyFFTPlan, :DummyiFFTPlan, :DummybFFTPlan, :DummyDCTPlan, :DummyiD
@eval begin
mutable struct $P{T,inplace,G} <: DummyPlan{T}
region::G # region (iterable) of dims that are transformed
pinv::DummyPlan{T}
pinv::Plan
$P{T,inplace,G}(region::G) where {T<:AbstractFloats, inplace, G} = new(region)
end
end
Expand All @@ -272,7 +322,7 @@ for P in (:DummyrFFTPlan, :DummyirFFTPlan, :DummybrFFTPlan)
mutable struct $P{T,inplace,G} <: DummyPlan{T}
n::Integer
region::G # region (iterable) of dims that are transformed
pinv::DummyPlan{T}
pinv::Plan
$P{T,inplace,G}(n::Integer, region::G) where {T<:AbstractFloats, inplace, G} = new(n, region)
end
end
Expand All @@ -287,8 +337,8 @@ for (Plan,iPlan) in ((:DummyFFTPlan,:DummyiFFTPlan),
end

# Specific for rfft, irfft and brfft:
plan_inv(p::DummyirFFTPlan{T,inplace,G}) where {T,inplace,G} = DummyrFFTPlan{T,inplace,G}(p.n, p.region)
plan_inv(p::DummyrFFTPlan{T,inplace,G}) where {T,inplace,G} = DummyirFFTPlan{T,inplace,G}(p.n, p.region)
plan_inv(p::DummyirFFTPlan{T,inplace,G}) where {T,inplace,G} = DummyrFFTPlan{real(T),inplace,G}(p.n, p.region)
plan_inv(p::DummyrFFTPlan{T,inplace,G}) where {T,inplace,G} = DummyirFFTPlan{Complex{T},inplace,G}(p.n, p.region)



Expand Down Expand Up @@ -345,11 +395,11 @@ plan_dct!(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyDCTPlan
plan_idct(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyiDCTPlan{T,false,typeof(region)}(region)
plan_idct!(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyiDCTPlan{T,true,typeof(region)}(region)

plan_rfft(x::StridedArray{T}, region) where {T <: RealFloats} = DummyrFFTPlan{T,false,typeof(region)}(length(x), region)
plan_rfft(x::StridedArray{T}, region) where {T <: RealFloats} = DummyrFFTPlan{T,false,typeof(region)}(size(x, first(region)), region)
plan_brfft(x::StridedArray{T}, n::Integer, region) where {T <: ComplexFloats} = DummybrFFTPlan{T,false,typeof(region)}(n, region)

# A plan for irfft is created in terms of a plan for brfft.
# plan_irfft(x::StridedArray{T}, n::Integer, region) where {T <: ComplexFloats} = DummyirFFTPlan{Complex{real(T)},false,typeof(region)}(n, region)
# Explicitly define plan_irfft to ensure correct scaling
plan_irfft(x::StridedArray{T}, n::Integer, region) where {T <: ComplexFloats} = DummyirFFTPlan{T,false,typeof(region)}(n, region)

# These don't exist for now:
# plan_rfft!(x::StridedArray{T}) where {T <: RealFloats} = DummyrFFTPlan{Complex{real(T)},true}()
Expand Down
59 changes: 57 additions & 2 deletions test/fft_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ function test_fft_dct(T)
@test norm(dct(idct(c))-c,Inf) < 1000eps(T)

@test_throws AssertionError irfft(c, 197)
@test norm(irfft(c, 198) - irfft(map(ComplexF64, c), 198), Inf) < 10eps(Float64)
@test norm(irfft(c, 199) - irfft(map(ComplexF64, c), 199), Inf) < 10eps(Float64)
@test norm(irfft(c, 198) - irfft(map(ComplexF64, c), 198), Inf) < 100eps(Float64)
@test norm(irfft(c, 199) - irfft(map(ComplexF64, c), 199), Inf) < 100eps(Float64)
@test_throws AssertionError irfft(c, 200)
end

Expand Down Expand Up @@ -202,3 +202,58 @@ end
@allocations generic_fft!(A2) # compile
@test N+150 > @allocations generic_fft!(A2) # a few allocations is OK
end

@testset "Batched rfft/irfft" begin
for T in (Float64, BigFloat)
X = randn(T, 10, 6)

Y1 = rfft(X, 1) # Dimension 1
@test size(Y1) == (10÷2+1, 6)
for j in 1:6
@test Y1[:, j] ≈ rfft(X[:, j])
end
@test irfft(Y1, 10, 1) ≈ X

Y2 = rfft(X, 2) # Dimension 2
@test size(Y2) == (10, 6÷2+1)
for i in 1:10
@test Y2[i, :] ≈ rfft(X[i, :])
end
@test irfft(Y2, 6, 2) ≈ X

Y12 = rfft(X, (1, 2)) # 2D RFFT
@test size(Y12) == (10÷2+1, 6)
@test Y12 ≈ fft(rfft(X, 1), 2)
@test irfft(Y12, 10, (1, 2)) ≈ X

p1 = plan_rfft(X, 1) # Plans
@test p1 * X ≈ rfft(X, 1)
@test inv(p1) * (p1 * X) ≈ X

p2 = plan_rfft(X, 2)
@test p2 * X ≈ rfft(X, 2)
@test inv(p2) * (p2 * X) ≈ X
end


for n in (7, 11) # Test a few odd lengths
X = randn(BigFloat, n, 4)
Y = rfft(X, 1)
@test size(Y) == (n÷2+1, 4)
@test irfft(Y, n, 1) ≈ X
end

data = randn(BigFloat, 10, 10)
v = view(data, 1:8, 1:6)
@test rfft(v, 1) ≈ rfft(collect(v), 1)
@test irfft(rfft(v, 1), 8, 1) ≈ v

X3 = randn(BigFloat, 4, 10, 4) # Test 3D Batched

Y3 = rfft(X3, 2) # Transform along dimension 2
@test size(Y3) == (4, 10÷2+1, 4)
for i in 1:4, k in 1:4
@test Y3[i, :, k] ≈ rfft(X3[i, :, k])
end
@test irfft(Y3, 10, 2) ≈ X3
end
Loading