Skip to content
190 changes: 136 additions & 54 deletions src/algorithms/dit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,45 +50,70 @@ fn recursive_dit_fft_f64<S: Simd>(
0
};

// Remaining stages use per-stage kernels
for stage in start_stage..log_size {
stage_twiddle_idx = execute_dit_stage_f64(
let mut stage = start_stage;
while stage < log_size {
let (new_idx, consumed) = execute_dit_stages_f64(
simd,
&mut reals[..size],
&mut imags[..size],
stage,
log_size - stage,
planner,
stage_twiddle_idx,
);
stage_twiddle_idx = new_idx;
stage += consumed;
}
stage_twiddle_idx
} else {
let half = size / 2;
let log_half = half.ilog2() as usize;
let quarter = size / 4;
let log_quarter = quarter.ilog2() as usize;

let (re_first_half, re_second_half) = reals.split_at_mut(half);
let (im_first_half, im_second_half) = imags.split_at_mut(half);
// Recursively process both halves
let (re_lo, re_hi) = reals.split_at_mut(size / 2);
let (im_lo, im_hi) = imags.split_at_mut(size / 2);
let (re_q0, re_q1) = re_lo.split_at_mut(quarter);
let (im_q0, im_q1) = im_lo.split_at_mut(quarter);
let (re_q2, re_q3) = re_hi.split_at_mut(quarter);
let (im_q2, im_q3) = im_hi.split_at_mut(quarter);

// Recursively process all 4 quarters (parallel across pairs)
run_maybe_in_parallel(
size > opts.smallest_parallel_chunk_size,
|| recursive_dit_fft_f64(simd, re_first_half, im_first_half, half, planner, opts, 0),
|| recursive_dit_fft_f64(simd, re_second_half, im_second_half, half, planner, opts, 0),
|| {
run_maybe_in_parallel(
size / 2 > opts.smallest_parallel_chunk_size,
|| recursive_dit_fft_f64(simd, re_q0, im_q0, quarter, planner, opts, 0),
|| recursive_dit_fft_f64(simd, re_q1, im_q1, quarter, planner, opts, 0),
)
},
|| {
run_maybe_in_parallel(
size / 2 > opts.smallest_parallel_chunk_size,
|| recursive_dit_fft_f64(simd, re_q2, im_q2, quarter, planner, opts, 0),
|| recursive_dit_fft_f64(simd, re_q3, im_q3, quarter, planner, opts, 0),
)
},
);

// Both halves completed stages 0..log_half-1
// All 4 quarters completed stages 0..log_quarter-1.
// Now process the 2 remaining cross-block stages (log_quarter and log_quarter+1)
// using the fused kernel.
// Stages 0-5 use hardcoded twiddles, 6+ use planner
stage_twiddle_idx = log_half.saturating_sub(6);
stage_twiddle_idx = log_quarter.saturating_sub(6);

// Process remaining stages that span both halves
for stage in log_half..log_size {
stage_twiddle_idx = execute_dit_stage_f64(
let mut stage = log_quarter;
while stage < log_size {
let (new_idx, consumed) = execute_dit_stages_f64(
simd,
&mut reals[..size],
&mut imags[..size],
stage,
log_size - stage,
planner,
stage_twiddle_idx,
);
stage_twiddle_idx = new_idx;
stage += consumed;
}

stage_twiddle_idx
Expand Down Expand Up @@ -116,126 +141,183 @@ fn recursive_dit_fft_f32<S: Simd>(
0
};

// Remaining stages use per-stage kernels
for stage in start_stage..log_size {
stage_twiddle_idx = execute_dit_stage_f32(
let mut stage = start_stage;
while stage < log_size {
let (new_idx, consumed) = execute_dit_stages_f32(
simd,
&mut reals[..size],
&mut imags[..size],
stage,
log_size - stage,
planner,
stage_twiddle_idx,
);
stage_twiddle_idx = new_idx;
stage += consumed;
}
stage_twiddle_idx
} else {
let half = size / 2;
let log_half = half.ilog2() as usize;
let quarter = size / 4;
let log_quarter = quarter.ilog2() as usize;

let (re_lo, re_hi) = reals.split_at_mut(size / 2);
let (im_lo, im_hi) = imags.split_at_mut(size / 2);
let (re_q0, re_q1) = re_lo.split_at_mut(quarter);
let (im_q0, im_q1) = im_lo.split_at_mut(quarter);
let (re_q2, re_q3) = re_hi.split_at_mut(quarter);
let (im_q2, im_q3) = im_hi.split_at_mut(quarter);

let (re_first_half, re_second_half) = reals.split_at_mut(half);
let (im_first_half, im_second_half) = imags.split_at_mut(half);
// Recursively process both halves
// Recursively process all 4 quarters (parallel across pairs)
run_maybe_in_parallel(
size > opts.smallest_parallel_chunk_size,
|| recursive_dit_fft_f32(simd, re_first_half, im_first_half, half, planner, opts, 0),
|| recursive_dit_fft_f32(simd, re_second_half, im_second_half, half, planner, opts, 0),
|| {
run_maybe_in_parallel(
size / 2 > opts.smallest_parallel_chunk_size,
|| recursive_dit_fft_f32(simd, re_q0, im_q0, quarter, planner, opts, 0),
|| recursive_dit_fft_f32(simd, re_q1, im_q1, quarter, planner, opts, 0),
)
},
|| {
run_maybe_in_parallel(
size / 2 > opts.smallest_parallel_chunk_size,
|| recursive_dit_fft_f32(simd, re_q2, im_q2, quarter, planner, opts, 0),
|| recursive_dit_fft_f32(simd, re_q3, im_q3, quarter, planner, opts, 0),
)
},
);

// Both halves completed stages 0..log_half-1
// All 4 quarters completed stages 0..log_quarter-1.
// Now process the 2 remaining cross-block stages (log_quarter and log_quarter+1)
// using the fused kernel.
// Stages 0-5 use hardcoded twiddles, 6+ use planner
stage_twiddle_idx = log_half.saturating_sub(6);
stage_twiddle_idx = log_quarter.saturating_sub(6);

// Process remaining stages that span both halves
for stage in log_half..log_size {
stage_twiddle_idx = execute_dit_stage_f32(
let mut stage = log_quarter;
while stage < log_size {
let (new_idx, consumed) = execute_dit_stages_f32(
simd,
&mut reals[..size],
&mut imags[..size],
stage,
log_size - stage,
planner,
stage_twiddle_idx,
);
stage_twiddle_idx = new_idx;
stage += consumed;
}

stage_twiddle_idx
}
}

/// Execute a single DIT stage, dispatching to appropriate kernel based on chunk size.
/// Returns updated stage_twiddle_idx.
fn execute_dit_stage_f64<S: Simd>(
/// Execute one or two DIT stages, dispatching to appropriate kernel based on chunk size.
/// Returns (updated stage_twiddle_idx, number of stages consumed).
fn execute_dit_stages_f64<S: Simd>(
simd: S,
reals: &mut [f64],
imags: &mut [f64],
stage: usize,
stages_remaining: usize,
planner: &PlannerDit64,
stage_twiddle_idx: usize,
) -> usize {
) -> (usize, usize) {
let dist = 1 << stage; // 2.pow(stage)
let chunk_size = dist * 2;

if chunk_size == 2 {
simd.vectorize(|| fft_dit_chunk_2(simd, reals, imags));
stage_twiddle_idx
(stage_twiddle_idx, 1)
} else if chunk_size == 4 {
fft_dit_chunk_4_f64(simd, reals, imags);
stage_twiddle_idx
(stage_twiddle_idx, 1)
} else if chunk_size == 8 {
fft_dit_chunk_8_f64(simd, reals, imags);
stage_twiddle_idx
(stage_twiddle_idx, 1)
} else if chunk_size == 16 {
fft_dit_chunk_16_f64(simd, reals, imags);
stage_twiddle_idx
(stage_twiddle_idx, 1)
} else if chunk_size == 32 {
fft_dit_chunk_32_f64(simd, reals, imags);
stage_twiddle_idx
(stage_twiddle_idx, 1)
} else if chunk_size == 64 {
fft_dit_chunk_64_f64(simd, reals, imags);
stage_twiddle_idx
(stage_twiddle_idx, 1)
} else if stages_remaining >= 2 {
// Fuse two stages into a single pass over memory
let (twiddles_re, twiddles_im) = &planner.stage_twiddles[stage_twiddle_idx];
let (twiddles_re2, twiddles_im2) = &planner.stage_twiddles[stage_twiddle_idx + 1];
fft_dit_fused_2stage_f64_narrow(
simd,
reals,
imags,
twiddles_re,
twiddles_im,
twiddles_re2,
twiddles_im2,
dist,
);
(stage_twiddle_idx + 2, 2)
} else {
// For larger chunks, use general kernel with twiddles from planner
// Last stage (odd number of stages remaining), use single-stage kernel
let (twiddles_re, twiddles_im) = &planner.stage_twiddles[stage_twiddle_idx];
fft_dit_chunk_n_f64(simd, reals, imags, twiddles_re, twiddles_im, dist);
stage_twiddle_idx + 1
(stage_twiddle_idx + 1, 1)
}
}

/// Execute a single DIT stage, dispatching to appropriate kernel based on chunk size.
/// Returns updated stage_twiddle_idx.
fn execute_dit_stage_f32<S: Simd>(
/// Execute one or two DIT stages, dispatching to appropriate kernel based on chunk size.
/// Returns (updated stage_twiddle_idx, number of stages consumed).
fn execute_dit_stages_f32<S: Simd>(
simd: S,
reals: &mut [f32],
imags: &mut [f32],
stage: usize,
stages_remaining: usize,
planner: &PlannerDit32,
stage_twiddle_idx: usize,
) -> usize {
) -> (usize, usize) {
let dist = 1 << stage; // 2.pow(stage)
let chunk_size = dist * 2;

if chunk_size == 2 {
simd.vectorize(|| fft_dit_chunk_2(simd, reals, imags));
stage_twiddle_idx
(stage_twiddle_idx, 1)
} else if chunk_size == 4 {
fft_dit_chunk_4_f32(simd, reals, imags);
stage_twiddle_idx
(stage_twiddle_idx, 1)
} else if chunk_size == 8 {
fft_dit_chunk_8_f32(simd, reals, imags);
stage_twiddle_idx
(stage_twiddle_idx, 1)
} else if chunk_size == 16 {
fft_dit_chunk_16_f32(simd, reals, imags);
stage_twiddle_idx
(stage_twiddle_idx, 1)
} else if chunk_size == 32 {
fft_dit_chunk_32_f32(simd, reals, imags);
stage_twiddle_idx
(stage_twiddle_idx, 1)
} else if chunk_size == 64 {
fft_dit_chunk_64_f32(simd, reals, imags);
stage_twiddle_idx
(stage_twiddle_idx, 1)
} else if stages_remaining >= 2 {
// Fuse two stages into a single pass over memory
let (twiddles_re, twiddles_im) = &planner.stage_twiddles[stage_twiddle_idx];
let (twiddles_re2, twiddles_im2) = &planner.stage_twiddles[stage_twiddle_idx + 1];
fft_dit_fused_2stage_f32_narrow(
simd,
reals,
imags,
twiddles_re,
twiddles_im,
twiddles_re2,
twiddles_im2,
dist,
);
(stage_twiddle_idx + 2, 2)
} else {
// For larger chunks, use general kernel with twiddles from planner
// Last stage (odd number of stages remaining), use single-stage kernel
let (twiddles_re, twiddles_im) = &planner.stage_twiddles[stage_twiddle_idx];
fft_dit_chunk_n_f32(simd, reals, imags, twiddles_re, twiddles_im, dist);
stage_twiddle_idx + 1
(stage_twiddle_idx + 1, 1)
}
}

Expand Down
Loading
Loading