From 28c8cc7fe248240d43f4be0bd38588b5272ba31c Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Tue, 26 May 2026 18:29:36 +0100 Subject: [PATCH] u Signed-off-by: Joe Isaacs --- vortex-array/benches/cast_primitive.rs | 58 ++++++++ .../src/arrays/primitive/compute/cast.rs | 129 +++++++++++++----- 2 files changed, 153 insertions(+), 34 deletions(-) diff --git a/vortex-array/benches/cast_primitive.rs b/vortex-array/benches/cast_primitive.rs index 86895fb2ce7..3921ef2c777 100644 --- a/vortex-array/benches/cast_primitive.rs +++ b/vortex-array/benches/cast_primitive.rs @@ -1,6 +1,12 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use std::sync::Arc; + +use arrow_array::UInt32Array; +use arrow_buffer::NullBuffer; +use arrow_cast::CastOptions; +use arrow_schema::DataType as ArrowDataType; use divan::Bencher; use rand::prelude::*; use vortex_array::Canonical; @@ -13,6 +19,9 @@ use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; use vortex_array::expr::stats::Stat; +use vortex_array::validity::Validity; +use vortex_buffer::BitBuffer; +use vortex_buffer::Buffer; fn main() { divan::main(); @@ -46,3 +55,52 @@ fn cast_u16_to_u32(bencher: Bencher) { .execute::(&mut LEGACY_SESSION.create_execution_ctx()) }); } + +// Slow-path inputs: u32 -> u8 with mixed validity, all values in-range (so the cast succeeds), +// no precomputed min/max stats — forces `cast_values` and the `Mask::Values` arm. +fn slow_path_inputs() -> (Vec, BitBuffer) { + let mut rng = StdRng::seed_from_u64(42); + let values: Vec = (0..N).map(|_| rng.random_range(0..=200u32)).collect(); + let validity: BitBuffer = (0..N).map(|_| rng.random_bool(0.7)).collect(); + (values, validity) +} + +#[divan::bench] +fn cast_u32_u8_vortex(bencher: Bencher) { + let (values, validity) = slow_path_inputs(); + let arr = PrimitiveArray::new(Buffer::from(values), Validity::from(validity)).into_array(); + bencher.with_inputs(|| arr.clone()).bench_refs(|a| { + #[expect(clippy::unwrap_used)] + a.cast(DType::Primitive(PType::U8, Nullability::Nullable)) + .unwrap() + .execute::(&mut LEGACY_SESSION.create_execution_ctx()) + }); +} + +#[divan::bench] +fn cast_u32_u8_arrow(bencher: Bencher) { + let (values, validity) = slow_path_inputs(); + let nulls = NullBuffer::from(validity.iter().collect::>()); + let arr: Arc = Arc::new(UInt32Array::new(values.into(), Some(nulls))); + let opts = CastOptions { safe: false, ..Default::default() }; + bencher.with_inputs(|| Arc::clone(&arr)).bench_refs(|a| { + #[expect(clippy::unwrap_used)] + arrow_cast::cast_with_options(a.as_ref(), &ArrowDataType::UInt8, &opts).unwrap() + }); +} + +// Pure scalar baseline: no validity mask at all, checked cast on every element. Bails on +// the first overflow (which never happens for our in-range inputs). +#[divan::bench] +fn cast_u32_u8_checked_no_validity(bencher: Bencher) { + let (values, _) = slow_path_inputs(); + bencher.with_inputs(|| values.clone()).bench_refs(|vs| { + let mut out = Vec::with_capacity(vs.len()); + for &v in vs.iter() { + #[expect(clippy::expect_used)] + out.push(u8::try_from(v).expect("in-range")); + } + out + }); +} + diff --git a/vortex-array/src/arrays/primitive/compute/cast.rs b/vortex-array/src/arrays/primitive/compute/cast.rs index 10c0b8d6eba..e475519dd48 100644 --- a/vortex-array/src/arrays/primitive/compute/cast.rs +++ b/vortex-array/src/arrays/primitive/compute/cast.rs @@ -3,11 +3,11 @@ use num_traits::AsPrimitive; use num_traits::NumCast; +use vortex_buffer::BitBuffer; use vortex_buffer::Buffer; use vortex_buffer::BufferMut; use vortex_error::VortexResult; use vortex_error::vortex_bail; -use vortex_error::vortex_err; use vortex_mask::Mask; use crate::ArrayRef; @@ -102,9 +102,11 @@ impl CastKernel for Primitive { } } -/// Cast values from `F` to `T`. For infallible casts this is a pure pass; for fallible casts -/// each valid value goes through a checked `NumCast::from` and the kernel bails if any of them -/// overflow `T`. Invalid positions use the wrapping `as` cast since their values are masked out. +/// Cast values from `F` to `T`. For infallible casts this is a pure pass. For fallible casts +/// where cached stats can't prove fit, the hot loop is unconditional `as_()` + a parallel range +/// check whose results OR-reduce into a single `fail_acc` word — one pass, no `?` in the inner +/// body, fully SIMD-vectorizable. If `fail_acc` is set, a cold scalar pass walks the array to +/// attribute the failure to a specific index for a precise error message. fn cast_values( array: ArrayView<'_, Primitive>, new_validity: Validity, @@ -116,43 +118,102 @@ where { let values = array.as_slice::(); - // Fast path: statically infallible, or cached min/max prove every valid value fits in `T`. - // The cached check never triggers a stats computation — if the bounds aren't already known - // we fall through to the per-lane loop below. if values_always_fit(F::PTYPE, T::PTYPE) || values_fit_in(array, T::PTYPE, ctx, false) { return Ok(PrimitiveArray::new(cast::(values), new_validity).into_array()); } - // TODO(joe): if the values source and target have the same bit-width we can - // mutate in place. - - // Fallible: invalid lanes are pre-multiplied to zero so the checked cast always succeeds for - // them; valid lanes go through `NumCast::from` and the whole cast bails on the first overflow. let mask = array.validity()?.execute_mask(array.len(), ctx)?; - let overflow = || { - vortex_err!( + let mut buffer = BufferMut::::zeroed(values.len()); + let out = buffer.as_mut_slice(); + let mut fail_acc: u32 = 0; + + match &mask { + Mask::AllFalse(_) => { + // No valid lanes — buffer is already zeroed. + } + Mask::AllTrue(_) => { + for (i, &v) in values.iter().enumerate() { + out[i] = v.as_(); + fail_acc |= ::from(v).is_none() as u32; + } + } + Mask::Values(m) => { + fail_acc = fallible_cast_with_validity::(values, m.bit_buffer(), out); + } + } + + if fail_acc != 0 { + // Cold scalar fallback: identify the failing index for a precise error. + for (idx, (&v, valid)) in values.iter().zip(mask_iter(&mask, values.len())).enumerate() { + if valid && ::from(v).is_none() { + vortex_bail!( + Compute: "Cannot cast {} to {} — value at index {} exceeds target range", + F::PTYPE, T::PTYPE, idx, + ); + } + } + // Should be unreachable, but emit a generic error if the hot/cold paths disagree. + vortex_bail!( Compute: "Cannot cast {} to {} — value exceeds target range", F::PTYPE, T::PTYPE, - ) - }; - let buffer: Buffer = match &mask { - Mask::AllTrue(_) => BufferMut::try_from_trusted_len_iter( - values - .iter() - .map(|&v| ::from(v).ok_or_else(overflow)), - )? - .freeze(), - Mask::AllFalse(_) => BufferMut::::zeroed(values.len()).freeze(), - Mask::Values(m) => BufferMut::try_from_trusted_len_iter( - values.iter().zip(m.bit_buffer().iter()).map(|(&v, valid)| { - let factor = if valid { F::one() } else { F::zero() }; - ::from(v * factor).ok_or_else(overflow) - }), - )? - .freeze(), - }; - - Ok(PrimitiveArray::new(buffer, new_validity).into_array()) + ); + } + + Ok(PrimitiveArray::new(buffer.freeze(), new_validity).into_array()) +} + +/// Unconditional `as_()` cast of every lane in `values` into `out`, with a SIMD-reducible +/// overflow detector that returns a nonzero failure word iff any valid lane would overflow `T`. +/// Walks validity in 64-lane blocks (`from_fn` lane-mask + uniform inner body, fully unrollable) +/// and bails at the block boundary on the first failure — branch is outside the SIMD region. +#[inline] +fn fallible_cast_with_validity( + values: &[F], + bit_buffer: &BitBuffer, + out: &mut [T], +) -> u32 +where + F: NativePType + AsPrimitive, + T: NativePType, +{ + debug_assert_eq!(values.len(), bit_buffer.len()); + debug_assert_eq!(values.len(), out.len()); + let bit_chunks = bit_buffer.chunks(); + let mut fail_acc: u32 = 0; + let mut idx = 0usize; + for word in bit_chunks.iter() { + let valid: [bool; 64] = std::array::from_fn(|i| (word >> i) & 1 != 0); + for i in 0..64 { + let v = values[idx + i]; + out[idx + i] = v.as_(); + // Mask invalid lanes to F::zero (always fits any T) so they don't pollute fail_acc. + let v_for_check = if valid[i] { v } else { F::zero() }; + fail_acc |= ::from(v_for_check).is_none() as u32; + } + idx += 64; + if fail_acc != 0 { + return fail_acc; + } + } + let rem = bit_chunks.remainder_bits(); + for b in 0..bit_chunks.remainder_len() { + let v = values[idx + b]; + out[idx + b] = v.as_(); + let valid = (rem >> b) & 1 != 0; + let v_for_check = if valid { v } else { F::zero() }; + fail_acc |= ::from(v_for_check).is_none() as u32; + } + fail_acc +} + +/// Cold-path iterator over a `Mask` as a sequence of `bool`s. Only used after `fail_acc != 0` +/// to attribute the failure to a specific index. +fn mask_iter<'a>(mask: &'a Mask, len: usize) -> Box + 'a> { + match mask { + Mask::AllTrue(_) => Box::new(std::iter::repeat_n(true, len)), + Mask::AllFalse(_) => Box::new(std::iter::repeat_n(false, len)), + Mask::Values(m) => Box::new(m.bit_buffer().iter()), + } } /// Out-of-range values at invalid positions are truncated/wrapped by `as`, which is fine because