Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions pyiceberg/catalog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from collections.abc import Callable
from dataclasses import dataclass
from enum import Enum
from functools import lru_cache
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -74,7 +75,11 @@

logger = logging.getLogger(__name__)

_ENV_CONFIG = Config()

@lru_cache(maxsize=1)
def _get_env_config() -> Config:
return Config.load()


TOKEN = "token"
TYPE = "type"
Expand Down Expand Up @@ -243,9 +248,9 @@ def load_catalog(name: str | None = None, **properties: str | None) -> Catalog:
or if it could not determine the catalog based on the properties.
"""
if name is None:
name = _ENV_CONFIG.get_default_catalog_name()
name = _get_env_config().get_default_catalog_name()

env = _ENV_CONFIG.get_catalog_config(name)
env = _get_env_config().get_catalog_config(name)
conf: RecursiveDict = merge_config(env or {}, cast(RecursiveDict, properties))

catalog_type: CatalogType | None
Expand Down Expand Up @@ -278,7 +283,7 @@ def load_catalog(name: str | None = None, **properties: str | None) -> Catalog:


def list_catalogs() -> list[str]:
return _ENV_CONFIG.get_known_catalogs()
return _get_env_config().get_known_catalogs()


def delete_files(io: FileIO, files_to_delete: set[str], file_type: str) -> None:
Expand Down Expand Up @@ -781,7 +786,7 @@ def _convert_schema_if_needed(

from pyiceberg.io.pyarrow import _ConvertToIcebergWithoutIDs, visit_pyarrow

downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
downcast_ns_timestamp_to_us = Config.load().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
if isinstance(schema, pa.Schema):
schema: Schema = visit_pyarrow( # type: ignore
schema,
Expand Down
2 changes: 1 addition & 1 deletion pyiceberg/catalog/bigquery_metastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(self, name: str, **properties: str):
raise ValueError(f"Missing property: {GCP_PROJECT_ID}")

# BigQuery requires current-snapshot-id to be present for tables to be created.
if not Config().get_bool("legacy-current-snapshot-id"):
if not Config.load().get_bool("legacy-current-snapshot-id"):
raise ValueError("legacy-current-snapshot-id must be enabled to work with BigQuery.")

if credentials_file and credentials_info_str:
Expand Down
6 changes: 3 additions & 3 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1744,7 +1744,7 @@ def __init__(
self._bound_row_filter = bind(table_metadata.schema(), row_filter, case_sensitive=case_sensitive)
self._case_sensitive = case_sensitive
self._limit = limit
self._downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE)
self._downcast_ns_timestamp_to_us = Config.load().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE)

@property
def _projected_field_ids(self) -> set[int]:
Expand Down Expand Up @@ -2685,7 +2685,7 @@ def write_parquet(task: WriteTask) -> DataFile:
else:
file_schema = table_schema

downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
downcast_ns_timestamp_to_us = Config.load().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
batches = [
_to_requested_schema(
requested_schema=file_schema,
Expand Down Expand Up @@ -2892,7 +2892,7 @@ def _dataframe_to_data_files(
default=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT,
)
name_mapping = table_metadata.schema().name_mapping
downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
downcast_ns_timestamp_to_us = Config.load().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
task_schema = pyarrow_to_schema(
df.schema,
name_mapping=name_mapping,
Expand Down
2 changes: 1 addition & 1 deletion pyiceberg/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def table_metadata(metadata: TableMetadata, output_file: OutputFile, overwrite:
"""
with output_file.create(overwrite=overwrite) as output_stream:
# We need to serialize None values, in order to dump `None` current-snapshot-id as `-1`
exclude_none = False if Config().get_bool("legacy-current-snapshot-id") else True
exclude_none = False if Config.load().get_bool("legacy-current-snapshot-id") else True

