Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
329 changes: 324 additions & 5 deletions crates/connect/src/functions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,198 @@ where
create_map(map)
}

/// Creates a lambda variable reference for use in higher-order function bodies.
///
/// Use this instead of `col()` when referencing lambda parameters inside
/// `transform`, `filter`, `aggregate`, etc.
///
/// # Example
/// ```rust
/// // Transform array: [1,2,3] -> [2,3,4]
/// transform(col("array"), lvar("x") + lit(1), "x")
/// ```
pub fn lvar(name: &str) -> Column {
Column::from(spark::Expression {
expr_type: Some(spark::expression::ExprType::UnresolvedNamedLambdaVariable(
spark::expression::UnresolvedNamedLambdaVariable {
name_parts: vec![name.to_string()],
},
)),
})
}

/// Creates a lambda expression for use with higher-order functions.
fn create_lambda(func: Column, var_names: &[&str]) -> spark::Expression {
let arguments = var_names
.iter()
.map(|name| spark::expression::UnresolvedNamedLambdaVariable {
name_parts: vec![name.to_string()],
})
.collect();

spark::Expression {
expr_type: Some(spark::expression::ExprType::LambdaFunction(Box::new(
spark::expression::LambdaFunction {
function: Some(Box::new(func.expression)),
arguments,
},
))),
}
}

/// Applies a function to every element in the array.
///
/// # Example
/// ```rust
/// // Transform array elements: [1,2,3] -> [2,3,4]
/// transform(col("array"), lvar("x") + lit(1), "x")
/// ```
pub fn transform(col: impl Into<Column>, func: Column, var_name: &str) -> Column {
let lambda = create_lambda(func, &[var_name]);
invoke_func("transform", vec![col.into(), Column::from(lambda)])
}

/// Filters an array using a boolean predicate.
///
/// # Example
/// ```rust
/// // Keep only positive values
/// filter(col("array"), lvar("x").eq(lit(1)), "x")
/// ```
pub fn filter(col: impl Into<Column>, func: Column, var_name: &str) -> Column {
let lambda = create_lambda(func, &[var_name]);
invoke_func("filter", vec![col.into(), Column::from(lambda)])
}

/// Returns true if the predicate holds for any element in the array.
///
/// # Example
/// ```rust
/// exists(col("array"), lvar("x").eq(lit(1)), "x")
/// ```
pub fn exists(col: impl Into<Column>, func: Column, var_name: &str) -> Column {
let lambda = create_lambda(func, &[var_name]);
invoke_func("exists", vec![col.into(), Column::from(lambda)])
}

/// Returns true if the predicate holds for all elements in the array.
///
/// # Example
/// ```rust
/// forall(col("array"), lvar("x").eq(lit(1)), "x")
/// ```
pub fn forall(col: impl Into<Column>, func: Column, var_name: &str) -> Column {
let lambda = create_lambda(func, &[var_name]);
invoke_func("forall", vec![col.into(), Column::from(lambda)])
}

/// Applies a binary function to an initial state and all elements in the array,
/// and reduces this to a single state.
///
/// # Example
/// ```rust
/// // Sum all elements: aggregate([1,2,3], 0, (acc, x) -> acc + x)
/// aggregate(col("array"), lit(0), lvar("acc") + lvar("x"), "acc", "x")
/// ```
pub fn aggregate(
col: impl Into<Column>,
initial_value: impl Into<Column>,
merge: Column,
acc_name: &str,
elem_name: &str,
) -> Column {
let lambda = create_lambda(merge, &[acc_name, elem_name]);
invoke_func(
"aggregate",
vec![col.into(), initial_value.into(), Column::from(lambda)],
)
}

/// Sorts the given array using a comparator function.
///
/// # Example
/// ```rust
/// // Sort descending: array_sort([3,1,2], (a, b) -> b - a)
/// array_sort_with_comp(col("array"), lvar("b") - lvar("a"), "a", "b")
/// ```
pub fn array_sort_with_comp(
col: impl Into<Column>,
comparator: Column,
left_name: &str,
right_name: &str,
) -> Column {
let lambda = create_lambda(comparator, &[left_name, right_name]);
invoke_func("array_sort", vec![col.into(), Column::from(lambda)])
}

