Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion crates/connect/src/client/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ impl Default for ChannelBuilder {
Err(_) => "sc://localhost:15002".to_string(),
};

ChannelBuilder::create(&connection).unwrap()
ChannelBuilder::create(&connection).unwrap_or_else(|_| {
ChannelBuilder::create("sc://localhost:15002")
.expect("default connection must be valid")
})
}
}

Expand Down
18 changes: 14 additions & 4 deletions crates/connect/src/client/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,20 @@ where

Box::pin(async move {
for (key, value) in &headers {
let meta_key = HeaderName::from_str(key.as_str()).unwrap();
let meta_val = HeaderValue::from_str(value.as_str()).unwrap();

request.headers_mut().insert(meta_key, meta_val);
match (
HeaderName::from_str(key.as_str()),
HeaderValue::from_str(value.as_str()),
) {
(Ok(meta_key), Ok(meta_val)) => {
request.headers_mut().insert(meta_key, meta_val);
}
(Err(e), _) => {
eprintln!("skipping header with invalid name '{}': {}", key, e);
}
(_, Err(e)) => {
eprintln!("skipping header with invalid value '{}': {}", key, e);
}
}
}

inner.call(request).await
Expand Down
28 changes: 22 additions & 6 deletions crates/connect/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,9 @@ where
let req = spark::ReattachExecuteRequest {
session_id: self.session_id(),
user_context: self.user_context.clone(),
operation_id: self.operation_id.clone().unwrap(),
operation_id: self.operation_id.clone().ok_or_else(|| {
SparkError::AnalysisException("operation_id is not set".to_string())
})?,
client_type: self.builder.user_agent.clone(),
last_response_id: self.response_id.clone(),
};
Expand Down Expand Up @@ -249,7 +251,9 @@ where

async fn release_until(&mut self) -> Result<(), SparkError> {
let release_until = spark::release_execute_request::ReleaseUntil {
response_id: self.response_id.clone().unwrap(),
response_id: self.response_id.clone().ok_or_else(|| {
SparkError::AnalysisException("response_id is not set".to_string())
})?,
};

self.release_execute(Some(spark::release_execute_request::Release::ReleaseUntil(
Expand All @@ -276,7 +280,9 @@ where
let req = spark::ReleaseExecuteRequest {
session_id: self.session_id(),
user_context: self.user_context.clone(),
operation_id: self.operation_id.clone().unwrap(),
operation_id: self.operation_id.clone().ok_or_else(|| {
SparkError::AnalysisException("operation_id is not set".to_string())
})?,
client_type: self.builder.user_agent.clone(),
release,
};
Expand Down Expand Up @@ -435,7 +441,9 @@ where
}
ResponseType::ResultComplete(_) => self.handler.result_complete = true,
ResponseType::Extension(_) => {
unimplemented!("extension response types are not implemented")
return Err(SparkError::NotYetImplemented(
"extension response types are not implemented".to_string(),
))
}
}
}
Expand Down Expand Up @@ -560,8 +568,16 @@ where
let col = rows.column(0);

let data: &arrow::array::StringArray = match col.data_type() {
arrow::datatypes::DataType::Utf8 => col.as_any().downcast_ref().unwrap(),
_ => unimplemented!("only Utf8 data types are currently handled currently."),
arrow::datatypes::DataType::Utf8 => col.as_any().downcast_ref().ok_or_else(|| {
SparkError::AnalysisException(
"failed to downcast column to StringArray".to_string(),
)
})?,
_ => {
return Err(SparkError::NotYetImplemented(
"only Utf8 data types are currently handled".to_string(),
))
}
};

Ok(data.value(0).to_string())
Expand Down