diff --git a/vortex-array/src/aggregate_fn/fns/is_constant/mod.rs b/vortex-array/src/aggregate_fn/fns/is_constant/mod.rs index b14b2a050f8..9c3a6278f02 100644 --- a/vortex-array/src/aggregate_fn/fns/is_constant/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/is_constant/mod.rs @@ -13,7 +13,6 @@ mod varbin; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; -use vortex_mask::Mask; use self::bool::check_bool_constant; use self::decimal::check_decimal_constant; @@ -44,6 +43,7 @@ use crate::expr::stats::Precision; use crate::expr::stats::Stat; use crate::expr::stats::StatsProvider; use crate::expr::stats::StatsProviderExt; +use crate::mask::MaskNullAsFalse; use crate::scalar::Scalar; use crate::scalar_fn::fns::operators::Operator; @@ -74,7 +74,7 @@ fn arrays_value_equal(a: &ArrayRef, b: &ArrayRef, ctx: &mut ExecutionCtx) -> Vor // Compare values element-wise. Result is null where both inputs are null, // true/false where both are valid. let eq_result = a.binary(b.clone(), Operator::Eq)?; - let eq_result = eq_result.execute::(ctx)?; + let eq_result = eq_result.execute::(ctx)?.into_mask(); Ok(eq_result.true_count() == valid_count) } diff --git a/vortex-array/src/mask.rs b/vortex-array/src/mask.rs index ab035670e24..7fb25d4d627 100644 --- a/vortex-array/src/mask.rs +++ b/vortex-array/src/mask.rs @@ -12,38 +12,76 @@ use crate::Executable; use crate::ExecutionCtx; use crate::IntoArray; use crate::arrays::BoolArray; -use crate::arrays::Constant; use crate::columnar::Columnar; use crate::dtype::DType; +use crate::dtype::Nullability; impl Executable for Mask { + /// Executes a boolean array into a [`Mask`]. + /// + /// The array must have a non-nullable boolean dtype. Use [`MaskNullAsFalse`] to execute a + /// nullable boolean array, coercing null elements to `false`. fn execute(array: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { - if !matches!(array.dtype(), DType::Bool(_)) { - vortex_bail!("Mask array must have boolean dtype, not {}", array.dtype()); + if !matches!(array.dtype(), DType::Bool(Nullability::NonNullable)) { + vortex_bail!( + "Mask array must have boolean(NonNullable) dtype, not {}", + array.dtype() + ); } - if let Some(constant) = array.as_opt::() { - let mask_value = constant.scalar().as_bool().value().unwrap_or(false); - return Ok(Mask::new(array.len(), mask_value)); + let array_len = array.len(); + Ok(match array.execute(ctx)? { + Columnar::Constant(s) => { + Mask::new(array_len, s.scalar().as_bool().value().unwrap_or(false)) + } + Columnar::Canonical(a) => { + let bool = a.into_array().execute::(ctx)?; + Mask::from(bool.into_bit_buffer()) + } + }) + } +} + +/// An [`Executable`] target that executes a boolean array into a [`Mask`], coercing null +/// elements to `false`. +/// +/// [`Mask`] itself requires a non-nullable boolean array and errors on nullable input. Use this +/// wrapper for filter and pruning predicates over nullable data, where SQL semantics treat +/// `NULL` as not matching. +pub struct MaskNullAsFalse(Mask); + +impl MaskNullAsFalse { + /// Consumes the wrapper and returns the underlying [`Mask`]. + pub fn into_mask(self) -> Mask { + self.0 + } +} + +impl From for Mask { + fn from(value: MaskNullAsFalse) -> Self { + value.0 + } +} + +impl Executable for MaskNullAsFalse { + fn execute(array: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { + if !matches!(array.dtype(), DType::Bool(_)) { + vortex_bail!("Mask array must have boolean dtype, not {}", array.dtype()); } let array_len = array.len(); - Ok(match array.execute(ctx)? { + Ok(Self(match array.execute(ctx)? { Columnar::Constant(s) => { Mask::new(array_len, s.scalar().as_bool().value().unwrap_or(false)) } Columnar::Canonical(a) => { let bool = a.into_array().execute::(ctx)?; - let mask = bool + let validity = bool .as_ref() .validity()? .execute_mask(bool.as_ref().len(), ctx)?; - let bits = bool.into_bit_buffer(); - // To handle nullable boolean arrays, we treat nulls as false in the mask. - // TODO(ngates): is this correct? Feels like we should just force the caller to - // pass non-nullable boolean arrays. - mask.bitand(&Mask::from(bits)) + validity.bitand(&Mask::from(bool.into_bit_buffer())) } - }) + })) } } diff --git a/vortex-cuda/src/layout.rs b/vortex-cuda/src/layout.rs index 76e5a5ba5fc..da0fa28931d 100644 --- a/vortex-cuda/src/layout.rs +++ b/vortex-cuda/src/layout.rs @@ -59,6 +59,7 @@ use vortex::layout::sequence::SendableSequentialStream; use vortex::layout::sequence::SequencePointer; use vortex::layout::vtable; use vortex::mask::Mask; +use vortex::mask::MaskNullAsFalse; use vortex::scalar::Scalar; use vortex::scalar::ScalarTruncation; use vortex::scalar::lower_bound; @@ -329,12 +330,12 @@ impl LayoutReader for CudaFlatReader { let array = array.apply(&expr)?; let array = array.filter(mask.clone())?; let mut ctx = session.create_execution_ctx(); - let array_mask = array.execute::(&mut ctx)?; + let array_mask = array.execute::(&mut ctx)?.into_mask(); mask.intersect_by_rank(&array_mask) } else { let array = array.apply(&expr)?; let mut ctx = session.create_execution_ctx(); - let array_mask = array.execute::(&mut ctx)?; + let array_mask = array.execute::(&mut ctx)?.into_mask(); mask.bitand(&array_mask) }; diff --git a/vortex-layout/src/layouts/dict/reader.rs b/vortex-layout/src/layouts/dict/reader.rs index 96f12d53ece..19e002b42ee 100644 --- a/vortex-layout/src/layouts/dict/reader.rs +++ b/vortex-layout/src/layouts/dict/reader.rs @@ -21,6 +21,7 @@ use vortex_array::dtype::DType; use vortex_array::dtype::FieldMask; use vortex_array::expr::Expression; use vortex_array::expr::root; +use vortex_array::mask::MaskNullAsFalse; use vortex_array::optimizer::ArrayOptimizer; use vortex_error::VortexError; use vortex_error::VortexExpect; @@ -212,7 +213,10 @@ impl LayoutReader for DictReader { let mask = mask.await?; let mut ctx = session.create_execution_ctx(); - let dict_mask = values.take(codes)?.execute::(&mut ctx)?; + let dict_mask = values + .take(codes)? + .execute::(&mut ctx)? + .into_mask(); Ok(mask.bitand(&dict_mask)) })) diff --git a/vortex-layout/src/layouts/flat/reader.rs b/vortex-layout/src/layouts/flat/reader.rs index 8817fe5c678..6b4e7b0c240 100644 --- a/vortex-layout/src/layouts/flat/reader.rs +++ b/vortex-layout/src/layouts/flat/reader.rs @@ -15,6 +15,7 @@ use vortex_array::VortexSessionExecute; use vortex_array::dtype::DType; use vortex_array::dtype::FieldMask; use vortex_array::expr::Expression; +use vortex_array::mask::MaskNullAsFalse; use vortex_array::serde::SerializedArray; use vortex_error::VortexExpect; use vortex_error::VortexResult; @@ -155,14 +156,14 @@ impl LayoutReader for FlatReader { let array = array.apply(&expr)?; let array = array.filter(mask.clone())?; let mut ctx = session.create_execution_ctx(); - let array_mask = array.execute::(&mut ctx)?; + let array_mask = array.execute::(&mut ctx)?.into_mask(); mask.intersect_by_rank(&array_mask) } else { // Run over the full array, with a simpler bitand at the end. let array = array.apply(&expr)?; let mut ctx = session.create_execution_ctx(); - let array_mask = array.execute::(&mut ctx)?; + let array_mask = array.execute::(&mut ctx)?.into_mask(); mask.bitand(&array_mask) }; diff --git a/vortex-layout/src/layouts/partitioned.rs b/vortex-layout/src/layouts/partitioned.rs index d5ac3d44c55..deaccd92513 100644 --- a/vortex-layout/src/layouts/partitioned.rs +++ b/vortex-layout/src/layouts/partitioned.rs @@ -15,10 +15,10 @@ use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; use vortex_array::expr::Expression; use vortex_array::expr::transform::PartitionedExpr; +use vortex_array::mask::MaskNullAsFalse; use vortex_array::validity::Validity; use vortex_error::VortexError; use vortex_error::VortexResult; -use vortex_mask::Mask; use vortex_session::VortexSession; use crate::ArrayFuture; @@ -90,7 +90,10 @@ impl PartitionedExprEval

