From 125d021afaf17f102460330d284f2727dd0ff9c0 Mon Sep 17 00:00:00 2001 From: "george.robertson1" <50412379+georgeRobertson@users.noreply.github.com> Date: Thu, 5 Mar 2026 23:43:46 +0000 Subject: [PATCH] feat: add option in csv readers to clean and null empty strings --- .../implementations/duckdb/readers/csv.py | 16 +++- .../implementations/spark/readers/csv.py | 13 +++- .../test_readers/test_ddb_csv.py | 76 +++++++++++++++++++ .../test_readers/test_spark/test_spark.py | 56 ++++++++++++++ 4 files changed, 159 insertions(+), 2 deletions(-) create mode 100644 tests/test_core_engine/test_backends/test_readers/test_spark/test_spark.py diff --git a/src/dve/core_engine/backends/implementations/duckdb/readers/csv.py b/src/dve/core_engine/backends/implementations/duckdb/readers/csv.py index ff65d9f..43edb6a 100644 --- a/src/dve/core_engine/backends/implementations/duckdb/readers/csv.py +++ b/src/dve/core_engine/backends/implementations/duckdb/readers/csv.py @@ -46,6 +46,7 @@ def __init__( field_check: bool = False, field_check_error_code: Optional[str] = "ExpectedVsActualFieldMismatch", field_check_error_message: Optional[str] = "The submitted header is missing fields", + null_empty_strings: bool = False, **_, ): self.header = header @@ -55,6 +56,7 @@ def __init__( self.field_check = field_check self.field_check_error_code = field_check_error_code self.field_check_error_message = field_check_error_message + self.null_empty_strings = null_empty_strings super().__init__() @@ -109,7 +111,16 @@ def read_to_relation( # pylint: disable=unused-argument } reader_options["columns"] = ddb_schema - return read_csv(resource, **reader_options) + rel = read_csv(resource, **reader_options) + + if self.null_empty_strings: + cleaned_cols = ",".join([ + f"NULLIF(TRIM({c}), '') as {c}" + for c in reader_options["columns"].keys() + ]) + rel = rel.select(cleaned_cols) + + return rel class PolarsToDuckDBCSVReader(DuckDBCSVReader): @@ -147,6 +158,9 @@ def read_to_relation( # pylint: disable=unused-argument # redundant df = pl.scan_csv(resource, **reader_options).select(list(polars_types.keys())) # type: ignore # pylint: disable=W0612 + if self.null_empty_strings: + df = df.select([pl.col(c).str.strip_chars().replace("", None) for c in df.columns]) + return ddb.sql("SELECT * FROM df") diff --git a/src/dve/core_engine/backends/implementations/spark/readers/csv.py b/src/dve/core_engine/backends/implementations/spark/readers/csv.py index 95db464..e629517 100644 --- a/src/dve/core_engine/backends/implementations/spark/readers/csv.py +++ b/src/dve/core_engine/backends/implementations/spark/readers/csv.py @@ -3,6 +3,7 @@ from collections.abc import Iterator from typing import Any, Optional +import pyspark.sql.functions as psf from pydantic import BaseModel from pyspark.sql import DataFrame, SparkSession from pyspark.sql.types import StructType @@ -30,6 +31,7 @@ def __init__( header: bool = True, multi_line: bool = False, encoding: str = "utf-8-sig", + null_empty_strings: bool = False, spark_session: Optional[SparkSession] = None, **_, ) -> None: @@ -40,6 +42,7 @@ def __init__( self.quote_char = quote_char self.header = header self.multi_line = multi_line + self.null_empty_strings = null_empty_strings self.spark_session = spark_session if spark_session else SparkSession.builder.getOrCreate() # type: ignore # pylint: disable=C0301 super().__init__() @@ -70,8 +73,16 @@ def read_to_dataframe( "multiLine": self.multi_line, } - return ( + df = ( self.spark_session.read.format("csv") .options(**kwargs) # type: ignore .load(resource, schema=spark_schema) ) + + if self.null_empty_strings: + df = df.select(*[ + psf.trim(psf.col(c.name)).alias(c.name) + for c in spark_schema.fields + ]).replace("", None) + + return df diff --git a/tests/test_core_engine/test_backends/test_readers/test_ddb_csv.py b/tests/test_core_engine/test_backends/test_readers/test_ddb_csv.py index 8f9d40d..f195a0d 100644 --- a/tests/test_core_engine/test_backends/test_readers/test_ddb_csv.py +++ b/tests/test_core_engine/test_backends/test_readers/test_ddb_csv.py @@ -2,6 +2,7 @@ from pathlib import Path from tempfile import TemporaryDirectory +import polars as pl import pytest from duckdb import DuckDBPyRelation, default_connection from pydantic import BaseModel @@ -33,6 +34,10 @@ class SimpleHeaderModel(BaseModel): header_2: str +class VerySimpleModel(BaseModel): + test_col: str + + @pytest.fixture def temp_dir(): with TemporaryDirectory(prefix="ddb_test_csv_reader") as temp_dir: @@ -157,3 +162,74 @@ def test_ddb_csv_repeating_header_reader_with_more_than_one_set_of_distinct_valu with pytest.raises(MessageBearingError): reader.read_to_relation(str(file_uri), "test", SimpleHeaderModel) + + +def test_DuckDBCSVReader_with_null_empty_strings(temp_dir): + test_df = pl.DataFrame({"test_col": ["fine", " ", " "]}) + file_uri = temp_dir.joinpath("test_empty_string1.csv").as_posix() + test_df.write_csv( + file_uri, + include_header=True, + quote_style="always" + ) + + reader = DuckDBCSVReader( + header=True, + delim=",", + quotechar='"', + connection=default_connection, + null_empty_strings=True, + ) + + entity = reader.read_to_relation(file_uri, "test", VerySimpleModel) + + assert entity.shape[0] == 3 + assert entity.filter("test_col IS NULL").shape[0] == 2 + + +def test_DuckDBCSVRepeatingHeaderReader_with_null_empty_strings(temp_dir): + test_df = pl.DataFrame({ + "header_1": ["fine",], "header_2": [" "], + }) + file_uri = temp_dir.joinpath("test_empty_string2.csv").as_posix() + test_df.write_csv( + file_uri, + include_header=True, + quote_style="always" + ) + + reader = DuckDBCSVRepeatingHeaderReader( + header=True, + delim=",", + quotechar='"', + connection=default_connection, + null_empty_strings=True, + ) + + entity = reader.read_to_relation(file_uri, "test", SimpleHeaderModel) + + assert entity.shape[0] == 1 + assert entity.filter("header_2 IS NULL").shape[0] == 1 + + +def test_PolarsToDuckDBCSVReader_with_null_empty_strings(temp_dir): + test_df = pl.DataFrame({"test_col": ["fine", " ", " "]}) + file_uri = temp_dir.joinpath("test_empty_string3.csv").as_posix() + test_df.write_csv( + file_uri, + include_header=True, + quote_style="always" + ) + + reader = PolarsToDuckDBCSVReader( + header=True, + delim=",", + quotechar='"', + connection=default_connection, + null_empty_strings=True, + ) + + entity = reader.read_to_relation(file_uri, "test", VerySimpleModel) + + assert entity.shape[0] == 3 + assert entity.filter("test_col IS NULL").shape[0] == 2 diff --git a/tests/test_core_engine/test_backends/test_readers/test_spark/test_spark.py b/tests/test_core_engine/test_backends/test_readers/test_spark/test_spark.py new file mode 100644 index 0000000..d5dccc1 --- /dev/null +++ b/tests/test_core_engine/test_backends/test_readers/test_spark/test_spark.py @@ -0,0 +1,56 @@ +"""Test Spark readers""" + +# pylint: disable=W0621 +# pylint: disable=C0116 +# pylint: disable=C0103 +# pylint: disable=C0115 + +import tempfile +from pathlib import Path + +import polars as pl +import pytest +from pydantic import BaseModel +from pyspark.sql import DataFrame, Row, SparkSession +from pyspark.sql.types import StringType, StructField, StructType + +from dve.core_engine.backends.implementations.spark.readers.csv import SparkCSVReader + + +class SparkCSVTestModel(BaseModel): + test_col: str + + +@pytest.fixture +def spark_null_csv_resource(): + test_df = pl.DataFrame({"test_col": ["fine", " ", " "]}) + + with tempfile.TemporaryDirectory() as tdir: + resource_uri = Path(tdir, "test_spark_csv_reader.csv").as_posix() + test_df.write_csv(resource_uri, include_header=True, quote_style="always") + + yield resource_uri + + +def test_SparkCSVReader_clean_empty_strings(spark: SparkSession, spark_null_csv_resource): + resource_uri = spark_null_csv_resource + expected_df = spark.createDataFrame( + [ + Row( + test_col="fine", + ), + Row( + test_col=None, + ), + Row(test_col=None), + ], + StructType([StructField("test_field", StringType())]), + ) + + reader = SparkCSVReader(null_empty_strings=True, spark_session=spark) + + result_df: DataFrame = reader.read_to_dataframe( + resource=resource_uri, entity_name="test", schema=SparkCSVTestModel + ) + + assert result_df.exceptAll(expected_df).count() == 0