diff --git a/cosmotech/coal/postgresql/runner.py b/cosmotech/coal/postgresql/runner.py index 683628bb..120a60f4 100644 --- a/cosmotech/coal/postgresql/runner.py +++ b/cosmotech/coal/postgresql/runner.py @@ -105,8 +105,8 @@ def remove_runner_metadata_from_postgresql( schema_table = f"{_psql.db_schema}.{_psql.table_prefix}RunnerMetadata" sql_delete_from_metatable = f""" DELETE FROM {schema_table} - WHERE last_csm_run_id={runner.get("lastRunId")}; + WHERE last_csm_run_id={runner.get("lastRunInfo").get("lastRunId")}; """ curs.execute(sql_delete_from_metatable) conn.commit() - return runner.get("lastRunId") + return runner.get("lastRunInfo").get("lastRunId") diff --git a/cosmotech/coal/store/store.py b/cosmotech/coal/store/store.py index 86015d88..d940fffa 100644 --- a/cosmotech/coal/store/store.py +++ b/cosmotech/coal/store/store.py @@ -6,6 +6,7 @@ # specifically authorized by written means by Cosmo Tech. import pathlib +from functools import wraps import pyarrow from adbc_driver_sqlite import dbapi @@ -15,6 +16,19 @@ from cosmotech.coal.utils.logger import LOGGER +def table_name_to_lower(func): + @wraps(func) + def wrapper(*args, **kwargs): + if "table_name" in kwargs: + kwargs["table_name"] = kwargs["table_name"].lower() + else: + args = list(args) + args[1] = args[1].lower() + return func(*args, **kwargs) + + return wrapper + + class Store: @staticmethod def sanitize_column(column_name: str) -> str: @@ -34,27 +48,31 @@ def reset(self): if self._database_path.exists(): self._database_path.unlink() + @table_name_to_lower def get_table(self, table_name: str) -> pyarrow.Table: if not self.table_exists(table_name): raise ValueError(T("coal.errors.data.no_table").format(table_name=table_name)) - return self.execute_query(f"select * from {table_name}") + return self.execute_query(f'select * from "{table_name}"') + @table_name_to_lower def table_exists(self, table_name) -> bool: return table_name in self.list_tables() + @table_name_to_lower def get_table_schema(self, table_name: str) -> pyarrow.Schema: if not self.table_exists(table_name): raise ValueError(T("coal.errors.data.no_table").format(table_name=table_name)) with dbapi.connect(self._database) as conn: return conn.adbc_get_table_schema(table_name) + @table_name_to_lower def add_table(self, table_name: str, data=pyarrow.Table, replace: bool = False): with dbapi.connect(self._database, autocommit=True) as conn: with conn.cursor() as curs: rows = curs.adbc_ingest(table_name, data, "replace" if replace else "create_append") LOGGER.debug(T("coal.common.data_transfer.rows_inserted").format(rows=rows, table_name=table_name)) - def execute_query(self, sql_query: str) -> pyarrow.Table: + def execute_query(self, sql_query: str, parameters: list = (None,)) -> pyarrow.Table: batch_size = 1024 batch_size_increment = 1024 while True: @@ -62,7 +80,7 @@ def execute_query(self, sql_query: str) -> pyarrow.Table: with dbapi.connect(self._database, autocommit=True) as conn: with conn.cursor() as curs: curs.adbc_statement.set_options(**{"adbc.sqlite.query.batch_rows": str(batch_size)}) - curs.execute(sql_query) + curs.execute(sql_query, parameters) return curs.fetch_arrow_table() except OSError: batch_size += batch_size_increment diff --git a/tests/integration/coal/test_store/test_store_store.py b/tests/integration/coal/test_store/test_store_store.py new file mode 100644 index 00000000..126685e7 --- /dev/null +++ b/tests/integration/coal/test_store/test_store_store.py @@ -0,0 +1,68 @@ +# Copyright (C) - 2022 - 2025 - Cosmo Tech +# This document and all information contained herein is the exclusive property - +# including all intellectual property rights pertaining thereto - of Cosmo Tech. +# Any use, reproduction, translation, broadcasting, transmission, distribution, +# etc., to any person is prohibited unless it has been previously and +# specifically authorized by written means by Cosmo Tech. + +import pyarrow as pa +import pytest + +from cosmotech.coal.store.store import Store + + +@pytest.fixture(scope="function") +def store(): + store = Store(reset=True) + yield store + store.reset() + + +class TestStore: + """Tests for the store class.""" + + def test_get_table(self, store): + """Test get table with table name starting with numbers""" + + # Arrange + table_name = "normal_name" + table = pa.Table.from_arrays([pa.array([1, 2, 3]), pa.array(["a", "b", "c"])], names=["id", "name"]) + store.add_table(table_name, table) + + # Act + result = store.get_table(table_name) + + # Assert + assert result + + def test_get_table_with_number_name(self, store): + """Test get table with table name starting with numbers""" + + # Arrange + table_name = "10mb" + table = pa.Table.from_arrays([pa.array([1, 2, 3]), pa.array(["a", "b", "c"])], names=["id", "name"]) + store.add_table(table_name, table) + + # Act + result = store.get_table(table_name) + + # Assert + assert result + + def test_add_get_table_with_upper_and_lower_case(self, store): + """Test add table and get table behavior with uppper and lower cases""" + + # Arrange + + table = pa.Table.from_arrays([pa.array([1, 2, 3]), pa.array(["a", "b", "c"])], names=["id", "name"]) + store.add_table("10mb", table) + table = pa.Table.from_arrays([pa.array([4, 5, 6]), pa.array(["A", "B", "C"])], names=["id", "name"]) + store.add_table("10MB", table) + + # Act + UPPER_result = store.get_table("10MB") + upper_result = store.get_table("10mb") + + assert upper_result + assert UPPER_result + assert upper_result == UPPER_result diff --git a/tests/unit/coal/test_postgresql/test_postgresql_runner.py b/tests/unit/coal/test_postgresql/test_postgresql_runner.py index d50ff91e..711099af 100644 --- a/tests/unit/coal/test_postgresql/test_postgresql_runner.py +++ b/tests/unit/coal/test_postgresql/test_postgresql_runner.py @@ -112,7 +112,7 @@ def test_remove_runner_metadata_to_postgresql(self, mock_connect, mock_postgres_ mock_runner = { "id": "test-runner-id", "name": "Test Runner", - "lastRunId": "test-run-id", + "lastRunInfo": {"lastRunId": "test-run-id"}, "runTemplateId": "test-template-id", } diff --git a/tests/unit/coal/test_store/test_store_store.py b/tests/unit/coal/test_store/test_store_store.py index 8098ac3e..8e37f5b6 100644 --- a/tests/unit/coal/test_store/test_store_store.py +++ b/tests/unit/coal/test_store/test_store_store.py @@ -77,7 +77,7 @@ def test_get_table(self, mock_execute_query, mock_table_exists): # Assert mock_table_exists.assert_called_once_with(table_name) - mock_execute_query.assert_called_once_with(f"select * from {table_name}") + mock_execute_query.assert_called_once_with(f'select * from "{table_name}"') assert result == expected_table @patch.object(Store, "table_exists") @@ -270,7 +270,7 @@ def test_execute_query(self, mock_connect): # Assert mock_connect.assert_called_once() mock_cursor.adbc_statement.set_options.assert_called_once_with(**{"adbc.sqlite.query.batch_rows": "1024"}) - mock_cursor.execute.assert_called_once_with(sql_query) + mock_cursor.execute.assert_called_once_with(sql_query, (None,)) mock_cursor.fetch_arrow_table.assert_called_once() assert result == expected_table @@ -306,7 +306,7 @@ def test_execute_query_with_oserror(self, mock_connect): # First call with batch_size = 1024, second with batch_size = 2048 mock_cursor.adbc_statement.set_options.assert_any_call(**{"adbc.sqlite.query.batch_rows": "1024"}) mock_cursor.adbc_statement.set_options.assert_any_call(**{"adbc.sqlite.query.batch_rows": "2048"}) - mock_cursor.execute.assert_called_once_with(sql_query) + mock_cursor.execute.assert_called_once_with(sql_query, (None,)) mock_cursor.fetch_arrow_table.assert_called_once() assert result == expected_table