Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,31 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.types.{ArrayType, DataType}

abstract class VeloxCollect(child: Expression)
abstract class VeloxCollect(child: Expression, val ignoreNulls: Boolean)
extends DeclarativeAggregate
with UnaryLike[Expression] {

protected lazy val buffer: AttributeReference = AttributeReference("buffer", dataType)()

override def dataType: DataType = ArrayType(child.dataType, false)
override def dataType: DataType = ArrayType(child.dataType, !ignoreNulls)

override def nullable: Boolean = false

override def aggBufferAttributes: Seq[AttributeReference] = Seq(buffer)

override lazy val initialValues: Seq[Expression] = Seq(Literal.create(Array(), dataType))

override lazy val updateExpressions: Seq[Expression] = Seq(
If(
IsNull(child),
buffer,
Concat(Seq(buffer, CreateArray(Seq(child), useStringTypeWhenEmpty = false))))
)
override lazy val updateExpressions: Seq[Expression] = {
val append = if (ignoreNulls) {
If(
IsNull(child),
buffer,
Concat(Seq(buffer, CreateArray(Seq(child), useStringTypeWhenEmpty = false))))
} else {
Concat(Seq(buffer, CreateArray(Seq(child), useStringTypeWhenEmpty = false)))
}
Seq(append)
}

override lazy val mergeExpressions: Seq[Expression] = Seq(
Concat(Seq(buffer.left, buffer.right))
Expand All @@ -49,7 +54,8 @@ abstract class VeloxCollect(child: Expression)
override def defaultResult: Option[Literal] = Option(Literal.create(Array(), dataType))
}

case class VeloxCollectSet(child: Expression) extends VeloxCollect(child) {
case class VeloxCollectSet(child: Expression, override val ignoreNulls: Boolean = true)
extends VeloxCollect(child, ignoreNulls) {

override lazy val evaluateExpression: Expression =
ArrayDistinct(buffer)
Expand All @@ -60,7 +66,8 @@ case class VeloxCollectSet(child: Expression) extends VeloxCollect(child) {
copy(child = newChild)
}

case class VeloxCollectList(child: Expression) extends VeloxCollect(child) {
case class VeloxCollectList(child: Expression, override val ignoreNulls: Boolean = true)
extends VeloxCollect(child, ignoreNulls) {

override val evaluateExpression: Expression = buffer

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,24 @@ object CollectRewriteRule {
def unapply(expr: Expression): Option[Expression] = expr match {
case aggExpr @ AggregateExpression(s: CollectSet, _, _, _, _) if has[VeloxCollectSet] =>
val newAggExpr =
aggExpr.copy(aggregateFunction = VeloxCollectSet(s.child))
aggExpr.copy(aggregateFunction = VeloxCollectSet(s.child, getIgnoreNulls(s)))
Some(newAggExpr)
case aggExpr @ AggregateExpression(l: CollectList, _, _, _, _) if has[VeloxCollectList] =>
val newAggExpr = aggExpr.copy(VeloxCollectList(l.child))
val newAggExpr = aggExpr.copy(VeloxCollectList(l.child, getIgnoreNulls(l)))
Some(newAggExpr)
case _ => None
}
}

private def getIgnoreNulls(expr: Expression): Boolean = {
try {
val method = expr.getClass.getMethod("ignoreNulls")
method.invoke(expr).asInstanceOf[Boolean]
} catch {
case _: NoSuchMethodException => true // Default: ignore nulls
}
}

private def has[T <: Expression: ClassTag]: Boolean =
ExpressionMappings.expressionsMap.contains(classTag[T].runtimeClass)
}
30 changes: 23 additions & 7 deletions cpp/velox/substrait/SubstraitToVeloxPlan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -284,14 +284,30 @@ std::string SubstraitToVeloxPlanConverter::toAggregationFunctionName(
// The merge_extract function is registered without suffix.
return functionName;
}
// The merge_extract function must be registered with suffix based on result type.
functionName += ("_" + companionFunctionSuffix(resultType));
signatures = exec::getAggregateFunctionSignatures(functionName);
VELOX_CHECK(
signatures.has_value() && signatures.value().size() > 0,
// The merge_extract function must be registered with suffix based on
// result type. First try exact concrete type suffix.
auto suffixedName =
functionName + "_" + companionFunctionSuffix(resultType);
signatures = exec::getAggregateFunctionSignatures(suffixedName);
if (signatures.has_value() && signatures.value().size() > 0) {
return suffixedName;
}
// When companion functions are registered with generic type variables
// (e.g., "collect_set_merge_extract_array_T"), look up companion
// function names from the aggregate function registry.
auto companionSigs = exec::getCompanionFunctionSignatures(baseName);
if (companionSigs.has_value()) {
for (const auto& entry : companionSigs->mergeExtract) {
auto entrySigs =
exec::getAggregateFunctionSignatures(entry.functionName);
if (entrySigs.has_value() && entrySigs.value().size() > 0) {
return entry.functionName;
}
}
}
VELOX_FAIL(
"Cannot find function signature for {} in final aggregation step.",
functionName);
return functionName;
suffixedName);
}
case core::AggregationNode::Step::kIntermediate:
suffix = "_merge";
Expand Down
2 changes: 1 addition & 1 deletion ep/build-velox/src/get-velox.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
set -exu

CURRENT_DIR=$(cd "$(dirname "$BASH_SOURCE")"; pwd)
VELOX_REPO=https://github.com/IBM/velox.git
VELOX_REPO=https://github.com/zhouyuan/velox.git
VELOX_BRANCH=dft-2026_03_24
VELOX_ENHANCED_BRANCH=ibm-2026_03_24
VELOX_HOME=""
Expand Down
Loading