Skip to content

Commit a9be844

Browse files
chenliu0831claude
andcommitted
Fix connect() for Spark Connect and simplify CI workflow
- Handle Spark Connect session type in connect() (separate class from pyspark.sql.SparkSession) - Remove manual server start/stop from CI; conftest fixture handles it - Accept NaN for non-numeric profile stats from Spark Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent fb45401 commit a9be844

3 files changed

Lines changed: 15 additions & 21 deletions

File tree

.github/workflows/base.yml

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -51,24 +51,9 @@ jobs:
5151
run: |
5252
pytest tests/v2/test_unit.py -v
5353
54-
- name: Start Spark Connect Server
55-
run: |
56-
$SPARK_HOME/sbin/start-connect-server.sh \
57-
--packages org.apache.spark:spark-connect_2.12:3.5.0 \
58-
--jars $PWD/deequ_2.12-2.1.0b-spark-3.5.jar \
59-
--conf spark.connect.extensions.relation.classes=com.amazon.deequ.connect.DeequRelationPlugin
60-
# Wait for server to start
61-
sleep 20
62-
# Verify server is running
63-
ps aux | grep SparkConnectServer | grep -v grep
64-
6554
- name: Run V2 integration tests
6655
env:
6756
SPARK_REMOTE: "sc://localhost:15002"
57+
DEEQU_JAR: ${{ github.workspace }}/deequ_2.12-2.1.0b-spark-3.5.jar
6858
run: |
6959
pytest tests/v2/ -v --ignore=tests/v2/test_unit.py
70-
71-
- name: Stop Spark Connect Server
72-
if: always()
73-
run: |
74-
$SPARK_HOME/sbin/stop-connect-server.sh || true

pydeequ/engines/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,14 +380,21 @@ def connect(
380380
except ImportError:
381381
pass
382382

383-
# Try Spark
383+
# Try Spark (regular and Connect sessions are separate classes)
384384
try:
385385
from pyspark.sql import SparkSession
386386
if isinstance(connection, SparkSession):
387387
from pydeequ.engines.spark import SparkEngine
388388
return SparkEngine(connection, table=table, dataframe=dataframe)
389389
except ImportError:
390390
pass
391+
try:
392+
from pyspark.sql.connect.session import SparkSession as ConnectSession
393+
if isinstance(connection, ConnectSession):
394+
from pydeequ.engines.spark import SparkEngine
395+
return SparkEngine(connection, table=table, dataframe=dataframe)
396+
except ImportError:
397+
pass
391398

392399
raise ValueError(
393400
f"Unsupported connection type: {type(connection).__name__}. "

tests/v2/test_profiles.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,16 @@ def test_numeric_statistics(self, engine, profiler_df):
9090
assert age_profile["std_dev"] is not None
9191

9292
def test_non_numeric_has_null_stats(self, engine, profiler_df):
93-
"""Test non-numeric columns have null for numeric stats."""
93+
"""Test non-numeric columns have null/NaN for numeric stats."""
9494
result = ColumnProfilerRunner(engine).onData(dataframe=profiler_df).run()
9595
rows = {r["column"]: r for r in result.to_dict('records')}
9696

9797
name_profile = rows["name"]
98-
assert name_profile["mean"] is None
99-
assert name_profile["minimum"] is None
100-
assert name_profile["maximum"] is None
98+
# Spark returns NaN for non-numeric stats, DuckDB returns None
99+
import math
100+
assert name_profile["mean"] is None or math.isnan(name_profile["mean"])
101+
assert name_profile["minimum"] is None or math.isnan(name_profile["minimum"])
102+
assert name_profile["maximum"] is None or math.isnan(name_profile["maximum"])
101103

102104

103105
class TestKLLProfiling:

0 commit comments

Comments
 (0)