Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,15 @@

package org.apache.doris.nereids.processor.post;

import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.SlotRef;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.PartitionPrunablePredicate;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.physical.AbstractPhysicalPlan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalFilter;
import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan;
import org.apache.doris.nereids.util.ExpressionUtils;

import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

Expand Down Expand Up @@ -67,25 +56,10 @@ public Plan visitPhysicalFilter(PhysicalFilter<? extends Plan> filter, CascadesC
if (skipPrunePredicate) {
return filter;
}
Set<Long> scanPartitions = new HashSet<>(scan.getSelectedPartitionIds());
Map<String, Slot> nameToOutputSlot = buildNameToSlotMap(scan);

Set<Expression> remaining = new LinkedHashSet<>(filter.getConjuncts());
boolean changed = false;
PartitionPrunablePredicate entry = entryOpt.get();
if (entry.getSelectedPartitionIds().containsAll(scanPartitions)) {
Map<Expression, Expression> slotReplaceMap =
buildSlotReplaceMap(entry.getSnapshotPartitionSlots(), nameToOutputSlot);
if (slotReplaceMap != null) {
for (Expression conjunct : entry.getPrunableConjuncts()) {
Expression rewritten = slotReplaceMap.isEmpty()
? conjunct : ExpressionUtils.replace(conjunct, slotReplaceMap);
if (remaining.remove(rewritten)) {
changed = true;
}
}
}
}
Set<Expression> prunableConjuncts = entry.getRewrittenPrunableConjuncts(scan, scan.getOutput());
Set<Expression> remaining = new LinkedHashSet<>(filter.getConjuncts());
boolean changed = remaining.removeAll(prunableConjuncts);
if (!changed) {
return filter;
}
Expand All @@ -95,52 +69,4 @@ public Plan visitPhysicalFilter(PhysicalFilter<? extends Plan> filter, CascadesC
return filter.withConjunctsAndChild(remaining, scan)
.copyStatsAndGroupIdFrom((AbstractPhysicalPlan) filter);
}

private static Map<String, Slot> buildNameToSlotMap(PhysicalOlapScan scan) {
OlapTable table = scan.getTable();
List<Slot> slots = scan.getOutput();
Map<String, Slot> map = new HashMap<>(slots.size());
if (scan.getSelectedIndexId() == table.getBaseIndexId()) {
for (Slot slot : slots) {
map.put(slot.getName().toLowerCase(), slot);
}
} else {
for (Slot slot : slots) {
if (!(slot instanceof SlotReference)) {
continue;
}
SlotReference slotReference = (SlotReference) slot;
Optional<Column> columnOptional = slotReference.getOriginalColumn();
if (!columnOptional.isPresent()) {
continue;
}
Expr expr = columnOptional.get().getDefineExpr();
if (!(expr instanceof SlotRef)) {
continue;
}
map.put(((SlotRef) expr).getColumnName().toLowerCase(), slot);
}
}
return map;
}

/**
* Map each recorded snapshot slot to the scan's current output slot of the
* same column name. Returns null when any snapshot slot cannot be located,
* so the caller can skip the entry.
*/
private static Map<Expression, Expression> buildSlotReplaceMap(
List<Slot> snapshotSlots, Map<String, Slot> nameToOutputSlot) {
Map<Expression, Expression> replaceMap = new HashMap<>(snapshotSlots.size());
for (Slot snapshot : snapshotSlots) {
Slot current = nameToOutputSlot.get(snapshot.getName().toLowerCase());
if (current == null) {
return null;
}
if (!snapshot.equals(current)) {
replaceMap.put(snapshot, current);
}
}
return replaceMap;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.PartitionPrunablePredicate;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
import org.apache.doris.nereids.trees.plans.algebra.CatalogRelation;
Expand Down Expand Up @@ -137,6 +138,7 @@
import org.apache.doris.nereids.trees.plans.physical.PhysicalWorkTableReference;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanVisitor;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.SessionVariable;
import org.apache.doris.statistics.AnalysisManager;
Expand Down Expand Up @@ -165,6 +167,7 @@
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -1208,7 +1211,55 @@ public Statistics computeAssertNumRows(AssertNumRowsElement assertNumRowsElement
* computeFilter
*/
public Statistics computeFilter(Filter filter, Statistics inputStats) {
return new FilterEstimation().estimate(filter.getPredicate(), inputStats);
Set<Expression> conjuncts = getFilterConjunctsForStats(filter);
if (conjuncts.isEmpty()) {
return inputStats;
}
return new FilterEstimation().estimate(ExpressionUtils.and(conjuncts), inputStats);
}

private Set<Expression> getFilterConjunctsForStats(Filter filter) {
Optional<Plan> scanPlanOpt = getFilterChildOlapScan(filter);
if (!scanPlanOpt.isPresent()) {
return filter.getConjuncts();
}
Plan scanPlan = scanPlanOpt.get();
OlapScan scan = (OlapScan) scanPlan;
Optional<PartitionPrunablePredicate> entryOpt = scan.getPartitionPrunablePredicates();
if (!entryOpt.isPresent()) {
return filter.getConjuncts();
}
Set<Expression> prunableConjuncts = entryOpt.get().getRewrittenPrunableConjuncts(scan, scanPlan.getOutput());
if (prunableConjuncts.isEmpty()) {
return filter.getConjuncts();
}
Set<Expression> remaining = new LinkedHashSet<>(filter.getConjuncts());
remaining.removeAll(prunableConjuncts);
return remaining;
}

private Optional<Plan> getFilterChildOlapScan(Filter filter) {
Optional<Plan> child = getFilterChild(filter);
if (child.isPresent() && child.get() instanceof OlapScan) {
return child;
}
if (groupExpression != null) {
Plan childPlan = groupExpression.getFirstChildPlan(OlapScan.class);
if (childPlan instanceof OlapScan) {
return Optional.of(childPlan);
}
}
return Optional.empty();
}

private Optional<Plan> getFilterChild(Filter filter) {
if (filter instanceof LogicalFilter) {
return Optional.of(((LogicalFilter<?>) filter).child());
}
if (filter instanceof PhysicalFilter) {
return Optional.of(((PhysicalFilter<?>) filter).child());
}
return Optional.empty();
}

private ColumnStatistic getColumnStatistic(TableIf table, String colName, long idxId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,24 @@

package org.apache.doris.nereids.trees.plans;

import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.SlotRef;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.algebra.OlapScan;
import org.apache.doris.nereids.util.ExpressionUtils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

/**
Expand Down Expand Up @@ -92,4 +102,69 @@ public List<Slot> getSnapshotPartitionSlots() {
public Set<Expression> getPrunableConjuncts() {
return prunableConjuncts;
}

public Set<Expression> getRewrittenPrunableConjuncts(OlapScan scan, List<Slot> output) {
if (!selectedPartitionIds.containsAll(scan.getSelectedPartitionIds())) {
return ImmutableSet.of();
}
Map<Expression, Expression> slotReplaceMap = buildSlotReplaceMap(
snapshotPartitionSlots, buildNameToSlotMap(scan, output));
if (slotReplaceMap == null) {
return ImmutableSet.of();
}
ImmutableSet.Builder<Expression> rewrittenConjuncts =
ImmutableSet.builderWithExpectedSize(prunableConjuncts.size());
for (Expression conjunct : prunableConjuncts) {
rewrittenConjuncts.add(slotReplaceMap.isEmpty()
? conjunct : ExpressionUtils.replace(conjunct, slotReplaceMap));
}
return rewrittenConjuncts.build();
}

private static Map<String, Slot> buildNameToSlotMap(OlapScan scan, List<Slot> output) {
OlapTable table = scan.getTable();
Map<String, Slot> map = new HashMap<>(output.size());
if (scan.getSelectedIndexId() == table.getBaseIndexId()) {
for (Slot slot : output) {
map.put(slot.getName().toLowerCase(), slot);
}
} else {
for (Slot slot : output) {
if (!(slot instanceof SlotReference)) {
continue;
}
SlotReference slotReference = (SlotReference) slot;
Optional<Column> columnOptional = slotReference.getOriginalColumn();
if (!columnOptional.isPresent()) {
continue;
}
Expr expr = columnOptional.get().getDefineExpr();
if (!(expr instanceof SlotRef)) {
continue;
}
map.put(((SlotRef) expr).getColumnName().toLowerCase(), slot);
}
}
return map;
}

/**
* Map each recorded snapshot slot to the scan's current output slot of the
* same column name. Returns null when any snapshot slot cannot be located,
* so the caller can skip this record.
*/
private static Map<Expression, Expression> buildSlotReplaceMap(
List<Slot> snapshotSlots, Map<String, Slot> nameToOutputSlot) {
Map<Expression, Expression> replaceMap = new HashMap<>(snapshotSlots.size());
for (Slot snapshot : snapshotSlots) {
Slot current = nameToOutputSlot.get(snapshot.getName().toLowerCase());
if (current == null) {
return null;
}
if (!snapshot.equals(current)) {
replaceMap.put(snapshot, current);
}
}
return replaceMap;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
package org.apache.doris.nereids.trees.plans.algebra;

import org.apache.doris.catalog.OlapTable;
import org.apache.doris.nereids.trees.plans.PartitionPrunablePredicate;

import java.util.List;
import java.util.Optional;

/** OlapScan */
public interface OlapScan {
Expand All @@ -32,6 +34,8 @@ public interface OlapScan {

List<Long> getSelectedTabletIds();

Optional<PartitionPrunablePredicate> getPartitionPrunablePredicates();

/** getScanTabletNum */
default int getScanTabletNum() {
List<Long> selectedTabletIds = getSelectedTabletIds();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.LimitPhase;
import org.apache.doris.nereids.trees.plans.PartitionPrunablePredicate;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
Expand Down Expand Up @@ -87,6 +88,13 @@ private Group newFakeGroup() {
return group;
}

private SlotReference getSlot(LogicalOlapScan scan, String slotName) {
return (SlotReference) scan.getOutput().stream()
.filter(slot -> slot.getName().equals(slotName))
.findFirst()
.get();
}

@Test
public void testFilter() {
List<String> qualifier = Lists.newArrayList();
Expand Down Expand Up @@ -135,6 +143,56 @@ public void testFilter() {
ownerGroupOr.getStatistics().getRowCount(), 0.1);
}

@Test
public void testFilterIgnoresRecordedPrunedPredicatesInStats() {
LogicalOlapScan scan = PlanConstructor.newDpHyperLogicalOlapScan(10, "t_pruned_stats", 0);
SlotReference partitionSlot = getSlot(scan, "id");
SlotReference dataSlot = getSlot(scan, "age");

EqualTo partitionEqual = new EqualTo(partitionSlot, new IntegerLiteral(1));
EqualTo dataEqual = new EqualTo(dataSlot, new IntegerLiteral(2));
PartitionPrunablePredicate prunablePredicate = new PartitionPrunablePredicate(
ImmutableSet.copyOf(scan.getSelectedPartitionIds()),
ImmutableList.of(partitionSlot),
ImmutableSet.of(partitionEqual));
LogicalOlapScan prunedScan = scan.withPartitionPrunablePredicates(Optional.of(prunablePredicate));

ColumnStatisticBuilder partitionColumnStats = new ColumnStatisticBuilder();
partitionColumnStats.setNdv(100);
partitionColumnStats.setMinValue(0);
partitionColumnStats.setMaxValue(1000);
partitionColumnStats.setNumNulls(0);
ColumnStatisticBuilder dataColumnStats = new ColumnStatisticBuilder();
dataColumnStats.setNdv(10);
dataColumnStats.setMinValue(0);
dataColumnStats.setMaxValue(1000);
dataColumnStats.setNumNulls(0);

Map<Expression, ColumnStatistic> slotColumnStatsMap = new HashMap<>();
slotColumnStatsMap.put(partitionSlot, partitionColumnStats.build());
slotColumnStatsMap.put(dataSlot, dataColumnStats.build());
Statistics childStats = new Statistics(1000, slotColumnStatsMap);

GroupExpression scanGroupExpression = new GroupExpression(prunedScan);
Group childGroup = new Group(null, scanGroupExpression,
new LogicalProperties(prunedScan::getOutput, () -> DataTrait.EMPTY_TRAIT));
childGroup.setStatistics(childStats);
GroupPlan groupPlan = new GroupPlan(childGroup);

LogicalFilter<GroupPlan> filter = new LogicalFilter<>(ImmutableSet.of(partitionEqual, dataEqual), groupPlan);
GroupExpression filterGroupExpression = new GroupExpression(filter, ImmutableList.of(childGroup));
Group ownerGroup = new Group(null, filterGroupExpression, null);
StatsCalculator.estimate(filterGroupExpression, null);
Assertions.assertEquals(100, ownerGroup.getStatistics().getRowCount(), 0.001);

LogicalFilter<GroupPlan> partitionOnlyFilter = new LogicalFilter<>(ImmutableSet.of(partitionEqual), groupPlan);
GroupExpression partitionOnlyGroupExpression =
new GroupExpression(partitionOnlyFilter, ImmutableList.of(childGroup));
Group partitionOnlyOwnerGroup = new Group(null, partitionOnlyGroupExpression, null);
StatsCalculator.estimate(partitionOnlyGroupExpression, null);
Assertions.assertEquals(1000, partitionOnlyOwnerGroup.getStatistics().getRowCount(), 0.001);
}

// a, b are in (0,100)
// a=200 and b=300 => output: 0 rows
@org.junit.Test
Expand Down
Loading