Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 184 additions & 0 deletions datafusion/functions-nested/src/array_sum.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! [`ScalarUDFImpl`] definitions for array_sum function.

use crate::utils::make_scalar_function;
use arrow::array::{Array, ArrayRef, Float64Array, OffsetSizeTrait};
use arrow::datatypes::{
DataType,
DataType::{FixedSizeList, LargeList, List, Null},
Field,
};
use datafusion_common::cast::{as_float64_array, as_generic_list_array};
use datafusion_common::utils::{ListCoercion, coerced_type_with_base_type_only};
use datafusion_common::{Result, internal_err, plan_err, utils::take_function_args};
use datafusion_expr::{
ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
Volatility,
};
use datafusion_macros::user_doc;
use std::sync::Arc;

make_udf_expr_and_func!(
ArraySum,
array_sum,
array,
"returns the sum of elements in a numeric array.",
array_sum_udf
);

#[user_doc(
doc_section(label = "Array Functions"),
description = "Returns the sum of the elements of the input array, computed as `array[0] + array[1] + ...`. NULL elements are skipped (per SQL aggregate convention). Returns NULL if the input row is NULL or every element is NULL. Returns 0.0 for an empty array.",
syntax_example = "array_sum(array)",
sql_example = r#"```sql
> select array_sum([1.0, 2.0, 3.0]);
+----------------------------+
| array_sum(List([1.0,2.0,3.0])) |
+----------------------------+
| 6.0 |
+----------------------------+
```"#,
argument(
name = "array",
description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
)
)]
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct ArraySum {
signature: Signature,
aliases: Vec<String>,
}

impl Default for ArraySum {
fn default() -> Self {
Self::new()
}
}

impl ArraySum {
pub fn new() -> Self {
Self {
signature: Signature::user_defined(Volatility::Immutable),
aliases: vec!["list_sum".to_string()],
}
}
}

impl ScalarUDFImpl for ArraySum {
fn name(&self) -> &str {
"array_sum"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Float64)
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
let [arg_type] = take_function_args(self.name(), arg_types)?;
let coercion = Some(&ListCoercion::FixedSizedListToList);

if !matches!(arg_type, Null | List(_) | LargeList(_) | FixedSizeList(..)) {
return plan_err!("{} does not support type {arg_type}", self.name());
}

let coerced = if matches!(arg_type, Null) {
List(Arc::new(Field::new_list_field(DataType::Float64, true)))
} else {
coerced_type_with_base_type_only(arg_type, &DataType::Float64, coercion)
};

Ok(vec![coerced])
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
make_scalar_function(array_sum_inner)(&args.args)
}

fn aliases(&self) -> &[String] {
&self.aliases
}

fn documentation(&self) -> Option<&Documentation> {
self.doc()
}
}

fn array_sum_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
let [array] = take_function_args("array_sum", args)?;
match array.data_type() {
List(_) => general_array_sum::<i32>(array),
LargeList(_) => general_array_sum::<i64>(array),
arg_type => {
internal_err!("array_sum received unexpected type after coercion: {arg_type}")
}
}
}

fn general_array_sum<O: OffsetSizeTrait>(array: &ArrayRef) -> Result<ArrayRef> {
let list_array = as_generic_list_array::<O>(array)?;
let values = as_float64_array(list_array.values())?;
let offsets = list_array.value_offsets();

let mut builder = Float64Array::builder(list_array.len());

for row in 0..list_array.len() {
if list_array.is_null(row) {
builder.append_null();
continue;
}

let start = offsets[row].as_usize();
let end = offsets[row + 1].as_usize();
let len = end - start;

// Empty array: sum is the additive identity. Matches SQL SUM(<empty>) = 0
// and DuckDB's list_sum(([]) = 0 conventions.
if len == 0 {
builder.append_value(0.0);
continue;
}

// `slice` resets the logical offset to 0, so `i` below is 0-based within the slice.
let slice = values.slice(start, len);

// Skip NULL elements per SQL aggregate convention (matches PostgreSQL
// array_sum, DuckDB list_sum, Spark aggregate). A row with every
// element NULL yields NULL — same behavior as SQL SUM over all-NULL.
let mut sum = 0.0_f64;
let mut any_valid = false;
for i in 0..len {
if !slice.is_null(i) {
sum += slice.value(i);
any_valid = true;
}
}

if any_valid {
builder.append_value(sum);
} else {
builder.append_null();
}
}

Ok(Arc::new(builder.finish()))
}
3 changes: 3 additions & 0 deletions datafusion/functions-nested/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ pub mod array_filter;
pub mod array_has;
pub mod array_normalize;
pub mod array_scale;
pub mod array_sum;
pub mod array_transform;
pub mod arrays_zip;
pub mod cardinality;
Expand Down Expand Up @@ -98,6 +99,7 @@ pub mod expr_fn {
pub use super::array_has::array_has_any;
pub use super::array_normalize::array_normalize;
pub use super::array_scale::array_scale;
pub use super::array_sum::array_sum;
pub use super::array_transform::array_transform;
pub use super::arrays_zip::arrays_zip;
pub use super::cardinality::cardinality;
Expand Down Expand Up @@ -174,6 +176,7 @@ pub fn all_default_nested_functions() -> Vec<Arc<ScalarUDF>> {
length::array_length_udf(),
array_normalize::array_normalize_udf(),
array_scale::array_scale_udf(),
array_sum::array_sum_udf(),
cosine_distance::cosine_distance_udf(),
inner_product::inner_product_udf(),
distance::array_distance_udf(),
Expand Down
152 changes: 152 additions & 0 deletions datafusion/sqllogictest/test_files/array_sum.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

## array_sum

# Basic case
query R
select array_sum([1.0, 2.0, 3.0]);
----
6

# Single element
query R
select array_sum([5.0]);
----
5

# Negative values
query R
select array_sum([-1.0, -2.0, -3.0]);
----
-6

# Positive and negative cancel
query R
select array_sum([1.0, -1.0, 2.0, -2.0]);
----
0

# Empty array returns 0 (additive identity, per SQL SUM convention)
query R
select array_sum(arrow_cast(make_array(), 'List(Float64)'));
----
0

# Bare NULL input returns NULL row
query R
select array_sum(NULL);
----
NULL

# NULL elements are skipped (SQL aggregate convention)
query R
select array_sum([1.0, NULL, 3.0]);
----
4

# Single NULL among numeric: skip the NULL
query R
select array_sum([NULL, 10.0]);
----
10

# All-NULL array returns NULL row (matches SQL SUM over all-NULL)
query R
select array_sum(arrow_cast([NULL, NULL], 'List(Float64)'));
----
NULL

# LargeList support
query R
select array_sum(arrow_cast([1.0, 2.0, 3.0], 'LargeList(Float64)'));
----
6

# FixedSizeList input (coerced to List)
query R
select array_sum(arrow_cast([1.0, 2.0, 3.0], 'FixedSizeList(3, Float64)'));
----
6

# Float32 inner type (coerced to Float64)
query R
select array_sum(arrow_cast([1.0, 2.0, 3.0], 'List(Float32)'));
----
6

# Int64 inner type (coerced to Float64)
query R
select array_sum(arrow_cast([1, 2, 3], 'List(Int64)'));
----
6

# Integer literals (coerced to Float64)
query R
select array_sum([1, 2, 3]);
----
6

# Unsupported non-list input (plan error)
query error array_sum does not support type
select array_sum(1);

# Multi-row query with mix of normal, single-element, NULL elements, empty, NULL row
query R
select array_sum(column1) from (values
(make_array(1.0, 2.0, 3.0)),
(make_array(0.0)),
(make_array(1.0, NULL, 4.0)),
(arrow_cast(make_array(), 'List(Float64)')),
(NULL)
) as t(column1);
----
6
0
5
0
NULL

# Wrong arity (zero args)
query error array_sum function requires 1 argument, got 0
select array_sum();

# Wrong arity (two args)
query error array_sum function requires 1 argument, got 2
select array_sum([1.0], [2.0]);

# Return type is Float64
query RT
select array_sum([1.0, 2.0, 3.0]), arrow_typeof(array_sum([1.0, 2.0, 3.0]));
----
6 Float64

# list_sum alias produces the same result
query R
select list_sum([1.0, 2.0, 3.0]);
----
6

# list_sum alias with NULL row propagates correctly
query R
select list_sum(column1) from (values
(make_array(1.0, 2.0)),
(NULL)
) as t(column1);
----
3
NULL
Loading
Loading