From 8dbe39f827408f30054cc32b963b4a71930e39f9 Mon Sep 17 00:00:00 2001 From: Matthew Kim <38759997+friendlymatthew@users.noreply.github.com> Date: Thu, 12 Mar 2026 13:38:30 -0400 Subject: [PATCH] add native union type casting support --- arrow-cast/src/cast/mod.rs | 385 ++++++++++++++++++++++++++++++++++++- 1 file changed, 383 insertions(+), 2 deletions(-) diff --git a/arrow-cast/src/cast/mod.rs b/arrow-cast/src/cast/mod.rs index 9f1eba1057fd..443d8eb9eba4 100644 --- a/arrow-cast/src/cast/mod.rs +++ b/arrow-cast/src/cast/mod.rs @@ -63,7 +63,7 @@ use crate::parse::{ string_to_datetime, }; use arrow_array::{builder::*, cast::*, temporal_conversions::*, timezone::Tz, types::*, *}; -use arrow_buffer::{ArrowNativeType, OffsetBuffer, i256}; +use arrow_buffer::{ArrowNativeType, BooleanBuffer, NullBuffer, OffsetBuffer, i256}; use arrow_data::ArrayData; use arrow_data::transform::MutableArrayData; use arrow_schema::*; @@ -230,7 +230,24 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { } (Struct(_), _) => false, (_, Struct(_)) => false, - + (Union(from_fields, _), Union(to_fields, _)) => { + from_fields.len() == to_fields.len() + && from_fields.iter().all(|(from_id, from_field)| { + to_fields + .iter() + .find(|(to_id, _)| *to_id == from_id) + .is_some_and(|(_, to_field)| { + can_cast_types(from_field.data_type(), to_field.data_type()) + }) + }) + } + (Union(fields, _), _) => { + to_type == &Null + || fields + .iter() + .any(|(_, f)| can_cast_types(f.data_type(), to_type)) + } + (_, Union(_, _)) => false, (_, Boolean) => from_type.is_integer() || from_type.is_floating() || from_type.is_string(), (Boolean, _) => to_type.is_integer() || to_type.is_floating() || to_type.is_string(), @@ -807,6 +824,22 @@ pub fn cast_with_options( "Casting from type {from_type} to dictionary type {to_type} not supported", ))), }, + // Union casts must come before list/scalar arms since (_, List) would match unions + (Union(from_fields, _), Union(to_fields, to_mode)) => cast_union_to_union( + array.as_any().downcast_ref::().unwrap(), + from_fields, + to_fields, + *to_mode, + cast_options, + ), + (Union(_, _), _) => cast_union_to_type( + array.as_any().downcast_ref::().unwrap(), + to_type, + cast_options, + ), + (_, Union(_, _)) => Err(ArrowError::CastError(format!( + "Casting from {from_type} to {to_type} not supported" + ))), // Casting between lists of same types (cast inner values) (List(_), List(to)) => cast_list_values::(array, to, cast_options), (LargeList(_), LargeList(to)) => cast_list_values::(array, to, cast_options), @@ -2276,6 +2309,137 @@ fn cast_struct_fields_in_order( .collect::, ArrowError>>() } +/// Cast a UnionArray to another UnionArray by casting each child array. +fn cast_union_to_union( + array: &UnionArray, + from_fields: &UnionFields, + to_fields: &UnionFields, + _to_mode: UnionMode, + cast_options: &CastOptions, +) -> Result { + let type_ids = array.type_ids().clone(); + let offsets = array.offsets().cloned(); + + let new_children: Vec = from_fields + .iter() + .map(|(from_id, _from_field)| { + let (_, to_field) = to_fields + .iter() + .find(|(to_id, _)| *to_id == from_id) + .ok_or_else(|| { + ArrowError::CastError(format!( + "Cannot cast union: type_id {from_id} not found in target union" + )) + })?; + let child = array.child(from_id); + cast_with_options(child.as_ref(), to_field.data_type(), cast_options) + }) + .collect::>()?; + + let union = UnionArray::try_new(to_fields.clone(), type_ids, offsets, new_children)?; + Ok(Arc::new(union)) +} + +/// Cast a UnionArray to a non-union type. +/// +/// Finds the first variant whose type matches or can be cast to `to_type`, +/// extracts values for rows where that variant is active (NULLing other rows), +/// then casts the result to the target type. Prefers an exact type match over +/// a cast-compatible one. +/// +/// Since union extraction inherently introduces nulls for non-matching rows, +/// the target type's inner fields are made nullable to avoid validation errors +/// (e.g., casting to `List(non-nullable Utf8)` becomes `List(nullable Utf8)`). +fn cast_union_to_type( + array: &UnionArray, + to_type: &DataType, + cast_options: &CastOptions, +) -> Result { + let DataType::Union(fields, _) = array.data_type() else { + return Err(ArrowError::CastError( + "expected Union data type".to_string(), + )); + }; + + let len = array.len(); + let type_ids = array.type_ids(); + + if to_type == &DataType::Null { + return Ok(new_null_array(to_type, len)); + } + + // Make inner fields nullable since union extraction introduces nulls + let nullable_to_type = make_nullable(to_type); + + // Find a matching variant: prefer exact type match, fall back to cast-compatible + let matching_type_id = fields + .iter() + .find_map(|(id, f)| (f.data_type() == to_type).then_some(id)) + .or_else(|| { + fields + .iter() + .find_map(|(id, f)| can_cast_types(f.data_type(), to_type).then_some(id)) + }); + + let Some(match_id) = matching_type_id else { + return Err(ArrowError::CastError(format!( + "Casting from {} to {to_type} not supported: no union variant is cast-compatible", + array.data_type(), + ))); + }; + + let matching_child = array.child(match_id); + + // Extract values for the matching variant, NULLing rows with different active variants + match array.offsets() { + Some(offsets) => { + // Dense union: use offsets to index into child, take to gather + let indices = Int32Array::try_new( + offsets.clone(), + Some(BooleanBuffer::from_iter(type_ids.iter().map(|&tid| tid == match_id)).into()), + )?; + let extracted = take(matching_child.as_ref(), &indices, None)?; + if matching_child.data_type() == &nullable_to_type { + Ok(extracted) + } else { + cast_with_options(extracted.as_ref(), &nullable_to_type, cast_options) + } + } + None => { + // Sparse union: child is same length as union, apply null mask + let null_mask = BooleanBuffer::from_iter(type_ids.iter().map(|&tid| tid == match_id)); + let nulls = NullBuffer::new(null_mask); + + let data = matching_child + .to_data() + .into_builder() + .nulls(Some(nulls)) + .build()?; + let extracted = make_array(data); + + if matching_child.data_type() == &nullable_to_type { + Ok(extracted) + } else { + cast_with_options(extracted.as_ref(), &nullable_to_type, cast_options) + } + } + } +} + +/// Make inner fields of a data type nullable. +/// This is needed for union extraction which inherently introduces nulls. +fn make_nullable(data_type: &DataType) -> DataType { + use DataType::*; + match data_type { + List(f) => List(Arc::new(f.as_ref().clone().with_nullable(true))), + LargeList(f) => LargeList(Arc::new(f.as_ref().clone().with_nullable(true))), + FixedSizeList(f, s) => FixedSizeList(Arc::new(f.as_ref().clone().with_nullable(true)), *s), + ListView(f) => ListView(Arc::new(f.as_ref().clone().with_nullable(true))), + LargeListView(f) => LargeListView(Arc::new(f.as_ref().clone().with_nullable(true))), + _ => data_type.clone(), + } +} + fn cast_from_decimal( array: &dyn Array, base: D::Native, @@ -13347,4 +13511,221 @@ mod tests { assert_eq!(expected, actual); } + + #[test] + fn test_can_cast_union_to_type() { + let union_type = DataType::Union( + UnionFields::from_fields(vec![ + Field::new("i", DataType::Int32, false), + Field::new("s", DataType::Utf8, false), + ]), + UnionMode::Sparse, + ); + + // Can cast to Int32 (exact match on variant) + assert!(can_cast_types(&union_type, &DataType::Int32)); + // Can cast to Int64 (Int32 variant castable to Int64) + assert!(can_cast_types(&union_type, &DataType::Int64)); + // Can cast to Utf8 (exact match on variant) + assert!(can_cast_types(&union_type, &DataType::Utf8)); + // Cannot cast to a type no variant can cast to + assert!(!can_cast_types( + &union_type, + &DataType::Struct(Fields::empty()) + )); + // Cannot cast non-union to union + assert!(!can_cast_types(&DataType::Int32, &union_type)); + } + + #[test] + fn test_can_cast_union_to_union() { + let from = DataType::Union( + UnionFields::from_fields(vec![ + Field::new("i", DataType::Int32, false), + Field::new("s", DataType::Utf8, false), + ]), + UnionMode::Dense, + ); + let to = DataType::Union( + UnionFields::from_fields(vec![ + Field::new("i", DataType::Int64, false), + Field::new("s", DataType::LargeUtf8, false), + ]), + UnionMode::Dense, + ); + assert!(can_cast_types(&from, &to)); + + // Mismatched type IDs + let to_bad = DataType::Union( + UnionFields::try_new( + vec![2, 3], + vec![ + Field::new("i", DataType::Int64, false), + Field::new("s", DataType::LargeUtf8, false), + ], + ) + .expect("valid"), + UnionMode::Dense, + ); + assert!(!can_cast_types(&from, &to_bad)); + } + + #[test] + fn test_cast_sparse_union_to_int32() { + let fields = UnionFields::from_fields(vec![ + Field::new("i", DataType::Int32, false), + Field::new("s", DataType::Utf8, false), + ]); + + // 5 rows: i=1, s="hello", i=3, s="world", i=5 + let union = UnionArray::try_new( + fields, + vec![0, 1, 0, 1, 0].into(), + None, + vec![ + Arc::new(Int32Array::from(vec![1, 0, 3, 0, 5])), + Arc::new(StringArray::from(vec!["", "hello", "", "world", ""])), + ], + ) + .unwrap(); + + let result = cast(&union, &DataType::Int32).unwrap(); + let result = result.as_primitive::(); + + // Rows where variant is "i" get values, others are NULL + assert_eq!(result.value(0), 1); + assert!(result.is_null(1)); + assert_eq!(result.value(2), 3); + assert!(result.is_null(3)); + assert_eq!(result.value(4), 5); + } + + #[test] + fn test_cast_dense_union_to_int64() { + let fields = UnionFields::from_fields(vec![ + Field::new("i", DataType::Int32, false), + Field::new("s", DataType::Utf8, false), + ]); + + // 4 rows: i=10, s="a", i=20, s="b" + let union = UnionArray::try_new( + fields, + vec![0, 1, 0, 1].into(), + Some(vec![0, 0, 1, 1].into()), + vec![ + Arc::new(Int32Array::from(vec![10, 20])), + Arc::new(StringArray::from(vec!["a", "b"])), + ], + ) + .unwrap(); + + // Cast to Int64 (Int32 is cast-compatible) + let result = cast(&union, &DataType::Int64).unwrap(); + let result = result.as_primitive::(); + + assert_eq!(result.value(0), 10); + assert!(result.is_null(1)); + assert_eq!(result.value(2), 20); + assert!(result.is_null(3)); + } + + #[test] + fn test_cast_union_to_union() { + let from_fields = UnionFields::from_fields(vec![ + Field::new("i", DataType::Int32, false), + Field::new("f", DataType::Float32, false), + ]); + let to_fields = UnionFields::from_fields(vec![ + Field::new("i", DataType::Int64, false), + Field::new("f", DataType::Float64, false), + ]); + + let union = UnionArray::try_new( + from_fields, + vec![0, 1, 0].into(), + None, + vec![ + Arc::new(Int32Array::from(vec![1, 0, 3])), + Arc::new(Float32Array::from(vec![0.0, 2.5, 0.0])), + ], + ) + .unwrap(); + + let to_type = DataType::Union(to_fields.clone(), UnionMode::Sparse); + let result = cast(&union, &to_type).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(result.type_id(0), 0); + assert_eq!(result.type_id(1), 1); + assert_eq!(result.type_id(2), 0); + + let child_i = result.child(0); + let child_i = child_i.as_primitive::(); + assert_eq!(child_i.value(0), 1); + assert_eq!(child_i.value(2), 3); + + let child_f = result.child(1); + let child_f = child_f.as_primitive::(); + assert!((child_f.value(1) - 2.5).abs() < f64::EPSILON); + } + + #[test] + fn test_cast_union_prefers_exact_type_match() { + // Union(Int32, Int64) → Int64: Int64 is an exact match, so only + // rows with the Int64 variant produce values; Int32 rows become NULL. + let fields = UnionFields::from_fields(vec![ + Field::new("i32", DataType::Int32, false), + Field::new("i64", DataType::Int64, false), + ]); + + // 4 rows: i32=1, i64=2, i32=3, i64=4 + let union = UnionArray::try_new( + fields, + vec![0, 1, 0, 1].into(), + None, + vec![ + Arc::new(Int32Array::from(vec![1, 0, 3, 0])), + Arc::new(Int64Array::from(vec![0, 2, 0, 4])), + ], + ) + .unwrap(); + + let result = cast(&union, &DataType::Int64).unwrap(); + let result = result.as_primitive::(); + + // Int32 rows are NULL because we pick the exact-match Int64 variant + assert!(result.is_null(0)); + assert_eq!(result.value(1), 2); + assert!(result.is_null(2)); + assert_eq!(result.value(3), 4); + } + + #[test] + fn test_cast_dense_union_to_non_nullable_list() { + // Regression test: casting Union to LargeList with non-nullable inner field + // should succeed by making the inner field nullable (union extraction introduces nulls) + let fields = UnionFields::try_new( + vec![0, 1], + vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int64, false), + ], + ) + .unwrap(); + + let union = UnionArray::try_new( + fields, + vec![0, 1].into(), + Some(vec![0, 0].into()), + vec![ + Arc::new(Int32Array::from(vec![1])) as ArrayRef, + Arc::new(Int64Array::from(vec![2])) as ArrayRef, + ], + ) + .unwrap(); + + let to_type = DataType::LargeList(Arc::new(Field::new_list_field(DataType::Utf8, false))); + let result = cast(&union, &to_type); + assert!(result.is_ok(), "cast should succeed: {:?}", result.err()); + } }