diff --git a/.vscode/tasks.json b/.vscode/tasks.json index 3057043..ad986dd 100644 --- a/.vscode/tasks.json +++ b/.vscode/tasks.json @@ -18,13 +18,13 @@ { "label": "pytest integration", "type": "shell", - "command": "cd tests/integration; PYTHONPATH=../..: pytest -m sqlite3 -vv", + "command": "cd tests/integration; PYTHONPATH=../..: pytest -m 'not postgres' -vv", "problemMatcher": [] }, { "label": "pytest integration last failed", "type": "shell", - "command": "cd tests/integration; PYTHONPATH=../..: pytest -m sqlite3 -vv --last-failed", + "command": "cd tests/integration; PYTHONPATH=../..: pytest -m 'not postgres' -vv --last-failed", "problemMatcher": [] }, { diff --git a/lufa/api_v1.py b/lufa/api_v1.py index 303cbd8..bd12b95 100644 --- a/lufa/api_v1.py +++ b/lufa/api_v1.py @@ -8,7 +8,7 @@ from lufa.auth import ro_token_required, sanitize, token_required, with_json_data from lufa.decorators import debug_only from lufa.provider import get_api_repository, get_awx_client, get_database_manager -from lufa.repository.api_repository import LufaKeyError +from lufa.repository.api_repository import JobExport, LufaKeyError MALFORMED_JSON = {"error": "Malformed json"} @@ -243,7 +243,7 @@ def jobs_export(tower_job_id: int): "tasks": list, }, ) -def import_job(data: dict): +def import_job(data: JobExport): """ Imports a job from export data. diff --git a/lufa/database.py b/lufa/database.py index f0d8a52..010f7eb 100644 --- a/lufa/database.py +++ b/lufa/database.py @@ -94,7 +94,7 @@ def init_db(self) -> None: def get_db_now(self) -> str: cur = self.get_db_connection().cursor() - cur.execute("select datetime('now') as now;") + cur.execute("select strftime('%Y-%m-%d %H:%M', datetime('now')) as now;") return cur.fetchone()["now"] diff --git a/lufa/repository/api_repository.py b/lufa/repository/api_repository.py index bfd92e3..ba3cd70 100644 --- a/lufa/repository/api_repository.py +++ b/lufa/repository/api_repository.py @@ -2,7 +2,7 @@ import logging import sqlite3 from abc import ABC, abstractmethod -from typing import Optional, TypeAlias, TypedDict +from typing import Optional, TypeAlias, TypedDict, cast from psycopg2.errors import ForeignKeyViolation, InvalidDatetimeFormat, InvalidTextRepresentation @@ -29,7 +29,11 @@ class Callback(TypedDict): ansible_host: str state: JobState module: str - result_dump: str + result_dump: JSon + + +class CallbackExport(Callback): + timestamp: TimeStamp class Task(TypedDict): @@ -38,6 +42,12 @@ class Task(TypedDict): task_name: str +class TaskExport(TypedDict): + ansible_uuid: str + task_name: str + callbacks: list[CallbackExport] + + class TowerJobStats(TypedDict): ansible_host: str ok: int @@ -59,6 +69,41 @@ class JobTemplateComplianceStates(TypedDict): organisation: str +class FullJob(TypedDict): + tower_job_id: int + tower_job_template_id: int + tower_job_template_name: str + ansible_limit: str + tower_user_name: str + awx_tags: list[str] + extra_vars: JSon + artifacts: JSon + tower_schedule_id: int + tower_schedule_name: str + tower_workflow_job_id: int + tower_workflow_job_name: str + start_time: TimeStamp + end_time: TimeStamp | None + state: JobState + + +class TowerJobTemplate(TypedDict): + tower_job_template_id: int + tower_job_template_name: str + playbook_path: str + compliance_interval: int + awx_organisation: str + template_infos: str | None + + +class JobExport(TypedDict): + exported_at: TimeStamp + job: FullJob + job_template: TowerJobTemplate + stats: list[TowerJobStats] + tasks: list[TaskExport] + + class ApiRepository(ABC): @abstractmethod def get_all_noncompliant_hosts(self) -> dict[str, list[JobTemplateComplianceStates]]: @@ -136,12 +181,12 @@ def add_stats(self, tower_job_id: int, stats: list[TowerJobStats]) -> None: pass @abstractmethod - def export_job(self, tower_job_id: int) -> dict: + def export_job(self, tower_job_id: int) -> JobExport: """Exports complete job data with tasks and callbacks""" pass @abstractmethod - def import_job(self, job_data: dict) -> int: + def import_job(self, export_data: JobExport) -> int: """Imports a job from a dict Returns the tower_job_id of the imported job. """ @@ -392,7 +437,7 @@ def update_job( conn.commit() - def export_job(self, tower_job_id: int) -> dict: + def export_job(self, tower_job_id: int) -> JobExport: """Exports complete job data with tasks and callbacks""" conn = self.db_manager.get_db_connection() cursor = conn.cursor() @@ -486,10 +531,10 @@ def export_job(self, tower_job_id: int) -> dict: } ) - tasks_with_callbacks = list(tasks_dict.values()) + tasks_with_callbacks = cast(list[TaskExport], list(tasks_dict.values())) # build export structure - export_data = { + export_data: JobExport = { "exported_at": self.db_manager.get_db_now(), "job": { "tower_job_id": job["tower_job_id"], @@ -522,7 +567,7 @@ def export_job(self, tower_job_id: int) -> dict: return export_data - def import_job(self, export_data: dict) -> int: + def import_job(self, export_data: JobExport) -> int: """ Imports a job from export data. @@ -548,7 +593,7 @@ def import_job(self, export_data: dict) -> int: try: # insert/update job_template - template_infos_value = template_data.get("template_infos", {}) + template_infos_value = cast(str, template_data.get("template_infos", {})) if template_infos_value is not None: template_infos_json = json.dumps(template_infos_value) else: @@ -675,7 +720,7 @@ def import_job(self, export_data: dict) -> int: # insert callbacks for this task for callback in task.get("callbacks", []): - result_dump = json.dumps(callback["result_dump"]) + result_dump = callback["result_dump"] cursor.execute( """ @@ -959,7 +1004,7 @@ def update_job( raise LufaKeyError("tower_job_id", tower_job_id) from ex db_conn.commit() - def export_job(self, tower_job_id: int) -> dict: + def export_job(self, tower_job_id: int) -> JobExport: conn = self.db_manager.get_db_connection() cursor = conn.cursor() @@ -1029,10 +1074,10 @@ def export_job(self, tower_job_id: int) -> dict: (tower_job_id,), ) - tasks_with_callbacks = cursor.fetchall() + tasks_with_callbacks: list[TaskExport] = cursor.fetchall() # build export structure - export_data = { + export_data: JobExport = { "exported_at": self.db_manager.get_db_now(), "job": { "tower_job_id": job["tower_job_id"], @@ -1041,8 +1086,8 @@ def export_job(self, tower_job_id: int) -> dict: "ansible_limit": job["ansible_limit"], "tower_user_name": job["tower_user_name"], "awx_tags": job["awx_tags"], - "extra_vars": job["extra_vars"], - "artifacts": job["artifacts"], + "extra_vars": json.dumps(job["extra_vars"]), + "artifacts": json.dumps(job["artifacts"]), "tower_schedule_id": job["tower_schedule_id"], "tower_schedule_name": job["tower_schedule_name"], "tower_workflow_job_id": job["tower_workflow_job_id"], @@ -1064,14 +1109,18 @@ def export_job(self, tower_job_id: int) -> dict: { "ansible_uuid": task["ansible_uuid"], "task_name": task["task_name"], - "callbacks": task["callbacks"] if task["callbacks"] else [], + "callbacks": ( + [{**cb, "result_dump": json.dumps(cb["result_dump"])} for cb in task["callbacks"]] + if task["callbacks"] + else [] + ), } for task in tasks_with_callbacks ], } return export_data - def import_job(self, export_data: dict) -> int: + def import_job(self, export_data: JobExport) -> int: """Imports a job from a dict Returns the tower_job_id of the imported job. """ @@ -1121,14 +1170,9 @@ def import_job(self, export_data: dict) -> int: ) # insert job - # ensure extra_vars and artifacts are strings - extra_vars = job_data.get("extra_vars", "{}") - if isinstance(extra_vars, dict): - extra_vars = json.dumps(extra_vars) - - artifacts = job_data.get("artifacts", "{}") - if isinstance(artifacts, dict): - artifacts = json.dumps(artifacts) + # ensure extra_vars and artifacts are dicts + extra_vars = cast(dict, job_data.get("extra_vars", {})) + artifacts = cast(dict, job_data.get("artifacts", {})) cursor.execute( """ @@ -1212,7 +1256,7 @@ def import_job(self, export_data: dict) -> int: # insert callbacks for this task for callback in task.get("callbacks", []): - result_dump = json.dumps(callback["result_dump"]) + result_dump = callback["result_dump"] cursor.execute( """ INSERT INTO task_callbacks (task_ansible_uuid, diff --git a/tests/integration/api_repository/test_job_import_export.py b/tests/integration/api_repository/test_job_import_export.py index caeabae..4aa71dc 100644 --- a/tests/integration/api_repository/test_job_import_export.py +++ b/tests/integration/api_repository/test_job_import_export.py @@ -1,8 +1,10 @@ +from typing import cast + import pytest -from lufa.repository.api_repository import ApiRepository +from lufa.repository.api_repository import ApiRepository, JobExport, TowerJobStats from lufa.repository.backend_repository import ResourceNotFoundError -from tests.integration.conftest import HostIntependantTowerJobStats, LufaFactory +from tests.integration.conftest import ApiRepositoryToBackend, HostIntependantTowerJobStats, LufaFactory HOST1 = "host1.example.com" HOST2 = "host2.example.com" @@ -18,7 +20,7 @@ def test_export_nonexistent_job_raises_error(self, api_repository: ApiRepository def test_export_job_with_tasks_and_callbacks( self, - api_repository, + api_repository: ApiRepository, lufa_factory: LufaFactory, single_any_stat: HostIntependantTowerJobStats, ): @@ -50,7 +52,7 @@ def test_export_job_with_tasks_and_callbacks( result_dump='{"changed": false}', ) - result = api_repository.export_job(job.tower_job_id) + result: JobExport = api_repository.export_job(job.tower_job_id) # verify tasks assert len(result["tasks"]) == 2 @@ -58,6 +60,12 @@ def test_export_job_with_tasks_and_callbacks( assert "Install packages" in task_names assert "Configure service" in task_names + assert type(result["job"]["extra_vars"]) is str + assert type(result["job"]["artifacts"]) is str + for tasks in result["tasks"]: + for cb in tasks["callbacks"]: + assert type(cb["result_dump"]) is str + # verify callbacks for task in result["tasks"]: if task["task_name"] == "Install packages": @@ -71,68 +79,15 @@ def test_export_job_with_tasks_and_callbacks( class TestImportJob: """Test importing a job from API repository""" - def test_import_existing_job_fails( - self, - api_repository: ApiRepository, - lufa_factory: LufaFactory, - ): - """Test that importing a job with an ID that already exists raises an error.""" - # create an existing job - existing_job = lufa_factory.add_tower_template().add_job() - - # manually build import data with the same job ID - import_data = { - "exported_at": "2026-03-23T10:00:00", - "job": { - "tower_job_id": existing_job.tower_job_id, # Same ID as existing job - "tower_job_template_id": 100, - "tower_job_template_name": "Test Template", - "ansible_limit": "*.example.com", - "tower_user_name": "testuser", - "awx_tags": ["tag1"], - "extra_vars": "{}", - "artifacts": "{}", - "tower_schedule_id": None, - "tower_schedule_name": None, - "tower_workflow_job_id": None, - "tower_workflow_job_name": None, - "start_time": "2026-03-23T09:00:00", - "end_time": None, - "state": "started", - }, - "job_template": { - "tower_job_template_id": 100, - "tower_job_template_name": "Test Template", - "playbook_path": "test.yml", - "compliance_interval": 0, - "awx_organisation": "Default", - "template_infos": None, - }, - "stats": [], - "tasks": [], - } - - # try to import with the same tower_job_id - should fail - with pytest.raises((ValueError, Exception)) as exc_info: - api_repository.import_job(import_data) - - # verify error message mentions the job already exists - assert "already exists" in str(exc_info.value).lower() - - def test_import_job_with_tasks_and_callbacks( - self, - api_repository: ApiRepository, - lufa_factory: LufaFactory, - ): - """Test importing a complete job with tasks and callbacks.""" - + @pytest.fixture + def import_data(self): new_job_id = 99000 new_template_id = 88000 task1_uuid = "aaaaaaaa-1111-2222-3333-aaaaaaaaaaaa" task2_uuid = "bbbbbbbb-1111-2222-3333-bbbbbbbbbbbb" task3_uuid = "cccccccc-1111-2222-3333-cccccccccccc" - import_data = { + return { "exported_at": "2026-03-23T10:00:00", "job": { "tower_job_id": new_job_id, @@ -147,8 +102,8 @@ def test_import_job_with_tasks_and_callbacks( "tower_schedule_name": "Daily Schedule", "tower_workflow_job_id": 67890, "tower_workflow_job_name": "Import Workflow", - "start_time": "2026-03-23T09:00:00", - "end_time": "2026-03-23T09:30:00", + "start_time": "2026-03-23T09:00:00.123", + "end_time": "2026-03-23T09:30:00.123", "state": "success", }, "job_template": { @@ -190,14 +145,14 @@ def test_import_job_with_tasks_and_callbacks( "ansible_host": HOST1, "state": "ok", "module": "setup", - "timestamp": "2026-03-23T09:05:00", + "timestamp": "2026-03-23T09:05:00.123", "result_dump": '{"ansible_facts": {"os_family": "Debian"}}', }, { "ansible_host": HOST2, "state": "ok", "module": "setup", - "timestamp": "2026-03-23T09:05:01", + "timestamp": "2026-03-23T09:05:01.123", "result_dump": '{"ansible_facts": {"os_family": "RedHat"}}', }, ], @@ -210,14 +165,14 @@ def test_import_job_with_tasks_and_callbacks( "ansible_host": HOST1, "state": "ok", "module": "apt", - "timestamp": "2026-03-23T09:10:00", + "timestamp": "2026-03-23T09:10:00.123", "result_dump": '{"changed": true, "packages": ["nginx", "postgresql"]}', }, { "ansible_host": HOST2, "state": "ok", "module": "yum", - "timestamp": "2026-03-23T09:10:01", + "timestamp": "2026-03-23T09:10:01.123", "result_dump": '{"changed": true, "packages": ["nginx", "postgresql"]}', }, ], @@ -230,30 +185,146 @@ def test_import_job_with_tasks_and_callbacks( "ansible_host": HOST1, "state": "ok", "module": "systemd", - "timestamp": "2026-03-23T09:15:00", - "result_dump": '{"changed": false, "status": {"running": true}}', + "timestamp": "2026-03-23T09:15:00.123", + "result_dump": '{"status": {"running": true}, "changed": false}', }, { "ansible_host": HOST2, "state": "changed", "module": "systemd", - "timestamp": "2026-03-23T09:15:01", - "result_dump": '{"changed": true, "status": {"running": true}}', + "timestamp": "2026-03-23T09:15:01.123", + "result_dump": '{"status": {"running": true}, "changed": true}', }, ], }, ], } + def test_reexporting_is_mostly_invariant(self, api_repository: ApiRepository, import_data: JobExport): + """Test exporting a complete job with tasks and callbacks an imported returns the original json with only specified possible changes.""" + # import the job imported_job_id = api_repository.import_job(import_data) + reexport = api_repository.export_job(imported_job_id) - assert imported_job_id == new_job_id + assert import_data["exported_at"] < reexport["exported_at"] + self.assert_mostly_equal(reexport, import_data) + + def test_import_existing_job_fails( + self, + api_repository: ApiRepository, + lufa_factory: LufaFactory, + ): + """Test that importing a job with an ID that already exists raises an error.""" + # create an existing job + existing_job = lufa_factory.add_tower_template().add_job() + + # manually build import data with the same job ID + import_data: JobExport = { + "exported_at": "2026-03-23T10:00:00", + "job": { + "tower_job_id": existing_job.tower_job_id, # Same ID as existing job + "tower_job_template_id": 100, + "tower_job_template_name": "Test Template", + "ansible_limit": "*.example.com", + "tower_user_name": "testuser", + "awx_tags": ["tag1"], + "extra_vars": "{}", + "artifacts": "{}", + "tower_schedule_id": 42, + "tower_schedule_name": "abc", + "tower_workflow_job_id": 42, + "tower_workflow_job_name": "abc", + "start_time": "2026-03-23T09:00:00", + "end_time": None, + "state": "started", + }, + "job_template": { + "tower_job_template_id": 100, + "tower_job_template_name": "Test Template", + "playbook_path": "test.yml", + "compliance_interval": 0, + "awx_organisation": "Default", + "template_infos": None, + }, + "stats": [], + "tasks": [], + } + + # try to import with the same tower_job_id - should fail + with pytest.raises((ValueError, Exception)) as exc_info: + api_repository.import_job(import_data) + + # verify error message mentions the job already exists + assert "already exists" in str(exc_info.value).lower() + + def test_import_job_with_tasks_and_callbacks( + self, + api_repository: ApiRepository, + import_data: JobExport, + ): + """Test importing a complete job with tasks and callbacks.""" + + # import the job + imported_job_id = api_repository.import_job(import_data) + + assert imported_job_id == import_data["job"]["tower_job_id"] # verify job exists assert api_repository.job_exists(imported_job_id) # verify tasks were imported - assert api_repository.tasks_exists(task1_uuid) - assert api_repository.tasks_exists(task2_uuid) - assert api_repository.tasks_exists(task3_uuid) + assert api_repository.tasks_exists(import_data["tasks"][0]["ansible_uuid"]) + assert api_repository.tasks_exists(import_data["tasks"][1]["ansible_uuid"]) + assert api_repository.tasks_exists(import_data["tasks"][2]["ansible_uuid"]) + + def test_export_job_with_tasks_and_callbacks( + self, + api_repository: ApiRepository, + api_repository_to_backend: ApiRepositoryToBackend, + lufa_factory: LufaFactory, + single_any_stat: HostIntependantTowerJobStats, + ): + """Test export of a complete job with tasks and callbacks""" + + job = lufa_factory.add_tower_template().add_job().with_stats(HOST1, single_any_stat).with_end_time() + + # add tasks + task1_uuid = "aaaaaaaa-1111-1111-1111-111111111111" + task2_uuid = "bbbbbbbb-2222-2222-2222-222222222222" + api_repository.add_task(task1_uuid, job.tower_job_id, "Install packages") + api_repository.add_task(task2_uuid, job.tower_job_id, "Configure service") + + # add callbacks for task1 + api_repository.add_callback( + task_ansible_uuid=task1_uuid, + ansible_host=HOST1, + state="ok", + module="apt", + result_dump='{"changed": true}', + ) + + # add callbacks for task2 + api_repository.add_callback( + task_ansible_uuid=task2_uuid, + ansible_host=HOST1, + state="ok", + module="template", + result_dump='{"changed": false}', + ) + + initial_export = api_repository.export_job(job.tower_job_id) + for to_backend in api_repository_to_backend(): # to test postgre2postgres in single DB + to_backend.import_job(initial_export) + reexport = to_backend.export_job(job.tower_job_id) + + self.assert_mostly_equal(reexport, initial_export) + + def assert_mostly_equal(self, export: JobExport, original: JobExport) -> None: + """assert two JobExports are same with only specified possible changes.""" + assert original["exported_at"] <= export["exported_at"] + export["stats"] = [cast(TowerJobStats, dict(item)) for item in export["stats"]] + original["tasks"].sort(key=lambda x: x["ansible_uuid"]) + export["tasks"].sort(key=lambda x: x["ansible_uuid"]) + masked = {"exported_at": "masked"} + assert {**export, **masked} == {**original, **masked} diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index dc8985e..4df3433 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -4,13 +4,14 @@ import tempfile from dataclasses import dataclass from datetime import datetime, timedelta +from functools import partial from importlib.resources import files -from typing import Iterable, Literal, NamedTuple, TypedDict, cast +from typing import Callable, Iterable, Literal, NamedTuple, TypeAlias, TypedDict, cast import pytest from lufa.database import DatabaseManager, NumDays, PostgresDatabaseManager, SqliteDatabaseManager -from lufa.provider import AppConfig, DbConfig, PostgresConfig, SqliteConfig +from lufa.provider import DbConfig, PostgresConfig, SqliteConfig from lufa.repository.api_repository import ApiRepository, PostgresApiRepository, SqliteApiRepository, TowerJobStats from lufa.repository.backend_repository import BackendRepository, PostgresBackendRepository, SqliteBackendRepository from lufa.repository.user_repository import PostgresUserRepository, SqliteUserRepository, UserRepository @@ -19,13 +20,15 @@ def pytest_generate_tests(metafunc: pytest.Metafunc): - if "mark_db_backend" in metafunc.fixturenames: + for key in ["mark_db_backend", "mark_db_to_backend"]: + if key not in metafunc.fixturenames: + continue relevant_marks = [pytest.mark.sqlite3, pytest.mark.postgres] found_mark_names = [ m.name for m in metafunc.definition.iter_markers() if m.name in [r.name for r in relevant_marks] ] metafunc.parametrize( - "mark_db_backend", + key, ( # test that not marked with supported db backends get all [pytest.param(r.name, marks=r) for r in relevant_marks] @@ -85,7 +88,11 @@ def pytest_generate_tests(metafunc: pytest.Metafunc): metafunc.parametrize("single_any_stat", [pytest.param(p.param, id=p.title) for p in success_stats[:1]]) -@pytest.fixture(scope=SCOPE) +@pytest.fixture(scope=SCOPE, name="db_config") +def db_config_fixture(mark_db_backend: str) -> Iterable[DbConfig]: + yield from db_config(mark_db_backend) + + def db_config(mark_db_backend: str) -> Iterable[DbConfig]: match mark_db_backend: case pytest.mark.sqlite3.name: @@ -116,14 +123,22 @@ def db_config(mark_db_backend: str) -> Iterable[DbConfig]: raise NotImplementedError(f"Unknown DB backend marker: pytest.mark.{mark_db_backend}") -@pytest.fixture +@pytest.fixture(name="db_manager") +def db_manager_fixture(empty_db: DatabaseManager) -> DatabaseManager: + return db_manager(empty_db) + + def db_manager(empty_db: DatabaseManager) -> DatabaseManager: empty_db.init_db() return empty_db -@pytest.fixture -def empty_db(mark_db_backend: str, db_config: AppConfig) -> Iterable[DatabaseManager]: +@pytest.fixture(name="empty_db") +def empty_db_fixture(mark_db_backend: str, db_config: DbConfig) -> Iterable[DatabaseManager]: + yield from empty_db(mark_db_backend, db_config) + + +def empty_db(mark_db_backend: str, db_config: DbConfig) -> Iterable[DatabaseManager]: match mark_db_backend: case pytest.mark.sqlite3.name: sqlite = SqliteDatabaseManager(db_config["DB_DATABASE"], str((files("lufa").joinpath("schema_sqlite.sql")))) @@ -137,11 +152,12 @@ def empty_db(mark_db_backend: str, db_config: AppConfig) -> Iterable[DatabaseMan sqlite.close_db() case pytest.mark.postgres.name: + postgres_config = cast(PostgresConfig, db_config) postgres = PostgresDatabaseManager( - host=db_config["DB_HOST"], - user=db_config["DB_USER"], - database=db_config["DB_DATABASE"], - password=db_config["DB_PASSWORD"], + host=postgres_config["DB_HOST"], + user=postgres_config["DB_USER"], + database=postgres_config["DB_DATABASE"], + password=postgres_config["DB_PASSWORD"], init_script=str(files("lufa").joinpath("schema.sql")), ) with open( @@ -153,7 +169,11 @@ def empty_db(mark_db_backend: str, db_config: AppConfig) -> Iterable[DatabaseMan postgres.close_db() -@pytest.fixture +@pytest.fixture(name="api_repository") +def api_repository_fixture(mark_db_backend: str, db_manager: DatabaseManager) -> ApiRepository: + return api_repository(mark_db_backend, db_manager) + + def api_repository(mark_db_backend: str, db_manager: DatabaseManager) -> ApiRepository: if mark_db_backend == pytest.mark.sqlite3.name: return SqliteApiRepository(db_manager) @@ -162,6 +182,21 @@ def api_repository(mark_db_backend: str, db_manager: DatabaseManager) -> ApiRepo raise NotImplementedError(f"Unknown DB backend marker: pytest.mark.{mark_db_backend}") +ApiRepositoryToBackend: TypeAlias = Callable[[], Iterable[ApiRepository]] + + +@pytest.fixture +def api_repository_to_backend(mark_db_to_backend: str) -> Iterable[ApiRepositoryToBackend]: + for config in db_config(mark_db_to_backend): + yield partial(reset_api_repository, mark_db_to_backend, config) + + +def reset_api_repository(mark_db_to_backend: str, config: DbConfig) -> Iterable[ApiRepository]: + # to be able to export then import to same Postgres DB + for empty in empty_db(mark_db_to_backend, config): + yield api_repository(mark_db_to_backend, db_manager(empty)) + + @pytest.fixture def user_repository(mark_db_backend: str, db_manager: DatabaseManager) -> UserRepository: if mark_db_backend == pytest.mark.sqlite3.name: