diff --git a/spark/common/src/main/scala/org/apache/spark/sql/execution/datasources/geoparquet/GeoParquetFileFormat.scala b/spark/common/src/main/scala/org/apache/spark/sql/execution/datasources/geoparquet/GeoParquetFileFormat.scala index 8a5c2c19a0f..db0f1dd8649 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/execution/datasources/geoparquet/GeoParquetFileFormat.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/execution/datasources/geoparquet/GeoParquetFileFormat.scala @@ -273,6 +273,21 @@ class GeoParquetFileFormat(val spatialFilter: Option[GeoParquetSpatialFilter]) None } + // Spatial filters that translate to Parquet row-group predicates (e.g. Box2D bounds + // comparisons on a Box2D-typed column) are AND'd into the pushed-down filter so Parquet + // can skip row groups whose column statistics disprove them. Gated on the same Spark + // SQL flag as ordinary Parquet pushdown so disabling `spark.sql.parquet.filterPushdown` + // also disables Sedona-injected row-group predicates. + val combinedPushed = if (enableParquetFilterPushDown) { + spatialFilter.flatMap(_.toParquetFilter) match { + case Some(spatialPredicate) => + Some(pushed.fold(spatialPredicate)(p => FilterApi.and(p, spatialPredicate))) + case None => pushed + } + } else { + pushed + } + // Prune file scans using pushed down spatial filters and per-column bboxes in geoparquet metadata val shouldScanFile = GeoParquetMetaData.parseKeyValueMetaData(footerFileMetaData.getKeyValueMetaData).forall { @@ -304,8 +319,10 @@ class GeoParquetFileFormat(val spatialFilter: Option[GeoParquetSpatialFilter]) // Try to push down filters when filter push-down is enabled. // Notice: This push-down is RowGroups level, not individual records. - if (pushed.isDefined) { - ParquetInputFormat.setFilterPredicate(hadoopAttemptContext.getConfiguration, pushed.get) + if (combinedPushed.isDefined) { + ParquetInputFormat.setFilterPredicate( + hadoopAttemptContext.getConfiguration, + combinedPushed.get) } if (enableVectorizedReader) { logWarning( @@ -319,8 +336,8 @@ class GeoParquetFileFormat(val spatialFilter: Option[GeoParquetSpatialFilter]) datetimeRebaseSpec, int96RebaseSpec, options) - val reader = if (pushed.isDefined && enableRecordFilter) { - val parquetFilter = FilterCompat.get(pushed.get, null) + val reader = if (combinedPushed.isDefined && enableRecordFilter) { + val parquetFilter = FilterCompat.get(combinedPushed.get, null) new ParquetRecordReader[InternalRow](readSupport, parquetFilter) } else { new ParquetRecordReader[InternalRow](readSupport) diff --git a/spark/common/src/main/scala/org/apache/spark/sql/execution/datasources/geoparquet/GeoParquetSpatialFilter.scala b/spark/common/src/main/scala/org/apache/spark/sql/execution/datasources/geoparquet/GeoParquetSpatialFilter.scala index cacc64d94b4..65e4437b8a8 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/execution/datasources/geoparquet/GeoParquetSpatialFilter.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/execution/datasources/geoparquet/GeoParquetSpatialFilter.scala @@ -18,6 +18,8 @@ */ package org.apache.spark.sql.execution.datasources.geoparquet +import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate} +import org.apache.sedona.common.geometryObjects.Box2D import org.apache.sedona.core.spatialOperator.SpatialPredicate import org.locationtech.jts.geom.Envelope import org.locationtech.jts.geom.Geometry @@ -28,7 +30,21 @@ import org.locationtech.jts.geom.Geometry * [[org.apache.spark.sql.sedona_sql.optimization.SpatialFilterPushDownForGeoParquet]]. */ trait GeoParquetSpatialFilter { + + /** + * File-level evaluation against GeoParquet column metadata. Used for cheap whole-file pruning + * before reading row-group statistics. Filters that cannot soundly prune at the file metadata + * level should return `true` here and emit their pruning predicate via [[toParquetFilter]]. + */ def evaluate(columns: Map[String, GeometryFieldMetaData]): Boolean + + /** + * Translate this spatial filter into a Parquet [[FilterPredicate]] that the Parquet reader can + * evaluate against row-group statistics. Returns `None` if the filter cannot be expressed as a + * Parquet predicate (e.g. arbitrary JTS predicates on a geometry column). + */ + def toParquetFilter: Option[FilterPredicate] = None + def simpleString: String } @@ -40,6 +56,14 @@ object GeoParquetSpatialFilter { left.evaluate(columns) && right.evaluate(columns) } + override def toParquetFilter: Option[FilterPredicate] = + (left.toParquetFilter, right.toParquetFilter) match { + case (Some(l), Some(r)) => Some(FilterApi.and(l, r)) + case (Some(l), None) => Some(l) + case (None, Some(r)) => Some(r) + case _ => None + } + override def simpleString: String = s"(${left.simpleString}) AND (${right.simpleString})" } @@ -47,6 +71,14 @@ object GeoParquetSpatialFilter { extends GeoParquetSpatialFilter { override def evaluate(columns: Map[String, GeometryFieldMetaData]): Boolean = left.evaluate(columns) || right.evaluate(columns) + + // OR pushdown to Parquet requires both sides translate; otherwise we'd drop matching rows. + override def toParquetFilter: Option[FilterPredicate] = + for { + l <- left.toParquetFilter + r <- right.toParquetFilter + } yield FilterApi.or(l, r) + override def simpleString: String = s"(${left.simpleString}) OR (${right.simpleString})" } @@ -88,4 +120,87 @@ object GeoParquetSpatialFilter { } override def simpleString: String = s"$columnName ${predicateType.name} $queryWindow" } + + /** + * Semantic kind of a Box2D leaf predicate. Determines which inequality system is emitted as a + * Parquet filter against the four (xmin, ymin, xmax, ymax) leaf columns of a Box2D-typed + * column. + */ + sealed trait Box2DPredicateKind { + def simpleName: String + } + object Box2DPredicateKind { + + /** `ST_BoxIntersects(box_col, lit)` — symmetric, same regardless of argument order. */ + case object Intersects extends Box2DPredicateKind { + override def simpleName: String = "INTERSECTS" + } + + /** `ST_BoxContains(box_col, lit)` — the column box must contain the literal box. */ + case object ColumnContainsLiteral extends Box2DPredicateKind { + override def simpleName: String = "CONTAINS" + } + + /** `ST_BoxContains(lit, box_col)` — the literal box must contain the column box. */ + case object LiteralContainsColumn extends Box2DPredicateKind { + override def simpleName: String = "CONTAINED_BY" + } + } + + /** + * Pushdown filter for predicates that operate on a Box2D-typed column (e.g. + * `ST_BoxIntersects(box_col, lit_box)` or `ST_BoxContains(box_col, lit_box)`). + * + * Pruning is performed by translating the predicate into per-leaf inequalities on the Box2D + * column's four `Double` fields (`xmin`, `ymin`, `xmax`, `ymax`) and pushing the result down as + * a Parquet [[FilterPredicate]]. Parquet's row-group statistics machinery then skips row groups + * whose per-column min/max bounds disprove the predicate. + * + * File-metadata evaluation returns `true` (i.e. don't prune at the GeoParquet metadata layer) + * because that path relied on the geometry column's bbox and is unsound when the GeoParquet 1.1 + * spec permits coverings to be conservatively wider than per-row envelopes. The Parquet-stats + * path uses the Box2D column's actual recorded min/max, so it is sound for any writer. + */ + case class Box2DLeafFilter( + box2dColumnName: String, + predicateKind: Box2DPredicateKind, + queryBox: Box2D) + extends GeoParquetSpatialFilter { + + override def evaluate(columns: Map[String, GeometryFieldMetaData]): Boolean = true + + override def toParquetFilter: Option[FilterPredicate] = { + val xmin = FilterApi.doubleColumn(s"$box2dColumnName.xmin") + val ymin = FilterApi.doubleColumn(s"$box2dColumnName.ymin") + val xmax = FilterApi.doubleColumn(s"$box2dColumnName.xmax") + val ymax = FilterApi.doubleColumn(s"$box2dColumnName.ymax") + val qxMin = java.lang.Double.valueOf(queryBox.getXMin) + val qyMin = java.lang.Double.valueOf(queryBox.getYMin) + val qxMax = java.lang.Double.valueOf(queryBox.getXMax) + val qyMax = java.lang.Double.valueOf(queryBox.getYMax) + + val predicate = predicateKind match { + case Box2DPredicateKind.Intersects => + // Intersection: row's xmax >= lit.xmin && xmin <= lit.xmax && ymax >= lit.ymin && ymin <= lit.ymax + FilterApi.and( + FilterApi.and(FilterApi.gtEq(xmax, qxMin), FilterApi.ltEq(xmin, qxMax)), + FilterApi.and(FilterApi.gtEq(ymax, qyMin), FilterApi.ltEq(ymin, qyMax))) + case Box2DPredicateKind.ColumnContainsLiteral => + // Column contains literal: row's xmin <= lit.xmin && xmax >= lit.xmax && ymin <= lit.ymin && ymax >= lit.ymax + FilterApi.and( + FilterApi.and(FilterApi.ltEq(xmin, qxMin), FilterApi.gtEq(xmax, qxMax)), + FilterApi.and(FilterApi.ltEq(ymin, qyMin), FilterApi.gtEq(ymax, qyMax))) + case Box2DPredicateKind.LiteralContainsColumn => + // Literal contains column: row's xmin >= lit.xmin && xmax <= lit.xmax && ymin >= lit.ymin && ymax <= lit.ymax + FilterApi.and( + FilterApi.and(FilterApi.gtEq(xmin, qxMin), FilterApi.ltEq(xmax, qxMax)), + FilterApi.and(FilterApi.gtEq(ymin, qyMin), FilterApi.ltEq(ymax, qyMax))) + } + Some(predicate) + } + + override def simpleString: String = + s"$box2dColumnName ${predicateKind.simpleName} BOX(${queryBox.getXMin} ${queryBox.getYMin}, " + + s"${queryBox.getXMax} ${queryBox.getYMax})" + } } diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/SpatialFilterPushDownForGeoParquet.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/SpatialFilterPushDownForGeoParquet.scala index 7c8fefca473..31ccea3de6d 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/SpatialFilterPushDownForGeoParquet.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/SpatialFilterPushDownForGeoParquet.scala @@ -18,6 +18,7 @@ */ package org.apache.spark.sql.sedona_sql.optimization +import org.apache.sedona.common.geometryObjects.Box2D import org.apache.sedona.common.sphere.Haversine import org.apache.sedona.core.spatialOperator.SpatialPredicate import org.apache.sedona.sql.utils.GeometrySerializer @@ -42,10 +43,13 @@ import org.apache.spark.sql.execution.datasources.PushableColumnBase import org.apache.spark.sql.execution.datasources.geoparquet.GeoParquetFileFormatBase import org.apache.spark.sql.execution.datasources.geoparquet.GeoParquetSpatialFilter import org.apache.spark.sql.execution.datasources.geoparquet.GeoParquetSpatialFilter.AndFilter +import org.apache.spark.sql.execution.datasources.geoparquet.GeoParquetSpatialFilter.Box2DLeafFilter +import org.apache.spark.sql.execution.datasources.geoparquet.GeoParquetSpatialFilter.Box2DPredicateKind import org.apache.spark.sql.execution.datasources.geoparquet.GeoParquetSpatialFilter.LeafFilter import org.apache.spark.sql.execution.datasources.geoparquet.GeoParquetSpatialFilter.OrFilter +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT -import org.apache.spark.sql.sedona_sql.expressions.{ST_AsEWKT, ST_Buffer, ST_Contains, ST_CoveredBy, ST_Covers, ST_Crosses, ST_DWithin, ST_Distance, ST_DistanceSphere, ST_DistanceSpheroid, ST_Equals, ST_Intersects, ST_OrderingEquals, ST_Overlaps, ST_Touches, ST_Within} +import org.apache.spark.sql.sedona_sql.expressions.{ST_AsEWKT, ST_BoxContains, ST_BoxIntersects, ST_Buffer, ST_Contains, ST_CoveredBy, ST_Covers, ST_Crosses, ST_DWithin, ST_Distance, ST_DistanceSphere, ST_DistanceSpheroid, ST_Equals, ST_Intersects, ST_OrderingEquals, ST_Overlaps, ST_Touches, ST_Within} import org.apache.spark.sql.sedona_sql.optimization.ExpressionUtils.splitConjunctivePredicates import org.apache.spark.sql.types.DoubleType import org.locationtech.jts.geom.Geometry @@ -144,6 +148,24 @@ class SpatialFilterPushDownForGeoParquet(sparkSession: SparkSession) extends Rul SpatialPredicate.INTERSECTS, GeometryUDT.deserialize(value)) + // Box2D predicates push down as Parquet row-group filters on the Box2D column's underlying + // (xmin, ymin, xmax, ymax) double leaves. Pruning is done by Parquet's stats-based skipping + // against the column's recorded min/max, which is sound regardless of how the writer chose + // the per-row Box2D values. + case ST_BoxIntersects(_) => + // Intersects is symmetric — both argument orders produce the same predicate. + for { + (name, value) <- resolveNameAndLiteral(predicate.children, pushableColumn) + queryBox <- extractBox2DLiteral(value) + } yield Box2DLeafFilter(unquote(name), Box2DPredicateKind.Intersects, queryBox) + + case ST_BoxContains(Seq(pushableColumn(name), Literal(v, _))) => + extractBox2DLiteral(v).map(qb => + Box2DLeafFilter(unquote(name), Box2DPredicateKind.ColumnContainsLiteral, qb)) + case ST_BoxContains(Seq(Literal(v, _), pushableColumn(name))) => + extractBox2DLiteral(v).map(qb => + Box2DLeafFilter(unquote(name), Box2DPredicateKind.LiteralContainsColumn, qb)) + case LessThan(ST_Distance(distArgs), Literal(d, DoubleType)) => for ((name, value) <- resolveNameAndLiteral(distArgs, pushableColumn)) yield distanceFilter(name, GeometryUDT.deserialize(value), d.asInstanceOf[Double]) @@ -256,6 +278,24 @@ class SpatialFilterPushDownForGeoParquet(sparkSession: SparkSession) extends Rul parseColumnPath(name).mkString(".") } + /** + * Extract a [[Box2D]] from a Catalyst literal value. Box2DUDT serializes to an InternalRow of + * four doubles; if the value is something else, the predicate is not pushable. Inverted bounds + * (xmin>xmax or ymin>ymax) are rejected here so the predicate falls back to runtime evaluation + * and surfaces the expected IllegalArgumentException — pushing them through Parquet would + * silently prune all matching rows before the throw fires. + */ + private def extractBox2DLiteral(value: Any): Option[Box2D] = value match { + case row: InternalRow if row.numFields == 4 => + val xmin = row.getDouble(0) + val ymin = row.getDouble(1) + val xmax = row.getDouble(2) + val ymax = row.getDouble(3) + if (xmin > xmax || ymin > ymax) None + else Some(new Box2D(xmin, ymin, xmax, ymax)) + case _ => None + } + private def resolveNameAndLiteral( expressions: Seq[Expression], pushableColumn: PushableColumnBase): Option[(String, Any)] = { diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/GeoParquetSpatialFilterPushDownSuite.scala b/spark/common/src/test/scala/org/apache/sedona/sql/GeoParquetSpatialFilterPushDownSuite.scala index 2f37488d1a5..9eec5b7070f 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/GeoParquetSpatialFilterPushDownSuite.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/GeoParquetSpatialFilterPushDownSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.SimpleMode import org.apache.spark.sql.execution.datasources.geoparquet.{GeoParquetFileFormat, GeoParquetMetaData, GeoParquetSpatialFilter} +import org.apache.spark.sql.functions.expr import org.locationtech.jts.geom.Coordinate import org.locationtech.jts.geom.Geometry import org.locationtech.jts.geom.GeometryFactory @@ -317,6 +318,114 @@ class GeoParquetSpatialFilterPushDownSuite extends TestBaseScala with TableDrive assert(getPushedDownSpatialFilter(dfFiltered).isEmpty) } } + + it("Push down ST_BoxIntersects against a Box2D column") { + val (box2dDf, box2dDir) = setupBox2DCoveringFixture() + try { + // Q1 region only (region 1, center +10/+10) + val q1Filter = + "ST_BoxIntersects(geom_bbox, ST_MakeBox2D(ST_Point(5.0, 5.0), ST_Point(15.0, 15.0)))" + verifyBox2DFilter(box2dDf, q1Filter) + + // Window covering Q2 and Q4 (negative X) + val leftHalfFilter = + "ST_BoxIntersects(geom_bbox, ST_MakeBox2D(ST_Point(-20.0, -20.0), ST_Point(-1.0, 20.0)))" + verifyBox2DFilter(box2dDf, leftHalfFilter) + + // Disjoint window prunes everything + val disjointFilter = + "ST_BoxIntersects(geom_bbox, ST_MakeBox2D(ST_Point(100.0, 100.0), ST_Point(200.0, 200.0)))" + verifyBox2DFilter(box2dDf, disjointFilter) + + // Reverse argument order: ST_BoxIntersects(lit, col) is symmetric. + val reversedFilter = + "ST_BoxIntersects(ST_MakeBox2D(ST_Point(5.0, 5.0), ST_Point(15.0, 15.0)), geom_bbox)" + verifyBox2DFilter(box2dDf, reversedFilter) + } finally { + FileUtils.deleteDirectory(new File(box2dDir).getParentFile) + } + } + + it("ST_BoxIntersects with inverted-bound literal falls back to runtime throw") { + val (box2dDf, box2dDir) = setupBox2DCoveringFixture() + try { + // xmin > xmax: must not push down — otherwise Parquet's row-group filter could prune all + // matches and hide the expected IllegalArgumentException from the predicate's runtime + // evaluation. + val invertedFilter = + "ST_BoxIntersects(geom_bbox, ST_MakeBox2D(ST_Point(20.0, -20.0), ST_Point(-20.0, 20.0)))" + val dfFiltered = box2dDf.where(invertedFilter) + assert( + getPushedDownSpatialFilter(dfFiltered).isEmpty, + "Inverted-bound Box2D literal must not be pushed down") + val ex = intercept[Exception](dfFiltered.collect()) + assert( + Iterator + .iterate(ex: Throwable)(_.getCause) + .takeWhile(_ != null) + .exists(_.isInstanceOf[IllegalArgumentException]), + s"Expected IllegalArgumentException in cause chain, got: $ex") + } finally { + FileUtils.deleteDirectory(new File(box2dDir).getParentFile) + } + } + + it("Push down ST_BoxContains against a Box2D column") { + val (box2dDf, box2dDir) = setupBox2DCoveringFixture() + try { + // ST_BoxContains(box_col, lit_box) — the column box must contain the literal box. A tiny + // query inside Q1 is contained only by rows from region 1. + val containsFilter = + "ST_BoxContains(geom_bbox, ST_MakeBox2D(ST_Point(9.0, 9.0), ST_Point(10.0, 10.0)))" + verifyBox2DFilter(box2dDf, containsFilter) + + // Reverse argument order: ST_BoxContains(lit_box, col) — the literal box must contain the + // column box. The 10x10 window in Q1 contains the 2x2 polygons centered at (5,5), (5,15), + // (15,5), (15,15) only partially; only rows whose envelopes lie entirely inside the window + // survive. + val reversedFilter = + "ST_BoxContains(ST_MakeBox2D(ST_Point(4.0, 4.0), ST_Point(16.0, 16.0)), geom_bbox)" + verifyBox2DFilter(box2dDf, reversedFilter) + } finally { + FileUtils.deleteDirectory(new File(box2dDir).getParentFile) + } + } + } + + private def setupBox2DCoveringFixture(): (DataFrame, String) = { + val box2dParent = + Files.createTempDirectory("sedona_geoparquet_box2d_").toFile.getAbsolutePath + val box2dDir = box2dParent + "/data" + val withBox = df.withColumn("geom_bbox", expr("ST_Box2D(geom)")) + withBox.coalesce(1).write.partitionBy("region").format("geoparquet").save(box2dDir) + val box2dDf = sparkSession.read.format("geoparquet").load(box2dDir) + (box2dDf, box2dDir) + } + + private def verifyBox2DFilter(box2dDf: DataFrame, condition: String): Unit = { + val dfFiltered = box2dDf.where(condition) + + // Pushdown is attached and translates to a Parquet row-group filter. + val pushed = getPushedDownSpatialFilter(dfFiltered) + assert(pushed.isDefined, s"Expected spatial filter push-down for: $condition") + assert( + pushed.get.toParquetFilter.isDefined, + s"Expected a Parquet FilterPredicate for: $condition") + + // Correctness: pushdown must not drop any matching rows. Compare against a run with the + // spatial filter rule disabled (so no Parquet predicate is injected from Sedona). + val expectedResult = + withConf(Map("spark.sedona.geoparquet.spatialFilterPushDown" -> "false")) { + box2dDf + .where(condition) + .orderBy("region", "id") + .select("region", "id") + .collect() + .toSeq + } + val actualResult = + dfFiltered.orderBy("region", "id").select("region", "id").collect().toSeq + assert(expectedResult == actualResult, s"Result mismatch under push-down for: $condition") } /**