Skip to content

Commit a6c7228

Browse files
committed
Refactor metadata around data source table configs
Why these changes are being introduced: * Continues the data source refactor by removing the split between metadata type configs, valid tables, and current view specs. * The previous structure duplicated table metadata across the data sources and metadata layer, which made readable tables and derived metadata columns harder to maintain. How this addresses that need: * Introduce DataSourceTableConfig so each data source defines its base and custom readable tables in one place. * Derive metadata/read column lists directly from each source schema and use source classes throughout TIMDEXDatasetMetadata. * Update exports and tests to use the new source-driven table and column interfaces. Relevant ticket(s): * https://mitlibraries.atlassian.net/browse/USE-496
1 parent 8eadf94 commit a6c7228

10 files changed

Lines changed: 313 additions & 442 deletions

File tree

tests/test_embeddings.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from timdex_dataset_api import TIMDEXDataset
1313
from timdex_dataset_api.embeddings import DatasetEmbedding, TIMDEXEmbeddings
1414

15-
EMBEDDINGS_DEFAULT_COLUMNS_SET = set(TIMDEXEmbeddings.DEFAULT_READ_COLUMNS)
15+
EMBEDDINGS_AVAILABLE_COLUMNS_SET = set(TIMDEXEmbeddings.AVAILABLE_READ_COLUMNS)
1616

1717

