diff --git a/crates/iceberg/src/scan/mod.rs b/crates/iceberg/src/scan/mod.rs index c055c12c9a..38a138f273 100644 --- a/crates/iceberg/src/scan/mod.rs +++ b/crates/iceberg/src/scan/mod.rs @@ -27,8 +27,8 @@ use std::sync::Arc; use arrow_array::RecordBatch; use futures::channel::mpsc::{Sender, channel}; -use futures::stream::BoxStream; -use futures::{SinkExt, StreamExt, TryStreamExt}; +use futures::stream::{self, BoxStream}; +use futures::{SinkExt, StreamExt, TryStreamExt, future}; pub use task::*; use crate::arrow::ArrowReaderBuilder; @@ -60,6 +60,7 @@ pub struct TableScanBuilder<'a> { concurrency_limit_manifest_files: usize, row_group_filtering_enabled: bool, row_selection_enabled: bool, + minimum_reader_tasks: usize, } impl<'a> TableScanBuilder<'a> { @@ -78,6 +79,7 @@ impl<'a> TableScanBuilder<'a> { concurrency_limit_manifest_files: num_cpus, row_group_filtering_enabled: true, row_selection_enabled: false, + minimum_reader_tasks: 0, } } @@ -146,6 +148,14 @@ impl<'a> TableScanBuilder<'a> { self } + /// Sets the minimum reader tasks limit for this scan + /// When enabled, files are read in parallel chunks of size + /// no less than this limit, to keep each cpu busy. + pub fn with_minimum_reader_tasks(mut self, limit: usize) -> Self { + self.minimum_reader_tasks = limit; + self + } + /// Sets the manifest entry concurrency limit for this scan pub fn with_manifest_entry_concurrency_limit(mut self, limit: usize) -> Self { self.concurrency_limit_manifest_entries = limit; @@ -210,6 +220,7 @@ impl<'a> TableScanBuilder<'a> { concurrency_limit_manifest_files: self.concurrency_limit_manifest_files, row_group_filtering_enabled: self.row_group_filtering_enabled, row_selection_enabled: self.row_selection_enabled, + minimum_reader_tasks: self.minimum_reader_tasks, }); }; current_snapshot_id.clone() @@ -303,6 +314,7 @@ impl<'a> TableScanBuilder<'a> { concurrency_limit_manifest_files: self.concurrency_limit_manifest_files, row_group_filtering_enabled: self.row_group_filtering_enabled, row_selection_enabled: self.row_selection_enabled, + minimum_reader_tasks: 0, }) } } @@ -329,6 +341,11 @@ pub struct TableScan { /// be processed in parallel concurrency_limit_data_files: usize, + /// The minimum number of [`ManifestEntry`]s that will + /// be processed in parallel. If specified, new tasks + /// will be spawned to read files in parallel + minimum_reader_tasks: usize, + row_group_filtering_enabled: bool, row_selection_enabled: bool, } @@ -430,18 +447,43 @@ impl TableScan { Ok(file_scan_task_rx.boxed()) } - /// Returns an [`ArrowRecordBatchStream`]. - pub async fn to_arrow(&self) -> Result { + fn arrow_reader(&self) -> crate::arrow::ArrowReader { let mut arrow_reader_builder = ArrowReaderBuilder::new(self.file_io.clone()) .with_data_file_concurrency_limit(self.concurrency_limit_data_files) .with_row_group_filtering_enabled(self.row_group_filtering_enabled) .with_row_selection_enabled(self.row_selection_enabled); - if let Some(batch_size) = self.batch_size { arrow_reader_builder = arrow_reader_builder.with_batch_size(batch_size); } + arrow_reader_builder.build() + } + + /// Returns an [`ArrowRecordBatchStream`]. + pub async fn to_arrow(&self) -> Result { + let plan_files = self.plan_files().await?; - arrow_reader_builder.build().read(self.plan_files().await?) + if self.minimum_reader_tasks == 0 { + self.arrow_reader().read(plan_files) + } else { + // spawn chunks into their own tasks for parallelism + let files: Vec<_> = plan_files.try_collect().await?; + let workers = std::thread::available_parallelism().map_or(4, |p| p.get()); + let chunk_size = files.len().div_ceil(workers).max(self.minimum_reader_tasks); + let futs = files.chunks(chunk_size).map(|chunk| { + #[allow(clippy::unnecessary_to_owned)] + let tasks = stream::iter(chunk.to_vec().into_iter().map(Ok)); + let reader = self.arrow_reader(); + tokio::spawn(async move { reader.read(Box::pin(tasks) as _) }) + }); + + let record_streams = future::try_join_all(futs) + .await + .map_err(|e| Error::new(ErrorKind::Unexpected, e.to_string()))? + .into_iter() + .collect::>>()?; + let stream = stream::iter(record_streams.into_iter()).flatten(); + Ok(Box::pin(stream) as ArrowRecordBatchStream) + } } /// Returns a reference to the column names of the table scan. @@ -2254,4 +2296,29 @@ pub mod tests { // Assert it finished (didn't timeout) assert!(result.is_ok(), "Scan timed out - deadlock detected"); } + + #[tokio::test] + async fn test_minimum_reader_task() { + let mut fixture = TableTestFixture::new(); + fixture.setup_manifest_files().await; + + // Create table scan for with no minimum task reader (e.g. on the same task) + let sync_table_scan = fixture.table.clone().scan().build().unwrap(); + let sync_batch_stream = sync_table_scan.to_arrow().await.unwrap(); + let sync_batches: Vec<_> = sync_batch_stream.try_collect().await.unwrap(); + + // Create table scan for 4 minimum tasks (e.g. on the same task) + let parallel_table_scan = fixture + .table + .scan() + .with_minimum_reader_tasks(4) + .build() + .unwrap(); + let parallel_batch_stream = parallel_table_scan.to_arrow().await.unwrap(); + let parallel_batches: Vec<_> = parallel_batch_stream.try_collect().await.unwrap(); + + for (sync, par) in sync_batches.into_iter().zip(parallel_batches) { + assert_eq!(sync, par); + } + } }