From 55ef3238c773850d2c9795779682611e04690988 Mon Sep 17 00:00:00 2001 From: Matthew Kim <38759997+friendlymatthew@users.noreply.github.com> Date: Tue, 17 Mar 2026 12:06:29 -0400 Subject: [PATCH] teach row group pruning about struct fields --- Cargo.lock | 49 +- Cargo.toml | 14 +- .../examples/data_io/parquet_index.rs | 30 +- .../examples/query_planning/pruning.rs | 16 +- datafusion/common/src/pruning.rs | 133 +++-- .../datasource-parquet/src/page_filter.rs | 12 +- .../src/row_group_filter.rs | 475 +++++++++++++++++- datafusion/pruning/Cargo.toml | 1 + datafusion/pruning/src/lib.rs | 4 +- datafusion/pruning/src/pruning_predicate.rs | 330 ++++++++---- 10 files changed, 828 insertions(+), 236 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 168d3bd0c1812..e63e417c0570d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -233,8 +233,7 @@ checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" [[package]] name = "arrow" version = "58.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "602268ce9f569f282cedb9a9f6bac569b680af47b9b077d515900c03c5d190da" +source = "git+https://github.com/pydantic/arrow-rs.git?branch=friendlymatthew%2Fstatistics-converter-from-col-index#a987c9e8b38e0f07efe355dcd80d4662b3a44b3f" dependencies = [ "arrow-arith", "arrow-array", @@ -256,8 +255,7 @@ dependencies = [ [[package]] name = "arrow-arith" version = "58.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd53c6bf277dea91f136ae8e3a5d7041b44b5e489e244e637d00ae302051f56f" +source = "git+https://github.com/pydantic/arrow-rs.git?branch=friendlymatthew%2Fstatistics-converter-from-col-index#a987c9e8b38e0f07efe355dcd80d4662b3a44b3f" dependencies = [ "arrow-array", "arrow-buffer", @@ -270,8 +268,7 @@ dependencies = [ [[package]] name = "arrow-array" version = "58.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e53796e07a6525edaf7dc28b540d477a934aff14af97967ad1d5550878969b9e" +source = "git+https://github.com/pydantic/arrow-rs.git?branch=friendlymatthew%2Fstatistics-converter-from-col-index#a987c9e8b38e0f07efe355dcd80d4662b3a44b3f" dependencies = [ "ahash", "arrow-buffer", @@ -289,8 +286,7 @@ dependencies = [ [[package]] name = "arrow-buffer" version = "58.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2c1a85bb2e94ee10b76531d8bc3ce9b7b4c0d508cabfb17d477f63f2617bd20" +source = "git+https://github.com/pydantic/arrow-rs.git?branch=friendlymatthew%2Fstatistics-converter-from-col-index#a987c9e8b38e0f07efe355dcd80d4662b3a44b3f" dependencies = [ "bytes", "half", @@ -301,8 +297,7 @@ dependencies = [ [[package]] name = "arrow-cast" version = "58.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89fb245db6b0e234ed8e15b644edb8664673fefe630575e94e62cd9d489a8a26" +source = "git+https://github.com/pydantic/arrow-rs.git?branch=friendlymatthew%2Fstatistics-converter-from-col-index#a987c9e8b38e0f07efe355dcd80d4662b3a44b3f" dependencies = [ "arrow-array", "arrow-buffer", @@ -323,8 +318,7 @@ dependencies = [ [[package]] name = "arrow-csv" version = "58.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d374882fb465a194462527c0c15a93aa19a554cf690a6b77a26b2a02539937a7" +source = "git+https://github.com/pydantic/arrow-rs.git?branch=friendlymatthew%2Fstatistics-converter-from-col-index#a987c9e8b38e0f07efe355dcd80d4662b3a44b3f" dependencies = [ "arrow-array", "arrow-cast", @@ -338,8 +332,7 @@ dependencies = [ [[package]] name = "arrow-data" version = "58.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "189d210bc4244c715fa3ed9e6e22864673cccb73d5da28c2723fb2e527329b33" +source = "git+https://github.com/pydantic/arrow-rs.git?branch=friendlymatthew%2Fstatistics-converter-from-col-index#a987c9e8b38e0f07efe355dcd80d4662b3a44b3f" dependencies = [ "arrow-buffer", "arrow-schema", @@ -351,8 +344,7 @@ dependencies = [ [[package]] name = "arrow-flight" version = "58.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4f5cdf00ee0003ba0768d3575d0afc47d736b29673b14c3c228fdffa9a3fb29" +source = "git+https://github.com/pydantic/arrow-rs.git?branch=friendlymatthew%2Fstatistics-converter-from-col-index#a987c9e8b38e0f07efe355dcd80d4662b3a44b3f" dependencies = [ "arrow-arith", "arrow-array", @@ -379,8 +371,7 @@ dependencies = [ [[package]] name = "arrow-ipc" version = "58.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7968c2e5210c41f4909b2ef76f6e05e172b99021c2def5edf3cc48fdd39d1d6c" +source = "git+https://github.com/pydantic/arrow-rs.git?branch=friendlymatthew%2Fstatistics-converter-from-col-index#a987c9e8b38e0f07efe355dcd80d4662b3a44b3f" dependencies = [ "arrow-array", "arrow-buffer", @@ -395,8 +386,7 @@ dependencies = [ [[package]] name = "arrow-json" version = "58.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92111dba5bf900f443488e01f00d8c4ddc2f47f5c50039d18120287b580baa22" +source = "git+https://github.com/pydantic/arrow-rs.git?branch=friendlymatthew%2Fstatistics-converter-from-col-index#a987c9e8b38e0f07efe355dcd80d4662b3a44b3f" dependencies = [ "arrow-array", "arrow-buffer", @@ -419,8 +409,7 @@ dependencies = [ [[package]] name = "arrow-ord" version = "58.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "211136cb253577ee1a6665f741a13136d4e563f64f5093ffd6fb837af90b9495" +source = "git+https://github.com/pydantic/arrow-rs.git?branch=friendlymatthew%2Fstatistics-converter-from-col-index#a987c9e8b38e0f07efe355dcd80d4662b3a44b3f" dependencies = [ "arrow-array", "arrow-buffer", @@ -432,8 +421,7 @@ dependencies = [ [[package]] name = "arrow-row" version = "58.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e0f20145f9f5ea3fe383e2ba7a7487bf19be36aa9dbf5dd6a1f92f657179663" +source = "git+https://github.com/pydantic/arrow-rs.git?branch=friendlymatthew%2Fstatistics-converter-from-col-index#a987c9e8b38e0f07efe355dcd80d4662b3a44b3f" dependencies = [ "arrow-array", "arrow-buffer", @@ -445,8 +433,7 @@ dependencies = [ [[package]] name = "arrow-schema" version = "58.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b47e0ca91cc438d2c7879fe95e0bca5329fff28649e30a88c6f760b1faeddcb" +source = "git+https://github.com/pydantic/arrow-rs.git?branch=friendlymatthew%2Fstatistics-converter-from-col-index#a987c9e8b38e0f07efe355dcd80d4662b3a44b3f" dependencies = [ "bitflags", "serde", @@ -457,8 +444,7 @@ dependencies = [ [[package]] name = "arrow-select" version = "58.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "750a7d1dda177735f5e82a314485b6915c7cccdbb278262ac44090f4aba4a325" +source = "git+https://github.com/pydantic/arrow-rs.git?branch=friendlymatthew%2Fstatistics-converter-from-col-index#a987c9e8b38e0f07efe355dcd80d4662b3a44b3f" dependencies = [ "ahash", "arrow-array", @@ -471,8 +457,7 @@ dependencies = [ [[package]] name = "arrow-string" version = "58.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1eab1208bc4fe55d768cdc9b9f3d9df5a794cdb3ee2586bf89f9b30dc31ad8c" +source = "git+https://github.com/pydantic/arrow-rs.git?branch=friendlymatthew%2Fstatistics-converter-from-col-index#a987c9e8b38e0f07efe355dcd80d4662b3a44b3f" dependencies = [ "arrow-array", "arrow-buffer", @@ -2559,6 +2544,7 @@ dependencies = [ "datafusion-datasource", "datafusion-expr", "datafusion-expr-common", + "datafusion-functions", "datafusion-functions-nested", "datafusion-physical-expr", "datafusion-physical-expr-common", @@ -4393,8 +4379,7 @@ dependencies = [ [[package]] name = "parquet" version = "58.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f491d0ef1b510194426ee67ddc18a9b747ef3c42050c19322a2cd2e1666c29b" +source = "git+https://github.com/pydantic/arrow-rs.git?branch=friendlymatthew%2Fstatistics-converter-from-col-index#a987c9e8b38e0f07efe355dcd80d4662b3a44b3f" dependencies = [ "ahash", "arrow-array", diff --git a/Cargo.toml b/Cargo.toml index a185cd874a013..aa05960313406 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -91,19 +91,19 @@ ahash = { version = "0.8", default-features = false, features = [ "runtime-rng", ] } apache-avro = { version = "0.21", default-features = false } -arrow = { version = "58.0.0", features = [ +arrow = { git = "https://github.com/pydantic/arrow-rs.git", branch = "friendlymatthew/statistics-converter-from-col-index", features = [ "prettyprint", "chrono-tz", ] } -arrow-buffer = { version = "58.0.0", default-features = false } -arrow-flight = { version = "58.0.0", features = [ +arrow-buffer = { git = "https://github.com/pydantic/arrow-rs.git", branch = "friendlymatthew/statistics-converter-from-col-index", default-features = false } +arrow-flight = { git = "https://github.com/pydantic/arrow-rs.git", branch = "friendlymatthew/statistics-converter-from-col-index", features = [ "flight-sql-experimental", ] } -arrow-ipc = { version = "58.0.0", default-features = false, features = [ +arrow-ipc = { git = "https://github.com/pydantic/arrow-rs.git", branch = "friendlymatthew/statistics-converter-from-col-index", default-features = false, features = [ "lz4", ] } -arrow-ord = { version = "58.0.0", default-features = false } -arrow-schema = { version = "58.0.0", default-features = false } +arrow-ord = { git = "https://github.com/pydantic/arrow-rs.git", branch = "friendlymatthew/statistics-converter-from-col-index", default-features = false } +arrow-schema = { git = "https://github.com/pydantic/arrow-rs.git", branch = "friendlymatthew/statistics-converter-from-col-index", default-features = false } async-trait = "0.1.89" bigdecimal = "0.4.8" bytes = "1.11" @@ -168,7 +168,7 @@ memchr = "2.8.0" num-traits = { version = "0.2" } object_store = { version = "0.13.1", default-features = false } parking_lot = "0.12" -parquet = { version = "58.0.0", default-features = false, features = [ +parquet = { git = "https://github.com/pydantic/arrow-rs.git", branch = "friendlymatthew/statistics-converter-from-col-index", default-features = false, features = [ "arrow", "async", "object_store", diff --git a/datafusion-examples/examples/data_io/parquet_index.rs b/datafusion-examples/examples/data_io/parquet_index.rs index e11a303f442a4..f7b3ec44b72b1 100644 --- a/datafusion-examples/examples/data_io/parquet_index.rs +++ b/datafusion-examples/examples/data_io/parquet_index.rs @@ -25,7 +25,7 @@ use arrow::datatypes::{Int32Type, SchemaRef}; use arrow::util::pretty::pretty_format_batches; use async_trait::async_trait; use datafusion::catalog::Session; -use datafusion::common::pruning::PruningStatistics; +use datafusion::common::pruning::{PruningColumn, PruningStatistics}; use datafusion::common::{ DFSchema, DataFusionError, Result, ScalarValue, internal_datafusion_err, }; @@ -432,21 +432,19 @@ impl ParquetMetadataIndex { /// the required statistics via the [`PruningStatistics`] trait impl PruningStatistics for ParquetMetadataIndex { /// return the minimum values for the value column - fn min_values(&self, column: &Column) -> Option { - if column.name.eq("value") { - Some(self.value_column_mins().clone()) - } else { - None - } + fn min_values(&self, column: &PruningColumn) -> Option { + column + .name() + .eq("value") + .then_some(self.value_column_mins().clone()) } /// return the maximum values for the value column - fn max_values(&self, column: &Column) -> Option { - if column.name.eq("value") { - Some(self.value_column_maxes().clone()) - } else { - None - } + fn max_values(&self, column: &PruningColumn) -> Option { + column + .name() + .eq("value") + .then_some(self.value_column_maxes().clone()) } /// return the number of "containers". In this example, each "container" is @@ -457,12 +455,12 @@ impl PruningStatistics for ParquetMetadataIndex { /// Return `None` to signal we don't have any information about null /// counts in the index, - fn null_counts(&self, _column: &Column) -> Option { + fn null_counts(&self, _column: &PruningColumn) -> Option { None } /// return the row counts for each file - fn row_counts(&self, _column: &Column) -> Option { + fn row_counts(&self, _column: &PruningColumn) -> Option { Some(self.row_counts_ref().clone()) } @@ -470,7 +468,7 @@ impl PruningStatistics for ParquetMetadataIndex { /// but is not used in this example, so return `None` fn contained( &self, - _column: &Column, + _column: &PruningColumn, _values: &HashSet, ) -> Option { None diff --git a/datafusion-examples/examples/query_planning/pruning.rs b/datafusion-examples/examples/query_planning/pruning.rs index 33f3f8428a77f..d3ba0235d9702 100644 --- a/datafusion-examples/examples/query_planning/pruning.rs +++ b/datafusion-examples/examples/query_planning/pruning.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use arrow::array::{ArrayRef, BooleanArray, Int32Array}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use datafusion::common::pruning::PruningStatistics; +use datafusion::common::pruning::{PruningColumn, PruningStatistics}; use datafusion::common::{DFSchema, ScalarValue}; use datafusion::error::Result; use datafusion::execution::context::ExecutionProps; @@ -148,40 +148,40 @@ impl PruningStatistics for MyCatalog { 3 } - fn min_values(&self, column: &Column) -> Option { + fn min_values(&self, column: &PruningColumn) -> Option { // The pruning predicate evaluates the bounds for multiple expressions // at once, so return an array with an element for the minimum value in // each file - match column.name.as_str() { + match column.name() { "x" => Some(i32_array(self.x_values.iter().map(|(min, _)| min))), "y" => Some(i32_array(self.y_values.iter().map(|(min, _)| min))), name => panic!("unknown column name: {name}"), } } - fn max_values(&self, column: &Column) -> Option { + fn max_values(&self, column: &PruningColumn) -> Option { // similarly to min_values, return an array with an element for the // maximum value in each file - match column.name.as_str() { + match column.name() { "x" => Some(i32_array(self.x_values.iter().map(|(_, max)| max))), "y" => Some(i32_array(self.y_values.iter().map(|(_, max)| max))), name => panic!("unknown column name: {name}"), } } - fn null_counts(&self, _column: &Column) -> Option { + fn null_counts(&self, _column: &PruningColumn) -> Option { // In this example, we know nothing about the number of nulls None } - fn row_counts(&self, _column: &Column) -> Option { + fn row_counts(&self, _column: &PruningColumn) -> Option { // In this example, we know nothing about the number of rows in each file None } fn contained( &self, - _column: &Column, + _column: &PruningColumn, _values: &HashSet, ) -> Option { // this method can be used to implement Bloom filter like filtering diff --git a/datafusion/common/src/pruning.rs b/datafusion/common/src/pruning.rs index 5a7598ea1f299..a7e61e59d5d3f 100644 --- a/datafusion/common/src/pruning.rs +++ b/datafusion/common/src/pruning.rs @@ -26,6 +26,50 @@ use crate::stats::Precision; use crate::{Column, Statistics}; use crate::{ColumnStatistics, ScalarValue}; +/// Identifies a column for statistics lookup in [`PruningStatistics`]. +/// +/// Can represent either a top-level column or a nested field within a struct. +/// For top-level columns, `field_path` is empty. For nested struct fields +/// (e.g., `s['outer']['inner']`), `field_path` contains the path from the +/// root column to the leaf field (e.g., `["outer", "inner"]`). +#[derive(Debug, Clone)] +pub struct PruningColumn { + /// The root column reference. + pub column: Column, + /// Path to nested struct field. Empty for top-level columns. + pub field_path: Vec, +} + +impl PruningColumn { + /// Returns the name of the root column. + pub fn name(&self) -> &str { + self.column.name() + } + + /// Returns true if this references a nested struct field. + pub fn is_nested(&self) -> bool { + !self.field_path.is_empty() + } +} + +impl From for PruningColumn { + fn from(column: Column) -> Self { + Self { + column, + field_path: vec![], + } + } +} + +impl From<&Column> for PruningColumn { + fn from(column: &Column) -> Self { + Self { + column: column.clone(), + field_path: vec![], + } + } +} + /// A source of runtime statistical information to [`PruningPredicate`]s. /// /// # Supported Information @@ -68,14 +112,14 @@ pub trait PruningStatistics { /// not known for any row, return `None`. /// /// Note: the returned array must contain [`Self::num_containers`] rows - fn min_values(&self, column: &Column) -> Option; + fn min_values(&self, column: &PruningColumn) -> Option; /// Return the maximum values for the named column, if known. /// /// See [`Self::min_values`] for when to return `None` and null values. /// /// Note: the returned array must contain [`Self::num_containers`] rows - fn max_values(&self, column: &Column) -> Option; + fn max_values(&self, column: &PruningColumn) -> Option; /// Return the number of containers (e.g. Row Groups) being pruned with /// these statistics. @@ -93,7 +137,7 @@ pub trait PruningStatistics { /// Note: the returned array must contain [`Self::num_containers`] rows /// /// [`UInt64Array`]: arrow::array::UInt64Array - fn null_counts(&self, column: &Column) -> Option; + fn null_counts(&self, column: &PruningColumn) -> Option; /// Return the number of rows for the named column in each container /// as an [`UInt64Array`]. @@ -103,7 +147,7 @@ pub trait PruningStatistics { /// Note: the returned array must contain [`Self::num_containers`] rows /// /// [`UInt64Array`]: arrow::array::UInt64Array - fn row_counts(&self, column: &Column) -> Option; + fn row_counts(&self, column: &PruningColumn) -> Option; /// Returns [`BooleanArray`] where each row represents information known /// about specific literal `values` in a column. @@ -123,7 +167,7 @@ pub trait PruningStatistics { /// Note: the returned array must contain [`Self::num_containers`] rows fn contained( &self, - column: &Column, + column: &PruningColumn, values: &HashSet, ) -> Option; } @@ -239,7 +283,7 @@ impl PartitionPruningStatistics { #[expect(deprecated)] impl PruningStatistics for PartitionPruningStatistics { - fn min_values(&self, column: &Column) -> Option { + fn min_values(&self, column: &PruningColumn) -> Option { let index = self.partition_schema.index_of(column.name()).ok()?; self.partition_values.get(index).and_then(|v| { if v.is_empty() || v.null_count() == v.len() { @@ -252,7 +296,7 @@ impl PruningStatistics for PartitionPruningStatistics { }) } - fn max_values(&self, column: &Column) -> Option { + fn max_values(&self, column: &PruningColumn) -> Option { self.min_values(column) } @@ -260,17 +304,17 @@ impl PruningStatistics for PartitionPruningStatistics { self.num_containers } - fn null_counts(&self, _column: &Column) -> Option { + fn null_counts(&self, _column: &PruningColumn) -> Option { None } - fn row_counts(&self, _column: &Column) -> Option { + fn row_counts(&self, _column: &PruningColumn) -> Option { None } fn contained( &self, - column: &Column, + column: &PruningColumn, values: &HashSet, ) -> Option { let index = self.partition_schema.index_of(column.name()).ok()?; @@ -327,7 +371,7 @@ impl PrunableStatistics { fn get_exact_column_statistics( &self, - column: &Column, + column: &PruningColumn, get_stat: impl Fn(&ColumnStatistics) -> &Precision, ) -> Option { let index = self.schema.index_of(column.name()).ok()?; @@ -359,11 +403,11 @@ impl PrunableStatistics { } impl PruningStatistics for PrunableStatistics { - fn min_values(&self, column: &Column) -> Option { + fn min_values(&self, column: &PruningColumn) -> Option { self.get_exact_column_statistics(column, |stat| &stat.min_value) } - fn max_values(&self, column: &Column) -> Option { + fn max_values(&self, column: &PruningColumn) -> Option { self.get_exact_column_statistics(column, |stat| &stat.max_value) } @@ -371,7 +415,7 @@ impl PruningStatistics for PrunableStatistics { self.statistics.len() } - fn null_counts(&self, column: &Column) -> Option { + fn null_counts(&self, column: &PruningColumn) -> Option { let index = self.schema.index_of(column.name()).ok()?; if self.statistics.iter().any(|s| { s.column_statistics @@ -397,7 +441,7 @@ impl PruningStatistics for PrunableStatistics { } } - fn row_counts(&self, column: &Column) -> Option { + fn row_counts(&self, column: &PruningColumn) -> Option { // If the column does not exist in the schema, return None if self.schema.index_of(column.name()).is_err() { return None; @@ -426,7 +470,7 @@ impl PruningStatistics for PrunableStatistics { fn contained( &self, - _column: &Column, + _column: &PruningColumn, _values: &HashSet, ) -> Option { None @@ -470,7 +514,7 @@ impl CompositePruningStatistics { #[expect(deprecated)] impl PruningStatistics for CompositePruningStatistics { - fn min_values(&self, column: &Column) -> Option { + fn min_values(&self, column: &PruningColumn) -> Option { for stats in &self.statistics { if let Some(array) = stats.min_values(column) { return Some(array); @@ -479,7 +523,7 @@ impl PruningStatistics for CompositePruningStatistics { None } - fn max_values(&self, column: &Column) -> Option { + fn max_values(&self, column: &PruningColumn) -> Option { for stats in &self.statistics { if let Some(array) = stats.max_values(column) { return Some(array); @@ -492,7 +536,7 @@ impl PruningStatistics for CompositePruningStatistics { self.statistics[0].num_containers() } - fn null_counts(&self, column: &Column) -> Option { + fn null_counts(&self, column: &PruningColumn) -> Option { for stats in &self.statistics { if let Some(array) = stats.null_counts(column) { return Some(array); @@ -501,7 +545,7 @@ impl PruningStatistics for CompositePruningStatistics { None } - fn row_counts(&self, column: &Column) -> Option { + fn row_counts(&self, column: &PruningColumn) -> Option { for stats in &self.statistics { if let Some(array) = stats.row_counts(column) { return Some(array); @@ -512,7 +556,7 @@ impl PruningStatistics for CompositePruningStatistics { fn contained( &self, - column: &Column, + column: &PruningColumn, values: &HashSet, ) -> Option { for stats in &self.statistics { @@ -536,6 +580,11 @@ mod tests { use arrow::datatypes::{DataType, Field}; use std::sync::Arc; + /// Helper to create a PruningColumn from a column name for tests + fn pruning_col(name: &str) -> PruningColumn { + Column::new_unqualified(name).into() + } + /// return a PartitionPruningStatistics for two columns 'a' and 'b' /// and the following stats /// @@ -559,8 +608,8 @@ mod tests { fn test_partition_pruning_statistics() { let partition_stats = partition_pruning_statistics_setup(); - let column_a = Column::new_unqualified("a"); - let column_b = Column::new_unqualified("b"); + let column_a = pruning_col("a"); + let column_b = pruning_col("b"); // Partition values don't know anything about nulls or row counts assert!(partition_stats.null_counts(&column_a).is_none()); @@ -616,7 +665,7 @@ mod tests { fn test_partition_pruning_statistics_multiple_positive_values() { let partition_stats = partition_pruning_statistics_setup(); - let column_a = Column::new_unqualified("a"); + let column_a = pruning_col("a"); // The two containers have `a` values 1 and 3, so they both only contain values from 1 and 3 let values = HashSet::from([ScalarValue::from(1i32), ScalarValue::from(3i32)]); @@ -629,7 +678,7 @@ mod tests { fn test_partition_pruning_statistics_multiple_negative_values() { let partition_stats = partition_pruning_statistics_setup(); - let column_a = Column::new_unqualified("a"); + let column_a = pruning_col("a"); // The two containers have `a` values 1 and 3, // so the first contains ONLY values from 1,2 @@ -663,9 +712,9 @@ mod tests { PartitionPruningStatistics::try_new(partition_values, partition_fields) .unwrap(); - let column_a = Column::new_unqualified("a"); - let column_b = Column::new_unqualified("b"); - let column_c = Column::new_unqualified("c"); + let column_a = pruning_col("a"); + let column_b = pruning_col("b"); + let column_c = pruning_col("c"); let values_a = HashSet::from([ScalarValue::from(1i32), ScalarValue::Int32(None)]); let contained_a = partition_stats.contained(&column_a, &values_a).unwrap(); @@ -702,8 +751,8 @@ mod tests { PartitionPruningStatistics::try_new(partition_values, partition_fields) .unwrap(); - let column_a = Column::new_unqualified("a"); - let column_b = Column::new_unqualified("b"); + let column_a = pruning_col("a"); + let column_b = pruning_col("b"); // Partition values don't know anything about nulls or row counts assert!(partition_stats.null_counts(&column_a).is_none()); @@ -766,8 +815,8 @@ mod tests { ])); let pruning_stats = PrunableStatistics::new(statistics, schema); - let column_a = Column::new_unqualified("a"); - let column_b = Column::new_unqualified("b"); + let column_a = pruning_col("a"); + let column_b = pruning_col("b"); // Min/max values are the same as the statistics let min_values_a = as_int32_array(&pruning_stats.min_values(&column_a).unwrap()) @@ -834,7 +883,7 @@ mod tests { assert_eq!(pruning_stats.num_containers(), 2); // Test with a column that has no statistics - let column_c = Column::new_unqualified("c"); + let column_c = pruning_col("c"); assert!(pruning_stats.min_values(&column_c).is_none()); assert!(pruning_stats.max_values(&column_c).is_none()); assert!(pruning_stats.null_counts(&column_c).is_none()); @@ -852,7 +901,7 @@ mod tests { assert!(pruning_stats.contained(&column_c, &values).is_none()); // Test with a column that doesn't exist - let column_d = Column::new_unqualified("d"); + let column_d = pruning_col("d"); assert!(pruning_stats.min_values(&column_d).is_none()); assert!(pruning_stats.max_values(&column_d).is_none()); assert!(pruning_stats.null_counts(&column_d).is_none()); @@ -870,8 +919,8 @@ mod tests { ])); let pruning_stats = PrunableStatistics::new(statistics, schema); - let column_a = Column::new_unqualified("a"); - let column_b = Column::new_unqualified("b"); + let column_a = pruning_col("a"); + let column_b = pruning_col("b"); // Min/max values are all missing assert!(pruning_stats.min_values(&column_a).is_none()); @@ -956,12 +1005,12 @@ mod tests { ]); // Test accessing columns that are only in partition statistics - let part_a = Column::new_unqualified("part_a"); - let part_b = Column::new_unqualified("part_b"); + let part_a = pruning_col("part_a"); + let part_b = pruning_col("part_b"); // Test accessing columns that are only in file statistics - let col_x = Column::new_unqualified("col_x"); - let col_y = Column::new_unqualified("col_y"); + let col_x = pruning_col("col_x"); + let col_y = pruning_col("col_y"); // For partition columns, should get values from partition statistics let min_values_part_a = @@ -1045,7 +1094,7 @@ mod tests { assert!(composite_stats.contained(&col_x, &values).is_none()); // Non-existent column should return None for everything - let non_existent = Column::new_unqualified("non_existent"); + let non_existent = pruning_col("non_existent"); assert!(composite_stats.min_values(&non_existent).is_none()); assert!(composite_stats.max_values(&non_existent).is_none()); assert!(composite_stats.null_counts(&non_existent).is_none()); @@ -1129,7 +1178,7 @@ mod tests { Box::new(second_stats.clone()), ]); - let col_a = Column::new_unqualified("col_a"); + let col_a = pruning_col("col_a"); // Should get values from first statistics since it has priority let min_values = as_int32_array(&composite_stats.min_values(&col_a).unwrap()) diff --git a/datafusion/datasource-parquet/src/page_filter.rs b/datafusion/datasource-parquet/src/page_filter.rs index 194e6e94fba3a..5db2d9dfe9b9f 100644 --- a/datafusion/datasource-parquet/src/page_filter.rs +++ b/datafusion/datasource-parquet/src/page_filter.rs @@ -29,7 +29,7 @@ use arrow::{ datatypes::{Schema, SchemaRef}, }; use datafusion_common::ScalarValue; -use datafusion_common::pruning::PruningStatistics; +use datafusion_common::pruning::{PruningColumn, PruningStatistics}; use datafusion_physical_expr::{PhysicalExpr, split_conjunction}; use datafusion_pruning::PruningPredicate; @@ -463,7 +463,7 @@ impl<'a> PagesPruningStatistics<'a> { } } impl PruningStatistics for PagesPruningStatistics<'_> { - fn min_values(&self, _column: &datafusion_common::Column) -> Option { + fn min_values(&self, _column: &PruningColumn) -> Option { match self.converter.data_page_mins( self.column_index, self.offset_index, @@ -477,7 +477,7 @@ impl PruningStatistics for PagesPruningStatistics<'_> { } } - fn max_values(&self, _column: &datafusion_common::Column) -> Option { + fn max_values(&self, _column: &PruningColumn) -> Option { match self.converter.data_page_maxes( self.column_index, self.offset_index, @@ -495,7 +495,7 @@ impl PruningStatistics for PagesPruningStatistics<'_> { self.page_offsets.len() } - fn null_counts(&self, _column: &datafusion_common::Column) -> Option { + fn null_counts(&self, _column: &PruningColumn) -> Option { match self.converter.data_page_null_counts( self.column_index, self.offset_index, @@ -509,7 +509,7 @@ impl PruningStatistics for PagesPruningStatistics<'_> { } } - fn row_counts(&self, _column: &datafusion_common::Column) -> Option { + fn row_counts(&self, _column: &PruningColumn) -> Option { match self.converter.data_page_row_counts( self.offset_index, self.row_group_metadatas, @@ -525,7 +525,7 @@ impl PruningStatistics for PagesPruningStatistics<'_> { fn contained( &self, - _column: &datafusion_common::Column, + _column: &PruningColumn, _values: &HashSet, ) -> Option { None diff --git a/datafusion/datasource-parquet/src/row_group_filter.rs b/datafusion/datasource-parquet/src/row_group_filter.rs index 932988af051e4..b2a92c1669b20 100644 --- a/datafusion/datasource-parquet/src/row_group_filter.rs +++ b/datafusion/datasource-parquet/src/row_group_filter.rs @@ -20,9 +20,9 @@ use std::sync::Arc; use super::{ParquetAccessPlan, ParquetFileMetrics}; use arrow::array::{ArrayRef, BooleanArray}; -use arrow::datatypes::Schema; -use datafusion_common::pruning::PruningStatistics; -use datafusion_common::{Column, Result, ScalarValue}; +use arrow::datatypes::{DataType, Schema}; +use datafusion_common::pruning::{PruningColumn, PruningStatistics}; +use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_datasource::FileRange; use datafusion_physical_expr::PhysicalExprSimplifier; use datafusion_physical_expr::expressions::NotExpr; @@ -524,11 +524,11 @@ impl BloomFilterStatistics { } impl PruningStatistics for BloomFilterStatistics { - fn min_values(&self, _column: &Column) -> Option { + fn min_values(&self, _column: &PruningColumn) -> Option { None } - fn max_values(&self, _column: &Column) -> Option { + fn max_values(&self, _column: &PruningColumn) -> Option { None } @@ -536,11 +536,11 @@ impl PruningStatistics for BloomFilterStatistics { 1 } - fn null_counts(&self, _column: &Column) -> Option { + fn null_counts(&self, _column: &PruningColumn) -> Option { None } - fn row_counts(&self, _column: &Column) -> Option { + fn row_counts(&self, _column: &PruningColumn) -> Option { None } @@ -551,10 +551,10 @@ impl PruningStatistics for BloomFilterStatistics { /// of the values in a column are not present. fn contained( &self, - column: &Column, + column: &PruningColumn, values: &HashSet, ) -> Option { - let (sbbf, parquet_type) = self.column_sbbf.get(column.name.as_str())?; + let (sbbf, parquet_type) = self.column_sbbf.get(column.name())?; // Bloom filters are probabilistic data structures that can return false // positives (i.e. it might return true even if the value is not @@ -594,26 +594,104 @@ impl<'a> RowGroupPruningStatistics<'a> { self.row_group_metadatas.iter().copied() } - fn statistics_converter<'b>( + /// Returns a `StatisticsConverter` for the given column. + /// + /// For nested struct fields (where `column.is_nested()` is true), resolves + /// the field path to a Parquet leaf column index and uses + /// [`StatisticsConverter::from_column_index`]. + /// For top-level columns, uses [`StatisticsConverter::try_new`]. + fn statistics_converter( &'a self, - column: &'b Column, + column: &PruningColumn, ) -> Result> { + if column.is_nested() { + let leaf_idx = self.resolve_nested_leaf_index(column).ok_or_else(|| { + DataFusionError::Internal(format!( + "could not resolve nested field path {:?} for column '{}'", + column.field_path, + column.name() + )) + })?; + + let arrow_field = self.resolve_nested_field(column).ok_or_else(|| { + DataFusionError::Internal(format!( + "could not resolve Arrow field for nested path {:?} in column '{}'", + column.field_path, + column.name() + )) + })?; + + return Ok(StatisticsConverter::from_column_index( + leaf_idx, + arrow_field, + self.parquet_schema, + )?); + } + Ok(StatisticsConverter::try_new( - &column.name, + column.name(), self.arrow_schema, self.parquet_schema, )?) } + + /// Resolve a nested struct field path to the Parquet leaf column index. + /// + /// For example, given a struct column "s" with field path ["outer", "inner"], + /// this finds the leaf column in the Parquet schema corresponding to "s.outer.inner". + fn resolve_nested_leaf_index(&self, column: &PruningColumn) -> Option { + let full_path = std::iter::once(column.name().to_string()) + .chain(column.field_path.iter().cloned()) + .collect::>(); + + // Search through all leaf columns in the Parquet schema to find + // one whose path starts with our full path + let num_columns = self.parquet_schema.num_columns(); + for i in 0..num_columns { + let col_descr = self.parquet_schema.column(i); + let col_path = col_descr.path().parts(); + + if col_path.len() >= full_path.len() + && col_path[..full_path.len()] == full_path[..] + { + return Some(i); + } + } + + None + } + + /// Get the Arrow field for a nested struct field by navigating through the schema. + fn resolve_nested_field( + &self, + column: &PruningColumn, + ) -> Option<&arrow::datatypes::Field> { + let root_field = self.arrow_schema.field_with_name(column.name()).ok()?; + + let mut current_field = root_field; + + for path_element in &column.field_path { + match current_field.data_type() { + DataType::Struct(fields) => { + current_field = + fields.iter().find(|f| f.name() == path_element)?.as_ref(); + } + _ => return None, + } + } + + Some(current_field) + } } impl PruningStatistics for RowGroupPruningStatistics<'_> { - fn min_values(&self, column: &Column) -> Option { + fn min_values(&self, column: &PruningColumn) -> Option { self.statistics_converter(column) .and_then(|c| Ok(c.row_group_mins(self.metadata_iter())?)) .ok() } - fn max_values(&self, column: &Column) -> Option { + fn max_values(&self, column: &PruningColumn) -> Option { self.statistics_converter(column) .and_then(|c| Ok(c.row_group_maxes(self.metadata_iter())?)) .ok() @@ -623,14 +701,14 @@ impl PruningStatistics for RowGroupPruningStatistics<'_> { self.row_group_metadatas.len() } - fn null_counts(&self, column: &Column) -> Option { + fn null_counts(&self, column: &PruningColumn) -> Option { self.statistics_converter(column) .and_then(|c| Ok(c.row_group_null_counts(self.metadata_iter())?)) .ok() .map(|counts| Arc::new(counts) as ArrayRef) } - fn row_counts(&self, column: &Column) -> Option { + fn row_counts(&self, column: &PruningColumn) -> Option { // row counts are the same for all columns in a row group self.statistics_converter(column) .and_then(|c| Ok(c.row_group_row_counts(self.metadata_iter())?)) @@ -641,7 +719,7 @@ impl PruningStatistics for RowGroupPruningStatistics<'_> { fn contained( &self, - _column: &Column, + _column: &PruningColumn, _values: &HashSet, ) -> Option { None @@ -660,12 +738,13 @@ mod tests { use arrow::datatypes::{DataType, Field}; use datafusion_common::Result; use datafusion_expr::{Expr, cast, col, lit}; + use datafusion_functions::core::get_field; use datafusion_physical_expr::planner::logical2physical; use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; use object_store::ObjectStoreExt; use parquet::arrow::ArrowSchemaConverter; use parquet::arrow::async_reader::ParquetObjectReader; - use parquet::basic::LogicalType; + use parquet::basic::{LogicalType, Repetition}; use parquet::data_type::{ByteArray, FixedLenByteArray}; use parquet::file::metadata::ColumnChunkMetaData; use parquet::{ @@ -1797,4 +1876,364 @@ mod tests { Ok(pruned_row_groups) } + + /// Build a Parquet SchemaDescriptor for a struct column. + /// + /// Creates: `s: struct { value: INT32, label: BYTE_ARRAY }` + /// Parquet leaves: s.value (index 0), s.label (index 1) + fn get_struct_schema_descr() -> SchemaDescPtr { + use parquet::schema::types::Type as SchemaType; + + let value_field = Arc::new( + SchemaType::primitive_type_builder("value", PhysicalType::INT32) + .build() + .unwrap(), + ); + let label_field = Arc::new( + SchemaType::primitive_type_builder("label", PhysicalType::BYTE_ARRAY) + .with_logical_type(Some(LogicalType::String)) + .build() + .unwrap(), + ); + let struct_group = Arc::new( + SchemaType::group_type_builder("s") + .with_fields(vec![value_field, label_field]) + .with_repetition(Repetition::REQUIRED) + .build() + .unwrap(), + ); + let schema = SchemaType::group_type_builder("schema") + .with_fields(vec![struct_group]) + .build() + .unwrap(); + Arc::new(SchemaDescriptor::new(Arc::new(schema))) + } + + /// Test that row group pruning works for struct field predicates. + /// + /// Creates two row groups with statistics for `s.value`: + /// - RG0: s.value min=1, max=10 + /// - RG1: s.value min=11, max=20 + /// + /// Predicate: `get_field(s, 'value') > 15` + /// Expected: RG0 is pruned (max=10 < 15), RG1 remains (max=20 >= 15) + #[test] + fn row_group_pruning_predicate_struct_field() { + let struct_fields: arrow::datatypes::Fields = vec![ + Arc::new(Field::new("value", DataType::Int32, false)), + Arc::new(Field::new("label", DataType::Utf8, false)), + ] + .into(); + let schema = Arc::new(Schema::new(vec![Field::new( + "s", + DataType::Struct(struct_fields), + false, + )])); + + // get_field(s, 'value') > 15 + let get_field_expr = get_field().call(vec![ + col("s"), + Expr::Literal(ScalarValue::Utf8(Some("value".to_string())), None), + ]); + let predicate = get_field_expr.gt(lit(15i32)); + let expr = logical2physical(&predicate, &schema); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); + + let schema_descr = get_struct_schema_descr(); + + // RG0: s.value min=1, max=10 → should be PRUNED (max < 15) + let rgm0 = get_row_group_meta_data( + &schema_descr, + vec![ + ParquetStatistics::int32(Some(1), Some(10), None, Some(0), false), + ParquetStatistics::byte_array(None, None, None, Some(0), false), + ], + ); + // RG1: s.value min=11, max=20 → should REMAIN (max >= 15) + let rgm1 = get_row_group_meta_data( + &schema_descr, + vec![ + ParquetStatistics::int32(Some(11), Some(20), None, Some(0), false), + ParquetStatistics::byte_array(None, None, None, Some(0), false), + ], + ); + + let metrics = parquet_file_metrics(); + let mut row_groups = RowGroupAccessPlanFilter::new(ParquetAccessPlan::new_all(2)); + row_groups.prune_by_statistics( + &schema, + &schema_descr, + &[rgm0, rgm1], + &pruning_predicate, + &metrics, + ); + + // Only RG1 (index 1) should remain + assert_pruned(row_groups, ExpectedPruning::Some(vec![1])); + } + + /// Test that row group pruning works when all row groups match + /// the struct field predicate. + #[test] + fn row_group_pruning_predicate_struct_field_no_pruning() { + let struct_fields: arrow::datatypes::Fields = + vec![Arc::new(Field::new("value", DataType::Int32, false))].into(); + let schema = Arc::new(Schema::new(vec![Field::new( + "s", + DataType::Struct(struct_fields), + false, + )])); + + // get_field(s, 'value') > 0 — both row groups should pass + let get_field_expr = get_field().call(vec![ + col("s"), + Expr::Literal(ScalarValue::Utf8(Some("value".to_string())), None), + ]); + let predicate = get_field_expr.gt(lit(0i32)); + let expr = logical2physical(&predicate, &schema); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); + + // Single leaf column (s.value) + let schema_descr = { + use parquet::schema::types::Type as SchemaType; + let value_field = Arc::new( + SchemaType::primitive_type_builder("value", PhysicalType::INT32) + .build() + .unwrap(), + ); + let struct_group = Arc::new( + SchemaType::group_type_builder("s") + .with_fields(vec![value_field]) + .with_repetition(Repetition::REQUIRED) + .build() + .unwrap(), + ); + let schema = SchemaType::group_type_builder("schema") + .with_fields(vec![struct_group]) + .build() + .unwrap(); + Arc::new(SchemaDescriptor::new(Arc::new(schema))) + }; + + let rgm0 = get_row_group_meta_data( + &schema_descr, + vec![ParquetStatistics::int32( + Some(1), + Some(10), + None, + Some(0), + false, + )], + ); + let rgm1 = get_row_group_meta_data( + &schema_descr, + vec![ParquetStatistics::int32( + Some(11), + Some(20), + None, + Some(0), + false, + )], + ); + + let metrics = parquet_file_metrics(); + let mut row_groups = RowGroupAccessPlanFilter::new(ParquetAccessPlan::new_all(2)); + row_groups.prune_by_statistics( + &schema, + &schema_descr, + &[rgm0, rgm1], + &pruning_predicate, + &metrics, + ); + + // Both row groups should remain (all have max > 0) + assert_pruned(row_groups, ExpectedPruning::None); + } + + /// Test that row group pruning works when ALL row groups can be pruned + /// by the struct field predicate. + #[test] + fn row_group_pruning_predicate_struct_field_all_pruned() { + let struct_fields: arrow::datatypes::Fields = + vec![Arc::new(Field::new("value", DataType::Int32, false))].into(); + let schema = Arc::new(Schema::new(vec![Field::new( + "s", + DataType::Struct(struct_fields), + false, + )])); + + // get_field(s, 'value') > 100 — no row groups match + let get_field_expr = get_field().call(vec![ + col("s"), + Expr::Literal(ScalarValue::Utf8(Some("value".to_string())), None), + ]); + let predicate = get_field_expr.gt(lit(100i32)); + let expr = logical2physical(&predicate, &schema); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); + + let schema_descr = { + use parquet::schema::types::Type as SchemaType; + let value_field = Arc::new( + SchemaType::primitive_type_builder("value", PhysicalType::INT32) + .build() + .unwrap(), + ); + let struct_group = Arc::new( + SchemaType::group_type_builder("s") + .with_fields(vec![value_field]) + .with_repetition(Repetition::REQUIRED) + .build() + .unwrap(), + ); + let schema = SchemaType::group_type_builder("schema") + .with_fields(vec![struct_group]) + .build() + .unwrap(); + Arc::new(SchemaDescriptor::new(Arc::new(schema))) + }; + + // RG0: max=10, RG1: max=20 — both below 100 + let rgm0 = get_row_group_meta_data( + &schema_descr, + vec![ParquetStatistics::int32( + Some(1), + Some(10), + None, + Some(0), + false, + )], + ); + let rgm1 = get_row_group_meta_data( + &schema_descr, + vec![ParquetStatistics::int32( + Some(11), + Some(20), + None, + Some(0), + false, + )], + ); + + let metrics = parquet_file_metrics(); + let mut row_groups = RowGroupAccessPlanFilter::new(ParquetAccessPlan::new_all(2)); + row_groups.prune_by_statistics( + &schema, + &schema_descr, + &[rgm0, rgm1], + &pruning_predicate, + &metrics, + ); + + // All row groups should be pruned + assert_pruned(row_groups, ExpectedPruning::All); + } + + /// End-to-end test: write a parquet file with struct columns and multiple + /// row groups, then verify row group pruning actually skips the right ones. + #[test] + fn row_group_pruning_struct_field_end_to_end() { + use arrow::array::{Int32Array, StringArray, StructArray}; + use arrow::record_batch::RecordBatch; + use datafusion_functions::core::get_field; + use parquet::arrow::ArrowWriter; + use parquet::file::properties::WriterProperties; + use tempfile::NamedTempFile; + + let struct_fields: arrow::datatypes::Fields = vec![ + Arc::new(Field::new("value", DataType::Int32, false)), + Arc::new(Field::new("label", DataType::Utf8, false)), + ] + .into(); + let schema = Arc::new(Schema::new(vec![Field::new( + "s", + DataType::Struct(struct_fields.clone()), + false, + )])); + + // Write two row groups: + // RG0: s.value in [1, 5, 10] + // RG1: s.value in [11, 15, 20] + let batch0 = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(StructArray::new( + struct_fields.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 5, 10])) as _, + Arc::new(StringArray::from(vec!["a", "b", "c"])) as _, + ], + None, + ))], + ) + .unwrap(); + + let batch1 = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(StructArray::new( + struct_fields, + vec![ + Arc::new(Int32Array::from(vec![11, 15, 20])) as _, + Arc::new(StringArray::from(vec!["d", "e", "f"])) as _, + ], + None, + ))], + ) + .unwrap(); + + let file = NamedTempFile::new().unwrap(); + let props = WriterProperties::builder() + .set_max_row_group_row_count(Some(3)) // force each batch into its own row group + .build(); + let mut writer = ArrowWriter::try_new( + file.reopen().unwrap(), + Arc::clone(&schema), + Some(props), + ) + .unwrap(); + writer.write(&batch0).unwrap(); + writer.write(&batch1).unwrap(); + writer.close().unwrap(); + + // Read back and verify two row groups were created + let reader_file = file.reopen().unwrap(); + let builder = + parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder::try_new( + reader_file, + ) + .unwrap(); + let metadata = builder.metadata().clone(); + assert_eq!( + metadata.num_row_groups(), + 2, + "expected 2 row groups in test file" + ); + + let file_schema = builder.schema().clone(); + + // Predicate: get_field(s, 'value') > 12 + // RG0 has max=10, should be pruned + // RG1 has max=20, should remain + let get_field_expr = get_field().call(vec![ + col("s"), + Expr::Literal(ScalarValue::Utf8(Some("value".to_string())), None), + ]); + let predicate = get_field_expr.gt(lit(12i32)); + let expr = logical2physical(&predicate, &file_schema); + let pruning_predicate = + PruningPredicate::try_new(expr, file_schema.clone()).unwrap(); + + let metrics = parquet_file_metrics(); + let mut row_groups = RowGroupAccessPlanFilter::new(ParquetAccessPlan::new_all( + metadata.num_row_groups(), + )); + row_groups.prune_by_statistics( + &file_schema, + metadata.file_metadata().schema_descr(), + metadata.row_groups(), + &pruning_predicate, + &metrics, + ); + + // Only RG1 should remain + assert_pruned(row_groups, ExpectedPruning::Some(vec![1])); + } } diff --git a/datafusion/pruning/Cargo.toml b/datafusion/pruning/Cargo.toml index e6f4bb6f273c9..181291a401961 100644 --- a/datafusion/pruning/Cargo.toml +++ b/datafusion/pruning/Cargo.toml @@ -22,6 +22,7 @@ datafusion-datasource = { workspace = true } datafusion-expr-common = { workspace = true, default-features = true } datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } +datafusion-functions = { workspace = true } datafusion-physical-plan = { workspace = true } log = { workspace = true } diff --git a/datafusion/pruning/src/lib.rs b/datafusion/pruning/src/lib.rs index be17f29eaafa0..fe27340081a0d 100644 --- a/datafusion/pruning/src/lib.rs +++ b/datafusion/pruning/src/lib.rs @@ -22,6 +22,6 @@ mod pruning_predicate; pub use file_pruner::FilePruner; pub use pruning_predicate::{ - PredicateRewriter, PruningPredicate, PruningStatistics, RequiredColumns, - UnhandledPredicateHook, build_pruning_predicate, + PredicateRewriter, PruningColumn, PruningPredicate, PruningStatistics, + RequiredColumns, UnhandledPredicateHook, build_pruning_predicate, }; diff --git a/datafusion/pruning/src/pruning_predicate.rs b/datafusion/pruning/src/pruning_predicate.rs index 6f6b00e80abc2..48fabf4dd6ec6 100644 --- a/datafusion/pruning/src/pruning_predicate.rs +++ b/datafusion/pruning/src/pruning_predicate.rs @@ -29,7 +29,7 @@ use arrow::{ record_batch::{RecordBatch, RecordBatchOptions}, }; // pub use for backwards compatibility -pub use datafusion_common::pruning::PruningStatistics; +pub use datafusion_common::pruning::{PruningColumn, PruningStatistics}; use datafusion_physical_expr::simplifier::PhysicalExprSimplifier; use datafusion_physical_plan::metrics::Count; use log::{debug, trace}; @@ -42,6 +42,8 @@ use datafusion_common::{ tree_node::{Transformed, TreeNode}, }; use datafusion_expr_common::operator::Operator; +use datafusion_functions::core::getfield::GetFieldFunc; +use datafusion_physical_expr::ScalarFunctionExpr; use datafusion_physical_expr::expressions::CastColumnExpr; use datafusion_physical_expr::utils::{Guarantee, LiteralGuarantee}; use datafusion_physical_expr::{PhysicalExprRef, expressions as phys_expr}; @@ -531,7 +533,8 @@ impl PruningPredicate { guarantee, literals, } = literal_guarantee; - if let Some(results) = statistics.contained(column, literals) { + let pruning_column = PruningColumn::from(column); + if let Some(results) = statistics.contained(&pruning_column, literals) { match guarantee { // `In` means the values in the column must be one of the // values in the set for the predicate to evaluate to true. @@ -725,7 +728,8 @@ pub struct RequiredColumns { /// * Statistics type (e.g. Min or Max or Null_Count) /// * The field the statistics value should be placed in for /// pruning predicate evaluation (e.g. `min_value` or `max_value`) - columns: Vec<(phys_expr::Column, StatisticsType, Field)>, + /// * The nested field path for struct field accesses (empty for top-level columns) + columns: Vec<(phys_expr::Column, StatisticsType, Field, Vec)>, } impl RequiredColumns { @@ -743,7 +747,7 @@ impl RequiredColumns { /// * `true` returns None pub fn single_column(&self) -> Option<&phys_expr::Column> { if self.columns.windows(2).all(|w| { - // check if all columns are the same (ignoring statistics and field) + // check if all columns are the same (ignoring statistics, field, and field_path) let c1 = &w[0].0; let c2 = &w[1].0; c1 == c2 @@ -764,7 +768,7 @@ impl RequiredColumns { let fields = self .columns .iter() - .map(|(_c, _t, f)| f.clone()) + .map(|(_c, _t, f, _fp)| f.clone()) .collect::>(); Schema::new(fields) } @@ -773,7 +777,8 @@ impl RequiredColumns { /// `self.columns` for details) pub(crate) fn iter( &self, - ) -> impl Iterator { + ) -> impl Iterator)> + { self.columns.iter() } @@ -781,6 +786,7 @@ impl RequiredColumns { &self, column: &phys_expr::Column, statistics_type: StatisticsType, + field_path: &[String], ) -> Option { match statistics_type { StatisticsType::RowCount => { @@ -788,15 +794,17 @@ impl RequiredColumns { self.columns .iter() .enumerate() - .find(|(_i, (_c, t, _f))| t == &statistics_type) - .map(|(i, (_c, _t, _f))| i) + .find(|(_i, (_c, t, _f, _fp))| t == &statistics_type) + .map(|(i, (_c, _t, _f, _fp))| i) } _ => self .columns .iter() .enumerate() - .find(|(_i, (c, t, _f))| c == column && t == &statistics_type) - .map(|(i, (_c, _t, _f))| i), + .find(|(_i, (c, t, _f, fp))| { + c == column && t == &statistics_type && fp.as_slice() == field_path + }) + .map(|(i, (_c, _t, _f, _fp))| i), } } @@ -814,17 +822,26 @@ impl RequiredColumns { column_expr: &Arc, field: &Field, stat_type: StatisticsType, + field_path: &[String], ) -> Result> { - let (idx, need_to_insert) = match self.find_stat_column(column, stat_type) { - Some(idx) => (idx, false), - None => (self.columns.len(), true), - }; + let (idx, need_to_insert) = + match self.find_stat_column(column, stat_type, field_path) { + Some(idx) => (idx, false), + None => (self.columns.len(), true), + }; let column_name = column.name(); + let path_suffix = if field_path.is_empty() { + String::new() + } else { + format!(".{}", field_path.join(".")) + }; let stat_column_name = match stat_type { - StatisticsType::Min => format!("{column_name}_min"), - StatisticsType::Max => format!("{column_name}_max"), - StatisticsType::NullCount => format!("{column_name}_null_count"), + StatisticsType::Min => format!("{column_name}{path_suffix}_min"), + StatisticsType::Max => format!("{column_name}{path_suffix}_max"), + StatisticsType::NullCount => { + format!("{column_name}{path_suffix}_null_count") + } StatisticsType::RowCount => "row_count".to_string(), }; @@ -836,7 +853,12 @@ impl RequiredColumns { let nullable = true; let stat_field = Field::new(stat_column.name(), field.data_type().clone(), nullable); - self.columns.push((column.clone(), stat_type, stat_field)); + self.columns.push(( + column.clone(), + stat_type, + stat_field, + field_path.to_vec(), + )); } rewrite_column_expr(Arc::clone(column_expr), column, &stat_column) } @@ -847,8 +869,9 @@ impl RequiredColumns { column: &phys_expr::Column, column_expr: &Arc, field: &Field, + field_path: &[String], ) -> Result> { - self.stat_column_expr(column, column_expr, field, StatisticsType::Min) + self.stat_column_expr(column, column_expr, field, StatisticsType::Min, field_path) } /// rewrite col --> col_max @@ -857,8 +880,9 @@ impl RequiredColumns { column: &phys_expr::Column, column_expr: &Arc, field: &Field, + field_path: &[String], ) -> Result> { - self.stat_column_expr(column, column_expr, field, StatisticsType::Max) + self.stat_column_expr(column, column_expr, field, StatisticsType::Max, field_path) } /// rewrite col --> col_null_count @@ -867,8 +891,15 @@ impl RequiredColumns { column: &phys_expr::Column, column_expr: &Arc, field: &Field, + field_path: &[String], ) -> Result> { - self.stat_column_expr(column, column_expr, field, StatisticsType::NullCount) + self.stat_column_expr( + column, + column_expr, + field, + StatisticsType::NullCount, + field_path, + ) } /// rewrite col --> col_row_count @@ -877,14 +908,26 @@ impl RequiredColumns { column: &phys_expr::Column, column_expr: &Arc, field: &Field, + field_path: &[String], ) -> Result> { - self.stat_column_expr(column, column_expr, field, StatisticsType::RowCount) + self.stat_column_expr( + column, + column_expr, + field, + StatisticsType::RowCount, + field_path, + ) } } impl From> for RequiredColumns { fn from(columns: Vec<(phys_expr::Column, StatisticsType, Field)>) -> Self { - Self { columns } + Self { + columns: columns + .into_iter() + .map(|(c, t, f)| (c, t, f, vec![])) + .collect(), + } } } @@ -919,8 +962,11 @@ fn build_statistics_record_batch( ) -> Result { let mut arrays = Vec::::new(); // For each needed statistics column: - for (column, statistics_type, stat_field) in required_columns.iter() { - let column = Column::from_name(column.name()); + for (column, statistics_type, stat_field, field_path) in required_columns.iter() { + let column = PruningColumn { + column: Column::from_name(column.name()), + field_path: field_path.clone(), + }; let data_type = stat_field.data_type(); let num_containers = statistics.num_containers(); @@ -966,6 +1012,9 @@ struct PruningExpressionBuilder<'a> { op: Operator, scalar_expr: Arc, field: &'a Field, + /// For struct field accesses like `get_field(s, 'value')`, the path from + /// the root column to the nested field. Empty for top-level columns. + field_path: Vec, required_columns: &'a mut RequiredColumns, } @@ -1010,17 +1059,54 @@ impl<'a> PruningExpressionBuilder<'a> { }; let df_schema = DFSchema::try_from(Arc::clone(schema))?; - let (column_expr, correct_operator, scalar_expr) = rewrite_expr_to_prunable( - column_expr, - correct_operator, - scalar_expr, - df_schema, - )?; - let field = match schema.column_with_name(column.name()) { - Some((_, f)) => f, - _ => { - return plan_err!("Field not found in schema"); + let (column_expr, correct_operator, scalar_expr, field_path) = + rewrite_expr_to_prunable( + column_expr, + correct_operator, + scalar_expr, + df_schema, + )?; + + // For struct field accesses, navigate through the struct type to find the leaf field. + let field = if field_path.is_empty() { + match schema.column_with_name(column.name()) { + Some((_, f)) => f, + _ => { + return plan_err!("Field not found in schema"); + } } + } else { + // Navigate through nested struct fields to find the leaf field + let root_field = match schema.column_with_name(column.name()) { + Some((_, f)) => f, + _ => { + return plan_err!("Field not found in schema"); + } + }; + let mut current_field = root_field; + for path_element in &field_path { + match current_field.data_type() { + DataType::Struct(fields) => { + current_field = + match fields.iter().find(|f| f.name() == path_element) { + Some(f) => f.as_ref(), + None => { + return plan_err!( + "Struct field '{path_element}' not found in {}", + current_field.name() + ); + } + }; + } + _ => { + return plan_err!( + "Expected struct type for field path navigation, got {}", + current_field.data_type() + ); + } + } + } + current_field }; Ok(Self { @@ -1029,6 +1115,7 @@ impl<'a> PruningExpressionBuilder<'a> { op: correct_operator, scalar_expr, field, + field_path, required_columns, }) } @@ -1042,52 +1129,42 @@ impl<'a> PruningExpressionBuilder<'a> { } fn min_column_expr(&mut self) -> Result> { - self.required_columns - .min_column_expr(&self.column, &self.column_expr, self.field) + self.required_columns.min_column_expr( + &self.column, + &self.column_expr, + self.field, + &self.field_path, + ) } fn max_column_expr(&mut self) -> Result> { - self.required_columns - .max_column_expr(&self.column, &self.column_expr, self.field) + self.required_columns.max_column_expr( + &self.column, + &self.column_expr, + self.field, + &self.field_path, + ) } - /// This function is to simply retune the `null_count` physical expression no matter what the - /// predicate expression is - /// - /// i.e., x > 5 => x_null_count, - /// cast(x as int) < 10 => x_null_count, - /// try_cast(x as float) < 10.0 => x_null_count fn null_count_column_expr(&mut self) -> Result> { - // Retune to [`phys_expr::Column`] let column_expr = Arc::new(self.column.clone()) as _; - - // null_count is DataType::UInt64, which is different from the column's data type (i.e. self.field) let null_count_field = &Field::new(self.field.name(), DataType::UInt64, true); - self.required_columns.null_count_column_expr( &self.column, &column_expr, null_count_field, + &self.field_path, ) } - /// This function is to simply retune the `row_count` physical expression no matter what the - /// predicate expression is - /// - /// i.e., x > 5 => x_row_count, - /// cast(x as int) < 10 => x_row_count, - /// try_cast(x as float) < 10.0 => x_row_count fn row_count_column_expr(&mut self) -> Result> { - // Retune to [`phys_expr::Column`] let column_expr = Arc::new(self.column.clone()) as _; - - // row_count is DataType::UInt64, which is different from the column's data type (i.e. self.field) let row_count_field = &Field::new(self.field.name(), DataType::UInt64, true); - self.required_columns.row_count_column_expr( &self.column, &column_expr, row_count_field, + &self.field_path, ) } } @@ -1104,12 +1181,16 @@ impl<'a> PruningExpressionBuilder<'a> { /// 6. `try_cast(can_prunable_expr) > 10` /// /// More rewrite rules are still in progress. +/// Returns `(column_expr, operator, scalar_expr, field_path)`. +/// +/// `field_path` is non-empty when the expression accesses a nested struct field +/// via `get_field`, e.g. `get_field(s, 'value') > 5` returns field_path `["value"]`. fn rewrite_expr_to_prunable( column_expr: &PhysicalExprRef, op: Operator, scalar_expr: &PhysicalExprRef, schema: DFSchema, -) -> Result<(PhysicalExprRef, Operator, PhysicalExprRef)> { +) -> Result<(PhysicalExprRef, Operator, PhysicalExprRef, Vec)> { if !is_compare_op(op) { return plan_err!("rewrite_expr_to_prunable only support compare expression"); } @@ -1121,34 +1202,66 @@ fn rewrite_expr_to_prunable( .is_some() { // `col op lit()` - Ok((Arc::clone(column_expr), op, Arc::clone(scalar_expr))) + Ok((Arc::clone(column_expr), op, Arc::clone(scalar_expr), vec![])) + } else if let Some(func) = + ScalarFunctionExpr::try_downcast_func::(column_expr.as_ref()) + { + // `get_field(col, 'field_name') op lit()` + let args = func.args(); + if let Some(inner_column) = args + .first() + .and_then(|a| a.as_any().downcast_ref::()) + { + let field_path = args[1..] + .iter() + .map(|arg| { + arg.as_any() + .downcast_ref::() + .and_then(|lit| { + lit.value().try_as_str().flatten().map(|s| s.to_string()) + }) + }) + .collect::>>(); + + match field_path { + Some(path) if !path.is_empty() => { + let inner_col_expr = Arc::new(inner_column.clone()) as _; + Ok((inner_col_expr, op, Arc::clone(scalar_expr), path)) + } + _ => { + plan_err!( + "get_field with non-literal field names is not supported for pruning" + ) + } + } + } else { + plan_err!( + "get_field with non-column first argument is not supported for pruning" + ) + } } else if let Some(cast) = column_expr_any.downcast_ref::() { // `cast(col) op lit()` let arrow_schema = schema.as_arrow(); let from_type = cast.expr().data_type(arrow_schema)?; verify_support_type_for_prune(&from_type, cast.cast_type())?; - let (left, op, right) = + let (left, op, right, field_path) = rewrite_expr_to_prunable(cast.expr(), op, scalar_expr, schema)?; let left = Arc::new(phys_expr::CastExpr::new( left, cast.cast_type().clone(), None, )); - Ok((left, op, right)) + Ok((left, op, right, field_path)) } else if let Some(cast_col) = column_expr_any.downcast_ref::() { // `cast_column(col) op lit()` - same as CastExpr but uses CastColumnExpr let arrow_schema = schema.as_arrow(); let from_type = cast_col.expr().data_type(arrow_schema)?; let to_type = cast_col.target_field().data_type(); verify_support_type_for_prune(&from_type, to_type)?; - let (left, op, right) = + let (left, op, right, field_path) = rewrite_expr_to_prunable(cast_col.expr(), op, scalar_expr, schema)?; - // Predicate pruning / statistics generally don't support struct columns yet. - // In the future we may want to support pruning on nested fields, in which case we probably need to - // do something more sophisticated here. - // But for now since we don't support pruning on nested fields, we can just cast to the target type directly. let left = Arc::new(phys_expr::CastExpr::new(left, to_type.clone(), None)); - Ok((left, op, right)) + Ok((left, op, right, field_path)) } else if let Some(try_cast) = column_expr_any.downcast_ref::() { @@ -1156,19 +1269,19 @@ fn rewrite_expr_to_prunable( let arrow_schema = schema.as_arrow(); let from_type = try_cast.expr().data_type(arrow_schema)?; verify_support_type_for_prune(&from_type, try_cast.cast_type())?; - let (left, op, right) = + let (left, op, right, field_path) = rewrite_expr_to_prunable(try_cast.expr(), op, scalar_expr, schema)?; let left = Arc::new(phys_expr::TryCastExpr::new( left, try_cast.cast_type().clone(), )); - Ok((left, op, right)) + Ok((left, op, right, field_path)) } else if let Some(neg) = column_expr_any.downcast_ref::() { // `-col > lit()` --> `col < -lit()` - let (left, op, right) = + let (left, op, right, field_path) = rewrite_expr_to_prunable(neg.arg(), op, scalar_expr, schema)?; let right = Arc::new(phys_expr::NegativeExpr::new(right)); - Ok((left, reverse_operator(op)?, right)) + Ok((left, reverse_operator(op)?, right, field_path)) } else if let Some(not) = column_expr_any.downcast_ref::() { // `!col = true` --> `col = !true` if op != Operator::Eq && op != Operator::NotEq { @@ -1182,7 +1295,7 @@ fn rewrite_expr_to_prunable( { let left = Arc::clone(not.arg()); let right = Arc::new(phys_expr::NotExpr::new(Arc::clone(scalar_expr))); - Ok((left, reverse_operator(op)?, right)) + Ok((left, reverse_operator(op)?, right, vec![])) } else { plan_err!("Not with complex expression {column_expr:?} is not supported") } @@ -1277,10 +1390,10 @@ fn build_single_column_expr( let col_ref = Arc::new(column.clone()) as _; let min = required_columns - .min_column_expr(column, &col_ref, field) + .min_column_expr(column, &col_ref, field, &[]) .ok()?; let max = required_columns - .max_column_expr(column, &col_ref, field) + .max_column_expr(column, &col_ref, field, &[]) .ok()?; // remember -- we want an expression that is: @@ -1322,10 +1435,10 @@ fn build_is_null_column_expr( let null_count_field = &Field::new(field.name(), DataType::UInt64, true); if with_not { if let Ok(row_count_expr) = - required_columns.row_count_column_expr(col, expr, null_count_field) + required_columns.row_count_column_expr(col, expr, null_count_field, &[]) { required_columns - .null_count_column_expr(col, expr, null_count_field) + .null_count_column_expr(col, expr, null_count_field, &[]) .map(|null_count_column_expr| { // IsNotNull(column) => null_count != row_count Arc::new(phys_expr::BinaryExpr::new( @@ -1340,7 +1453,7 @@ fn build_is_null_column_expr( } } else { required_columns - .null_count_column_expr(col, expr, null_count_field) + .null_count_column_expr(col, expr, null_count_field, &[]) .map(|null_count_column_expr| { // IsNull(column) => null_count > 0 Arc::new(phys_expr::BinaryExpr::new( @@ -2269,16 +2382,16 @@ mod tests { } impl PruningStatistics for TestStatistics { - fn min_values(&self, column: &Column) -> Option { + fn min_values(&self, column: &PruningColumn) -> Option { self.stats - .get(column) + .get(&column.column) .map(|container_stats| container_stats.min()) .unwrap_or(None) } - fn max_values(&self, column: &Column) -> Option { + fn max_values(&self, column: &PruningColumn) -> Option { self.stats - .get(column) + .get(&column.column) .map(|container_stats| container_stats.max()) .unwrap_or(None) } @@ -2291,27 +2404,27 @@ mod tests { .unwrap_or(0) } - fn null_counts(&self, column: &Column) -> Option { + fn null_counts(&self, column: &PruningColumn) -> Option { self.stats - .get(column) + .get(&column.column) .map(|container_stats| container_stats.null_counts()) .unwrap_or(None) } - fn row_counts(&self, column: &Column) -> Option { + fn row_counts(&self, column: &PruningColumn) -> Option { self.stats - .get(column) + .get(&column.column) .map(|container_stats| container_stats.row_counts()) .unwrap_or(None) } fn contained( &self, - column: &Column, + column: &PruningColumn, values: &HashSet, ) -> Option { self.stats - .get(column) + .get(&column.column) .and_then(|container_stats| container_stats.contained(values)) } } @@ -2324,11 +2437,11 @@ mod tests { } impl PruningStatistics for OneContainerStats { - fn min_values(&self, _column: &Column) -> Option { + fn min_values(&self, _column: &PruningColumn) -> Option { self.min_values.clone() } - fn max_values(&self, _column: &Column) -> Option { + fn max_values(&self, _column: &PruningColumn) -> Option { self.max_values.clone() } @@ -2336,17 +2449,17 @@ mod tests { self.num_containers } - fn null_counts(&self, _column: &Column) -> Option { + fn null_counts(&self, _column: &PruningColumn) -> Option { None } - fn row_counts(&self, _column: &Column) -> Option { + fn row_counts(&self, _column: &PruningColumn) -> Option { None } fn contained( &self, - _column: &Column, + _column: &PruningColumn, _values: &HashSet, ) -> Option { None @@ -2374,7 +2487,7 @@ mod tests { // Fields in required schema should be unique, otherwise when creating batches // it will fail because of duplicate field names let mut fields = HashSet::new(); - for (_col, _ty, field) in p.required_columns().iter() { + for (_col, _ty, field, _fp) in p.required_columns().iter() { let was_new = fields.insert(field); if !was_new { panic!( @@ -3049,7 +3162,8 @@ mod tests { ( phys_expr::Column::new("c1", 0), StatisticsType::Min, - c1_min_field.with_nullable(true) // could be nullable if stats are not present + c1_min_field.with_nullable(true), // could be nullable if stats are not present + vec![], ) ); // c1 < 1 should add c1_null_count @@ -3059,7 +3173,8 @@ mod tests { ( phys_expr::Column::new("c1", 0), StatisticsType::NullCount, - c1_null_count_field.with_nullable(true) // could be nullable if stats are not present + c1_null_count_field.with_nullable(true), // could be nullable if stats are not present + vec![], ) ); // c1 < 1 should add row_count @@ -3069,7 +3184,8 @@ mod tests { ( phys_expr::Column::new("c1", 0), StatisticsType::RowCount, - row_count_field.with_nullable(true) // could be nullable if stats are not present + row_count_field.with_nullable(true), // could be nullable if stats are not present + vec![], ) ); // c2 = 2 should add c2_min and c2_max @@ -3079,7 +3195,8 @@ mod tests { ( phys_expr::Column::new("c2", 1), StatisticsType::Min, - c2_min_field.with_nullable(true) // could be nullable if stats are not present + c2_min_field.with_nullable(true), // could be nullable if stats are not present + vec![], ) ); let c2_max_field = Field::new("c2_max", DataType::Int32, false); @@ -3088,7 +3205,8 @@ mod tests { ( phys_expr::Column::new("c2", 1), StatisticsType::Max, - c2_max_field.with_nullable(true) // could be nullable if stats are not present + c2_max_field.with_nullable(true), // could be nullable if stats are not present + vec![], ) ); // c2 = 2 should add c2_null_count @@ -3098,7 +3216,8 @@ mod tests { ( phys_expr::Column::new("c2", 1), StatisticsType::NullCount, - c2_null_count_field.with_nullable(true) // could be nullable if stats are not present + c2_null_count_field.with_nullable(true), // could be nullable if stats are not present + vec![], ) ); // c2 = 1 should add row_count @@ -3108,7 +3227,8 @@ mod tests { ( phys_expr::Column::new("c1", 0), StatisticsType::RowCount, - row_count_field.with_nullable(true) // could be nullable if stats are not present + row_count_field.with_nullable(true), // could be nullable if stats are not present + vec![], ) ); // c2 = 3 shouldn't add any new statistics fields @@ -4785,7 +4905,7 @@ mod tests { let left_input = logical2physical(&left_input, &schema); let right_input = lit(ScalarValue::Int32(Some(12))); let right_input = logical2physical(&right_input, &schema); - let (result_left, _, result_right) = rewrite_expr_to_prunable( + let (result_left, _, result_right, _) = rewrite_expr_to_prunable( &left_input, Operator::Eq, &right_input, @@ -4800,7 +4920,7 @@ mod tests { let left_input = logical2physical(&left_input, &schema); let right_input = lit(ScalarValue::Decimal128(Some(12), 20, 3)); let right_input = logical2physical(&right_input, &schema); - let (result_left, _, result_right) = rewrite_expr_to_prunable( + let (result_left, _, result_right, _) = rewrite_expr_to_prunable( &left_input, Operator::Gt, &right_input, @@ -4815,7 +4935,7 @@ mod tests { let left_input = logical2physical(&left_input, &schema); let right_input = lit(ScalarValue::Int64(Some(12))); let right_input = logical2physical(&right_input, &schema); - let (result_left, _, result_right) = + let (result_left, _, result_right, _) = rewrite_expr_to_prunable(&left_input, Operator::Gt, &right_input, df_schema) .unwrap(); assert_eq!(result_left.to_string(), left_input.to_string());