/// Filters entries from a map using a predicate.
///
/// # Example
/// ```rust
/// map_filter(col("map"), lvar("v").eq(lit(1)), "k", "v")
/// ```
pub fn map_filter(
col: impl Into<Column>,
func: Column,
key_name: &str,
value_name: &str,
) -> Column {
let lambda = create_lambda(func, &[key_name, value_name]);
invoke_func("map_filter", vec![col.into(), Column::from(lambda)])
}

/// Applies a function to every key-value pair in a map and returns a map of the results.
///
/// # Example
/// ```rust
/// transform_keys(col("map"), lvar("k") + lit(1), "k", "v")
/// ```
pub fn transform_keys(
col: impl Into<Column>,
func: Column,
key_name: &str,
value_name: &str,
) -> Column {
let lambda = create_lambda(func, &[key_name, value_name]);
invoke_func("transform_keys", vec![col.into(), Column::from(lambda)])
}

/// Applies a function to every key-value pair in a map and returns a map with transformed values.
///
/// # Example
/// ```rust
/// transform_values(col("map"), lvar("v") * lit(2), "k", "v")
/// ```
pub fn transform_values(
col: impl Into<Column>,
func: Column,
key_name: &str,
value_name: &str,
) -> Column {
let lambda = create_lambda(func, &[key_name, value_name]);
invoke_func("transform_values", vec![col.into(), Column::from(lambda)])
}

/// Merges two arrays element-wise using a function.
///
/// # Example
/// ```rust
/// zip_with(col("arr1"), col("arr2"), lvar("x") + lvar("y"), "x", "y")
/// ```
pub fn zip_with(
left: impl Into<Column>,
right: impl Into<Column>,
func: Column,
left_name: &str,
right_name: &str,
) -> Column {
let lambda = create_lambda(func, &[left_name, right_name]);
invoke_func(
"zip_with",
vec![left.into(), right.into(), Column::from(lambda)],
)
}

// Normal Functions

