From 9b6741c275e0fd7b8b7c68149612e8043523cce7 Mon Sep 17 00:00:00 2001 From: Ariel Miculas Date: Tue, 26 May 2026 13:49:15 +0300 Subject: [PATCH] feat: demonstrate issue with partial aggregation Related to #22526 --- datafusion/physical-plan/Cargo.toml | 5 + .../physical-plan/src/aggregates/row_hash.rs | 96 +++++++++++++++++++ 2 files changed, 101 insertions(+) diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml index 5a05173eb370f..d15baabf8e84b 100644 --- a/datafusion/physical-plan/Cargo.toml +++ b/datafusion/physical-plan/Cargo.toml @@ -125,3 +125,8 @@ name = "dictionary_group_values" [[bench]] harness = false name = "multi_group_by" + +[[bench]] +harness = false +name = "partial_hash_agg" +required-features = ["test_utils"] diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 1164fb37b384a..e2d366cbe27ba 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -1817,4 +1817,100 @@ mod tests { Ok(()) } + + /// Verify that partial aggregation emits batches whose memory footprint is + /// proportional to `batch_size`, so a downstream `RepartitionExec` does not + /// over-reserve memory and spill unnecessarily. + /// + /// With high-cardinality input (one unique group per row) the hash table + /// grows until early-emission kicks in. Each emitted batch must be ≤ a + /// small constant multiple of `batch_size × bytes_per_row`. + #[tokio::test] + async fn test_partial_agg_batch_memory_size_is_bounded() -> Result<()> { + let batch_size: usize = 1024; + + let schema = Arc::new(Schema::new(vec![ + Field::new("group_col", DataType::Int32, false), + Field::new("value_col", DataType::Int64, false), + ])); + + // One unique group per row — worst case for memory growth before emit. + let num_rows = batch_size * 20; + let group_ids: Vec = (0..num_rows as i32).collect(); + let values: Vec = vec![1i64; num_rows]; + let input_batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(group_ids)), + Arc::new(Int64Array::from(values)), + ], + )?; + + // Tight memory limit so early emission fires before all rows are processed. + let runtime = RuntimeEnvBuilder::default() + .with_memory_limit(512 * 1024, 1.0) + .build_arc()?; + let mut task_ctx = TaskContext::default().with_runtime(runtime); + let mut cfg = task_ctx.session_config().clone(); + cfg = cfg.set( + "datafusion.execution.batch_size", + &datafusion_common::ScalarValue::UInt64(Some(batch_size as u64)), + ); + // Disable skip-aggregation so every row goes through the hash table. + cfg = cfg.set( + "datafusion.execution.skip_partial_aggregation_probe_rows_threshold", + &datafusion_common::ScalarValue::UInt64(Some(u64::MAX)), + ); + task_ctx = task_ctx.with_session_config(cfg); + let task_ctx = Arc::new(task_ctx); + + let group_expr = vec![(col("group_col", &schema)?, "group_col".to_string())]; + let aggr_expr = vec![Arc::new( + AggregateExprBuilder::new(count_udaf(), vec![col("value_col", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("cnt") + .build()?, + )]; + + let exec = TestMemoryExec::try_new( + &[vec![input_batch]], + Arc::clone(&schema), + None, + )?; + let exec = Arc::new(TestMemoryExec::update_cache(&Arc::new(exec))); + + let aggregate_exec = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::new_single(group_expr), + aggr_expr, + vec![None], + exec, + Arc::clone(&schema), + )?; + + let mut stream = + GroupedHashAggregateStream::new(&aggregate_exec, &Arc::clone(&task_ctx), 0)?; + + // group_col (i32) + count state (i64) = 12 bytes per row; use 4× as a + // generous upper bound to account for Arrow buffer alignment/overhead. + let bytes_per_row: usize = 12; + let max_allowed = batch_size * bytes_per_row * 4; + + let mut total_rows = 0usize; + while let Some(result) = stream.next().await { + let batch = result?; + let mem = get_record_batch_memory_size(&batch); + assert!( + mem <= max_allowed, + "emitted batch has {mem} bytes, expected ≤ {max_allowed} \ + (batch_size={batch_size}, rows={})", + batch.num_rows(), + ); + total_rows += batch.num_rows(); + } + + assert_eq!(total_rows, num_rows, "all input groups must be emitted"); + + Ok(()) + } }