From b4beb7199b026d2e7d3f3692ed148cd0aae0fc00 Mon Sep 17 00:00:00 2001 From: Neil Conway Date: Mon, 25 May 2026 10:58:35 -0400 Subject: [PATCH] . --- datafusion/functions/benches/atan2.rs | 56 +------------ datafusion/functions/src/macros.rs | 79 ++++++------------- datafusion/functions/src/math/monotonicity.rs | 10 +-- datafusion/sqllogictest/test_files/scalar.slt | 18 ++++- .../source/user-guide/sql/scalar_functions.md | 10 +-- 5 files changed, 51 insertions(+), 122 deletions(-) diff --git a/datafusion/functions/benches/atan2.rs b/datafusion/functions/benches/atan2.rs index f1c9756a0cc08..2a95286a99d1c 100644 --- a/datafusion/functions/benches/atan2.rs +++ b/datafusion/functions/benches/atan2.rs @@ -17,7 +17,7 @@ extern crate criterion; -use arrow::datatypes::{DataType, Field, Float32Type, Float64Type}; +use arrow::datatypes::{DataType, Field, Float64Type}; use arrow::util::bench_util::create_primitive_array; use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_common::ScalarValue; @@ -32,34 +32,6 @@ fn criterion_benchmark(c: &mut Criterion) { let config_options = Arc::new(ConfigOptions::default()); for size in [1024, 4096, 8192] { - let y_f32 = Arc::new(create_primitive_array::(size, 0.2)); - let x_f32 = Arc::new(create_primitive_array::(size, 0.2)); - let f32_args = vec![ColumnarValue::Array(y_f32), ColumnarValue::Array(x_f32)]; - let f32_arg_fields = f32_args - .iter() - .enumerate() - .map(|(idx, arg)| { - Field::new(format!("arg_{idx}"), arg.data_type(), true).into() - }) - .collect::>(); - let return_field_f32 = Field::new("f", DataType::Float32, true).into(); - - c.bench_function(&format!("atan2 f32 array: {size}"), |b| { - b.iter(|| { - black_box( - atan2_fn - .invoke_with_args(ScalarFunctionArgs { - args: f32_args.clone(), - arg_fields: f32_arg_fields.clone(), - number_rows: size, - return_field: Arc::clone(&return_field_f32), - config_options: Arc::clone(&config_options), - }) - .unwrap(), - ) - }) - }); - let y_f64 = Arc::new(create_primitive_array::(size, 0.2)); let x_f64 = Arc::new(create_primitive_array::(size, 0.2)); let f64_args = vec![ColumnarValue::Array(y_f64), ColumnarValue::Array(x_f64)]; @@ -89,32 +61,6 @@ fn criterion_benchmark(c: &mut Criterion) { }); } - let scalar_f32_args = vec![ - ColumnarValue::Scalar(ScalarValue::Float32(Some(1.0))), - ColumnarValue::Scalar(ScalarValue::Float32(Some(2.0))), - ]; - let scalar_f32_arg_fields = vec![ - Field::new("a", DataType::Float32, false).into(), - Field::new("b", DataType::Float32, false).into(), - ]; - let return_field_f32 = Field::new("f", DataType::Float32, false).into(); - - c.bench_function("atan2 f32 scalar", |b| { - b.iter(|| { - black_box( - atan2_fn - .invoke_with_args(ScalarFunctionArgs { - args: scalar_f32_args.clone(), - arg_fields: scalar_f32_arg_fields.clone(), - number_rows: 1, - return_field: Arc::clone(&return_field_f32), - config_options: Arc::clone(&config_options), - }) - .unwrap(), - ) - }) - }); - let scalar_f64_args = vec![ ColumnarValue::Scalar(ScalarValue::Float64(Some(1.0))), ColumnarValue::Scalar(ScalarValue::Float64(Some(2.0))), diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs index 79e19313699cb..cba07327b1530 100644 --- a/datafusion/functions/src/macros.rs +++ b/datafusion/functions/src/macros.rs @@ -345,8 +345,8 @@ macro_rules! make_math_unary_udf { /// Macro to create a binary math UDF. /// -/// A binary math function takes two arguments of types Float32 or Float64, -/// applies a binary floating function to the argument, and returns a value of the same type. +/// A binary math function takes two numeric arguments, coerces them to Float64, +/// applies a binary floating function to the arguments, and returns Float64. /// /// $UDF: the name of the UDF struct that implements `ScalarUDFImpl` /// $NAME: the name of the function @@ -362,10 +362,9 @@ macro_rules! make_math_binary_udf { use std::sync::Arc; use arrow::array::{ArrayRef, AsArray}; - use arrow::datatypes::{DataType, Float32Type, Float64Type}; + use arrow::datatypes::{DataType, Float64Type}; use datafusion_common::utils::take_function_args; use datafusion_common::{Result, ScalarValue, internal_err}; - use datafusion_expr::TypeSignature; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, @@ -381,11 +380,8 @@ macro_rules! make_math_binary_udf { pub fn new() -> Self { use DataType::*; Self { - signature: Signature::one_of( - vec![ - TypeSignature::Exact(vec![Float32, Float32]), - TypeSignature::Exact(vec![Float64, Float64]), - ], + signature: Signature::exact( + vec![Float64, Float64], Volatility::Immutable, ), } @@ -401,14 +397,8 @@ macro_rules! make_math_binary_udf { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - let arg_type = &arg_types[0]; - - match arg_type { - DataType::Float32 => Ok(DataType::Float32), - // For other types (possible values float64/null/int), use Float64 - _ => Ok(DataType::Float64), - } + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) } fn output_ordering( @@ -422,10 +412,7 @@ macro_rules! make_math_binary_udf { &self, args: ScalarFunctionArgs, ) -> Result { - let ScalarFunctionArgs { - args, return_field, .. - } = args; - let return_type = return_field.data_type(); + let ScalarFunctionArgs { args, .. } = args; let [y, x] = take_function_args(self.name(), args)?; match (y, x) { @@ -434,8 +421,7 @@ macro_rules! make_math_binary_udf { ColumnarValue::Scalar(x_scalar), ) => match (&y_scalar, &x_scalar) { (y, x) if y.is_null() || x.is_null() => { - ColumnarValue::Scalar(ScalarValue::Null) - .cast_to(return_type, None) + Ok(ColumnarValue::Scalar(ScalarValue::Float64(None))) } ( ScalarValue::Float64(Some(yv)), @@ -443,12 +429,6 @@ macro_rules! make_math_binary_udf { ) => Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some( f64::$BINARY_FUNC(*yv, *xv), )))), - ( - ScalarValue::Float32(Some(yv)), - ScalarValue::Float32(Some(xv)), - ) => Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some( - f32::$BINARY_FUNC(*yv, *xv), - )))), _ => internal_err!( "Unexpected scalar types for function {}: {:?}, {:?}", self.name(), @@ -458,38 +438,25 @@ macro_rules! make_math_binary_udf { }, (y, x) => { let args = ColumnarValue::values_to_arrays(&[y, x])?; - let arr: ArrayRef = match args[0].data_type() { - DataType::Float64 => { + match (args[0].data_type(), args[1].data_type()) { + (DataType::Float64, DataType::Float64) => { let y = args[0].as_primitive::(); let x = args[1].as_primitive::(); - let result = - arrow::compute::binary::<_, _, _, Float64Type>( - y, - x, - |y, x| f64::$BINARY_FUNC(y, x), - )?; - Arc::new(result) as _ - } - DataType::Float32 => { - let y = args[0].as_primitive::(); - let x = args[1].as_primitive::(); - let result = - arrow::compute::binary::<_, _, _, Float32Type>( - y, - x, - |y, x| f32::$BINARY_FUNC(y, x), - )?; - Arc::new(result) as _ + let result = arrow::compute::binary::<_, _, _, Float64Type>( + y, + x, + |y, x| f64::$BINARY_FUNC(y, x), + )?; + + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) } - other => { - return internal_err!( - "Unsupported data type {other:?} for function {}", + (left, right) => { + internal_err!( + "Unexpected array types for function {}: {left:?}, {right:?}", self.name() - ); + ) } - }; - - Ok(ColumnarValue::Array(arr)) + } } } } diff --git a/datafusion/functions/src/math/monotonicity.rs b/datafusion/functions/src/math/monotonicity.rs index 4a0db9ef0cf7a..52449f9c9e0b9 100644 --- a/datafusion/functions/src/math/monotonicity.rs +++ b/datafusion/functions/src/math/monotonicity.rs @@ -262,11 +262,11 @@ Can be a constant, column, or function, and any combination of arithmetic operat ) .with_sql_example(r#"```sql > SELECT atan2(1, 1); -+------------+ -| atan2(1,1) | -+------------+ -| 0.7853982 | -+------------+ ++--------------------+ +| atan2(1,1) | ++--------------------+ +| 0.7853981633974483 | ++--------------------+ ```"#) .build() }); diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 38f76f13151bc..fefc39060cf6c 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -234,7 +234,23 @@ select round(atanh(a), 5), round(atanh(b), 5), round(atanh(c), 5) from small_flo query RRR rowsort select atan2(0, 1), atan2(1, 2), atan2(2, 2); ---- -0 0.4636476 0.7853982 +0 0.463647609001 0.785398163397 + +# atan2 always returns Float64, including integer, Float32, and NULL inputs +query TTTT +select + arrow_typeof(atan2(1, 1)), + arrow_typeof(atan2(arrow_cast(1.0, 'Float32'), arrow_cast(1.0, 'Float32'))), + arrow_typeof(atan2(null, null)), + arrow_typeof(atan2(null, 64)); +---- +Float64 Float64 Float64 Float64 + +# atan2 with integer inputs is computed in double precision +query B +select atan2(1, 1000000) = atan2(1.0, 1000000.0); +---- +true # atan2 scalar nulls query R rowsort diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 6bf61391eb10e..8ff0032723f90 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -227,11 +227,11 @@ atan2(expression_y, expression_x) ```sql > SELECT atan2(1, 1); -+------------+ -| atan2(1,1) | -+------------+ -| 0.7853982 | -+------------+ ++--------------------+ +| atan2(1,1) | ++--------------------+ +| 0.7853981633974483 | ++--------------------+ ``` ### `atanh`