diff --git a/.gitignore b/.gitignore index 38d49af..2d0c882 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ __pycache__/ dist/ .claude/ + +.vscode/ diff --git a/CHANGELOG.md b/CHANGELOG.md index 11327be..7a83011 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,40 @@ # Changelog +## 8.0.0 - 2026-03-20 + +### Breaking + +- `kg_build` postprocessor subprocess entry points renamed: + - `runner.py` → `oneshot.py` + - `worker.py` → `persistent.py` +- `SubprocessPostprocessor` split into `OneshotPostprocessor` and `PersistentPostprocessor`; any host code referencing the old class name must be updated. +- `PostprocessorService` is now profile-agnostic; profile resolution no longer happens inside the service. +- `utils.get_me` / `utils.reset_me` module files renamed to `_get_me.py` / `_reset_me.py`; direct submodule imports (not recommended) must be updated. + +### Added + +- `ShaclValidationService` — runs SHACL validation in a dedicated process pool via `PreparedShaclValidator`, wired into `ProfileImportProtocol`. +- Separate pool-size settings for postprocessors and SHACL validation. +- In-process postprocessor runtime (`inprocess`) for single-process execution. +- SHACL process-pool queue-wait and execution-time tracking in timing logs. +- `morph_kgc` subprocess pool for true RML-mapping parallelism, bypassing `pyparsing` lock contention and the GIL. + - Configurable pool size via `morph_kgc_pool_size` / `MORPH_KGC_POOL_SIZE`. + - Subprocess queue-wait tracked separately in timing logs. +- `PostprocessorResult` dataclass — replaces implicit tuple return from postprocessing stage. +- `ImportAnnotationPostprocessor` and `RootIdReconcilerPostprocessor` extracted as named processors. +- `first_level_subjects` graph utility helper. +- Slice verification tooling extended with `run_slice_smoke_imports.py` and `run_slice_tests.py`. + +### Changed + +- Postprocessors reorganised into `postprocessors/` subpackage (`processors/`, `PostprocessorService`, loader helpers). +- `ProfileImportProtocol.__init__` decomposed into focused `_init_*` factory methods; class surface significantly reduced. +- `morph_kgc` RML mapping stage runs in subprocess pool instead of a thread executor. +- SHACL validation and postprocessors offloaded to dedicated thread/process pools; ingestion runs in an executor to avoid blocking the event loop. +- Persistent `ApiClient` reused across requests instead of one per graph; `ApiClient` is closed on protocol shutdown. +- Lazy-export guards remapped to modules with real third-party dependencies so `ModuleNotFoundError` fires correctly when an extra is absent. +- `python-liquid` added to the `workflow` extra (required by `graph.ttl_liquid`). + ## 7.0.0 - 2026-03-15 ### Breaking diff --git a/poetry.lock b/poetry.lock index f0bc229..b3ac36a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.3.1 and should not be changed by hand. [[package]] name = "advertools" @@ -288,7 +288,7 @@ description = "Internationalization utilities" optional = true python-versions = ">=3.8" groups = ["main"] -markers = "extra == \"graph\" or extra == \"kg-build\" or extra == \"legacy\" or extra == \"all\" or extra == \"ingestion\"" +markers = "extra == \"workflow\" or extra == \"graph\" or extra == \"kg-build\" or extra == \"legacy\" or extra == \"all\" or extra == \"ingestion\"" files = [ {file = "babel-2.18.0-py3-none-any.whl", hash = "sha256:e2b422b277c2b9a9630c1d7903c2a00d0830c409c59ac8cae9081c92f1aeba35"}, {file = "babel-2.18.0.tar.gz", hash = "sha256:b80b99a14bd085fcacfa15c9165f651fbb3406e66cc603abf11c5750937c992d"}, @@ -1702,8 +1702,11 @@ files = [ {file = "lxml-5.4.0-cp36-cp36m-win_amd64.whl", hash = "sha256:7ce1a171ec325192c6a636b64c94418e71a1964f56d002cc28122fceff0b6121"}, {file = "lxml-5.4.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:795f61bcaf8770e1b37eec24edf9771b307df3af74d1d6f27d812e15a9ff3872"}, {file = "lxml-5.4.0-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:29f451a4b614a7b5b6c2e043d7b64a15bd8304d7e767055e8ab68387a8cacf4e"}, + {file = "lxml-5.4.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:891f7f991a68d20c75cb13c5c9142b2a3f9eb161f1f12a9489c82172d1f133c0"}, {file = "lxml-5.4.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4aa412a82e460571fad592d0f93ce9935a20090029ba08eca05c614f99b0cc92"}, + {file = "lxml-5.4.0-cp37-cp37m-manylinux_2_28_aarch64.whl", hash = "sha256:ac7ba71f9561cd7d7b55e1ea5511543c0282e2b6450f122672a2694621d63b7e"}, {file = "lxml-5.4.0-cp37-cp37m-manylinux_2_28_x86_64.whl", hash = "sha256:c5d32f5284012deaccd37da1e2cd42f081feaa76981f0eaa474351b68df813c5"}, + {file = "lxml-5.4.0-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:ce31158630a6ac85bddd6b830cffd46085ff90498b397bd0a259f59d27a12188"}, {file = "lxml-5.4.0-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:31e63621e073e04697c1b2d23fcb89991790eef370ec37ce4d5d469f40924ed6"}, {file = "lxml-5.4.0-cp37-cp37m-win32.whl", hash = "sha256:be2ba4c3c5b7900246a8f866580700ef0d538f2ca32535e991027bdaba944063"}, {file = "lxml-5.4.0-cp37-cp37m-win_amd64.whl", hash = "sha256:09846782b1ef650b321484ad429217f5154da4d6e786636c38e434fa32e94e49"}, @@ -1788,7 +1791,7 @@ description = "Safely add untrusted strings to HTML/XML markup." optional = true python-versions = ">=3.9" groups = ["main"] -markers = "extra == \"graph\" or extra == \"kg-build\" or extra == \"legacy\" or extra == \"all\"" +markers = "extra == \"workflow\" or extra == \"graph\" or extra == \"kg-build\" or extra == \"legacy\" or extra == \"all\"" files = [ {file = "markupsafe-3.0.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:2f981d352f04553a7171b8e44369f2af4055f888dfb147d55e42d29e29e74559"}, {file = "markupsafe-3.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e1c1493fb6e50ab01d20a22826e57520f1284df32f2d8601fdd90b6304601419"}, @@ -3185,7 +3188,7 @@ description = "A Python engine for the Liquid template language." optional = true python-versions = ">=3.7" groups = ["main"] -markers = "extra == \"graph\" or extra == \"kg-build\" or extra == \"legacy\" or extra == \"all\"" +markers = "extra == \"workflow\" or extra == \"graph\" or extra == \"kg-build\" or extra == \"legacy\" or extra == \"all\"" files = [ {file = "python_liquid-2.1.0-py3-none-any.whl", hash = "sha256:d3bbcddff4e1a73287b59218df3471613598271e69ac3d17d97e000f4b984e3e"}, {file = "python_liquid-2.1.0.tar.gz", hash = "sha256:a4c2abb24ac40ded8c9ba844ebbfbe78a3e41c6fe10a7bbe94144582569b73d0"}, @@ -3249,6 +3252,13 @@ optional = false python-versions = ">=3.8" groups = ["dev"] files = [ + {file = "PyYAML-6.0.3-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:c2514fceb77bc5e7a2f7adfaa1feb2fb311607c9cb518dbc378688ec73d8292f"}, + {file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9c57bb8c96f6d1808c030b1687b9b5fb476abaa47f0db9c0101f5e9f394e97f4"}, + {file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:efd7b85f94a6f21e4932043973a7ba2613b059c4a000551892ac9f1d11f5baf3"}, + {file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:22ba7cfcad58ef3ecddc7ed1db3409af68d023b7f940da23c6c2a1890976eda6"}, + {file = "PyYAML-6.0.3-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:6344df0d5755a2c9a276d4473ae6b90647e216ab4757f8426893b5dd2ac3f369"}, + {file = "PyYAML-6.0.3-cp38-cp38-win32.whl", hash = "sha256:3ff07ec89bae51176c0549bc4c63aa6202991da2d9a6129d7aef7f1407d3f295"}, + {file = "PyYAML-6.0.3-cp38-cp38-win_amd64.whl", hash = "sha256:5cf4e27da7e3fbed4d6c3d8e797387aaad68102272f8f9752883bc32d61cb87b"}, {file = "pyyaml-6.0.3-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:214ed4befebe12df36bcc8bc2b64b396ca31be9304b8f59e25c11cf94a4c033b"}, {file = "pyyaml-6.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:02ea2dfa234451bbb8772601d7b8e426c2bfa197136796224e50e35a78777956"}, {file = "pyyaml-6.0.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b30236e45cf30d2b8e7b3e85881719e98507abed1011bf463a8fa23e9c3e98a8"}, @@ -4345,9 +4355,9 @@ legacy = ["google-auth", "gql", "gspread", "lxml", "pandas", "playwright", "pyco render = ["lxml", "playwright"] structured-data = ["advertools", "lxml", "morph-kgc", "playwright", "pyshacl", "rdflib", "requests", "tqdm"] validation = ["pyshacl", "rdflib", "requests", "tqdm"] -workflow = ["advertools", "google-auth", "gql", "gspread", "lxml", "pandas", "playwright", "pydantic-core", "rdflib", "tqdm"] +workflow = ["advertools", "google-auth", "gql", "gspread", "lxml", "pandas", "playwright", "pydantic-core", "python-liquid", "rdflib", "tqdm"] [metadata] lock-version = "2.1" python-versions = ">=3.10, <3.15" -content-hash = "0810a8470047131214fc3655380b14044bb11660895b114d5f61fc0e0263d1bc" +content-hash = "a119ca316866d292b70b03bb5e509c4eded8fb9d581ed3c5e541961e6aee98a8" diff --git a/pyproject.toml b/pyproject.toml index 96799a1..3d7ee87 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "wordlift-sdk" -version = "7.0.1" +version = "8.0.0" description = "Python toolkit for orchestrating WordLift imports and structured data workflows." authors = ["David Riccitelli "] readme = "README.md" @@ -81,6 +81,7 @@ workflow = [ "pandas", "playwright", "pydantic-core", + "python-liquid", "rdflib", "tqdm", ] diff --git a/tests/kg_build/postprocessors/__init__.py b/tests/kg_build/postprocessors/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/kg_build/postprocessors/processors/__init__.py b/tests/kg_build/postprocessors/processors/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/kg_build/test_id_allocator.py b/tests/kg_build/postprocessors/processors/test_id_allocator.py similarity index 94% rename from tests/kg_build/test_id_allocator.py rename to tests/kg_build/postprocessors/processors/test_id_allocator.py index bbd77a6..c420626 100644 --- a/tests/kg_build/test_id_allocator.py +++ b/tests/kg_build/postprocessors/processors/test_id_allocator.py @@ -2,8 +2,11 @@ from rdflib import Graph, Literal, RDF, URIRef -import wordlift_sdk.kg_build.id_allocator as id_allocator_module -from wordlift_sdk.kg_build.id_allocator import IdAllocator, normalize_slug +import wordlift_sdk.kg_build.postprocessors.processors.id_allocator as id_allocator_module +from wordlift_sdk.kg_build.postprocessors.processors.id_allocator import ( + IdAllocator, + normalize_slug, +) def _graph(subject: URIRef) -> Graph: diff --git a/tests/kg_build/test_kg_build_id_generator.py b/tests/kg_build/postprocessors/processors/test_id_generator.py similarity index 99% rename from tests/kg_build/test_kg_build_id_generator.py rename to tests/kg_build/postprocessors/processors/test_id_generator.py index d59d394..b46acfa 100644 --- a/tests/kg_build/test_kg_build_id_generator.py +++ b/tests/kg_build/postprocessors/processors/test_id_generator.py @@ -3,7 +3,9 @@ from rdflib import Graph, Literal, RDF, URIRef from rdflib.namespace import XSD -from wordlift_sdk.kg_build.id_generator import CanonicalIdGenerator +from wordlift_sdk.kg_build.postprocessors.processors.id_generator import ( + CanonicalIdGenerator, +) from wordlift_sdk.kg_build.iri_lookup import IriLookup from wordlift_sdk.kg_build.id_policy import DEFAULT_ID_POLICY, IdPolicy diff --git a/tests/kg_build/test_id_postprocessor.py b/tests/kg_build/postprocessors/processors/test_id_postprocessor.py similarity index 96% rename from tests/kg_build/test_id_postprocessor.py rename to tests/kg_build/postprocessors/processors/test_id_postprocessor.py index 9f2a0c5..d5a94ee 100644 --- a/tests/kg_build/test_id_postprocessor.py +++ b/tests/kg_build/postprocessors/processors/test_id_postprocessor.py @@ -4,7 +4,9 @@ from rdflib import Graph, Literal, RDF, URIRef -from wordlift_sdk.kg_build.id_postprocessor import CanonicalIdsPostprocessor +from wordlift_sdk.kg_build.postprocessors.processors.id_postprocessor import ( + CanonicalIdsPostprocessor, +) def test_id_postprocessor_no_dataset_uri_returns_original_graph() -> None: diff --git a/tests/kg_build/test_postprocessor_runner_helpers.py b/tests/kg_build/postprocessors/test_oneshot_helpers.py similarity index 93% rename from tests/kg_build/test_postprocessor_runner_helpers.py rename to tests/kg_build/postprocessors/test_oneshot_helpers.py index 1083eaa..1034147 100644 --- a/tests/kg_build/test_postprocessor_runner_helpers.py +++ b/tests/kg_build/postprocessors/test_oneshot_helpers.py @@ -4,7 +4,7 @@ from rdflib import Graph, Literal, URIRef -from wordlift_sdk.kg_build import postprocessor_runner as runner +from wordlift_sdk.kg_build.postprocessors import oneshot as runner def test_load_class_variants(monkeypatch) -> None: diff --git a/tests/kg_build/test_postprocessor_runner_main.py b/tests/kg_build/postprocessors/test_oneshot_main.py similarity index 97% rename from tests/kg_build/test_postprocessor_runner_main.py rename to tests/kg_build/postprocessors/test_oneshot_main.py index 0b3bc8e..224010f 100644 --- a/tests/kg_build/test_postprocessor_runner_main.py +++ b/tests/kg_build/postprocessors/test_oneshot_main.py @@ -6,7 +6,7 @@ from rdflib import Graph, Literal, URIRef -from wordlift_sdk.kg_build import postprocessor_runner as runner +from wordlift_sdk.kg_build.postprocessors import oneshot as runner def _graph() -> Graph: diff --git a/tests/kg_build/test_postprocessor_worker.py b/tests/kg_build/postprocessors/test_persistent.py similarity index 98% rename from tests/kg_build/test_postprocessor_worker.py rename to tests/kg_build/postprocessors/test_persistent.py index 7f359a3..0cfe3b7 100644 --- a/tests/kg_build/test_postprocessor_worker.py +++ b/tests/kg_build/postprocessors/test_persistent.py @@ -7,7 +7,7 @@ from rdflib import Graph, Literal, URIRef -from wordlift_sdk.kg_build import postprocessor_worker as worker +from wordlift_sdk.kg_build.postprocessors import persistent as worker def _graph() -> Graph: diff --git a/tests/kg_build/test_postprocessors.py b/tests/kg_build/postprocessors/test_postprocessors.py similarity index 93% rename from tests/kg_build/test_postprocessors.py rename to tests/kg_build/postprocessors/test_postprocessors.py index 92afadc..ed812ab 100644 --- a/tests/kg_build/test_postprocessors.py +++ b/tests/kg_build/postprocessors/test_postprocessors.py @@ -12,7 +12,7 @@ import pytest from rdflib import Dataset, Graph, Literal, URIRef -from wordlift_sdk.kg_build.postprocessor_runner import ( +from wordlift_sdk.kg_build.postprocessors.oneshot import ( _build_context, _read_graph_nquads, ) @@ -20,11 +20,14 @@ LoadedPostprocessor, PostprocessorContext, PostprocessorSpec, - SubprocessPostprocessor, - _build_runner_payload, close_loaded_postprocessors, load_postprocessors_for_profile, ) +from wordlift_sdk.kg_build.postprocessors.graph_io import _build_runner_payload +from wordlift_sdk.kg_build.postprocessors.subprocess import ( + OneshotSubprocessPostprocessor, + PersistentSubprocessPostprocessor, +) PROJECT_ROOT = Path(__file__).resolve().parents[2] _current_pythonpath = os.environ.get("PYTHONPATH", "") @@ -162,8 +165,8 @@ class = "test_pp:ProfileTwo" first = loaded[0].handler second = loaded[1].handler - assert isinstance(second, SubprocessPostprocessor) - assert isinstance(first, SubprocessPostprocessor) + assert isinstance(second, OneshotSubprocessPostprocessor) + assert isinstance(first, OneshotSubprocessPostprocessor) assert first.spec.python == "/profile/python" assert first.spec.timeout_seconds == 17 assert first.spec.keep_temp_on_error is True @@ -190,7 +193,7 @@ class = "test_pp:BaseOne" assert [item.name for item in loaded] == ["test_pp:BaseOne"] first = loaded[0].handler - assert isinstance(first, SubprocessPostprocessor) + assert isinstance(first, OneshotSubprocessPostprocessor) assert first.spec.python == "/base/python" assert first.spec.timeout_seconds == 11 assert first.spec.keep_temp_on_error is False @@ -219,8 +222,7 @@ class = "test_pp:ProfileOne" runtime="persistent", ) assert len(loaded) == 1 - assert isinstance(loaded[0].handler, SubprocessPostprocessor) - assert loaded[0].handler.runtime == "persistent" + assert isinstance(loaded[0].handler, PersistentSubprocessPostprocessor) def test_subprocess_execution_and_nquads_exchange(tmp_path: Path) -> None: @@ -249,7 +251,7 @@ def process_graph(self, graph, context): enabled=True, keep_temp_on_error=False, ) - processor = SubprocessPostprocessor(spec=spec, root_dir=root) + processor = OneshotSubprocessPostprocessor(spec=spec, root_dir=root) output = processor.process_graph(_sample_graph(), _sample_context()) assert output is not None @@ -291,11 +293,7 @@ def process_graph(self, graph, context): enabled=True, keep_temp_on_error=False, ) - processor = SubprocessPostprocessor( - spec=spec, - root_dir=root, - runtime="persistent", - ) + processor = PersistentSubprocessPostprocessor(spec=spec, root_dir=root) first = processor.process_graph(_sample_graph(), _sample_context()) second = processor.process_graph(_sample_graph(), _sample_context()) @@ -351,7 +349,12 @@ def process_graph(self, graph, context): enabled=True, keep_temp_on_error=False, ) - processor = SubprocessPostprocessor(spec=spec, root_dir=root, runtime=runtime) + cls = ( + PersistentSubprocessPostprocessor + if runtime == "persistent" + else OneshotSubprocessPostprocessor + ) + processor = cls(spec=spec, root_dir=root) try: output = processor.process_graph( _sample_graph(), @@ -405,7 +408,12 @@ def process_graph(self, graph, context): enabled=True, keep_temp_on_error=False, ) - processor = SubprocessPostprocessor(spec=spec, root_dir=root, runtime=runtime) + cls = ( + PersistentSubprocessPostprocessor + if runtime == "persistent" + else OneshotSubprocessPostprocessor + ) + processor = cls(spec=spec, root_dir=root) try: output = processor.process_graph( _sample_graph(), @@ -471,7 +479,7 @@ def process_graph(self, graph, context): [ sys.executable, "-m", - "wordlift_sdk.kg_build.postprocessor_runner", + "wordlift_sdk.kg_build.postprocessors.oneshot", "--class", "test_pp:AddRunnerTriple", "--input-graph", @@ -517,7 +525,7 @@ def process_graph(self, graph, context): enabled=True, keep_temp_on_error=False, ) - processor = SubprocessPostprocessor(spec=spec, root_dir=root) + processor = OneshotSubprocessPostprocessor(spec=spec, root_dir=root) with pytest.raises(subprocess.TimeoutExpired): processor.process_graph(_sample_graph(), _sample_context()) @@ -543,11 +551,7 @@ def process_graph(self, graph, context): enabled=True, keep_temp_on_error=False, ) - processor = SubprocessPostprocessor( - spec=spec, - root_dir=root, - runtime="persistent", - ) + processor = PersistentSubprocessPostprocessor(spec=spec, root_dir=root) with pytest.raises(subprocess.TimeoutExpired): processor.process_graph(_sample_graph(), _sample_context()) @@ -571,7 +575,7 @@ def process_graph(self, graph, context): enabled=True, keep_temp_on_error=True, ) - processor = SubprocessPostprocessor(spec=spec, root_dir=root) + processor = OneshotSubprocessPostprocessor(spec=spec, root_dir=root) with pytest.raises(RuntimeError): processor.process_graph(_sample_graph(), _sample_context()) @@ -607,7 +611,7 @@ def process_graph(self, graph, context): enabled=True, keep_temp_on_error=True, ) - processor = SubprocessPostprocessor(spec=spec, root_dir=root) + processor = OneshotSubprocessPostprocessor(spec=spec, root_dir=root) secret = "top-secret-key" with pytest.raises(RuntimeError): @@ -683,7 +687,7 @@ def test_subprocess_uses_inherited_environment_without_pythonpath_injection( enabled=True, keep_temp_on_error=False, ) - processor = SubprocessPostprocessor(spec=spec, root_dir=root) + processor = OneshotSubprocessPostprocessor(spec=spec, root_dir=root) captured: dict[str, object] = {} def fake_run(*args, **kwargs): diff --git a/tests/kg_build/postprocessors/test_service.py b/tests/kg_build/postprocessors/test_service.py new file mode 100644 index 0000000..df0cf49 --- /dev/null +++ b/tests/kg_build/postprocessors/test_service.py @@ -0,0 +1,164 @@ +from __future__ import annotations + +import asyncio +from types import SimpleNamespace + +import pytest +from rdflib import Graph, Literal, URIRef + +from wordlift_sdk.kg_build.postprocessors.service import PostprocessorService +from wordlift_sdk.kg_build.postprocessors.types import ( + LoadedPostprocessor, + PostprocessorContext, +) + + +def _sample_graph() -> Graph: + g = Graph() + g.add( + ( + URIRef("https://example.com/s"), + URIRef("https://example.com/p"), + Literal("v"), + ) + ) + return g + + +def _sample_context() -> PostprocessorContext: + return PostprocessorContext( + profile_name="test", + profile={}, + url="https://example.com/page", + account=SimpleNamespace(dataset_uri="https://data.example.com"), + account_key=None, + exports={}, + response=SimpleNamespace( + id=None, web_page=SimpleNamespace(url=None, html=None) + ), + existing_web_page_id=None, + ) + + +def _make_service(pool_size: int = 1, processors=None) -> PostprocessorService: + if processors is None: + + class _Passthrough: + def process_graph(self, graph: Graph, context) -> Graph: + return graph + + processors = [LoadedPostprocessor(name="passthrough", handler=_Passthrough())] + + return PostprocessorService( + postprocessors_factory=lambda: processors, + pool_size=pool_size, + ) + + +def test_apply_returns_result_with_graph_and_timings() -> None: + service = _make_service() + result = asyncio.run(service.apply(_sample_graph(), _sample_context())) + service.close() + + assert isinstance(result.graph, Graph) + assert len(result.graph) == 1 + assert result.queue_wait_ms >= 0 + assert result.postprocessors_ms >= 0 + + +def test_apply_runs_processors_in_order() -> None: + additions: list[int] = [] + + class _Mark: + def __init__(self, n: int) -> None: + self._n = n + + def process_graph(self, graph: Graph, context) -> Graph: + additions.append(self._n) + graph.add( + ( + URIRef(f"https://example.com/s{self._n}"), + URIRef("https://example.com/p"), + Literal(self._n), + ) + ) + return graph + + processors = [ + LoadedPostprocessor(name="first", handler=_Mark(1)), + LoadedPostprocessor(name="second", handler=_Mark(2)), + ] + service = PostprocessorService( + postprocessors_factory=lambda: processors, + pool_size=1, + ) + result = asyncio.run(service.apply(_sample_graph(), _sample_context())) + service.close() + + assert additions == [1, 2] + assert len(result.graph) == 3 # original + 2 added + + +def test_close_calls_close_on_closeable_handlers() -> None: + class _Closeable: + def __init__(self) -> None: + self.closed = False + + def close(self) -> None: + self.closed = True + + def process_graph(self, graph: Graph, context) -> Graph: + return graph + + handler = _Closeable() + service = PostprocessorService( + postprocessors_factory=lambda: [LoadedPostprocessor(name="c", handler=handler)], + pool_size=1, + ) + service.close() + + assert handler.closed is True + + +def test_pool_isolates_slots() -> None: + """Each slot in the pool should be an independent list of processors.""" + slot_ids: list[int] = [] + + class _Recorder: + def __init__(self, slot_id: int) -> None: + self._slot_id = slot_id + + def process_graph(self, graph: Graph, context) -> Graph: + slot_ids.append(self._slot_id) + return graph + + slot_counter = [0] + + def factory() -> list[LoadedPostprocessor]: + slot_counter[0] += 1 + sid = slot_counter[0] + return [LoadedPostprocessor(name=f"slot-{sid}", handler=_Recorder(sid))] + + pool_size = 2 + service = PostprocessorService(postprocessors_factory=factory, pool_size=pool_size) + try: + # Run both slots sequentially + asyncio.run(service.apply(_sample_graph(), _sample_context())) + asyncio.run(service.apply(_sample_graph(), _sample_context())) + finally: + service.close() + + # Both slots should have been used (order may vary but both IDs present) + assert len(slot_ids) == 2 + assert set(slot_ids) == {1, 2} + + +@pytest.mark.asyncio +async def test_apply_async_returns_correct_graph() -> None: + service = _make_service() + graph = _sample_graph() + result = await service.apply(graph, _sample_context()) + service.close() + + assert isinstance(result.graph, Graph) + assert len(result.graph) == 1 diff --git a/tests/kg_build/test_graph_utils.py b/tests/kg_build/test_graph_utils.py new file mode 100644 index 0000000..4f1dec4 --- /dev/null +++ b/tests/kg_build/test_graph_utils.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +from rdflib import Graph, Literal, URIRef + +from wordlift_sdk.kg_build.graph_utils import first_level_subjects + +DATASET = "https://data.example.com" + + +def _uri(path: str) -> URIRef: + return URIRef(f"{DATASET}/{path}") + + +def _ext(path: str) -> URIRef: + return URIRef(f"https://external.example.com/{path}") + + +def test_empty_graph_returns_empty_set() -> None: + assert first_level_subjects(Graph(), DATASET) == set() + + +def test_dataset_uri_match_returns_two_segment_subjects() -> None: + g = Graph() + canonical = _uri("articles/my-article") # 2 segments → first-level + deep = _uri("articles/my-article/comments/1") # 4 segments → not first-level + g.add((canonical, URIRef("https://schema.org/name"), Literal("Article"))) + g.add((deep, URIRef("https://schema.org/name"), Literal("Comment"))) + + result = first_level_subjects(g, DATASET) + assert canonical in result + assert deep not in result + + +def test_dataset_uri_match_ignores_single_segment() -> None: + g = Graph() + one_seg = _uri("articles") # 1 segment → not first-level by-id + two_seg = _uri("articles/slug") # 2 segments → first-level + g.add((one_seg, URIRef("https://schema.org/name"), Literal("Collection"))) + g.add((two_seg, URIRef("https://schema.org/name"), Literal("Item"))) + + result = first_level_subjects(g, DATASET) + assert two_seg in result + assert one_seg not in result + + +def test_fallback_to_unreferenced_subjects_when_no_dataset_match() -> None: + g = Graph() + root = _ext("root") + child = _ext("child") + # child is referenced by root, so root is the unreferenced subject + g.add((root, URIRef("https://schema.org/hasPart"), child)) + g.add((child, URIRef("https://schema.org/name"), Literal("Child"))) + + # No dataset_uri prefix match; fall back to "not referenced" logic + result = first_level_subjects(g, "") + assert root in result + assert child not in result + + +def test_fallback_returns_all_when_everything_is_referenced() -> None: + g = Graph() + a = _ext("a") + b = _ext("b") + # mutual references: both are referenced + g.add((a, URIRef("https://schema.org/hasPart"), b)) + g.add((b, URIRef("https://schema.org/hasPart"), a)) + + result = first_level_subjects(g, "") + assert result == {a, b} + + +def test_blank_dataset_uri_uses_reference_fallback() -> None: + g = Graph() + page = _ext("page") + product = _ext("product") + g.add((page, URIRef("https://schema.org/mentions"), product)) + g.add((product, URIRef("https://schema.org/name"), Literal("Product"))) + + result = first_level_subjects(g, "") + assert page in result + assert product not in result + + +def test_dataset_uri_prefix_no_match_falls_back_gracefully() -> None: + g = Graph() + ext_subject = _ext("item") + g.add((ext_subject, URIRef("https://schema.org/name"), Literal("External"))) + + # dataset_uri set but no subject matches the prefix + result = first_level_subjects(g, DATASET) + assert ext_subject in result + + +def test_literal_objects_are_not_counted_as_subjects() -> None: + g = Graph() + s = _uri("things/item") + g.add((s, URIRef("https://schema.org/name"), Literal("Name"))) + + result = first_level_subjects(g, DATASET) + assert s in result diff --git a/tests/kg_build/test_kpi.py b/tests/kg_build/test_kpi.py index 905a54b..69ef8e2 100644 --- a/tests/kg_build/test_kpi.py +++ b/tests/kg_build/test_kpi.py @@ -1,6 +1,7 @@ from rdflib import Graph, Literal, RDF, URIRef from wordlift_sdk.kg_build.kpi import KgBuildKpiCollector +from wordlift_sdk.validation.shacl_validation_service import ValidationOutcome def test_kpi_collector_records_graph_and_validation() -> None: @@ -14,11 +15,13 @@ def test_kpi_collector_records_graph_and_validation() -> None: collector.record_graph(graph) collector.record_validation( - passed=False, - warning_count=2, - error_count=1, - warning_sources={"google-article": 2}, - error_sources={"google-product": 1}, + ValidationOutcome( + passed=False, + warning_sources={"google-article": 2}, + error_sources={"google-product": 1}, + queue_wait_ms=0, + validation_ms=0, + ) ) summary = collector.summary("demo") diff --git a/tests/kg_build/test_profile_inheritance.py b/tests/kg_build/test_profile_inheritance.py index 8bb90e1..fbdea15 100644 --- a/tests/kg_build/test_profile_inheritance.py +++ b/tests/kg_build/test_profile_inheritance.py @@ -45,7 +45,8 @@ def test_runtime_inherits_from_base_when_selected_missing(tmp_path: Path) -> Non ) assert profile.settings["postprocessor_runtime"] == "persistent" - assert protocol._postprocessor_runtime == "persistent" + # Verify the protocol accepted the inherited runtime (service is initialised without error) + assert protocol._postprocessor_service is not None def test_validation_settings_parse_into_profile_settings(tmp_path: Path) -> None: diff --git a/tests/kg_build/test_protocol.py b/tests/kg_build/test_protocol.py index 3e46fed..43efe7e 100644 --- a/tests/kg_build/test_protocol.py +++ b/tests/kg_build/test_protocol.py @@ -7,7 +7,6 @@ from jinja2 import UndefinedError from rdflib import BNode, Graph, Literal, RDF, URIRef from wordlift_client import WebPage, WebPageScrapeResponse -from wordlift_sdk.validation.shacl import ValidationResult from wordlift_sdk.kg_build.config.loader import ProfileDefinition, ProfileMappingRoute import wordlift_sdk.kg_build.protocol as protocol_module @@ -16,6 +15,17 @@ _path_contains_part, _resolve_postprocessor_runtime, ) +from wordlift_sdk.kg_build.rml_mapping import MappingResult +from wordlift_sdk.kg_build.postprocessors.types import PostprocessorResult +from wordlift_sdk.kg_build.postprocessors.processors.graph_annotation import ( + ImportAnnotationPostprocessor, +) +from wordlift_sdk.kg_build.postprocessors.processors.id_postprocessor import ( + CanonicalIdsPostprocessor, + RootIdReconcilerPostprocessor, + _find_web_page_iri as _find_web_page_iri_impl, +) +from wordlift_sdk.validation.shacl_validation_service import ValidationOutcome def _make_profile() -> ProfileDefinition: @@ -60,7 +70,7 @@ def _make_context() -> SimpleNamespace: return SimpleNamespace( account=SimpleNamespace(dataset_uri="https://data.example.com/dataset"), client_configuration=SimpleNamespace(api_key={}), - graph_queue=SimpleNamespace(put=AsyncMock()), + graph_queue=SimpleNamespace(put=AsyncMock(), close=AsyncMock()), configuration_provider=SimpleNamespace( get_value=lambda *_args, **_kwargs: None ), @@ -71,13 +81,73 @@ def _make_context_without_dataset() -> SimpleNamespace: return SimpleNamespace( account=SimpleNamespace(dataset_uri=None), client_configuration=SimpleNamespace(api_key={}), - graph_queue=SimpleNamespace(put=AsyncMock()), + graph_queue=SimpleNamespace(put=AsyncMock(), close=AsyncMock()), configuration_provider=SimpleNamespace( get_value=lambda *_args, **_kwargs: None ), ) +def _make_mapping_result(graph: Graph) -> MappingResult: + return MappingResult(graph=graph, queue_wait_ms=0, mapping_ms=0) + + +def _make_validation_outcome( + *, + passed: bool, + warning_sources: dict | None = None, + error_sources: dict | None = None, +) -> ValidationOutcome: + return ValidationOutcome( + passed=passed, + warning_sources=warning_sources or {}, + error_sources=error_sources or {}, + queue_wait_ms=0, + validation_ms=0, + ) + + +def _passthrough_pp() -> AsyncMock: + return AsyncMock( + side_effect=lambda g, url, resp, ewi, eih: PostprocessorResult( + graph=g, queue_wait_ms=0, postprocessors_ms=0 + ) + ) + + +def _annotating_pp( + dataset_uri: str = "https://data.example.com/dataset", + import_hash_mode: str = "on", +) -> AsyncMock: + async def _stage(graph, url, resp, ewi, eih): + ctx = SimpleNamespace( + account=SimpleNamespace(dataset_uri=dataset_uri), + existing_import_hash=eih, + import_hash_mode=import_hash_mode, + ) + g = ImportAnnotationPostprocessor().process_graph(graph, ctx) + return PostprocessorResult(graph=g, queue_wait_ms=0, postprocessors_ms=0) + + return AsyncMock(side_effect=_stage) + + +def _reconciling_pp( + dataset_uri: str = "https://data.example.com/dataset", +) -> AsyncMock: + async def _stage(graph, url, resp, ewi, eih): + ctx = SimpleNamespace( + account=SimpleNamespace(dataset_uri=dataset_uri), + existing_import_hash=eih, + import_hash_mode="on", + existing_web_page_id=ewi, + ) + g = RootIdReconcilerPostprocessor().process_graph(graph, ctx) + g = ImportAnnotationPostprocessor().process_graph(g, ctx) + return PostprocessorResult(graph=g, queue_wait_ms=0, postprocessors_ms=0) + + return AsyncMock(side_effect=_stage) + + def _make_graph(subject: str) -> Graph: graph = Graph() s = URIRef(subject) @@ -130,12 +200,13 @@ async def test_profile_protocol_reconciles_to_existing_id_and_sets_source(): protocol._patch_static_templates_once = AsyncMock() protocol._resolve_mapping_path = MagicMock(return_value=Path("mapping.yarrrml")) protocol._get_mapping_content = MagicMock(return_value="mapping") - protocol._core_ids.process_graph = MagicMock(side_effect=lambda g, _: g) - protocol._apply_postprocessors = MagicMock(side_effect=lambda g, *_: g) - protocol.patcher.patch_all = AsyncMock() - protocol.rml_service.apply_mapping = AsyncMock( - return_value=_make_graph("https://example.com/mapped-web-page") + protocol._run_mapping_stage = AsyncMock( + return_value=_make_mapping_result( + _make_graph("https://example.com/mapped-web-page") + ) ) + protocol._run_postprocessing_stage = _reconciling_pp() + protocol.patcher.patch_all = AsyncMock() response = WebPageScrapeResponse( web_page=WebPage(url="https://example.com/page", html="") @@ -169,12 +240,11 @@ async def test_profile_protocol_put_strategy_writes_to_graph_queue() -> None: protocol._patch_static_templates_once = AsyncMock() protocol._resolve_mapping_path = MagicMock(return_value=Path("mapping.yarrrml")) protocol._get_mapping_content = MagicMock(return_value="mapping") - protocol._core_ids.process_graph = MagicMock(side_effect=lambda g, _: g) - protocol._apply_postprocessors = MagicMock(side_effect=lambda g, *_: g) - protocol.patcher.patch_all = AsyncMock() - protocol.rml_service.apply_mapping = AsyncMock( - return_value=_make_dataset_scoped_graph() + protocol._run_mapping_stage = AsyncMock( + return_value=_make_mapping_result(_make_dataset_scoped_graph()) ) + protocol._run_postprocessing_stage = _passthrough_pp() + protocol.patcher.patch_all = AsyncMock() response = WebPageScrapeResponse( web_page=WebPage(url="https://example.com/page", html="") @@ -197,7 +267,7 @@ async def test_static_templates_use_graph_queue_when_put_strategy_enabled() -> N ) protocol._template_graph = _make_dataset_scoped_graph() protocol._template_exports = {} - protocol._validate_graph_if_enabled = MagicMock(return_value=None) + protocol._shacl_validator.validate = AsyncMock(return_value=None) protocol._emit_progress = MagicMock() protocol._kpi.record_graph = MagicMock() protocol.patcher.patch_all = AsyncMock() @@ -220,8 +290,7 @@ async def test_profile_protocol_put_strategy_honors_import_hash_write_mode() -> protocol._patch_static_templates_once = AsyncMock() protocol._resolve_mapping_path = MagicMock(return_value=Path("mapping.yarrrml")) protocol._get_mapping_content = MagicMock(return_value="mapping") - protocol._core_ids.process_graph = MagicMock(side_effect=lambda g, _: g) - protocol._apply_postprocessors = MagicMock(side_effect=lambda g, *_: g) + protocol._run_postprocessing_stage = _passthrough_pp() protocol.patcher.patch_all = AsyncMock() graph = _make_dataset_scoped_graph() child = URIRef("https://data.example.com/dataset/entities/article-1/faq/1") @@ -233,7 +302,7 @@ async def test_profile_protocol_put_strategy_honors_import_hash_write_mode() -> ) ) graph.add((child, RDF.type, URIRef("https://schema.org/Question"))) - protocol.rml_service.apply_mapping = AsyncMock(return_value=graph) + protocol._run_mapping_stage = AsyncMock(return_value=_make_mapping_result(graph)) response = WebPageScrapeResponse( web_page=WebPage(url="https://example.com/page", html="") @@ -264,17 +333,22 @@ async def test_profile_protocol_put_strategy_skips_when_import_hash_matches() -> protocol._patch_static_templates_once = AsyncMock() protocol._resolve_mapping_path = MagicMock(return_value=Path("mapping.yarrrml")) protocol._get_mapping_content = MagicMock(return_value="mapping") - protocol._core_ids.process_graph = MagicMock(side_effect=lambda g, _: g) - protocol._apply_postprocessors = MagicMock(side_effect=lambda g, *_: g) - protocol.patcher.patch_all = AsyncMock() graph = _make_dataset_scoped_graph() - protocol._set_source(graph, existing_web_page_id=None) + # Pre-annotate so the expected hash matches what the pipeline will produce + ann_ctx = SimpleNamespace( + account=context.account, + existing_import_hash=None, + import_hash_mode="on", + ) + ImportAnnotationPostprocessor().process_graph(graph, ann_ctx) expected_hash = protocol.patcher._compute_import_hash( URIRef("https://data.example.com/dataset/web-pages/1"), graph, "https://data.example.com/dataset", ) - protocol.rml_service.apply_mapping = AsyncMock(return_value=graph) + protocol._run_mapping_stage = AsyncMock(return_value=_make_mapping_result(graph)) + protocol._run_postprocessing_stage = _annotating_pp() + protocol.patcher.patch_all = AsyncMock() response = WebPageScrapeResponse( web_page=WebPage(url="https://example.com/page", html="") @@ -297,12 +371,11 @@ async def test_profile_protocol_put_strategy_honors_import_hash_off_mode() -> No protocol._patch_static_templates_once = AsyncMock() protocol._resolve_mapping_path = MagicMock(return_value=Path("mapping.yarrrml")) protocol._get_mapping_content = MagicMock(return_value="mapping") - protocol._core_ids.process_graph = MagicMock(side_effect=lambda g, _: g) - protocol._apply_postprocessors = MagicMock(side_effect=lambda g, *_: g) - protocol.patcher.patch_all = AsyncMock() - protocol.rml_service.apply_mapping = AsyncMock( - return_value=_make_dataset_scoped_graph() + protocol._run_mapping_stage = AsyncMock( + return_value=_make_mapping_result(_make_dataset_scoped_graph()) ) + protocol._run_postprocessing_stage = _passthrough_pp() + protocol.patcher.patch_all = AsyncMock() response = WebPageScrapeResponse( web_page=WebPage(url="https://example.com/page", html="") @@ -335,13 +408,12 @@ async def test_profile_protocol_sets_source_on_mapped_subject_when_existing_id_m protocol._patch_static_templates_once = AsyncMock() protocol._resolve_mapping_path = MagicMock(return_value=Path("mapping.yarrrml")) protocol._get_mapping_content = MagicMock(return_value="mapping") - protocol._core_ids.process_graph = MagicMock(side_effect=lambda g, _: g) - protocol._apply_postprocessors = MagicMock(side_effect=lambda g, *_: g) - protocol.patcher.patch_all = AsyncMock() mapped_subject = "https://example.com/mapped-web-page" - protocol.rml_service.apply_mapping = AsyncMock( - return_value=_make_graph(mapped_subject) + protocol._run_mapping_stage = AsyncMock( + return_value=_make_mapping_result(_make_graph(mapped_subject)) ) + protocol._run_postprocessing_stage = _annotating_pp() + protocol.patcher.patch_all = AsyncMock() response = WebPageScrapeResponse( web_page=WebPage(url="https://example.com/page", html="") @@ -367,12 +439,11 @@ async def test_profile_protocol_sets_source_only_on_first_level_uri_subjects(): protocol._patch_static_templates_once = AsyncMock() protocol._resolve_mapping_path = MagicMock(return_value=Path("mapping.yarrrml")) protocol._get_mapping_content = MagicMock(return_value="mapping") - protocol._core_ids.process_graph = MagicMock(side_effect=lambda g, _: g) - protocol._apply_postprocessors = MagicMock(side_effect=lambda g, *_: g) - protocol.patcher.patch_all = AsyncMock() - protocol.rml_service.apply_mapping = AsyncMock( - return_value=_make_multi_entity_graph() + protocol._run_mapping_stage = AsyncMock( + return_value=_make_mapping_result(_make_multi_entity_graph()) ) + protocol._run_postprocessing_stage = _annotating_pp() + protocol.patcher.patch_all = AsyncMock() response = WebPageScrapeResponse( web_page=WebPage(url="https://example.com/page", html="") @@ -423,11 +494,8 @@ async def test_callback_runs_canonical_ids_after_postprocessors() -> None: Literal("https://translated.com/developers"), ) ) - protocol.rml_service.apply_mapping = AsyncMock(return_value=mapped_graph) - def _inject_service_product_and_fragment_offer( - graph: Graph, *_args, **_kwargs - ) -> Graph: + async def _pp_with_injection(graph, url, resp, ewi, eih): graph.add((root, RDF.type, URIRef("http://schema.org/Product"))) graph.add((root, RDF.type, URIRef("http://schema.org/Service"))) graph.add( @@ -444,11 +512,17 @@ def _inject_service_product_and_fragment_offer( URIRef(f"{root}#aggregate-offer-usd"), ) ) - return graph + ctx = SimpleNamespace( + account=SimpleNamespace(dataset_uri="https://data.example.com/dataset"), + extensions=None, + ) + g = CanonicalIdsPostprocessor().process_graph(graph, ctx) + return PostprocessorResult(graph=g, queue_wait_ms=0, postprocessors_ms=0) - protocol._apply_postprocessors = MagicMock( - side_effect=_inject_service_product_and_fragment_offer + protocol._run_mapping_stage = AsyncMock( + return_value=_make_mapping_result(mapped_graph) ) + protocol._run_postprocessing_stage = AsyncMock(side_effect=_pp_with_injection) response = WebPageScrapeResponse( web_page=WebPage(url="https://translated.com/developers", html="") @@ -484,12 +558,11 @@ async def test_profile_protocol_applies_existing_import_hash_to_all_uri_subjects protocol._patch_static_templates_once = AsyncMock() protocol._resolve_mapping_path = MagicMock(return_value=Path("mapping.yarrrml")) protocol._get_mapping_content = MagicMock(return_value="mapping") - protocol._core_ids.process_graph = MagicMock(side_effect=lambda g, _: g) - protocol._apply_postprocessors = MagicMock(side_effect=lambda g, *_: g) - protocol.patcher.patch_all = AsyncMock() - protocol.rml_service.apply_mapping = AsyncMock( - return_value=_make_multi_entity_graph() + protocol._run_mapping_stage = AsyncMock( + return_value=_make_mapping_result(_make_multi_entity_graph()) ) + protocol._run_postprocessing_stage = _annotating_pp() + protocol.patcher.patch_all = AsyncMock() response = WebPageScrapeResponse( web_page=WebPage(url="https://example.com/page", html="") @@ -520,15 +593,14 @@ async def test_profile_protocol_sets_source_when_web_page_absent_but_uri_subject protocol._patch_static_templates_once = AsyncMock() protocol._resolve_mapping_path = MagicMock(return_value=Path("mapping.yarrrml")) protocol._get_mapping_content = MagicMock(return_value="mapping") - protocol._core_ids.process_graph = MagicMock(side_effect=lambda g, _: g) - protocol._apply_postprocessors = MagicMock(side_effect=lambda g, *_: g) + protocol._run_postprocessing_stage = _annotating_pp() protocol.patcher.patch_all = AsyncMock() graph = Graph() article = URIRef("https://example.com/entities/article-only") graph.add((article, RDF.type, URIRef("http://schema.org/Article"))) graph.add((article, URIRef("http://schema.org/headline"), Literal("Title"))) - protocol.rml_service.apply_mapping = AsyncMock(return_value=graph) + protocol._run_mapping_stage = AsyncMock(return_value=_make_mapping_result(graph)) response = WebPageScrapeResponse( web_page=WebPage(url="https://example.com/page", html="") @@ -554,8 +626,7 @@ async def test_profile_protocol_sets_source_by_dataset_id_depth() -> None: protocol._patch_static_templates_once = AsyncMock() protocol._resolve_mapping_path = MagicMock(return_value=Path("mapping.yarrrml")) protocol._get_mapping_content = MagicMock(return_value="mapping") - protocol._core_ids.process_graph = MagicMock(side_effect=lambda g, _: g) - protocol._apply_postprocessors = MagicMock(side_effect=lambda g, *_: g) + protocol._run_postprocessing_stage = _annotating_pp() protocol.patcher.patch_all = AsyncMock() graph = Graph() @@ -567,7 +638,7 @@ async def test_profile_protocol_sets_source_by_dataset_id_depth() -> None: graph.add((entity, RDF.type, URIRef("https://schema.org/Article"))) graph.add((entity, URIRef("https://schema.org/hasPart"), child)) graph.add((child, RDF.type, URIRef("https://schema.org/Question"))) - protocol.rml_service.apply_mapping = AsyncMock(return_value=graph) + protocol._run_mapping_stage = AsyncMock(return_value=_make_mapping_result(graph)) response = WebPageScrapeResponse( web_page=WebPage(url="https://example.com/page", html="") @@ -598,8 +669,7 @@ async def test_profile_protocol_does_not_set_source_on_blank_nodes(): protocol._patch_static_templates_once = AsyncMock() protocol._resolve_mapping_path = MagicMock(return_value=Path("mapping.yarrrml")) protocol._get_mapping_content = MagicMock(return_value="mapping") - protocol._core_ids.process_graph = MagicMock(side_effect=lambda g, _: g) - protocol._apply_postprocessors = MagicMock(side_effect=lambda g, *_: g) + protocol._run_postprocessing_stage = _annotating_pp() protocol.patcher.patch_all = AsyncMock() graph = Graph() @@ -608,7 +678,7 @@ async def test_profile_protocol_does_not_set_source_on_blank_nodes(): graph.add((article, RDF.type, URIRef("http://schema.org/Article"))) graph.add((blank, RDF.type, URIRef("http://schema.org/Thing"))) graph.add((article, URIRef("http://schema.org/mentions"), blank)) - protocol.rml_service.apply_mapping = AsyncMock(return_value=graph) + protocol._run_mapping_stage = AsyncMock(return_value=_make_mapping_result(graph)) response = WebPageScrapeResponse( web_page=WebPage(url="https://example.com/page", html="") @@ -640,12 +710,11 @@ def fake_loader(*, root_dir, profile_name, runtime=None): return [] monkeypatch.setattr(protocol_module, "load_postprocessors_for_profile", fake_loader) - protocol = ProfileImportProtocol( + ProfileImportProtocol( context=_make_context(), profile=_make_profile_with_settings({"POSTPROCESSOR_RUNTIME": "persistent"}), root_dir=Path.cwd(), ) - assert protocol._postprocessor_runtime == "persistent" assert captured["runtime"] == "persistent" @@ -675,7 +744,10 @@ def test_build_pp_context_exposes_resolved_profile_and_account_key() -> None: ) context = protocol._build_pp_context( - "https://example.com/page", response, existing_web_page_id=None + "https://example.com/page", + response, + existing_web_page_id=None, + existing_import_hash=None, ) assert context.account_key == "profile-secret" @@ -699,44 +771,37 @@ def test_build_pp_context_preserves_custom_profile_settings() -> None: ) context = protocol._build_pp_context( - "https://example.com/page", response, existing_web_page_id=None + "https://example.com/page", + response, + existing_web_page_id=None, + existing_import_hash=None, ) assert context.profile["settings"]["disable_article_markup"] is True -def test_apply_postprocessors_fails_fast_when_account_key_missing() -> None: +def test_account_key_resolved_from_profile_api_key() -> None: + profile = ProfileDefinition( + **{ + **_make_profile().__dict__, + "api_key": "profile-secret", + } + ) protocol = ProfileImportProtocol( context=_make_context(), - profile=_make_profile(), + profile=profile, root_dir=Path.cwd(), ) + assert protocol._account_key == "profile-secret" - class _NeverRun: - name = "never-run" - called = False - - def run(self, graph, context): - self.called = True - return graph - - handler = _NeverRun() - protocol._postprocessors = [handler] # type: ignore[assignment] - response = WebPageScrapeResponse( - web_page=WebPage(url="https://example.com/page", html="") +def test_account_key_is_none_when_no_key_configured() -> None: + protocol = ProfileImportProtocol( + context=_make_context(), + profile=_make_profile(), + root_dir=Path.cwd(), ) - graph = _make_graph("https://example.com/mapped-web-page") - - with pytest.raises(RuntimeError, match="Postprocessor runtime requires an API key"): - protocol._apply_postprocessors( - graph, - "https://example.com/page", - response, - existing_web_page_id=None, - ) - - assert handler.called is False + assert protocol._account_key is None def test_protocol_helpers_runtime_and_path_part() -> None: @@ -792,7 +857,7 @@ async def test_callback_returns_early_when_mapping_has_no_triples() -> None: protocol._patch_static_templates_once = AsyncMock() protocol._resolve_mapping_path = MagicMock(return_value=Path("mapping.yarrrml")) protocol._get_mapping_content = MagicMock(return_value="mapping") - protocol.rml_service.apply_mapping = AsyncMock(return_value=Graph()) + protocol._run_mapping_stage = AsyncMock(return_value=_make_mapping_result(Graph())) protocol.patcher.patch_all = AsyncMock() response = WebPageScrapeResponse( @@ -803,21 +868,16 @@ async def test_callback_returns_early_when_mapping_has_no_triples() -> None: protocol.patcher.patch_all.assert_not_called() -def test_close_invokes_postprocessor_cleanup(monkeypatch: pytest.MonkeyPatch) -> None: - called: dict[str, object] = {} - - def fake_close(postprocessors): - called["value"] = postprocessors - - monkeypatch.setattr(protocol_module, "close_loaded_postprocessors", fake_close) +def test_close_invokes_postprocessor_service_close() -> None: protocol = ProfileImportProtocol( context=_make_context(), profile=_make_profile(), root_dir=Path.cwd(), ) - protocol._postprocessors = ["x"] # type: ignore[assignment] - protocol.close() - assert called["value"] == ["x"] + mock_close = MagicMock() + protocol._postprocessor_service.close = mock_close + asyncio.run(protocol.close()) + mock_close.assert_called_once() def test_resolve_path_and_overlay_paths(tmp_path: Path) -> None: @@ -1018,88 +1078,69 @@ def test_get_mapping_content_uses_cache_and_requires_dataset() -> None: protocol2._get_mapping_content(path) -def test_apply_postprocessors_runs_all_processors() -> None: +def test_postprocessor_factory_builds_required_processors( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Verify the factory used by PostprocessorService includes the standard processors.""" + + def fake_loader(*, root_dir, profile_name, runtime=None): + return [] + + monkeypatch.setattr(protocol_module, "load_postprocessors_for_profile", fake_loader) protocol = ProfileImportProtocol( context=_make_context(), - profile=_make_profile_with_settings({"api_key": "x"}), + profile=_make_profile(), root_dir=Path.cwd(), ) - response = WebPageScrapeResponse( - web_page=WebPage(url="https://example.com/page", html="") - ) - graph = _make_graph("https://example.com/page") - - class _P1: - name = "p1" - - def run(self, g, _ctx): - g.add( - ( - URIRef("https://example.com/page"), - URIRef("https://schema.org/name"), - Literal("a"), - ) - ) - return g - - class _P2: - name = "p2" - - def run(self, g, _ctx): - return g - - protocol._postprocessors = [_P1(), _P2()] # type: ignore[assignment] - protocol._resolve_postprocessor_account_key = MagicMock(return_value="secret") - out = protocol._apply_postprocessors( - graph, "https://example.com/page", response, None - ) - assert len(out) >= len(graph) + # Get one slot from the pool to inspect the processors + processors = list(protocol._postprocessor_service._queue.get_nowait()) + names = [p.name for p in processors] + assert "root_id_reconciler" in names + assert "canonical_ids" in names + assert "import_annotation" in names -def test_resolve_postprocessor_account_key_priority( +def test_resolve_account_key_priority( monkeypatch: pytest.MonkeyPatch, ) -> None: - protocol = ProfileImportProtocol( - context=_make_context(), - profile=_make_profile(), - root_dir=Path.cwd(), + profile = _make_profile() + context = _make_context() + + profile_with_key = ProfileDefinition( + **{**profile.__dict__, "api_key": "profile-key"} ) - protocol.profile = ProfileDefinition( - **{**protocol.profile.__dict__, "api_key": "profile-key"} + assert ( + protocol_module._resolve_account_key(profile_with_key, context) == "profile-key" ) - assert protocol._resolve_postprocessor_account_key() == "profile-key" - protocol.profile = ProfileDefinition( - **{**protocol.profile.__dict__, "api_key": None} - ) - protocol.context.client_configuration.api_key = {"ApiKey": "runtime-key"} - assert protocol._resolve_postprocessor_account_key() == "runtime-key" + context.client_configuration.api_key = {"ApiKey": "runtime-key"} + assert protocol_module._resolve_account_key(profile, context) == "runtime-key" - protocol.context.client_configuration.api_key = {} - protocol.context.configuration_provider = SimpleNamespace( + context.client_configuration.api_key = {} + context.configuration_provider = SimpleNamespace( get_value=lambda name: "provider-key" if name == "WORDLIFT_KEY" else None ) - assert protocol._resolve_postprocessor_account_key() == "provider-key" + assert protocol_module._resolve_account_key(profile, context) == "provider-key" - protocol.context.configuration_provider = SimpleNamespace( + context.configuration_provider = SimpleNamespace( get_value=lambda _name: (_ for _ in ()).throw(RuntimeError("nope")) ) monkeypatch.setenv("WORDLIFT_API_KEY", "env-key") - assert protocol._resolve_postprocessor_account_key() == "env-key" + assert protocol_module._resolve_account_key(profile, context) == "env-key" monkeypatch.delenv("WORDLIFT_API_KEY", raising=False) def test_clean_key_write_debug_and_reconcile(tmp_path: Path) -> None: + assert protocol_module._clean_key(None) is None + assert protocol_module._clean_key(" ") is None + assert protocol_module._clean_key(" x ") == "x" + protocol = ProfileImportProtocol( context=_make_context(), profile=_make_profile(), root_dir=tmp_path, debug_dir=tmp_path / "debug", ) - assert protocol._clean_key(None) is None - assert protocol._clean_key(" ") is None - assert protocol._clean_key(" x ") == "x" - graph = _make_graph("https://example.com/old") protocol._write_debug_graph(graph, "https://example.com/page") protocol._write_debug_source_documents( @@ -1113,8 +1154,14 @@ def test_clean_key_write_debug_and_reconcile(tmp_path: Path) -> None: child = URIRef("https://example.com/child") https_graph.add((old, RDF.type, URIRef("https://schema.org/WebPage"))) https_graph.add((child, URIRef("https://schema.org/about"), old)) - assert protocol._find_web_page_iri(https_graph) == old - protocol._reconcile_root_id(https_graph, str(new)) + assert _find_web_page_iri_impl(https_graph) == old + ctx = SimpleNamespace( + existing_web_page_id=str(new), + account=SimpleNamespace(dataset_uri=""), + existing_import_hash=None, + import_hash_mode="on", + ) + RootIdReconcilerPostprocessor().process_graph(https_graph, ctx) assert (new, RDF.type, URIRef("https://schema.org/WebPage")) in https_graph assert (child, URIRef("https://schema.org/about"), new) in https_graph @@ -1132,17 +1179,15 @@ async def test_callback_writes_html_xhtml_and_ttl_debug_artifacts( protocol._patch_static_templates_once = AsyncMock() protocol._resolve_mapping_path = MagicMock(return_value=Path("mapping.yarrrml")) protocol._get_mapping_content = MagicMock(return_value="mapping") - protocol._core_ids.process_graph = MagicMock(side_effect=lambda g, _: g) - protocol._apply_postprocessors = MagicMock(side_effect=lambda g, *_: g) - protocol.patcher.patch_all = AsyncMock() - async def _apply_mapping(**kwargs): - debug_output = kwargs.get("debug_output") + async def _mapping_stage(response, url, ewi, debug_output): if isinstance(debug_output, dict): debug_output["xhtml"] = "Converted" - return _make_graph("https://example.com/mapped-web-page") + return _make_mapping_result(_make_graph("https://example.com/mapped-web-page")) - protocol.rml_service.apply_mapping = AsyncMock(side_effect=_apply_mapping) + protocol._run_mapping_stage = AsyncMock(side_effect=_mapping_stage) + protocol._run_postprocessing_stage = _passthrough_pp() + protocol.patcher.patch_all = AsyncMock() response = WebPageScrapeResponse( web_page=WebPage(url="https://example.com/page", html="Raw") @@ -1190,14 +1235,10 @@ def test_protocol_setting_parsers_and_progress_error_logging( profile=profile, root_dir=Path.cwd(), ) - assert protocol._shacl_mode == "warn" - assert protocol._shacl_shape_specs == [ - "google-article.ttl", - "https://example.com/custom-shape.ttl", - ] + assert protocol._shacl_validator.mode.value == "warn" assert protocol._import_hash_mode == "write" - assert protocol._resolve_list_setting(["a", " ", "b"]) == ["a", "b"] - assert protocol._resolve_list_setting(123) == ["123"] + assert protocol_module._resolve_list_setting(["a", " ", "b"]) == ["a", "b"] + assert protocol_module._resolve_list_setting(123) == ["123"] protocol._on_progress = lambda _payload: (_ for _ in ()).throw(RuntimeError("boom")) with caplog.at_level("WARNING"): @@ -1265,8 +1306,8 @@ async def test_patch_static_templates_fail_validation_raises() -> None: protocol._template_graph = graph protocol._template_exports = {} protocol.patcher.patch_all = AsyncMock() - protocol._validate_graph = MagicMock( - return_value=_make_validation_result(conforms=False) + protocol._shacl_validator.validate = AsyncMock( + return_value=_make_validation_outcome(passed=False) ) with pytest.raises( @@ -1281,11 +1322,6 @@ async def test_patch_static_templates_fail_validation_raises() -> None: def test_find_web_page_iri_returns_none_when_missing() -> None: - protocol = ProfileImportProtocol( - context=_make_context(), - profile=_make_profile(), - root_dir=Path.cwd(), - ) graph = Graph() graph.add( ( @@ -1294,61 +1330,17 @@ def test_find_web_page_iri_returns_none_when_missing() -> None: URIRef("https://schema.org/Thing"), ) ) - assert protocol._find_web_page_iri(graph) is None + assert _find_web_page_iri_impl(graph) is None -def _make_validation_result( - *, - conforms: bool, - warning_shapes: list[URIRef] | None = None, - error_shapes: list[URIRef] | None = None, - shape_map: dict[URIRef, str] | None = None, -) -> ValidationResult: - warning_shapes = warning_shapes or [] - error_shapes = error_shapes or [] - shape_map = shape_map or {} - report = Graph() - sh_result_severity = URIRef("http://www.w3.org/ns/shacl#resultSeverity") - sh_warning = URIRef("http://www.w3.org/ns/shacl#Warning") - sh_violation = URIRef("http://www.w3.org/ns/shacl#Violation") - sh_source_shape = URIRef("http://www.w3.org/ns/shacl#sourceShape") - - for index, shape in enumerate(warning_shapes): - node = URIRef(f"https://example.com/report/w/{index}") - report.add((node, sh_result_severity, sh_warning)) - report.add((node, sh_source_shape, shape)) - for index, shape in enumerate(error_shapes): - node = URIRef(f"https://example.com/report/e/{index}") - report.add((node, sh_result_severity, sh_violation)) - report.add((node, sh_source_shape, shape)) - - return ValidationResult( - conforms=conforms, - report_text="report", - report_graph=report, - data_graph=Graph(), - shape_source_map=shape_map, - warning_count=len(warning_shapes), - ) - - -def test_summarize_validation_aggregates_sources() -> None: - protocol = ProfileImportProtocol( - context=_make_context(), - profile=_make_profile(), - root_dir=Path.cwd(), - ) - article_shape = URIRef("https://shape.example/article") - product_shape = URIRef("https://shape.example/product") - result = _make_validation_result( - conforms=False, - warning_shapes=[article_shape], - error_shapes=[article_shape, product_shape], - shape_map={article_shape: "google-article", product_shape: "google-product"}, +def test_validation_outcome_to_dict_aggregates_sources() -> None: + outcome = _make_validation_outcome( + passed=False, + warning_sources={"google-article": 1}, + error_sources={"google-article": 1, "google-product": 1}, ) - summary = protocol._summarize_validation(result) + summary = outcome.to_dict() assert summary == { - "total": 1, "pass": False, "fail": True, "warnings": {"count": 1, "sources": {"google-article": 1}}, @@ -1376,21 +1368,16 @@ async def test_profile_protocol_emits_progress_and_validation_in_warn_mode() -> protocol._patch_static_templates_once = AsyncMock() protocol._resolve_mapping_path = MagicMock(return_value=Path("mapping.yarrrml")) protocol._get_mapping_content = MagicMock(return_value="mapping") - protocol._core_ids.process_graph = MagicMock(side_effect=lambda g, _: g) - protocol._apply_postprocessors = MagicMock(side_effect=lambda g, *_: g) + protocol._run_mapping_stage = AsyncMock( + return_value=_make_mapping_result(_make_dataset_scoped_graph()) + ) + protocol._run_postprocessing_stage = _passthrough_pp() protocol.patcher.patch_all = AsyncMock() - protocol.rml_service.apply_mapping = AsyncMock( - return_value=_make_dataset_scoped_graph() - ) - protocol._validate_graph = MagicMock( - return_value=_make_validation_result( - conforms=False, - warning_shapes=[URIRef("https://shape.example/w")], - error_shapes=[URIRef("https://shape.example/e")], - shape_map={ - URIRef("https://shape.example/w"): "google-article", - URIRef("https://shape.example/e"): "google-product", - }, + protocol._shacl_validator.validate = AsyncMock( + return_value=_make_validation_outcome( + passed=False, + warning_sources={"google-article": 1}, + error_sources={"google-product": 1}, ) ) @@ -1404,7 +1391,6 @@ async def test_profile_protocol_emits_progress_and_validation_in_warn_mode() -> assert payload["kind"] == "graph" assert payload["url"] == "https://example.com/page" assert payload["validation"] == { - "total": 1, "pass": False, "fail": True, "warnings": {"count": 1, "sources": {"google-article": 1}}, @@ -1440,14 +1426,13 @@ async def test_profile_protocol_validation_fail_mode_raises() -> None: protocol._patch_static_templates_once = AsyncMock() protocol._resolve_mapping_path = MagicMock(return_value=Path("mapping.yarrrml")) protocol._get_mapping_content = MagicMock(return_value="mapping") - protocol._core_ids.process_graph = MagicMock(side_effect=lambda g, _: g) - protocol._apply_postprocessors = MagicMock(side_effect=lambda g, *_: g) - protocol.patcher.patch_all = AsyncMock() - protocol.rml_service.apply_mapping = AsyncMock( - return_value=_make_dataset_scoped_graph() + protocol._run_mapping_stage = AsyncMock( + return_value=_make_mapping_result(_make_dataset_scoped_graph()) ) - protocol._validate_graph = MagicMock( - return_value=_make_validation_result(conforms=False) + protocol._run_postprocessing_stage = _passthrough_pp() + protocol.patcher.patch_all = AsyncMock() + protocol._shacl_validator.validate = AsyncMock( + return_value=_make_validation_outcome(passed=False) ) response = WebPageScrapeResponse( @@ -1475,12 +1460,11 @@ async def test_profile_protocol_emits_null_validation_when_disabled() -> None: protocol._patch_static_templates_once = AsyncMock() protocol._resolve_mapping_path = MagicMock(return_value=Path("mapping.yarrrml")) protocol._get_mapping_content = MagicMock(return_value="mapping") - protocol._core_ids.process_graph = MagicMock(side_effect=lambda g, _: g) - protocol._apply_postprocessors = MagicMock(side_effect=lambda g, *_: g) - protocol.patcher.patch_all = AsyncMock() - protocol.rml_service.apply_mapping = AsyncMock( - return_value=_make_dataset_scoped_graph() + protocol._run_mapping_stage = AsyncMock( + return_value=_make_mapping_result(_make_dataset_scoped_graph()) ) + protocol._run_postprocessing_stage = _passthrough_pp() + protocol.patcher.patch_all = AsyncMock() response = WebPageScrapeResponse( web_page=WebPage(url="https://example.com/page", html="") @@ -1504,12 +1488,11 @@ async def test_profile_protocol_passes_import_hash_mode_to_patcher() -> None: protocol._patch_static_templates_once = AsyncMock() protocol._resolve_mapping_path = MagicMock(return_value=Path("mapping.yarrrml")) protocol._get_mapping_content = MagicMock(return_value="mapping") - protocol._core_ids.process_graph = MagicMock(side_effect=lambda g, _: g) - protocol._apply_postprocessors = MagicMock(side_effect=lambda g, *_: g) - protocol.patcher.patch_all = AsyncMock() - protocol.rml_service.apply_mapping = AsyncMock( - return_value=_make_dataset_scoped_graph() + protocol._run_mapping_stage = AsyncMock( + return_value=_make_mapping_result(_make_dataset_scoped_graph()) ) + protocol._run_postprocessing_stage = _passthrough_pp() + protocol.patcher.patch_all = AsyncMock() response = WebPageScrapeResponse( web_page=WebPage(url="https://example.com/page", html="") @@ -1535,12 +1518,11 @@ async def test_profile_protocol_emits_graph_and_static_template_events() -> None protocol._template_exports = {} protocol._resolve_mapping_path = MagicMock(return_value=Path("mapping.yarrrml")) protocol._get_mapping_content = MagicMock(return_value="mapping") - protocol._core_ids.process_graph = MagicMock(side_effect=lambda g, _: g) - protocol._apply_postprocessors = MagicMock(side_effect=lambda g, *_: g) - protocol.patcher.patch_all = AsyncMock() - protocol.rml_service.apply_mapping = AsyncMock( - return_value=_make_dataset_scoped_graph() + protocol._run_mapping_stage = AsyncMock( + return_value=_make_mapping_result(_make_dataset_scoped_graph()) ) + protocol._run_postprocessing_stage = _passthrough_pp() + protocol.patcher.patch_all = AsyncMock() response = WebPageScrapeResponse( web_page=WebPage(url="https://example.com/page", html="") @@ -1560,12 +1542,11 @@ async def test_profile_protocol_collects_run_level_kpis() -> None: protocol._patch_static_templates_once = AsyncMock() protocol._resolve_mapping_path = MagicMock(return_value=Path("mapping.yarrrml")) protocol._get_mapping_content = MagicMock(return_value="mapping") - protocol._core_ids.process_graph = MagicMock(side_effect=lambda g, _: g) - protocol._apply_postprocessors = MagicMock(side_effect=lambda g, *_: g) - protocol.patcher.patch_all = AsyncMock() - protocol.rml_service.apply_mapping = AsyncMock( - return_value=_make_dataset_scoped_graph() + protocol._run_mapping_stage = AsyncMock( + return_value=_make_mapping_result(_make_dataset_scoped_graph()) ) + protocol._run_postprocessing_stage = _annotating_pp() + protocol.patcher.patch_all = AsyncMock() response = WebPageScrapeResponse( web_page=WebPage(url="https://example.com/page", html="") @@ -1604,7 +1585,7 @@ def test_protocol_validation_mode_normalization_and_deprecation( ), root_dir=Path.cwd(), ) - assert strict_protocol._shacl_mode == "fail" + assert strict_protocol._shacl_validator.mode.value == "fail" assert "Deprecated SHACL validation mode 'strict' detected" in caplog.text with caplog.at_level("WARNING"): @@ -1615,7 +1596,7 @@ def test_protocol_validation_mode_normalization_and_deprecation( ), root_dir=Path.cwd(), ) - assert unknown_protocol._shacl_mode == "warn" + assert unknown_protocol._shacl_validator.mode.value == "warn" assert "Unsupported SHACL validation mode" in caplog.text with caplog.at_level("WARNING"): diff --git a/tests/kg_build/test_rml_mapping.py b/tests/kg_build/test_rml_mapping.py index 7d4d72e..810e31c 100644 --- a/tests/kg_build/test_rml_mapping.py +++ b/tests/kg_build/test_rml_mapping.py @@ -7,7 +7,6 @@ import pytest from rdflib import Graph -import wordlift_sdk.kg_build.rml_mapping as rml_module from wordlift_sdk.kg_build.rml_mapping import RmlMappingService @@ -34,50 +33,46 @@ def _context(dataset_uri: str | None): @pytest.mark.asyncio -async def test_apply_mapping_from_content_success( - monkeypatch: pytest.MonkeyPatch, -) -> None: - service = RmlMappingService(_context("https://data.example.com")) +async def test_apply_mapping_from_content_success() -> None: + service = RmlMappingService( + _context("https://data.example.com"), pipeline=_Pipeline() + ) service._html_converter.convert = MagicMock(return_value="") - monkeypatch.setattr(rml_module, "MaterializationPipeline", _Pipeline) debug_output: dict[str, str] = {} - graph = await service.apply_mapping( + result = await service.apply_mapping( html="", url="https://example.com/page", mapping_file_path="demo.yarrrml", mapping_content="m: 1", debug_output=debug_output, ) - assert isinstance(graph, Graph) - assert len(graph) > 0 + assert isinstance(result.graph, Graph) + assert len(result.graph) > 0 assert debug_output["xhtml"] == "" @pytest.mark.asyncio async def test_apply_mapping_file_not_found_returns_none() -> None: service = RmlMappingService(_context("https://data.example.com")) - out = await service.apply_mapping( + result = await service.apply_mapping( html="", url="https://example.com", mapping_file_path=Path("/no/such/file.yarrrml"), ) - assert out is None + assert result.graph is None @pytest.mark.asyncio -async def test_apply_mapping_missing_dataset_uri_returns_none( - monkeypatch: pytest.MonkeyPatch, -) -> None: - service = RmlMappingService(_context(None)) - monkeypatch.setattr(rml_module, "MaterializationPipeline", _Pipeline) - out = await service.apply_mapping( +async def test_apply_mapping_missing_dataset_uri_returns_none() -> None: + service = RmlMappingService(_context(None), pipeline=_Pipeline()) + result = await service.apply_mapping( html="", url="https://example.com", mapping_file_path="x", mapping_content="m: 1", ) - assert out is None + assert result.graph is None def test_normalize_schema_uris() -> None: diff --git a/tests/test_dataset_resolver.py b/tests/test_dataset_resolver.py index e42f838..03c945b 100644 --- a/tests/test_dataset_resolver.py +++ b/tests/test_dataset_resolver.py @@ -46,15 +46,16 @@ def __init__(self, *args, **kwargs) -> None: sys.modules.setdefault("wordlift_client.models", _models_module) sys.modules.setdefault("wordlift_client.models.ask_request", _ask_module) -_pyshacl = types.ModuleType("pyshacl") +try: + import pyshacl as _pyshacl_real # noqa: F401 +except ImportError: + _pyshacl = types.ModuleType("pyshacl") + def _stub_validate(*_args, **_kwargs): + return None, None, None -def _stub_validate(*_args, **_kwargs): - return None, None, None - - -_pyshacl.validate = _stub_validate -sys.modules.setdefault("pyshacl", _pyshacl) + _pyshacl.validate = _stub_validate + sys.modules["pyshacl"] = _pyshacl from wordlift_sdk.structured_data.dataset_resolver import DatasetResolver # noqa: E402 diff --git a/tests/test_google_search_console_data_import_helpers.py b/tests/test_google_search_console_data_import_helpers.py index 025e148..e6601a1 100644 --- a/tests/test_google_search_console_data_import_helpers.py +++ b/tests/test_google_search_console_data_import_helpers.py @@ -2,6 +2,8 @@ import asyncio import importlib +import sys +import types from datetime import datetime, timedelta from types import SimpleNamespace @@ -16,6 +18,8 @@ raise_error_if_account_analytics_not_configured, ) +_ENTITIES_MOD = "wordlift_sdk.deprecated.create_entities_with_top_query_dataframe" + gsc_import_mod = importlib.import_module( "wordlift_sdk.google_search_console.create_google_search_console_data_import" ) @@ -43,9 +47,9 @@ async def test_create_google_search_console_data_import_only_imports_stale_rows( async def _fake_entities_df(key, url_list): return source_df - monkeypatch.setattr( - gsc_import_mod, "create_entities_with_top_query_dataframe", _fake_entities_df - ) + stub = types.ModuleType(_ENTITIES_MOD) + stub.create_entities_with_top_query_dataframe = _fake_entities_df # type: ignore[attr-defined] + monkeypatch.setitem(sys.modules, _ENTITIES_MOD, stub) called_urls: list[str] = [] @@ -100,9 +104,9 @@ async def test_create_google_search_console_data_import_skips_when_no_stale( async def _fake_entities_df(key, url_list): return source_df - monkeypatch.setattr( - gsc_import_mod, "create_entities_with_top_query_dataframe", _fake_entities_df - ) + stub = types.ModuleType(_ENTITIES_MOD) + stub.create_entities_with_top_query_dataframe = _fake_entities_df # type: ignore[attr-defined] + monkeypatch.setitem(sys.modules, _ENTITIES_MOD, stub) calls: dict[str, int] = {"gather": 0} diff --git a/tests/test_lazy_exports.py b/tests/test_lazy_exports.py index 093c5c9..e7f020f 100644 --- a/tests/test_lazy_exports.py +++ b/tests/test_lazy_exports.py @@ -6,31 +6,71 @@ import pytest +# Modules that own ProcessPoolExecutors must not be evicted — dropping them +# causes function-identity mismatches when the pool tries to pickle workers. +_PRESERVE_MODULES = frozenset( + [ + "wordlift_sdk.structured_data.engine", + "wordlift_sdk.validation.shacl_validation_service", + "wordlift_sdk.workflow.url_handler.ingestion_web_page_scrape_url_handler", + "wordlift_sdk.workflow.url_handler.web_page_scrape_url_handler", + ] +) + + def _drop_modules(prefix: str) -> None: for name in list(sys.modules): + if name in _PRESERVE_MODULES: + continue if name == prefix or name.startswith(f"{prefix}."): sys.modules.pop(name, None) -def test_root_package_import_is_lazy(): +def test_root_package_import_is_lazy(monkeypatch: pytest.MonkeyPatch): _drop_modules("wordlift_sdk") package = importlib.import_module("wordlift_sdk") assert "wordlift_sdk.main" not in sys.modules + import types + + stub_main = types.ModuleType("wordlift_sdk.main") + stub_main.run_kg_import_workflow = object() # type: ignore[attr-defined] + + def fake_import_module(name: str): + if name == "wordlift_sdk.main": + sys.modules["wordlift_sdk.main"] = stub_main + return stub_main + return importlib.import_module(name) + + monkeypatch.setattr("wordlift_sdk._lazy_exports.import_module", fake_import_module) + package.run_kg_import_workflow assert "wordlift_sdk.main" in sys.modules -def test_feature_package_import_is_lazy(): +def test_feature_package_import_is_lazy(monkeypatch: pytest.MonkeyPatch): _drop_modules("wordlift_sdk.render") package = importlib.import_module("wordlift_sdk.render") assert "wordlift_sdk.render.html_renderer" not in sys.modules + import types + + stub_renderer = types.ModuleType("wordlift_sdk.render.html_renderer") + stub_renderer.HtmlRenderer = object() # type: ignore[attr-defined] + + def fake_import_module(name: str): + if name == "wordlift_sdk.render.html_renderer": + sys.modules["wordlift_sdk.render.html_renderer"] = stub_renderer + return stub_renderer + return importlib.import_module(name) + + monkeypatch.setattr("wordlift_sdk._lazy_exports.import_module", fake_import_module) + package.HtmlRenderer assert "wordlift_sdk.render.html_renderer" in sys.modules diff --git a/tests/test_merchant_listing_defined_region_validation.py b/tests/test_merchant_listing_defined_region_validation.py index d17c799..43d9474 100644 --- a/tests/test_merchant_listing_defined_region_validation.py +++ b/tests/test_merchant_listing_defined_region_validation.py @@ -1,30 +1,13 @@ -import importlib.util import json -import sys from pathlib import Path -import pytest - from wordlift_sdk.validation import shacl from wordlift_sdk.validation.shacl import extract_validation_issues -def _load_real_validate(monkeypatch: pytest.MonkeyPatch): - if "pyshacl" in sys.modules: - monkeypatch.delitem(sys.modules, "pyshacl", raising=False) - spec = importlib.util.find_spec("pyshacl") - if spec is None or spec.loader is None: - raise RuntimeError("pyshacl is required for this test.") - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - return module.validate - - def test_merchant_listing_defined_region_address_country_only_is_warning_only( - tmp_path: Path, monkeypatch: pytest.MonkeyPatch + tmp_path: Path, ) -> None: - monkeypatch.setattr(shacl, "validate", _load_real_validate(monkeypatch)) - payload = { "@context": {"@vocab": "http://schema.org/"}, "@id": "https://data.wordlift.io/wl1506344/merchant-return-policys/shipping-policy/offer-shipping-details/offer-shipping-details-1/defined-regions/defined-region", @@ -51,10 +34,8 @@ def test_merchant_listing_defined_region_address_country_only_is_warning_only( def test_defined_region_address_country_only_conforms_with_default_shapes( - tmp_path: Path, monkeypatch: pytest.MonkeyPatch + tmp_path: Path, ) -> None: - monkeypatch.setattr(shacl, "validate", _load_real_validate(monkeypatch)) - payload = { "@context": {"@vocab": "http://schema.org/"}, "@id": "https://data.wordlift.io/wl1506344/merchant-return-policys/shipping-policy/offer-shipping-details/offer-shipping-details-1/defined-regions/defined-region", diff --git a/tests/test_product_snippet_validation.py b/tests/test_product_snippet_validation.py index 9054125..b202ba1 100644 --- a/tests/test_product_snippet_validation.py +++ b/tests/test_product_snippet_validation.py @@ -1,29 +1,12 @@ from pathlib import Path -import importlib.util import json -import sys - -import pytest from wordlift_sdk.validation import shacl -def _load_real_validate(monkeypatch: pytest.MonkeyPatch): - if "pyshacl" in sys.modules: - monkeypatch.delitem(sys.modules, "pyshacl", raising=False) - spec = importlib.util.find_spec("pyshacl") - if spec is None or spec.loader is None: - raise RuntimeError("pyshacl is required for this test.") - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - return module.validate - - def test_product_snippet_offers_satisfies_one_of( - tmp_path: Path, monkeypatch: pytest.MonkeyPatch + tmp_path: Path, ) -> None: - monkeypatch.setattr(shacl, "validate", _load_real_validate(monkeypatch)) - fixture = Path("tests/fixtures/product_snippet_offers.jsonld") data = json.loads(fixture.read_text(encoding="utf-8")) @@ -43,10 +26,8 @@ def test_product_snippet_offers_satisfies_one_of( def test_product_snippet_aggregate_offer_satisfies_one_of( - tmp_path: Path, monkeypatch: pytest.MonkeyPatch + tmp_path: Path, ) -> None: - monkeypatch.setattr(shacl, "validate", _load_real_validate(monkeypatch)) - fixture = Path("tests/fixtures/product_snippet_aggregate_offer.jsonld") data = json.loads(fixture.read_text(encoding="utf-8")) diff --git a/tests/test_recommended_one_of_validation.py b/tests/test_recommended_one_of_validation.py index f4c89c2..29ce149 100644 --- a/tests/test_recommended_one_of_validation.py +++ b/tests/test_recommended_one_of_validation.py @@ -1,25 +1,10 @@ -import importlib.util import json -import sys from pathlib import Path -import pytest - from wordlift_sdk.validation import shacl from wordlift_sdk.validation.shacl import extract_validation_issues -def _load_real_validate(monkeypatch: pytest.MonkeyPatch): - if "pyshacl" in sys.modules: - monkeypatch.delitem(sys.modules, "pyshacl", raising=False) - spec = importlib.util.find_spec("pyshacl") - if spec is None or spec.loader is None: - raise RuntimeError("pyshacl is required for this test.") - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - return module.validate - - def _write_jsonld(tmp_path: Path, name: str, payload: dict) -> Path: path = tmp_path / name path.write_text(json.dumps(payload), encoding="utf-8") @@ -31,10 +16,8 @@ def _messages_for(result) -> list[str]: def test_dataset_recommended_either_or_is_warning_only( - tmp_path: Path, monkeypatch: pytest.MonkeyPatch + tmp_path: Path, ) -> None: - monkeypatch.setattr(shacl, "validate", _load_real_validate(monkeypatch)) - missing_payload = { "@context": {"@vocab": "http://schema.org/"}, "@type": "Dataset", @@ -74,10 +57,8 @@ def test_dataset_recommended_either_or_is_warning_only( def test_offer_shipping_details_recommended_either_or_is_warning_only( - tmp_path: Path, monkeypatch: pytest.MonkeyPatch + tmp_path: Path, ) -> None: - monkeypatch.setattr(shacl, "validate", _load_real_validate(monkeypatch)) - missing_payload = { "@context": {"@vocab": "http://schema.org/"}, "@type": "OfferShippingDetails", @@ -116,10 +97,8 @@ def test_offer_shipping_details_recommended_either_or_is_warning_only( def test_product_offer_price_currency_recommended_either_or_is_warning_only( - tmp_path: Path, monkeypatch: pytest.MonkeyPatch + tmp_path: Path, ) -> None: - monkeypatch.setattr(shacl, "validate", _load_real_validate(monkeypatch)) - missing_payload = { "@context": {"@vocab": "http://schema.org/"}, "@type": "Product", diff --git a/tests/test_structured_data_engine_class.py b/tests/test_structured_data_engine_class.py index f4a272e..b208a78 100644 --- a/tests/test_structured_data_engine_class.py +++ b/tests/test_structured_data_engine_class.py @@ -44,15 +44,16 @@ def __init__(self, *args, **kwargs) -> None: sys.modules.setdefault("wordlift_client.models", _models_module) sys.modules.setdefault("wordlift_client.models.ask_request", _ask_module) -_pyshacl = types.ModuleType("pyshacl") +try: + import pyshacl as _pyshacl_real # noqa: F401 +except ImportError: + _pyshacl = types.ModuleType("pyshacl") + def _stub_validate(*_args, **_kwargs): + return None, None, None -def _stub_validate(*_args, **_kwargs): - return None, None, None - - -_pyshacl.validate = _stub_validate -sys.modules.setdefault("pyshacl", _pyshacl) + _pyshacl.validate = _stub_validate + sys.modules["pyshacl"] = _pyshacl from wordlift_sdk.structured_data.structured_data_engine import ( # noqa: E402 StructuredDataEngine, diff --git a/tests/test_structured_data_engine_validation_helpers.py b/tests/test_structured_data_engine_validation_helpers.py index 884b483..f3c7700 100644 --- a/tests/test_structured_data_engine_validation_helpers.py +++ b/tests/test_structured_data_engine_validation_helpers.py @@ -977,23 +977,7 @@ def test_normalize_agent_yarrrml_additional_parser_branches(monkeypatch): assert any(m["name"] == "main" for m in mappings) -def test_materialize_graph_and_xpath_first_text_branches(monkeypatch): - real_import = builtins.__import__ - - def _missing_morph(name, *args, **kwargs): - if name == "morph_kgc": - raise ImportError("missing") - return real_import(name, *args, **kwargs) - - monkeypatch.setattr(builtins, "__import__", _missing_morph) - try: - engine._materialize_graph(Path("mapping.yarrrml")) - assert False, "expected RuntimeError" - except RuntimeError as exc: - assert "morph-kgc is required" in str(exc) - finally: - monkeypatch.setattr(builtins, "__import__", real_import) - +def test_materialize_graph_and_xpath_first_text_branches(): class _Doc: def __init__(self): self.calls = 0 diff --git a/tests/test_structured_data_materialization_generic.py b/tests/test_structured_data_materialization_generic.py index 8d387f8..f4f492b 100644 --- a/tests/test_structured_data_materialization_generic.py +++ b/tests/test_structured_data_materialization_generic.py @@ -50,15 +50,16 @@ def __init__(self, *args, **kwargs) -> None: sys.modules.setdefault("wordlift_client.models", _models_module) sys.modules.setdefault("wordlift_client.models.ask_request", _ask_module) -_pyshacl = types.ModuleType("pyshacl") +try: + import pyshacl as _pyshacl_real # noqa: F401 +except ImportError: + _pyshacl = types.ModuleType("pyshacl") + def _stub_validate(*_args, **_kwargs): + return None, None, None -def _stub_validate(*_args, **_kwargs): - return None, None, None - - -_pyshacl.validate = _stub_validate -sys.modules.setdefault("pyshacl", _pyshacl) + _pyshacl.validate = _stub_validate + sys.modules["pyshacl"] = _pyshacl from wordlift_sdk.structured_data.engine import ( # noqa: E402 materialize_yarrrml_jsonld, @@ -577,12 +578,17 @@ def test_unsupported_xpath_or_function_raises_actionable_error( tmp_path: Path, monkeypatch: pytest.MonkeyPatch, ) -> None: - fake_morph = types.SimpleNamespace( - materialize=lambda _cfg: (_ for _ in ()).throw( - ValueError("XPathEvalError: Unsupported function local-namez()") - ) - ) - monkeypatch.setitem(sys.modules, "morph_kgc", fake_morph) + import wordlift_sdk.structured_data.engine as _engine + + class _FakeFuture: + def result(self): + raise ValueError("XPathEvalError: Unsupported function local-namez()") + + class _FakePool: + def submit(self, fn, *args, **kwargs): + return _FakeFuture() + + monkeypatch.setattr(_engine, "_get_morph_kgc_pool", lambda: _FakePool()) mapping = """ prefixes: diff --git a/tests/test_structured_data_workflows.py b/tests/test_structured_data_workflows.py index 373802c..0b7ecad 100644 --- a/tests/test_structured_data_workflows.py +++ b/tests/test_structured_data_workflows.py @@ -52,15 +52,16 @@ def __init__(self, *args, **kwargs) -> None: sys.modules.setdefault("wordlift_client.models", _models_module) sys.modules.setdefault("wordlift_client.models.ask_request", _ask_module) -_pyshacl = types.ModuleType("pyshacl") +try: + import pyshacl as _pyshacl_real # noqa: F401 +except ImportError: + _pyshacl = types.ModuleType("pyshacl") + def _stub_validate(*_args, **_kwargs): + return None, None, None -def _stub_validate(*_args, **_kwargs): - return None, None, None - - -_pyshacl.validate = _stub_validate -sys.modules.setdefault("pyshacl", _pyshacl) + _pyshacl.validate = _stub_validate + sys.modules["pyshacl"] = _pyshacl from wordlift_sdk.structured_data import ( # noqa: E402 CreateRequest, diff --git a/tests/tools/run_slice_tests.py b/tests/tools/run_slice_tests.py index 72e3111..d2ef093 100644 --- a/tests/tools/run_slice_tests.py +++ b/tests/tools/run_slice_tests.py @@ -41,7 +41,6 @@ "tests/ingestion", "tests/test_google_sheets_url_provider.py", "tests/test_list_url_provider.py", - "tests/test_ingestion_source_bridge.py", "tests/url_provider/test_sitemap_url_provider.py", ], "structured-data": [ diff --git a/tests/kg_build/test_ingestion_bridge_url_handler.py b/tests/workflow/test_ingestion_bridge_url_handler.py similarity index 95% rename from tests/kg_build/test_ingestion_bridge_url_handler.py rename to tests/workflow/test_ingestion_bridge_url_handler.py index 60fe267..5ad4fca 100644 --- a/tests/kg_build/test_ingestion_bridge_url_handler.py +++ b/tests/workflow/test_ingestion_bridge_url_handler.py @@ -6,6 +6,7 @@ import pytest +import wordlift_sdk.workflow.url_handler.ingestion_web_page_scrape_url_handler as _handler_mod from wordlift_sdk.ingestion.errors import LoaderRuntimeError from wordlift_sdk.ingestion.loaders import PlaywrightLoaderAdapter from wordlift_sdk.url_source import Url @@ -37,7 +38,8 @@ async def test_ingestion_bridge_handler_calls_callback( ) monkeypatch.setattr( - "wordlift_sdk.workflow.url_handler.ingestion_web_page_scrape_url_handler.run_ingestion", + _handler_mod, + "run_ingestion", lambda settings: SimpleNamespace( pages=[ SimpleNamespace( @@ -88,7 +90,8 @@ async def test_ingestion_bridge_handler_raises_on_failed_ingestion( ) monkeypatch.setattr( - "wordlift_sdk.workflow.url_handler.ingestion_web_page_scrape_url_handler.run_ingestion", + _handler_mod, + "run_ingestion", lambda settings: SimpleNamespace( pages=[], events=[ @@ -124,7 +127,8 @@ async def test_ingestion_bridge_handler_raises_and_skips_callback_on_http_404( ) monkeypatch.setattr( - "wordlift_sdk.workflow.url_handler.ingestion_web_page_scrape_url_handler.run_ingestion", + _handler_mod, + "run_ingestion", lambda settings: SimpleNamespace( pages=[ SimpleNamespace( @@ -164,7 +168,8 @@ async def test_ingestion_bridge_handler_raises_and_skips_callback_on_http_500( ) monkeypatch.setattr( - "wordlift_sdk.workflow.url_handler.ingestion_web_page_scrape_url_handler.run_ingestion", + _handler_mod, + "run_ingestion", lambda settings: SimpleNamespace( pages=[ SimpleNamespace( @@ -205,7 +210,8 @@ async def test_ingestion_bridge_handler_surfaces_failed_meta_diagnostics( caplog.set_level("ERROR") monkeypatch.setattr( - "wordlift_sdk.workflow.url_handler.ingestion_web_page_scrape_url_handler.run_ingestion", + _handler_mod, + "run_ingestion", lambda settings: SimpleNamespace( pages=[], events=[ @@ -265,7 +271,8 @@ async def test_ingestion_bridge_handler_meta_fallback_keeps_old_message( ) monkeypatch.setattr( - "wordlift_sdk.workflow.url_handler.ingestion_web_page_scrape_url_handler.run_ingestion", + _handler_mod, + "run_ingestion", lambda settings: SimpleNamespace( pages=[], events=[ @@ -306,7 +313,8 @@ async def test_ingestion_bridge_handler_truncates_diagnostics_payload( long_message = "token=abc123 " + ("x" * 10000) monkeypatch.setattr( - "wordlift_sdk.workflow.url_handler.ingestion_web_page_scrape_url_handler.run_ingestion", + _handler_mod, + "run_ingestion", lambda settings: SimpleNamespace( pages=[], events=[ diff --git a/tests/kg_build/test_web_page_scrape_url_handler.py b/tests/workflow/test_web_page_scrape_url_handler.py similarity index 100% rename from tests/kg_build/test_web_page_scrape_url_handler.py rename to tests/workflow/test_web_page_scrape_url_handler.py diff --git a/wordlift_sdk/google_search_console/create_google_search_console_data_import.py b/wordlift_sdk/google_search_console/create_google_search_console_data_import.py index 00bd0fc..41f566d 100644 --- a/wordlift_sdk/google_search_console/create_google_search_console_data_import.py +++ b/wordlift_sdk/google_search_console/create_google_search_console_data_import.py @@ -9,7 +9,6 @@ from twisted.mail.scripts.mailmail import Configuration from wordlift_client import AnalyticsImportRequest -from ..deprecated import create_entities_with_top_query_dataframe from ..utils import create_delayed logger = logging.getLogger(__name__) @@ -19,6 +18,8 @@ async def create_google_search_console_data_import( configuration: Configuration, key: str, url_list: list[str] ) -> None: # Get the entities data with the top query. + from ..deprecated import create_entities_with_top_query_dataframe + entities_with_top_query_df = await create_entities_with_top_query_dataframe( key=key, url_list=url_list ) diff --git a/wordlift_sdk/graph/audit/_entity_matrix.py b/wordlift_sdk/graph/audit/_entity_matrix.py index 1c8101e..1dab968 100644 --- a/wordlift_sdk/graph/audit/_entity_matrix.py +++ b/wordlift_sdk/graph/audit/_entity_matrix.py @@ -15,7 +15,7 @@ _find_webpage_urls, ) from wordlift_sdk.validation.shacl import ( - _normalize_schema_org_uris, # type: ignore[attr-defined] + _normalize_schema_org_uris as normalize_schema_org_uris, # type: ignore[attr-defined] ) _SCHEMA_ORG_PREFIXES = ("http://schema.org/", "https://schema.org/") @@ -120,7 +120,7 @@ def build_entity_matrix( excl: set[str] = set(exclude_types or []) load_result = load_graph(path) - normalized = _normalize_schema_org_uris(load_result.graph) + normalized = normalize_schema_org_uris(load_result.graph) webpage_urls = _find_webpage_urls(normalized) if not webpage_urls: diff --git a/wordlift_sdk/kg_build/__init__.py b/wordlift_sdk/kg_build/__init__.py index 1692a5e..17c9c4d 100644 --- a/wordlift_sdk/kg_build/__init__.py +++ b/wordlift_sdk/kg_build/__init__.py @@ -1,8 +1,6 @@ from __future__ import annotations -from importlib import import_module -from typing import Any - +from .._lazy_exports import resolve_attr __all__ = [ "ProfileConfig", @@ -59,18 +57,21 @@ "wordlift_sdk.kg_build.cloud_flow", "get_debug_output_dir", ), - "run_cloud_workflow": ("wordlift_sdk.kg_build.cloud_flow", "run_cloud_workflow"), + "run_cloud_workflow": ("wordlift_sdk.kg_build.protocol", "run_cloud_workflow"), "KgBuildApplicationContainer": ( "wordlift_sdk.kg_build.container", "KgBuildApplicationContainer", ), - "IdAllocator": ("wordlift_sdk.kg_build.id_allocator", "IdAllocator"), + "IdAllocator": ( + "wordlift_sdk.kg_build.postprocessors.processors.id_allocator", + "IdAllocator", + ), "CanonicalIdGenerator": ( - "wordlift_sdk.kg_build.id_generator", + "wordlift_sdk.kg_build.postprocessors.processors.id_generator", "CanonicalIdGenerator", ), "CanonicalIdsPostprocessor": ( - "wordlift_sdk.kg_build.id_postprocessor", + "wordlift_sdk.kg_build.postprocessors.processors.id_postprocessor", "CanonicalIdsPostprocessor", ), "IriLookup": ("wordlift_sdk.kg_build.iri_lookup", "IriLookup"), @@ -139,12 +140,10 @@ } -def __getattr__(name: str) -> Any: - target = _EXPORTS.get(name) - if target is None: - raise AttributeError( - f"module 'wordlift_sdk.kg_build' has no attribute '{name}'" - ) - module_name, attr_name = target - module = import_module(module_name) - return getattr(module, attr_name) +def __getattr__(name: str): + return resolve_attr( + name=name, + module_name="wordlift_sdk.kg_build", + exports=_EXPORTS, + extra="kg-build", + ) diff --git a/wordlift_sdk/kg_build/graph_utils.py b/wordlift_sdk/kg_build/graph_utils.py new file mode 100644 index 0000000..df35268 --- /dev/null +++ b/wordlift_sdk/kg_build/graph_utils.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from rdflib import Graph, URIRef + + +def first_level_subjects(graph: Graph, dataset_uri: str) -> set[URIRef]: + """Return the first-level URIRef subjects of *graph*. + + When *dataset_uri* is set, first-level subjects are those whose IRI matches + ``//`` (exactly two non-empty path segments after the + base URI). Falls back to subjects that are not referenced as objects by any + other triple; if every subject is referenced, returns all subjects. + """ + subjects = {s for s in graph.subjects() if isinstance(s, URIRef)} + if dataset_uri: + first_level_by_id = { + s + for s in subjects + if str(s).startswith(f"{dataset_uri}/") + and len([p for p in str(s)[len(dataset_uri) + 1 :].split("/") if p]) == 2 + } + if first_level_by_id: + return first_level_by_id + + referenced = { + obj + for _, _, obj in graph.triples((None, None, None)) + if isinstance(obj, URIRef) and obj in subjects + } + first_level = subjects - referenced + return first_level or subjects diff --git a/wordlift_sdk/kg_build/id_postprocessor.py b/wordlift_sdk/kg_build/id_postprocessor.py deleted file mode 100644 index 7660a13..0000000 --- a/wordlift_sdk/kg_build/id_postprocessor.py +++ /dev/null @@ -1,37 +0,0 @@ -from __future__ import annotations - -from rdflib import Graph - -from .id_generator import CanonicalIdGenerator -from .iri_lookup import IriLookup - - -class CanonicalIdsPostprocessor: - """Postprocessor adapter that applies canonical ID generation to a graph.""" - - def __init__( - self, - generator: CanonicalIdGenerator | None = None, - iri_lookup: IriLookup | None = None, - context_key: str = "kg_build.iri_lookup", - strategy: str = "legacy", - ) -> None: - self._generator = generator or CanonicalIdGenerator(strategy=strategy) - self._iri_lookup = iri_lookup - self._context_key = context_key - - def process_graph(self, graph: Graph, context) -> Graph: - dataset_uri = str(getattr(context.account, "dataset_uri", "")).rstrip("/") - if not dataset_uri: - return graph - iri_lookup = self._iri_lookup or self._lookup_from_context(context) - return self._generator.apply(graph, dataset_uri, iri_lookup=iri_lookup) - - def _lookup_from_context(self, context) -> IriLookup | None: - extensions = getattr(context, "extensions", None) - if not isinstance(extensions, dict): - return None - lookup = extensions.get(self._context_key) - if lookup is None or not hasattr(lookup, "iri_for_subject"): - return None - return lookup diff --git a/wordlift_sdk/kg_build/kpi.py b/wordlift_sdk/kg_build/kpi.py index 5edea07..f6c0822 100644 --- a/wordlift_sdk/kg_build/kpi.py +++ b/wordlift_sdk/kg_build/kpi.py @@ -6,6 +6,8 @@ from rdflib import Graph, RDF, URIRef +from wordlift_sdk.validation.shacl_validation_service import ValidationOutcome + @dataclass class KgBuildKpiCollector: @@ -98,26 +100,18 @@ def record_graph(self, graph: Graph) -> None: self._property_assertions_total += 1 self._properties_by_predicate[str(predicate)] += 1 - def record_validation( - self, - *, - passed: bool, - warning_count: int, - error_count: int, - warning_sources: dict[str, int] | Counter[str] | None = None, - error_sources: dict[str, int] | Counter[str] | None = None, - ) -> None: + def record_validation(self, outcome: ValidationOutcome) -> None: self._validation_total += 1 - if passed: + if outcome.passed: self._validation_pass += 1 else: self._validation_fail += 1 - self._warning_count += warning_count - self._error_count += error_count - if warning_sources: - self._warning_sources.update(warning_sources) - if error_sources: - self._error_sources.update(error_sources) + self._warning_count += outcome.warning_count + self._error_count += outcome.error_count + if outcome.warning_sources: + self._warning_sources.update(outcome.warning_sources) + if outcome.error_sources: + self._error_sources.update(outcome.error_sources) def summary(self, profile_name: str) -> dict[str, object]: entities_by_type = { diff --git a/wordlift_sdk/kg_build/postprocessors.py b/wordlift_sdk/kg_build/postprocessors.py deleted file mode 100644 index b67de6b..0000000 --- a/wordlift_sdk/kg_build/postprocessors.py +++ /dev/null @@ -1,620 +0,0 @@ -from __future__ import annotations - -import json -import logging -import select -import shutil -import subprocess -import tempfile -from dataclasses import dataclass, field -from pathlib import Path -from typing import Any, Protocol, runtime_checkable - -from rdflib import Dataset, Graph - -logger = logging.getLogger(__name__) - -try: - import tomllib -except ModuleNotFoundError: # pragma: no cover - import tomli as tomllib - -_RUNTIME_ONESHOT = "oneshot" -_RUNTIME_PERSISTENT = "persistent" - - -@dataclass(frozen=True) -class PostprocessorContext: - profile_name: str - profile: dict[str, Any] - url: str - account: Any - account_key: str | None - exports: dict[str, Any] - response: Any - existing_web_page_id: str | None - ids: Any | None = None - - -@runtime_checkable -class GraphPostprocessor(Protocol): - def process_graph( - self, graph: Graph, context: PostprocessorContext - ) -> Graph | None: ... - - -@dataclass(frozen=True) -class LoadedPostprocessor: - name: str - handler: GraphPostprocessor - - def run(self, graph: Graph, context: PostprocessorContext) -> Graph: - result = self.handler.process_graph(graph, context) - return graph if result is None else result - - -@dataclass(frozen=True) -class PostprocessorSpec: - class_path: str - python: str - timeout_seconds: int - enabled: bool - keep_temp_on_error: bool - - -class PersistentWorkerTransportError(RuntimeError): - pass - - -class PersistentWorkerJobError(RuntimeError): - pass - - -class PersistentPostprocessorClient: - def __init__(self, *, spec: PostprocessorSpec, root_dir: Path) -> None: - self._spec = spec - self._root_dir = root_dir - self._process: subprocess.Popen[str] | None = None - self._next_job_id = 0 - - def close(self) -> None: - process = self._process - self._process = None - if process is None: - return - - try: - if process.poll() is None and process.stdin is not None: - process.stdin.write(json.dumps({"op": "shutdown"}) + "\n") - process.stdin.flush() - except Exception: - pass - - self._terminate(process) - - def process_graph( - self, - *, - input_graph_path: Path, - output_graph_path: Path, - context_payload: dict[str, Any], - ) -> None: - for attempt in range(2): - try: - self._process_graph_once( - input_graph_path=input_graph_path, - output_graph_path=output_graph_path, - context_payload=context_payload, - ) - return - except PersistentWorkerTransportError: - self.close() - if attempt == 1: - raise - - def _process_graph_once( - self, - *, - input_graph_path: Path, - output_graph_path: Path, - context_payload: dict[str, Any], - ) -> None: - process = self._ensure_started() - self._next_job_id += 1 - job_id = self._next_job_id - - payload = { - "op": "process", - "id": job_id, - "input_graph": str(input_graph_path), - "output_graph": str(output_graph_path), - "context": context_payload, - } - - try: - assert process.stdin is not None - process.stdin.write( - json.dumps(payload, ensure_ascii=True, default=str) + "\n" - ) - process.stdin.flush() - except Exception as exc: - raise PersistentWorkerTransportError( - f"Postprocessor worker stdin failed: {self._spec.class_path}" - ) from exc - - message = self._read_message( - process, timeout_seconds=self._spec.timeout_seconds - ) - if message.get("id") != job_id: - raise PersistentWorkerTransportError( - f"Postprocessor worker returned invalid response id for {self._spec.class_path}." - ) - if message.get("ok") is True: - return - - error = str(message.get("error") or "unknown worker error") - raise PersistentWorkerJobError( - f"Postprocessor failed: {self._spec.class_path}\n{error}".strip() - ) - - def _ensure_started(self) -> subprocess.Popen[str]: - process = self._process - if process is not None and process.poll() is None: - return process - - cmd = [ - self._spec.python, - "-m", - "wordlift_sdk.kg_build.postprocessor_worker", - "--class", - self._spec.class_path, - ] - process = subprocess.Popen( - cmd, - text=True, - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - cwd=str(self._root_dir), - bufsize=1, - ) - - try: - ready = self._read_message( - process, timeout_seconds=min(self._spec.timeout_seconds, 10) - ) - except Exception: - self._terminate(process) - raise - - if ready.get("op") != "ready" or ready.get("ok") is not True: - stderr = self._read_stderr(process) - self._terminate(process) - raise PersistentWorkerTransportError( - f"Postprocessor worker failed to start: {self._spec.class_path}" - + (f"\n{stderr}" if stderr else "") - ) - - self._process = process - return process - - def _read_message( - self, - process: subprocess.Popen[str], - *, - timeout_seconds: int, - ) -> dict[str, Any]: - if process.stdout is None: - raise PersistentWorkerTransportError("Worker stdout is unavailable.") - - ready, _, _ = select.select([process.stdout], [], [], timeout_seconds) - if not ready: - self._terminate(process) - cmd = ( - process.args if isinstance(process.args, list) else [str(process.args)] - ) - raise subprocess.TimeoutExpired(cmd=cmd, timeout=timeout_seconds) - - line = process.stdout.readline() - if not line: - stderr = self._read_stderr(process) - self._terminate(process) - raise PersistentWorkerTransportError( - f"Postprocessor worker exited unexpectedly: {self._spec.class_path}" - + (f"\n{stderr}" if stderr else "") - ) - - try: - return json.loads(line) - except json.JSONDecodeError as exc: - raise PersistentWorkerTransportError( - "Postprocessor worker returned invalid JSON response." - ) from exc - - def _read_stderr(self, process: subprocess.Popen[str]) -> str: - if process.stderr is None: - return "" - try: - return (process.stderr.read() or "").strip() - except Exception: - return "" - - def _terminate(self, process: subprocess.Popen[str]) -> None: - if process.poll() is None: - process.kill() - try: - process.wait(timeout=5) - except subprocess.TimeoutExpired: - pass - - -@dataclass -class SubprocessPostprocessor: - spec: PostprocessorSpec - root_dir: Path - runtime: str = _RUNTIME_ONESHOT - _persistent_client: PersistentPostprocessorClient | None = field( - init=False, - default=None, - repr=False, - ) - - def close(self) -> None: - if self._persistent_client is not None: - self._persistent_client.close() - self._persistent_client = None - - def process_graph( - self, graph: Graph, context: PostprocessorContext - ) -> Graph | None: - payload = _build_runner_payload(context) - temp_dir_path = Path(tempfile.mkdtemp(prefix="worai_pp_")) - failed = False - try: - input_graph_path = temp_dir_path / "input_graph.nq" - output_graph_path = temp_dir_path / "output_graph.nq" - context_path = temp_dir_path / "context.json" - - _write_graph_nquads(graph, input_graph_path) - context_path.write_text( - json.dumps(payload, ensure_ascii=True, default=str), - encoding="utf-8", - ) - - if self.runtime == _RUNTIME_PERSISTENT: - self._run_persistent( - input_graph_path=input_graph_path, - output_graph_path=output_graph_path, - context_payload=payload, - ) - else: - self._run_oneshot( - input_graph_path=input_graph_path, - output_graph_path=output_graph_path, - context_path=context_path, - ) - - if not output_graph_path.exists(): - failed = True - raise RuntimeError( - "Postprocessor did not produce output graph: " - f"{self.spec.class_path}" - ) - - return _read_graph_nquads(output_graph_path) - except Exception: - failed = True - raise - finally: - if failed and self.spec.keep_temp_on_error: - debug_dir = self.root_dir / "output" / "postprocessor_debug" - debug_dir.mkdir(parents=True, exist_ok=True) - target = debug_dir / ( - self.spec.class_path.replace(":", "_").replace(".", "_") - ) - if target.exists(): - shutil.rmtree(target) - shutil.copytree(temp_dir_path, target) - _redact_debug_context(target / "context.json") - if temp_dir_path.exists(): - shutil.rmtree(temp_dir_path, ignore_errors=True) - - def _run_oneshot( - self, - *, - input_graph_path: Path, - output_graph_path: Path, - context_path: Path, - ) -> None: - cmd = [ - self.spec.python, - "-m", - "wordlift_sdk.kg_build.postprocessor_runner", - "--class", - self.spec.class_path, - "--input-graph", - str(input_graph_path), - "--output-graph", - str(output_graph_path), - "--context", - str(context_path), - ] - completed = subprocess.run( - cmd, - text=True, - capture_output=True, - cwd=str(self.root_dir), - timeout=self.spec.timeout_seconds, - check=False, - ) - if completed.returncode != 0: - stderr = (completed.stderr or "").strip() - raise RuntimeError( - f"Postprocessor failed: {self.spec.class_path} " - f"(exit={completed.returncode})" + (f"\n{stderr}" if stderr else "") - ) - - def _run_persistent( - self, - *, - input_graph_path: Path, - output_graph_path: Path, - context_payload: dict[str, Any], - ) -> None: - if self._persistent_client is None: - self._persistent_client = PersistentPostprocessorClient( - spec=self.spec, - root_dir=self.root_dir, - ) - self._persistent_client.process_graph( - input_graph_path=input_graph_path, - output_graph_path=output_graph_path, - context_payload=context_payload, - ) - - -def _as_bool(value: Any, default: bool) -> bool: - if value is None: - return default - if isinstance(value, bool): - return value - raise TypeError("Expected boolean value.") - - -def _as_str(value: Any, default: str) -> str: - if value is None: - return default - if not isinstance(value, str) or not value.strip(): - raise TypeError("Expected non-empty string value.") - return value - - -def _as_positive_int(value: Any, default: int) -> int: - if value is None: - return default - if not isinstance(value, int) or value <= 0: - raise TypeError("Expected positive integer value.") - return value - - -def _normalize_runtime(value: str | None) -> str: - runtime = (value or _RUNTIME_ONESHOT).strip().lower() - if runtime not in {_RUNTIME_ONESHOT, _RUNTIME_PERSISTENT}: - raise ValueError("POSTPROCESSOR_RUNTIME must be one of: oneshot, persistent.") - return runtime - - -def _load_manifest_specs(manifest_path: Path) -> list[PostprocessorSpec]: - if not manifest_path.exists(): - return [] - with open(manifest_path, "rb") as f: - doc = tomllib.load(f) - - default_python = _as_str(doc.get("python"), "./.venv/bin/python") - default_timeout = _as_positive_int(doc.get("timeout_seconds"), 120) - default_enabled = _as_bool(doc.get("enabled"), True) - default_keep_temp = _as_bool(doc.get("keep_temp_on_error"), False) - - rows = doc.get("postprocessors") - if rows is None: - return [] - if not isinstance(rows, list): - raise TypeError( - f"{manifest_path}: 'postprocessors' must be an array of tables." - ) - - specs: list[PostprocessorSpec] = [] - for index, row in enumerate(rows, start=1): - if not isinstance(row, dict): - raise TypeError( - f"{manifest_path}: postprocessors[{index}] must be a table." - ) - class_path = row.get("class") - if not isinstance(class_path, str) or ":" not in class_path: - raise TypeError( - f"{manifest_path}: postprocessors[{index}].class must be " - "'package.module:ClassName'." - ) - spec = PostprocessorSpec( - class_path=class_path.strip(), - python=_as_str(row.get("python"), default_python), - timeout_seconds=_as_positive_int( - row.get("timeout_seconds"), default_timeout - ), - enabled=_as_bool(row.get("enabled"), default_enabled), - keep_temp_on_error=_as_bool( - row.get("keep_temp_on_error"), default_keep_temp - ), - ) - specs.append(spec) - return specs - - -def _build_runner_payload(context: PostprocessorContext) -> dict[str, Any]: - account = getattr(context, "account", None) - dataset_uri = str(getattr(account, "dataset_uri", "")).rstrip("/") - country_code = str(getattr(account, "country_code", "")).strip().lower() - account_key = ( - str(context.account_key).strip() - if getattr(context, "account_key", None) is not None - else "" - ) - profile = dict(getattr(context, "profile", {}) or {}) - if "settings" not in profile or not isinstance(profile.get("settings"), dict): - profile["settings"] = {} - profile_settings = dict(profile.get("settings", {}) or {}) - profile_settings.setdefault("api_url", "https://api.wordlift.io") - profile["settings"] = profile_settings - response = getattr(context, "response", None) - web_page = getattr(response, "web_page", None) if response else None - return { - "profile_name": context.profile_name, - "profile": profile, - "url": context.url, - "dataset_uri": dataset_uri, - "country_code": country_code, - "account_key": account_key or None, - "exports": context.exports, - "existing_web_page_id": context.existing_web_page_id, - "response": { - "id": getattr(response, "id", None) or context.existing_web_page_id, - "web_page": { - "url": getattr(web_page, "url", None), - "html": getattr(web_page, "html", None), - }, - }, - } - - -def load_postprocessors_for_profile( - *, - root_dir: Path, - profile_name: str, - runtime: str | None = None, -) -> list[LoadedPostprocessor]: - base_manifest = root_dir / "profiles" / "_base" / "postprocessors.toml" - profile_manifest = root_dir / "profiles" / profile_name / "postprocessors.toml" - - selected_manifest: Path | None - if profile_manifest.exists(): - selected_manifest = profile_manifest - elif base_manifest.exists(): - selected_manifest = base_manifest - else: - selected_manifest = None - - specs = _load_manifest_specs(selected_manifest) if selected_manifest else [] - - resolved_runtime = _normalize_runtime(runtime) - loaded: list[LoadedPostprocessor] = [] - for spec in specs: - if not spec.enabled: - continue - loaded.append( - LoadedPostprocessor( - name=spec.class_path, - handler=SubprocessPostprocessor( - spec=spec, - root_dir=root_dir, - runtime=resolved_runtime, - ), - ) - ) - - logger.info( - "Loaded %s postprocessors for profile '%s' from manifest: %s (runtime=%s)", - len(loaded), - profile_name, - selected_manifest or "none", - resolved_runtime, - ) - logger.debug( - "Postprocessor manifest precedence for profile '%s': selected=%s base=%s chosen=%s", - profile_name, - profile_manifest, - base_manifest, - selected_manifest or "none", - ) - return loaded - - -def load_postprocessors( - manifest_path: Path, - *, - root_dir: Path, - runtime: str | None = None, -) -> list[LoadedPostprocessor]: - specs = _load_manifest_specs(manifest_path) - resolved_runtime = _normalize_runtime(runtime) - loaded: list[LoadedPostprocessor] = [] - for spec in specs: - if not spec.enabled: - continue - loaded.append( - LoadedPostprocessor( - name=spec.class_path, - handler=SubprocessPostprocessor( - spec=spec, - root_dir=root_dir, - runtime=resolved_runtime, - ), - ) - ) - return loaded - - -def close_loaded_postprocessors(postprocessors: list[LoadedPostprocessor]) -> None: - for processor in postprocessors: - close = getattr(processor.handler, "close", None) - if callable(close): - close() - - -def _write_graph_nquads(graph: Graph, path: Path) -> None: - dataset = Dataset() - for triple in graph: - dataset.add(triple) - dataset.serialize(destination=path, format="nquads") - - -def _read_graph_nquads(path: Path) -> Graph: - dataset = Dataset() - dataset.parse(path, format="nquads") - graph = Graph() - for triple in dataset.triples((None, None, None)): - graph.add(triple) - return graph - - -def _redact_debug_context(path: Path) -> None: - if not path.exists(): - return - try: - payload = json.loads(path.read_text(encoding="utf-8")) - except Exception: - return - if not isinstance(payload, dict): - return - if payload.get("account_key"): - payload["account_key"] = "***REDACTED***" - profile = payload.get("profile") - if isinstance(profile, dict) and profile.get("api_key"): - profile["api_key"] = "***REDACTED***" - settings = ( - profile.get("settings") - if isinstance(profile, dict) and isinstance(profile.get("settings"), dict) - else None - ) - if settings and settings.get("api_key"): - settings["api_key"] = "***REDACTED***" - if settings and settings.get("wordlift_key"): - settings["wordlift_key"] = "***REDACTED***" - if settings and settings.get("WORDLIFT_KEY"): - settings["WORDLIFT_KEY"] = "***REDACTED***" - if settings and settings.get("WORDLIFT_API_KEY"): - settings["WORDLIFT_API_KEY"] = "***REDACTED***" - payload["profile"] = profile - path.write_text( - json.dumps(payload, ensure_ascii=True, default=str), - encoding="utf-8", - ) diff --git a/wordlift_sdk/kg_build/postprocessors/__init__.py b/wordlift_sdk/kg_build/postprocessors/__init__.py new file mode 100644 index 0000000..b05f6f4 --- /dev/null +++ b/wordlift_sdk/kg_build/postprocessors/__init__.py @@ -0,0 +1,168 @@ +from __future__ import annotations + +import logging +from pathlib import Path + +from .graph_io import close_loaded_postprocessors +from .subprocess import ( + _build_handler, + _normalize_runtime, +) +from .types import ( + Closeable, + GraphPostprocessor, + LoadedPostprocessor, + PostprocessorContext, + PostprocessorResult, + PostprocessorRuntime, + PostprocessorSpec, + PersistentWorkerJobError, + PersistentWorkerTransportError, +) + +logger = logging.getLogger(__name__) + +try: + import tomllib +except ModuleNotFoundError: # pragma: no cover + import tomli as tomllib + + +def _as_bool(value, default: bool) -> bool: + if value is None: + return default + if isinstance(value, bool): + return value + raise TypeError("Expected boolean value.") + + +def _as_str(value, default: str) -> str: + if value is None: + return default + if not isinstance(value, str) or not value.strip(): + raise TypeError("Expected non-empty string value.") + return value + + +def _as_positive_int(value, default: int) -> int: + if value is None: + return default + if not isinstance(value, int) or value <= 0: + raise TypeError("Expected positive integer value.") + return value + + +def _load_manifest_specs(manifest_path: Path) -> list[PostprocessorSpec]: + if not manifest_path.exists(): + return [] + with open(manifest_path, "rb") as f: + doc = tomllib.load(f) + + default_python = _as_str(doc.get("python"), "./.venv/bin/python") + default_timeout = _as_positive_int(doc.get("timeout_seconds"), 120) + default_enabled = _as_bool(doc.get("enabled"), True) + default_keep_temp = _as_bool(doc.get("keep_temp_on_error"), False) + + rows = doc.get("postprocessors") + if rows is None: + return [] + if not isinstance(rows, list): + raise TypeError( + f"{manifest_path}: 'postprocessors' must be an array of tables." + ) + + specs: list[PostprocessorSpec] = [] + for index, row in enumerate(rows, start=1): + if not isinstance(row, dict): + raise TypeError( + f"{manifest_path}: postprocessors[{index}] must be a table." + ) + class_path = row.get("class") + if not isinstance(class_path, str) or ":" not in class_path: + raise TypeError( + f"{manifest_path}: postprocessors[{index}].class must be " + "'package.module:ClassName'." + ) + specs.append(PostprocessorSpec( + class_path=class_path.strip(), + python=_as_str(row.get("python"), default_python), + timeout_seconds=_as_positive_int(row.get("timeout_seconds"), default_timeout), + enabled=_as_bool(row.get("enabled"), default_enabled), + keep_temp_on_error=_as_bool(row.get("keep_temp_on_error"), default_keep_temp), + )) + return specs + + +def _load_from_specs( + specs: list[PostprocessorSpec], + root_dir: Path, + runtime: PostprocessorRuntime, +) -> list[LoadedPostprocessor]: + return [ + LoadedPostprocessor( + name=spec.class_path, + handler=_build_handler(spec, root_dir, runtime), + ) + for spec in specs + if spec.enabled + ] + + +def load_postprocessors_for_profile( + *, + root_dir: Path, + profile_name: str, + runtime: str | None = None, +) -> list[LoadedPostprocessor]: + base_manifest = root_dir / "profiles" / "_base" / "postprocessors.toml" + profile_manifest = root_dir / "profiles" / profile_name / "postprocessors.toml" + + if profile_manifest.exists(): + selected_manifest: Path | None = profile_manifest + elif base_manifest.exists(): + selected_manifest = base_manifest + else: + selected_manifest = None + + logger.debug( + "Postprocessor manifest precedence for profile '%s': profile=%s base=%s chosen=%s", + profile_name, + profile_manifest, + base_manifest, + selected_manifest or "none", + ) + return load_postprocessors(selected_manifest, root_dir=root_dir, runtime=runtime) + + +def load_postprocessors( + manifest_path: Path | None, + *, + root_dir: Path, + runtime: str | None = None, +) -> list[LoadedPostprocessor]: + specs = _load_manifest_specs(manifest_path) if manifest_path else [] + resolved_runtime = _normalize_runtime(runtime) + loaded = _load_from_specs(specs, root_dir, resolved_runtime) + logger.info( + "Loaded %s postprocessors from manifest: %s (runtime=%s)", + len(loaded), + manifest_path or "none", + resolved_runtime, + ) + return loaded + + +__all__ = [ + "Closeable", + "GraphPostprocessor", + "LoadedPostprocessor", + "PostprocessorContext", + "PostprocessorResult", + "PostprocessorRuntime", + "PostprocessorSpec", + "PersistentWorkerJobError", + "PersistentWorkerTransportError", + "close_loaded_postprocessors", + "load_postprocessors", + "load_postprocessors_for_profile", +] diff --git a/wordlift_sdk/kg_build/postprocessors/graph_io.py b/wordlift_sdk/kg_build/postprocessors/graph_io.py new file mode 100644 index 0000000..866189c --- /dev/null +++ b/wordlift_sdk/kg_build/postprocessors/graph_io.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +from rdflib import Dataset, Graph + +from .types import Closeable, LoadedPostprocessor, PostprocessorContext + + +def _build_runner_payload(context: PostprocessorContext) -> dict[str, Any]: + account = getattr(context, "account", None) + dataset_uri = str(getattr(account, "dataset_uri", "")).rstrip("/") + country_code = str(getattr(account, "country_code", "")).strip().lower() + account_key = ( + str(context.account_key).strip() + if getattr(context, "account_key", None) is not None + else "" + ) + profile = dict(getattr(context, "profile", {}) or {}) + if "settings" not in profile or not isinstance(profile.get("settings"), dict): + profile["settings"] = {} + profile_settings = dict(profile.get("settings", {}) or {}) + profile_settings.setdefault("api_url", "https://api.wordlift.io") + profile["settings"] = profile_settings + response = getattr(context, "response", None) + web_page = getattr(response, "web_page", None) if response else None + return { + "profile_name": context.profile_name, + "profile": profile, + "url": context.url, + "dataset_uri": dataset_uri, + "country_code": country_code, + "account_key": account_key or None, + "exports": context.exports, + "existing_web_page_id": context.existing_web_page_id, + "response": { + "id": getattr(response, "id", None) or context.existing_web_page_id, + "web_page": { + "url": getattr(web_page, "url", None), + "html": getattr(web_page, "html", None), + }, + }, + } + + +def close_loaded_postprocessors(postprocessors: list[LoadedPostprocessor]) -> None: + for processor in postprocessors: + if isinstance(processor.handler, Closeable): + processor.handler.close() + + +def _write_graph_nquads(graph: Graph, path: Path) -> None: + dataset = Dataset() + for triple in graph: + dataset.add(triple) + dataset.serialize(destination=path, format="nquads") + + +def _read_graph_nquads(path: Path) -> Graph: + dataset = Dataset() + dataset.parse(path, format="nquads") + graph = Graph() + for triple in dataset.triples((None, None, None)): + graph.add(triple) + return graph + + +def _redact_debug_context(path: Path) -> None: + if not path.exists(): + return + try: + payload = json.loads(path.read_text(encoding="utf-8")) + except Exception: + return + if not isinstance(payload, dict): + return + if payload.get("account_key"): + payload["account_key"] = "***REDACTED***" + profile = payload.get("profile") + if isinstance(profile, dict) and profile.get("api_key"): + profile["api_key"] = "***REDACTED***" + settings = ( + profile.get("settings") + if isinstance(profile, dict) and isinstance(profile.get("settings"), dict) + else None + ) + if settings and settings.get("api_key"): + settings["api_key"] = "***REDACTED***" + if settings and settings.get("wordlift_key"): + settings["wordlift_key"] = "***REDACTED***" + if settings and settings.get("WORDLIFT_KEY"): + settings["WORDLIFT_KEY"] = "***REDACTED***" + if settings and settings.get("WORDLIFT_API_KEY"): + settings["WORDLIFT_API_KEY"] = "***REDACTED***" + payload["profile"] = profile + path.write_text( + json.dumps(payload, ensure_ascii=True, default=str), + encoding="utf-8", + ) diff --git a/wordlift_sdk/kg_build/postprocessor_runner.py b/wordlift_sdk/kg_build/postprocessors/oneshot.py similarity index 96% rename from wordlift_sdk/kg_build/postprocessor_runner.py rename to wordlift_sdk/kg_build/postprocessors/oneshot.py index f85fce6..6b8ceda 100644 --- a/wordlift_sdk/kg_build/postprocessor_runner.py +++ b/wordlift_sdk/kg_build/postprocessors/oneshot.py @@ -10,8 +10,8 @@ from rdflib import Dataset, Graph -from .id_allocator import IdAllocator -from .postprocessors import PostprocessorContext +from .types import PostprocessorContext +from .processors.id_allocator import IdAllocator def _build_context(payload: dict[str, Any]) -> PostprocessorContext: @@ -90,7 +90,7 @@ def main() -> None: output_graph = graph if result is None else result _write_graph_nquads(output_graph, Path(args.output_graph)) except Exception as exc: # pragma: no cover - process boundary - print(f"[postprocessor_runner] {exc}", file=sys.stderr) + print(f"[postprocessors.runner] {exc}", file=sys.stderr) raise SystemExit(1) from exc diff --git a/wordlift_sdk/kg_build/postprocessor_worker.py b/wordlift_sdk/kg_build/postprocessors/persistent.py similarity index 98% rename from wordlift_sdk/kg_build/postprocessor_worker.py rename to wordlift_sdk/kg_build/postprocessors/persistent.py index 3a62fbb..eb04efb 100644 --- a/wordlift_sdk/kg_build/postprocessor_worker.py +++ b/wordlift_sdk/kg_build/postprocessors/persistent.py @@ -12,7 +12,7 @@ from rdflib import Dataset, Graph -from .postprocessor_runner import _build_context +from .oneshot import _build_context def _load_class(class_path: str): diff --git a/wordlift_sdk/kg_build/postprocessors/processors/__init__.py b/wordlift_sdk/kg_build/postprocessors/processors/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/wordlift_sdk/kg_build/postprocessors/processors/graph_annotation.py b/wordlift_sdk/kg_build/postprocessors/processors/graph_annotation.py new file mode 100644 index 0000000..5769615 --- /dev/null +++ b/wordlift_sdk/kg_build/postprocessors/processors/graph_annotation.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from rdflib import Graph, Literal, URIRef + +from ...graph_utils import first_level_subjects + +SEOVOC_SOURCE = URIRef("https://w3id.org/seovoc/source") +SEOVOC_IMPORT_HASH = URIRef("https://w3id.org/seovoc/importHash") + + +class ImportAnnotationPostprocessor: + """Stamps first-level graph subjects with web-page-import provenance metadata. + + Sets seovoc:source to 'web-page-import' on every first-level subject, and + optionally propagates the existing import hash to all URIRef subjects when + import_hash_mode is not 'off'. Both are needed before graph persistence so + the KG can track provenance and skip unchanged imports. + + Reads from context: + - account.dataset_uri — for first-level subject resolution + - existing_import_hash — hash from a prior import of the same page + - import_hash_mode — 'on' | 'write' | 'off' + """ + + def process_graph(self, graph: Graph, context) -> Graph: + dataset_uri = str( + getattr(getattr(context, "account", None), "dataset_uri", "") or "" + ).rstrip("/") + for subject in first_level_subjects(graph, dataset_uri): + graph.set((subject, SEOVOC_SOURCE, Literal("web-page-import"))) + + import_hash_mode = getattr(context, "import_hash_mode", "on") + if import_hash_mode == "off": + return graph + existing_import_hash = getattr(context, "existing_import_hash", None) + if not existing_import_hash: + return graph + for subject in (s for s in graph.subjects() if isinstance(s, URIRef)): + graph.set((subject, SEOVOC_IMPORT_HASH, Literal(existing_import_hash))) + + return graph diff --git a/wordlift_sdk/kg_build/id_allocator.py b/wordlift_sdk/kg_build/postprocessors/processors/id_allocator.py similarity index 99% rename from wordlift_sdk/kg_build/id_allocator.py rename to wordlift_sdk/kg_build/postprocessors/processors/id_allocator.py index d597272..de4e6ce 100644 --- a/wordlift_sdk/kg_build/id_allocator.py +++ b/wordlift_sdk/kg_build/postprocessors/processors/id_allocator.py @@ -6,7 +6,7 @@ from rdflib import Graph, Literal, RDF, URIRef -from .id_policy import DEFAULT_ID_POLICY, IdPolicy +from ...id_policy import DEFAULT_ID_POLICY, IdPolicy SCHEMA = "http://schema.org/" diff --git a/wordlift_sdk/kg_build/id_generator.py b/wordlift_sdk/kg_build/postprocessors/processors/id_generator.py similarity index 99% rename from wordlift_sdk/kg_build/id_generator.py rename to wordlift_sdk/kg_build/postprocessors/processors/id_generator.py index 3741c6f..d063f59 100644 --- a/wordlift_sdk/kg_build/id_generator.py +++ b/wordlift_sdk/kg_build/postprocessors/processors/id_generator.py @@ -7,8 +7,8 @@ from rdflib import Graph, Literal, RDF, URIRef -from .id_policy import DEFAULT_ID_POLICY, IdPolicy -from .iri_lookup import IriLookup +from ...id_policy import DEFAULT_ID_POLICY, IdPolicy +from ...iri_lookup import IriLookup SCHEMA = "http://schema.org/" diff --git a/wordlift_sdk/kg_build/postprocessors/processors/id_postprocessor.py b/wordlift_sdk/kg_build/postprocessors/processors/id_postprocessor.py new file mode 100644 index 0000000..ae51a92 --- /dev/null +++ b/wordlift_sdk/kg_build/postprocessors/processors/id_postprocessor.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from rdflib import Graph, RDF, URIRef + +from .id_generator import CanonicalIdGenerator +from ...iri_lookup import IriLookup + + +def _find_web_page_iri(graph: Graph) -> URIRef | None: + for subject in graph.subjects(RDF.type, URIRef("http://schema.org/WebPage")): + return subject + for subject in graph.subjects(RDF.type, URIRef("https://schema.org/WebPage")): + return subject + return None + + +def _swap_iris(graph: Graph, old_iri: URIRef, new_iri: URIRef) -> None: + for subject, predicate, obj in list(graph.triples((old_iri, None, None))): + graph.remove((subject, predicate, obj)) + graph.add((new_iri, predicate, obj)) + for subject, predicate, obj in list(graph.triples((None, None, old_iri))): + graph.remove((subject, predicate, obj)) + graph.add((subject, predicate, new_iri)) + + +class RootIdReconcilerPostprocessor: + """Rewrites the WebPage node IRI to match the existing web page ID. + + When a page has been imported before, the mapping may generate a different + IRI than the one already stored. This postprocessor swaps all triples + referencing the old IRI to use the canonical one from the system. + Runs before custom postprocessors so they always see the correct subject. + """ + + def process_graph(self, graph: Graph, context) -> Graph: + root_id = getattr(context, "existing_web_page_id", None) + if not root_id: + return graph + old_iri = _find_web_page_iri(graph) + if old_iri and str(old_iri) != root_id: + _swap_iris(graph, old_iri, URIRef(root_id)) + return graph + + +class CanonicalIdsPostprocessor: + """Postprocessor adapter that applies canonical ID generation to a graph.""" + + def __init__( + self, + generator: CanonicalIdGenerator | None = None, + iri_lookup: IriLookup | None = None, + context_key: str = "kg_build.iri_lookup", + strategy: str = "legacy", + ) -> None: + self._generator = generator or CanonicalIdGenerator(strategy=strategy) + self._iri_lookup = iri_lookup + self._context_key = context_key + + def process_graph(self, graph: Graph, context) -> Graph: + dataset_uri = str(getattr(context.account, "dataset_uri", "")).rstrip("/") + if not dataset_uri: + return graph + iri_lookup = self._iri_lookup or self._lookup_from_context(context) + return self._generator.apply(graph, dataset_uri, iri_lookup=iri_lookup) + + def _lookup_from_context(self, context) -> IriLookup | None: + extensions = getattr(context, "extensions", None) + if not isinstance(extensions, dict): + return None + lookup = extensions.get(self._context_key) + if lookup is None or not hasattr(lookup, "iri_for_subject"): + return None + return lookup diff --git a/wordlift_sdk/kg_build/postprocessors/service.py b/wordlift_sdk/kg_build/postprocessors/service.py new file mode 100644 index 0000000..f508bb6 --- /dev/null +++ b/wordlift_sdk/kg_build/postprocessors/service.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +import asyncio +import functools +import logging +import time +from concurrent.futures import ThreadPoolExecutor +from collections.abc import Iterable +from typing import Callable + +from rdflib import Graph + +from .graph_io import close_loaded_postprocessors +from .types import LoadedPostprocessor, PostprocessorContext, PostprocessorResult + +logger = logging.getLogger(__name__) + + +class PostprocessorService: + """Executes an ordered list of postprocessors against a graph. + + Completely agnostic to profiles and pipeline composition — callers are + responsible for assembling the postprocessor list and building the context. + """ + + def __init__( + self, + *, + postprocessors_factory: Callable[[], Iterable[LoadedPostprocessor]], + pool_size: int, + ) -> None: + self._executor = ThreadPoolExecutor( + max_workers=pool_size, thread_name_prefix="worai_pp" + ) + self._queue: asyncio.Queue = asyncio.Queue() + for _ in range(pool_size): + self._queue.put_nowait(postprocessors_factory()) + logger.info("Created postprocessor pool (pool_size=%d)", pool_size) + + async def apply( + self, + graph: Graph, + context: PostprocessorContext, + ) -> PostprocessorResult: + _t1 = time.perf_counter() + postprocessors = await self._queue.get() + queue_wait_ms = int((time.perf_counter() - _t1) * 1000) + loop = asyncio.get_event_loop() + try: + return await loop.run_in_executor( + self._executor, + functools.partial( + self._run, graph, context, postprocessors, queue_wait_ms + ), + ) + finally: + self._queue.put_nowait(postprocessors) + + def close(self) -> None: + while not self._queue.empty(): + try: + close_loaded_postprocessors(self._queue.get_nowait()) + except asyncio.QueueEmpty: + break + self._executor.shutdown(wait=False) + + def _run( + self, + graph: Graph, + context: PostprocessorContext, + postprocessors: Iterable[LoadedPostprocessor], + queue_wait_ms: int, + ) -> PostprocessorResult: + _t_start = time.perf_counter() + for processor in postprocessors: + _tp = time.perf_counter() + graph = processor.run(graph, context) + logger.info( + "Applied postprocessor '%s' for %s [%dms]", + processor.name, + context.url, + int((time.perf_counter() - _tp) * 1000), + ) + return PostprocessorResult( + graph=graph, + queue_wait_ms=queue_wait_ms, + postprocessors_ms=int((time.perf_counter() - _t_start) * 1000), + ) diff --git a/wordlift_sdk/kg_build/postprocessors/subprocess.py b/wordlift_sdk/kg_build/postprocessors/subprocess.py new file mode 100644 index 0000000..e5954af --- /dev/null +++ b/wordlift_sdk/kg_build/postprocessors/subprocess.py @@ -0,0 +1,399 @@ +from __future__ import annotations + +import asyncio +import importlib +import inspect +import json +import logging +import select +import shutil +import subprocess +import tempfile +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from rdflib import Graph + +from .types import ( + GraphPostprocessor, + PostprocessorContext, + PostprocessorRuntime, + PostprocessorSpec, + PersistentWorkerJobError, + PersistentWorkerTransportError, + _SubprocessRunner, +) + +logger = logging.getLogger(__name__) + + +class PersistentPostprocessorClient: + def __init__(self, *, spec: PostprocessorSpec, root_dir: Path) -> None: + self._spec = spec + self._root_dir = root_dir + self._process: subprocess.Popen[str] | None = None + self._next_job_id = 0 + + def close(self) -> None: + process = self._process + self._process = None + if process is None: + return + + try: + if process.poll() is None and process.stdin is not None: + process.stdin.write(json.dumps({"op": "shutdown"}) + "\n") + process.stdin.flush() + except Exception: + pass + + self._terminate(process) + + def process_graph( + self, + *, + input_graph_path: Path, + output_graph_path: Path, + context_payload: dict[str, Any], + ) -> None: + for attempt in range(2): + try: + self._process_graph_once( + input_graph_path=input_graph_path, + output_graph_path=output_graph_path, + context_payload=context_payload, + ) + return + except PersistentWorkerTransportError: + self.close() + if attempt == 1: + raise + + def _process_graph_once( + self, + *, + input_graph_path: Path, + output_graph_path: Path, + context_payload: dict[str, Any], + ) -> None: + process = self._ensure_started() + self._next_job_id += 1 + job_id = self._next_job_id + + payload = { + "op": "process", + "id": job_id, + "input_graph": str(input_graph_path), + "output_graph": str(output_graph_path), + "context": context_payload, + } + + try: + assert process.stdin is not None + process.stdin.write( + json.dumps(payload, ensure_ascii=True, default=str) + "\n" + ) + process.stdin.flush() + except Exception as exc: + raise PersistentWorkerTransportError( + f"Postprocessor worker stdin failed: {self._spec.class_path}" + ) from exc + + message = self._read_message( + process, timeout_seconds=self._spec.timeout_seconds + ) + if message.get("id") != job_id: + raise PersistentWorkerTransportError( + f"Postprocessor worker returned invalid response id for {self._spec.class_path}." + ) + if message.get("ok") is True: + return + + error = str(message.get("error") or "unknown worker error") + raise PersistentWorkerJobError( + f"Postprocessor failed: {self._spec.class_path}\n{error}".strip() + ) + + def _ensure_started(self) -> subprocess.Popen[str]: + process = self._process + if process is not None and process.poll() is None: + return process + + cmd = [ + self._spec.python, + "-m", + "wordlift_sdk.kg_build.postprocessors.persistent", + "--class", + self._spec.class_path, + ] + process = subprocess.Popen( + cmd, + text=True, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + cwd=str(self._root_dir), + bufsize=1, + ) + + try: + ready = self._read_message( + process, timeout_seconds=min(self._spec.timeout_seconds, 60) + ) + except Exception: + self._terminate(process) + raise + + if ready.get("op") != "ready" or ready.get("ok") is not True: + stderr = self._read_stderr(process) + self._terminate(process) + raise PersistentWorkerTransportError( + f"Postprocessor worker failed to start: {self._spec.class_path}" + + (f"\n{stderr}" if stderr else "") + ) + + self._process = process + return process + + def _read_message( + self, + process: subprocess.Popen[str], + *, + timeout_seconds: int, + ) -> dict[str, Any]: + if process.stdout is None: + raise PersistentWorkerTransportError("Worker stdout is unavailable.") + + ready, _, _ = select.select([process.stdout], [], [], timeout_seconds) + if not ready: + self._terminate(process) + cmd = ( + process.args if isinstance(process.args, list) else [str(process.args)] + ) + raise subprocess.TimeoutExpired(cmd=cmd, timeout=timeout_seconds) + + line = process.stdout.readline() + if not line: + stderr = self._read_stderr(process) + self._terminate(process) + raise PersistentWorkerTransportError( + f"Postprocessor worker exited unexpectedly: {self._spec.class_path}" + + (f"\n{stderr}" if stderr else "") + ) + + try: + return json.loads(line) + except json.JSONDecodeError as exc: + raise PersistentWorkerTransportError( + "Postprocessor worker returned invalid JSON response." + ) from exc + + def _read_stderr(self, process: subprocess.Popen[str]) -> str: + if process.stderr is None: + return "" + try: + return (process.stderr.read() or "").strip() + except Exception: + return "" + + def _terminate(self, process: subprocess.Popen[str]) -> None: + if process.poll() is None: + process.kill() + try: + process.wait(timeout=5) + except subprocess.TimeoutExpired: + pass + + +def _run_subprocess( + spec: PostprocessorSpec, + root_dir: Path, + graph: Graph, + payload: dict[str, Any], + runner: _SubprocessRunner, +) -> Graph | None: + """Shared scaffolding for subprocess-based postprocessors. + + Handles temp-dir lifecycle, graph serialization, output verification, + and debug-copy on failure. *runner* is called with the prepared paths + and is responsible only for the actual subprocess execution step. + """ + from .graph_io import _redact_debug_context, _read_graph_nquads, _write_graph_nquads + + temp_dir_path = Path(tempfile.mkdtemp(prefix="worai_pp_")) + failed = False + try: + input_graph_path = temp_dir_path / "input_graph.nq" + output_graph_path = temp_dir_path / "output_graph.nq" + context_path = temp_dir_path / "context.json" + + _write_graph_nquads(graph, input_graph_path) + context_path.write_text( + json.dumps(payload, ensure_ascii=True, default=str), + encoding="utf-8", + ) + + runner( + input_graph_path=input_graph_path, + output_graph_path=output_graph_path, + context_path=context_path, + context_payload=payload, + ) + + if not output_graph_path.exists(): + failed = True + raise RuntimeError( + f"Postprocessor did not produce output graph: {spec.class_path}" + ) + + return _read_graph_nquads(output_graph_path) + except Exception: + failed = True + raise + finally: + if failed and spec.keep_temp_on_error: + debug_dir = root_dir / "output" / "postprocessor_debug" + debug_dir.mkdir(parents=True, exist_ok=True) + target = debug_dir / (spec.class_path.replace(":", "_").replace(".", "_")) + if target.exists(): + shutil.rmtree(target) + shutil.copytree(temp_dir_path, target) + _redact_debug_context(target / "context.json") + if temp_dir_path.exists(): + shutil.rmtree(temp_dir_path, ignore_errors=True) + + +@dataclass(frozen=True) +class OneshotSubprocessPostprocessor: + spec: PostprocessorSpec + root_dir: Path + + def process_graph( + self, graph: Graph, context: PostprocessorContext + ) -> Graph | None: + from .graph_io import _build_runner_payload + + return _run_subprocess( + self.spec, self.root_dir, graph, _build_runner_payload(context), self._run + ) + + def _run( + self, + *, + input_graph_path: Path, + output_graph_path: Path, + context_path: Path, + **_: Any, + ) -> None: + cmd = [ + self.spec.python, + "-m", + "wordlift_sdk.kg_build.postprocessors.oneshot", + "--class", + self.spec.class_path, + "--input-graph", + str(input_graph_path), + "--output-graph", + str(output_graph_path), + "--context", + str(context_path), + ] + completed = subprocess.run( + cmd, + text=True, + capture_output=True, + cwd=str(self.root_dir), + timeout=self.spec.timeout_seconds, + check=False, + ) + if completed.returncode != 0: + stderr = (completed.stderr or "").strip() + raise RuntimeError( + f"Postprocessor failed: {self.spec.class_path} " + f"(exit={completed.returncode})" + (f"\n{stderr}" if stderr else "") + ) + + def close(self) -> None: + pass # oneshot processors have no persistent resources to release + + +@dataclass +class PersistentSubprocessPostprocessor: + spec: PostprocessorSpec + root_dir: Path + _client: PersistentPostprocessorClient | None = field( + init=False, + default=None, + repr=False, + ) + + def close(self) -> None: + if self._client is not None: + self._client.close() + self._client = None + + def process_graph( + self, graph: Graph, context: PostprocessorContext + ) -> Graph | None: + from .graph_io import _build_runner_payload + + return _run_subprocess( + self.spec, self.root_dir, graph, _build_runner_payload(context), self._run + ) + + def _run( + self, + *, + input_graph_path: Path, + output_graph_path: Path, + context_payload: dict[str, Any], + **_: Any, + ) -> None: + if self._client is None: + self._client = PersistentPostprocessorClient( + spec=self.spec, + root_dir=self.root_dir, + ) + self._client.process_graph( + input_graph_path=input_graph_path, + output_graph_path=output_graph_path, + context_payload=context_payload, + ) + + +@dataclass(frozen=True) +class InProcessPostprocessor: + class_path: str + + def process_graph( + self, graph: Graph, context: PostprocessorContext + ) -> Graph | None: + module_name, class_name = self.class_path.split(":", 1) + module = importlib.import_module(module_name) + klass = getattr(module, class_name) + processor = klass() + result = processor.process_graph(graph, context) + if inspect.isawaitable(result): + result = asyncio.run(result) + return result + + +def _build_handler( + spec: PostprocessorSpec, root_dir: Path, runtime: PostprocessorRuntime +) -> GraphPostprocessor: + if runtime == PostprocessorRuntime.INPROCESS: + return InProcessPostprocessor(class_path=spec.class_path) + if runtime == PostprocessorRuntime.PERSISTENT: + return PersistentSubprocessPostprocessor(spec=spec, root_dir=root_dir) + return OneshotSubprocessPostprocessor(spec=spec, root_dir=root_dir) + + +def _normalize_runtime(value: str | None) -> PostprocessorRuntime: + raw = (value or PostprocessorRuntime.ONESHOT.value).strip().lower() + try: + return PostprocessorRuntime(raw) + except ValueError: + raise ValueError( + "POSTPROCESSOR_RUNTIME must be one of: oneshot, persistent, inprocess." + ) diff --git a/wordlift_sdk/kg_build/postprocessors/types.py b/wordlift_sdk/kg_build/postprocessors/types.py new file mode 100644 index 0000000..a9323bd --- /dev/null +++ b/wordlift_sdk/kg_build/postprocessors/types.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from pathlib import Path +from typing import Any, Protocol, runtime_checkable + +from rdflib import Graph + + +class PostprocessorRuntime(str, Enum): + ONESHOT = "oneshot" + PERSISTENT = "persistent" + INPROCESS = "inprocess" + + +@dataclass(frozen=True) +class PostprocessorContext: + profile_name: str + profile: dict[str, Any] + url: str + account: Any + account_key: str | None + exports: dict[str, Any] + response: Any + existing_web_page_id: str | None + existing_import_hash: str | None = None + import_hash_mode: str = "on" + ids: Any | None = None + + +@dataclass(frozen=True) +class PostprocessorSpec: + class_path: str + python: str + timeout_seconds: int + enabled: bool + keep_temp_on_error: bool + + +class _SubprocessRunner(Protocol): + def __call__( + self, + *, + input_graph_path: Path, + output_graph_path: Path, + context_path: Path, + context_payload: dict[str, Any], + ) -> None: ... + + +@runtime_checkable +class Closeable(Protocol): + def close(self) -> None: ... + + +@runtime_checkable +class GraphPostprocessor(Protocol): + def process_graph( + self, graph: Graph, context: PostprocessorContext + ) -> Graph | None: ... + + +@dataclass(frozen=True) +class PostprocessorResult: + graph: Graph + queue_wait_ms: int + postprocessors_ms: int + + +@dataclass(frozen=True) +class LoadedPostprocessor: + name: str + handler: GraphPostprocessor + + def run(self, graph: Graph, context: PostprocessorContext) -> Graph: + result = self.handler.process_graph(graph, context) + return graph if result is None else result + + +class PersistentWorkerTransportError(RuntimeError): + pass + + +class PersistentWorkerJobError(RuntimeError): + pass diff --git a/wordlift_sdk/kg_build/protocol.py b/wordlift_sdk/kg_build/protocol.py index fd8ddc0..d05b3e2 100644 --- a/wordlift_sdk/kg_build/protocol.py +++ b/wordlift_sdk/kg_build/protocol.py @@ -4,52 +4,133 @@ import hashlib import logging import os -import tempfile +from concurrent.futures import ThreadPoolExecutor from dataclasses import asdict from pathlib import Path from types import SimpleNamespace from typing import Any from jinja2 import UndefinedError -from rdflib import Graph, Literal, RDF, URIRef +from rdflib import Graph, URIRef from wordlift_client.models.web_page_scrape_response import WebPageScrapeResponse from wordlift_sdk.protocol import Context from wordlift_sdk.protocol.web_page_import_protocol import ( WebPageImportProtocolInterface, ) -from wordlift_sdk.validation.shacl import ( - ValidationResult, - resolve_shape_specs, - validate_file, +from wordlift_sdk.validation.shacl import resolve_shape_specs +from wordlift_sdk.validation.shacl_validation_service import ( + ShaclValidationService, + ValidationMode, + ValidationOutcome, ) +from .cloud_flow import run_cloud_workflow as run_cloud_workflow # noqa: F401 from .config import ProfileDefinition from .entity_patcher import EntityPatcher -from .id_allocator import IdAllocator -from .id_postprocessor import CanonicalIdsPostprocessor from .kpi import KgBuildKpiCollector from .postprocessors import ( + LoadedPostprocessor, PostprocessorContext, - close_loaded_postprocessors, + PostprocessorResult, load_postprocessors_for_profile, ) -from .rml_mapping import RmlMappingService +from .postprocessors.processors.graph_annotation import ImportAnnotationPostprocessor +from .graph_utils import first_level_subjects +from .postprocessors.processors.id_allocator import IdAllocator +from .postprocessors.processors.id_postprocessor import ( + CanonicalIdsPostprocessor, + RootIdReconcilerPostprocessor, +) +from .postprocessors.service import PostprocessorService +from .rml_mapping import MappingResult, RmlMappingService from .templates import JinjaRdfTemplateReifier, TemplateTextRenderer +from wordlift_sdk.structured_data.engine import init_morph_kgc_pool logger = logging.getLogger(__name__) -SEOVOC_SOURCE = URIRef("https://w3id.org/seovoc/source") -SEOVOC_IMPORT_HASH = URIRef("https://w3id.org/seovoc/importHash") def _path_contains_part(path: str, part: str) -> bool: return part in Path(path).parts -def _resolve_postprocessor_runtime(settings: dict[str, Any]) -> str: - value = settings.get("postprocessor_runtime") +def _clean_key(value: Any) -> str | None: + key = str(value).strip() if value is not None else "" + return key or None + + +def _resolve_account_key(profile: Any, context: Any) -> str | None: + if key := _clean_key(getattr(profile, "api_key", None)): + return key + api_key_map = getattr( + getattr(context, "client_configuration", None), "api_key", None + ) + if isinstance(api_key_map, dict): + if key := _clean_key(api_key_map.get("ApiKey")): + return key + provider = getattr(context, "configuration_provider", None) + if provider is not None: + for name in ("WORDLIFT_KEY", "WORDLIFT_API_KEY"): + try: + if key := _clean_key(provider.get_value(name)): + return key + except Exception: + pass + for name in ("WORDLIFT_KEY", "WORDLIFT_API_KEY"): + if key := _clean_key(os.getenv(name)): + return key + return None + + +def _resolve_list_setting(value: Any) -> list[str]: if value is None: - value = settings.get("POSTPROCESSOR_RUNTIME") - return str(value or "persistent") + return [] + if isinstance(value, str): + return [part.strip() for part in value.split(",") if part.strip()] + if isinstance(value, (list, tuple)): + return [text for item in value if (text := str(item).strip())] + return [str(value).strip()] if str(value).strip() else [] + + +def _resolve_validation_mode(value: Any) -> ValidationMode: + if value is None: + return ValidationMode.WARN + mode = str(value).strip().lower() + if mode == "strict": + logger.warning( + "Deprecated SHACL validation mode 'strict' detected; using 'fail'." + ) + return ValidationMode.FAIL + try: + return ValidationMode(mode) + except ValueError: + logger.warning("Unsupported SHACL validation mode '%s'; using 'warn'.", mode) + return ValidationMode.WARN + + +def _resolve_import_hash_mode(value: Any) -> str: + if value is None: + return "on" + mode = str(value).strip().lower() + if mode in {"on", "write", "off"}: + return mode + logger.warning("Unsupported import hash mode '%s'; using 'on'.", mode) + return "on" + + +def _setting(settings: dict, name: str, fallback: str, default: Any) -> Any: + """Read a profile setting by snake_case name, falling back to UPPER_CASE, then default.""" + v = settings.get(name) + if v is None: + v = settings.get(fallback) + return default if v is None else v + + +def _resolve_postprocessor_runtime(settings: dict[str, Any]) -> str: + return str( + _setting( + settings, "postprocessor_runtime", "POSTPROCESSOR_RUNTIME", "persistent" + ) + ) class ProfileImportProtocol(WebPageImportProtocolInterface): @@ -76,6 +157,95 @@ def __init__( self._graph_write_strategy = graph_write_strategy self.profile_dir = self.root_dir / "profiles" / self.profile.name + + settings = dict(self.profile.settings) + _pool_size = int(_setting(settings, "concurrency", "CONCURRENCY", 4)) + self._init_postprocessor_service(settings, context, _pool_size) + self._init_mapping_service(settings, context, _pool_size) + self._init_shacl_validator(settings, _pool_size) + self._init_graph_writer(settings, context) + self._kpi = KgBuildKpiCollector( + dataset_uri=getattr(self.context.account, "dataset_uri", None), + validation_enabled=self._shacl_validator.mode != ValidationMode.OFF, + ) + logger.debug( + "Resolved mappings for profile '%s': effective_dir=%s (origin=%s), routes=%s (origin=%s), overlay_dirs=%s", + self.profile.name, + self.mappings_dir, + self.profile.origins.get("mappings_dir", "default"), + len(self.profile.routes), + self.profile.origins.get("routes", "default"), + [str(p) for p in self._mapping_dirs], + ) + + def _init_postprocessor_service( + self, settings: dict, context: Context, pool_size: int + ) -> None: + canonical_id_strategy = ( + str( + _setting( + settings, "canonical_id_strategy", "CANONICAL_ID_STRATEGY", "legacy" + ) + ) + .strip() + .lower() + ) + core_ids = CanonicalIdsPostprocessor(strategy=canonical_id_strategy) + runtime = _resolve_postprocessor_runtime(settings) + logger.info( + "Resolved postprocessor runtime for profile '%s': %s (origin=%s)", + self.profile.name, + runtime, + self.profile.origins.get("postprocessor_runtime", "default"), + ) + pp_pool_size = int( + _setting( + settings, + "postprocessor_pool_size", + "POSTPROCESSOR_POOL_SIZE", + pool_size, + ) + ) + logger.info( + "Postprocessor pool size for profile '%s': %d (concurrency=%d)", + self.profile.name, + pp_pool_size, + pool_size, + ) + account_key = _resolve_account_key(self.profile, context) + root_dir = self.root_dir + profile = self.profile + + def _postprocessors_factory() -> list[LoadedPostprocessor]: + leading = [ + LoadedPostprocessor( + name="root_id_reconciler", + handler=RootIdReconcilerPostprocessor(), + ) + ] + custom = load_postprocessors_for_profile( + root_dir=root_dir, + profile_name=profile.name, + runtime=runtime, + ) + trailing = [ + LoadedPostprocessor(name="canonical_ids", handler=core_ids), + LoadedPostprocessor( + name="import_annotation", + handler=ImportAnnotationPostprocessor(), + ), + ] + return leading + custom + trailing + + self._account_key = account_key + self._postprocessor_service = PostprocessorService( + postprocessors_factory=_postprocessors_factory, + pool_size=pp_pool_size, + ) + + def _init_mapping_service( + self, settings: dict, context: Context, pool_size: int + ) -> None: self.templates_dir = self._resolve_path(self.profile.templates_dir) self.mappings_dir = self._resolve_path(self.profile.mappings_dir) self._template_dirs = self._resolve_overlay_paths( @@ -84,88 +254,70 @@ def __init__( self._mapping_dirs = self._resolve_overlay_paths( self.profile.mapping_overlay_dirs or (self.profile.mappings_dir,) ) - - self.rml_service = RmlMappingService(context) - self.patcher = EntityPatcher(context) self.template_reifier = JinjaRdfTemplateReifier(self._template_dirs) self.text_renderer = TemplateTextRenderer() - self._template_graph: Graph | None = None self._template_exports: dict[str, Any] | None = None self._mapping_cache: dict[Path, str] = {} self._static_templates_patched = False self._static_templates_lock = asyncio.Lock() - canonical_id_strategy = ( - str( - self.profile.settings.get( - "canonical_id_strategy", - self.profile.settings.get("CANONICAL_ID_STRATEGY", "legacy"), - ) + self.rml_service = RmlMappingService(context) + mapping_pool_size = int( + _setting( + settings, "mapping_pool_size", "MAPPING_POOL_SIZE", os.cpu_count() or 4 ) - .strip() - .lower() - ) - self._core_ids = CanonicalIdsPostprocessor(strategy=canonical_id_strategy) - self._postprocessor_runtime = _resolve_postprocessor_runtime( - dict(self.profile.settings) ) logger.info( - "Resolved postprocessor runtime for profile '%s': %s (origin=%s)", + "Mapping pool size for profile '%s': %d", self.profile.name, - self._postprocessor_runtime, - self.profile.origins.get("postprocessor_runtime", "default"), + mapping_pool_size, ) - self._postprocessors = load_postprocessors_for_profile( - root_dir=self.root_dir, - profile_name=self.profile.name, - runtime=self._postprocessor_runtime, + init_morph_kgc_pool(mapping_pool_size) + # Wraps apply_mapping calls so they run in a thread rather than blocking + # the asyncio event loop. The thread itself blocks on the morph_kgc + # ProcessPoolExecutor slot, leaving the event loop free for I/O. + self._mapping_executor = ThreadPoolExecutor( + max_workers=pool_size, thread_name_prefix="worai_ml" ) - self._shacl_mode = self._resolve_validation_mode( - self.profile.settings.get( - "shacl_validate_mode", - self.profile.settings.get("SHACL_VALIDATE_MODE", "warn"), - ) + + def _init_shacl_validator(self, settings: dict, pool_size: int) -> None: + mode = _resolve_validation_mode( + _setting(settings, "shacl_validate_mode", "SHACL_VALIDATE_MODE", "warn") ) - shacl_builtin_shapes = self._resolve_list_setting( - self.profile.settings.get( - "shacl_builtin_shapes", - self.profile.settings.get("SHACL_BUILTIN_SHAPES"), - ) + builtin_shapes = _resolve_list_setting( + _setting(settings, "shacl_builtin_shapes", "SHACL_BUILTIN_SHAPES", None) ) - shacl_exclude_builtin_shapes = self._resolve_list_setting( - self.profile.settings.get( + exclude_builtin_shapes = _resolve_list_setting( + _setting( + settings, "shacl_exclude_builtin_shapes", - self.profile.settings.get("SHACL_EXCLUDE_BUILTIN_SHAPES"), + "SHACL_EXCLUDE_BUILTIN_SHAPES", + None, ) ) - shacl_extra_shapes = self._resolve_list_setting( - self.profile.settings.get( - "shacl_extra_shapes", self.profile.settings.get("SHACL_EXTRA_SHAPES") - ) + extra_shapes = _resolve_list_setting( + _setting(settings, "shacl_extra_shapes", "SHACL_EXTRA_SHAPES", None) ) - self._shacl_shape_specs = resolve_shape_specs( - builtin_shapes=shacl_builtin_shapes or None, - exclude_builtin_shapes=shacl_exclude_builtin_shapes or None, - extra_shapes=shacl_extra_shapes or None, + shape_specs = resolve_shape_specs( + builtin_shapes=builtin_shapes or None, + exclude_builtin_shapes=exclude_builtin_shapes or None, + extra_shapes=extra_shapes or None, ) - self._import_hash_mode = self._resolve_import_hash_mode( - self.profile.settings.get( - "import_hash_mode", - self.profile.settings.get("IMPORT_HASH_MODE", "on"), + shacl_pool_size = int( + _setting( + settings, "shacl_pool_size", "SHACL_POOL_SIZE", max(2, pool_size // 2) ) ) - self._kpi = KgBuildKpiCollector( - dataset_uri=getattr(self.context.account, "dataset_uri", None), - validation_enabled=self._shacl_mode != "off", + self._shacl_validator = ShaclValidationService( + shape_specs=shape_specs or None, + mode=mode, + pool_size=shacl_pool_size, ) - logger.debug( - "Resolved mappings for profile '%s': effective_dir=%s (origin=%s), routes=%s (origin=%s), overlay_dirs=%s", - self.profile.name, - self.mappings_dir, - self.profile.origins.get("mappings_dir", "default"), - len(self.profile.routes), - self.profile.origins.get("routes", "default"), - [str(p) for p in self._mapping_dirs], + + def _init_graph_writer(self, settings: dict, context: Context) -> None: + self.patcher = EntityPatcher(context) + self._import_hash_mode = _resolve_import_hash_mode( + _setting(settings, "import_hash_mode", "IMPORT_HASH_MODE", "on") ) async def callback( @@ -179,79 +331,162 @@ async def callback( if hasattr(response, "web_page") and response.web_page else "Unknown URL" ) - if hasattr(response, "errors") and response.errors: logger.error("Cloud callback error for %s: %s", url, response.errors) return - if not response.web_page or not response.web_page.html: logger.warning("No HTML content for %s, skipping mapping", url) return await self._patch_static_templates_once() - mapping_path = self._resolve_mapping_path(url) - rendered_mapping = self._get_mapping_content(mapping_path) - mapping_response = self._mapping_response(response, existing_web_page_id) debug_output: dict[str, str] | None = {} if self.debug_dir else None - - graph = await self.rml_service.apply_mapping( - html=response.web_page.html, - url=url, - mapping_file_path=mapping_path, - mapping_content=rendered_mapping, - response=mapping_response, - debug_output=debug_output, + mapping = await self._run_mapping_stage( + response, url, existing_web_page_id, debug_output ) - if not graph or len(graph) == 0: + if not mapping.graph or len(mapping.graph) == 0: logger.warning("No triples produced for %s", url) return - if existing_web_page_id: - self._reconcile_root_id(graph, existing_web_page_id) - graph = self._apply_postprocessors(graph, url, response, existing_web_page_id) - # Canonical IDs must run after custom postprocessors so any nodes minted - # by local logic are normalized before graph sync patching. - graph = self._core_ids.process_graph( - graph, self._build_pp_context(url, response, existing_web_page_id) + pp_result = await self._run_postprocessing_stage( + mapping.graph, url, response, existing_web_page_id, existing_import_hash ) - self._set_source(graph, existing_web_page_id) - self._set_existing_import_hash(graph, existing_import_hash) if self.debug_dir: xhtml = (debug_output or {}).get("xhtml") self._write_debug_source_documents( url=url, html=response.web_page.html, xhtml=xhtml ) - self._write_debug_graph(graph, url) + self._write_debug_graph(pp_result.graph, url) - validation_payload = self._validate_graph_if_enabled(graph, url) - graph_metrics = self._kpi.graph_metrics(graph) + outcome: ValidationOutcome | None = await self._shacl_validator.validate( + pp_result.graph + ) + if outcome is not None: + logger.info( + "SHACL validation for %s: pass=%s warnings=%d errors=%d", + url, + outcome.passed, + outcome.warning_count, + outcome.error_count, + ) + self._kpi.record_validation(outcome) + self._kpi.record_graph(pp_result.graph) self._emit_progress( { "kind": "graph", "profile": self.profile.name, "url": url, - "graph": graph_metrics, - "validation": validation_payload, + "graph": self._kpi.graph_metrics(pp_result.graph), + "validation": outcome.to_dict() if outcome else None, } ) - self._kpi.record_graph(graph) if ( - validation_payload is not None - and self._shacl_mode == "fail" - and not validation_payload["pass"] + outcome is not None + and self._shacl_validator.mode == ValidationMode.FAIL + and outcome.failed ): raise RuntimeError(f"SHACL validation failed for {url} in fail mode.") - await self._write_graph(graph) - logger.info("Wrote %s triples for %s", len(graph), url) + await self._write_graph(pp_result.graph) + logger.info( + "Wrote %s triples for %s [mapping_wait=%dms mapping=%dms postprocessor_wait=%dms postprocessors=%dms validation_wait=%dms validation=%dms]", + len(pp_result.graph), + url, + mapping.queue_wait_ms, + mapping.mapping_ms, + pp_result.queue_wait_ms, + pp_result.postprocessors_ms, + outcome.queue_wait_ms if outcome else 0, + outcome.validation_ms if outcome else 0, + ) - def close(self) -> None: - close_loaded_postprocessors(self._postprocessors) + async def close(self) -> None: + self._postprocessor_service.close() + self._mapping_executor.shutdown(wait=False) + self._shacl_validator.close() + await self.context.graph_queue.close() def get_kpi_summary(self) -> dict[str, object]: return self._kpi.summary(self.profile.name) + @property + def _dataset_uri(self) -> str: + return str(getattr(self.context.account, "dataset_uri", "") or "").rstrip("/") + + @staticmethod + def _url_hash(url: str) -> str: + return hashlib.sha256(url.encode("utf-8")).hexdigest() + + async def _run_mapping_stage( + self, + response: WebPageScrapeResponse, + url: str, + existing_web_page_id: str | None, + debug_output: dict[str, str] | None, + ) -> MappingResult: + mapping_path = self._resolve_mapping_path(url) + rendered_mapping = self._get_mapping_content(mapping_path) + mapping_response = self._mapping_response(response, existing_web_page_id) + + def _run() -> MappingResult: + # apply_mapping has no awaits — all work is synchronous (morph_kgc). + # Run in a thread so the event loop stays free for I/O while the + # thread waits for its morph_kgc subprocess slot. + return asyncio.run( + self.rml_service.apply_mapping( + html=response.web_page.html, + url=url, + mapping_file_path=mapping_path, + mapping_content=rendered_mapping, + response=mapping_response, + debug_output=debug_output, + ) + ) + + return await asyncio.get_event_loop().run_in_executor( + self._mapping_executor, _run + ) + + async def _run_postprocessing_stage( + self, + graph: Graph, + url: str, + response: WebPageScrapeResponse, + existing_web_page_id: str | None, + existing_import_hash: str | None, + ) -> PostprocessorResult: + context = self._build_pp_context( + url, response, existing_web_page_id, existing_import_hash + ) + return await self._postprocessor_service.apply(graph, context) + + def _build_pp_context( + self, + url: str, + response: WebPageScrapeResponse, + existing_web_page_id: str | None, + existing_import_hash: str | None, + ) -> PostprocessorContext: + dataset_uri = self._dataset_uri + ids = IdAllocator(dataset_uri) if dataset_uri else None + profile_payload = asdict(self.profile) + profile_settings = dict(profile_payload.get("settings", {}) or {}) + profile_settings.setdefault("api_url", "https://api.wordlift.io") + profile_payload["settings"] = profile_settings + return PostprocessorContext( + profile_name=self.profile.name, + profile=profile_payload, + url=url, + account=self.context.account, + account_key=self._account_key, + exports=self._template_exports or {}, + response=response, + existing_web_page_id=existing_web_page_id, + existing_import_hash=existing_import_hash, + import_hash_mode=self._import_hash_mode, + ids=ids, + ) + def _resolve_path(self, raw_path: str) -> Path: path = Path(raw_path) if path.is_absolute(): @@ -284,22 +519,28 @@ async def _patch_static_templates_once(self) -> None: self._ensure_templates_loaded() if self._template_graph and len(self._template_graph) > 0: - validation_payload = self._validate_graph_if_enabled( - self._template_graph, "static_templates" - ) + outcome = await self._shacl_validator.validate(self._template_graph) + if outcome is not None: + logger.info( + "SHACL validation for static_templates: pass=%s warnings=%d errors=%d", + outcome.passed, + outcome.warning_count, + outcome.error_count, + ) + self._kpi.record_validation(outcome) self._emit_progress( { "kind": "static_templates", "profile": self.profile.name, "graph": self._kpi.graph_metrics(self._template_graph), - "validation": validation_payload, + "validation": outcome.to_dict() if outcome else None, } ) self._kpi.record_graph(self._template_graph) if ( - validation_payload is not None - and self._shacl_mode == "fail" - and not validation_payload["pass"] + outcome is not None + and self._shacl_validator.mode == ValidationMode.FAIL + and outcome.failed ): raise RuntimeError( "SHACL validation failed for static templates in fail mode." @@ -321,13 +562,13 @@ def _ensure_templates_loaded(self) -> None: if self._template_graph is not None and self._template_exports is not None: return - dataset_uri = getattr(self.context.account, "dataset_uri", None) + dataset_uri = self._dataset_uri if not dataset_uri: raise RuntimeError("Dataset URI not available on context.account.") base_context = { "account": self.context.account, - "dataset_uri": str(dataset_uri).rstrip("/"), + "dataset_uri": dataset_uri, } exports, exports_summary = self.text_renderer.load_exports_with_summary( self._template_dirs, base_context @@ -393,7 +634,7 @@ def _get_mapping_content(self, mapping_path: Path) -> str: if cached is not None: return cached - dataset_uri = getattr(self.context.account, "dataset_uri", None) + dataset_uri = self._dataset_uri if not dataset_uri: raise RuntimeError("Dataset URI not available on context.account.") @@ -401,7 +642,7 @@ def _get_mapping_content(self, mapping_path: Path) -> str: context = { "account": self.context.account, - "dataset_uri": str(dataset_uri).rstrip("/"), + "dataset_uri": dataset_uri, "exports": self._template_exports or {}, } template_path = self.text_renderer.resolve_mapping_template(mapping_path) @@ -418,9 +659,7 @@ async def _write_graph(self, graph: Graph) -> None: await self.patcher.patch_all(graph, import_hash_mode=self._import_hash_mode) def _prepare_graph_for_put(self, graph: Graph) -> bool: - dataset_uri = str( - getattr(self.context.account, "dataset_uri", "") or "" - ).rstrip("/") + dataset_uri = self._dataset_uri if not dataset_uri: return False @@ -432,23 +671,23 @@ def _prepare_graph_for_put(self, graph: Graph) -> bool: if not subjects: return False - first_level_subjects = { + page_subjects = { subject - for subject in self._first_level_subjects(graph) + for subject in first_level_subjects(graph, dataset_uri) if subject in subjects } - if not first_level_subjects: + if not page_subjects: return False if self._import_hash_mode == "off": return True - representative = next(iter(first_level_subjects)) + representative = next(iter(page_subjects)) existing_hash = self.patcher._existing_import_hash(representative, graph) import_hash = self.patcher._compute_import_hash( representative, graph, dataset_uri ) - for subject in first_level_subjects: + for subject in page_subjects: self.patcher._set_import_hash(subject, graph, import_hash) return not ( @@ -457,93 +696,10 @@ def _prepare_graph_for_put(self, graph: Graph) -> bool: and existing_hash == import_hash ) - def _apply_postprocessors( - self, - graph: Graph, - url: str, - response: WebPageScrapeResponse, - existing_web_page_id: str | None, - ) -> Graph: - if not self._postprocessors: - return graph - - pp_context = self._build_pp_context(url, response, existing_web_page_id) - if not pp_context.account_key: - raise RuntimeError( - "Postprocessor runtime requires an API key. Configure one via profile " - "'api_key', WORDLIFT_KEY, or WORDLIFT_API_KEY." - ) - - for processor in self._postprocessors: - graph = processor.run(graph, pp_context) - logger.info("Applied postprocessor '%s' for %s", processor.name, url) - return graph - - def _build_pp_context( - self, - url: str, - response: WebPageScrapeResponse, - existing_web_page_id: str | None, - ) -> PostprocessorContext: - dataset_uri = str(getattr(self.context.account, "dataset_uri", "")).rstrip("/") - ids = IdAllocator(dataset_uri) if dataset_uri else None - profile_payload = asdict(self.profile) - profile_settings = dict(profile_payload.get("settings", {}) or {}) - profile_settings.setdefault("api_url", "https://api.wordlift.io") - profile_payload["settings"] = profile_settings - return PostprocessorContext( - profile_name=self.profile.name, - profile=profile_payload, - url=url, - account=self.context.account, - account_key=self._resolve_postprocessor_account_key(), - exports=self._template_exports or {}, - response=response, - existing_web_page_id=existing_web_page_id, - ids=ids, - ) - - def _resolve_postprocessor_account_key(self) -> str | None: - profile_key = self._clean_key(self.profile.api_key) - if profile_key: - return profile_key - - client_config = getattr(self.context, "client_configuration", None) - if client_config is not None: - api_key_map = getattr(client_config, "api_key", None) - if isinstance(api_key_map, dict): - runtime_key = self._clean_key(api_key_map.get("ApiKey")) - if runtime_key: - return runtime_key - - provider = getattr(self.context, "configuration_provider", None) - if provider is not None: - for name in ("WORDLIFT_KEY", "WORDLIFT_API_KEY"): - try: - key = self._clean_key(provider.get_value(name)) - except Exception: - key = None - if key: - return key - - for name in ("WORDLIFT_KEY", "WORDLIFT_API_KEY"): - key = self._clean_key(os.getenv(name)) - if key: - return key - - return None - - @staticmethod - def _clean_key(value: Any) -> str | None: - if value is None: - return None - key = str(value).strip() - return key or None - def _write_debug_graph(self, graph: Graph, url: str) -> None: assert self.debug_dir is not None self.debug_dir.mkdir(parents=True, exist_ok=True) - safe_name = hashlib.sha256(url.encode("utf-8")).hexdigest() + safe_name = self._url_hash(url) debug_file = self.debug_dir / f"{safe_name}.ttl" graph.serialize(destination=debug_file, format="turtle") @@ -552,81 +708,13 @@ def _write_debug_source_documents( ) -> None: assert self.debug_dir is not None self.debug_dir.mkdir(parents=True, exist_ok=True) - safe_name = hashlib.sha256(url.encode("utf-8")).hexdigest() + safe_name = self._url_hash(url) html_file = self.debug_dir / f"{safe_name}.html" html_file.write_text(html, encoding="utf-8") if xhtml: xhtml_file = self.debug_dir / f"{safe_name}.xhtml" xhtml_file.write_text(xhtml, encoding="utf-8") - def _reconcile_root_id(self, graph: Graph, root_id: str) -> None: - old_iri = self._find_web_page_iri(graph) - if old_iri and str(old_iri) != root_id: - self._swap_iris(graph, old_iri, URIRef(root_id)) - - def _find_web_page_iri(self, graph: Graph) -> URIRef | None: - for subject in graph.subjects(RDF.type, URIRef("http://schema.org/WebPage")): - return subject - for subject in graph.subjects(RDF.type, URIRef("https://schema.org/WebPage")): - return subject - return None - - def _swap_iris(self, graph: Graph, old_iri: URIRef, new_iri: URIRef) -> None: - for subject, predicate, obj in list(graph.triples((old_iri, None, None))): - graph.remove((subject, predicate, obj)) - graph.add((new_iri, predicate, obj)) - for subject, predicate, obj in list(graph.triples((None, None, old_iri))): - graph.remove((subject, predicate, obj)) - graph.add((subject, predicate, new_iri)) - - def _set_source(self, graph: Graph, existing_web_page_id: str | None) -> None: - del existing_web_page_id - for subject in self._first_level_subjects(graph): - graph.set((subject, SEOVOC_SOURCE, Literal("web-page-import"))) - - def _set_existing_import_hash(self, graph: Graph, import_hash: str | None) -> None: - if self._import_hash_mode == "off": - return - if not import_hash: - return - subjects = { - subject for subject in graph.subjects() if isinstance(subject, URIRef) - } - for subject in subjects: - graph.set((subject, SEOVOC_IMPORT_HASH, Literal(import_hash))) - - def _first_level_subjects(self, graph: Graph) -> set[URIRef]: - subjects = { - subject for subject in graph.subjects() if isinstance(subject, URIRef) - } - dataset_uri = str( - getattr(self.context.account, "dataset_uri", "") or "" - ).rstrip("/") - if dataset_uri: - first_level_by_id = { - subject - for subject in subjects - if str(subject).startswith(f"{dataset_uri}/") - and len( - [ - part - for part in str(subject)[len(dataset_uri) + 1 :].split("/") - if part - ] - ) - == 2 - } - if first_level_by_id: - return first_level_by_id - - referenced = { - obj - for _, _, obj in graph.triples((None, None, None)) - if isinstance(obj, URIRef) and obj in subjects - } - first_level = subjects - referenced - return first_level or subjects - def _mapping_response( self, response: WebPageScrapeResponse, @@ -640,95 +728,6 @@ def _mapping_response( web_page=response.web_page, ) - def _validate_graph_if_enabled( - self, graph: Graph, url: str - ) -> dict[str, Any] | None: - if self._shacl_mode == "off": - return None - result = self._validate_graph(graph) - summary = self._summarize_validation(result) - self._kpi.record_validation( - passed=summary["pass"], - warning_count=summary["warnings"]["count"], - error_count=summary["errors"]["count"], - warning_sources=summary["warnings"]["sources"], - error_sources=summary["errors"]["sources"], - ) - logger.info( - "SHACL validation for %s: pass=%s warnings=%s errors=%s", - url, - summary["pass"], - summary["warnings"]["count"], - summary["errors"]["count"], - ) - return summary - - def _validate_graph(self, graph: Graph) -> ValidationResult: - with tempfile.NamedTemporaryFile(mode="w", suffix=".ttl", delete=False) as f: - tmp = Path(f.name) - try: - graph.serialize(destination=tmp, format="turtle") - return validate_file( - str(tmp), - shape_specs=self._shacl_shape_specs - if self._shacl_shape_specs - else None, - ) - finally: - try: - tmp.unlink(missing_ok=True) - except Exception: - logger.debug("Failed to remove temporary SHACL graph file: %s", tmp) - - def _summarize_validation(self, result: ValidationResult) -> dict[str, Any]: - sh = URIRef("http://www.w3.org/ns/shacl#") - sh_warning = URIRef(f"{sh}Warning") - sh_violation = URIRef(f"{sh}Violation") - sh_source_shape = URIRef(f"{sh}sourceShape") - - warning_sources: dict[str, int] = {} - error_sources: dict[str, int] = {} - warning_count = 0 - error_count = 0 - - for report_node in result.report_graph.subjects( - URIRef(f"{sh}resultSeverity"), sh_warning - ): - warning_count += 1 - shape = next( - result.report_graph.objects(report_node, sh_source_shape), None - ) - label = result.shape_source_map.get(shape, "unknown") - warning_sources[str(label)] = warning_sources.get(str(label), 0) + 1 - - for report_node in result.report_graph.subjects( - URIRef(f"{sh}resultSeverity"), sh_violation - ): - error_count += 1 - shape = next( - result.report_graph.objects(report_node, sh_source_shape), None - ) - label = result.shape_source_map.get(shape, "unknown") - error_sources[str(label)] = error_sources.get(str(label), 0) + 1 - - return { - "total": 1, - "pass": bool(result.conforms), - "fail": not bool(result.conforms), - "warnings": { - "count": warning_count, - "sources": dict( - sorted(warning_sources.items(), key=lambda item: item[0]) - ), - }, - "errors": { - "count": error_count, - "sources": dict( - sorted(error_sources.items(), key=lambda item: item[0]) - ), - }, - } - def _emit_progress(self, payload: dict[str, Any]) -> None: if not callable(self._on_progress): return @@ -736,40 +735,3 @@ def _emit_progress(self, payload: dict[str, Any]) -> None: self._on_progress(payload) except Exception: logger.warning("Failed to emit kg_build progress payload.", exc_info=True) - - def _resolve_list_setting(self, value: Any) -> list[str]: - if value is None: - return [] - if isinstance(value, str): - return [part.strip() for part in value.split(",") if part.strip()] - if isinstance(value, (list, tuple)): - specs: list[str] = [] - for item in value: - text = str(item).strip() - if text: - specs.append(text) - return specs - return [str(value).strip()] if str(value).strip() else [] - - def _resolve_validation_mode(self, value: Any) -> str: - if value is None: - return "warn" - mode = str(value).strip().lower() - if mode == "strict": - logger.warning( - "Deprecated SHACL validation mode 'strict' detected; using 'fail'." - ) - return "fail" - if mode in {"off", "warn", "fail"}: - return mode - logger.warning("Unsupported SHACL validation mode '%s'; using 'warn'.", mode) - return "warn" - - def _resolve_import_hash_mode(self, value: Any) -> str: - if value is None: - return "on" - mode = str(value).strip().lower() - if mode in {"on", "write", "off"}: - return mode - logger.warning("Unsupported import hash mode '%s'; using 'on'.", mode) - return "on" diff --git a/wordlift_sdk/kg_build/rml_mapping.py b/wordlift_sdk/kg_build/rml_mapping.py index 5b40a91..6d9ef39 100644 --- a/wordlift_sdk/kg_build/rml_mapping.py +++ b/wordlift_sdk/kg_build/rml_mapping.py @@ -4,23 +4,39 @@ import logging import os import tempfile +import time +from dataclasses import dataclass from pathlib import Path from typing import Any from rdflib import Graph from wordlift_sdk.protocol import Context +from wordlift_sdk.structured_data.engine import _morph_kgc_tls from wordlift_sdk.structured_data.materialization import MaterializationPipeline from wordlift_sdk.utils.html_converter import HtmlConverter logger = logging.getLogger(__name__) +@dataclass(frozen=True) +class MappingResult: + graph: Graph | None + queue_wait_ms: int + mapping_ms: int + + class RmlMappingService: - def __init__(self, context: Context) -> None: + def __init__( + self, + context: Context, + pipeline: MaterializationPipeline | None = None, + html_converter: HtmlConverter | None = None, + ) -> None: self._context = context - self._html_converter = HtmlConverter() + self._pipeline = pipeline or MaterializationPipeline() + self._html_converter = html_converter or HtmlConverter() - def to_xhtml(self, html: str) -> str: + def _to_xhtml(self, html: str) -> str: return self._html_converter.convert(html) async def apply_mapping( @@ -32,9 +48,11 @@ async def apply_mapping( mapping_content: str | None = None, response: object | None = None, debug_output: dict[str, str] | None = None, - ) -> Graph | None: + ) -> MappingResult: + queue_wait_ms = 0 + _t_start = time.perf_counter() try: - xhtml_str = xhtml or self.to_xhtml(html) + xhtml_str = xhtml or self._to_xhtml(html) if debug_output is not None: debug_output["xhtml"] = xhtml_str @@ -50,27 +68,31 @@ async def apply_mapping( resolved_mapping_content = f.read() except FileNotFoundError: logger.error("Mapping file not found: %s", mapping_file_path) - return None + return MappingResult( + graph=None, + queue_wait_ms=queue_wait_ms, + mapping_ms=int((time.perf_counter() - _t_start) * 1000), + ) dataset_uri = getattr(self._context.account, "dataset_uri", None) if not dataset_uri: raise RuntimeError("Dataset URI not available on context.account.") - pipeline = MaterializationPipeline() - normalized_yarrrml, mappings = pipeline.normalize( + normalized_yarrrml, mappings = self._pipeline.normalize( resolved_mapping_content, url, Path(data_path), response=response, ) - jsonld_raw = pipeline.materialize( + jsonld_raw = self._pipeline.materialize( normalized_yarrrml, Path(data_path), Path(temp_dir), url=url, response=response, ) - jsonld_data = pipeline.postprocess( + queue_wait_ms = getattr(_morph_kgc_tls, "mapping_wait_ms", 0) + jsonld_data = self._pipeline.postprocess( jsonld_raw, mappings, xhtml_str, @@ -93,7 +115,12 @@ async def apply_mapping( "No triples generated from mapping %s.", mapping_file_path ) - return graph + return MappingResult( + graph=graph, + queue_wait_ms=queue_wait_ms, + mapping_ms=int((time.perf_counter() - _t_start) * 1000) + - queue_wait_ms, + ) except Exception as exc: logger.error( @@ -102,7 +129,11 @@ async def apply_mapping( exc, exc_info=True, ) - return None + return MappingResult( + graph=None, + queue_wait_ms=queue_wait_ms, + mapping_ms=int((time.perf_counter() - _t_start) * 1000), + ) def _normalize_schema_uris(self, payload: Any): if isinstance(payload, dict): diff --git a/wordlift_sdk/protocol/graph/graph_queue.py b/wordlift_sdk/protocol/graph/graph_queue.py index 1ad8e33..56e818a 100644 --- a/wordlift_sdk/protocol/graph/graph_queue.py +++ b/wordlift_sdk/protocol/graph/graph_queue.py @@ -8,7 +8,13 @@ from rdflib import Graph from rdflib.compare import to_isomorphic from wordlift_client import Configuration -from tenacity import retry, retry_if_exception_type, wait_fixed, after_log +from tenacity import ( + retry, + retry_if_exception_type, + wait_fixed, + after_log, + stop_after_attempt, +) logger = logging.getLogger(__name__) @@ -20,9 +26,37 @@ class GraphQueue: def __init__(self, client_configuration: Configuration): self.client_configuration = client_configuration self.hashes = set() + self._api_client: wordlift_client.ApiClient | None = None + self._api_client_lock: asyncio.Lock | None = None + + async def _get_api_client(self) -> wordlift_client.ApiClient: + # Lazy-init the lock (must be created on the event loop). + if self._api_client_lock is None: + self._api_client_lock = asyncio.Lock() + if self._api_client is not None: + return self._api_client + async with self._api_client_lock: + if self._api_client is None: + # ApiClient.__init__ calls ssl.create_default_context() synchronously + # and must run on the event loop thread (it calls asyncio internals). + # Creating it once and caching avoids repeated SSL cert loading per put(). + client = wordlift_client.ApiClient( + configuration=self.client_configuration + ) + await client.__aenter__() + self._api_client = client + return self._api_client + + async def close(self) -> None: + if self._api_client is not None: + try: + await self._api_client.__aexit__(None, None, None) + except Exception: + pass + self._api_client = None @retry( - # stop=stop_after_attempt(5), # Retry up to 5 times + stop=stop_after_attempt(5), retry=retry_if_exception_type( asyncio.TimeoutError | aiohttp.client_exceptions.ServerDisconnectedError @@ -39,23 +73,22 @@ def __init__(self, client_configuration: Configuration): reraise=True, ) async def put(self, graph: Graph) -> None: - hash = GraphQueue.hash_graph(graph) + loop = asyncio.get_event_loop() + hash = await loop.run_in_executor(None, GraphQueue.hash_graph, graph) if hash not in self.hashes: self.hashes.add(hash) - async with wordlift_client.ApiClient( - configuration=self.client_configuration - ) as api_client: - api_instance = wordlift_client.EntitiesApi(api_client) + api_client = await self._get_api_client() + api_instance = wordlift_client.EntitiesApi(api_client) - try: - await api_instance.create_or_update_entities( - graph.serialize(format="turtle"), - _content_type="text/turtle", - ) - except Exception as e: - logger.error(f"Failed to create entities: {e}", exc_info=e) - raise e + try: + await api_instance.create_or_update_entities( + graph.serialize(format="turtle"), + _content_type="text/turtle", + ) + except Exception as e: + logger.error(f"Failed to create entities: {e}", exc_info=e) + raise e @staticmethod def hash_graph(graph: Graph) -> str: diff --git a/wordlift_sdk/structured_data/__init__.py b/wordlift_sdk/structured_data/__init__.py index 1e4d977..c5181c9 100644 --- a/wordlift_sdk/structured_data/__init__.py +++ b/wordlift_sdk/structured_data/__init__.py @@ -21,12 +21,12 @@ _EXPORTS = { - "CreateRequest": ("wordlift_sdk.structured_data.models", "CreateRequest"), + "CreateRequest": ("wordlift_sdk.structured_data.orchestrator", "CreateRequest"), "CreateWorkflow": ( "wordlift_sdk.structured_data.orchestrator", "CreateWorkflow", ), - "GenerateRequest": ("wordlift_sdk.structured_data.models", "GenerateRequest"), + "GenerateRequest": ("wordlift_sdk.structured_data.orchestrator", "GenerateRequest"), "GenerateWorkflow": ( "wordlift_sdk.structured_data.orchestrator", "GenerateWorkflow", diff --git a/wordlift_sdk/structured_data/engine.py b/wordlift_sdk/structured_data/engine.py index 2fec73d..3d27952 100644 --- a/wordlift_sdk/structured_data/engine.py +++ b/wordlift_sdk/structured_data/engine.py @@ -6,7 +6,10 @@ import hashlib import json import logging +import multiprocessing +import os import re +from concurrent.futures import ProcessPoolExecutor from dataclasses import dataclass from importlib import resources from pathlib import Path @@ -29,6 +32,58 @@ from wordlift_sdk.validation.shacl import ValidationResult, validate_file +import threading +import time as _time + + +# Top-level worker — must be module-level to be picklable for ProcessPoolExecutor. +# Accepts submit_time so it can measure queue wait (time spent waiting for a +# free subprocess slot). Returns (ntriples, queue_wait_ms). +def _morph_kgc_worker(config: str, submit_time: float) -> tuple[str, int]: + import morph_kgc as _mkgc + import time as _t + + queue_wait_ms = int((_t.time() - submit_time) * 1000) + ntriples = _mkgc.materialize(config).serialize(format="nt") + return ntriples, queue_wait_ms + + +# Thread-local used to pass mapping_wait_ms out of _materialize_graph without +# changing the return type of the public materialization API. +# Consumed by rml_mapping.RmlMappingService.apply_mapping — callers above that +# layer receive the timing as a regular return value. +_morph_kgc_tls = threading.local() + +# Lazy process pool — created on first use in the main process only. +# Worker subprocesses import this module but never call _get_morph_kgc_pool(), +# so they do NOT create their own pools (no recursive process explosion). +_morph_kgc_pool: ProcessPoolExecutor | None = None + + +def init_morph_kgc_pool(max_workers: int) -> None: + """Pre-create the morph_kgc process pool with a specific worker count. + Call once from the protocol __init__ before any mapping work starts. + Subsequent calls are no-ops (pool is only created once). + """ + global _morph_kgc_pool + if _morph_kgc_pool is not None: + return + ctx = multiprocessing.get_context("spawn") + _morph_kgc_pool = ProcessPoolExecutor(max_workers=max_workers, mp_context=ctx) + + +def _get_morph_kgc_pool() -> ProcessPoolExecutor: + global _morph_kgc_pool + if _morph_kgc_pool is None: + # Fallback if init_morph_kgc_pool was never called. + ctx = multiprocessing.get_context("spawn") + _morph_kgc_pool = ProcessPoolExecutor( + max_workers=os.cpu_count() or 4, + mp_context=ctx, + ) + return _morph_kgc_pool + + _SCHEMA_BASE = "https://schema.org" _SCHEMA_HTTP = "http://schema.org/" _AGENT_BASE_URL = "https://api.wordlift.io/agent" @@ -1341,22 +1396,32 @@ def _normalize_materialization_error(error: Exception) -> RuntimeError: def _materialize_graph(mapping_path: Path) -> Graph: - try: - import morph_kgc - except ImportError as exc: - raise RuntimeError( - "morph-kgc is required. Install with: pip install morph-kgc" - ) from exc - config = ( "[CONFIGURATION]\n" "output_format = N-TRIPLES\n" + # Disable morph_kgc internal multiprocessing: on Linux it uses fork() which + # deadlocks when the parent process already has threads running (asyncio pool, + # SHACL ProcessPoolExecutor). The outer pipeline handles concurrency. + "number_of_processes = 1\n" "\n" "[DataSource1]\n" f"mappings = {mapping_path}\n" ) try: - return morph_kgc.materialize(config) + # Submit to subprocess pool — each worker has isolated pyparsing state, + # so calls are genuinely parallel across CPU cores with no lock needed. + # .result() blocks the calling thread (not the asyncio event loop). + ntriples, queue_wait_ms = ( + _get_morph_kgc_pool() + .submit(_morph_kgc_worker, config, _time.time()) + .result() + ) + # Store wait time in thread-local so protocol.py can read it without + # changing the return type of this function. + _morph_kgc_tls.mapping_wait_ms = queue_wait_ms + graph = Graph() + graph.parse(data=ntriples, format="nt") + return graph except RuntimeError: raise except Exception as exc: diff --git a/wordlift_sdk/utils/__init__.py b/wordlift_sdk/utils/__init__.py index df6297e..18d34b3 100644 --- a/wordlift_sdk/utils/__init__.py +++ b/wordlift_sdk/utils/__init__.py @@ -36,8 +36,8 @@ "create_entity_patch_request", ), "create_delayed": ("wordlift_sdk.utils.delayed", "create_delayed"), - "get_me": ("wordlift_sdk.utils.get_me", "get_me"), - "reset_me": ("wordlift_sdk.utils.reset_me", "reset_me"), + "get_me": ("wordlift_sdk.utils._get_me", "get_me"), + "reset_me": ("wordlift_sdk.utils._reset_me", "reset_me"), "HtmlConverter": ("wordlift_sdk.utils.html_converter", "HtmlConverter"), "AutoConcurrencyController": ( "wordlift_sdk.utils.auto_concurrency", diff --git a/wordlift_sdk/utils/get_me.py b/wordlift_sdk/utils/_get_me.py similarity index 100% rename from wordlift_sdk/utils/get_me.py rename to wordlift_sdk/utils/_get_me.py diff --git a/wordlift_sdk/utils/reset_me.py b/wordlift_sdk/utils/_reset_me.py similarity index 100% rename from wordlift_sdk/utils/reset_me.py rename to wordlift_sdk/utils/_reset_me.py diff --git a/wordlift_sdk/validation/__init__.py b/wordlift_sdk/validation/__init__.py index bcc616f..40701c1 100644 --- a/wordlift_sdk/validation/__init__.py +++ b/wordlift_sdk/validation/__init__.py @@ -17,6 +17,9 @@ "prepare_shapes", "validate_file", "validate_jsonld_from_url", + "ShaclValidationService", + "ValidationMode", + "ValidationOutcome", ] @@ -51,6 +54,18 @@ "wordlift_sdk.validation.shacl", "validate_jsonld_from_url", ), + "ShaclValidationService": ( + "wordlift_sdk.validation.shacl_validation_service", + "ShaclValidationService", + ), + "ValidationMode": ( + "wordlift_sdk.validation.shacl_validation_service", + "ValidationMode", + ), + "ValidationOutcome": ( + "wordlift_sdk.validation.shacl_validation_service", + "ValidationOutcome", + ), } diff --git a/wordlift_sdk/validation/shacl_validation_service.py b/wordlift_sdk/validation/shacl_validation_service.py new file mode 100644 index 0000000..7ad8499 --- /dev/null +++ b/wordlift_sdk/validation/shacl_validation_service.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +import asyncio +import concurrent.futures +import functools +import logging +import time +from concurrent.futures import ProcessPoolExecutor +from dataclasses import dataclass +from enum import Enum +from typing import Any + +from rdflib import Graph +from rdflib.namespace import SH + +from wordlift_sdk.validation.shacl import PreparedShaclValidator + +logger = logging.getLogger(__name__) + +DEFAULT_VALIDATION_TIMEOUT_SECONDS = 120.0 + + +class ValidationMode(str, Enum): + OFF = "off" + WARN = "warn" + FAIL = "fail" + + +# Module-level worker state — one copy per subprocess, initialised by _init_worker. +# Must be module-level for picklability by ProcessPoolExecutor. +_worker_validator: PreparedShaclValidator | None = None + + +def _init_worker(shape_specs: list[str] | None) -> None: + global _worker_validator + _worker_validator = PreparedShaclValidator.from_shape_specs(shape_specs) + + +def _validate_in_worker(ntriples: str, submit_time: float) -> dict: + queue_wait_ms = int((time.time() - submit_time) * 1000) + t_start = time.perf_counter() + + data_graph = Graph() + data_graph.parse(data=ntriples, format="nt") + + result = _worker_validator.validate_graph(data_graph) + source_map = _worker_validator.prepared_shapes.shape_source_map + + warning_sources: dict[str, int] = {} + error_sources: dict[str, int] = {} + for node in result.report_graph.subjects(SH.resultSeverity, SH.Warning): + shape = next(result.report_graph.objects(node, SH.sourceShape), None) + label = source_map.get(shape, "unknown") + warning_sources[str(label)] = warning_sources.get(str(label), 0) + 1 + for node in result.report_graph.subjects(SH.resultSeverity, SH.Violation): + shape = next(result.report_graph.objects(node, SH.sourceShape), None) + label = source_map.get(shape, "unknown") + error_sources[str(label)] = error_sources.get(str(label), 0) + 1 + + return { + "passed": bool(result.conforms), + "warning_sources": dict(sorted(warning_sources.items())), + "error_sources": dict(sorted(error_sources.items())), + "queue_wait_ms": queue_wait_ms, + "validation_ms": int((time.perf_counter() - t_start) * 1000), + } + + +@dataclass +class ValidationOutcome: + passed: bool + warning_sources: dict[str, int] + error_sources: dict[str, int] + queue_wait_ms: int + validation_ms: int + + @property + def failed(self) -> bool: + return not self.passed + + @property + def warning_count(self) -> int: + return sum(self.warning_sources.values()) + + @property + def error_count(self) -> int: + return sum(self.error_sources.values()) + + def to_dict(self) -> dict[str, Any]: + return { + "pass": self.passed, + "fail": self.failed, + "warnings": {"count": self.warning_count, "sources": self.warning_sources}, + "errors": {"count": self.error_count, "sources": self.error_sources}, + } + + +class ShaclValidationService: + def __init__( + self, + shape_specs: list[str] | None, + mode: ValidationMode, + pool_size: int = 1, + timeout_seconds: float = DEFAULT_VALIDATION_TIMEOUT_SECONDS, + ) -> None: + self._mode = mode + self._timeout_seconds = timeout_seconds + self._executor: ProcessPoolExecutor | None = None + if mode != ValidationMode.OFF: + self._executor = ProcessPoolExecutor( + max_workers=pool_size, + initializer=_init_worker, + initargs=(shape_specs,), + ) + logger.info( + "Created SHACL process pool with %d workers (mode=%s)", + pool_size, + mode, + ) + + @property + def mode(self) -> ValidationMode: + return self._mode + + async def validate(self, graph: Graph) -> ValidationOutcome | None: + """Validate *graph* against the configured SHACL shapes. + + Returns ``None`` when validation is disabled (mode=off) or skipped due + to a timeout or broken executor. + """ + if self._mode == ValidationMode.OFF or self._executor is None: + return None + ntriples = graph.serialize(format="nt") + loop = asyncio.get_event_loop() + try: + result = await asyncio.wait_for( + loop.run_in_executor( + self._executor, + functools.partial(_validate_in_worker, ntriples, time.time()), + ), + timeout=self._timeout_seconds, + ) + except (asyncio.TimeoutError, concurrent.futures.BrokenExecutor) as exc: + logger.warning("SHACL validation skipped: %s (%s)", type(exc).__name__, exc) + return None + return ValidationOutcome( + passed=result["passed"], + warning_sources=result["warning_sources"], + error_sources=result["error_sources"], + queue_wait_ms=result["queue_wait_ms"], + validation_ms=result["validation_ms"], + ) + + def close(self) -> None: + if self._executor is not None: + self._executor.shutdown(wait=False) + self._executor = None diff --git a/wordlift_sdk/workflow/url_handler/ingestion_web_page_scrape_url_handler.py b/wordlift_sdk/workflow/url_handler/ingestion_web_page_scrape_url_handler.py index a228e26..6c7a719 100644 --- a/wordlift_sdk/workflow/url_handler/ingestion_web_page_scrape_url_handler.py +++ b/wordlift_sdk/workflow/url_handler/ingestion_web_page_scrape_url_handler.py @@ -1,5 +1,7 @@ from __future__ import annotations +import asyncio +import functools import json import logging import re @@ -43,7 +45,10 @@ def __init__( async def __call__(self, url: Url) -> None: settings = self._build_settings(url) - result = run_ingestion(settings) + loop = asyncio.get_event_loop() + result = await loop.run_in_executor( + None, functools.partial(run_ingestion, settings) + ) if not result.pages: failed = [