Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.sedona.spark.SedonaContext
import org.apache.spark.SparkContext
import org.apache.spark.sql.SparkSessionExtensions
import org.apache.spark.sql.parser.ParserFactory
import org.apache.spark.sql.sedona_sql.optimization.Box2DCastResolutionRule
import org.slf4j.{Logger, LoggerFactory}

class SedonaSqlExtensions extends (SparkSessionExtensions => Unit) {
Expand All @@ -36,6 +37,11 @@ class SedonaSqlExtensions extends (SparkSessionExtensions => Unit) {
_ => ()
})

// Resolve geometry↔Box2D casts during analysis so the analyzer accepts
// `CAST(geom AS box2d)` / `CAST(box AS geometry)` despite Spark's stock cast resolver
// refusing arbitrary UDT-to-UDT casts.
e.injectResolutionRule(_ => new Box2DCastResolutionRule)

// Inject Sedona SQL parser
if (enableParser) {
// Try to inject the Sedona SQL parser but gracefully handle initialization failures.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -554,9 +554,8 @@ private[apache] case class ST_MakeBox2D(inputExpressions: Seq[Expression])

/**
* Convert a Box2D to a closed rectangular polygon Geometry. Equivalent to PostGIS {@code
* box2d::geometry}. Exposed as a function rather than a Catalyst implicit cast because UDT-to-UDT
* implicit casts require Catalyst-level work; ST_GeomFromBox2D lives alongside the other
* ST_GeomFrom* constructors.
* box2d::geometry}. `CAST(box AS geometry)` is also accepted (resolved to this expression by the
* Box2D cast resolution rule).
*
* @param inputExpressions
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.spark.sql.sedona_sql.optimization

import org.apache.spark.sql.catalyst.expressions.Cast
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.sedona_sql.UDT.{Box2DUDT, GeometryUDT}
import org.apache.spark.sql.sedona_sql.expressions.{ST_Box2D, ST_GeomFromBox2D}

/**
* Analyzer rule that resolves Catalyst casts between Sedona UDTs that Spark's stock cast resolver
* does not handle. Specifically:
*
* - `CAST(geom AS box2d)` → `ST_Box2D(geom)` (planar bounding box of the geometry)
* - `CAST(box AS geometry)` → `ST_GeomFromBox2D(box)` (rectangular polygon from a Box2D)
*
* Spark's `Cast.canCast` returns `false` for arbitrary UDT-to-UDT casts, so without this rule the
* analyzer would reject the cast. We rewrite during analysis (before `CheckAnalysis`) so the
* downstream optimizer and codegen path see the expression tree of an ordinary Sedona expression.
*
* Implicit type coercion (e.g. passing a Geometry into a Box2D-typed function argument without an
* explicit cast) is intentionally out of scope here; it requires hooking into Catalyst's type
* coercion rules and is tracked separately.
*/
class Box2DCastResolutionRule extends Rule[LogicalPlan] {

override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions {
case c: Cast
if c.child.resolved
&& c.child.dataType.isInstanceOf[GeometryUDT]
&& c.dataType.isInstanceOf[Box2DUDT] =>
ST_Box2D(Seq(c.child))

case c: Cast
if c.child.resolved
&& c.child.dataType.isInstanceOf[Box2DUDT]
&& c.dataType.isInstanceOf[GeometryUDT] =>
ST_GeomFromBox2D(Seq(c.child))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sedona.sql

import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Cast, Expression, Literal}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.sedona_sql.UDT.{Box2DUDT, GeometryUDT}
import org.apache.spark.sql.sedona_sql.expressions.{ST_Box2D, ST_GeomFromBox2D}
import org.apache.spark.sql.sedona_sql.optimization.Box2DCastResolutionRule
import org.apache.spark.sql.types.LongType
import org.scalatest.funspec.AnyFunSpec

class Box2DCastResolutionRuleSuite extends AnyFunSpec {

private val rule = new Box2DCastResolutionRule

private def projectExprPlan(input: AttributeReference, expr: Expression): LogicalPlan = {
val rel = LocalRelation(input)
Project(Seq(Alias(expr, "out")()), rel)
}

describe("Box2DCastResolutionRule") {
it("rewrites Cast(geometry-typed expression, Box2DUDT) into ST_Box2D") {
val geomAttr = AttributeReference("g", GeometryUDT(), nullable = true)()
val cast = Cast(geomAttr, Box2DUDT)
val rewritten = rule(projectExprPlan(geomAttr, cast))
val outExpr =
rewritten.asInstanceOf[Project].projectList.head.asInstanceOf[Alias].child
assert(outExpr.isInstanceOf[ST_Box2D])
assert(outExpr.asInstanceOf[ST_Box2D].inputExpressions == Seq(geomAttr))
assert(outExpr.dataType.isInstanceOf[Box2DUDT])
}

it("rewrites Cast(Box2D-typed expression, GeometryUDT) into ST_GeomFromBox2D") {
val boxAttr = AttributeReference("b", Box2DUDT, nullable = true)()
val cast = Cast(boxAttr, GeometryUDT())
val rewritten = rule(projectExprPlan(boxAttr, cast))
val outExpr =
rewritten.asInstanceOf[Project].projectList.head.asInstanceOf[Alias].child
assert(outExpr.isInstanceOf[ST_GeomFromBox2D])
assert(outExpr.asInstanceOf[ST_GeomFromBox2D].inputExpressions == Seq(boxAttr))
assert(outExpr.dataType.isInstanceOf[GeometryUDT])
}

it("leaves unrelated casts untouched") {
val geomAttr = AttributeReference("g", GeometryUDT(), nullable = true)()
val cast = Cast(Literal(1), LongType)
val rewritten = rule(projectExprPlan(geomAttr, cast))
val outExpr =
rewritten.asInstanceOf[Project].projectList.head.asInstanceOf[Alias].child
assert(outExpr.isInstanceOf[Cast])
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,19 @@ package org.apache.sedona.sql.parser

import org.apache.spark.sql.catalyst.parser.SqlBaseParser._
import org.apache.spark.sql.execution.SparkSqlAstBuilder
import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
import org.apache.spark.sql.sedona_sql.UDT.{Box2DUDT, GeometryUDT}
import org.apache.spark.sql.types.DataType

class SedonaSqlAstBuilder extends SparkSqlAstBuilder {

/**
* Override the method to handle the geometry data type
* @param ctx
* @return
* Recognize Sedona UDT names (GEOMETRY, BOX2D) as primitive data types so SQL `CAST(... AS
* geometry)` / `CAST(... AS box2d)` parse to the matching UDT.
*/
override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = {
ctx.getText.toUpperCase() match {
case "GEOMETRY" => GeometryUDT()
case "BOX2D" => Box2DUDT
case _ => super.visitPrimitiveDataType(ctx)
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sedona.sql

import org.apache.sedona.common.geometryObjects.Box2D
import org.apache.spark.sql.functions.{col, expr}
import org.apache.spark.sql.sedona_sql.UDT.{Box2DUDT, GeometryUDT}

class Box2DCastSuite extends TestBaseScala {

/**
* SQL `CAST(... AS box2d)` / `CAST(... AS geometry)` parsing requires Sedona's
* `SedonaSqlAstBuilder` to be active. The test base randomizes
* `spark.sedona.enableParserExtension` across CI runs, and `SparkContext` is JVM-singleton so
* the active value can differ from this suite's session-level config. Probe directly by parsing
* a tiny CAST: this matches the behavior the SQL tests actually depend on, and caches the
* answer for the rest of the suite. DataFrame `.cast(...)` tests run unconditionally because
* the resolution rule is always injected.
*/
private lazy val sqlCastSupported: Boolean = {
try {
sparkSession
.sql("SELECT CAST(ST_GeomFromText('POINT (0 0)') AS box2d) AS b")
.collect()
true
} catch {
case _: org.apache.spark.sql.catalyst.parser.ParseException => false
}
}

describe("Geometry ↔ Box2D Catalyst cast") {

it("DataFrame .cast(Box2DUDT) rewrites to ST_Box2D") {
import sparkSession.implicits._
val df = Seq("LINESTRING (0 0, 10 20)").toDF("wkt")
val box = df
.select(expr("ST_GeomFromText(wkt)").alias("g"))
.select(col("g").cast(Box2DUDT).alias("b"))
.collect()
.head
.getAs[Box2D]("b")
assert(box == new Box2D(0.0, 0.0, 10.0, 20.0))
}

it("DataFrame .cast(GeometryUDT) rewrites to ST_GeomFromBox2D") {
val df =
sparkSession.sql("SELECT ST_MakeBox2D(ST_Point(0.0, 0.0), ST_Point(2.0, 4.0)) AS b")
val wkt = df
.select(col("b").cast(GeometryUDT()).alias("g"))
.selectExpr("ST_AsText(g) AS wkt")
.collect()
.head
.getString(0)
assert(wkt == "POLYGON ((0 0, 0 4, 2 4, 2 0, 0 0))")
}

it("DataFrame round-trip Geometry → Box2D → Geometry yields the envelope polygon") {
import sparkSession.implicits._
val df = Seq("LINESTRING (0 0, 5 10)").toDF("wkt")
val wkt = df
.select(expr("ST_GeomFromText(wkt)").alias("g"))
.select(col("g").cast(Box2DUDT).cast(GeometryUDT()).alias("env"))
.selectExpr("ST_AsText(env) AS wkt")
.collect()
.head
.getString(0)
assert(wkt == "POLYGON ((0 0, 0 10, 5 10, 5 0, 0 0))")
}

it("DataFrame .cast(Box2DUDT) on NULL geometry returns null") {
val box = sparkSession
.sql("SELECT ST_GeomFromText(NULL) AS g")
.select(col("g").cast(Box2DUDT).alias("b"))
.collect()
.head
.getAs[Box2D]("b")
assert(box == null)
}

it("SQL CAST(geom AS box2d) returns the planar bbox") {
assume(
sqlCastSupported,
"Sedona SQL parser extension is required for `CAST(... AS box2d)` syntax")
val box = sparkSession
.sql("SELECT CAST(ST_GeomFromText('LINESTRING (0 0, 10 20)') AS box2d) AS b")
.collect()
.head
.getAs[Box2D]("b")
assert(box == new Box2D(0.0, 0.0, 10.0, 20.0))
}

it("SQL CAST(box AS geometry) returns the rectangular polygon") {
assume(
sqlCastSupported,
"Sedona SQL parser extension is required for `CAST(... AS geometry)` syntax")
val wkt = sparkSession
.sql("SELECT ST_AsText(CAST(ST_MakeBox2D(ST_Point(0.0, 0.0), ST_Point(2.0, 4.0)) AS geometry)) AS w")
.collect()
.head
.getString(0)
assert(wkt == "POLYGON ((0 0, 0 4, 2 4, 2 0, 0 0))")
}

it("SQL round-trip Geometry → Box2D → Geometry yields the envelope polygon") {
assume(
sqlCastSupported,
"Sedona SQL parser extension is required for `CAST(... AS ...)` between UDTs")
val wkt = sparkSession
.sql("SELECT ST_AsText(CAST(CAST(ST_GeomFromText('LINESTRING (0 0, 5 10)') AS box2d) AS geometry)) AS w")
.collect()
.head
.getString(0)
assert(wkt == "POLYGON ((0 0, 0 10, 5 10, 5 0, 0 0))")
}

it("SQL CAST(NULL geometry AS box2d) returns null") {
assume(
sqlCastSupported,
"Sedona SQL parser extension is required for `CAST(... AS box2d)` syntax")
val box = sparkSession
.sql("SELECT CAST(ST_GeomFromText(NULL) AS box2d) AS b")
.collect()
.head
.getAs[Box2D]("b")
assert(box == null)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,19 @@ package org.apache.sedona.sql.parser

import org.apache.spark.sql.catalyst.parser.SqlBaseParser._
import org.apache.spark.sql.execution.SparkSqlAstBuilder
import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
import org.apache.spark.sql.sedona_sql.UDT.{Box2DUDT, GeometryUDT}
import org.apache.spark.sql.types.DataType

class SedonaSqlAstBuilder extends SparkSqlAstBuilder {

/**
* Override the method to handle the geometry data type
* @param ctx
* @return
* Recognize Sedona UDT names (GEOMETRY, BOX2D) as primitive data types so SQL `CAST(... AS
* geometry)` / `CAST(... AS box2d)` parse to the matching UDT.
*/
override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = {
ctx.getText.toUpperCase() match {
case "GEOMETRY" => GeometryUDT()
case "BOX2D" => Box2DUDT
case _ => super.visitPrimitiveDataType(ctx)
}
}
Expand Down
Loading
Loading