diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml index c64d3cad694a..675146b7c30b 100644 --- a/datafusion/physical-plan/Cargo.toml +++ b/datafusion/physical-plan/Cargo.toml @@ -39,14 +39,18 @@ workspace = true [features] force_hash_collisions = [] -test_utils = ["arrow/test_utils"] -tokio_coop = [] -tokio_coop_fallback = [] +# Enables `PhysicalExpr::try_to_proto` / `try_from_proto` hooks on the +# physical expressions defined in this crate (e.g. `HashExpr`). Off by +# default so consumers that never serialize plans pay nothing. proto = [ "dep:datafusion-proto-models", "dep:datafusion-proto-common", + "datafusion-physical-expr/proto", "datafusion-physical-expr-common/proto", ] +test_utils = ["arrow/test_utils"] +tokio_coop = [] +tokio_coop_fallback = [] [lib] name = "datafusion_physical_plan" diff --git a/datafusion/physical-plan/src/joins/hash_join/partitioned_hash_eval.rs b/datafusion/physical-plan/src/joins/hash_join/partitioned_hash_eval.rs index 46b087ad70b2..747d56c3d564 100644 --- a/datafusion/physical-plan/src/joins/hash_join/partitioned_hash_eval.rs +++ b/datafusion/physical-plan/src/joins/hash_join/partitioned_hash_eval.rs @@ -27,6 +27,8 @@ use arrow::{ use datafusion_common::Result; use datafusion_common::hash_utils::RandomState; use datafusion_common::hash_utils::{create_hashes, with_hashes}; +#[cfg(feature = "proto")] +use datafusion_common::internal_err; use datafusion_expr::ColumnarValue; use datafusion_physical_expr_common::physical_expr::{ DynHash, PhysicalExpr, PhysicalExprRef, @@ -199,6 +201,63 @@ impl PhysicalExpr for HashExpr { fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.description) } + + #[cfg(feature = "proto")] + fn try_to_proto( + &self, + ctx: &datafusion_physical_expr_common::physical_expr::proto_encode::PhysicalExprEncodeCtx<'_>, + ) -> Result> { + use datafusion_proto_models::protobuf; + let on_columns = self + .on_columns + .iter() + .map(|e| ctx.encode_child(e)) + .collect::>>()?; + Ok(Some(protobuf::PhysicalExprNode { + expr_id: None, + expr_type: Some(protobuf::physical_expr_node::ExprType::HashExpr( + protobuf::PhysicalHashExprNode { + on_columns, + seed0: self.seed(), + description: self.description.clone(), + }, + )), + })) + } +} + +#[cfg(feature = "proto")] +impl HashExpr { + /// Reconstruct a [`HashExpr`] from its protobuf representation. + /// + /// Takes the whole [`PhysicalExprNode`], the exact inverse of what + /// [`PhysicalExpr::try_to_proto`] produces, so every expression's + /// `try_from_proto` shares one signature. Child sub-expressions are + /// decoded recursively via [`PhysicalExprDecodeCtx::decode`]. + /// + /// [`PhysicalExprNode`]: datafusion_proto_models::protobuf::PhysicalExprNode + /// [`PhysicalExpr::try_to_proto`]: datafusion_physical_expr_common::physical_expr::PhysicalExpr::try_to_proto + /// [`PhysicalExprDecodeCtx::decode`]: datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx::decode + pub fn try_from_proto( + node: &datafusion_proto_models::protobuf::PhysicalExprNode, + ctx: &datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx<'_>, + ) -> Result> { + use datafusion_proto_models::protobuf; + let hash_expr = match &node.expr_type { + Some(protobuf::physical_expr_node::ExprType::HashExpr(h)) => h, + _ => return internal_err!("PhysicalExprNode is not a HashExpr"), + }; + let on_columns = hash_expr + .on_columns + .iter() + .map(|e| ctx.decode(e)) + .collect::>>()?; + Ok(Arc::new(HashExpr::new( + on_columns, + SeededRandomState::with_seed(hash_expr.seed0), + hash_expr.description.clone(), + ))) + } } /// Physical expression that checks join keys in a [`Map`] (hash table or array map). @@ -498,6 +557,172 @@ mod tests { assert_eq!(compute_hash(&expr1), compute_hash(&expr2)); } + #[cfg(feature = "proto")] + mod proto_tests { + use super::*; + use arrow::datatypes::{DataType, Field}; + use datafusion_common::internal_datafusion_err; + use datafusion_physical_expr_common::physical_expr::proto_decode::{ + PhysicalExprDecode, PhysicalExprDecodeCtx, + }; + use datafusion_physical_expr_common::physical_expr::proto_encode::{ + PhysicalExprEncode, PhysicalExprEncodeCtx, + }; + use datafusion_proto_models::protobuf; + + struct TestEncoder; + + impl PhysicalExprEncode for TestEncoder { + fn encode( + &self, + expr: &Arc, + ) -> Result { + let ctx = PhysicalExprEncodeCtx::new(self); + expr.try_to_proto(&ctx)?.ok_or_else(|| { + internal_datafusion_err!("test encoder cannot encode {expr:?}") + }) + } + } + + struct TestDecoder; + + impl PhysicalExprDecode for TestDecoder { + fn decode( + &self, + node: &protobuf::PhysicalExprNode, + schema: &Schema, + ) -> Result> { + let ctx = PhysicalExprDecodeCtx::new(schema, self); + match &node.expr_type { + Some(protobuf::physical_expr_node::ExprType::Column(_)) => { + Column::try_from_proto(node, &ctx) + } + _ => internal_err!("test decoder cannot decode {node:?}"), + } + } + } + + fn test_decode_ctx<'a>( + schema: &'a Schema, + decoder: &'a TestDecoder, + ) -> PhysicalExprDecodeCtx<'a> { + PhysicalExprDecodeCtx::new(schema, decoder) + } + + #[test] + fn hash_expr_try_to_proto() { + let expr = HashExpr::new( + vec![Arc::new(Column::new("a", 0)), Arc::new(Column::new("b", 1))], + SeededRandomState::with_seed(42), + "hash_join".to_string(), + ); + let encoder = TestEncoder; + let ctx = PhysicalExprEncodeCtx::new(&encoder); + + let proto = expr.try_to_proto(&ctx).unwrap().unwrap(); + + assert_eq!(proto.expr_id, None); + let hash_expr = match proto.expr_type.unwrap() { + protobuf::physical_expr_node::ExprType::HashExpr(hash_expr) => hash_expr, + other => panic!("expected HashExpr, got {other:?}"), + }; + assert_eq!(hash_expr.seed0, 42); + assert_eq!(hash_expr.description, "hash_join"); + assert_eq!(hash_expr.on_columns.len(), 2); + assert!( + hash_expr + .on_columns + .iter() + .all(|expr| expr.expr_id.is_none()) + ); + } + + #[test] + fn hash_expr_try_from_proto() { + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, true), + ]); + let decoder = TestDecoder; + let ctx = test_decode_ctx(&schema, &decoder); + let proto = protobuf::PhysicalExprNode { + expr_id: None, + expr_type: Some(protobuf::physical_expr_node::ExprType::HashExpr( + protobuf::PhysicalHashExprNode { + on_columns: vec![ + protobuf::PhysicalExprNode { + expr_id: None, + expr_type: Some( + protobuf::physical_expr_node::ExprType::Column( + protobuf::PhysicalColumn { + name: "a".to_string(), + index: 0, + }, + ), + ), + }, + protobuf::PhysicalExprNode { + expr_id: None, + expr_type: Some( + protobuf::physical_expr_node::ExprType::Column( + protobuf::PhysicalColumn { + name: "b".to_string(), + index: 1, + }, + ), + ), + }, + ], + seed0: 42, + description: "hash_join".to_string(), + }, + )), + }; + + let expr = HashExpr::try_from_proto(&proto, &ctx).unwrap(); + let expr = expr.downcast_ref::().unwrap(); + + assert_eq!(expr.seed(), 42); + assert_eq!(expr.description(), "hash_join"); + assert_eq!(expr.on_columns().len(), 2); + assert_eq!( + expr.on_columns()[0] + .downcast_ref::() + .map(|col| (col.name(), col.index())), + Some(("a", 0)) + ); + assert_eq!( + expr.on_columns()[1] + .downcast_ref::() + .map(|col| (col.name(), col.index())), + Some(("b", 1)) + ); + } + + #[test] + fn hash_expr_try_from_proto_rejects_wrong_node_type() { + let schema = Schema::empty(); + let decoder = TestDecoder; + let ctx = test_decode_ctx(&schema, &decoder); + let proto = protobuf::PhysicalExprNode { + expr_id: None, + expr_type: Some(protobuf::physical_expr_node::ExprType::Column( + protobuf::PhysicalColumn { + name: "a".to_string(), + index: 0, + }, + )), + }; + + let err = HashExpr::try_from_proto(&proto, &ctx).unwrap_err(); + assert!( + err.to_string() + .contains("PhysicalExprNode is not a HashExpr"), + "{err}" + ); + } + } + #[test] fn test_hash_table_lookup_expr_eq_same() { let col_a: PhysicalExprRef = Arc::new(Column::new("a", 0)); diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 96144b11e9d3..a20ddbd07170 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -46,7 +46,7 @@ use datafusion_physical_plan::expressions::{ BinaryExpr, CaseExpr, CastExpr, Column, InListExpr, IsNotNullExpr, IsNullExpr, LikeExpr, Literal, NegativeExpr, NotExpr, TryCastExpr, UnKnownColumn, }; -use datafusion_physical_plan::joins::{HashExpr, SeededRandomState}; +use datafusion_physical_plan::joins::HashExpr; use datafusion_physical_plan::windows::{create_window_expr, schema_add_window_field}; use datafusion_physical_plan::{Partitioning, PhysicalExpr, WindowExpr}; use datafusion_proto_common::common::proto_error; @@ -416,19 +416,7 @@ pub fn parse_physical_expr_with_converter( ) } ExprType::LikeExpr(_) => LikeExpr::try_from_proto(proto, &decode_ctx)?, - ExprType::HashExpr(hash_expr) => { - let on_columns = parse_physical_exprs( - &hash_expr.on_columns, - ctx, - input_schema, - proto_converter, - )?; - Arc::new(HashExpr::new( - on_columns, - SeededRandomState::with_seed(hash_expr.seed0), - hash_expr.description.clone(), - )) - } + ExprType::HashExpr(_) => HashExpr::try_from_proto(proto, &decode_ctx)?, ExprType::ScalarSubquery(sq) => { let data_type: arrow::datatypes::DataType = sq .data_type diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index c359f651c0e1..f56f7b1fb35f 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -39,7 +39,6 @@ use datafusion_physical_plan::expressions::{ CaseExpr, CastExpr, DynamicFilterPhysicalExpr, IsNotNullExpr, IsNullExpr, Literal, NegativeExpr, NotExpr, TryCastExpr, UnKnownColumn, }; -use datafusion_physical_plan::joins::HashExpr; use datafusion_physical_plan::udaf::AggregateFunctionExpr; use datafusion_physical_plan::windows::{PlainAggregateWindowExpr, WindowUDFExpr}; use datafusion_physical_plan::{Partitioning, PhysicalExpr, WindowExpr}; @@ -448,21 +447,6 @@ pub fn serialize_physical_expr_with_converter( }, )), }) - } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_id, - expr_type: Some(protobuf::physical_expr_node::ExprType::HashExpr( - protobuf::PhysicalHashExprNode { - on_columns: serialize_physical_exprs( - expr.on_columns(), - codec, - proto_converter, - )?, - seed0: expr.seed(), - description: expr.description().to_string(), - }, - )), - }) } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { expr_id,