Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
field_check: bool = False,
field_check_error_code: Optional[str] = "ExpectedVsActualFieldMismatch",
field_check_error_message: Optional[str] = "The submitted header is missing fields",
null_empty_strings: bool = False,
**_,
):
self.header = header
Expand All @@ -55,6 +56,7 @@ def __init__(
self.field_check = field_check
self.field_check_error_code = field_check_error_code
self.field_check_error_message = field_check_error_message
self.null_empty_strings = null_empty_strings

super().__init__()

Expand Down Expand Up @@ -109,7 +111,16 @@ def read_to_relation( # pylint: disable=unused-argument
}

reader_options["columns"] = ddb_schema
return read_csv(resource, **reader_options)
rel = read_csv(resource, **reader_options)

if self.null_empty_strings:
cleaned_cols = ",".join([
f"NULLIF(TRIM({c}), '') as {c}"
for c in reader_options["columns"].keys()
])
rel = rel.select(cleaned_cols)

return rel


class PolarsToDuckDBCSVReader(DuckDBCSVReader):
Expand Down Expand Up @@ -147,6 +158,9 @@ def read_to_relation( # pylint: disable=unused-argument
# redundant
df = pl.scan_csv(resource, **reader_options).select(list(polars_types.keys())) # type: ignore # pylint: disable=W0612

if self.null_empty_strings:
df = df.select([pl.col(c).str.strip_chars().replace("", None) for c in df.columns])

return ddb.sql("SELECT * FROM df")


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections.abc import Iterator
from typing import Any, Optional

import pyspark.sql.functions as psf
from pydantic import BaseModel
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.types import StructType
Expand Down Expand Up @@ -30,6 +31,7 @@ def __init__(
header: bool = True,
multi_line: bool = False,
encoding: str = "utf-8-sig",
null_empty_strings: bool = False,
spark_session: Optional[SparkSession] = None,
**_,
) -> None:
Expand All @@ -40,6 +42,7 @@ def __init__(
self.quote_char = quote_char
self.header = header
self.multi_line = multi_line
self.null_empty_strings = null_empty_strings
self.spark_session = spark_session if spark_session else SparkSession.builder.getOrCreate() # type: ignore # pylint: disable=C0301

super().__init__()
Expand Down Expand Up @@ -70,8 +73,16 @@ def read_to_dataframe(
"multiLine": self.multi_line,
}

return (
df = (
self.spark_session.read.format("csv")
.options(**kwargs) # type: ignore
.load(resource, schema=spark_schema)
)

if self.null_empty_strings:
df = df.select(*[
psf.trim(psf.col(c.name)).alias(c.name)
for c in spark_schema.fields
]).replace("", None)

return df
76 changes: 76 additions & 0 deletions tests/test_core_engine/test_backends/test_readers/test_ddb_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path
from tempfile import TemporaryDirectory

import polars as pl
import pytest
from duckdb import DuckDBPyRelation, default_connection
from pydantic import BaseModel
Expand Down Expand Up @@ -33,6 +34,10 @@ class SimpleHeaderModel(BaseModel):
header_2: str


class VerySimpleModel(BaseModel):
test_col: str


@pytest.fixture
def temp_dir():
with TemporaryDirectory(prefix="ddb_test_csv_reader") as temp_dir:
Expand Down Expand Up @@ -157,3 +162,74 @@ def test_ddb_csv_repeating_header_reader_with_more_than_one_set_of_distinct_valu

with pytest.raises(MessageBearingError):
reader.read_to_relation(str(file_uri), "test", SimpleHeaderModel)


def test_DuckDBCSVReader_with_null_empty_strings(temp_dir):
test_df = pl.DataFrame({"test_col": ["fine", " ", " "]})
file_uri = temp_dir.joinpath("test_empty_string1.csv").as_posix()
test_df.write_csv(
file_uri,
include_header=True,
quote_style="always"
)

reader = DuckDBCSVReader(
header=True,
delim=",",
quotechar='"',
connection=default_connection,
null_empty_strings=True,
)

entity = reader.read_to_relation(file_uri, "test", VerySimpleModel)

assert entity.shape[0] == 3
assert entity.filter("test_col IS NULL").shape[0] == 2


def test_DuckDBCSVRepeatingHeaderReader_with_null_empty_strings(temp_dir):
test_df = pl.DataFrame({
"header_1": ["fine",], "header_2": [" "],
})
file_uri = temp_dir.joinpath("test_empty_string2.csv").as_posix()
test_df.write_csv(
file_uri,
include_header=True,
quote_style="always"
)

reader = DuckDBCSVRepeatingHeaderReader(
header=True,
delim=",",
quotechar='"',
connection=default_connection,
null_empty_strings=True,
)

entity = reader.read_to_relation(file_uri, "test", SimpleHeaderModel)

assert entity.shape[0] == 1
assert entity.filter("header_2 IS NULL").shape[0] == 1


def test_PolarsToDuckDBCSVReader_with_null_empty_strings(temp_dir):
test_df = pl.DataFrame({"test_col": ["fine", " ", " "]})
file_uri = temp_dir.joinpath("test_empty_string3.csv").as_posix()
test_df.write_csv(
file_uri,
include_header=True,
quote_style="always"
)

reader = PolarsToDuckDBCSVReader(
header=True,
delim=",",
quotechar='"',
connection=default_connection,
null_empty_strings=True,
)

entity = reader.read_to_relation(file_uri, "test", VerySimpleModel)

assert entity.shape[0] == 3
assert entity.filter("test_col IS NULL").shape[0] == 2
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Test Spark readers"""

# pylint: disable=W0621
# pylint: disable=C0116
# pylint: disable=C0103
# pylint: disable=C0115

import tempfile
from pathlib import Path

import polars as pl
import pytest
from pydantic import BaseModel
from pyspark.sql import DataFrame, Row, SparkSession
from pyspark.sql.types import StringType, StructField, StructType

from dve.core_engine.backends.implementations.spark.readers.csv import SparkCSVReader


class SparkCSVTestModel(BaseModel):
test_col: str


@pytest.fixture
def spark_null_csv_resource():
test_df = pl.DataFrame({"test_col": ["fine", " ", " "]})

with tempfile.TemporaryDirectory() as tdir:
resource_uri = Path(tdir, "test_spark_csv_reader.csv").as_posix()
test_df.write_csv(resource_uri, include_header=True, quote_style="always")

yield resource_uri


def test_SparkCSVReader_clean_empty_strings(spark: SparkSession, spark_null_csv_resource):
resource_uri = spark_null_csv_resource
expected_df = spark.createDataFrame(
[
Row(
test_col="fine",
),
Row(
test_col=None,
),
Row(test_col=None),
],
StructType([StructField("test_field", StringType())]),
)

reader = SparkCSVReader(null_empty_strings=True, spark_session=spark)

result_df: DataFrame = reader.read_to_dataframe(
resource=resource_uri, entity_name="test", schema=SparkCSVTestModel
)

assert result_df.exceptAll(expected_df).count() == 0