From 603179fa67e2d6bc5ab960fb8420b5550862384c Mon Sep 17 00:00:00 2001 From: Rafael Fernandez Date: Sun, 29 Mar 2026 11:20:06 +0200 Subject: [PATCH] [SPARK-52428] Add higher-order functions with lambda expression support --- crates/connect/src/functions/mod.rs | 329 +++++++++++++++++++++++++++- 1 file changed, 324 insertions(+), 5 deletions(-) diff --git a/crates/connect/src/functions/mod.rs b/crates/connect/src/functions/mod.rs index b518b4d..efa9817 100644 --- a/crates/connect/src/functions/mod.rs +++ b/crates/connect/src/functions/mod.rs @@ -91,6 +91,198 @@ where create_map(map) } +/// Creates a lambda variable reference for use in higher-order function bodies. +/// +/// Use this instead of `col()` when referencing lambda parameters inside +/// `transform`, `filter`, `aggregate`, etc. +/// +/// # Example +/// ```rust +/// // Transform array: [1,2,3] -> [2,3,4] +/// transform(col("array"), lvar("x") + lit(1), "x") +/// ``` +pub fn lvar(name: &str) -> Column { + Column::from(spark::Expression { + expr_type: Some(spark::expression::ExprType::UnresolvedNamedLambdaVariable( + spark::expression::UnresolvedNamedLambdaVariable { + name_parts: vec![name.to_string()], + }, + )), + }) +} + +/// Creates a lambda expression for use with higher-order functions. +fn create_lambda(func: Column, var_names: &[&str]) -> spark::Expression { + let arguments = var_names + .iter() + .map(|name| spark::expression::UnresolvedNamedLambdaVariable { + name_parts: vec![name.to_string()], + }) + .collect(); + + spark::Expression { + expr_type: Some(spark::expression::ExprType::LambdaFunction(Box::new( + spark::expression::LambdaFunction { + function: Some(Box::new(func.expression)), + arguments, + }, + ))), + } +} + +/// Applies a function to every element in the array. +/// +/// # Example +/// ```rust +/// // Transform array elements: [1,2,3] -> [2,3,4] +/// transform(col("array"), lvar("x") + lit(1), "x") +/// ``` +pub fn transform(col: impl Into, func: Column, var_name: &str) -> Column { + let lambda = create_lambda(func, &[var_name]); + invoke_func("transform", vec![col.into(), Column::from(lambda)]) +} + +/// Filters an array using a boolean predicate. +/// +/// # Example +/// ```rust +/// // Keep only positive values +/// filter(col("array"), lvar("x").eq(lit(1)), "x") +/// ``` +pub fn filter(col: impl Into, func: Column, var_name: &str) -> Column { + let lambda = create_lambda(func, &[var_name]); + invoke_func("filter", vec![col.into(), Column::from(lambda)]) +} + +/// Returns true if the predicate holds for any element in the array. +/// +/// # Example +/// ```rust +/// exists(col("array"), lvar("x").eq(lit(1)), "x") +/// ``` +pub fn exists(col: impl Into, func: Column, var_name: &str) -> Column { + let lambda = create_lambda(func, &[var_name]); + invoke_func("exists", vec![col.into(), Column::from(lambda)]) +} + +/// Returns true if the predicate holds for all elements in the array. +/// +/// # Example +/// ```rust +/// forall(col("array"), lvar("x").eq(lit(1)), "x") +/// ``` +pub fn forall(col: impl Into, func: Column, var_name: &str) -> Column { + let lambda = create_lambda(func, &[var_name]); + invoke_func("forall", vec![col.into(), Column::from(lambda)]) +} + +/// Applies a binary function to an initial state and all elements in the array, +/// and reduces this to a single state. +/// +/// # Example +/// ```rust +/// // Sum all elements: aggregate([1,2,3], 0, (acc, x) -> acc + x) +/// aggregate(col("array"), lit(0), lvar("acc") + lvar("x"), "acc", "x") +/// ``` +pub fn aggregate( + col: impl Into, + initial_value: impl Into, + merge: Column, + acc_name: &str, + elem_name: &str, +) -> Column { + let lambda = create_lambda(merge, &[acc_name, elem_name]); + invoke_func( + "aggregate", + vec![col.into(), initial_value.into(), Column::from(lambda)], + ) +} + +/// Sorts the given array using a comparator function. +/// +/// # Example +/// ```rust +/// // Sort descending: array_sort([3,1,2], (a, b) -> b - a) +/// array_sort_with_comp(col("array"), lvar("b") - lvar("a"), "a", "b") +/// ``` +pub fn array_sort_with_comp( + col: impl Into, + comparator: Column, + left_name: &str, + right_name: &str, +) -> Column { + let lambda = create_lambda(comparator, &[left_name, right_name]); + invoke_func("array_sort", vec![col.into(), Column::from(lambda)]) +} + +/// Filters entries from a map using a predicate. +/// +/// # Example +/// ```rust +/// map_filter(col("map"), lvar("v").eq(lit(1)), "k", "v") +/// ``` +pub fn map_filter( + col: impl Into, + func: Column, + key_name: &str, + value_name: &str, +) -> Column { + let lambda = create_lambda(func, &[key_name, value_name]); + invoke_func("map_filter", vec![col.into(), Column::from(lambda)]) +} + +/// Applies a function to every key-value pair in a map and returns a map of the results. +/// +/// # Example +/// ```rust +/// transform_keys(col("map"), lvar("k") + lit(1), "k", "v") +/// ``` +pub fn transform_keys( + col: impl Into, + func: Column, + key_name: &str, + value_name: &str, +) -> Column { + let lambda = create_lambda(func, &[key_name, value_name]); + invoke_func("transform_keys", vec![col.into(), Column::from(lambda)]) +} + +/// Applies a function to every key-value pair in a map and returns a map with transformed values. +/// +/// # Example +/// ```rust +/// transform_values(col("map"), lvar("v") * lit(2), "k", "v") +/// ``` +pub fn transform_values( + col: impl Into, + func: Column, + key_name: &str, + value_name: &str, +) -> Column { + let lambda = create_lambda(func, &[key_name, value_name]); + invoke_func("transform_values", vec![col.into(), Column::from(lambda)]) +} + +/// Merges two arrays element-wise using a function. +/// +/// # Example +/// ```rust +/// zip_with(col("arr1"), col("arr2"), lvar("x") + lvar("y"), "x", "y") +/// ``` +pub fn zip_with( + left: impl Into, + right: impl Into, + func: Column, + left_name: &str, + right_name: &str, +) -> Column { + let lambda = create_lambda(func, &[left_name, right_name]); + invoke_func( + "zip_with", + vec![left.into(), right.into(), Column::from(lambda)], + ) +} + // Normal Functions /// Returns a [Column] based on the given column name. @@ -733,11 +925,6 @@ gen_func!(element_at, [col: Column, extraction: Column], "Returns element of arr gen_func!(array_append, [col: Column, value: Column], "Returns an array of the elements in col1 along with the added element in col2 at the last of the array."); gen_func!(array_size, [col: Column], "Returns the total number of elements in the array."); -#[allow(unused_variables)] -pub fn array_sort(col: impl Into, compactor: Option>) -> Column { - unimplemented!() -} - /// adds an item into a given array at a specified array index. pub fn array_insert( col: impl Into, @@ -2620,4 +2807,136 @@ mod tests { Ok(()) } + + #[test] + fn test_transform_lambda_structure() { + let result = transform(col("arr"), lvar("x") + lit(1), "x"); + match result.expression.expr_type { + Some(spark::expression::ExprType::UnresolvedFunction(ref f)) => { + assert_eq!(f.function_name, "transform"); + assert_eq!(f.arguments.len(), 2); + // Second arg should be a LambdaFunction + match &f.arguments[1].expr_type { + Some(spark::expression::ExprType::LambdaFunction(ref lambda)) => { + assert!(lambda.function.is_some()); + assert_eq!(lambda.arguments.len(), 1); + assert_eq!(lambda.arguments[0].name_parts, vec!["x".to_string()]); + } + _ => panic!("Expected LambdaFunction as second argument"), + } + } + _ => panic!("Expected UnresolvedFunction"), + } + } + + #[test] + fn test_filter_lambda_structure() { + let result = filter(col("arr"), lvar("x").eq(lit(0)), "x"); + match result.expression.expr_type { + Some(spark::expression::ExprType::UnresolvedFunction(ref f)) => { + assert_eq!(f.function_name, "filter"); + assert_eq!(f.arguments.len(), 2); + } + _ => panic!("Expected UnresolvedFunction"), + } + } + + #[test] + fn test_aggregate_lambda_structure() { + let result = aggregate(col("arr"), lit(0), lvar("acc") + lvar("x"), "acc", "x"); + match result.expression.expr_type { + Some(spark::expression::ExprType::UnresolvedFunction(ref f)) => { + assert_eq!(f.function_name, "aggregate"); + assert_eq!(f.arguments.len(), 3); + // Third arg is the lambda with 2 variables + match &f.arguments[2].expr_type { + Some(spark::expression::ExprType::LambdaFunction(ref lambda)) => { + assert_eq!(lambda.arguments.len(), 2); + assert_eq!(lambda.arguments[0].name_parts, vec!["acc".to_string()]); + assert_eq!(lambda.arguments[1].name_parts, vec!["x".to_string()]); + } + _ => panic!("Expected LambdaFunction"), + } + } + _ => panic!("Expected UnresolvedFunction"), + } + } + + #[test] + fn test_exists_lambda_structure() { + let result = exists(col("arr"), lvar("x").eq(lit(0)), "x"); + match result.expression.expr_type { + Some(spark::expression::ExprType::UnresolvedFunction(ref f)) => { + assert_eq!(f.function_name, "exists"); + assert_eq!(f.arguments.len(), 2); + } + _ => panic!("Expected UnresolvedFunction"), + } + } + + #[test] + fn test_forall_lambda_structure() { + let result = forall(col("arr"), lvar("x").eq(lit(0)), "x"); + match result.expression.expr_type { + Some(spark::expression::ExprType::UnresolvedFunction(ref f)) => { + assert_eq!(f.function_name, "forall"); + } + _ => panic!("Expected UnresolvedFunction"), + } + } + + #[test] + fn test_array_sort_with_comp_structure() { + let result = array_sort_with_comp(col("arr"), lvar("a") - lvar("b"), "a", "b"); + match result.expression.expr_type { + Some(spark::expression::ExprType::UnresolvedFunction(ref f)) => { + assert_eq!(f.function_name, "array_sort"); + assert_eq!(f.arguments.len(), 2); + match &f.arguments[1].expr_type { + Some(spark::expression::ExprType::LambdaFunction(ref lambda)) => { + assert_eq!(lambda.arguments.len(), 2); + assert_eq!(lambda.arguments[0].name_parts, vec!["a".to_string()]); + assert_eq!(lambda.arguments[1].name_parts, vec!["b".to_string()]); + } + _ => panic!("Expected LambdaFunction"), + } + } + _ => panic!("Expected UnresolvedFunction"), + } + } + + #[test] + fn test_map_filter_structure() { + let result = map_filter(col("map"), lvar("v").eq(lit(0)), "k", "v"); + match result.expression.expr_type { + Some(spark::expression::ExprType::UnresolvedFunction(ref f)) => { + assert_eq!(f.function_name, "map_filter"); + assert_eq!(f.arguments.len(), 2); + } + _ => panic!("Expected UnresolvedFunction"), + } + } + + #[test] + fn test_transform_keys_structure() { + let result = transform_keys(col("map"), lvar("k") + lit(1), "k", "v"); + match result.expression.expr_type { + Some(spark::expression::ExprType::UnresolvedFunction(ref f)) => { + assert_eq!(f.function_name, "transform_keys"); + } + _ => panic!("Expected UnresolvedFunction"), + } + } + + #[test] + fn test_zip_with_structure() { + let result = zip_with(col("a"), col("b"), lvar("x") + lvar("y"), "x", "y"); + match result.expression.expr_type { + Some(spark::expression::ExprType::UnresolvedFunction(ref f)) => { + assert_eq!(f.function_name, "zip_with"); + assert_eq!(f.arguments.len(), 3); + } + _ => panic!("Expected UnresolvedFunction"), + } + } }