Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions datafusion/physical-plan/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,8 @@ name = "dictionary_group_values"
[[bench]]
harness = false
name = "multi_group_by"

[[bench]]
harness = false
name = "partial_hash_agg"
required-features = ["test_utils"]
96 changes: 96 additions & 0 deletions datafusion/physical-plan/src/aggregates/row_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1817,4 +1817,100 @@ mod tests {

Ok(())
}

/// Verify that partial aggregation emits batches whose memory footprint is
/// proportional to `batch_size`, so a downstream `RepartitionExec` does not
/// over-reserve memory and spill unnecessarily.
///
/// With high-cardinality input (one unique group per row) the hash table
/// grows until early-emission kicks in. Each emitted batch must be ≤ a
/// small constant multiple of `batch_size × bytes_per_row`.
#[tokio::test]
async fn test_partial_agg_batch_memory_size_is_bounded() -> Result<()> {
let batch_size: usize = 1024;

let schema = Arc::new(Schema::new(vec![
Field::new("group_col", DataType::Int32, false),
Field::new("value_col", DataType::Int64, false),
]));

// One unique group per row — worst case for memory growth before emit.
let num_rows = batch_size * 20;
let group_ids: Vec<i32> = (0..num_rows as i32).collect();
let values: Vec<i64> = vec![1i64; num_rows];
let input_batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int32Array::from(group_ids)),
Arc::new(Int64Array::from(values)),
],
)?;

// Tight memory limit so early emission fires before all rows are processed.
let runtime = RuntimeEnvBuilder::default()
.with_memory_limit(512 * 1024, 1.0)
.build_arc()?;
let mut task_ctx = TaskContext::default().with_runtime(runtime);
let mut cfg = task_ctx.session_config().clone();
cfg = cfg.set(
"datafusion.execution.batch_size",
&datafusion_common::ScalarValue::UInt64(Some(batch_size as u64)),
);
// Disable skip-aggregation so every row goes through the hash table.
cfg = cfg.set(
"datafusion.execution.skip_partial_aggregation_probe_rows_threshold",
&datafusion_common::ScalarValue::UInt64(Some(u64::MAX)),
);
task_ctx = task_ctx.with_session_config(cfg);
let task_ctx = Arc::new(task_ctx);

let group_expr = vec![(col("group_col", &schema)?, "group_col".to_string())];
let aggr_expr = vec![Arc::new(
AggregateExprBuilder::new(count_udaf(), vec![col("value_col", &schema)?])
.schema(Arc::clone(&schema))
.alias("cnt")
.build()?,
)];

let exec = TestMemoryExec::try_new(
&[vec![input_batch]],
Arc::clone(&schema),
None,
)?;
let exec = Arc::new(TestMemoryExec::update_cache(&Arc::new(exec)));

let aggregate_exec = AggregateExec::try_new(
AggregateMode::Partial,
PhysicalGroupBy::new_single(group_expr),
aggr_expr,
vec![None],
exec,
Arc::clone(&schema),
)?;

let mut stream =
GroupedHashAggregateStream::new(&aggregate_exec, &Arc::clone(&task_ctx), 0)?;

// group_col (i32) + count state (i64) = 12 bytes per row; use 4× as a
// generous upper bound to account for Arrow buffer alignment/overhead.
let bytes_per_row: usize = 12;
let max_allowed = batch_size * bytes_per_row * 4;

let mut total_rows = 0usize;
while let Some(result) = stream.next().await {
let batch = result?;
let mem = get_record_batch_memory_size(&batch);
assert!(
mem <= max_allowed,
"emitted batch has {mem} bytes, expected ≤ {max_allowed} \
(batch_size={batch_size}, rows={})",
batch.num_rows(),
);
total_rows += batch.num_rows();
}

assert_eq!(total_rows, num_rows, "all input groups must be emitted");

Ok(())
}
}