/// Returns a [Column] based on the given column name.
Expand Down Expand Up @@ -733,11 +925,6 @@ gen_func!(element_at, [col: Column, extraction: Column], "Returns element of arr
gen_func!(array_append, [col: Column, value: Column], "Returns an array of the elements in col1 along with the added element in col2 at the last of the array.");
gen_func!(array_size, [col: Column], "Returns the total number of elements in the array.");

#[allow(unused_variables)]
pub fn array_sort(col: impl Into<Column>, compactor: Option<impl Into<Column>>) -> Column {
unimplemented!()
}

/// adds an item into a given array at a specified array index.
pub fn array_insert(
col: impl Into<Column>,
Expand Down Expand Up @@ -2620,4 +2807,136 @@ mod tests {

Ok(())
}

#[test]
fn test_transform_lambda_structure() {
let result = transform(col("arr"), lvar("x") + lit(1), "x");
match result.expression.expr_type {
Some(spark::expression::ExprType::UnresolvedFunction(ref f)) => {
assert_eq!(f.function_name, "transform");
assert_eq!(f.arguments.len(), 2);
// Second arg should be a LambdaFunction
match &f.arguments[1].expr_type {
Some(spark::expression::ExprType::LambdaFunction(ref lambda)) => {
assert!(lambda.function.is_some());
assert_eq!(lambda.arguments.len(), 1);
assert_eq!(lambda.arguments[0].name_parts, vec!["x".to_string()]);
}
_ => panic!("Expected LambdaFunction as second argument"),
}
}
_ => panic!("Expected UnresolvedFunction"),
}
}

#[test]
fn test_filter_lambda_structure() {
let result = filter(col("arr"), lvar("x").eq(lit(0)), "x");
match result.expression.expr_type {
Some(spark::expression::ExprType::UnresolvedFunction(ref f)) => {
assert_eq!(f.function_name, "filter");
assert_eq!(f.arguments.len(), 2);
}
_ => panic!("Expected UnresolvedFunction"),
}
}

#[test]
fn test_aggregate_lambda_structure() {
let result = aggregate(col("arr"), lit(0), lvar("acc") + lvar("x"), "acc", "x");
match result.expression.expr_type {
Some(spark::expression::ExprType::UnresolvedFunction(ref f)) => {
assert_eq!(f.function_name, "aggregate");
assert_eq!(f.arguments.len(), 3);
// Third arg is the lambda with 2 variables
match &f.arguments[2].expr_type {
Some(spark::expression::ExprType::LambdaFunction(ref lambda)) => {
assert_eq!(lambda.arguments.len(), 2);
assert_eq!(lambda.arguments[0].name_parts, vec!["acc".to_string()]);
assert_eq!(lambda.arguments[1].name_parts, vec!["x".to_string()]);
}
_ => panic!("Expected LambdaFunction"),
}
}
_ => panic!("Expected UnresolvedFunction"),
}
}

#[test]
fn test_exists_lambda_structure() {
let result = exists(col("arr"), lvar("x").eq(lit(0)), "x");
match result.expression.expr_type {
Some(spark::expression::ExprType::UnresolvedFunction(ref f)) => {
assert_eq!(f.function_name, "exists");
assert_eq!(f.arguments.len(), 2);
}
_ => panic!("Expected UnresolvedFunction"),
}
}

#[test]
fn test_forall_lambda_structure() {
let result = forall(col("arr"), lvar("x").eq(lit(0)), "x");
match result.expression.expr_type {
Some(spark::expression::ExprType::UnresolvedFunction(ref f)) => {
assert_eq!(f.function_name, "forall");
}
_ => panic!("Expected UnresolvedFunction"),
}
}

#[test]
fn test_array_sort_with_comp_structure() {
let result = array_sort_with_comp(col("arr"), lvar("a") - lvar("b"), "a", "b");
match result.expression.expr_type {
Some(spark::expression::ExprType::UnresolvedFunction(ref f)) => {
assert_eq!(f.function_name, "array_sort");
assert_eq!(f.arguments.len(), 2);
match &f.arguments[1].expr_type {
Some(spark::expression::ExprType::LambdaFunction(ref lambda)) => {
assert_eq!(lambda.arguments.len(), 2);
assert_eq!(lambda.arguments[0].name_parts, vec!["a".to_string()]);
assert_eq!(lambda.arguments[1].name_parts, vec!["b".to_string()]);
}
_ => panic!("Expected LambdaFunction"),
}
}
_ => panic!("Expected UnresolvedFunction"),
}
}

#[test]
fn test_map_filter_structure() {
let result = map_filter(col("map"), lvar("v").eq(lit(0)), "k", "v");
match result.expression.expr_type {
Some(spark::expression::ExprType::UnresolvedFunction(ref f)) => {
assert_eq!(f.function_name, "map_filter");
assert_eq!(f.arguments.len(), 2);
}
_ => panic!("Expected UnresolvedFunction"),
}
}

#[test]
fn test_transform_keys_structure() {
let result = transform_keys(col("map"), lvar("k") + lit(1), "k", "v");
match result.expression.expr_type {
Some(spark::expression::ExprType::UnresolvedFunction(ref f)) => {
assert_eq!(f.function_name, "transform_keys");
}
_ => panic!("Expected UnresolvedFunction"),
}
}

#[test]
fn test_zip_with_structure() {
let result = zip_with(col("a"), col("b"), lvar("x") + lvar("y"), "x", "y");
match result.expression.expr_type {
Some(spark::expression::ExprType::UnresolvedFunction(ref f)) => {
assert_eq!(f.function_name, "zip_with");
assert_eq!(f.arguments.len(), 3);
}
_ => panic!("Expected UnresolvedFunction"),
}
}
}