Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion be/src/exec/operator/aggregation_sink_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ class AggSinkOperatorX MOCK_REMOVE(final) : public DataSinkOperatorX<AggSinkLoca
state);
}
if (!_needs_finalize && !state->enable_local_exchange_before_agg() &&
!(_is_merge && _child && _child->is_serial_operator())) {
!(_is_merge && _child_breaks_local_key_distribution(state))) {
return DataSinkOperatorX<AggSinkLocalState>::required_data_distribution(state);
}
return _is_colocate && _require_bucket_distribution
Expand Down Expand Up @@ -193,6 +193,23 @@ class AggSinkOperatorX MOCK_REMOVE(final) : public DataSinkOperatorX<AggSinkLoca
using DataSinkOperatorX<AggSinkLocalState>::get_local_state;

protected:
static bool _is_hash_shuffle(ExchangeType exchange_type) {
return exchange_type == ExchangeType::HASH_SHUFFLE ||
exchange_type == ExchangeType::BUCKET_HASH_SHUFFLE;
}

bool _child_breaks_local_key_distribution(RuntimeState* state) const {
if (!_child) {
return false;
}
if (_child->is_serial_operator()) {
return true;
}
const auto child_distribution = _child->required_data_distribution(state);
return child_distribution.need_local_exchange() &&
!_is_hash_shuffle(child_distribution.distribution_type);
}

MOCK_FUNCTION Status _init_probe_expr_ctx(RuntimeState* state);

MOCK_FUNCTION Status _init_aggregate_evaluators(RuntimeState* state);
Expand Down
30 changes: 30 additions & 0 deletions be/test/exec/operator/agg_operator_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,23 @@ struct MockAggSourceOperator : public AggSourceOperatorX {
std::unique_ptr<RowDescriptor> mock_row_descriptor;
};

class MockDistributionOperator final : public OperatorX<MockLocalState> {
public:
MockDistributionOperator(ExchangeType exchange_type) : _exchange_type(exchange_type) {}

Status get_block(RuntimeState* /*state*/, Block* /*block*/, bool* eos) override {
*eos = true;
return Status::OK();
}

DataDistribution required_data_distribution(RuntimeState* /*state*/) const override {
return {_exchange_type};
}

private:
ExchangeType _exchange_type;
};

std::shared_ptr<AggSinkOperatorX> create_agg_sink_op(OperatorContext& ctx, bool is_merge,
bool without_key) {
auto op = std::make_shared<MockAggsinkOperator>();
Expand All @@ -108,6 +125,19 @@ std::shared_ptr<AggSinkOperatorX> create_agg_sink_op(OperatorContext& ctx, bool
return op;
}

TEST(AggOperatorRequiredDistributionTest, require_hash_shuffle_after_non_hash_child_exchange) {
OperatorContext ctx;
auto sink_op = std::make_shared<MockAggsinkOperator>();
sink_op->_partition_exprs.emplace_back();
sink_op->_needs_finalize = false;
sink_op->_is_merge = true;
sink_op->_child =
std::make_shared<MockDistributionOperator>(ExchangeType::ADAPTIVE_PASSTHROUGH);

const auto distribution = sink_op->required_data_distribution(&ctx.state);
EXPECT_EQ(ExchangeType::HASH_SHUFFLE, distribution.distribution_type);
}

std::shared_ptr<AggSourceOperatorX> create_agg_source_op(OperatorContext& ctx, bool without_key,
bool needs_finalize) {
auto op = std::make_shared<MockAggSourceOperator>();
Expand Down
Loading