diff --git a/datafusion/functions-nested/src/array_sum.rs b/datafusion/functions-nested/src/array_sum.rs new file mode 100644 index 0000000000000..6314a64813fcf --- /dev/null +++ b/datafusion/functions-nested/src/array_sum.rs @@ -0,0 +1,184 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for array_sum function. + +use crate::utils::make_scalar_function; +use arrow::array::{Array, ArrayRef, Float64Array, OffsetSizeTrait}; +use arrow::datatypes::{ + DataType, + DataType::{FixedSizeList, LargeList, List, Null}, + Field, +}; +use datafusion_common::cast::{as_float64_array, as_generic_list_array}; +use datafusion_common::utils::{ListCoercion, coerced_type_with_base_type_only}; +use datafusion_common::{Result, internal_err, plan_err, utils::take_function_args}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, +}; +use datafusion_macros::user_doc; +use std::sync::Arc; + +make_udf_expr_and_func!( + ArraySum, + array_sum, + array, + "returns the sum of elements in a numeric array.", + array_sum_udf +); + +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns the sum of the elements of the input array, computed as `array[0] + array[1] + ...`. NULL elements are skipped (per SQL aggregate convention). Returns NULL if the input row is NULL or every element is NULL. Returns 0.0 for an empty array.", + syntax_example = "array_sum(array)", + sql_example = r#"```sql +> select array_sum([1.0, 2.0, 3.0]); ++----------------------------+ +| array_sum(List([1.0,2.0,3.0])) | ++----------------------------+ +| 6.0 | ++----------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ArraySum { + signature: Signature, + aliases: Vec, +} + +impl Default for ArraySum { + fn default() -> Self { + Self::new() + } +} + +impl ArraySum { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec!["list_sum".to_string()], + } + } +} + +impl ScalarUDFImpl for ArraySum { + fn name(&self) -> &str { + "array_sum" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let [arg_type] = take_function_args(self.name(), arg_types)?; + let coercion = Some(&ListCoercion::FixedSizedListToList); + + if !matches!(arg_type, Null | List(_) | LargeList(_) | FixedSizeList(..)) { + return plan_err!("{} does not support type {arg_type}", self.name()); + } + + let coerced = if matches!(arg_type, Null) { + List(Arc::new(Field::new_list_field(DataType::Float64, true))) + } else { + coerced_type_with_base_type_only(arg_type, &DataType::Float64, coercion) + }; + + Ok(vec![coerced]) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(array_sum_inner)(&args.args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +fn array_sum_inner(args: &[ArrayRef]) -> Result { + let [array] = take_function_args("array_sum", args)?; + match array.data_type() { + List(_) => general_array_sum::(array), + LargeList(_) => general_array_sum::(array), + arg_type => { + internal_err!("array_sum received unexpected type after coercion: {arg_type}") + } + } +} + +fn general_array_sum(array: &ArrayRef) -> Result { + let list_array = as_generic_list_array::(array)?; + let values = as_float64_array(list_array.values())?; + let offsets = list_array.value_offsets(); + + let mut builder = Float64Array::builder(list_array.len()); + + for row in 0..list_array.len() { + if list_array.is_null(row) { + builder.append_null(); + continue; + } + + let start = offsets[row].as_usize(); + let end = offsets[row + 1].as_usize(); + let len = end - start; + + // Empty array: sum is the additive identity. Matches SQL SUM() = 0 + // and DuckDB's list_sum(([]) = 0 conventions. + if len == 0 { + builder.append_value(0.0); + continue; + } + + // `slice` resets the logical offset to 0, so `i` below is 0-based within the slice. + let slice = values.slice(start, len); + + // Skip NULL elements per SQL aggregate convention (matches PostgreSQL + // array_sum, DuckDB list_sum, Spark aggregate). A row with every + // element NULL yields NULL — same behavior as SQL SUM over all-NULL. + let mut sum = 0.0_f64; + let mut any_valid = false; + for i in 0..len { + if !slice.is_null(i) { + sum += slice.value(i); + any_valid = true; + } + } + + if any_valid { + builder.append_value(sum); + } else { + builder.append_null(); + } + } + + Ok(Arc::new(builder.finish())) +} diff --git a/datafusion/functions-nested/src/lib.rs b/datafusion/functions-nested/src/lib.rs index aacc4dbd3d481..b12bedcc44839 100644 --- a/datafusion/functions-nested/src/lib.rs +++ b/datafusion/functions-nested/src/lib.rs @@ -48,6 +48,7 @@ pub mod array_filter; pub mod array_has; pub mod array_normalize; pub mod array_scale; +pub mod array_sum; pub mod array_transform; pub mod arrays_zip; pub mod cardinality; @@ -98,6 +99,7 @@ pub mod expr_fn { pub use super::array_has::array_has_any; pub use super::array_normalize::array_normalize; pub use super::array_scale::array_scale; + pub use super::array_sum::array_sum; pub use super::array_transform::array_transform; pub use super::arrays_zip::arrays_zip; pub use super::cardinality::cardinality; @@ -174,6 +176,7 @@ pub fn all_default_nested_functions() -> Vec> { length::array_length_udf(), array_normalize::array_normalize_udf(), array_scale::array_scale_udf(), + array_sum::array_sum_udf(), cosine_distance::cosine_distance_udf(), inner_product::inner_product_udf(), distance::array_distance_udf(), diff --git a/datafusion/sqllogictest/test_files/array_sum.slt b/datafusion/sqllogictest/test_files/array_sum.slt new file mode 100644 index 0000000000000..740ecce76e054 --- /dev/null +++ b/datafusion/sqllogictest/test_files/array_sum.slt @@ -0,0 +1,152 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +## array_sum + +# Basic case +query R +select array_sum([1.0, 2.0, 3.0]); +---- +6 + +# Single element +query R +select array_sum([5.0]); +---- +5 + +# Negative values +query R +select array_sum([-1.0, -2.0, -3.0]); +---- +-6 + +# Positive and negative cancel +query R +select array_sum([1.0, -1.0, 2.0, -2.0]); +---- +0 + +# Empty array returns 0 (additive identity, per SQL SUM convention) +query R +select array_sum(arrow_cast(make_array(), 'List(Float64)')); +---- +0 + +# Bare NULL input returns NULL row +query R +select array_sum(NULL); +---- +NULL + +# NULL elements are skipped (SQL aggregate convention) +query R +select array_sum([1.0, NULL, 3.0]); +---- +4 + +# Single NULL among numeric: skip the NULL +query R +select array_sum([NULL, 10.0]); +---- +10 + +# All-NULL array returns NULL row (matches SQL SUM over all-NULL) +query R +select array_sum(arrow_cast([NULL, NULL], 'List(Float64)')); +---- +NULL + +# LargeList support +query R +select array_sum(arrow_cast([1.0, 2.0, 3.0], 'LargeList(Float64)')); +---- +6 + +# FixedSizeList input (coerced to List) +query R +select array_sum(arrow_cast([1.0, 2.0, 3.0], 'FixedSizeList(3, Float64)')); +---- +6 + +# Float32 inner type (coerced to Float64) +query R +select array_sum(arrow_cast([1.0, 2.0, 3.0], 'List(Float32)')); +---- +6 + +# Int64 inner type (coerced to Float64) +query R +select array_sum(arrow_cast([1, 2, 3], 'List(Int64)')); +---- +6 + +# Integer literals (coerced to Float64) +query R +select array_sum([1, 2, 3]); +---- +6 + +# Unsupported non-list input (plan error) +query error array_sum does not support type +select array_sum(1); + +# Multi-row query with mix of normal, single-element, NULL elements, empty, NULL row +query R +select array_sum(column1) from (values + (make_array(1.0, 2.0, 3.0)), + (make_array(0.0)), + (make_array(1.0, NULL, 4.0)), + (arrow_cast(make_array(), 'List(Float64)')), + (NULL) +) as t(column1); +---- +6 +0 +5 +0 +NULL + +# Wrong arity (zero args) +query error array_sum function requires 1 argument, got 0 +select array_sum(); + +# Wrong arity (two args) +query error array_sum function requires 1 argument, got 2 +select array_sum([1.0], [2.0]); + +# Return type is Float64 +query RT +select array_sum([1.0, 2.0, 3.0]), arrow_typeof(array_sum([1.0, 2.0, 3.0])); +---- +6 Float64 + +# list_sum alias produces the same result +query R +select list_sum([1.0, 2.0, 3.0]); +---- +6 + +# list_sum alias with NULL row propagates correctly +query R +select list_sum(column1) from (values + (make_array(1.0, 2.0)), + (NULL) +) as t(column1); +---- +3 +NULL diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 955654d80e688..b599cfe3b1cfd 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -3289,6 +3289,7 @@ _Alias of [current_date](#current_date)._ - [array_scale](#array_scale) - [array_slice](#array_slice) - [array_sort](#array_sort) +- [array_sum](#array_sum) - [array_to_string](#array_to_string) - [array_transform](#array_transform) - [array_union](#array_union) @@ -3345,6 +3346,7 @@ _Alias of [current_date](#current_date)._ - [list_scale](#list_scale) - [list_slice](#list_slice) - [list_sort](#list_sort) +- [list_sum](#list_sum) - [list_to_string](#list_to_string) - [list_transform](#list_transform) - [list_union](#list_union) @@ -4483,6 +4485,33 @@ array_sort(array, desc, nulls_first) - list_sort +### `array_sum` + +Returns the sum of the elements of the input array, computed as `array[0] + array[1] + ...`. NULL elements are skipped (per SQL aggregate convention). Returns NULL if the input row is NULL or every element is NULL. Returns 0.0 for an empty array. + +```sql +array_sum(array) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. + +#### Example + +```sql +> select array_sum([1.0, 2.0, 3.0]); ++----------------------------+ +| array_sum(List([1.0,2.0,3.0])) | ++----------------------------+ +| 6.0 | ++----------------------------+ +``` + +#### Aliases + +- list_sum + ### `array_to_string` Converts each element to its text representation. @@ -4951,6 +4980,10 @@ _Alias of [array_slice](#array_slice)._ _Alias of [array_sort](#array_sort)._ +### `list_sum` + +_Alias of [array_sum](#array_sum)._ + ### `list_to_string` _Alias of [array_to_string](#array_to_string)._