diff --git a/mindee/geometry/point.py b/mindee/geometry/point.py index 6d71ca72..a5d0905c 100644 --- a/mindee/geometry/point.py +++ b/mindee/geometry/point.py @@ -9,5 +9,8 @@ class Point(NamedTuple): y: float """Y coordinate""" + def __str__(self) -> str: + return f"({self.x},{self.y})" + Points = Sequence[Point] diff --git a/mindee/geometry/polygon.py b/mindee/geometry/polygon.py index d5c61ec3..bd733cfd 100644 --- a/mindee/geometry/polygon.py +++ b/mindee/geometry/polygon.py @@ -50,6 +50,9 @@ def is_point_in_y(self, point: Point) -> bool: min_y, max_y = get_min_max_y(self) return is_point_in_y(point, min_y, max_y) + def __str__(self): + return "(" + ", ".join(str(p) for p in self) + ")" + def is_point_in_polygon_x(point: Point, polygon: Polygon) -> bool: """ diff --git a/mindee/parsing/v2/inference.py b/mindee/parsing/v2/inference.py index 477cc41c..341f7cf6 100644 --- a/mindee/parsing/v2/inference.py +++ b/mindee/parsing/v2/inference.py @@ -20,7 +20,8 @@ def __init__(self, raw_response: StringDict): def __str__(self) -> str: return ( f"Inference\n#########" - f"\n{self.model}" + f"\n{self.job}" + f"\n\n{self.model}" f"\n\n{self.file}" f"\n\n{self.active_options}" f"\n\n{self.result}\n" diff --git a/mindee/v2/parsing/inference/base_inference.py b/mindee/v2/parsing/inference/base_inference.py index 78462f0f..ccde42df 100644 --- a/mindee/v2/parsing/inference/base_inference.py +++ b/mindee/v2/parsing/inference/base_inference.py @@ -4,11 +4,14 @@ from mindee.parsing.common.string_dict import StringDict from mindee.parsing.v2.inference_file import InferenceFile from mindee.parsing.v2.inference_model import InferenceModel +from mindee.v2.parsing.inference.inference_job import InferenceJob class BaseInference(ABC): """Base class for V2 inference objects.""" + job: InferenceJob + """Job the inference belongs to.""" model: InferenceModel """Model info for the inference.""" file: InferenceFile @@ -18,6 +21,7 @@ class BaseInference(ABC): def __init__(self, raw_response: StringDict): self.id = raw_response["id"] + self.job = InferenceJob(raw_response["job"]) self.model = InferenceModel(raw_response["model"]) self.file = InferenceFile(raw_response["file"]) diff --git a/mindee/v2/parsing/inference/inference_job.py b/mindee/v2/parsing/inference/inference_job.py new file mode 100644 index 00000000..ac877545 --- /dev/null +++ b/mindee/v2/parsing/inference/inference_job.py @@ -0,0 +1,14 @@ +from mindee.parsing.common.string_dict import StringDict + + +class InferenceJob: + """Inference Job info.""" + + id: str + """UUID of the Job.""" + + def __init__(self, raw_response: StringDict) -> None: + self.id = raw_response["id"] + + def __str__(self) -> str: + return f"Job\n===\n:ID: {self.id}" diff --git a/mindee/v2/product/classification/classification_inference.py b/mindee/v2/product/classification/classification_inference.py index ae83660b..9e00ad60 100644 --- a/mindee/v2/product/classification/classification_inference.py +++ b/mindee/v2/product/classification/classification_inference.py @@ -16,4 +16,10 @@ def __init__(self, raw_response: StringDict) -> None: self.result = ClassificationResult(raw_response["result"]) def __str__(self) -> str: - return f"Inference\n#########\n{self.model}\n{self.file}\n{self.result}\n" + return ( + f"Inference\n#########" + f"\n{self.job}" + f"\n\n{self.model}" + f"\n\n{self.file}" + f"\n\n{self.result}\n" + ) diff --git a/mindee/v2/product/crop/crop_inference.py b/mindee/v2/product/crop/crop_inference.py index 5e3dd19f..0b4c5c8c 100644 --- a/mindee/v2/product/crop/crop_inference.py +++ b/mindee/v2/product/crop/crop_inference.py @@ -16,4 +16,10 @@ def __init__(self, raw_response: StringDict) -> None: self.result = CropResult(raw_response["result"]) def __str__(self) -> str: - return f"Inference\n#########\n{self.model}\n{self.file}\n{self.result}\n" + return ( + f"Inference\n#########" + f"\n{self.job}" + f"\n\n{self.model}" + f"\n\n{self.file}" + f"\n\n{self.result}\n" + ) diff --git a/mindee/v2/product/crop/crop_result.py b/mindee/v2/product/crop/crop_result.py index 3e3878be..12a40305 100644 --- a/mindee/v2/product/crop/crop_result.py +++ b/mindee/v2/product/crop/crop_result.py @@ -15,6 +15,6 @@ def __init__(self, raw_response: StringDict) -> None: def __str__(self) -> str: crops = "\n" if len(self.crops) > 0: - crops += "\n\n".join([str(crop) for crop in self.crops]) - out_str = f"Crops\n======{crops}" + crops += "\n".join([str(crop) for crop in self.crops]) + out_str = f"Crops\n====={crops}" return out_str diff --git a/mindee/v2/product/ocr/ocr_inference.py b/mindee/v2/product/ocr/ocr_inference.py index ffe7c888..60eda267 100644 --- a/mindee/v2/product/ocr/ocr_inference.py +++ b/mindee/v2/product/ocr/ocr_inference.py @@ -16,4 +16,10 @@ def __init__(self, raw_response: StringDict) -> None: self.result = OCRResult(raw_response["result"]) def __str__(self) -> str: - return f"Inference\n#########\n{self.model}\n{self.file}\n{self.result}\n" + return ( + f"Inference\n#########" + f"\n{self.job}" + f"\n\n{self.model}" + f"\n\n{self.file}" + f"\n\n{self.result}\n" + ) diff --git a/mindee/v2/product/split/split_inference.py b/mindee/v2/product/split/split_inference.py index 37aa6edb..7b3d562c 100644 --- a/mindee/v2/product/split/split_inference.py +++ b/mindee/v2/product/split/split_inference.py @@ -16,4 +16,10 @@ def __init__(self, raw_response: StringDict) -> None: self.result = SplitResult(raw_response["result"]) def __str__(self) -> str: - return f"Inference\n#########\n{self.model}\n{self.file}\n{self.result}\n" + return ( + f"Inference\n#########" + f"\n{self.job}" + f"\n\n{self.model}" + f"\n\n{self.file}" + f"\n\n{self.result}\n" + ) diff --git a/tests/data b/tests/data index 37f2e3de..4fd64fdd 160000 --- a/tests/data +++ b/tests/data @@ -1 +1 @@ -Subproject commit 37f2e3de48918e3b1a0e4604a9292aaeae05c637 +Subproject commit 4fd64fdd462c5f29e49b24aacf6e74c7d9aa1ab3 diff --git a/tests/v2/product/classification/test_classification_response.py b/tests/v2/product/classification/test_classification_response.py index 13fee268..68ca836f 100644 --- a/tests/v2/product/classification/test_classification_response.py +++ b/tests/v2/product/classification/test_classification_response.py @@ -1,6 +1,5 @@ import pytest -from mindee import LocalResponse from mindee.v2.product.classification.classification_classifier import ( ClassificationClassifier, ) @@ -9,24 +8,19 @@ ClassificationResponse, ) from mindee.v2.product.classification.classification_result import ClassificationResult -from tests.utils import V2_PRODUCT_DATA_DIR +from tests.v2.product.utils import get_product_samples @pytest.mark.v2 def test_classification_single(): - input_inference = LocalResponse( - V2_PRODUCT_DATA_DIR / "classification" / "classification_single.json" + json_sample, _ = get_product_samples( + product="classification", file_name="classification_single" ) - classification_response = input_inference.deserialize_response( - ClassificationResponse - ) - assert isinstance(classification_response.inference, ClassificationInference) - assert isinstance(classification_response.inference.result, ClassificationResult) + response = ClassificationResponse(json_sample) + assert isinstance(response.inference, ClassificationInference) + assert isinstance(response.inference.result, ClassificationResult) assert isinstance( - classification_response.inference.result.classification, + response.inference.result.classification, ClassificationClassifier, ) - assert ( - classification_response.inference.result.classification.document_type - == "invoice" - ) + assert response.inference.result.classification.document_type == "invoice" diff --git a/tests/v2/product/crop/test_crop_response.py b/tests/v2/product/crop/test_crop_response.py index 52565e0b..635a89ea 100644 --- a/tests/v2/product/crop/test_crop_response.py +++ b/tests/v2/product/crop/test_crop_response.py @@ -1,61 +1,69 @@ import pytest -from mindee import LocalResponse from mindee.v2.product.crop.crop_box import CropBox from mindee.v2.product.crop import CropInference from mindee.v2.product.crop.crop_response import CropResponse from mindee.v2.product.crop.crop_result import CropResult -from tests.utils import V2_PRODUCT_DATA_DIR + +from tests.v2.product.utils import get_product_samples @pytest.mark.v2 def test_crop_single(): - input_inference = LocalResponse(V2_PRODUCT_DATA_DIR / "crop" / "crop_single.json") - crop_response = input_inference.deserialize_response(CropResponse) - assert isinstance(crop_response.inference, CropInference) - assert crop_response.inference.result.crops - assert len(crop_response.inference.result.crops[0].location.polygon) == 4 - assert crop_response.inference.result.crops[0].location.polygon[0][0] == 0.15 - assert crop_response.inference.result.crops[0].location.polygon[0][1] == 0.254 - assert crop_response.inference.result.crops[0].location.polygon[1][0] == 0.85 - assert crop_response.inference.result.crops[0].location.polygon[1][1] == 0.254 - assert crop_response.inference.result.crops[0].location.polygon[2][0] == 0.85 - assert crop_response.inference.result.crops[0].location.polygon[2][1] == 0.947 - assert crop_response.inference.result.crops[0].location.polygon[3][0] == 0.15 - assert crop_response.inference.result.crops[0].location.polygon[3][1] == 0.947 - assert crop_response.inference.result.crops[0].location.page == 0 - assert crop_response.inference.result.crops[0].object_type == "invoice" + json_sample, rst_sample = get_product_samples( + product="crop", file_name="crop_single" + ) + response = CropResponse(json_sample) + assert isinstance(response.inference, CropInference) + assert response.inference.result.crops + assert len(response.inference.result.crops[0].location.polygon) == 4 + assert response.inference.result.crops[0].location.polygon[0][0] == 0.15 + assert response.inference.result.crops[0].location.polygon[0][1] == 0.254 + assert response.inference.result.crops[0].location.polygon[1][0] == 0.85 + assert response.inference.result.crops[0].location.polygon[1][1] == 0.254 + assert response.inference.result.crops[0].location.polygon[2][0] == 0.85 + assert response.inference.result.crops[0].location.polygon[2][1] == 0.947 + assert response.inference.result.crops[0].location.polygon[3][0] == 0.15 + assert response.inference.result.crops[0].location.polygon[3][1] == 0.947 + assert response.inference.result.crops[0].location.page == 0 + assert response.inference.result.crops[0].object_type == "invoice" + + assert rst_sample == str(response) @pytest.mark.v2 def test_crop_multiple(): - input_inference = LocalResponse(V2_PRODUCT_DATA_DIR / "crop" / "crop_multiple.json") - crop_response = input_inference.deserialize_response(CropResponse) - assert isinstance(crop_response.inference, CropInference) - assert isinstance(crop_response.inference.result, CropResult) - assert isinstance(crop_response.inference.result.crops[0], CropBox) - assert len(crop_response.inference.result.crops) == 2 - - assert len(crop_response.inference.result.crops[0].location.polygon) == 4 - assert crop_response.inference.result.crops[0].location.polygon[0][0] == 0.214 - assert crop_response.inference.result.crops[0].location.polygon[0][1] == 0.079 - assert crop_response.inference.result.crops[0].location.polygon[1][0] == 0.476 - assert crop_response.inference.result.crops[0].location.polygon[1][1] == 0.079 - assert crop_response.inference.result.crops[0].location.polygon[2][0] == 0.476 - assert crop_response.inference.result.crops[0].location.polygon[2][1] == 0.979 - assert crop_response.inference.result.crops[0].location.polygon[3][0] == 0.214 - assert crop_response.inference.result.crops[0].location.polygon[3][1] == 0.979 - assert crop_response.inference.result.crops[0].location.page == 0 - assert crop_response.inference.result.crops[0].object_type == "invoice" - - assert len(crop_response.inference.result.crops[1].location.polygon) == 4 - assert crop_response.inference.result.crops[1].location.polygon[0][0] == 0.547 - assert crop_response.inference.result.crops[1].location.polygon[0][1] == 0.15 - assert crop_response.inference.result.crops[1].location.polygon[1][0] == 0.862 - assert crop_response.inference.result.crops[1].location.polygon[1][1] == 0.15 - assert crop_response.inference.result.crops[1].location.polygon[2][0] == 0.862 - assert crop_response.inference.result.crops[1].location.polygon[2][1] == 0.97 - assert crop_response.inference.result.crops[1].location.polygon[3][0] == 0.547 - assert crop_response.inference.result.crops[1].location.polygon[3][1] == 0.97 - assert crop_response.inference.result.crops[1].location.page == 0 - assert crop_response.inference.result.crops[1].object_type == "invoice" + json_sample, rst_sample = get_product_samples( + product="crop", file_name="crop_multiple" + ) + response = CropResponse(json_sample) + assert isinstance(response.inference, CropInference) + assert isinstance(response.inference.result, CropResult) + assert isinstance(response.inference.result.crops[0], CropBox) + assert len(response.inference.result.crops) == 2 + + assert len(response.inference.result.crops[0].location.polygon) == 4 + assert response.inference.result.crops[0].location.polygon[0][0] == 0.214 + assert response.inference.result.crops[0].location.polygon[0][1] == 0.079 + assert response.inference.result.crops[0].location.polygon[1][0] == 0.476 + assert response.inference.result.crops[0].location.polygon[1][1] == 0.079 + assert response.inference.result.crops[0].location.polygon[2][0] == 0.476 + assert response.inference.result.crops[0].location.polygon[2][1] == 0.979 + assert response.inference.result.crops[0].location.polygon[3][0] == 0.214 + assert response.inference.result.crops[0].location.polygon[3][1] == 0.979 + assert response.inference.result.crops[0].location.page == 0 + assert response.inference.result.crops[0].object_type == "invoice" + + assert len(response.inference.result.crops[1].location.polygon) == 4 + assert response.inference.result.crops[1].location.polygon[0][0] == 0.547 + assert response.inference.result.crops[1].location.polygon[0][1] == 0.15 + assert response.inference.result.crops[1].location.polygon[1][0] == 0.862 + assert response.inference.result.crops[1].location.polygon[1][1] == 0.15 + assert response.inference.result.crops[1].location.polygon[2][0] == 0.862 + assert response.inference.result.crops[1].location.polygon[2][1] == 0.97 + assert response.inference.result.crops[1].location.polygon[3][0] == 0.547 + assert response.inference.result.crops[1].location.polygon[3][1] == 0.97 + assert response.inference.result.crops[1].location.page == 0 + assert response.inference.result.crops[1].object_type == "invoice" + + assert rst_sample == str(response) diff --git a/tests/v2/product/extraction/test_extraction_response.py b/tests/v2/product/extraction/test_extraction_response.py index 912d0846..3d0a06a3 100644 --- a/tests/v2/product/extraction/test_extraction_response.py +++ b/tests/v2/product/extraction/test_extraction_response.py @@ -1,6 +1,4 @@ import json -from pathlib import Path -from typing import Tuple import pytest @@ -16,34 +14,14 @@ from mindee.parsing.v2.inference_model import InferenceModel from mindee.parsing.v2.rag_metadata import RagMetadata from tests.utils import V2_PRODUCT_DATA_DIR - - -def _get_samples(json_path: Path, rst_path: Path) -> Tuple[dict, str]: - with json_path.open("r", encoding="utf-8") as fh: - json_sample = json.load(fh) - try: - with rst_path.open("r", encoding="utf-8") as fh: - rst_sample = fh.read() - except FileNotFoundError: - rst_sample = "" - return json_sample, rst_sample - - -def _get_inference_samples(name: str) -> Tuple[dict, str]: - json_path = V2_PRODUCT_DATA_DIR / "extraction" / f"{name}.json" - rst_path = V2_PRODUCT_DATA_DIR / "extraction" / f"{name}.rst" - return _get_samples(json_path, rst_path) - - -def _get_product_samples(product, name: str) -> Tuple[dict, str]: - json_path = V2_PRODUCT_DATA_DIR / "extraction" / product / f"{name}.json" - rst_path = V2_PRODUCT_DATA_DIR / "extraction" / product / f"{name}.rst" - return _get_samples(json_path, rst_path) +from tests.v2.product.utils import get_product_samples @pytest.mark.v2 def test_deep_nested_fields(): - json_sample, rst_sample = _get_inference_samples("deep_nested_fields") + json_sample, _ = get_product_samples( + product="extraction", file_name="deep_nested_fields" + ) response = InferenceResponse(json_sample) assert isinstance(response.inference, Inference) assert isinstance(response.inference.result.fields["field_simple"], SimpleField) @@ -115,7 +93,9 @@ def test_deep_nested_fields(): @pytest.mark.v2 def test_standard_field_types(): - json_sample, rst_sample = _get_inference_samples("standard_field_types") + json_sample, rst_sample = get_product_samples( + product="extraction", file_name="standard_field_types" + ) response = InferenceResponse(json_sample) assert isinstance(response.inference, Inference) @@ -148,7 +128,9 @@ def test_standard_field_types(): @pytest.mark.v2 def test_standard_field_object(): - json_sample, _ = _get_inference_samples("standard_field_types") + json_sample, _ = get_product_samples( + product="extraction", file_name="standard_field_types" + ) response = InferenceResponse(json_sample) object_field = response.inference.result.fields["field_object"] @@ -168,7 +150,9 @@ def test_standard_field_object(): @pytest.mark.v2 def test_standard_field_object_list(): - json_sample, _ = _get_inference_samples("standard_field_types") + json_sample, _ = get_product_samples( + product="extraction", file_name="standard_field_types" + ) response = InferenceResponse(json_sample) assert isinstance(response.inference, Inference) @@ -181,7 +165,9 @@ def test_standard_field_object_list(): @pytest.mark.v2 def test_standard_field_simple_list(): - json_sample, _ = _get_inference_samples("standard_field_types") + json_sample, _ = get_product_samples( + product="extraction", file_name="standard_field_types" + ) response = InferenceResponse(json_sample) assert isinstance(response.inference, Inference) @@ -194,7 +180,7 @@ def test_standard_field_simple_list(): @pytest.mark.v2 def test_raw_texts(): - json_sample, _ = _get_inference_samples("raw_texts") + json_sample, _ = get_product_samples(product="extraction", file_name="raw_texts") response = InferenceResponse(json_sample) assert isinstance(response.inference, Inference) @@ -210,7 +196,7 @@ def test_raw_texts(): @pytest.mark.v2 def test_rag_metadata_when_matched(): """RAG metadata when matched.""" - json_sample, _ = _get_inference_samples("rag_matched") + json_sample, _ = get_product_samples(product="extraction", file_name="rag_matched") response = InferenceResponse(json_sample) rag = response.inference.result.rag assert isinstance(rag, RagMetadata) @@ -221,7 +207,9 @@ def test_rag_metadata_when_matched(): @pytest.mark.v2 def test_rag_metadata_when_not_matched(): """RAG metadata when not matched.""" - json_sample, _ = _get_inference_samples("rag_not_matched") + json_sample, _ = get_product_samples( + product="extraction", file_name="rag_not_matched" + ) response = InferenceResponse(json_sample) rag = response.inference.result.rag assert isinstance(rag, RagMetadata) @@ -231,7 +219,9 @@ def test_rag_metadata_when_not_matched(): @pytest.mark.v2 def test_full_inference_response(): - json_sample, rst_sample = _get_product_samples("financial_document", "complete") + json_sample, _ = get_product_samples( + product="extraction/financial_document", file_name="complete" + ) response = InferenceResponse(json_sample) assert isinstance(response.inference, Inference) @@ -265,8 +255,8 @@ def test_field_locations_and_confidence() -> None: Validate that the first location polygon for the ``date`` field is correctly deserialized together with the associated confidence level. """ - json_sample, _ = _get_product_samples( - "financial_document", "complete_with_coordinates" + json_sample, _ = get_product_samples( + product="extraction/financial_document", file_name="complete_with_coordinates" ) response = InferenceResponse(json_sample) @@ -307,7 +297,9 @@ def test_field_locations_and_confidence() -> None: @pytest.mark.v2 def test_text_context_field_is_false() -> None: - json_sample, _ = _get_product_samples("financial_document", "complete") + json_sample, _ = get_product_samples( + product="extraction/financial_document", file_name="complete" + ) response = InferenceResponse(json_sample) assert isinstance(response.inference.active_options, InferenceActiveOptions) assert response.inference.active_options.text_context is False diff --git a/tests/v2/product/ocr/test_ocr_response.py b/tests/v2/product/ocr/test_ocr_response.py index ce8ff234..74b45c62 100644 --- a/tests/v2/product/ocr/test_ocr_response.py +++ b/tests/v2/product/ocr/test_ocr_response.py @@ -1,55 +1,47 @@ import pytest -from mindee import LocalResponse from mindee.v2.product.ocr.ocr_page import OCRPage from mindee.v2.product.ocr import OCRInference from mindee.v2.product.ocr.ocr_response import OCRResponse from mindee.v2.product.ocr.ocr_result import OCRResult -from tests.utils import V2_PRODUCT_DATA_DIR + +from tests.v2.product.utils import get_product_samples @pytest.mark.v2 def test_ocr_single(): - input_inference = LocalResponse(V2_PRODUCT_DATA_DIR / "ocr" / "ocr_single.json") - ocr_response = input_inference.deserialize_response(OCRResponse) - assert isinstance(ocr_response.inference, OCRInference) - assert ocr_response.inference.result.pages - assert len(ocr_response.inference.result.pages) == 1 - assert ocr_response.inference.result.pages[0].words[0].content == "Shipper:" + json_sample, _ = get_product_samples(product="ocr", file_name="ocr_single") + response = OCRResponse(json_sample) + assert isinstance(response.inference, OCRInference) + assert response.inference.result.pages + assert len(response.inference.result.pages) == 1 + assert response.inference.result.pages[0].words[0].content == "Shipper:" assert ( - ocr_response.inference.result.pages[0].words[0].polygon[0][0] - == 0.09742441209406495 + response.inference.result.pages[0].words[0].polygon[0][0] == 0.09742441209406495 ) assert ( - ocr_response.inference.result.pages[0].words[0].polygon[0][1] - == 0.07007125890736342 + response.inference.result.pages[0].words[0].polygon[0][1] == 0.07007125890736342 ) assert ( - ocr_response.inference.result.pages[0].words[0].polygon[1][0] - == 0.15621500559910415 + response.inference.result.pages[0].words[0].polygon[1][0] == 0.15621500559910415 ) assert ( - ocr_response.inference.result.pages[0].words[0].polygon[1][1] - == 0.07046714172604909 + response.inference.result.pages[0].words[0].polygon[1][1] == 0.07046714172604909 ) assert ( - ocr_response.inference.result.pages[0].words[0].polygon[2][0] - == 0.15621500559910415 + response.inference.result.pages[0].words[0].polygon[2][0] == 0.15621500559910415 ) assert ( - ocr_response.inference.result.pages[0].words[0].polygon[2][1] - == 0.08155186064924783 + response.inference.result.pages[0].words[0].polygon[2][1] == 0.08155186064924783 ) assert ( - ocr_response.inference.result.pages[0].words[0].polygon[3][0] - == 0.09742441209406495 + response.inference.result.pages[0].words[0].polygon[3][0] == 0.09742441209406495 ) assert ( - ocr_response.inference.result.pages[0].words[0].polygon[3][1] - == 0.08155186064924783 + response.inference.result.pages[0].words[0].polygon[3][1] == 0.08155186064924783 ) - assert len(ocr_response.inference.result.pages[0].words) == 305 - assert ocr_response.inference.result.pages[0].content == ( + assert len(response.inference.result.pages[0].words) == 305 + assert response.inference.result.pages[0].content == ( "Shipper: GLOBAL FREIGHT SOLUTIONS INC. 123 OCEAN DRIVE SHANGHAI, CHINA TEL: " "86-21-12345678 FAX: 86-21-87654321\nConsignee: PACIFIC TRADING CO. 789 TRADE " "STREET SINGAPORE 567890 SINGAPORE TEL: 65-65432100 FAX: 65-65432101\nNotify " @@ -83,50 +75,43 @@ def test_ocr_single(): @pytest.mark.v2 def test_ocr_multiple(): - input_inference = LocalResponse(V2_PRODUCT_DATA_DIR / "ocr" / "ocr_multiple.json") - ocr_response = input_inference.deserialize_response(OCRResponse) - assert isinstance(ocr_response.inference, OCRInference) - assert isinstance(ocr_response.inference.result, OCRResult) - assert isinstance(ocr_response.inference.result.pages[0], OCRPage) - assert len(ocr_response.inference.result.pages) == 3 + json_sample, _ = get_product_samples(product="ocr", file_name="ocr_multiple") + response = OCRResponse(json_sample) + + assert isinstance(response.inference, OCRInference) + assert isinstance(response.inference.result, OCRResult) + assert isinstance(response.inference.result.pages[0], OCRPage) + assert len(response.inference.result.pages) == 3 - assert len(ocr_response.inference.result.pages[0].words) == 295 - assert ocr_response.inference.result.pages[0].words[0].content == "FICTIOCORP" + assert len(response.inference.result.pages[0].words) == 295 + assert response.inference.result.pages[0].words[0].content == "FICTIOCORP" assert ( - ocr_response.inference.result.pages[0].words[0].polygon[0][0] - == 0.06649402824332337 + response.inference.result.pages[0].words[0].polygon[0][0] == 0.06649402824332337 ) assert ( - ocr_response.inference.result.pages[0].words[0].polygon[0][1] - == 0.03957449719523875 + response.inference.result.pages[0].words[0].polygon[0][1] == 0.03957449719523875 ) assert ( - ocr_response.inference.result.pages[0].words[0].polygon[1][0] - == 0.23219061218068954 + response.inference.result.pages[0].words[0].polygon[1][0] == 0.23219061218068954 ) assert ( - ocr_response.inference.result.pages[0].words[0].polygon[1][1] - == 0.03960015049938432 + response.inference.result.pages[0].words[0].polygon[1][1] == 0.03960015049938432 ) assert ( - ocr_response.inference.result.pages[0].words[0].polygon[2][0] - == 0.23219061218068954 + response.inference.result.pages[0].words[0].polygon[2][0] == 0.23219061218068954 ) assert ( - ocr_response.inference.result.pages[0].words[0].polygon[2][1] - == 0.06770762074155151 + response.inference.result.pages[0].words[0].polygon[2][1] == 0.06770762074155151 ) assert ( - ocr_response.inference.result.pages[0].words[0].polygon[3][0] - == 0.06649402824332337 + response.inference.result.pages[0].words[0].polygon[3][0] == 0.06649402824332337 ) assert ( - ocr_response.inference.result.pages[0].words[0].polygon[3][1] - == 0.06770762074155151 + response.inference.result.pages[0].words[0].polygon[3][1] == 0.06770762074155151 ) - assert len(ocr_response.inference.result.pages[1].words) == 450 - assert ocr_response.inference.result.pages[1].words[0].content == "KEOLIO" + assert len(response.inference.result.pages[1].words) == 450 + assert response.inference.result.pages[1].words[0].content == "KEOLIO" - assert len(ocr_response.inference.result.pages[2].words) == 355 - assert ocr_response.inference.result.pages[2].words[0].content == "KEOLIO" + assert len(response.inference.result.pages[2].words) == 355 + assert response.inference.result.pages[2].words[0].content == "KEOLIO" diff --git a/tests/v2/product/split/test_split_response.py b/tests/v2/product/split/test_split_response.py index 38537a8c..4ce2ad8b 100644 --- a/tests/v2/product/split/test_split_response.py +++ b/tests/v2/product/split/test_split_response.py @@ -1,47 +1,45 @@ import pytest -from mindee import LocalResponse from mindee.v2.product.split.split_range import SplitRange from mindee.v2.product.split import SplitInference from mindee.v2.product.split.split_response import SplitResponse from mindee.v2.product.split.split_result import SplitResult -from tests.utils import V2_PRODUCT_DATA_DIR +from tests.v2.product.utils import get_product_samples @pytest.mark.v2 def test_split_single(): - input_inference = LocalResponse(V2_PRODUCT_DATA_DIR / "split" / "split_single.json") - split_response = input_inference.deserialize_response(SplitResponse) - assert isinstance(split_response.inference, SplitInference) - assert split_response.inference.result.splits - assert len(split_response.inference.result.splits[0].page_range) == 2 - assert split_response.inference.result.splits[0].page_range[0] == 0 - assert split_response.inference.result.splits[0].page_range[1] == 0 - assert split_response.inference.result.splits[0].document_type == "receipt" + json_sample, _ = get_product_samples(product="split", file_name="split_single") + response = SplitResponse(json_sample) + + assert isinstance(response.inference, SplitInference) + assert response.inference.result.splits + assert len(response.inference.result.splits[0].page_range) == 2 + assert response.inference.result.splits[0].page_range[0] == 0 + assert response.inference.result.splits[0].page_range[1] == 0 + assert response.inference.result.splits[0].document_type == "receipt" @pytest.mark.v2 def test_split_multiple(): - input_inference = LocalResponse( - V2_PRODUCT_DATA_DIR / "split" / "split_multiple.json" - ) - split_response = input_inference.deserialize_response(SplitResponse) - assert isinstance(split_response.inference, SplitInference) - assert isinstance(split_response.inference.result, SplitResult) - assert isinstance(split_response.inference.result.splits[0], SplitRange) - assert len(split_response.inference.result.splits) == 3 - - assert len(split_response.inference.result.splits[0].page_range) == 2 - assert split_response.inference.result.splits[0].page_range[0] == 0 - assert split_response.inference.result.splits[0].page_range[1] == 0 - assert split_response.inference.result.splits[0].document_type == "invoice" - - assert len(split_response.inference.result.splits[1].page_range) == 2 - assert split_response.inference.result.splits[1].page_range[0] == 1 - assert split_response.inference.result.splits[1].page_range[1] == 3 - assert split_response.inference.result.splits[1].document_type == "invoice" - - assert len(split_response.inference.result.splits[2].page_range) == 2 - assert split_response.inference.result.splits[2].page_range[0] == 4 - assert split_response.inference.result.splits[2].page_range[1] == 4 - assert split_response.inference.result.splits[2].document_type == "invoice" + json_sample, _ = get_product_samples(product="split", file_name="split_multiple") + response = SplitResponse(json_sample) + assert isinstance(response.inference, SplitInference) + assert isinstance(response.inference.result, SplitResult) + assert isinstance(response.inference.result.splits[0], SplitRange) + assert len(response.inference.result.splits) == 3 + + assert len(response.inference.result.splits[0].page_range) == 2 + assert response.inference.result.splits[0].page_range[0] == 0 + assert response.inference.result.splits[0].page_range[1] == 0 + assert response.inference.result.splits[0].document_type == "invoice" + + assert len(response.inference.result.splits[1].page_range) == 2 + assert response.inference.result.splits[1].page_range[0] == 1 + assert response.inference.result.splits[1].page_range[1] == 3 + assert response.inference.result.splits[1].document_type == "invoice" + + assert len(response.inference.result.splits[2].page_range) == 2 + assert response.inference.result.splits[2].page_range[0] == 4 + assert response.inference.result.splits[2].page_range[1] == 4 + assert response.inference.result.splits[2].document_type == "invoice" diff --git a/tests/v2/product/utils.py b/tests/v2/product/utils.py new file mode 100644 index 00000000..0a3a14bf --- /dev/null +++ b/tests/v2/product/utils.py @@ -0,0 +1,22 @@ +from typing import Tuple +from pathlib import Path +import json + +from tests.utils import V2_PRODUCT_DATA_DIR + + +def get_samples(json_path: Path, rst_path: Path) -> Tuple[dict, str]: + with json_path.open("r", encoding="utf-8") as fh: + json_sample = json.load(fh) + try: + with rst_path.open("r", encoding="utf-8") as fh: + rst_sample = fh.read() + except FileNotFoundError: + rst_sample = "" + return json_sample, rst_sample + + +def get_product_samples(product: str, file_name: str) -> Tuple[dict, str]: + json_path = V2_PRODUCT_DATA_DIR / product / f"{file_name}.json" + rst_path = V2_PRODUCT_DATA_DIR / product / f"{file_name}.rst" + return get_samples(json_path, rst_path)