diff --git a/be/src/exec/exchange/local_exchange_source_operator.h b/be/src/exec/exchange/local_exchange_source_operator.h index 58252b24ec2c23..3fdf90b50f0075 100644 --- a/be/src/exec/exchange/local_exchange_source_operator.h +++ b/be/src/exec/exchange/local_exchange_source_operator.h @@ -82,6 +82,11 @@ class LocalExchangeSourceOperatorX final : public OperatorX::required_data_distribution( state); } + const bool child_breaks_distribution = child_breaks_local_key_distribution(state); if (!_needs_finalize && !state->enable_local_exchange_before_agg() && - !(_is_merge && _child && _child->is_serial_operator())) { + !child_breaks_distribution) { return DataSinkOperatorX::required_data_distribution(state); } return _is_colocate && _require_bucket_distribution diff --git a/be/src/exec/operator/distinct_streaming_aggregation_operator.h b/be/src/exec/operator/distinct_streaming_aggregation_operator.h index abf42eb50cc977..4ae652d498c65c 100644 --- a/be/src/exec/operator/distinct_streaming_aggregation_operator.h +++ b/be/src/exec/operator/distinct_streaming_aggregation_operator.h @@ -118,7 +118,8 @@ class DistinctStreamingAggOperatorX final if (_needs_finalize && _probe_expr_ctxs.empty()) { return {ExchangeType::NOOP}; } - if (!_needs_finalize && !state->enable_local_exchange_before_agg()) { + if (!_needs_finalize && !state->enable_local_exchange_before_agg() && + !child_breaks_local_key_distribution(state)) { return StatefulOperatorX::required_data_distribution( state); } @@ -142,6 +143,7 @@ class DistinctStreamingAggOperatorX final private: friend class DistinctStreamingAggLocalState; + void init_make_nullable(RuntimeState* state); TupleId _output_tuple_id; TupleDescriptor* _output_tuple_desc = nullptr; diff --git a/be/src/exec/operator/operator.cpp b/be/src/exec/operator/operator.cpp index d03f75306d5a8a..1ce7dc8727d688 100644 --- a/be/src/exec/operator/operator.cpp +++ b/be/src/exec/operator/operator.cpp @@ -147,6 +147,23 @@ DataDistribution OperatorBase::required_data_distribution(RuntimeState* /*state* : DataDistribution(ExchangeType::NOOP); } +bool OperatorBase::is_hash_shuffle(ExchangeType exchange_type) { + return exchange_type == ExchangeType::HASH_SHUFFLE || + exchange_type == ExchangeType::BUCKET_HASH_SHUFFLE; +} + +bool OperatorBase::child_breaks_local_key_distribution(RuntimeState* state) const { + if (!_child) { + return false; + } + if (_child->is_serial_operator()) { + return true; + } + const auto child_distribution = _child->required_data_distribution(state); + return child_distribution.need_local_exchange() && + !is_hash_shuffle(child_distribution.distribution_type); +} + const RowDescriptor& OperatorBase::row_desc() const { return _child->row_desc(); } diff --git a/be/src/exec/operator/operator.h b/be/src/exec/operator/operator.h index d4b09a4aed3bb5..ce5df951e22497 100644 --- a/be/src/exec/operator/operator.h +++ b/be/src/exec/operator/operator.h @@ -187,6 +187,9 @@ class OperatorBase { RuntimeState* /*state*/) const; protected: + [[nodiscard]] static bool is_hash_shuffle(ExchangeType exchange_type); + [[nodiscard]] bool child_breaks_local_key_distribution(RuntimeState* state) const; + OperatorPtr _child = nullptr; bool _is_closed; diff --git a/be/src/exec/operator/streaming_aggregation_operator.h b/be/src/exec/operator/streaming_aggregation_operator.h index 40a8de2824446a..007da4188651f1 100644 --- a/be/src/exec/operator/streaming_aggregation_operator.h +++ b/be/src/exec/operator/streaming_aggregation_operator.h @@ -224,9 +224,10 @@ class StreamingAggOperatorX MOCK_REMOVE(final) : public StatefulOperatorXis_hash_join_probe() && state->enable_streaming_agg_hash_join_force_passthrough()) { - return DataDistribution(ExchangeType::PASSTHROUGH); + return {ExchangeType::PASSTHROUGH}; } - if (!_needs_finalize && !state->enable_local_exchange_before_agg()) { + if (!_needs_finalize && !state->enable_local_exchange_before_agg() && + !child_breaks_local_key_distribution(state)) { return StatefulOperatorX::required_data_distribution(state); } if (_partition_exprs.empty()) { @@ -235,7 +236,7 @@ class StreamingAggOperatorX MOCK_REMOVE(final) : public StatefulOperatorX::required_data_distribution( state); } - return DataDistribution(ExchangeType::HASH_SHUFFLE, _partition_exprs); + return {ExchangeType::HASH_SHUFFLE, _partition_exprs}; } private: diff --git a/be/src/exec/pipeline/pipeline.cpp b/be/src/exec/pipeline/pipeline.cpp index de3c852ada1bb6..60c395e447f4b5 100644 --- a/be/src/exec/pipeline/pipeline.cpp +++ b/be/src/exec/pipeline/pipeline.cpp @@ -21,6 +21,7 @@ #include #include +#include "exec/exchange/local_exchange_source_operator.h" #include "exec/operator/operator.h" #include "exec/pipeline/pipeline_fragment_context.h" #include "exec/pipeline/pipeline_task.h" @@ -53,7 +54,14 @@ bool Pipeline::need_to_local_exchange(const DataDistribution target_data_distrib // If non-serial operators exist, we should improve parallelism for those. return true; } - + if (auto local_exchange_source = + std::dynamic_pointer_cast(_operators.front()); + local_exchange_source && is_hash_exchange(target_data_distribution.distribution_type)) { + const auto source_exchange_type = local_exchange_source->exchange_type(); + if (source_exchange_type != ExchangeType::NOOP && !is_hash_exchange(source_exchange_type)) { + return true; + } + } if (target_data_distribution.distribution_type != ExchangeType::BUCKET_HASH_SHUFFLE && target_data_distribution.distribution_type != ExchangeType::HASH_SHUFFLE) { // Always do local exchange if non-hash-partition exchanger is required. diff --git a/be/test/exec/operator/agg_operator_test.cpp b/be/test/exec/operator/agg_operator_test.cpp index 75592bfa0978ee..02b6a79bb72f6f 100644 --- a/be/test/exec/operator/agg_operator_test.cpp +++ b/be/test/exec/operator/agg_operator_test.cpp @@ -22,12 +22,14 @@ #include "core/data_type/data_type_nullable.h" #include "core/data_type/data_type_number.h" +#include "exec/exchange/local_exchange_source_operator.h" #include "exec/operator/aggregation_sink_operator.h" #include "exec/operator/aggregation_source_operator.h" #include "exec/operator/assert_num_rows_operator.h" #include "exec/operator/mock_operator.h" #include "exec/operator/operator_helper.h" #include "exec/pipeline/dependency.h" +#include "exec/pipeline/pipeline.h" #include "testutil/column_helper.h" #include "testutil/mock/mock_agg_fn_evaluator.h" #include "testutil/mock/mock_slot_ref.h" @@ -98,6 +100,23 @@ struct MockAggSourceOperator : public AggSourceOperatorX { std::unique_ptr mock_row_descriptor; }; +class MockDistributionOperator final : public OperatorX { +public: + MockDistributionOperator(ExchangeType exchange_type) : _exchange_type(exchange_type) {} + + Status get_block(RuntimeState* /*state*/, Block* /*block*/, bool* eos) override { + *eos = true; + return Status::OK(); + } + + DataDistribution required_data_distribution(RuntimeState* /*state*/) const override { + return {_exchange_type}; + } + +private: + ExchangeType _exchange_type; +}; + std::shared_ptr create_agg_sink_op(OperatorContext& ctx, bool is_merge, bool without_key) { auto op = std::make_shared(); @@ -108,6 +127,44 @@ std::shared_ptr create_agg_sink_op(OperatorContext& ctx, bool return op; } +TEST(AggOperatorRequiredDistributionTest, require_hash_shuffle_after_non_hash_child_exchange) { + OperatorContext ctx; + auto sink_op = std::make_shared(); + sink_op->_partition_exprs.emplace_back(); + sink_op->_needs_finalize = false; + OperatorPtr child = + std::make_shared(ExchangeType::ADAPTIVE_PASSTHROUGH); + sink_op->_child = child; + + const auto distribution = sink_op->required_data_distribution(&ctx.state); + EXPECT_EQ(ExchangeType::HASH_SHUFFLE, distribution.distribution_type); +} + +TEST(AggOperatorRequiredDistributionTest, require_hash_shuffle_after_non_hash_local_exchange) { + OperatorContext ctx; + auto sink_op = std::make_shared(); + sink_op->_needs_finalize = false; + OperatorPtr child = std::make_shared(); + EXPECT_TRUE(child->init(ExchangeType::ADAPTIVE_PASSTHROUGH).ok()); + sink_op->_child = child; + + TExpr distinct_agg_expr; + distinct_agg_expr.nodes.emplace_back(); + distinct_agg_expr.nodes[0].fn.name.function_name = "multi_distinct_count"; + TPlanNode tnode; + tnode.agg_node.aggregate_functions.push_back(distinct_agg_expr); + tnode.__set_distribute_expr_lists({{TExpr {}}}); + sink_op->update_operator(tnode, false, false); + + const auto distribution = sink_op->required_data_distribution(&ctx.state); + EXPECT_EQ(ExchangeType::HASH_SHUFFLE, distribution.distribution_type); + + Pipeline pipeline(0, 4, 4); + EXPECT_TRUE(pipeline.add_operator(child, 0).ok()); + pipeline.set_data_distribution(DataDistribution(ExchangeType::HASH_SHUFFLE)); + EXPECT_TRUE(pipeline.need_to_local_exchange(distribution, 1)); +} + std::shared_ptr create_agg_source_op(OperatorContext& ctx, bool without_key, bool needs_finalize) { auto op = std::make_shared(); diff --git a/be/test/exec/operator/distinct_streaming_aggregation_operator_test.cpp b/be/test/exec/operator/distinct_streaming_aggregation_operator_test.cpp index 88434e47fd7e01..17282356625d9a 100644 --- a/be/test/exec/operator/distinct_streaming_aggregation_operator_test.cpp +++ b/be/test/exec/operator/distinct_streaming_aggregation_operator_test.cpp @@ -22,6 +22,7 @@ #include #include "core/block/block.h" +#include "exec/exchange/local_exchange_source_operator.h" #include "exec/operator/mock_operator.h" #include "exec/operator/operator_helper.h" #include "testutil/column_helper.h" @@ -97,6 +98,20 @@ TEST_F(DistinctStreamingAggOperatorTest, test1) { } } +TEST_F(DistinctStreamingAggOperatorTest, require_hash_shuffle_after_non_hash_local_exchange) { + state->_query_options.__set_enable_local_exchange_before_agg(false); + op->_is_streaming_preagg = false; + op->_partition_exprs.emplace_back(); + op->_probe_expr_ctxs = MockSlotRef::create_mock_contexts(0, std::make_shared()); + + OperatorPtr child = std::make_shared(); + EXPECT_TRUE(child->init(ExchangeType::ADAPTIVE_PASSTHROUGH).ok()); + op->_child = child; + + const auto distribution = op->required_data_distribution(state.get()); + EXPECT_EQ(ExchangeType::HASH_SHUFFLE, distribution.distribution_type); +} + TEST_F(DistinctStreamingAggOperatorTest, test2) { op->_is_streaming_preagg = false; op->_limit = 3; diff --git a/be/test/exec/operator/streaming_agg_operator_test.cpp b/be/test/exec/operator/streaming_agg_operator_test.cpp index 0421d58bfd256b..bbe54ebec5d21f 100644 --- a/be/test/exec/operator/streaming_agg_operator_test.cpp +++ b/be/test/exec/operator/streaming_agg_operator_test.cpp @@ -23,6 +23,7 @@ #include "core/data_type/data_type_bitmap.h" #include "core/data_type/data_type_number.h" #include "core/value/bitmap_value.h" +#include "exec/exchange/local_exchange_source_operator.h" #include "exec/operator/aggregation_sink_operator.h" #include "exec/operator/aggregation_source_operator.h" #include "exec/operator/mock_operator.h" @@ -152,6 +153,19 @@ TEST_F(StreamingAggOperatorTest, test1) { { EXPECT_TRUE(local_state->close(state.get()).ok()); } } +TEST_F(StreamingAggOperatorTest, require_hash_shuffle_after_non_hash_local_exchange) { + state->_query_options.__set_enable_local_exchange_before_agg(false); + op->_needs_finalize = false; + op->_partition_exprs.emplace_back(); + + OperatorPtr child = std::make_shared(); + EXPECT_TRUE(child->init(ExchangeType::ADAPTIVE_PASSTHROUGH).ok()); + EXPECT_TRUE(op->set_child(child)); + + const auto distribution = op->required_data_distribution(state.get()); + EXPECT_EQ(ExchangeType::HASH_SHUFFLE, distribution.distribution_type); +} + TEST_F(StreamingAggOperatorTest, test2) { op->_aggregate_evaluators.push_back(create_mock_agg_fn_evaluator( pool, MockSlotRef::create_mock_contexts(1, std::make_shared()), false, diff --git a/regression-test/data/query_p0/join/test_agg_after_nested_loop_join_local_exchange.out b/regression-test/data/query_p0/join/test_agg_after_nested_loop_join_local_exchange.out new file mode 100644 index 00000000000000..cd357b13436302 --- /dev/null +++ b/regression-test/data/query_p0/join/test_agg_after_nested_loop_join_local_exchange.out @@ -0,0 +1,3 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !agg_after_nlj_local_exchange -- +10 5070261 diff --git a/regression-test/suites/query_p0/join/test_agg_after_nested_loop_join_local_exchange.groovy b/regression-test/suites/query_p0/join/test_agg_after_nested_loop_join_local_exchange.groovy new file mode 100644 index 00000000000000..91889b319c8a4b --- /dev/null +++ b/regression-test/suites/query_p0/join/test_agg_after_nested_loop_join_local_exchange.groovy @@ -0,0 +1,130 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_agg_after_nested_loop_join_local_exchange", "query_p0") { + sql "DROP TABLE IF EXISTS test_agg_after_nlj_local_exchange_t1" + sql "DROP TABLE IF EXISTS test_agg_after_nlj_local_exchange_t2" + + sql """ + CREATE TABLE test_agg_after_nlj_local_exchange_t1 ( + col_bigint_undef_signed BIGINT, + col_varchar_10__undef_signed VARCHAR(10), + col_varchar_64__undef_signed VARCHAR(64), + pk INT + ) + ENGINE=OLAP + DISTRIBUTED BY HASH(pk) BUCKETS 10 + PROPERTIES("replication_num" = "1") + """ + + sql """ + INSERT INTO test_agg_after_nlj_local_exchange_t1 + (pk, col_bigint_undef_signed, col_varchar_10__undef_signed, col_varchar_64__undef_signed) + VALUES + (0, -94, 'had', 'y'), + (1, 672609, 'k', 'h'), + (2, -3766684, 'a', 'p'), + (3, 5070261, 'on', 'x'), + (4, NULL, 'u', 'at'), + (5, -86, 'v', 'c'), + (6, 21910, 'how', 'm'), + (7, -63, 'that''s', 'go'), + (8, -8276281, 's', 'a'), + (9, -101, 'w', 'y') + """ + + sql """ + CREATE TABLE test_agg_after_nlj_local_exchange_t2 ( + pk INT, + col_varchar_10__undef_signed VARCHAR(10), + col_bigint_undef_signed BIGINT, + col_varchar_64__undef_signed VARCHAR(64) + ) + ENGINE=OLAP + DUPLICATE KEY(pk, col_varchar_10__undef_signed) + DISTRIBUTED BY HASH(pk) BUCKETS 10 + PROPERTIES("replication_num" = "1") + """ + + sql """ + INSERT INTO test_agg_after_nlj_local_exchange_t2 + (pk, col_bigint_undef_signed, col_varchar_10__undef_signed, col_varchar_64__undef_signed) + VALUES + (0, NULL, 'right', 'g'), + (1, -486256, 'on', 'on'), + (2, -1, 'I''ll', 'at'), + (3, 29263, 'h', 'don''t'), + (4, 5453, 'a', 's'), + (5, -119, 'j', 'can''t'), + (6, 89, 'one', 'n'), + (7, -7227, 's', 'u'), + (8, 94, 'time', 'b'), + (9, 1816630, 'yes', 'yes') + """ + + sql "SYNC" + + sql "SET default_variant_doc_hash_shard_count = 0" + sql "SET default_variant_max_subcolumns_count = 4" + sql "SET default_variant_sparse_hash_shard_count = 4" + sql "SET disable_join_reorder = true" + sql "SET disable_streaming_preaggregations = true" + sql "SET enable_common_expr_pushdown = false" + sql "SET enable_common_expr_pushdown_for_inverted_index = false" + sql "SET enable_distinct_streaming_agg_force_passthrough = false" + sql "SET enable_function_pushdown = true" + sql "SET enable_local_exchange_before_agg = false" + sql "SET enable_runtime_filter_partition_prune = false" + sql "SET enable_runtime_filter_prune = false" + sql "SET enable_strong_consistency_read = true" + sql "SET enable_sync_runtime_filter_size = false" + sql "SET exchange_multi_blocks_byte_size = 5563624" + sql "SET experimental_enable_parallel_scan = false" + sql "SET parallel_pipeline_task_num = 4" + sql "SET parallel_prepare_threshold = 28" + sql "SET query_timeout = 600" + sql "SET runtime_filter_type = 'IN,MIN_MAX'" + sql "SET runtime_filter_wait_time_ms = 5000" + sql "SET topn_opt_limit_threshold = 1000" + sql "SET agg_phase = 4" + + order_qt_agg_after_nlj_local_exchange """ + SELECT + COUNT(DISTINCT table1.`pk`) AS field1, + MAX(table1.col_bigint_undef_signed) AS field2 + FROM + test_agg_after_nlj_local_exchange_t1 AS table1 + LEFT OUTER JOIN test_agg_after_nlj_local_exchange_t2 AS table2 + ON table2.col_varchar_10__undef_signed = table2.col_varchar_64__undef_signed + LEFT JOIN test_agg_after_nlj_local_exchange_t1 AS table3 + ON table2.col_varchar_10__undef_signed = table2.col_varchar_64__undef_signed + WHERE + table1.`pk` > 3 + AND table1.`pk` < (3 + 25) + OR table1.col_varchar_64__undef_signed > 'cnvUBxJyCp' + AND table1.col_varchar_64__undef_signed <= 'z' + OR table1.col_bigint_undef_signed != 2 + OR table1.`pk` NOT BETWEEN 2 AND (2 + 1) + AND table1.`pk` > 7 + AND table1.`pk` < (7 + 2) + AND table1.`pk` IN (2, 10, 2) + ORDER BY + field1, + field2 + LIMIT 1000 + """ +}