1818
def test_dataset_embedding_init():
@@ -145,7 +145,7 @@ def test_embeddings_read_batches_yields_pyarrow_record_batches(
145145
def test_embeddings_read_batches_all_columns_by_default(timdex_embeddings_with_runs):
146146
batches = timdex_embeddings_with_runs.read_batches_iter()
147147
batch = next(batches)
148-
assert set(batch.column_names) == EMBEDDINGS_DEFAULT_COLUMNS_SET
148+
assert set(batch.column_names) == EMBEDDINGS_AVAILABLE_COLUMNS_SET
149149

150150

151151
def test_embeddings_read_batches_filter_columns(timdex_embeddings_with_runs):
@@ -273,7 +273,7 @@ def test_embeddings_read_dicts_yields_dictionary_for_each_embeddings_record(
273273
dict_iter = timdex_embeddings_with_runs.read_dicts_iter()
274274
record = next(dict_iter)
275275
assert isinstance(record, dict)
276-
assert set(record.keys()) == EMBEDDINGS_DEFAULT_COLUMNS_SET
276+
assert set(record.keys()) == EMBEDDINGS_AVAILABLE_COLUMNS_SET
277277

278278

279279
def test_current_embeddings_view_single_run(timdex_dataset_for_embeddings_views):

tests/test_metadata.py

Lines changed: 56 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
import os
55
from pathlib import Path
66

7+
import pytest
78
from duckdb import DuckDBPyConnection
89

910
from tests.utils import generate_sample_embeddings_for_run, generate_sample_records
1011
from timdex_dataset_api import TIMDEXDataset
12+
from timdex_dataset_api.data_source import TIMDEXDataSource
1113
from timdex_dataset_api.embeddings import TIMDEXEmbeddings
12-
from timdex_dataset_api.metadata import DataTypeMetadataConfig, TIMDEXDatasetMetadata
14+
from timdex_dataset_api.metadata import TIMDEXDatasetMetadata
1315
from timdex_dataset_api.records import TIMDEXRecords
1416

1517

@@ -33,46 +35,53 @@ def test_tdm_s3_dataset_structure_properties(timdex_dataset_empty):
3335
assert timdex_dataset_empty.location_scheme == "file"
3436

3537

36-
def test_data_type_metadata_config_prejoin_records_default_true():
37-
config = DataTypeMetadataConfig(
38-
name="example",
39-
metadata_columns=["timdex_record_id"],
40-
data_path="data/example",
41-
)
42-
assert config.prejoin_records is True
43-
44-
45-
def test_data_source_metadata_configs_are_derived_from_base_class():
46-
assert TIMDEXRecords.METADATA_CONFIG.name == TIMDEXRecords.NAME
47-
assert TIMDEXRecords.METADATA_CONFIG.data_path == TIMDEXRecords.DATA_PATH
48-
assert TIMDEXRecords.METADATA_CONFIG.prejoin_records is False
38+
def test_data_source_metadata_columns_are_derived_from_base_class():
4939
assert (
50-
TIMDEXRecords.METADATA_CONFIG.metadata_columns
40+
TIMDEXRecords.SOURCE_METADATA_COLUMNS
5141
== TIMDEXDatasetMetadata.BASE_METADATA_COLUMNS
5242
)
43+
assert TIMDEXRecords.METADATA_COLUMNS == TIMDEXDatasetMetadata.BASE_METADATA_COLUMNS
5344

54-
assert TIMDEXEmbeddings.METADATA_CONFIG.name == TIMDEXEmbeddings.NAME
55-
assert TIMDEXEmbeddings.METADATA_CONFIG.data_path == TIMDEXEmbeddings.DATA_PATH
56-
assert TIMDEXEmbeddings.METADATA_CONFIG.prejoin_records is True
57-
assert TIMDEXEmbeddings.METADATA_CONFIG.metadata_columns == [
45+
assert TIMDEXEmbeddings.SOURCE_METADATA_COLUMNS == [
5846
"timdex_record_id",
5947
"run_id",
6048
"run_record_offset",
61-
*TIMDEXEmbeddings.ADDITIONAL_METADATA_COLUMNS,
6249
"filename",
50+
"embedding_timestamp",
51+
"embedding_model",
52+
"embedding_strategy",
6353
]
54+
assert [
55+
*TIMDEXDatasetMetadata.BASE_METADATA_COLUMNS,
56+
"embedding_timestamp",
57+
"embedding_model",
58+
"embedding_strategy",
59+
] == TIMDEXEmbeddings.METADATA_COLUMNS
60+
61+
62+
def test_data_source_subclass_requires_contract_vars():
63+
with pytest.raises(
64+
TypeError,
65+
match=(
66+
"InvalidDataSource must define required class vars: "
67+
"SCHEMA, DATA_COLUMNS, DATA_PATH"
68+
),
69+
):
6470

71+
class InvalidDataSource(TIMDEXDataSource):
72+
NAME = "invalid"
6573

66-
def test_dataset_registers_current_view_specs_from_data_sources(tmp_path):
67-
td = TIMDEXDataset(str(tmp_path / "register_current_view_specs"))
6874

69-
expected_view_names = [
70-
spec.name
71-
for spec in (
72-
TIMDEXRecords.CURRENT_VIEW_SPECS + TIMDEXEmbeddings.CURRENT_VIEW_SPECS
73-
)
75+
def test_dataset_registers_table_configs_from_data_sources(tmp_path):
76+
td = TIMDEXDataset(str(tmp_path / "register_table_configs"))
77+
78+
expected_table_names = [
79+
table_config.name
80+
for table_config in (TIMDEXRecords.TABLES + TIMDEXEmbeddings.TABLES)
7481
]
75-
assert [spec.name for spec in td.current_metadata_view_specs] == expected_view_names
82+
assert [
83+
table_config.name for table_config in td.table_configs
84+
] == expected_table_names
7685

7786

7887
def test_tdm_create_metadata_database_file_success(
@@ -136,9 +145,7 @@ def test_tdm_views_created_on_init(timdex_metadata):
136145
assert expected_views <= actual_views
137146

138147

139-
def test_tdm_current_view_specs_missing_dependencies_are_skipped_generically(
140-
caplog, tmp_path
141-
):
148+
def test_tdm_custom_tables_missing_dependencies_are_skipped_generically(caplog, tmp_path):
142149
dataset_path = str(tmp_path / "current_view_missing_dependencies")
143150

144151
td = TIMDEXDataset(dataset_path)
@@ -166,25 +173,28 @@ def test_tdm_current_view_specs_missing_dependencies_are_skipped_generically(
166173
""").to_df()
167174
metadata_names = set(metadata_objects.table_name)
168175

169-
missing_specs = []
170-
for spec in td_with_metadata.current_metadata_view_specs:
176+
missing_tables = []
177+
for table_config in td_with_metadata.table_configs:
178+
if table_config.kind != "custom":
179+
continue
180+
171181
missing_required_tables = [
172182
table_name
173-
for table_name in spec.required_metadata_tables
183+
for table_name in table_config.required_metadata_tables
174184
if table_name not in metadata_names
175185
]
176186
if not missing_required_tables:
177187
continue
178188

179-
missing_specs.append(spec.name)
180-
assert spec.name not in metadata_names
189+
missing_tables.append(table_config.name)
190+
assert table_config.name not in metadata_names
181191
assert (
182192
"Skipping metadata."
183-
f"{spec.name} view creation because missing dependencies: "
193+
f"{table_config.name} view creation because missing dependencies: "
184194
f"{', '.join(missing_required_tables)}"
185195
) in caplog.text
186196

187-
assert missing_specs
197+
assert missing_tables
188198

189199

190200
def test_tdm_records_view_structure(timdex_metadata):
@@ -374,7 +384,7 @@ def test_tdm_merge_append_deltas_static_counts_match_records_count_before_merge(
374384
def test_tdm_merge_append_deltas_adds_records_to_static_db(
375385
timdex_metadata_with_deltas, timdex_metadata_merged_deltas
376386
):
377-
columns = ",".join(TIMDEXRecords.METADATA_CONFIG.metadata_columns)
387+
columns = ",".join(TIMDEXRecords.SOURCE_METADATA_COLUMNS)
378388
append_deltas = timdex_metadata_with_deltas.timdex_dataset.conn.query(f"""
379389
select
380390
{columns}
@@ -396,10 +406,10 @@ def test_tdm_merge_append_deltas_deletes_append_deltas(
396406
timdex_metadata_with_deltas, timdex_metadata_merged_deltas
397407
):
398408
records_deltas_path_before = timdex_metadata_with_deltas.append_deltas_path_for(
399-
TIMDEXRecords.METADATA_CONFIG
409+
TIMDEXRecords
400410
)
401411
records_deltas_path_after = timdex_metadata_merged_deltas.append_deltas_path_for(
402-
TIMDEXRecords.METADATA_CONFIG
412+
TIMDEXRecords
403413
)
404414

405415
assert timdex_metadata_with_deltas.append_deltas_count != 0
@@ -436,14 +446,7 @@ def test_tdm_embeddings_metadata_view_structure(tmp_path):
436446
"""select * from metadata.embeddings limit 1;"""
437447
).to_df()
438448
assert len(embeddings_df) == 1
439-
# pre-joined view includes native embeddings columns + records columns
440-
expected_columns = set(TIMDEXEmbeddings.METADATA_CONFIG.metadata_columns) | {
441-
"source",
442-
"run_date",
443-
"run_type",
444-
"action",
445-
"run_timestamp",
446-
}
449+
expected_columns = set(TIMDEXEmbeddings.METADATA_COLUMNS)
447450
assert set(embeddings_df.columns) == expected_columns
448451

449452

@@ -475,14 +478,7 @@ def test_tdm_current_embeddings_view_structure(tmp_path):
475478
).to_df()
476479

477480
assert len(current_embeddings_df) == 1
478-
# pre-joined view includes native embeddings columns + records columns
479-
expected_columns = set(TIMDEXEmbeddings.METADATA_CONFIG.metadata_columns) | {
480-
"source",
481-
"run_date",
482-
"run_type",
483-
"action",
484-
"run_timestamp",
485-
}
481+
expected_columns = set(TIMDEXEmbeddings.METADATA_COLUMNS)
486482
assert set(current_embeddings_df.columns) == expected_columns
487483

488484

@@ -590,14 +586,7 @@ def test_tdm_current_run_embeddings_view_structure(tmp_path):
590586
).to_df()
591587

592588
assert len(current_run_embeddings_df) == 1
593-
# pre-joined view includes native embeddings columns + records columns
594-
expected_columns = set(TIMDEXEmbeddings.METADATA_CONFIG.metadata_columns) | {
595-
"source",
596-
"run_date",
597-
"run_type",
598-
"action",
599-
"run_timestamp",
600-
}
589+
expected_columns = set(TIMDEXEmbeddings.METADATA_COLUMNS)
601590
assert set(current_run_embeddings_df.columns) == expected_columns
602591

603592

@@ -747,11 +736,7 @@ def test_tdm_keyset_paginated_query_on_prejoined_embeddings_view(tmp_path):
747736
# execute and verify results
748737
result_df = td.conn.query(query).to_df()
749738
assert len(result_df) == 10 # noqa: PLR2004
750-
expected_cols = set(
751-
TIMDEXDatasetMetadata.BASE_METADATA_COLUMNS
752-
+ TIMDEXEmbeddings.ADDITIONAL_METADATA_COLUMNS
753-
+ ["run_id_hash", "filename_hash"]
754-
)
739+
expected_cols = {*TIMDEXEmbeddings.METADATA_COLUMNS, "run_id_hash", "filename_hash"}
755740
assert set(result_df.columns) == expected_cols
756741

757742

@@ -783,9 +768,7 @@ def test_tdm_embeddings_write_append_deltas_without_static_embeddings_table(tmp_
783768
"""select count(*) from metadata.embeddings_append_deltas;"""
784769
).fetchone()[0]
785770

786-
embeddings_deltas_path = td.metadata.append_deltas_path_for(
787-
TIMDEXEmbeddings.METADATA_CONFIG
788-
)
771+
embeddings_deltas_path = td.metadata.append_deltas_path_for(TIMDEXEmbeddings)
789772
assert embeddings_count == record_count
790773
assert embeddings_deltas_count == record_count
791774
assert os.listdir(embeddings_deltas_path)

tests/test_read.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from timdex_dataset_api.records import TIMDEXRecords
1010

11-
DATASET_COLUMNS_SET = set(TIMDEXRecords.DEFAULT_READ_COLUMNS)
11+
DATASET_COLUMNS_SET = set(TIMDEXRecords.AVAILABLE_READ_COLUMNS)
1212

1313

1414
def _count_rows_via_duckdb_parquet(timdex_dataset) -> int:

tests/test_write.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def test_dataset_write_single_append_delta_success(
159159
):
160160
written_files = timdex_dataset_empty.records.write(sample_records_generator(1_000))
161161
records_deltas_path = timdex_dataset_empty.metadata.append_deltas_path_for(
162-
TIMDEXRecords.METADATA_CONFIG
162+
TIMDEXRecords
163163
)
164164
append_deltas = os.listdir(records_deltas_path)
165165

@@ -175,7 +175,7 @@ def test_dataset_write_multiple_append_deltas_success(
175175

176176
written_files = timdex_dataset_empty.records.write(sample_records_generator(1_000))
177177
records_deltas_path = timdex_dataset_empty.metadata.append_deltas_path_for(
178-
TIMDEXRecords.METADATA_CONFIG
178+
TIMDEXRecords
179179
)
180180
append_deltas = os.listdir(records_deltas_path)
181181

@@ -188,9 +188,9 @@ def test_dataset_write_append_delta_expected_metadata_columns(
188188
):
189189
timdex_dataset_empty.records.write(sample_records_generator(1_000))
190190
records_deltas_path = timdex_dataset_empty.metadata.append_deltas_path_for(
191-
TIMDEXRecords.METADATA_CONFIG
191+
TIMDEXRecords
192192
)
193193
append_delta_filepath = os.listdir(records_deltas_path)[0]
194194

195195
append_delta = pq.ParquetFile(Path(records_deltas_path) / append_delta_filepath)
196-
assert append_delta.schema.names == TIMDEXRecords.METADATA_CONFIG.metadata_columns
196+
assert append_delta.schema.names == TIMDEXRecords.SOURCE_METADATA_COLUMNS

timdex_dataset_api/__init__.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,21 @@
22

33
from importlib.metadata import version
44

5-
from timdex_dataset_api.data_source import TIMDEXDataSource, ValidTable
5+
from timdex_dataset_api.data_source import DataSourceTableConfig, TIMDEXDataSource
66
from timdex_dataset_api.dataset import TIMDEXDataset
77
from timdex_dataset_api.embeddings import DatasetEmbedding, TIMDEXEmbeddings
8-
from timdex_dataset_api.metadata import (
9-
CurrentMetadataViewSpec,
10-
DataTypeMetadataConfig,
11-
TIMDEXDatasetMetadata,
12-
)
8+
from timdex_dataset_api.metadata import TIMDEXDatasetMetadata
139
from timdex_dataset_api.records import DatasetRecord, TIMDEXRecords
1410

1511
__version__ = version("timdex_dataset_api")
1612

1713
__all__ = [
18-
"CurrentMetadataViewSpec",
19-
"DataTypeMetadataConfig",
14+
"DataSourceTableConfig",
2015
"DatasetEmbedding",
2116
"DatasetRecord",
2217
"TIMDEXDataSource",
2318
"TIMDEXDataset",
2419
"TIMDEXDatasetMetadata",
2520
"TIMDEXEmbeddings",
2621
"TIMDEXRecords",
27-
"ValidTable",
2822
]

0 commit comments

Comments
 (0)