for PartitionedExpr

{ .into_array(); let mut ctx = session.create_execution_ctx(); - let root_mask = root_scope.apply(&self.root)?.execute::(&mut ctx)?; + let root_mask = root_scope + .apply(&self.root)? + .execute::(&mut ctx)? + .into_mask(); let mask = mask.bitand(&root_mask); diff --git a/vortex-layout/src/layouts/row_idx/mod.rs b/vortex-layout/src/layouts/row_idx/mod.rs index 18fc62b05cd..5693ae7dc38 100644 --- a/vortex-layout/src/layouts/row_idx/mod.rs +++ b/vortex-layout/src/layouts/row_idx/mod.rs @@ -32,6 +32,7 @@ use vortex_array::expr::root; use vortex_array::expr::transform::PartitionedExpr; use vortex_array::expr::transform::partition; use vortex_array::expr::transform::replace; +use vortex_array::mask::MaskNullAsFalse; use vortex_array::scalar::PValue; use vortex_error::VortexExpect; use vortex_error::VortexResult; @@ -295,7 +296,10 @@ fn row_idx_mask_future( let array = idx_array(row_offset, &row_range).into_array(); let mut ctx = session.create_execution_ctx(); - let result_mask = array.apply(&expr)?.execute::(&mut ctx)?; + let result_mask = array + .apply(&expr)? + .execute::(&mut ctx)? + .into_mask(); Ok(result_mask.bitand(&mask.await?)) }) diff --git a/vortex-layout/src/layouts/zoned/zone_map.rs b/vortex-layout/src/layouts/zoned/zone_map.rs index 16360ff287d..69b98d5e41a 100644 --- a/vortex-layout/src/layouts/zoned/zone_map.rs +++ b/vortex-layout/src/layouts/zoned/zone_map.rs @@ -14,6 +14,7 @@ use vortex_array::arrays::StructArray; use vortex_array::dtype::DType; use vortex_array::expr::Expression; use vortex_array::expr::stats::Stat; +use vortex_array::mask::MaskNullAsFalse; use vortex_array::scalar_fn::internal::row_count::contains_row_count; use vortex_array::scalar_fn::internal::row_count::substitute_row_count; use vortex_array::validity::Validity; @@ -99,12 +100,14 @@ impl ZoneMap { let applied = self.array.clone().into_array().apply(predicate)?; if num_zones == 0 || !contains_row_count(&applied) { - return applied.execute::(&mut ctx); + return Ok(applied.execute::(&mut ctx)?.into_mask()); } let row_count_array = row_count_array(self.zone_len, self.row_count, num_zones)?; let substituted = substitute_row_count(applied, &row_count_array)?; - substituted.execute::(&mut ctx) + Ok(substituted + .execute::(&mut ctx)? + .into_mask()) } } diff --git a/vortex/src/lib.rs b/vortex/src/lib.rs index 8668de339cb..b5241797a01 100644 --- a/vortex/src/lib.rs +++ b/vortex/src/lib.rs @@ -76,6 +76,7 @@ pub mod layout { } pub mod mask { + pub use vortex_array::mask::MaskNullAsFalse; pub use vortex_mask::*; }