Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cosmotech/coal/postgresql/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ def remove_runner_metadata_from_postgresql(
schema_table = f"{_psql.db_schema}.{_psql.table_prefix}RunnerMetadata"
sql_delete_from_metatable = f"""
DELETE FROM {schema_table}
WHERE last_csm_run_id={runner.get("lastRunId")};
WHERE last_csm_run_id={runner.get("lastRunInfo").get("lastRunId")};
"""
curs.execute(sql_delete_from_metatable)
conn.commit()
return runner.get("lastRunId")
return runner.get("lastRunInfo").get("lastRunId")
24 changes: 21 additions & 3 deletions cosmotech/coal/store/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# specifically authorized by written means by Cosmo Tech.

import pathlib
from functools import wraps

import pyarrow
from adbc_driver_sqlite import dbapi
Expand All @@ -15,6 +16,19 @@
from cosmotech.coal.utils.logger import LOGGER


def table_name_to_lower(func):
@wraps(func)
def wrapper(*args, **kwargs):
if "table_name" in kwargs:
kwargs["table_name"] = kwargs["table_name"].lower()
else:
args = list(args)
args[1] = args[1].lower()
return func(*args, **kwargs)

return wrapper


class Store:
@staticmethod
def sanitize_column(column_name: str) -> str:
Expand All @@ -34,35 +48,39 @@ def reset(self):
if self._database_path.exists():
self._database_path.unlink()

@table_name_to_lower
def get_table(self, table_name: str) -> pyarrow.Table:
if not self.table_exists(table_name):
raise ValueError(T("coal.errors.data.no_table").format(table_name=table_name))
return self.execute_query(f"select * from {table_name}")
return self.execute_query(f'select * from "{table_name}"')

@table_name_to_lower
def table_exists(self, table_name) -> bool:
return table_name in self.list_tables()

@table_name_to_lower
def get_table_schema(self, table_name: str) -> pyarrow.Schema:
if not self.table_exists(table_name):
raise ValueError(T("coal.errors.data.no_table").format(table_name=table_name))
with dbapi.connect(self._database) as conn:
return conn.adbc_get_table_schema(table_name)

@table_name_to_lower
def add_table(self, table_name: str, data=pyarrow.Table, replace: bool = False):
with dbapi.connect(self._database, autocommit=True) as conn:
with conn.cursor() as curs:
rows = curs.adbc_ingest(table_name, data, "replace" if replace else "create_append")
LOGGER.debug(T("coal.common.data_transfer.rows_inserted").format(rows=rows, table_name=table_name))

def execute_query(self, sql_query: str) -> pyarrow.Table:
def execute_query(self, sql_query: str, parameters: list = (None,)) -> pyarrow.Table:
batch_size = 1024
batch_size_increment = 1024
while True:
try:
with dbapi.connect(self._database, autocommit=True) as conn:
with conn.cursor() as curs:
curs.adbc_statement.set_options(**{"adbc.sqlite.query.batch_rows": str(batch_size)})
curs.execute(sql_query)
curs.execute(sql_query, parameters)
return curs.fetch_arrow_table()
except OSError:
batch_size += batch_size_increment
Expand Down
68 changes: 68 additions & 0 deletions tests/integration/coal/test_store/test_store_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (C) - 2022 - 2025 - Cosmo Tech
# This document and all information contained herein is the exclusive property -
# including all intellectual property rights pertaining thereto - of Cosmo Tech.
# Any use, reproduction, translation, broadcasting, transmission, distribution,
# etc., to any person is prohibited unless it has been previously and
# specifically authorized by written means by Cosmo Tech.

import pyarrow as pa
import pytest

from cosmotech.coal.store.store import Store


@pytest.fixture(scope="function")
def store():
store = Store(reset=True)
yield store
store.reset()


class TestStore:
"""Tests for the store class."""

def test_get_table(self, store):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could add a test with case sensitive names, "UPPER" and "upper" should be either the same or different tables
We need to chose the standard working here, according to sql they should be the same, but we could have user having a "File.csv" and "file.csv" living next to each other.
Current implementation would have "File" and "file" being 2 different tables in the get_table, but we need to make sure that the add_table act identical then.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Table FILE and file are the same in the sqlite.

"""Test get table with table name starting with numbers"""

# Arrange
table_name = "normal_name"
table = pa.Table.from_arrays([pa.array([1, 2, 3]), pa.array(["a", "b", "c"])], names=["id", "name"])
store.add_table(table_name, table)

# Act
result = store.get_table(table_name)

# Assert
assert result

def test_get_table_with_number_name(self, store):
"""Test get table with table name starting with numbers"""

# Arrange
table_name = "10mb"
table = pa.Table.from_arrays([pa.array([1, 2, 3]), pa.array(["a", "b", "c"])], names=["id", "name"])
store.add_table(table_name, table)

# Act
result = store.get_table(table_name)

# Assert
assert result

def test_add_get_table_with_upper_and_lower_case(self, store):
"""Test add table and get table behavior with uppper and lower cases"""

# Arrange

table = pa.Table.from_arrays([pa.array([1, 2, 3]), pa.array(["a", "b", "c"])], names=["id", "name"])
store.add_table("10mb", table)
table = pa.Table.from_arrays([pa.array([4, 5, 6]), pa.array(["A", "B", "C"])], names=["id", "name"])
store.add_table("10MB", table)

# Act
UPPER_result = store.get_table("10MB")
upper_result = store.get_table("10mb")

assert upper_result
assert UPPER_result
assert upper_result == UPPER_result
2 changes: 1 addition & 1 deletion tests/unit/coal/test_postgresql/test_postgresql_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def test_remove_runner_metadata_to_postgresql(self, mock_connect, mock_postgres_
mock_runner = {
"id": "test-runner-id",
"name": "Test Runner",
"lastRunId": "test-run-id",
"lastRunInfo": {"lastRunId": "test-run-id"},
"runTemplateId": "test-template-id",
}

Expand Down
6 changes: 3 additions & 3 deletions tests/unit/coal/test_store/test_store_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_get_table(self, mock_execute_query, mock_table_exists):

# Assert
mock_table_exists.assert_called_once_with(table_name)
mock_execute_query.assert_called_once_with(f"select * from {table_name}")
mock_execute_query.assert_called_once_with(f'select * from "{table_name}"')
assert result == expected_table

@patch.object(Store, "table_exists")
Expand Down Expand Up @@ -270,7 +270,7 @@ def test_execute_query(self, mock_connect):
# Assert
mock_connect.assert_called_once()
mock_cursor.adbc_statement.set_options.assert_called_once_with(**{"adbc.sqlite.query.batch_rows": "1024"})
mock_cursor.execute.assert_called_once_with(sql_query)
mock_cursor.execute.assert_called_once_with(sql_query, (None,))
mock_cursor.fetch_arrow_table.assert_called_once()
assert result == expected_table

Expand Down Expand Up @@ -306,7 +306,7 @@ def test_execute_query_with_oserror(self, mock_connect):
# First call with batch_size = 1024, second with batch_size = 2048
mock_cursor.adbc_statement.set_options.assert_any_call(**{"adbc.sqlite.query.batch_rows": "1024"})
mock_cursor.adbc_statement.set_options.assert_any_call(**{"adbc.sqlite.query.batch_rows": "2048"})
mock_cursor.execute.assert_called_once_with(sql_query)
mock_cursor.execute.assert_called_once_with(sql_query, (None,))
mock_cursor.fetch_arrow_table.assert_called_once()
assert result == expected_table

Expand Down
Loading