json_bytes = metadata.model_dump_json(exclude_none=exclude_none).encode(UTF8)
json_bytes = Compressor.get_compressor(output_file.location).bytes_compressor()(json_bytes)
Expand Down
8 changes: 4 additions & 4 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ def append(self, df: pa.Table, snapshot_properties: dict[str, str] = EMPTY_DICT,
if not isinstance(df, pa.Table):
raise ValueError(f"Expected PyArrow table, got: {df}")

downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
downcast_ns_timestamp_to_us = Config.load().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
_check_pyarrow_schema_compatible(
self.table_metadata.schema(),
provided_schema=df.schema,
Expand Down Expand Up @@ -523,7 +523,7 @@ def dynamic_partition_overwrite(
f"in the latest partition spec: {field}"
)

downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
downcast_ns_timestamp_to_us = Config.load().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
_check_pyarrow_schema_compatible(
self.table_metadata.schema(),
provided_schema=df.schema,
Expand Down Expand Up @@ -588,7 +588,7 @@ def overwrite(
if not isinstance(df, pa.Table):
raise ValueError(f"Expected PyArrow table, got: {df}")

downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
downcast_ns_timestamp_to_us = Config.load().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
_check_pyarrow_schema_compatible(
self.table_metadata.schema(),
provided_schema=df.schema,
Expand Down Expand Up @@ -787,7 +787,7 @@ def upsert(

from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible

downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
downcast_ns_timestamp_to_us = Config.load().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
_check_pyarrow_schema_compatible(
self.table_metadata.schema(),
provided_schema=df.schema,
Expand Down
2 changes: 1 addition & 1 deletion pyiceberg/table/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def sort_order_by_id(self, sort_order_id: int) -> SortOrder | None:

@field_serializer("current_snapshot_id")
def serialize_current_snapshot_id(self, current_snapshot_id: int | None) -> int | None:
if current_snapshot_id is None and Config().get_bool("legacy-current-snapshot-id"):
if current_snapshot_id is None and Config.load().get_bool("legacy-current-snapshot-id"):
return -1
return current_snapshot_id

Expand Down
2 changes: 1 addition & 1 deletion pyiceberg/utils/concurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class ExecutorFactory:
@staticmethod
def max_workers() -> int | None:
"""Return the max number of workers configured."""
return Config().get_int("max-workers")
return Config.load().get_int("max-workers")

@staticmethod
def get_or_create() -> Executor:
Expand Down
14 changes: 10 additions & 4 deletions pyiceberg/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import logging
import os

Expand Down Expand Up @@ -59,10 +61,14 @@ def _lowercase_dictionary_keys(input_dict: RecursiveDict) -> RecursiveDict:
class Config:
config: RecursiveDict

def __init__(self) -> None:
config = self._from_configuration_files() or {}
config = merge_config(config, self._from_environment_variables(config))
self.config = FrozenDict(**config)
def __init__(self, config: RecursiveDict | None = None) -> None:
self.config = FrozenDict(**(config or {}))

@classmethod
def load(cls) -> Config:
config = cls._from_configuration_files() or {}
config = merge_config(config, cls._from_environment_variables(config))
return cls(config)

@staticmethod
def _from_configuration_files() -> RecursiveDict | None:
Expand Down
19 changes: 17 additions & 2 deletions tests/catalog/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,18 @@
# under the License.
# pylint:disable=redefined-outer-name


from collections.abc import Generator
from pathlib import PosixPath

import pytest
from pytest_mock import MockFixture

from pyiceberg.catalog import Catalog, load_catalog
from pyiceberg.catalog import Catalog, _get_env_config, load_catalog
from pyiceberg.catalog.memory import InMemoryCatalog
from pyiceberg.io import WAREHOUSE
from pyiceberg.schema import Schema
from pyiceberg.types import NestedField, StringType
from pyiceberg.utils.config import Config


@pytest.fixture
Expand Down Expand Up @@ -64,6 +65,20 @@ def test_load_catalog_has_type_and_impl() -> None:
)


def test_get_env_config_is_lazy_and_cached(mocker: MockFixture) -> None:
original_config = _get_env_config()
_get_env_config.cache_clear()
config = Config({"catalog": {"test": {"type": "in-memory"}}})
load_mock = mocker.patch("pyiceberg.catalog.Config.load", return_value=config)
assert _get_env_config() is config
assert _get_env_config() is config
load_mock.assert_called_once()

_get_env_config.cache_clear()
mocker.patch("pyiceberg.catalog.Config.load", return_value=original_config)
assert _get_env_config() is original_config


def test_catalog_repr(catalog: InMemoryCatalog) -> None:
s = repr(catalog)
assert s == "test.in_memory.catalog (<class 'pyiceberg.catalog.memory.InMemoryCatalog'>)"
Expand Down
12 changes: 6 additions & 6 deletions tests/catalog/test_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1887,18 +1887,18 @@ def test_catalog_from_environment_variables(catalog_config_mock: mock.Mock, rest


@mock.patch.dict(os.environ, EXAMPLE_ENV)
@mock.patch("pyiceberg.catalog._ENV_CONFIG.get_catalog_config")
def test_catalog_from_environment_variables_override(catalog_config_mock: mock.Mock, rest_mock: Mocker) -> None:
def test_catalog_from_environment_variables_override(rest_mock: Mocker) -> None:
rest_mock.get(
"https://other-service.io/api/v1/config",
json={"defaults": {}, "overrides": {}},
status_code=200,
)
env_config: RecursiveDict = Config._from_environment_variables({})

catalog_config_mock.return_value = cast(RecursiveDict, env_config.get("catalog")).get("production")
catalog = cast(RestCatalog, load_catalog("production", uri="https://other-service.io/api"))
assert catalog.uri == "https://other-service.io/api"
mock_env_config = mock.Mock()
mock_env_config.get_catalog_config.return_value = cast(RecursiveDict, env_config.get("catalog")).get("production")
with mock.patch("pyiceberg.catalog._get_env_config", return_value=mock_env_config):
catalog = cast(RestCatalog, load_catalog("production", uri="https://other-service.io/api"))
assert catalog.uri == "https://other-service.io/api"


def test_catalog_from_parameters_empty_env(rest_mock: Mocker) -> None:
Expand Down
6 changes: 3 additions & 3 deletions tests/cli/test_console.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,19 @@
def test_missing_uri(mocker: MockFixture, empty_home_dir_path: str) -> None:
# mock to prevent parsing ~/.pyiceberg.yaml or {PYICEBERG_HOME}/.pyiceberg.yaml
mocker.patch.dict(os.environ, values={"HOME": empty_home_dir_path, "PYICEBERG_HOME": empty_home_dir_path})
mocker.patch("pyiceberg.catalog._ENV_CONFIG", return_value=Config())
mocker.patch("pyiceberg.catalog._get_env_config", return_value=Config())

runner = CliRunner()
result = runner.invoke(run, ["list"])

assert result.exit_code == 1
assert result.output == "Could not initialize catalog with the following properties: {}\n"
assert "URI missing, please provide using --uri" in result.output


def test_hive_catalog_missing_uri_shows_helpful_error(mocker: MockFixture) -> None:
mock_env_config = mocker.MagicMock()
mock_env_config.get_catalog_config.return_value = {"type": "hive"}
mocker.patch("pyiceberg.catalog._ENV_CONFIG", mock_env_config)
mocker.patch("pyiceberg.catalog._get_env_config", return_value=mock_env_config)

runner = CliRunner()
result = runner.invoke(run, ["--catalog", "my_hive_catalog", "list"])
Expand Down
3 changes: 2 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ def _isolate_pyiceberg_config() -> None:
from pyiceberg.utils.config import Config

with mock.patch.object(Config, "_from_configuration_files", return_value=None):
_catalog_module._ENV_CONFIG = Config()
_catalog_module._get_env_config.cache_clear()
_catalog_module._get_env_config()


def pytest_addoption(parser: pytest.Parser) -> None:
Expand Down
43 changes: 33 additions & 10 deletions tests/utils/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,30 @@


def test_config() -> None:
"""To check if all the file lookups go well without any mocking"""
"""Config() should be a pure empty container and perform no implicit IO."""
assert Config()
assert Config().config == {}


def test_config_does_not_load_implicitly() -> None:
with (
mock.patch.object(Config, "_from_configuration_files") as from_files_mock,
mock.patch.object(Config, "_from_environment_variables") as from_env_mock,
):
Config()

from_files_mock.assert_not_called()
from_env_mock.assert_not_called()


@mock.patch.dict(os.environ, EXAMPLE_ENV)
def test_from_environment_variables() -> None:
assert Config().get_catalog_config("production") == {"uri": "https://service.io/api"}
assert Config.load().get_catalog_config("production") == {"uri": "https://service.io/api"}


@mock.patch.dict(os.environ, EXAMPLE_ENV)
def test_from_environment_variables_uppercase() -> None:
assert Config().get_catalog_config("PRODUCTION") == {"uri": "https://service.io/api"}
assert Config.load().get_catalog_config("PRODUCTION") == {"uri": "https://service.io/api"}


@mock.patch.dict(
Expand All @@ -50,7 +62,7 @@ def test_from_environment_variables_uppercase() -> None:
},
)
def test_fix_nested_objects_from_environment_variables() -> None:
assert Config().get_catalog_config("PRODUCTION") == {
assert Config.load().get_catalog_config("PRODUCTION") == {
"s3.region": "eu-north-1",
"s3.access-key-id": "username",
}
Expand All @@ -59,7 +71,7 @@ def test_fix_nested_objects_from_environment_variables() -> None:
@mock.patch.dict(os.environ, EXAMPLE_ENV)
@mock.patch.dict(os.environ, {"PYICEBERG_CATALOG__DEVELOPMENT__URI": "https://dev.service.io/api"})
def test_list_all_known_catalogs() -> None:
catalogs = Config().get_known_catalogs()
catalogs = Config.load().get_known_catalogs()
assert "production" in catalogs
assert "development" in catalogs

Expand All @@ -71,7 +83,7 @@ def test_from_configuration_files(tmp_path_factory: pytest.TempPathFactory) -> N
file.write(yaml_str)

os.environ["PYICEBERG_HOME"] = config_path
assert Config().get_catalog_config("production") == {"uri": "https://service.io/api"}
assert Config.load().get_catalog_config("production") == {"uri": "https://service.io/api"}


def test_lowercase_dictionary_keys() -> None:
Expand All @@ -95,13 +107,13 @@ def test_from_configuration_files_get_typed_value(tmp_path_factory: pytest.TempP

os.environ["PYICEBERG_HOME"] = config_path
with pytest.raises(ValueError):
Config().get_bool("max-workers")
Config.load().get_bool("max-workers")

with pytest.raises(ValueError):
Config().get_int("legacy-current-snapshot-id")
Config.load().get_int("legacy-current-snapshot-id")

assert Config().get_bool("legacy-current-snapshot-id")
assert Config().get_int("max-workers") == 4
assert Config.load().get_bool("legacy-current-snapshot-id")
assert Config.load().get_int("max-workers") == 4


@pytest.mark.parametrize(
Expand Down Expand Up @@ -183,3 +195,14 @@ def create_config_file(path: str, uri: str | None) -> None:
assert (
result["catalog"]["default"]["uri"] if result else None # type: ignore
) == expected_result, f"Unexpected configuration result. Expected: {expected_result}, Actual: {result}"


def test_load_reads_file_and_environment_once(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("PYICEBERG_CATALOG__PRODUCTION__URI", "https://env.service.io/api")
with mock.patch.object(
Config, "_from_configuration_files", return_value={"catalog": {"production": {"type": "rest"}}}
) as files_mock:
config = Config.load()

files_mock.assert_called_once()
assert config.get_catalog_config("production") == {"type": "rest", "uri": "https://env.service.io/api"}