diff --git a/crates/connect/src/dataframe.rs b/crates/connect/src/dataframe.rs index a69f655..5737282 100644 --- a/crates/connect/src/dataframe.rs +++ b/crates/connect/src/dataframe.rs @@ -169,7 +169,7 @@ impl DataFrame { } /// Persists the [DataFrame] with the default [storage::StorageLevel::MemoryAndDiskDeser] (MEMORY_AND_DISK_DESER). - pub async fn cache(self) -> DataFrame { + pub async fn cache(self) -> Result { self.persist(storage::StorageLevel::MemoryAndDiskDeser) .await } @@ -211,7 +211,9 @@ impl DataFrame { pub async fn columns(self) -> Result, SparkError> { let schema = self.schema().await?; - let struct_val = schema.kind.expect("Unwrapped an empty schema"); + let struct_val = schema.kind.ok_or(SparkError::AnalysisException( + "Schema response is empty".to_string(), + ))?; let cols = match struct_val { spark::data_type::Kind::Struct(val) => val @@ -219,7 +221,11 @@ impl DataFrame { .iter() .map(|field| field.name.to_string()) .collect(), - _ => unimplemented!("Unexpected schema response"), + _ => { + return Err(SparkError::NotYetImplemented( + "Unexpected schema response".to_string(), + )) + } }; Ok(cols) @@ -240,11 +246,18 @@ impl DataFrame { let col = result.column(0); let data: &PrimitiveArray = match col.data_type() { - DataType::Float64 => col - .as_any() - .downcast_ref() - .expect("failed to unwrap result"), - _ => panic!("Expected Float64 in response type"), + DataType::Float64 => { + col.as_any() + .downcast_ref() + .ok_or(SparkError::AnalysisException( + "Failed to extract numeric result".to_string(), + ))? + } + _ => { + return Err(SparkError::AnalysisException( + "Expected Float64 in response type for corr()".to_string(), + )) + } }; Ok(data.value(0)) @@ -257,8 +270,18 @@ impl DataFrame { let col = res.column(0); let data: &arrow::array::Int64Array = match col.data_type() { - arrow::datatypes::DataType::Int64 => col.as_any().downcast_ref().unwrap(), - _ => unimplemented!("only Utf8 data types are currently handled currently."), + arrow::datatypes::DataType::Int64 => { + col.as_any() + .downcast_ref() + .ok_or(SparkError::AnalysisException( + "Failed to extract numeric result".to_string(), + ))? + } + _ => { + return Err(SparkError::NotYetImplemented( + "only Utf8 data types are currently handled".to_string(), + )) + } }; Ok(data.value(0)) @@ -278,11 +301,18 @@ impl DataFrame { let col = result.column(0); let data: &PrimitiveArray = match col.data_type() { - DataType::Float64 => col - .as_any() - .downcast_ref() - .expect("failed to unwrap result"), - _ => panic!("Expected Float64 in response type"), + DataType::Float64 => { + col.as_any() + .downcast_ref() + .ok_or(SparkError::AnalysisException( + "Failed to extract numeric result".to_string(), + ))? + } + _ => { + return Err(SparkError::AnalysisException( + "Expected Float64 in response type for cov()".to_string(), + )) + } }; Ok(data.value(0)) @@ -453,35 +483,48 @@ impl DataFrame { pub async fn dtypes(self) -> Result, SparkError> { let schema = self.schema().await?; - let struct_val = schema.kind.expect("unwrapped an empty schema"); + let struct_val = schema.kind.ok_or(SparkError::AnalysisException( + "Schema response is empty".to_string(), + ))?; let dtypes = match struct_val { - spark::data_type::Kind::Struct(val) => val - .fields - .iter() - .map(|field| { - ( - field.name.to_string(), - field.data_type.clone().unwrap().kind.unwrap(), - ) - }) - .collect(), - _ => unimplemented!("Unexpected schema response"), + spark::data_type::Kind::Struct(val) => { + let mut result = Vec::new(); + for field in &val.fields { + let data_type = + field + .data_type + .clone() + .ok_or(SparkError::AnalysisException( + "Field data type is missing".to_string(), + ))?; + let kind = data_type.kind.ok_or(SparkError::AnalysisException( + "Field data type kind is missing".to_string(), + ))?; + result.push((field.name.to_string(), kind)); + } + result + } + _ => { + return Err(SparkError::NotYetImplemented( + "Unexpected schema response".to_string(), + )) + } }; Ok(dtypes) } /// Return a new [DataFrame] containing rows in this [DataFrame] but not in another [DataFrame] while preserving duplicates. - pub fn except_all(self, other: DataFrame) -> DataFrame { - self.check_same_session(&other).unwrap(); + pub fn except_all(self, other: DataFrame) -> Result { + self.check_same_session(&other)?; let plan = self.plan.except_all(other.plan); - DataFrame { + Ok(DataFrame { spark_session: self.spark_session, plan, - } + }) } /// Prints the [spark::Plan] to the console @@ -612,27 +655,27 @@ impl DataFrame { } /// Return a new [DataFrame] containing rows only in both this [DataFrame] and another [DataFrame]. - pub fn intersect(self, other: DataFrame) -> DataFrame { - self.check_same_session(&other).unwrap(); + pub fn intersect(self, other: DataFrame) -> Result { + self.check_same_session(&other)?; let plan = self.plan.intersect(other.plan); - DataFrame { + Ok(DataFrame { spark_session: self.spark_session, plan, - } + }) } /// Return a new [DataFrame] containing rows in both this [DataFrame] and another [DataFrame] while preserving duplicates. - pub fn intersect_all(self, other: DataFrame) -> DataFrame { - self.check_same_session(&other).unwrap(); + pub fn intersect_all(self, other: DataFrame) -> Result { + self.check_same_session(&other)?; let plan = self.plan.intersect_all(other.plan); - DataFrame { + Ok(DataFrame { spark_session: self.spark_session, plan, - } + }) } /// Checks if the DataFrame is empty and returns a boolean value. @@ -756,7 +799,10 @@ impl DataFrame { } /// Sets the storage level to persist the contents of the [DataFrame] across operations after the first time it is computed. - pub async fn persist(self, storage_level: storage::StorageLevel) -> DataFrame { + pub async fn persist( + self, + storage_level: storage::StorageLevel, + ) -> Result { let analyze = spark::analyze_plan_request::Analyze::Persist(spark::analyze_plan_request::Persist { relation: Some(self.plan.clone().relation()), @@ -765,14 +811,14 @@ impl DataFrame { let mut client = self.spark_session.clone().client(); - client.analyze(analyze).await.unwrap(); + client.analyze(analyze).await?; let plan = self.plan; - DataFrame { + Ok(DataFrame { spark_session: self.spark_session, plan, - } + }) } /// Prints out the schema in the tree format to a specific level number. @@ -1092,15 +1138,15 @@ impl DataFrame { } /// Return a new [DataFrame] containing rows in this [DataFrame] but not in another [DataFrame]. - pub fn subtract(self, other: DataFrame) -> DataFrame { - self.check_same_session(&other).unwrap(); + pub fn subtract(self, other: DataFrame) -> Result { + self.check_same_session(&other)?; let plan = self.plan.substract(other.plan); - DataFrame { + Ok(DataFrame { spark_session: self.spark_session, plan, - } + }) } /// Computes specified statistics for numeric and string columns. @@ -1236,43 +1282,47 @@ impl DataFrame { } /// Return a new [DataFrame] containing the union of rows in this and another [DataFrame]. - pub fn union(self, other: DataFrame) -> DataFrame { - self.check_same_session(&other).unwrap(); + pub fn union(self, other: DataFrame) -> Result { + self.check_same_session(&other)?; let plan = self.plan.union_all(other.plan); - DataFrame { + Ok(DataFrame { spark_session: self.spark_session, plan, - } + }) } /// Return a new [DataFrame] containing the union of rows in this and another [DataFrame]. - pub fn union_all(self, other: DataFrame) -> DataFrame { - self.check_same_session(&other).unwrap(); + pub fn union_all(self, other: DataFrame) -> Result { + self.check_same_session(&other)?; let plan = self.plan.union_all(other.plan); - DataFrame { + Ok(DataFrame { spark_session: self.spark_session, plan, - } + }) } /// Returns a new [DataFrame] containing union of rows in this and another [DataFrame]. - pub fn union_by_name(self, other: DataFrame, allow_missing_columns: Option) -> DataFrame { - self.check_same_session(&other).unwrap(); + pub fn union_by_name( + self, + other: DataFrame, + allow_missing_columns: Option, + ) -> Result { + self.check_same_session(&other)?; let plan = self.plan.union_by_name(other.plan, allow_missing_columns); - DataFrame { + Ok(DataFrame { spark_session: self.spark_session, plan, - } + }) } /// Marks the [DataFrame] as non-persistent, and remove all blocks for it from memory and disk. - pub async fn unpersist(self, blocking: Option) -> DataFrame { + pub async fn unpersist(self, blocking: Option) -> Result { let unpersist = spark::analyze_plan_request::Analyze::Unpersist( spark::analyze_plan_request::Unpersist { relation: Some(self.plan.clone().relation()), @@ -1282,12 +1332,12 @@ impl DataFrame { let mut client = self.spark_session.clone().client(); - client.analyze(unpersist).await.unwrap(); + client.analyze(unpersist).await?; - DataFrame { + Ok(DataFrame { spark_session: self.spark_session, plan: self.plan, - } + }) } /// Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns set. @@ -1575,7 +1625,7 @@ mod tests { let spark = setup().await; let df = spark.range(None, 2, 1, None); - df.clone().cache().await; + df.clone().cache().await?; let exp = df.clone().explain(None).await?; assert!(exp.contains("InMemoryTableScan")); @@ -1944,7 +1994,7 @@ mod tests { let df2 = spark.create_dataframe(&data2)?; - let output = df1.except_all(df2).collect().await?; + let output = df1.except_all(df2)?.collect().await?; let c1: ArrayRef = Arc::new(Int64Array::from(vec![1, 10])); let c2: ArrayRef = Arc::new(Int64Array::from(vec![1, 2])); @@ -2141,7 +2191,7 @@ mod tests { let df2 = spark.create_dataframe(&data2)?; - let output = df1.intersect(df2).collect().await?; + let output = df1.intersect(df2)?.collect().await?; let c1: ArrayRef = Arc::new(Int64Array::from(vec![1, 19])); let c2: ArrayRef = Arc::new(Int64Array::from(vec![1, 8]));