diff --git a/src/dve/core_engine/backends/base/contract.py b/src/dve/core_engine/backends/base/contract.py index a431120..fc7da4d 100644 --- a/src/dve/core_engine/backends/base/contract.py +++ b/src/dve/core_engine/backends/base/contract.py @@ -339,7 +339,7 @@ def read_raw_entities( reader_metadata = contract_metadata.reader_metadata[entity_name] extension = "." + ( get_file_suffix(resource) or "" - ) # Already checked that extension supported. + ).lower() # Already checked that extension supported. reader_config = reader_metadata[extension] reader_type = get_reader(reader_config.reader) diff --git a/src/dve/core_engine/backends/implementations/duckdb/reference_data.py b/src/dve/core_engine/backends/implementations/duckdb/reference_data.py index c10aed7..af815ce 100644 --- a/src/dve/core_engine/backends/implementations/duckdb/reference_data.py +++ b/src/dve/core_engine/backends/implementations/duckdb/reference_data.py @@ -46,4 +46,4 @@ def load_parquet_file(self, uri: str) -> DuckDBPyRelation: @mark_refdata_file_extension("arrow") def load_arrow_file(self, uri: str) -> DuckDBPyRelation: """Load an arrow ipc file into a duckdb relation""" - return self.connection.from_arrow(ipc.open_file(uri).read_all()) # type:ignore + return self.connection.from_arrow(ipc.open_stream(uri).read_all()) # type:ignore diff --git a/src/dve/metadata_parser/domain_types.py b/src/dve/metadata_parser/domain_types.py index 3153d26..6c4a5c4 100644 --- a/src/dve/metadata_parser/domain_types.py +++ b/src/dve/metadata_parser/domain_types.py @@ -173,33 +173,67 @@ def permissive_nhs_number(warn_on_test_numbers: bool = False): return type("NHSNumber", (NHSNumber, *NHSNumber.__bases__), dict_) -# TODO: Make the spacing configurable. Not all downstream consumers want a single space class Postcode(types.ConstrainedStr): """Postcode constrained string""" regex: re.Pattern = POSTCODE_REGEX strip_whitespace = True + apply_normalize = True @staticmethod - def normalize(postcode: str) -> Optional[str]: + def normalize(_postcode: str) -> Optional[str]: """Strips internal and external spaces""" - postcode = postcode.replace(" ", "") - if not postcode or postcode.lower() in NULL_POSTCODES: + _postcode = _postcode.replace(" ", "") + if not _postcode or _postcode.lower() in NULL_POSTCODES: return None - postcode = postcode.replace(" ", "") - return " ".join((postcode[0:-3], postcode[-3:])).upper() + _postcode = _postcode.replace(" ", "") + return " ".join((_postcode[0:-3], _postcode[-3:])).upper() @classmethod def validate(cls, value: str) -> Optional[str]: # type: ignore """Validates the given postcode""" - stripped = cls.normalize(value) - if not stripped: + if cls.apply_normalize and value: + value = cls.normalize(value) # type: ignore + + if not value: return None - if not cls.regex.match(stripped): + if not cls.regex.match(value): raise ValueError("Invalid Postcode submitted") - return stripped + return value + + +@lru_cache() +@validate_arguments +def postcode( + # pylint: disable=R0913 + strip_whitespace: Optional[bool] = True, + to_upper: Optional[bool] = False, + to_lower: Optional[bool] = False, + strict: Optional[bool] = False, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + curtail_length: Optional[int] = None, + regex: Optional[str] = POSTCODE_REGEX, # type: ignore + apply_normalize: Optional[bool] = True, +) -> type[Postcode]: + """Return a formatted date class with a set date format + and timezone treatment. + + """ + dict_ = Postcode.__dict__.copy() + dict_["strip_whitespace"] = strip_whitespace + dict_["to_upper"] = to_upper + dict_["to_lower"] = to_lower + dict_["strict"] = strict + dict_["min_length"] = min_length + dict_["max_length"] = max_length + dict_["curtail_length"] = curtail_length + dict_["regex"] = regex + dict_["apply_normalize"] = apply_normalize + + return type("Postcode", (Postcode, *Postcode.__bases__), dict_) class OrgID(_SimpleRegexValidator): @@ -482,6 +516,11 @@ def validate(cls, value: Union[dt.time, dt.datetime, str]) -> dt.time | None: return new_time + @classmethod + def __get_validators__(cls) -> Iterator[classmethod]: + """Gets all validators""" + yield cls.validate # type: ignore + @lru_cache() @validate_arguments diff --git a/src/dve/pipeline/utils.py b/src/dve/pipeline/utils.py index a7e88aa..37f0cc7 100644 --- a/src/dve/pipeline/utils.py +++ b/src/dve/pipeline/utils.py @@ -47,7 +47,7 @@ def load_config( def load_reader(dataset: Dataset, model_name: str, file_extension: str): """Loads the readers for the diven feed, model name and file extension""" - reader_config = dataset[model_name].reader_config[f".{file_extension}"] + reader_config = dataset[model_name].reader_config[f".{file_extension.lower()}"] reader = _READER_REGISTRY[reader_config.reader](**reader_config.kwargs_) return reader diff --git a/tests/features/books.feature b/tests/features/books.feature index 9bc0611..f13658a 100644 --- a/tests/features/books.feature +++ b/tests/features/books.feature @@ -5,7 +5,7 @@ Feature: Pipeline tests using the books dataset introduces more complex transformations that require aggregation. Scenario: Validate complex nested XML data (spark) - Given I submit the books file nested_books.xml for processing + Given I submit the books file nested_books.XML for processing And A spark pipeline is configured with schema file 'nested_books.dischema.json' And I add initial audit entries for the submission Then the latest audit record for the submission is marked with processing status file_transformation @@ -32,7 +32,7 @@ Feature: Pipeline tests using the books dataset | number_warnings | 0 | Scenario: Validate complex nested XML data (duckdb) - Given I submit the books file nested_books.xml for processing + Given I submit the books file nested_books.XML for processing And A duckdb pipeline is configured with schema file 'nested_books_ddb.dischema.json' And I add initial audit entries for the submission Then the latest audit record for the submission is marked with processing status file_transformation diff --git a/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_data_contract.py b/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_data_contract.py index 9e49338..0300808 100644 --- a/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_data_contract.py +++ b/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_data_contract.py @@ -50,7 +50,7 @@ def test_duckdb_data_contract_csv(temp_csv_file): "description": "test", "callable": "formattedtime", "constraints": { - "time_format": "%Y-%m-%d", + "time_format": "%H:%M:%S", "timezone_treatment": "forbid" } } diff --git a/tests/test_core_engine/test_engine.py b/tests/test_core_engine/test_engine.py index 5118cbd..ef23d71 100644 --- a/tests/test_core_engine/test_engine.py +++ b/tests/test_core_engine/test_engine.py @@ -99,8 +99,8 @@ def test_dummy_books_run(self, spark, temp_dir: str): with test_instance: _, errors_uri = test_instance.run_pipeline( entity_locations={ - "header": get_test_file_path("books/nested_books.xml").as_posix(), - "nested_books": get_test_file_path("books/nested_books.xml").as_posix(), + "header": get_test_file_path("books/nested_books.XML").as_posix(), + "nested_books": get_test_file_path("books/nested_books.XML").as_posix(), } ) diff --git a/tests/test_model_generation/test_domain_types.py b/tests/test_model_generation/test_domain_types.py index 6ceee74..a494494 100644 --- a/tests/test_model_generation/test_domain_types.py +++ b/tests/test_model_generation/test_domain_types.py @@ -98,6 +98,24 @@ def test_postcode(postcode, expected): assert model.postcode == expected +@pytest.mark.parametrize( + ("postcode", "should_error"), + [ + ("LS479AJ", True), + ("PostcodeIamNot", True), + ("LS47 9AJ", False) + ] +) +def test_postcode_errors_with_apply_normalize_disabled(postcode: str, should_error: bool): + postcode_type = hct.postcode(apply_normalize=False) + + if should_error: + with pytest.raises(ValueError, match="Invalid Postcode submitted"): + assert postcode_type.validate(postcode) + else: + assert postcode_type.validate(postcode) + + @pytest.mark.parametrize(("org_id", "expected"), [("AB123", "AB123"), ("ABCDE", "ABCDE")]) def test_org_id_passes(org_id, expected): model = ATestModel(org_id=org_id) @@ -347,7 +365,8 @@ def test_formattedtime( ["23:00:00", "%H:%M:%S", "require",], ["23:00:00Z", "%I:%M:%S", "forbid",], [dt.datetime(2025, 12, 1, 13, 0, 5, tzinfo=UTC), "%H:%M:%S", "forbid",], - [dt.time(13, 0, 5, tzinfo=UTC), "%H:%M:%S", "forbid",] + [dt.time(13, 0, 5, tzinfo=UTC), "%H:%M:%S", "forbid",], + ["12:00", "%H:%M:%S", "forbid",], ] ) def test_formattedtime_raises( @@ -360,3 +379,24 @@ def test_formattedtime_raises( time_type = hct.formattedtime(time_format, timezone_treatment) with pytest.raises(ValueError): time_type.validate(time_to_validate) # pylint: disable=W0106 + + +class StrictTimeModel(BaseModel): + time_val: hct.formattedtime(time_format="%H:%M:%S", timezone_treatment="forbid") + + +@pytest.mark.parametrize( + ["time_to_validate", "expected_to_error"], + [ + ("12:00:00", False), + ("120000", True), + ("12:00", True), + ("12", True), + ] +) +def test_formattedtime_against_model(time_to_validate: str, expected_to_error: bool): + if expected_to_error: + with pytest.raises(ValueError): + StrictTimeModel(time_val=time_to_validate) + else: + StrictTimeModel(time_val=time_to_validate) diff --git a/tests/testdata/books/nested_books.xml b/tests/testdata/books/nested_books.XML similarity index 100% rename from tests/testdata/books/nested_books.xml rename to tests/testdata/books/nested_books.XML diff --git a/tests/testdata/movies/refdata/movies_sequels.arrow b/tests/testdata/movies/refdata/movies_sequels.arrow index 89ec37f..9978fc8 100644 Binary files a/tests/testdata/movies/refdata/movies_sequels.arrow and b/tests/testdata/movies/refdata/movies_sequels.arrow differ