Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 116 additions & 66 deletions crates/connect/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ impl DataFrame {
}

/// Persists the [DataFrame] with the default [storage::StorageLevel::MemoryAndDiskDeser] (MEMORY_AND_DISK_DESER).
pub async fn cache(self) -> DataFrame {
pub async fn cache(self) -> Result<DataFrame, SparkError> {
self.persist(storage::StorageLevel::MemoryAndDiskDeser)
.await
}
Expand Down Expand Up @@ -211,15 +211,21 @@ impl DataFrame {
pub async fn columns(self) -> Result<Vec<String>, SparkError> {
let schema = self.schema().await?;

let struct_val = schema.kind.expect("Unwrapped an empty schema");
let struct_val = schema.kind.ok_or(SparkError::AnalysisException(
"Schema response is empty".to_string(),
))?;

let cols = match struct_val {
spark::data_type::Kind::Struct(val) => val
.fields
.iter()
.map(|field| field.name.to_string())
.collect(),
_ => unimplemented!("Unexpected schema response"),
_ => {
return Err(SparkError::NotYetImplemented(
"Unexpected schema response".to_string(),
))
}
};

Ok(cols)
Expand All @@ -240,11 +246,18 @@ impl DataFrame {
let col = result.column(0);

let data: &PrimitiveArray<Float64Type> = match col.data_type() {
DataType::Float64 => col
.as_any()
.downcast_ref()
.expect("failed to unwrap result"),
_ => panic!("Expected Float64 in response type"),
DataType::Float64 => {
col.as_any()
.downcast_ref()
.ok_or(SparkError::AnalysisException(
"Failed to extract numeric result".to_string(),
))?
}
_ => {
return Err(SparkError::AnalysisException(
"Expected Float64 in response type for corr()".to_string(),
))
}
};

Ok(data.value(0))
Expand All @@ -257,8 +270,18 @@ impl DataFrame {
let col = res.column(0);

let data: &arrow::array::Int64Array = match col.data_type() {
arrow::datatypes::DataType::Int64 => col.as_any().downcast_ref().unwrap(),
_ => unimplemented!("only Utf8 data types are currently handled currently."),
arrow::datatypes::DataType::Int64 => {
col.as_any()
.downcast_ref()
.ok_or(SparkError::AnalysisException(
"Failed to extract numeric result".to_string(),
))?
}
_ => {
return Err(SparkError::NotYetImplemented(
"only Utf8 data types are currently handled".to_string(),
))
}
};

Ok(data.value(0))
Expand All @@ -278,11 +301,18 @@ impl DataFrame {
let col = result.column(0);

let data: &PrimitiveArray<Float64Type> = match col.data_type() {
DataType::Float64 => col
.as_any()
.downcast_ref()
.expect("failed to unwrap result"),
_ => panic!("Expected Float64 in response type"),
DataType::Float64 => {
col.as_any()
.downcast_ref()
.ok_or(SparkError::AnalysisException(
"Failed to extract numeric result".to_string(),
))?
}
_ => {
return Err(SparkError::AnalysisException(
"Expected Float64 in response type for cov()".to_string(),
))
}
};

Ok(data.value(0))
Expand Down Expand Up @@ -453,35 +483,48 @@ impl DataFrame {
pub async fn dtypes(self) -> Result<Vec<(String, spark::data_type::Kind)>, SparkError> {
let schema = self.schema().await?;

let struct_val = schema.kind.expect("unwrapped an empty schema");
let struct_val = schema.kind.ok_or(SparkError::AnalysisException(
"Schema response is empty".to_string(),
))?;

let dtypes = match struct_val {
spark::data_type::Kind::Struct(val) => val
.fields
.iter()
.map(|field| {
(
field.name.to_string(),
field.data_type.clone().unwrap().kind.unwrap(),
)
})
.collect(),
_ => unimplemented!("Unexpected schema response"),
spark::data_type::Kind::Struct(val) => {
let mut result = Vec::new();
for field in &val.fields {
let data_type =
field
.data_type
.clone()
.ok_or(SparkError::AnalysisException(
"Field data type is missing".to_string(),
))?;
let kind = data_type.kind.ok_or(SparkError::AnalysisException(
"Field data type kind is missing".to_string(),
))?;
result.push((field.name.to_string(), kind));
}
result
}
_ => {
return Err(SparkError::NotYetImplemented(
"Unexpected schema response".to_string(),
))
}
};

Ok(dtypes)
}

/// Return a new [DataFrame] containing rows in this [DataFrame] but not in another [DataFrame] while preserving duplicates.
pub fn except_all(self, other: DataFrame) -> DataFrame {
self.check_same_session(&other).unwrap();
pub fn except_all(self, other: DataFrame) -> Result<DataFrame, SparkError> {
self.check_same_session(&other)?;

let plan = self.plan.except_all(other.plan);

DataFrame {
Ok(DataFrame {
spark_session: self.spark_session,
plan,
}
})
}

/// Prints the [spark::Plan] to the console
Expand Down Expand Up @@ -612,27 +655,27 @@ impl DataFrame {
}

/// Return a new [DataFrame] containing rows only in both this [DataFrame] and another [DataFrame].
pub fn intersect(self, other: DataFrame) -> DataFrame {
self.check_same_session(&other).unwrap();
pub fn intersect(self, other: DataFrame) -> Result<DataFrame, SparkError> {
self.check_same_session(&other)?;

let plan = self.plan.intersect(other.plan);

DataFrame {
Ok(DataFrame {
spark_session: self.spark_session,
plan,
}
})
}

/// Return a new [DataFrame] containing rows in both this [DataFrame] and another [DataFrame] while preserving duplicates.
pub fn intersect_all(self, other: DataFrame) -> DataFrame {
self.check_same_session(&other).unwrap();
pub fn intersect_all(self, other: DataFrame) -> Result<DataFrame, SparkError> {
self.check_same_session(&other)?;

let plan = self.plan.intersect_all(other.plan);

DataFrame {
Ok(DataFrame {
spark_session: self.spark_session,
plan,
}
})
}

/// Checks if the DataFrame is empty and returns a boolean value.
Expand Down Expand Up @@ -756,7 +799,10 @@ impl DataFrame {
}

/// Sets the storage level to persist the contents of the [DataFrame] across operations after the first time it is computed.
pub async fn persist(self, storage_level: storage::StorageLevel) -> DataFrame {
pub async fn persist(
self,
storage_level: storage::StorageLevel,
) -> Result<DataFrame, SparkError> {
let analyze =
spark::analyze_plan_request::Analyze::Persist(spark::analyze_plan_request::Persist {
relation: Some(self.plan.clone().relation()),
Expand All @@ -765,14 +811,14 @@ impl DataFrame {

let mut client = self.spark_session.clone().client();

client.analyze(analyze).await.unwrap();
client.analyze(analyze).await?;

let plan = self.plan;

DataFrame {
Ok(DataFrame {
spark_session: self.spark_session,
plan,
}
})
}

/// Prints out the schema in the tree format to a specific level number.
Expand Down Expand Up @@ -1092,15 +1138,15 @@ impl DataFrame {
}

/// Return a new [DataFrame] containing rows in this [DataFrame] but not in another [DataFrame].
pub fn subtract(self, other: DataFrame) -> DataFrame {
self.check_same_session(&other).unwrap();
pub fn subtract(self, other: DataFrame) -> Result<DataFrame, SparkError> {
self.check_same_session(&other)?;

let plan = self.plan.substract(other.plan);

DataFrame {
Ok(DataFrame {
spark_session: self.spark_session,
plan,
}
})
}

/// Computes specified statistics for numeric and string columns.
Expand Down Expand Up @@ -1236,43 +1282,47 @@ impl DataFrame {
}

/// Return a new [DataFrame] containing the union of rows in this and another [DataFrame].
pub fn union(self, other: DataFrame) -> DataFrame {
self.check_same_session(&other).unwrap();
pub fn union(self, other: DataFrame) -> Result<DataFrame, SparkError> {
self.check_same_session(&other)?;

let plan = self.plan.union_all(other.plan);

DataFrame {
Ok(DataFrame {
spark_session: self.spark_session,
plan,
}
})
}

/// Return a new [DataFrame] containing the union of rows in this and another [DataFrame].
pub fn union_all(self, other: DataFrame) -> DataFrame {
self.check_same_session(&other).unwrap();
pub fn union_all(self, other: DataFrame) -> Result<DataFrame, SparkError> {
self.check_same_session(&other)?;

let plan = self.plan.union_all(other.plan);

DataFrame {
Ok(DataFrame {
spark_session: self.spark_session,
plan,
}
})
}

/// Returns a new [DataFrame] containing union of rows in this and another [DataFrame].
pub fn union_by_name(self, other: DataFrame, allow_missing_columns: Option<bool>) -> DataFrame {
self.check_same_session(&other).unwrap();
pub fn union_by_name(
self,
other: DataFrame,
allow_missing_columns: Option<bool>,
) -> Result<DataFrame, SparkError> {
self.check_same_session(&other)?;

let plan = self.plan.union_by_name(other.plan, allow_missing_columns);

DataFrame {
Ok(DataFrame {
spark_session: self.spark_session,
plan,
}
})
}

/// Marks the [DataFrame] as non-persistent, and remove all blocks for it from memory and disk.
pub async fn unpersist(self, blocking: Option<bool>) -> DataFrame {
pub async fn unpersist(self, blocking: Option<bool>) -> Result<DataFrame, SparkError> {
let unpersist = spark::analyze_plan_request::Analyze::Unpersist(
spark::analyze_plan_request::Unpersist {
relation: Some(self.plan.clone().relation()),
Expand All @@ -1282,12 +1332,12 @@ impl DataFrame {

let mut client = self.spark_session.clone().client();

client.analyze(unpersist).await.unwrap();
client.analyze(unpersist).await?;

DataFrame {
Ok(DataFrame {
spark_session: self.spark_session,
plan: self.plan,
}
})
}

/// Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns set.
Expand Down Expand Up @@ -1575,7 +1625,7 @@ mod tests {
let spark = setup().await;

let df = spark.range(None, 2, 1, None);
df.clone().cache().await;
df.clone().cache().await?;

let exp = df.clone().explain(None).await?;
assert!(exp.contains("InMemoryTableScan"));
Expand Down Expand Up @@ -1944,7 +1994,7 @@ mod tests {

let df2 = spark.create_dataframe(&data2)?;

let output = df1.except_all(df2).collect().await?;
let output = df1.except_all(df2)?.collect().await?;

let c1: ArrayRef = Arc::new(Int64Array::from(vec![1, 10]));
let c2: ArrayRef = Arc::new(Int64Array::from(vec![1, 2]));
Expand Down Expand Up @@ -2141,7 +2191,7 @@ mod tests {

let df2 = spark.create_dataframe(&data2)?;

let output = df1.intersect(df2).collect().await?;
let output = df1.intersect(df2)?.collect().await?;

let c1: ArrayRef = Arc::new(Int64Array::from(vec![1, 19]));
let c2: ArrayRef = Arc::new(Int64Array::from(vec![1, 8]));
Expand Down