diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 0d08b5db906ce..a3d2224f04fd3 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -2212,6 +2212,33 @@ impl Expr { } } } + Expr::SetComparison(SetComparison { + expr, + subquery, + op: _, + quantifier: _, + }) => { + let subquery_schema = subquery.subquery.schema(); + match &subquery_schema.fields()[..] { + [subquery_field] => { + let column = Expr::Column(Column::new_unqualified( + subquery_field.name().clone(), + )); + rewrite_placeholder( + expr.as_mut(), + &column, + subquery_schema, + )?; + } + _ => { + return plan_err!( + "SetComparison should only return one column, but found {}: {}", + subquery_schema.fields().len(), + subquery_schema.field_names().join(", ") + ); + } + } + } Expr::Like(Like { expr, pattern, .. }) | Expr::SimilarTo(Like { expr, pattern, .. }) => { rewrite_placeholder(pattern.as_mut(), expr.as_ref(), schema)?; @@ -3945,6 +3972,112 @@ mod test { } } + #[test] + fn infer_placeholder_set_comparison_any() { + // WHERE $1 = ANY (SELECT a FROM t) -- parallel to infer_placeholder_in_subquery + let subquery_field = Field::new("a", DataType::Int32, false); + let subquery_schema = Arc::new( + DFSchema::from_unqualified_fields( + vec![subquery_field].into(), + Default::default(), + ) + .unwrap(), + ); + let subquery = Subquery { + subquery: Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: subquery_schema, + })), + outer_ref_columns: vec![], + spans: Spans::new(), + }; + + let set_cmp = Expr::SetComparison(SetComparison { + expr: Box::new(Expr::Placeholder(Placeholder { + id: "$1".to_string(), + field: None, + })), + subquery, + op: Operator::Eq, + quantifier: SetQuantifier::Any, + }); + + let outer_schema = DFSchema::empty(); + let (inferred_expr, contains_placeholder) = + set_cmp.infer_placeholder_types(&outer_schema).unwrap(); + + assert!(contains_placeholder); + + match inferred_expr { + Expr::SetComparison(sc) => { + assert_eq!(sc.quantifier, SetQuantifier::Any); + match *sc.expr { + Expr::Placeholder(p) => { + let inferred = + p.field.expect("placeholder field should be Int32"); + assert_eq!(inferred.data_type(), &DataType::Int32); + assert!(inferred.is_nullable()); + } + _ => panic!("Expected Placeholder expression in SetComparison"), + } + } + _ => panic!("Expected SetComparison expression"), + } + } + + #[test] + fn infer_placeholder_set_comparison_all() { + // WHERE $1 <> ALL (SELECT a FROM t) + let subquery_field = Field::new("a", DataType::Int32, false); + let subquery_schema = Arc::new( + DFSchema::from_unqualified_fields( + vec![subquery_field].into(), + Default::default(), + ) + .unwrap(), + ); + let subquery = Subquery { + subquery: Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: subquery_schema, + })), + outer_ref_columns: vec![], + spans: Spans::new(), + }; + + let set_cmp = Expr::SetComparison(SetComparison { + expr: Box::new(Expr::Placeholder(Placeholder { + id: "$1".to_string(), + field: None, + })), + subquery, + op: Operator::NotEq, + quantifier: SetQuantifier::All, + }); + + let outer_schema = DFSchema::empty(); + let (inferred_expr, contains_placeholder) = + set_cmp.infer_placeholder_types(&outer_schema).unwrap(); + + assert!(contains_placeholder); + + match inferred_expr { + Expr::SetComparison(sc) => { + assert_eq!(sc.quantifier, SetQuantifier::All); + match *sc.expr { + Expr::Placeholder(p) => { + let inferred = + p.field.expect("placeholder field should be Int32"); + assert_eq!(inferred.data_type(), &DataType::Int32); + assert!(inferred.is_nullable()); + } + _ => panic!("Expected Placeholder expression in SetComparison"), + } + } + _ => panic!("Expected SetComparison expression"), + } + } + #[test] fn infer_placeholder_like_and_similar_to() { // name LIKE $1