Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 151 additions & 0 deletions crates/connect/src/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,100 @@ where
}
}

/// Represents a Spark Map literal value.
///
/// Used with [`lit`](crate::functions::lit) to create a map column literal.
#[derive(Clone, Debug)]
pub struct MapLiteral<K, V>
where
K: Into<spark::expression::Literal> + Clone,
V: Into<spark::expression::Literal> + Clone,
spark::DataType: From<K>,
spark::DataType: From<V>,
{
pub keys: Vec<K>,
pub values: Vec<V>,
}

impl<K, V> MapLiteral<K, V>
where
K: Into<spark::expression::Literal> + Clone,
V: Into<spark::expression::Literal> + Clone,
spark::DataType: From<K>,
spark::DataType: From<V>,
{
pub fn new(keys: Vec<K>, values: Vec<V>) -> Self {
assert_eq!(
keys.len(),
values.len(),
"Keys and values must have the same length"
);
Self { keys, values }
}
}

impl<K, V> From<MapLiteral<K, V>> for spark::expression::Literal
where
K: Into<spark::expression::Literal> + Clone,
V: Into<spark::expression::Literal> + Clone,
spark::DataType: From<K>,
spark::DataType: From<V>,
{
fn from(value: MapLiteral<K, V>) -> Self {
let key_type = Some(spark::DataType::from(
value.keys.first().expect("Map cannot be empty").clone(),
));
let value_type = Some(spark::DataType::from(
value.values.first().expect("Map cannot be empty").clone(),
));

let keys = value.keys.into_iter().map(|k| k.into()).collect();
let values = value.values.into_iter().map(|v| v.into()).collect();

let map = spark::expression::literal::Map {
key_type,
value_type,
keys,
values,
};

spark::expression::Literal {
literal_type: Some(spark::expression::literal::LiteralType::Map(map)),
}
}
}

/// Represents a Spark Struct literal value.
///
/// Used with [`lit`](crate::functions::lit) to create a struct column literal.
#[derive(Clone, Debug)]
pub struct StructLiteral {
pub struct_type: spark::DataType,
pub elements: Vec<spark::expression::Literal>,
}

impl StructLiteral {
pub fn new(struct_type: spark::DataType, elements: Vec<spark::expression::Literal>) -> Self {
Self {
struct_type,
elements,
}
}
}

impl From<StructLiteral> for spark::expression::Literal {
fn from(value: StructLiteral) -> Self {
let struct_lit = spark::expression::literal::Struct {
struct_type: Some(value.struct_type),
elements: value.elements,
};

spark::expression::Literal {
literal_type: Some(spark::expression::literal::LiteralType::Struct(struct_lit)),
}
}
}

impl From<&str> for spark::expression::cast::CastToType {
fn from(value: &str) -> Self {
spark::expression::cast::CastToType::TypeStr(value.to_string())
Expand All @@ -266,3 +360,60 @@ impl From<DataType> for spark::expression::cast::CastToType {
spark::expression::cast::CastToType::Type(value.into())
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_map_literal() {
let map = MapLiteral::new(
vec!["key1".to_string(), "key2".to_string()],
vec![1i32, 2i32],
);
let literal: spark::expression::Literal = map.into();
match literal.literal_type {
Some(spark::expression::literal::LiteralType::Map(m)) => {
assert_eq!(m.keys.len(), 2);
assert_eq!(m.values.len(), 2);
assert!(m.key_type.is_some());
assert!(m.value_type.is_some());
}
_ => panic!("Expected Map literal"),
}
}

#[test]
fn test_struct_literal() {
let struct_type =
crate::types::DataType::Struct(Box::new(crate::types::StructType::new(vec![
crate::types::StructField {
name: "name",
data_type: crate::types::DataType::String,
nullable: true,
metadata: None,
},
crate::types::StructField {
name: "age",
data_type: crate::types::DataType::Integer,
nullable: true,
metadata: None,
},
])));

let elements = vec![
spark::expression::Literal::from("Alice".to_string()),
spark::expression::Literal::from(30i32),
];

let struct_lit = StructLiteral::new(struct_type.into(), elements);
let literal: spark::expression::Literal = struct_lit.into();
match literal.literal_type {
Some(spark::expression::literal::LiteralType::Struct(s)) => {
assert_eq!(s.elements.len(), 2);
assert!(s.struct_type.is_some());
}
_ => panic!("Expected Struct literal"),
}
}
}