diff --git a/openml/__init__.py b/openml/__init__.py index ae5db261f..47bc86b4d 100644 --- a/openml/__init__.py +++ b/openml/__init__.py @@ -18,9 +18,11 @@ # License: BSD 3-Clause from __future__ import annotations +from typing import TYPE_CHECKING + from . import ( _api_calls, - config, + _config as _config_module, datasets, evaluations, exceptions, @@ -33,6 +35,7 @@ utils, ) from .__version__ import __version__ +from ._api import _backend from .datasets import OpenMLDataFeature, OpenMLDataset from .evaluations import OpenMLEvaluation from .flows import OpenMLFlow @@ -49,6 +52,11 @@ OpenMLTask, ) +if TYPE_CHECKING: + from ._config import OpenMLConfigManager + +config: OpenMLConfigManager = _config_module.__config + def populate_cache( task_ids: list[int] | None = None, @@ -109,6 +117,7 @@ def populate_cache( "OpenMLTask", "__version__", "_api_calls", + "_backend", "config", "datasets", "evaluations", diff --git a/openml/_api/__init__.py b/openml/_api/__init__.py new file mode 100644 index 000000000..7766016d1 --- /dev/null +++ b/openml/_api/__init__.py @@ -0,0 +1,85 @@ +from .clients import ( + HTTPCache, + HTTPClient, + MinIOClient, +) +from .resources import ( + API_REGISTRY, + DatasetAPI, + DatasetV1API, + DatasetV2API, + EstimationProcedureAPI, + EstimationProcedureV1API, + EstimationProcedureV2API, + EvaluationAPI, + EvaluationMeasureAPI, + EvaluationMeasureV1API, + EvaluationMeasureV2API, + EvaluationV1API, + EvaluationV2API, + FallbackProxy, + FlowAPI, + FlowV1API, + FlowV2API, + ResourceAPI, + ResourceV1API, + ResourceV2API, + RunAPI, + RunV1API, + RunV2API, + SetupAPI, + SetupV1API, + SetupV2API, + StudyAPI, + StudyV1API, + StudyV2API, + TaskAPI, + TaskV1API, + TaskV2API, +) +from .setup import ( + APIBackend, + APIBackendBuilder, + _backend, +) + +__all__ = [ + "API_REGISTRY", + "APIBackend", + "APIBackendBuilder", + "DatasetAPI", + "DatasetV1API", + "DatasetV2API", + "EstimationProcedureAPI", + "EstimationProcedureV1API", + "EstimationProcedureV2API", + "EvaluationAPI", + "EvaluationMeasureAPI", + "EvaluationMeasureV1API", + "EvaluationMeasureV2API", + "EvaluationV1API", + "EvaluationV2API", + "FallbackProxy", + "FlowAPI", + "FlowV1API", + "FlowV2API", + "HTTPCache", + "HTTPClient", + "MinIOClient", + "ResourceAPI", + "ResourceV1API", + "ResourceV2API", + "RunAPI", + "RunV1API", + "RunV2API", + "SetupAPI", + "SetupV1API", + "SetupV2API", + "StudyAPI", + "StudyV1API", + "StudyV2API", + "TaskAPI", + "TaskV1API", + "TaskV2API", + "_backend", +] diff --git a/openml/_api/clients/__init__.py b/openml/_api/clients/__init__.py new file mode 100644 index 000000000..42f11fbcf --- /dev/null +++ b/openml/_api/clients/__init__.py @@ -0,0 +1,8 @@ +from .http import HTTPCache, HTTPClient +from .minio import MinIOClient + +__all__ = [ + "HTTPCache", + "HTTPClient", + "MinIOClient", +] diff --git a/openml/_api/clients/http.py b/openml/_api/clients/http.py new file mode 100644 index 000000000..913d3dd00 --- /dev/null +++ b/openml/_api/clients/http.py @@ -0,0 +1,834 @@ +from __future__ import annotations + +import hashlib +import json +import logging +import math +import random +import time +import xml +from collections.abc import Callable, Mapping +from pathlib import Path +from typing import Any, cast +from urllib.parse import urlencode, urljoin, urlparse + +import requests +import xmltodict +from requests import Response + +import openml +from openml.enums import APIVersion, RetryPolicy +from openml.exceptions import ( + OpenMLAuthenticationError, + OpenMLHashException, + OpenMLServerError, + OpenMLServerException, + OpenMLServerNoResult, +) + + +class HTTPCache: + """ + Filesystem-based cache for HTTP responses. + + This class stores HTTP responses on disk using a structured directory layout + derived from the request URL and parameters. Each cached response consists of + three files: metadata (``meta.json``), headers (``headers.json``), and the raw + body (``body.bin``). + + Notes + ----- + The cache key is derived from the URL (domain and path components) and query + parameters, excluding the ``api_key`` parameter. + """ + + @property + def path(self) -> Path: + return Path(openml.config.get_cache_directory()) + + def get_key(self, url: str, params: dict[str, Any]) -> str: + """ + Generate a filesystem-safe cache key for a request. + + The key is constructed from the reversed domain components, URL path + segments, and URL-encoded query parameters (excluding ``api_key``). + + Parameters + ---------- + url : str + The full request URL. + params : dict of str to Any + Query parameters associated with the request. + + Returns + ------- + str + A relative path string representing the cache key. + """ + parsed_url = urlparse(url) + netloc_parts = parsed_url.netloc.split(".")[::-1] + path_parts = parsed_url.path.strip("/").split("/") + + filtered_params = {k: v for k, v in params.items() if k != "api_key"} + params_part = [urlencode(filtered_params)] if filtered_params else [] + + return str(Path(*netloc_parts, *path_parts, *params_part)) + + def _key_to_path(self, key: str) -> Path: + """ + Convert a cache key into an absolute filesystem path. + + Parameters + ---------- + key : str + Cache key as returned by :meth:`get_key`. + + Returns + ------- + pathlib.Path + Absolute path corresponding to the cache entry. + """ + return self.path.joinpath(key) + + def load(self, key: str) -> Response: + """ + Load a cached HTTP response from disk. + + Parameters + ---------- + key : str + Cache key identifying the stored response. + + Returns + ------- + requests.Response + Reconstructed response object with status code, headers, body, and metadata. + + Raises + ------ + FileNotFoundError + If the cache entry or required files are missing. + ValueError + If required metadata is missing or malformed. + """ + path = self._key_to_path(key) + + if not path.exists(): + raise FileNotFoundError(f"Cache entry not found: {path}") + + meta_path = path / "meta.json" + headers_path = path / "headers.json" + body_path = path / "body.bin" + + if not (meta_path.exists() and headers_path.exists() and body_path.exists()): + raise FileNotFoundError(f"Incomplete cache at {path}") + + with meta_path.open("r", encoding="utf-8") as f: + meta = json.load(f) + + with headers_path.open("r", encoding="utf-8") as f: + headers = json.load(f) + + body = body_path.read_bytes() + + response = Response() + response.status_code = meta["status_code"] + response.url = meta["url"] + response.reason = meta["reason"] + response.headers = headers + response._content = body + response.encoding = meta["encoding"] + + return response + + def save(self, key: str, response: Response) -> None: + """ + Persist an HTTP response to disk. + + Parameters + ---------- + key : str + Cache key identifying where to store the response. + response : requests.Response + Response object to cache. + + Notes + ----- + The response body is stored as binary data. Headers and metadata + (status code, URL, reason, encoding, elapsed time, request info, and + creation timestamp) are stored as JSON. + """ + path = self._key_to_path(key) + path.mkdir(parents=True, exist_ok=True) + + (path / "body.bin").write_bytes(response.content) + + with (path / "headers.json").open("w", encoding="utf-8") as f: + json.dump(dict(response.headers), f) + + meta = { + "status_code": response.status_code, + "url": response.url, + "reason": response.reason, + "encoding": response.encoding, + "created_at": time.time(), + "request": { + "method": response.request.method if response.request else None, + "url": response.request.url if response.request else None, + "headers": dict(response.request.headers) if response.request else None, + "body": response.request.body if response.request else None, + }, + } + + with (path / "meta.json").open("w", encoding="utf-8") as f: + json.dump(meta, f) + + +class HTTPClient: + """ + HTTP client for interacting with the OpenML API. + + This client supports configurable retry policies, optional filesystem + caching, API key authentication, and response validation including + checksum verification. + + Parameters + ---------- + api_version : APIVersion + Backend API Version. + """ + + def __init__( + self, + *, + api_version: APIVersion, + ) -> None: + self.api_version = api_version + + self.cache = HTTPCache() + + @property + def server(self) -> str: + server = openml.config.servers[self.api_version]["server"] + if server is None: + servers_repr = {k.value: v for k, v in openml.config.servers.items()} + raise ValueError( + f'server found to be None for api_version="{self.api_version}" in {servers_repr}' + ) + return cast("str", server) + + @property + def api_key(self) -> str | None: + return cast("str | None", openml.config.servers[self.api_version]["apikey"]) + + @property + def retries(self) -> int: + return cast("int", openml.config.connection_n_retries) + + @property + def retry_policy(self) -> RetryPolicy: + return RetryPolicy.HUMAN if openml.config.retry_policy == "human" else RetryPolicy.ROBOT + + @property + def retry_func(self) -> Callable: + return self._human_delay if self.retry_policy == RetryPolicy.HUMAN else self._robot_delay + + def _robot_delay(self, n: int) -> float: + """ + Compute delay for automated retry policy. + + Parameters + ---------- + n : int + Current retry attempt number (1-based). + + Returns + ------- + float + Number of seconds to wait before the next retry. + + Notes + ----- + Uses a sigmoid-based growth curve with Gaussian noise to gradually + increase waiting time. + """ + wait = (1 / (1 + math.exp(-(n * 0.5 - 4)))) * 60 + variation = random.gauss(0, wait / 10) + return max(1.0, wait + variation) + + def _human_delay(self, n: int) -> float: + """ + Compute delay for human-like retry policy. + + Parameters + ---------- + n : int + Current retry attempt number (1-based). + + Returns + ------- + float + Number of seconds to wait before the next retry. + """ + return max(1.0, n) + + def _parse_exception_response( + self, + response: Response, + ) -> tuple[int | None, str]: + """ + Parse an error response returned by the server. + + Parameters + ---------- + response : requests.Response + HTTP response containing error details in JSON or XML format. + + Returns + ------- + tuple of (int or None, str) + Parsed error code and combined error message. The code may be + ``None`` if unavailable. + """ + content_type = response.headers.get("Content-Type", "").lower() + + if "application/json" in content_type: + server_exception = response.json() + server_error = server_exception["detail"] + code = server_error.get("code") + message = server_error.get("message") + additional_information = server_error.get("additional_information") + else: + server_exception = xmltodict.parse(response.text) + server_error = server_exception["oml:error"] + code = server_error.get("oml:code") + message = server_error.get("oml:message") + additional_information = server_error.get("oml:additional_information") + + if code is not None: + code = int(code) + + if message and additional_information: + full_message = f"{message} - {additional_information}" + elif message: + full_message = message + elif additional_information: + full_message = additional_information + else: + full_message = "" + + return code, full_message + + def _raise_code_specific_error( + self, + code: int, + message: str, + url: str, + files: Mapping[str, Any] | None, + ) -> None: + """ + Raise specialized exceptions based on OpenML error codes. + + Parameters + ---------- + code : int + Server-provided error code. + message : str + Parsed error message. + url : str + Request URL associated with the error. + files : Mapping of str to Any or None + Files sent with the request, if any. + + Raises + ------ + OpenMLServerNoResult + If the error indicates a missing resource. + OpenMLNotAuthorizedError + If authentication is required or invalid. + OpenMLServerException + For other server-side errors (except retryable database errors). + """ + if code in [111, 372, 512, 500, 482, 542, 674]: + # 512 for runs, 372 for datasets, 500 for flows + # 482 for tasks, 542 for evaluations, 674 for setups + # 111 for dataset descriptions + raise OpenMLServerNoResult(code=code, message=message, url=url) + + # 163: failure to validate flow XML (https://www.openml.org/api_docs#!/flow/post_flow) + if code == 163 and files is not None and "description" in files: + # file_elements['description'] is the XML file description of the flow + message = f"\n{files['description']}\n{message}" + + # Propagate all server errors to the calling functions, except + # for 107 which represents a database connection error. + # These are typically caused by high server load, + # which means trying again might resolve the issue. + # DATABASE_CONNECTION_ERRCODE + if code != 107: + raise OpenMLServerException(code=code, message=message, url=url) + + def _validate_response( + self, + method: str, + url: str, + files: Mapping[str, Any] | None, + response: Response, + ) -> Exception | None: + """ + Validate an HTTP response and determine whether to retry. + + Parameters + ---------- + method : str + HTTP method used for the request. + url : str + Full request URL. + files : Mapping of str to Any or None + Files sent with the request, if any. + response : requests.Response + Received HTTP response. + + Returns + ------- + Exception or None + ``None`` if the response is valid. Otherwise, an exception + indicating the error to raise or retry. + + Raises + ------ + OpenMLServerError + For unexpected server errors or malformed responses. + """ + if ( + "Content-Encoding" not in response.headers + or response.headers["Content-Encoding"] != "gzip" + ): + logging.warning(f"Received uncompressed content from OpenML for {url}.") + + if response.status_code == 200: + return None + + if response.status_code == requests.codes.URI_TOO_LONG: + raise OpenMLServerError(f"URI too long! ({url})") + + exception: Exception | None = None + code: int | None = None + message: str = "" + + try: + code, message = self._parse_exception_response(response) + + except (requests.exceptions.JSONDecodeError, xml.parsers.expat.ExpatError) as e: + if method != "GET": + extra = f"Status code: {response.status_code}\n{response.text}" + raise OpenMLServerError( + f"Unexpected server error when calling {url}. Please contact the " + f"developers!\n{extra}" + ) from e + + exception = e + + except Exception as e: + # If we failed to parse it out, + # then something has gone wrong in the body we have sent back + # from the server and there is little extra information we can capture. + raise OpenMLServerError( + f"Unexpected server error when calling {url}. Please contact the developers!\n" + f"Status code: {response.status_code}\n{response.text}", + ) from e + + if code is not None: + self._raise_code_specific_error( + code=code, + message=message, + url=url, + files=files, + ) + + if exception is None: + exception = OpenMLServerException(code=code, message=message, url=url) + + return exception + + def __request( # noqa: PLR0913 + self, + session: requests.Session, + method: str, + url: str, + params: Mapping[str, Any], + data: Mapping[str, Any], + headers: Mapping[str, str], + files: Mapping[str, Any] | None, + **request_kwargs: Any, + ) -> tuple[Response | None, Exception | None]: + """ + Execute a single HTTP request attempt. + + Parameters + ---------- + session : requests.Session + Active session used to send the request. + method : str + HTTP method (e.g., ``GET``, ``POST``). + url : str + Full request URL. + params : Mapping of str to Any + Query parameters. + data : Mapping of str to Any + Request body data. + headers : Mapping of str to str + HTTP headers. + files : Mapping of str to Any or None + Files to upload. + **request_kwargs : Any + Additional arguments forwarded to ``requests.Session.request``. + + Returns + ------- + tuple of (requests.Response or None, Exception or None) + Response and potential retry exception. + """ + exception: Exception | None = None + response: Response | None = None + + try: + response = session.request( + method=method, + url=url, + params=params, + data=data, + headers=headers, + files=files, + **request_kwargs, + ) + except ( + requests.exceptions.ChunkedEncodingError, + requests.exceptions.ConnectionError, + requests.exceptions.SSLError, + ) as e: + exception = e + + if response is not None: + exception = self._validate_response( + method=method, + url=url, + files=files, + response=response, + ) + + return response, exception + + def _request( # noqa: PLR0913, C901 + self, + method: str, + path: str, + *, + enable_cache: bool = False, + refresh_cache: bool = False, + use_api_key: bool = False, + md5_checksum: str | None = None, + **request_kwargs: Any, + ) -> Response: + """ + Send an HTTP request with retry, caching, and validation support. + + Parameters + ---------- + method : str + HTTP method to use. + path : str + API path relative to the base URL. + enable_cache : bool, optional + Whether to load/store response from cache. + refresh_cache : bool, optional + Only used when `enable_cache=True`. If True, ignore any existing + cached response and overwrite it with a fresh one. + use_api_key : bool, optional + Whether to include the API key in query parameters. + md5_checksum : str or None, optional + Expected MD5 checksum of the response body. + **request_kwargs : Any + Additional arguments passed to the underlying request. + + Returns + ------- + requests.Response + Final validated response. + + Raises + ------ + Exception + Propagates network, validation, or server exceptions after retries. + OpenMLHashException + If checksum verification fails. + """ + url = urljoin(self.server, path) + retries = max(1, self.retries) + + params = request_kwargs.pop("params", {}).copy() + data = request_kwargs.pop("data", {}).copy() + + if use_api_key: + if self.api_key is None: + raise OpenMLAuthenticationError( + message=( + f"The API call {url} requires authentication via an API key. " + "Please configure OpenML-Python to use your API " + "as described in this example: " + "https://openml.github.io/openml-python/latest/examples/Basics/introduction_tutorial/#authentication" + ) + ) + params["api_key"] = self.api_key + + if method.upper() in {"POST", "PUT", "PATCH"}: + data = {**params, **data} + params = {} + + # prepare headers + headers = request_kwargs.pop("headers", {}).copy() + headers.update(openml.config._HEADERS) + + files = request_kwargs.pop("files", None) + + if enable_cache and not refresh_cache: + cache_key = self.cache.get_key(url, params) + try: + return self.cache.load(cache_key) + except FileNotFoundError: + pass # cache miss, continue + except Exception: + raise # propagate unexpected cache errors + + with requests.Session() as session: + for retry_counter in range(1, retries + 1): + response, exception = self.__request( + session=session, + method=method, + url=url, + params=params, + data=data, + headers=headers, + files=files, + **request_kwargs, + ) + + # executed successfully + if exception is None: + break + # tries completed + if retry_counter >= retries: + raise exception + + delay = self.retry_func(retry_counter) + time.sleep(delay) + + # response is guaranteed to be not `None` + # otherwise an exception would have been raised before + response = cast("Response", response) + + if md5_checksum is not None: + self._verify_checksum(response, md5_checksum) + + if enable_cache: + cache_key = self.cache.get_key(url, params) + self.cache.save(cache_key, response) + + return response + + def _verify_checksum(self, response: Response, md5_checksum: str) -> None: + """ + Verify MD5 checksum of a response body. + + Parameters + ---------- + response : requests.Response + HTTP response whose content should be verified. + md5_checksum : str + Expected hexadecimal MD5 checksum. + + Raises + ------ + OpenMLHashException + If the computed checksum does not match the expected value. + """ + # ruff sees hashlib.md5 as insecure + actual = hashlib.md5(response.content).hexdigest() # noqa: S324 + if actual != md5_checksum: + raise OpenMLHashException( + f"Checksum of downloaded file is unequal to the expected checksum {md5_checksum} " + f"when downloading {response.url}.", + ) + + def get( + self, + path: str, + *, + enable_cache: bool = False, + refresh_cache: bool = False, + use_api_key: bool = False, + md5_checksum: str | None = None, + **request_kwargs: Any, + ) -> Response: + """ + Send a GET request. + + Parameters + ---------- + path : str + API path relative to the base URL. + enable_cache : bool, optional + Whether to use the response cache. + refresh_cache : bool, optional + Whether to ignore existing cached entries. + use_api_key : bool, optional + Whether to include the API key. + md5_checksum : str or None, optional + Expected MD5 checksum for response validation. + **request_kwargs : Any + Additional request arguments. + + Returns + ------- + requests.Response + HTTP response. + """ + return self._request( + method="GET", + path=path, + enable_cache=enable_cache, + refresh_cache=refresh_cache, + use_api_key=use_api_key, + md5_checksum=md5_checksum, + **request_kwargs, + ) + + def post( + self, + path: str, + *, + use_api_key: bool = True, + **request_kwargs: Any, + ) -> Response: + """ + Send a POST request. + + Parameters + ---------- + path : str + API path relative to the base URL. + use_api_key : bool, optional + Whether to include the API key. + **request_kwargs : Any + Additional request arguments. + + Returns + ------- + requests.Response + HTTP response. + """ + return self._request( + method="POST", + path=path, + enable_cache=False, + use_api_key=use_api_key, + **request_kwargs, + ) + + def delete( + self, + path: str, + **request_kwargs: Any, + ) -> Response: + """ + Send a DELETE request. + + Parameters + ---------- + path : str + API path relative to the base URL. + **request_kwargs : Any + Additional request arguments. + + Returns + ------- + requests.Response + HTTP response. + """ + return self._request( + method="DELETE", + path=path, + enable_cache=False, + use_api_key=True, + **request_kwargs, + ) + + def download( + self, + url: str, + handler: Callable[[Response, Path, str], Path] | None = None, + encoding: str = "utf-8", + file_name: str = "response.txt", + md5_checksum: str | None = None, + ) -> Path: + """ + Download a resource and store it in the cache directory. + + Parameters + ---------- + url : str + Absolute URL of the resource to download. + handler : callable or None, optional + Custom handler function accepting ``(response, path, encoding)`` + and returning a ``pathlib.Path``. + encoding : str, optional + Text encoding used when writing the response body. + file_name : str, optional + Name of the saved file. + md5_checksum : str or None, optional + Expected MD5 checksum for integrity verification. + + Returns + ------- + pathlib.Path + Path to the downloaded file. + + Raises + ------ + OpenMLHashException + If checksum verification fails. + """ + base = self.cache.path + file_path = base / "downloads" / urlparse(url).path.lstrip("/") / file_name + file_path = file_path.expanduser() + file_path.parent.mkdir(parents=True, exist_ok=True) + if file_path.exists(): + return file_path + + response = self.get(url, md5_checksum=md5_checksum) + if handler is not None: + return handler(response, file_path, encoding) + + return self._text_handler(response, file_path, encoding) + + def _text_handler(self, response: Response, path: Path, encoding: str) -> Path: + """ + Write response text content to a file. + + Parameters + ---------- + response : requests.Response + HTTP response containing text data. + path : pathlib.Path + Destination file path. + encoding : str + Text encoding for writing the file. + + Returns + ------- + pathlib.Path + Path to the written file. + """ + with path.open("w", encoding=encoding) as f: + f.write(response.text) + return path diff --git a/openml/_api/clients/minio.py b/openml/_api/clients/minio.py new file mode 100644 index 000000000..920b485e0 --- /dev/null +++ b/openml/_api/clients/minio.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from pathlib import Path + +import openml + + +class MinIOClient: + """ + Lightweight client configuration for interacting with a MinIO-compatible + object storage service. + + This class stores basic configuration such as a base filesystem path and + default HTTP headers. It is intended to be extended with actual request + or storage logic elsewhere. + + Attributes + ---------- + path : pathlib.Path or None + Configured base path for storage operations. + headers : dict of str to str + Default HTTP headers, including a user-agent identifying the + OpenML Python client version. + """ + + @property + def path(self) -> Path: + return Path(openml.config.get_cache_directory()) diff --git a/openml/_api/resources/__init__.py b/openml/_api/resources/__init__.py new file mode 100644 index 000000000..6d957966e --- /dev/null +++ b/openml/_api/resources/__init__.py @@ -0,0 +1,63 @@ +from ._registry import API_REGISTRY +from .base import ( + DatasetAPI, + EstimationProcedureAPI, + EvaluationAPI, + EvaluationMeasureAPI, + FallbackProxy, + FlowAPI, + ResourceAPI, + ResourceV1API, + ResourceV2API, + RunAPI, + SetupAPI, + StudyAPI, + TaskAPI, +) +from .dataset import DatasetV1API, DatasetV2API +from .estimation_procedure import ( + EstimationProcedureV1API, + EstimationProcedureV2API, +) +from .evaluation import EvaluationV1API, EvaluationV2API +from .evaluation_measure import EvaluationMeasureV1API, EvaluationMeasureV2API +from .flow import FlowV1API, FlowV2API +from .run import RunV1API, RunV2API +from .setup import SetupV1API, SetupV2API +from .study import StudyV1API, StudyV2API +from .task import TaskV1API, TaskV2API + +__all__ = [ + "API_REGISTRY", + "DatasetAPI", + "DatasetV1API", + "DatasetV2API", + "EstimationProcedureAPI", + "EstimationProcedureV1API", + "EstimationProcedureV2API", + "EvaluationAPI", + "EvaluationMeasureAPI", + "EvaluationMeasureV1API", + "EvaluationMeasureV2API", + "EvaluationV1API", + "EvaluationV2API", + "FallbackProxy", + "FlowAPI", + "FlowV1API", + "FlowV2API", + "ResourceAPI", + "ResourceV1API", + "ResourceV2API", + "RunAPI", + "RunV1API", + "RunV2API", + "SetupAPI", + "SetupV1API", + "SetupV2API", + "StudyAPI", + "StudyV1API", + "StudyV2API", + "TaskAPI", + "TaskV1API", + "TaskV2API", +] diff --git a/openml/_api/resources/_registry.py b/openml/_api/resources/_registry.py new file mode 100644 index 000000000..66d7ec428 --- /dev/null +++ b/openml/_api/resources/_registry.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from openml.enums import APIVersion, ResourceType + +from .dataset import DatasetV1API, DatasetV2API +from .estimation_procedure import ( + EstimationProcedureV1API, + EstimationProcedureV2API, +) +from .evaluation import EvaluationV1API, EvaluationV2API +from .evaluation_measure import EvaluationMeasureV1API, EvaluationMeasureV2API +from .flow import FlowV1API, FlowV2API +from .run import RunV1API, RunV2API +from .setup import SetupV1API, SetupV2API +from .study import StudyV1API, StudyV2API +from .task import TaskV1API, TaskV2API + +if TYPE_CHECKING: + from .base import ResourceAPI + +API_REGISTRY: dict[ + APIVersion, + dict[ResourceType, type[ResourceAPI]], +] = { + APIVersion.V1: { + ResourceType.DATASET: DatasetV1API, + ResourceType.TASK: TaskV1API, + ResourceType.EVALUATION_MEASURE: EvaluationMeasureV1API, + ResourceType.ESTIMATION_PROCEDURE: EstimationProcedureV1API, + ResourceType.EVALUATION: EvaluationV1API, + ResourceType.FLOW: FlowV1API, + ResourceType.STUDY: StudyV1API, + ResourceType.RUN: RunV1API, + ResourceType.SETUP: SetupV1API, + }, + APIVersion.V2: { + ResourceType.DATASET: DatasetV2API, + ResourceType.TASK: TaskV2API, + ResourceType.EVALUATION_MEASURE: EvaluationMeasureV2API, + ResourceType.ESTIMATION_PROCEDURE: EstimationProcedureV2API, + ResourceType.EVALUATION: EvaluationV2API, + ResourceType.FLOW: FlowV2API, + ResourceType.STUDY: StudyV2API, + ResourceType.RUN: RunV2API, + ResourceType.SETUP: SetupV2API, + }, +} diff --git a/openml/_api/resources/base/__init__.py b/openml/_api/resources/base/__init__.py new file mode 100644 index 000000000..ed6dc26f7 --- /dev/null +++ b/openml/_api/resources/base/__init__.py @@ -0,0 +1,30 @@ +from .base import ResourceAPI +from .fallback import FallbackProxy +from .resources import ( + DatasetAPI, + EstimationProcedureAPI, + EvaluationAPI, + EvaluationMeasureAPI, + FlowAPI, + RunAPI, + SetupAPI, + StudyAPI, + TaskAPI, +) +from .versions import ResourceV1API, ResourceV2API + +__all__ = [ + "DatasetAPI", + "EstimationProcedureAPI", + "EvaluationAPI", + "EvaluationMeasureAPI", + "FallbackProxy", + "FlowAPI", + "ResourceAPI", + "ResourceV1API", + "ResourceV2API", + "RunAPI", + "SetupAPI", + "StudyAPI", + "TaskAPI", +] diff --git a/openml/_api/resources/base/base.py b/openml/_api/resources/base/base.py new file mode 100644 index 000000000..625681e3b --- /dev/null +++ b/openml/_api/resources/base/base.py @@ -0,0 +1,236 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, NoReturn + +from openml.exceptions import ( + OpenMLNotAuthorizedError, + OpenMLNotSupportedError, + OpenMLServerError, + OpenMLServerException, +) + +if TYPE_CHECKING: + from collections.abc import Mapping + from typing import Any + + from openml._api.clients import HTTPClient, MinIOClient + from openml.enums import APIVersion, ResourceType + + +class ResourceAPI(ABC): + """ + Abstract base class for OpenML resource APIs. + + This class defines the common interface for interacting with OpenML + resources (e.g., datasets, flows, runs) across different API versions. + Concrete subclasses must implement the resource-specific operations + such as publishing, deleting, and tagging. + + Parameters + ---------- + http : HTTPClient + Configured HTTP client used for communication with the OpenML API. + minio : MinIOClient + Configured MinIO client used for object storage operations. + + Attributes + ---------- + api_version : APIVersion + API version implemented by the resource. + resource_type : ResourceType + Type of OpenML resource handled by the implementation. + _http : HTTPClient + Internal HTTP client instance. + _minio : MinIOClient or None + Internal MinIO client instance, if provided. + """ + + api_version: APIVersion + resource_type: ResourceType + + def __init__(self, http: HTTPClient, minio: MinIOClient): + self._http = http + self._minio = minio + + @abstractmethod + def delete(self, resource_id: int) -> bool: + """ + Delete a resource by its identifier. + + Parameters + ---------- + resource_id : int + Unique identifier of the resource to delete. + + Returns + ------- + bool + ``True`` if the deletion was successful. + + Notes + ----- + Concrete subclasses must implement this method. + """ + + @abstractmethod + def publish(self, path: str, files: Mapping[str, Any] | None) -> int: + """ + Publish a new resource to the OpenML server. + + Parameters + ---------- + path : str + API endpoint path used for publishing the resource. + files : Mapping of str to Any or None + Files or payload data required for publishing. The structure + depends on the resource type. + + Returns + ------- + int + Identifier of the newly created resource. + + Notes + ----- + Concrete subclasses must implement this method. + """ + + @abstractmethod + def tag(self, resource_id: int, tag: str) -> list[str]: + """ + Add a tag to a resource. + + Parameters + ---------- + resource_id : int + Identifier of the resource to tag. + tag : str + Tag to associate with the resource. + + Returns + ------- + list of str + Updated list of tags assigned to the resource. + + Notes + ----- + Concrete subclasses must implement this method. + """ + + @abstractmethod + def untag(self, resource_id: int, tag: str) -> list[str]: + """ + Remove a tag from a resource. + + Parameters + ---------- + resource_id : int + Identifier of the resource to untag. + tag : str + Tag to remove from the resource. + + Returns + ------- + list of str + Updated list of tags assigned to the resource. + + Notes + ----- + Concrete subclasses must implement this method. + """ + + @abstractmethod + def _get_endpoint_name(self) -> str: + """ + Return the endpoint name for the current resource type. + + Returns + ------- + str + Endpoint segment used in API paths. + + Notes + ----- + Datasets use the special endpoint name ``"data"`` instead of + their enum value. + """ + + def _handle_delete_exception( + self, resource_type: str, exception: OpenMLServerException + ) -> None: + """ + Map V1 deletion error codes to more specific exceptions. + + Parameters + ---------- + resource_type : str + Endpoint name of the resource type. + exception : OpenMLServerException + Original exception raised during deletion. + + Raises + ------ + OpenMLNotAuthorizedError + If the resource cannot be deleted due to ownership or + dependent entities. + OpenMLServerError + If deletion fails for an unknown reason. + OpenMLServerException + If the error code is not specially handled. + """ + # https://github.com/openml/OpenML/blob/21f6188d08ac24fcd2df06ab94cf421c946971b0/openml_OS/views/pages/api_new/v1/xml/pre.php + # Most exceptions are descriptive enough to be raised as their standard + # OpenMLServerException, however there are two cases where we add information: + # - a generic "failed" message, we direct them to the right issue board + # - when the user successfully authenticates with the server, + # but user is not allowed to take the requested action, + # in which case we specify a OpenMLNotAuthorizedError. + by_other_user = [323, 353, 393, 453, 594] + has_dependent_entities = [324, 326, 327, 328, 354, 454, 464, 595] + unknown_reason = [325, 355, 394, 455, 593] + if exception.code in by_other_user: + raise OpenMLNotAuthorizedError( + message=( + f"The {resource_type} can not be deleted because it was not uploaded by you." + ), + ) from exception + if exception.code in has_dependent_entities: + raise OpenMLNotAuthorizedError( + message=( + f"The {resource_type} can not be deleted because " + f"it still has associated entities: {exception.message}" + ), + ) from exception + if exception.code in unknown_reason: + raise OpenMLServerError( + message=( + f"The {resource_type} can not be deleted for unknown reason," + " please open an issue at: https://github.com/openml/openml/issues/new" + ), + ) from exception + raise exception + + def _not_supported(self, *, method: str) -> NoReturn: + """ + Raise an error indicating that a method is not supported. + + Parameters + ---------- + method : str + Name of the unsupported method. + + Raises + ------ + OpenMLNotSupportedError + If the current API version does not support the requested method + for the given resource type. + """ + version = getattr(self.api_version, "value", "unknown") + resource = getattr(self.resource_type, "value", "unknown") + + raise OpenMLNotSupportedError( + f"{self.__class__.__name__}: " + f"{version} API does not support `{method}` " + f"for resource `{resource}`" + ) diff --git a/openml/_api/resources/base/fallback.py b/openml/_api/resources/base/fallback.py new file mode 100644 index 000000000..9b8f64a17 --- /dev/null +++ b/openml/_api/resources/base/fallback.py @@ -0,0 +1,166 @@ +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + +from openml.exceptions import OpenMLNotSupportedError + + +class FallbackProxy: + """ + Proxy object that provides transparent fallback across multiple API versions. + + This class delegates attribute access to a sequence of API implementations. + When a callable attribute is invoked and raises ``OpenMLNotSupportedError``, + the proxy automatically attempts the same method on subsequent API instances + until one succeeds. + + Parameters + ---------- + *api_versions : Any + One or more API implementation instances ordered by priority. + The first API is treated as the primary implementation, and + subsequent APIs are used as fallbacks. + + Raises + ------ + ValueError + If no API implementations are provided. + + Notes + ----- + Attribute lookup is performed dynamically via ``__getattr__``. + Only methods that raise ``OpenMLNotSupportedError`` trigger fallback + behavior. Other exceptions are propagated immediately. + """ + + def __init__(self, *api_versions: Any): + if not api_versions: + raise ValueError("At least one API version must be provided") + self._apis = api_versions + + def __getattr__(self, name: str) -> Any: + """ + Dynamically resolve attribute access across API implementations. + + Parameters + ---------- + name : str + Name of the attribute being accessed. + + Returns + ------- + Any + The resolved attribute. If it is callable, a wrapped function + providing fallback behavior is returned. + + Raises + ------ + AttributeError + If none of the API implementations define the attribute. + """ + api, attr = self._find_attr(name) + if callable(attr): + return self._wrap_callable(name, api, attr) + return attr + + def _find_attr(self, name: str) -> tuple[Any, Any]: + """ + Find the first API implementation that defines a given attribute. + + Parameters + ---------- + name : str + Name of the attribute to search for. + + Returns + ------- + tuple of (Any, Any) + The API instance and the corresponding attribute. + + Raises + ------ + AttributeError + If no API implementation defines the attribute. + """ + for api in self._apis: + attr = getattr(api, name, None) + if attr is not None: + return api, attr + raise AttributeError(f"{self.__class__.__name__} has no attribute {name}") + + def _wrap_callable( + self, + name: str, + primary_api: Any, + primary_attr: Callable[..., Any], + ) -> Callable[..., Any]: + """ + Wrap a callable attribute to enable fallback behavior. + + Parameters + ---------- + name : str + Name of the method being wrapped. + primary_api : Any + Primary API instance providing the callable. + primary_attr : Callable[..., Any] + Callable attribute obtained from the primary API. + + Returns + ------- + Callable[..., Any] + Wrapped function that attempts the primary call first and + falls back to other APIs if ``OpenMLNotSupportedError`` is raised. + """ + + def wrapper(*args: Any, **kwargs: Any) -> Any: + try: + return primary_attr(*args, **kwargs) + except OpenMLNotSupportedError: + return self._call_fallbacks(name, primary_api, *args, **kwargs) + + return wrapper + + def _call_fallbacks( + self, + name: str, + skip_api: Any, + *args: Any, + **kwargs: Any, + ) -> Any: + """ + Attempt to call a method on fallback API implementations. + + Parameters + ---------- + name : str + Name of the method to invoke. + skip_api : Any + API instance to skip (typically the primary API that already failed). + *args : Any + Positional arguments passed to the method. + **kwargs : Any + Keyword arguments passed to the method. + + Returns + ------- + Any + Result returned by the first successful fallback invocation. + + Raises + ------ + OpenMLNotSupportedError + If all API implementations either do not define the method + or raise ``OpenMLNotSupportedError``. + """ + for api in self._apis: + if api is skip_api: + continue + attr = getattr(api, name, None) + if callable(attr): + try: + return attr(*args, **kwargs) + except OpenMLNotSupportedError: + continue + raise OpenMLNotSupportedError(f"Could not fallback to any API for method: {name}") diff --git a/openml/_api/resources/base/resources.py b/openml/_api/resources/base/resources.py new file mode 100644 index 000000000..ede0e1034 --- /dev/null +++ b/openml/_api/resources/base/resources.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +from openml.enums import ResourceType + +from .base import ResourceAPI + + +class DatasetAPI(ResourceAPI): + """Abstract API interface for dataset resources.""" + + resource_type: ResourceType = ResourceType.DATASET + + +class TaskAPI(ResourceAPI): + """Abstract API interface for task resources.""" + + resource_type: ResourceType = ResourceType.TASK + + +class EvaluationMeasureAPI(ResourceAPI): + """Abstract API interface for evaluation measure resources.""" + + resource_type: ResourceType = ResourceType.EVALUATION_MEASURE + + +class EstimationProcedureAPI(ResourceAPI): + """Abstract API interface for estimation procedure resources.""" + + resource_type: ResourceType = ResourceType.ESTIMATION_PROCEDURE + + +class EvaluationAPI(ResourceAPI): + """Abstract API interface for evaluation resources.""" + + resource_type: ResourceType = ResourceType.EVALUATION + + +class FlowAPI(ResourceAPI): + """Abstract API interface for flow resources.""" + + resource_type: ResourceType = ResourceType.FLOW + + +class StudyAPI(ResourceAPI): + """Abstract API interface for study resources.""" + + resource_type: ResourceType = ResourceType.STUDY + + +class RunAPI(ResourceAPI): + """Abstract API interface for run resources.""" + + resource_type: ResourceType = ResourceType.RUN + + +class SetupAPI(ResourceAPI): + """Abstract API interface for setup resources.""" + + resource_type: ResourceType = ResourceType.SETUP diff --git a/openml/_api/resources/base/versions.py b/openml/_api/resources/base/versions.py new file mode 100644 index 000000000..bba59b869 --- /dev/null +++ b/openml/_api/resources/base/versions.py @@ -0,0 +1,261 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any, cast + +import xmltodict + +from openml.enums import APIVersion, ResourceType +from openml.exceptions import ( + OpenMLServerException, +) + +from .base import ResourceAPI + +_LEGAL_RESOURCES_DELETE = [ + ResourceType.DATASET, + ResourceType.TASK, + ResourceType.FLOW, + ResourceType.STUDY, + ResourceType.RUN, + ResourceType.USER, +] + +_LEGAL_RESOURCES_TAG = [ + ResourceType.DATASET, + ResourceType.TASK, + ResourceType.FLOW, + ResourceType.SETUP, + ResourceType.RUN, +] + + +class ResourceV1API(ResourceAPI): + """ + Version 1 implementation of the OpenML resource API. + + This class provides XML-based implementations for publishing, + deleting, tagging, and untagging resources using the V1 API + endpoints. Responses are parsed using ``xmltodict``. + + Notes + ----- + V1 endpoints expect and return XML. Error handling follows the + legacy OpenML server behavior and maps specific error codes to + more descriptive exceptions where appropriate. + """ + + api_version: APIVersion = APIVersion.V1 + + def publish(self, path: str, files: Mapping[str, Any] | None) -> int: + """ + Publish a new resource using the V1 API. + + Parameters + ---------- + path : str + API endpoint path for the upload. + files : Mapping of str to Any or None + Files to upload as part of the request payload. + + Returns + ------- + int + Identifier of the newly created resource. + + Raises + ------ + ValueError + If the server response does not contain a valid resource ID. + OpenMLServerException + If the server returns an error during upload. + """ + response = self._http.post(path, files=files) + parsed_response = xmltodict.parse(response.content) + return self._extract_id_from_upload(parsed_response) + + def delete(self, resource_id: int) -> bool: + """ + Delete a resource using the V1 API. + + Parameters + ---------- + resource_id : int + Identifier of the resource to delete. + + Returns + ------- + bool + ``True`` if the server confirms successful deletion. + + Raises + ------ + ValueError + If the resource type is not supported for deletion. + OpenMLNotAuthorizedError + If the user is not permitted to delete the resource. + OpenMLServerError + If deletion fails for an unknown reason. + OpenMLServerException + For other server-side errors. + """ + if self.resource_type not in _LEGAL_RESOURCES_DELETE: + raise ValueError(f"Can't delete a {self.resource_type.value}") + + endpoint_name = self._get_endpoint_name() + path = f"{endpoint_name}/{resource_id}" + try: + response = self._http.delete(path) + result = xmltodict.parse(response.content) + return f"oml:{endpoint_name}_delete" in result + except OpenMLServerException as e: + self._handle_delete_exception(endpoint_name, e) + raise + + def tag(self, resource_id: int, tag: str) -> list[str]: + """ + Add a tag to a resource using the V1 API. + + Parameters + ---------- + resource_id : int + Identifier of the resource to tag. + tag : str + Tag to associate with the resource. + + Returns + ------- + list of str + Updated list of tags assigned to the resource. + + Raises + ------ + ValueError + If the resource type does not support tagging. + OpenMLServerException + If the server returns an error. + """ + if self.resource_type not in _LEGAL_RESOURCES_TAG: + raise ValueError(f"Can't tag a {self.resource_type.value}") + + endpoint_name = self._get_endpoint_name() + path = f"{endpoint_name}/tag" + data = {f"{endpoint_name}_id": resource_id, "tag": tag} + response = self._http.post(path, data=data) + + parsed_response = xmltodict.parse(response.content, force_list={"oml:tag"}) + result = parsed_response[f"oml:{endpoint_name}_tag"] + tags: list[str] = result.get("oml:tag", []) + + return tags + + def untag(self, resource_id: int, tag: str) -> list[str]: + """ + Remove a tag from a resource using the V1 API. + + Parameters + ---------- + resource_id : int + Identifier of the resource to untag. + tag : str + Tag to remove from the resource. + + Returns + ------- + list of str + Updated list of tags assigned to the resource. + + Raises + ------ + ValueError + If the resource type does not support tagging. + OpenMLServerException + If the server returns an error. + """ + if self.resource_type not in _LEGAL_RESOURCES_TAG: + raise ValueError(f"Can't untag a {self.resource_type.value}") + + endpoint_name = self._get_endpoint_name() + path = f"{endpoint_name}/untag" + data = {f"{endpoint_name}_id": resource_id, "tag": tag} + response = self._http.post(path, data=data) + + parsed_response = xmltodict.parse(response.content, force_list={"oml:tag"}) + result = parsed_response[f"oml:{endpoint_name}_untag"] + tags: list[str] = result.get("oml:tag", []) + + return tags + + def _get_endpoint_name(self) -> str: + if self.resource_type == ResourceType.DATASET: + return "data" + return cast("str", self.resource_type.value) + + def _extract_id_from_upload(self, parsed: Mapping[str, Any]) -> int: + """ + Extract the resource identifier from an XML upload response. + + Parameters + ---------- + parsed : Mapping of str to Any + Parsed XML response as returned by ``xmltodict.parse``. + + Returns + ------- + int + Extracted resource identifier. + + Raises + ------ + ValueError + If the response structure is unexpected or no identifier + can be found. + """ + # reads id from upload response + # actual parsed dict: {"oml:upload_flow": {"@xmlns:oml": "...", "oml:id": "42"}} + + # xmltodict always gives exactly one root key + ((_, root_value),) = parsed.items() + + if not isinstance(root_value, Mapping): + raise ValueError("Unexpected XML structure") + + # Look for oml:id directly in the root value + if "oml:id" in root_value: + id_value = root_value["oml:id"] + if isinstance(id_value, (str, int)): + return int(id_value) + + # Fallback: check all values for numeric/string IDs + for v in root_value.values(): + if isinstance(v, (str, int)): + return int(v) + + raise ValueError("No ID found in upload response") + + +class ResourceV2API(ResourceAPI): + """ + Version 2 implementation of the OpenML resource API. + + This class represents the V2 API for resources. Operations such as + publishing, deleting, tagging, and untagging are currently not + supported and will raise ``OpenMLNotSupportedError``. + """ + + api_version: APIVersion = APIVersion.V2 + + def publish(self, path: str, files: Mapping[str, Any] | None) -> int: # noqa: ARG002 + self._not_supported(method="publish") + + def delete(self, resource_id: int) -> bool: # noqa: ARG002 + self._not_supported(method="delete") + + def tag(self, resource_id: int, tag: str) -> list[str]: # noqa: ARG002 + self._not_supported(method="tag") + + def untag(self, resource_id: int, tag: str) -> list[str]: # noqa: ARG002 + self._not_supported(method="untag") + + def _get_endpoint_name(self) -> str: + return cast("str", self.resource_type.value) diff --git a/openml/_api/resources/dataset.py b/openml/_api/resources/dataset.py new file mode 100644 index 000000000..520594df9 --- /dev/null +++ b/openml/_api/resources/dataset.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from .base import DatasetAPI, ResourceV1API, ResourceV2API + + +class DatasetV1API(ResourceV1API, DatasetAPI): + """Version 1 API implementation for dataset resources.""" + + +class DatasetV2API(ResourceV2API, DatasetAPI): + """Version 2 API implementation for dataset resources.""" diff --git a/openml/_api/resources/estimation_procedure.py b/openml/_api/resources/estimation_procedure.py new file mode 100644 index 000000000..a45f7af66 --- /dev/null +++ b/openml/_api/resources/estimation_procedure.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from .base import EstimationProcedureAPI, ResourceV1API, ResourceV2API + + +class EstimationProcedureV1API(ResourceV1API, EstimationProcedureAPI): + """Version 1 API implementation for estimation procedure resources.""" + + +class EstimationProcedureV2API(ResourceV2API, EstimationProcedureAPI): + """Version 2 API implementation for estimation procedure resources.""" diff --git a/openml/_api/resources/evaluation.py b/openml/_api/resources/evaluation.py new file mode 100644 index 000000000..fe7e360a6 --- /dev/null +++ b/openml/_api/resources/evaluation.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from .base import EvaluationAPI, ResourceV1API, ResourceV2API + + +class EvaluationV1API(ResourceV1API, EvaluationAPI): + """Version 1 API implementation for evaluation resources.""" + + +class EvaluationV2API(ResourceV2API, EvaluationAPI): + """Version 2 API implementation for evaluation resources.""" diff --git a/openml/_api/resources/evaluation_measure.py b/openml/_api/resources/evaluation_measure.py new file mode 100644 index 000000000..4ed5097f7 --- /dev/null +++ b/openml/_api/resources/evaluation_measure.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from .base import EvaluationMeasureAPI, ResourceV1API, ResourceV2API + + +class EvaluationMeasureV1API(ResourceV1API, EvaluationMeasureAPI): + """Version 1 API implementation for evaluation measure resources.""" + + +class EvaluationMeasureV2API(ResourceV2API, EvaluationMeasureAPI): + """Version 2 API implementation for evaluation measure resources.""" diff --git a/openml/_api/resources/flow.py b/openml/_api/resources/flow.py new file mode 100644 index 000000000..1716d89d3 --- /dev/null +++ b/openml/_api/resources/flow.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from .base import FlowAPI, ResourceV1API, ResourceV2API + + +class FlowV1API(ResourceV1API, FlowAPI): + """Version 1 API implementation for flow resources.""" + + +class FlowV2API(ResourceV2API, FlowAPI): + """Version 2 API implementation for flow resources.""" diff --git a/openml/_api/resources/run.py b/openml/_api/resources/run.py new file mode 100644 index 000000000..4caccb0b6 --- /dev/null +++ b/openml/_api/resources/run.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from .base import ResourceV1API, ResourceV2API, RunAPI + + +class RunV1API(ResourceV1API, RunAPI): + """Version 1 API implementation for run resources.""" + + +class RunV2API(ResourceV2API, RunAPI): + """Version 2 API implementation for run resources.""" diff --git a/openml/_api/resources/setup.py b/openml/_api/resources/setup.py new file mode 100644 index 000000000..2896d3d9f --- /dev/null +++ b/openml/_api/resources/setup.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from .base import ResourceV1API, ResourceV2API, SetupAPI + + +class SetupV1API(ResourceV1API, SetupAPI): + """Version 1 API implementation for setup resources.""" + + +class SetupV2API(ResourceV2API, SetupAPI): + """Version 2 API implementation for setup resources.""" diff --git a/openml/_api/resources/study.py b/openml/_api/resources/study.py new file mode 100644 index 000000000..fb073555c --- /dev/null +++ b/openml/_api/resources/study.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from .base import ResourceV1API, ResourceV2API, StudyAPI + + +class StudyV1API(ResourceV1API, StudyAPI): + """Version 1 API implementation for study resources.""" + + +class StudyV2API(ResourceV2API, StudyAPI): + """Version 2 API implementation for study resources.""" diff --git a/openml/_api/resources/task.py b/openml/_api/resources/task.py new file mode 100644 index 000000000..1f62aa3f3 --- /dev/null +++ b/openml/_api/resources/task.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from .base import ResourceV1API, ResourceV2API, TaskAPI + + +class TaskV1API(ResourceV1API, TaskAPI): + """Version 1 API implementation for task resources.""" + + +class TaskV2API(ResourceV2API, TaskAPI): + """Version 2 API implementation for task resources.""" diff --git a/openml/_api/setup/__init__.py b/openml/_api/setup/__init__.py new file mode 100644 index 000000000..80545824f --- /dev/null +++ b/openml/_api/setup/__init__.py @@ -0,0 +1,10 @@ +from .backend import APIBackend +from .builder import APIBackendBuilder + +_backend = APIBackend.get_instance() + +__all__ = [ + "APIBackend", + "APIBackendBuilder", + "_backend", +] diff --git a/openml/_api/setup/backend.py b/openml/_api/setup/backend.py new file mode 100644 index 000000000..8ed37714d --- /dev/null +++ b/openml/_api/setup/backend.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, ClassVar, cast + +import openml + +from .builder import APIBackendBuilder + +if TYPE_CHECKING: + from openml._api.clients import HTTPClient, MinIOClient + from openml._api.resources import ( + DatasetAPI, + EstimationProcedureAPI, + EvaluationAPI, + EvaluationMeasureAPI, + FlowAPI, + RunAPI, + SetupAPI, + StudyAPI, + TaskAPI, + ) + + +class APIBackend: + """ + Central backend for accessing all OpenML API resource interfaces. + + This class provides a singleton interface to dataset, task, flow, + evaluation, run, setup, study, and other resource APIs. It also + manages configuration through a nested ``Config`` object and + allows dynamic retrieval and updating of configuration values. + + Parameters + ---------- + config : Config, optional + Optional configuration object. If not provided, a default + ``Config`` instance is created. + + Attributes + ---------- + dataset : DatasetAPI + Interface for dataset-related API operations. + task : TaskAPI + Interface for task-related API operations. + evaluation_measure : EvaluationMeasureAPI + Interface for evaluation measure-related API operations. + estimation_procedure : EstimationProcedureAPI + Interface for estimation procedure-related API operations. + evaluation : EvaluationAPI + Interface for evaluation-related API operations. + flow : FlowAPI + Interface for flow-related API operations. + study : StudyAPI + Interface for study-related API operations. + run : RunAPI + Interface for run-related API operations. + setup : SetupAPI + Interface for setup-related API operations. + """ + + _instance: ClassVar[APIBackend | None] = None + _backends: ClassVar[dict[str, APIBackendBuilder]] = {} + + @property + def _backend(self) -> APIBackendBuilder: + api_version = openml.config.api_version + fallback_api_version = openml.config.fallback_api_version + key = f"{api_version}_{fallback_api_version}" + + if key not in self._backends: + _backend = APIBackendBuilder.build( + api_version=api_version, + fallback_api_version=fallback_api_version, + ) + self._backends[key] = _backend + + return self._backends[key] + + @property + def dataset(self) -> DatasetAPI: + return cast("DatasetAPI", self._backend.dataset) + + @property + def task(self) -> TaskAPI: + return cast("TaskAPI", self._backend.task) + + @property + def evaluation_measure(self) -> EvaluationMeasureAPI: + return cast("EvaluationMeasureAPI", self._backend.evaluation_measure) + + @property + def estimation_procedure(self) -> EstimationProcedureAPI: + return cast("EstimationProcedureAPI", self._backend.estimation_procedure) + + @property + def evaluation(self) -> EvaluationAPI: + return cast("EvaluationAPI", self._backend.evaluation) + + @property + def flow(self) -> FlowAPI: + return cast("FlowAPI", self._backend.flow) + + @property + def study(self) -> StudyAPI: + return cast("StudyAPI", self._backend.study) + + @property + def run(self) -> RunAPI: + return cast("RunAPI", self._backend.run) + + @property + def setup(self) -> SetupAPI: + return cast("SetupAPI", self._backend.setup) + + @property + def http_client(self) -> HTTPClient: + return cast("HTTPClient", self._backend.http_client) + + @property + def fallback_http_client(self) -> HTTPClient | None: + return cast("HTTPClient | None", self._backend.fallback_http_client) + + @property + def minio_client(self) -> MinIOClient: + return cast("MinIOClient", self._backend.minio_client) + + @classmethod + def get_instance(cls) -> APIBackend: + """ + Get the singleton instance of the APIBackend. + + Returns + ------- + APIBackend + Singleton instance of the backend. + """ + if cls._instance is None: + cls._instance = cls() + return cls._instance diff --git a/openml/_api/setup/builder.py b/openml/_api/setup/builder.py new file mode 100644 index 000000000..573129316 --- /dev/null +++ b/openml/_api/setup/builder.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import TYPE_CHECKING + +from openml._api.clients import HTTPClient, MinIOClient +from openml._api.resources import API_REGISTRY, FallbackProxy +from openml.enums import ResourceType + +if TYPE_CHECKING: + from openml._api.resources import ResourceAPI + from openml.enums import APIVersion + + +class APIBackendBuilder: + """ + Builder class for constructing API backend instances. + + This class organizes resource-specific API objects (datasets, tasks, + flows, evaluations, runs, setups, studies, etc.) and provides a + centralized access point for both primary and optional fallback APIs. + + Parameters + ---------- + resource_apis : Mapping[ResourceType, ResourceAPI | FallbackProxy] + Mapping of resource types to their corresponding API instances + or fallback proxies. + + Attributes + ---------- + dataset : ResourceAPI | FallbackProxy + API interface for dataset resources. + task : ResourceAPI | FallbackProxy + API interface for task resources. + evaluation_measure : ResourceAPI | FallbackProxy + API interface for evaluation measure resources. + estimation_procedure : ResourceAPI | FallbackProxy + API interface for estimation procedure resources. + evaluation : ResourceAPI | FallbackProxy + API interface for evaluation resources. + flow : ResourceAPI | FallbackProxy + API interface for flow resources. + study : ResourceAPI | FallbackProxy + API interface for study resources. + run : ResourceAPI | FallbackProxy + API interface for run resources. + setup : ResourceAPI | FallbackProxy + API interface for setup resources. + http_client : HTTPClient + Client for HTTP Communication. + fallback_http_client : HTTPClient | None + Fallback Client for HTTP Communication. + minio_client : MinIOClient + Client for MinIO Communication. + """ + + def __init__( + self, + clients: Mapping[str, HTTPClient | MinIOClient | None], + resource_apis: Mapping[ResourceType, ResourceAPI | FallbackProxy], + ): + self.dataset = resource_apis[ResourceType.DATASET] + self.task = resource_apis[ResourceType.TASK] + self.evaluation_measure = resource_apis[ResourceType.EVALUATION_MEASURE] + self.estimation_procedure = resource_apis[ResourceType.ESTIMATION_PROCEDURE] + self.evaluation = resource_apis[ResourceType.EVALUATION] + self.flow = resource_apis[ResourceType.FLOW] + self.study = resource_apis[ResourceType.STUDY] + self.run = resource_apis[ResourceType.RUN] + self.setup = resource_apis[ResourceType.SETUP] + self.http_client = clients["http_client"] + self.fallback_http_client = clients["fallback_http_client"] + self.minio_client = clients["minio_client"] + + @classmethod + def build( + cls, + api_version: APIVersion, + fallback_api_version: APIVersion | None, + ) -> APIBackendBuilder: + """ + Construct an APIBackendBuilder instance from a configuration. + + This method initializes HTTP and MinIO clients, creates resource-specific + API instances for the primary API version, and optionally wraps them + with fallback proxies if a fallback API version is configured. + + Parameters + ---------- + config : Config + Configuration object containing API versions, endpoints, cache + settings, and connection parameters. + + Returns + ------- + APIBackendBuilder + Builder instance with all resource API interfaces initialized. + """ + minio_client = MinIOClient() + primary_http_client = HTTPClient(api_version=api_version) + clients: dict[str, HTTPClient | MinIOClient | None] = { + "http_client": primary_http_client, + "fallback_http_client": None, + "minio_client": minio_client, + } + + resource_apis: dict[ResourceType, ResourceAPI] = {} + for resource_type, resource_api_cls in API_REGISTRY[api_version].items(): + resource_apis[resource_type] = resource_api_cls(primary_http_client, minio_client) + + if fallback_api_version is None: + return cls(clients, resource_apis) + + fallback_http_client = HTTPClient(api_version=fallback_api_version) + clients["fallback_http_client"] = fallback_http_client + + fallback_resource_apis: dict[ResourceType, ResourceAPI] = {} + for resource_type, resource_api_cls in API_REGISTRY[fallback_api_version].items(): + fallback_resource_apis[resource_type] = resource_api_cls( + fallback_http_client, minio_client + ) + + merged: dict[ResourceType, FallbackProxy] = { + name: FallbackProxy(resource_apis[name], fallback_resource_apis[name]) + for name in resource_apis + } + + return cls(clients, merged) diff --git a/openml/_api_calls.py b/openml/_api_calls.py index 5da635c70..179c814e7 100644 --- a/openml/_api_calls.py +++ b/openml/_api_calls.py @@ -12,6 +12,7 @@ import xml import zipfile from pathlib import Path +from typing import cast import minio import requests @@ -19,7 +20,8 @@ import xmltodict from urllib3 import ProxyManager -from . import config +import openml + from .__version__ import __version__ from .exceptions import ( OpenMLAuthenticationError, @@ -70,7 +72,7 @@ def resolve_env_proxies(url: str) -> str | None: def _create_url_from_endpoint(endpoint: str) -> str: - url = config.server + url = cast("str", openml.config.server) if not url.endswith("/"): url += "/" url += endpoint @@ -171,7 +173,7 @@ def _download_minio_file( bucket_name=bucket, object_name=object_name, file_path=str(destination), - progress=ProgressBar() if config.show_progress else None, + progress=ProgressBar() if openml.config.show_progress else None, request_headers=_HEADERS, ) if destination.is_file() and destination.suffix == ".zip": @@ -300,7 +302,8 @@ def _file_id_to_url(file_id: int, filename: str | None = None) -> str: Presents the URL how to download a given file id filename is optional """ - openml_url = config.server.split("/api/") + openml_server = cast("str", openml.config.server) + openml_url = openml_server.split("/api/") url = openml_url[0] + f"/data/download/{file_id!s}" if filename is not None: url += "/" + filename @@ -316,7 +319,7 @@ def _read_url_files( and sending file_elements as files """ data = {} if data is None else data - data["api_key"] = config.apikey + data["api_key"] = openml.config.apikey if file_elements is None: file_elements = {} # Using requests.post sets header 'Accept-encoding' automatically to @@ -336,8 +339,8 @@ def __read_url( md5_checksum: str | None = None, ) -> requests.Response: data = {} if data is None else data - if config.apikey: - data["api_key"] = config.apikey + if openml.config.apikey: + data["api_key"] = openml.config.apikey return _send_request( request_method=request_method, url=url, @@ -362,10 +365,10 @@ def _send_request( # noqa: C901, PLR0912 files: FILE_ELEMENTS_TYPE | None = None, md5_checksum: str | None = None, ) -> requests.Response: - n_retries = max(1, config.connection_n_retries) + n_retries = max(1, openml.config.connection_n_retries) response: requests.Response | None = None - delay_method = _human_delay if config.retry_policy == "human" else _robot_delay + delay_method = _human_delay if openml.config.retry_policy == "human" else _robot_delay # Error to raise in case of retrying too often. Will be set to the last observed exception. retry_raise_e: Exception | None = None diff --git a/openml/_config.py b/openml/_config.py new file mode 100644 index 000000000..f50372a21 --- /dev/null +++ b/openml/_config.py @@ -0,0 +1,556 @@ +"""Store module level information like the API key, cache directory and the server""" + +# License: BSD 3-Clause +from __future__ import annotations + +import configparser +import logging +import logging.handlers +import os +import platform +import shutil +import warnings +from collections.abc import Iterator +from contextlib import contextmanager +from copy import deepcopy +from dataclasses import dataclass, field, fields, replace +from io import StringIO +from pathlib import Path +from typing import Any, ClassVar, Literal, cast +from urllib.parse import urlparse + +from openml.enums import APIVersion + +from .__version__ import __version__ + +logger = logging.getLogger(__name__) +openml_logger = logging.getLogger("openml") + + +SERVERS_REGISTRY: dict[str, dict[APIVersion, dict[str, str | None]]] = { + "production": { + APIVersion.V1: { + "server": "https://www.openml.org/api/v1/xml/", + "apikey": None, + }, + APIVersion.V2: { + "server": None, + "apikey": None, + }, + }, + "test": { + APIVersion.V1: { + "server": "https://test.openml.org/api/v1/xml/", + "apikey": "normaluser", + }, + APIVersion.V2: { + "server": None, + "apikey": None, + }, + }, + "local": { + APIVersion.V1: { + "server": "http://localhost:8000/api/v1/xml/", + "apikey": "normaluser", + }, + APIVersion.V2: { + "server": "http://localhost:8002/api/v1/xml/", + "apikey": "normaluser", + }, + }, +} + + +def _resolve_default_cache_dir() -> Path: + user_defined_cache_dir = os.environ.get("OPENML_CACHE_DIR") + if user_defined_cache_dir is not None: + return Path(user_defined_cache_dir) + + if platform.system().lower() != "linux": + return Path("~", ".openml").expanduser() + + xdg_cache_home = os.environ.get("XDG_CACHE_HOME") + if xdg_cache_home is None: + return Path("~", ".cache", "openml").expanduser() + + cache_dir = Path(xdg_cache_home) / "openml" + if cache_dir.exists(): + return cache_dir + + heuristic_dir_for_backwards_compat = Path(xdg_cache_home) / "org" / "openml" + if not heuristic_dir_for_backwards_compat.exists(): + return cache_dir + + root_dir_to_delete = Path(xdg_cache_home) / "org" + openml_logger.warning( + "An old cache directory was found at '%s'. This directory is no longer used by " + "OpenML-Python. To silence this warning you would need to delete the old cache " + "directory. The cached files will then be located in '%s'.", + root_dir_to_delete, + cache_dir, + ) + return Path(xdg_cache_home) + + +@dataclass +class OpenMLConfig: + """Dataclass storing the OpenML configuration.""" + + servers: dict[APIVersion, dict[str, str | None]] = field( + default_factory=lambda: deepcopy(SERVERS_REGISTRY["production"]) + ) + api_version: APIVersion = APIVersion.V1 + fallback_api_version: APIVersion | None = None + cachedir: Path = field(default_factory=_resolve_default_cache_dir) + avoid_duplicate_runs: bool = False + retry_policy: Literal["human", "robot"] = "human" + connection_n_retries: int = 5 + show_progress: bool = False + + @property + def server(self) -> str: + server = self.servers[self.api_version]["server"] + if server is None: + servers_repr = {k.value: v for k, v in self.servers.items()} + raise ValueError( + f'server found to be None for api_version="{self.api_version}" in {servers_repr}' + ) + return server + + @server.setter + def server(self, value: str | None) -> None: + self.servers[self.api_version]["server"] = value + + @property + def apikey(self) -> str | None: + return self.servers[self.api_version]["apikey"] + + @apikey.setter + def apikey(self, value: str | None) -> None: + self.servers[self.api_version]["apikey"] = value + + +class OpenMLConfigManager: + """The OpenMLConfigManager manages the configuration of the openml-python package.""" + + def __init__(self) -> None: + self.console_handler: logging.StreamHandler | None = None + self.file_handler: logging.handlers.RotatingFileHandler | None = None + + server_test_v1_apikey = self.get_servers("test")[APIVersion.V1]["apikey"] + server_test_v1_server = self.get_servers("test")[APIVersion.V1]["server"] + + self.OPENML_CACHE_DIR_ENV_VAR = "OPENML_CACHE_DIR" + self.OPENML_SKIP_PARQUET_ENV_VAR = "OPENML_SKIP_PARQUET" + self._TEST_SERVER_NORMAL_USER_KEY = server_test_v1_apikey + self._HEADERS: dict[str, str] = {"user-agent": f"openml-python/{__version__}"} + self.OPENML_TEST_SERVER_ADMIN_KEY_ENV_VAR = "OPENML_TEST_SERVER_ADMIN_KEY" + self.TEST_SERVER_URL = cast("str", server_test_v1_server).split("/api/v1/xml")[0] + + self._config: OpenMLConfig = OpenMLConfig() + # for legacy test `test_non_writable_home` + self._defaults: dict[str, Any] = OpenMLConfig().__dict__.copy() + self._root_cache_directory: Path = self._config.cachedir + + self.logger = logger + self.openml_logger = openml_logger + + self._examples = ConfigurationForExamples(self) + + self._setup() + + def __getattr__(self, name: str) -> Any: + if hasattr(self._config, name): + return getattr(self._config, name) + raise AttributeError(f"{type(self).__name__!r} object has no attribute {name!r}") + + _FIELDS: ClassVar[set[str]] = {f.name for f in fields(OpenMLConfig)} + + def __setattr__(self, name: str, value: Any) -> None: + # during __init__ before _config exists + if name in { + "_config", + "_root_cache_directory", + "console_handler", + "file_handler", + "logger", + "openml_logger", + "_examples", + "OPENML_CACHE_DIR_ENV_VAR", + "OPENML_SKIP_PARQUET_ENV_VAR", + "_TEST_SERVER_NORMAL_USER_KEY", + "_HEADERS", + }: + return object.__setattr__(self, name, value) + + if name in self._FIELDS: + # write into dataclass, not manager (prevents shadowing) + if name == "cachedir": + object.__setattr__(self, "_root_cache_directory", Path(value)) + object.__setattr__(self, "_config", replace(self._config, **{name: value})) + return None + + if name in ["server", "apikey"]: + setattr(self._config, name, value) + return None + + object.__setattr__(self, name, value) + return None + + def _create_log_handlers(self, create_file_handler: bool = True) -> None: # noqa: FBT002 + if self.console_handler is not None or self.file_handler is not None: + self.logger.debug("Requested to create log handlers, but they are already created.") + return + + message_format = "[%(levelname)s] [%(asctime)s:%(name)s] %(message)s" + output_formatter = logging.Formatter(message_format, datefmt="%H:%M:%S") + + self.console_handler = logging.StreamHandler() + self.console_handler.setFormatter(output_formatter) + + if create_file_handler: + one_mb = 2**20 + log_path = self._root_cache_directory / "openml_python.log" + self.file_handler = logging.handlers.RotatingFileHandler( + log_path, + maxBytes=one_mb, + backupCount=1, + delay=True, + ) + self.file_handler.setFormatter(output_formatter) + + def _convert_log_levels(self, log_level: int) -> tuple[int, int]: + openml_to_python = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} + python_to_openml = { + logging.DEBUG: 2, + logging.INFO: 1, + logging.WARNING: 0, + logging.CRITICAL: 0, + logging.ERROR: 0, + } + openml_level = python_to_openml.get(log_level, log_level) + python_level = openml_to_python.get(log_level, log_level) + return openml_level, python_level + + def _set_level_register_and_store(self, handler: logging.Handler, log_level: int) -> None: + _oml_level, py_level = self._convert_log_levels(log_level) + handler.setLevel(py_level) + + if self.openml_logger.level > py_level or self.openml_logger.level == logging.NOTSET: + self.openml_logger.setLevel(py_level) + + if handler not in self.openml_logger.handlers: + self.openml_logger.addHandler(handler) + + def set_console_log_level(self, console_output_level: int) -> None: + """Set the log level for console output.""" + assert self.console_handler is not None + self._set_level_register_and_store(self.console_handler, console_output_level) + + def set_file_log_level(self, file_output_level: int) -> None: + """Set the log level for file output.""" + assert self.file_handler is not None + self._set_level_register_and_store(self.file_handler, file_output_level) + + def get_server_base_url(self) -> str: + """Get the base URL of the OpenML server (i.e., without /api).""" + domain, _ = self._config.server.split("/api", maxsplit=1) + return domain.replace("api", "www") + + def get_servers(self, mode: str) -> dict[APIVersion, dict[str, str | None]]: + if mode not in SERVERS_REGISTRY: + raise ValueError( + f'invalid mode="{mode}" allowed modes: {", ".join(list(SERVERS_REGISTRY.keys()))}' + ) + return deepcopy(SERVERS_REGISTRY[mode]) + + def set_servers(self, mode: str) -> None: + servers = self.get_servers(mode) + self._config = replace(self._config, servers=servers) + + def set_api_version( + self, + api_version: APIVersion, + fallback_api_version: APIVersion | None = None, + ) -> None: + if api_version not in APIVersion: + raise ValueError( + f'invalid api_version="{api_version}" ' + f"allowed versions: {', '.join(list(APIVersion))}" + ) + + if fallback_api_version is not None and fallback_api_version not in APIVersion: + raise ValueError( + f'invalid fallback_api_version="{fallback_api_version}" ' + f"allowed versions: {', '.join(list(APIVersion))}" + ) + + self._config = replace( + self._config, + api_version=api_version, + fallback_api_version=fallback_api_version, + ) + + def set_retry_policy( + self, value: Literal["human", "robot"], n_retries: int | None = None + ) -> None: + """Set the retry policy for server connections.""" + default_retries_by_policy = {"human": 5, "robot": 50} + + if value not in default_retries_by_policy: + raise ValueError( + f"Detected retry_policy '{value}' but must be one of " + f"{list(default_retries_by_policy.keys())}", + ) + if n_retries is not None and not isinstance(n_retries, int): + raise TypeError( + f"`n_retries` must be of type `int` or `None` but is `{type(n_retries)}`." + ) + + if isinstance(n_retries, int) and n_retries < 1: + raise ValueError(f"`n_retries` is '{n_retries}' but must be positive.") + + self._config = replace( + self._config, + retry_policy=value, + connection_n_retries=( + default_retries_by_policy[value] if n_retries is None else n_retries + ), + ) + + def _handle_xdg_config_home_backwards_compatibility(self, xdg_home: str) -> Path: + config_dir = Path(xdg_home) / "openml" + + backwards_compat_config_file = Path(xdg_home) / "config" + if not backwards_compat_config_file.exists(): + return config_dir + + try: + self._parse_config(backwards_compat_config_file) + except Exception: # noqa: BLE001 + return config_dir + + correct_config_location = config_dir / "config" + try: + shutil.copy(backwards_compat_config_file, correct_config_location) + self.openml_logger.warning( + "An openml configuration file was found at the old location " + f"at {backwards_compat_config_file}. We have copied it to the new " + f"location at {correct_config_location}. " + "\nTo silence this warning please verify that the configuration file " + f"at {correct_config_location} is correct and delete the file at " + f"{backwards_compat_config_file}." + ) + return config_dir + except Exception as e: # noqa: BLE001 + self.openml_logger.warning( + "While attempting to perform a backwards compatible fix, we " + f"failed to copy the openml config file at " + f"{backwards_compat_config_file}' to {correct_config_location}" + f"\n{type(e)}: {e}", + "\n\nTo silence this warning, please copy the file " + "to the new location and delete the old file at " + f"{backwards_compat_config_file}.", + ) + return backwards_compat_config_file + + def determine_config_file_path(self) -> Path: + """Determine the path to the openml configuration file.""" + if platform.system().lower() == "linux": + xdg_home = os.environ.get("XDG_CONFIG_HOME") + if xdg_home is not None: + config_dir = self._handle_xdg_config_home_backwards_compatibility(xdg_home) + else: + config_dir = Path("~", ".config", "openml") + else: + config_dir = Path("~") / ".openml" + + config_dir = Path(config_dir).expanduser().resolve() + return config_dir / "config" + + def _parse_config(self, config_file: str | Path) -> dict[str, Any]: + config_file = Path(config_file) + config = configparser.RawConfigParser(defaults=OpenMLConfig().__dict__) # type: ignore + + config_file_ = StringIO() + config_file_.write("[FAKE_SECTION]\n") + try: + with config_file.open("r") as fh: + for line in fh: + config_file_.write(line) + except FileNotFoundError: + self.logger.info( + "No config file found at %s, using default configuration.", config_file + ) + except OSError as e: + self.logger.info("Error opening file %s: %s", config_file, e.args[0]) + config_file_.seek(0) + config.read_file(config_file_) + configuration = dict(config.items("FAKE_SECTION")) + for boolean_field in ["avoid_duplicate_runs", "show_progress"]: + if isinstance(config["FAKE_SECTION"][boolean_field], str): + configuration[boolean_field] = config["FAKE_SECTION"].getboolean(boolean_field) # type: ignore + return configuration # type: ignore + + def start_using_configuration_for_example(self) -> None: + """Sets the configuration to connect to the test server with valid apikey.""" + return self._examples.start_using_configuration_for_example() + + def stop_using_configuration_for_example(self) -> None: + """Store the configuration as it was before `start_use_example_configuration`.""" + return self._examples.stop_using_configuration_for_example() + + def _setup(self, config: dict[str, Any] | None = None) -> None: + config_file = self.determine_config_file_path() + config_dir = config_file.parent + + try: + if not config_dir.exists(): + config_dir.mkdir(exist_ok=True, parents=True) + except PermissionError: + self.openml_logger.warning( + f"No permission to create OpenML directory at {config_dir}!" + " This can result in OpenML-Python not working properly." + ) + + if config is None: + config = self._parse_config(config_file) + + self._config = replace( + self._config, + servers=config["servers"], + api_version=config["api_version"], + fallback_api_version=config["fallback_api_version"], + show_progress=config["show_progress"], + avoid_duplicate_runs=config["avoid_duplicate_runs"], + retry_policy=config["retry_policy"], + connection_n_retries=int(config["connection_n_retries"]), + ) + if "server" in config: + self._config.server = config["server"] + if "apikey" in config: + self._config.apikey = config["apikey"] + + user_defined_cache_dir = os.environ.get(self.OPENML_CACHE_DIR_ENV_VAR) + if user_defined_cache_dir is not None: + short_cache_dir = Path(user_defined_cache_dir) + else: + short_cache_dir = Path(config["cachedir"]) + + self._root_cache_directory = short_cache_dir.expanduser().resolve() + self._config = replace(self._config, cachedir=self._root_cache_directory) + + try: + cache_exists = self._root_cache_directory.exists() + if not cache_exists: + self._root_cache_directory.mkdir(exist_ok=True, parents=True) + self._create_log_handlers() + except PermissionError: + self.openml_logger.warning( + f"No permission to create OpenML directory at {self._root_cache_directory}!" + " This can result in OpenML-Python not working properly." + ) + self._create_log_handlers(create_file_handler=False) + + def set_field_in_config_file(self, field: str, value: Any) -> None: + """Set a field in the configuration file.""" + if not hasattr(OpenMLConfig(), field): + raise ValueError( + f"Field '{field}' is not valid and must be one of " + f"'{OpenMLConfig().__dict__.keys()}'." + ) + + self._config = replace(self._config, **{field: value}) + config_file = self.determine_config_file_path() + existing = self._parse_config(config_file) + with config_file.open("w") as fh: + for f in OpenMLConfig().__dict__: + v = value if f == field else existing.get(f) + if v is not None: + fh.write(f"{f} = {v}\n") + + def get_config_as_dict(self) -> dict[str, Any]: + """Get the current configuration as a dictionary.""" + return self._config.__dict__.copy() + + def get_cache_directory(self) -> str: + """Get the cache directory for the current server.""" + url_suffix = urlparse(self._config.server).netloc + url_parts = url_suffix.replace(":", "_").split(".")[::-1] + reversed_url_suffix = os.sep.join(url_parts) # noqa: PTH118 + return os.path.join(self._root_cache_directory, reversed_url_suffix) # noqa: PTH118 + + def set_root_cache_directory(self, root_cache_directory: str | Path) -> None: + """Set the root cache directory.""" + self._root_cache_directory = Path(root_cache_directory) + self._config = replace(self._config, cachedir=self._root_cache_directory) + + @contextmanager + def overwrite_config_context(self, config: dict[str, Any]) -> Iterator[dict[str, Any]]: + """Overwrite the current configuration within a context manager.""" + existing_config = self.get_config_as_dict() + merged_config = {**existing_config, **config} + + self._setup(merged_config) + yield merged_config + self._setup(existing_config) + + +class ConfigurationForExamples: + """Allows easy switching to and from a test configuration, used for examples.""" + + _last_used_servers = None + _start_last_called = False + + def __init__(self, manager: OpenMLConfigManager): + self._manager = manager + self._test_servers = manager.get_servers("test") + + def start_using_configuration_for_example(self) -> None: + """Sets the configuration to connect to the test server with valid apikey. + + To configuration as was before this call is stored, and can be recovered + by using the `stop_use_example_configuration` method. + """ + if self._start_last_called and self._manager._config.servers == self._test_servers: + # Method is called more than once in a row without modifying the server or apikey. + # We don't want to save the current test configuration as a last used configuration. + return + + self._last_used_servers = self._manager._config.servers + type(self)._start_last_called = True + + # Test server key for examples + self._manager._config = replace( + self._manager._config, + servers=self._test_servers, + ) + warnings.warn( + f"Switching to the test servers {self._test_servers} to not upload results to " + "the live server. Using the test server may result in reduced performance of the " + "API!", + stacklevel=2, + ) + + def stop_using_configuration_for_example(self) -> None: + """Return to configuration as it was before `start_use_example_configuration`.""" + if not type(self)._start_last_called: + # We don't want to allow this because it will (likely) result in the `server` and + # `apikey` variables being set to None. + raise RuntimeError( + "`stop_use_example_configuration` called without a saved config." + "`start_use_example_configuration` must be called first.", + ) + + self._manager._config = replace( + self._manager._config, + servers=cast("dict[APIVersion, dict[str, str | None]]", self._last_used_servers), + ) + type(self)._start_last_called = False + + +__config = OpenMLConfigManager() + + +def __getattr__(name: str) -> Any: + return getattr(__config, name) diff --git a/openml/base.py b/openml/base.py index a282be8eb..f79bc2931 100644 --- a/openml/base.py +++ b/openml/base.py @@ -8,8 +8,8 @@ import xmltodict +import openml import openml._api_calls -import openml.config from .utils import _get_rest_api_type_alias, _tag_openml_base diff --git a/openml/cli.py b/openml/cli.py index c33578f6e..838f774d1 100644 --- a/openml/cli.py +++ b/openml/cli.py @@ -6,10 +6,11 @@ import string import sys from collections.abc import Callable +from dataclasses import fields from pathlib import Path from urllib.parse import urlparse -from openml import config +import openml from openml.__version__ import __version__ @@ -59,17 +60,17 @@ def wait_until_valid_input( def print_configuration() -> None: - file = config.determine_config_file_path() + file = openml.config.determine_config_file_path() header = f"File '{file}' contains (or defaults to):" print(header) - max_key_length = max(map(len, config.get_config_as_dict())) - for field, value in config.get_config_as_dict().items(): + max_key_length = max(map(len, openml.config.get_config_as_dict())) + for field, value in openml.config.get_config_as_dict().items(): print(f"{field.ljust(max_key_length)}: {value}") def verbose_set(field: str, value: str) -> None: - config.set_field_in_config_file(field, value) + openml.config.set_field_in_config_file(field, value) print(f"{field} set to '{value}'.") @@ -82,7 +83,7 @@ def check_apikey(apikey: str) -> str: return "" instructions = ( - f"Your current API key is set to: '{config.apikey}'. " + f"Your current API key is set to: '{openml.config.apikey}'. " "You can get an API key at https://new.openml.org. " "You must create an account if you don't have one yet:\n" " 1. Log in with the account.\n" @@ -109,7 +110,7 @@ def check_server(server: str) -> str: def replace_shorthand(server: str) -> str: if server == "test": - return f"{config.TEST_SERVER_URL}/api/v1/xml" + return f"{openml.config.TEST_SERVER_URL}/api/v1/xml" if server == "production_server": return "https://www.openml.org/api/v1/xml" return server @@ -347,7 +348,9 @@ def main() -> None: "'https://openml.github.io/openml-python/main/usage.html#configuration'.", ) - configurable_fields = [f for f in config._defaults if f not in ["max_retries"]] + configurable_fields = [ + f.name for f in fields(openml._config.OpenMLConfig) if f.name not in ["max_retries"] + ] parser_configure.add_argument( "field", diff --git a/openml/config.py b/openml/config.py deleted file mode 100644 index 638b45650..000000000 --- a/openml/config.py +++ /dev/null @@ -1,529 +0,0 @@ -"""Store module level information like the API key, cache directory and the server""" - -# License: BSD 3-Clause -from __future__ import annotations - -import configparser -import logging -import logging.handlers -import os -import platform -import shutil -import warnings -from collections.abc import Iterator -from contextlib import contextmanager -from io import StringIO -from pathlib import Path -from typing import Any, Literal, cast -from typing_extensions import TypedDict -from urllib.parse import urlparse - -logger = logging.getLogger(__name__) -openml_logger = logging.getLogger("openml") -console_handler: logging.StreamHandler | None = None -file_handler: logging.handlers.RotatingFileHandler | None = None - -OPENML_CACHE_DIR_ENV_VAR = "OPENML_CACHE_DIR" -OPENML_SKIP_PARQUET_ENV_VAR = "OPENML_SKIP_PARQUET" -OPENML_TEST_SERVER_ADMIN_KEY_ENV_VAR = "OPENML_TEST_SERVER_ADMIN_KEY" -_TEST_SERVER_NORMAL_USER_KEY = "normaluser" - -TEST_SERVER_URL = "https://test.openml.org" - - -class _Config(TypedDict): - apikey: str - server: str - cachedir: Path - avoid_duplicate_runs: bool - retry_policy: Literal["human", "robot"] - connection_n_retries: int - show_progress: bool - - -def _create_log_handlers(create_file_handler: bool = True) -> None: # noqa: FBT002 - """Creates but does not attach the log handlers.""" - global console_handler, file_handler # noqa: PLW0603 - if console_handler is not None or file_handler is not None: - logger.debug("Requested to create log handlers, but they are already created.") - return - - message_format = "[%(levelname)s] [%(asctime)s:%(name)s] %(message)s" - output_formatter = logging.Formatter(message_format, datefmt="%H:%M:%S") - - console_handler = logging.StreamHandler() - console_handler.setFormatter(output_formatter) - - if create_file_handler: - one_mb = 2**20 - log_path = _root_cache_directory / "openml_python.log" - file_handler = logging.handlers.RotatingFileHandler( - log_path, - maxBytes=one_mb, - backupCount=1, - delay=True, - ) - file_handler.setFormatter(output_formatter) - - -def _convert_log_levels(log_level: int) -> tuple[int, int]: - """Converts a log level that's either defined by OpenML/Python to both specifications.""" - # OpenML verbosity level don't match Python values directly: - openml_to_python = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} - python_to_openml = { - logging.DEBUG: 2, - logging.INFO: 1, - logging.WARNING: 0, - logging.CRITICAL: 0, - logging.ERROR: 0, - } - # Because the dictionaries share no keys, we use `get` to convert as necessary: - openml_level = python_to_openml.get(log_level, log_level) - python_level = openml_to_python.get(log_level, log_level) - return openml_level, python_level - - -def _set_level_register_and_store(handler: logging.Handler, log_level: int) -> None: - """Set handler log level, register it if needed, save setting to config file if specified.""" - _oml_level, py_level = _convert_log_levels(log_level) - handler.setLevel(py_level) - - if openml_logger.level > py_level or openml_logger.level == logging.NOTSET: - openml_logger.setLevel(py_level) - - if handler not in openml_logger.handlers: - openml_logger.addHandler(handler) - - -def set_console_log_level(console_output_level: int) -> None: - """Set console output to the desired level and register it with openml logger if needed.""" - global console_handler # noqa: PLW0602 - assert console_handler is not None - _set_level_register_and_store(console_handler, console_output_level) - - -def set_file_log_level(file_output_level: int) -> None: - """Set file output to the desired level and register it with openml logger if needed.""" - global file_handler # noqa: PLW0602 - assert file_handler is not None - _set_level_register_and_store(file_handler, file_output_level) - - -# Default values (see also https://github.com/openml/OpenML/wiki/Client-API-Standards) -_user_path = Path("~").expanduser().absolute() - - -def _resolve_default_cache_dir() -> Path: - user_defined_cache_dir = os.environ.get(OPENML_CACHE_DIR_ENV_VAR) - if user_defined_cache_dir is not None: - return Path(user_defined_cache_dir) - - if platform.system().lower() != "linux": - return _user_path / ".openml" - - xdg_cache_home = os.environ.get("XDG_CACHE_HOME") - if xdg_cache_home is None: - return Path("~", ".cache", "openml") - - # This is the proper XDG_CACHE_HOME directory, but - # we unfortunately had a problem where we used XDG_CACHE_HOME/org, - # we check heuristically if this old directory still exists and issue - # a warning if it does. There's too much data to move to do this for the user. - - # The new cache directory exists - cache_dir = Path(xdg_cache_home) / "openml" - if cache_dir.exists(): - return cache_dir - - # The old cache directory *does not* exist - heuristic_dir_for_backwards_compat = Path(xdg_cache_home) / "org" / "openml" - if not heuristic_dir_for_backwards_compat.exists(): - return cache_dir - - root_dir_to_delete = Path(xdg_cache_home) / "org" - openml_logger.warning( - "An old cache directory was found at '%s'. This directory is no longer used by " - "OpenML-Python. To silence this warning you would need to delete the old cache " - "directory. The cached files will then be located in '%s'.", - root_dir_to_delete, - cache_dir, - ) - return Path(xdg_cache_home) - - -_defaults: _Config = { - "apikey": "", - "server": "https://www.openml.org/api/v1/xml", - "cachedir": _resolve_default_cache_dir(), - "avoid_duplicate_runs": False, - "retry_policy": "human", - "connection_n_retries": 5, - "show_progress": False, -} - -# Default values are actually added here in the _setup() function which is -# called at the end of this module -server = _defaults["server"] - - -def get_server_base_url() -> str: - """Return the base URL of the currently configured server. - - Turns ``"https://api.openml.org/api/v1/xml"`` in ``"https://www.openml.org/"`` - and ``"https://test.openml.org/api/v1/xml"`` in ``"https://test.openml.org/"`` - - Returns - ------- - str - """ - domain, _path = server.split("/api", maxsplit=1) - return domain.replace("api", "www") - - -apikey: str = _defaults["apikey"] -show_progress: bool = _defaults["show_progress"] -# The current cache directory (without the server name) -_root_cache_directory: Path = Path(_defaults["cachedir"]) -avoid_duplicate_runs = _defaults["avoid_duplicate_runs"] - -retry_policy: Literal["human", "robot"] = _defaults["retry_policy"] -connection_n_retries: int = _defaults["connection_n_retries"] - - -def set_retry_policy(value: Literal["human", "robot"], n_retries: int | None = None) -> None: - global retry_policy # noqa: PLW0603 - global connection_n_retries # noqa: PLW0603 - default_retries_by_policy = {"human": 5, "robot": 50} - - if value not in default_retries_by_policy: - raise ValueError( - f"Detected retry_policy '{value}' but must be one of " - f"{list(default_retries_by_policy.keys())}", - ) - if n_retries is not None and not isinstance(n_retries, int): - raise TypeError(f"`n_retries` must be of type `int` or `None` but is `{type(n_retries)}`.") - - if isinstance(n_retries, int) and n_retries < 1: - raise ValueError(f"`n_retries` is '{n_retries}' but must be positive.") - - retry_policy = value - connection_n_retries = default_retries_by_policy[value] if n_retries is None else n_retries - - -class ConfigurationForExamples: - """Allows easy switching to and from a test configuration, used for examples.""" - - _last_used_server = None - _last_used_key = None - _start_last_called = False - _test_server = f"{TEST_SERVER_URL}/api/v1/xml" - _test_apikey = _TEST_SERVER_NORMAL_USER_KEY - - @classmethod - def start_using_configuration_for_example(cls) -> None: - """Sets the configuration to connect to the test server with valid apikey. - - To configuration as was before this call is stored, and can be recovered - by using the `stop_use_example_configuration` method. - """ - global server # noqa: PLW0603 - global apikey # noqa: PLW0603 - - if cls._start_last_called and server == cls._test_server and apikey == cls._test_apikey: - # Method is called more than once in a row without modifying the server or apikey. - # We don't want to save the current test configuration as a last used configuration. - return - - cls._last_used_server = server - cls._last_used_key = apikey - cls._start_last_called = True - - # Test server key for examples - server = cls._test_server - apikey = cls._test_apikey - warnings.warn( - f"Switching to the test server {server} to not upload results to the live server. " - "Using the test server may result in reduced performance of the API!", - stacklevel=2, - ) - - @classmethod - def stop_using_configuration_for_example(cls) -> None: - """Return to configuration as it was before `start_use_example_configuration`.""" - if not cls._start_last_called: - # We don't want to allow this because it will (likely) result in the `server` and - # `apikey` variables being set to None. - raise RuntimeError( - "`stop_use_example_configuration` called without a saved config." - "`start_use_example_configuration` must be called first.", - ) - - global server # noqa: PLW0603 - global apikey # noqa: PLW0603 - - server = cast("str", cls._last_used_server) - apikey = cast("str", cls._last_used_key) - cls._start_last_called = False - - -def _handle_xdg_config_home_backwards_compatibility( - xdg_home: str, -) -> Path: - # NOTE(eddiebergman): A previous bug results in the config - # file being located at `${XDG_CONFIG_HOME}/config` instead - # of `${XDG_CONFIG_HOME}/openml/config`. As to maintain backwards - # compatibility, where users may already may have had a configuration, - # we copy it over an issue a warning until it's deleted. - # As a heurisitic to ensure that it's "our" config file, we try parse it first. - config_dir = Path(xdg_home) / "openml" - - backwards_compat_config_file = Path(xdg_home) / "config" - if not backwards_compat_config_file.exists(): - return config_dir - - # If it errors, that's a good sign it's not ours and we can - # safely ignore it, jumping out of this block. This is a heurisitc - try: - _parse_config(backwards_compat_config_file) - except Exception: # noqa: BLE001 - return config_dir - - # Looks like it's ours, lets try copy it to the correct place - correct_config_location = config_dir / "config" - try: - # We copy and return the new copied location - shutil.copy(backwards_compat_config_file, correct_config_location) - openml_logger.warning( - "An openml configuration file was found at the old location " - f"at {backwards_compat_config_file}. We have copied it to the new " - f"location at {correct_config_location}. " - "\nTo silence this warning please verify that the configuration file " - f"at {correct_config_location} is correct and delete the file at " - f"{backwards_compat_config_file}." - ) - return config_dir - except Exception as e: # noqa: BLE001 - # We failed to copy and its ours, return the old one. - openml_logger.warning( - "While attempting to perform a backwards compatible fix, we " - f"failed to copy the openml config file at " - f"{backwards_compat_config_file}' to {correct_config_location}" - f"\n{type(e)}: {e}", - "\n\nTo silence this warning, please copy the file " - "to the new location and delete the old file at " - f"{backwards_compat_config_file}.", - ) - return backwards_compat_config_file - - -def determine_config_file_path() -> Path: - if platform.system().lower() == "linux": - xdg_home = os.environ.get("XDG_CONFIG_HOME") - if xdg_home is not None: - config_dir = _handle_xdg_config_home_backwards_compatibility(xdg_home) - else: - config_dir = Path("~", ".config", "openml") - else: - config_dir = Path("~") / ".openml" - - # Still use os.path.expanduser to trigger the mock in the unit test - config_dir = Path(config_dir).expanduser().resolve() - return config_dir / "config" - - -def _setup(config: _Config | None = None) -> None: - """Setup openml package. Called on first import. - - Reads the config file and sets up apikey, server, cache appropriately. - key and server can be set by the user simply using - openml.config.apikey = THEIRKEY - openml.config.server = SOMESERVER - We could also make it a property but that's less clear. - """ - global apikey # noqa: PLW0603 - global server # noqa: PLW0603 - global _root_cache_directory # noqa: PLW0603 - global avoid_duplicate_runs # noqa: PLW0603 - global show_progress # noqa: PLW0603 - - config_file = determine_config_file_path() - config_dir = config_file.parent - - # read config file, create directory for config file - try: - if not config_dir.exists(): - config_dir.mkdir(exist_ok=True, parents=True) - except PermissionError: - openml_logger.warning( - f"No permission to create OpenML directory at {config_dir}!" - " This can result in OpenML-Python not working properly." - ) - - if config is None: - config = _parse_config(config_file) - - avoid_duplicate_runs = config["avoid_duplicate_runs"] - apikey = config["apikey"] - server = config["server"] - show_progress = config["show_progress"] - n_retries = int(config["connection_n_retries"]) - - set_retry_policy(config["retry_policy"], n_retries) - - user_defined_cache_dir = os.environ.get(OPENML_CACHE_DIR_ENV_VAR) - if user_defined_cache_dir is not None: - short_cache_dir = Path(user_defined_cache_dir) - else: - short_cache_dir = Path(config["cachedir"]) - _root_cache_directory = short_cache_dir.expanduser().resolve() - - try: - cache_exists = _root_cache_directory.exists() - # create the cache subdirectory - if not cache_exists: - _root_cache_directory.mkdir(exist_ok=True, parents=True) - _create_log_handlers() - except PermissionError: - openml_logger.warning( - f"No permission to create OpenML directory at {_root_cache_directory}!" - " This can result in OpenML-Python not working properly." - ) - _create_log_handlers(create_file_handler=False) - - -def set_field_in_config_file(field: str, value: Any) -> None: - """Overwrites the `field` in the configuration file with the new `value`.""" - if field not in _defaults: - raise ValueError(f"Field '{field}' is not valid and must be one of '{_defaults.keys()}'.") - - # TODO(eddiebergman): This use of globals has gone too far - globals()[field] = value - config_file = determine_config_file_path() - config = _parse_config(config_file) - with config_file.open("w") as fh: - for f in _defaults: - # We can't blindly set all values based on globals() because when the user - # sets it through config.FIELD it should not be stored to file. - # There doesn't seem to be a way to avoid writing defaults to file with configparser, - # because it is impossible to distinguish from an explicitly set value that matches - # the default value, to one that was set to its default because it was omitted. - value = globals()[f] if f == field else config.get(f) # type: ignore - if value is not None: - fh.write(f"{f} = {value}\n") - - -def _parse_config(config_file: str | Path) -> _Config: - """Parse the config file, set up defaults.""" - config_file = Path(config_file) - config = configparser.RawConfigParser(defaults=_defaults) # type: ignore - - # The ConfigParser requires a [SECTION_HEADER], which we do not expect in our config file. - # Cheat the ConfigParser module by adding a fake section header - config_file_ = StringIO() - config_file_.write("[FAKE_SECTION]\n") - try: - with config_file.open("r") as fh: - for line in fh: - config_file_.write(line) - except FileNotFoundError: - logger.info("No config file found at %s, using default configuration.", config_file) - except OSError as e: - logger.info("Error opening file %s: %s", config_file, e.args[0]) - config_file_.seek(0) - config.read_file(config_file_) - configuration = dict(config.items("FAKE_SECTION")) - for boolean_field in ["avoid_duplicate_runs", "show_progress"]: - if isinstance(config["FAKE_SECTION"][boolean_field], str): - configuration[boolean_field] = config["FAKE_SECTION"].getboolean(boolean_field) # type: ignore - return configuration # type: ignore - - -def get_config_as_dict() -> _Config: - return { - "apikey": apikey, - "server": server, - "cachedir": _root_cache_directory, - "avoid_duplicate_runs": avoid_duplicate_runs, - "connection_n_retries": connection_n_retries, - "retry_policy": retry_policy, - "show_progress": show_progress, - } - - -# NOTE: For backwards compatibility, we keep the `str` -def get_cache_directory() -> str: - """Get the current cache directory. - - This gets the cache directory for the current server relative - to the root cache directory that can be set via - ``set_root_cache_directory()``. The cache directory is the - ``root_cache_directory`` with additional information on which - subdirectory to use based on the server name. By default it is - ``root_cache_directory / org / openml / www`` for the standard - OpenML.org server and is defined as - ``root_cache_directory / top-level domain / second-level domain / - hostname`` - ``` - - Returns - ------- - cachedir : string - The current cache directory. - - """ - url_suffix = urlparse(server).netloc - url_parts = url_suffix.replace(":", "_").split(".")[::-1] - reversed_url_suffix = os.sep.join(url_parts) # noqa: PTH118 - return os.path.join(_root_cache_directory, reversed_url_suffix) # noqa: PTH118 - - -def set_root_cache_directory(root_cache_directory: str | Path) -> None: - """Set module-wide base cache directory. - - Sets the root cache directory, wherin the cache directories are - created to store content from different OpenML servers. For example, - by default, cached data for the standard OpenML.org server is stored - at ``root_cache_directory / org / openml / www``, and the general - pattern is ``root_cache_directory / top-level domain / second-level - domain / hostname``. - - Parameters - ---------- - root_cache_directory : string - Path to use as cache directory. - - See Also - -------- - get_cache_directory - """ - global _root_cache_directory # noqa: PLW0603 - _root_cache_directory = Path(root_cache_directory) - - -start_using_configuration_for_example = ( - ConfigurationForExamples.start_using_configuration_for_example -) -stop_using_configuration_for_example = ConfigurationForExamples.stop_using_configuration_for_example - - -@contextmanager -def overwrite_config_context(config: dict[str, Any]) -> Iterator[_Config]: - """A context manager to temporarily override variables in the configuration.""" - existing_config = get_config_as_dict() - merged_config = {**existing_config, **config} - - _setup(merged_config) # type: ignore - yield merged_config # type: ignore - - _setup(existing_config) - - -__all__ = [ - "get_cache_directory", - "get_config_as_dict", - "set_root_cache_directory", - "start_using_configuration_for_example", - "stop_using_configuration_for_example", -] - -_setup() diff --git a/openml/datasets/dataset.py b/openml/datasets/dataset.py index d9eee278d..59d6205ba 100644 --- a/openml/datasets/dataset.py +++ b/openml/datasets/dataset.py @@ -17,8 +17,8 @@ import scipy.sparse import xmltodict +import openml from openml.base import OpenMLBase -from openml.config import OPENML_SKIP_PARQUET_ENV_VAR from .data_feature import OpenMLDataFeature @@ -375,7 +375,9 @@ def _download_data(self) -> None: # import required here to avoid circular import. from .functions import _get_dataset_arff, _get_dataset_parquet - skip_parquet = os.environ.get(OPENML_SKIP_PARQUET_ENV_VAR, "false").casefold() == "true" + skip_parquet = ( + os.environ.get(openml.config.OPENML_SKIP_PARQUET_ENV_VAR, "false").casefold() == "true" + ) if self._parquet_url is not None and not skip_parquet: parquet_file = _get_dataset_parquet(self) self.parquet_file = None if parquet_file is None else str(parquet_file) diff --git a/openml/datasets/functions.py b/openml/datasets/functions.py index 3ac657ea0..432938520 100644 --- a/openml/datasets/functions.py +++ b/openml/datasets/functions.py @@ -19,9 +19,9 @@ import xmltodict from scipy.sparse import coo_matrix +import openml import openml._api_calls import openml.utils -from openml.config import OPENML_SKIP_PARQUET_ENV_VAR from openml.exceptions import ( OpenMLHashException, OpenMLPrivateDatasetError, @@ -492,7 +492,9 @@ def get_dataset( # noqa: C901, PLR0912 qualities_file = _get_dataset_qualities_file(did_cache_dir, dataset_id) parquet_file = None - skip_parquet = os.environ.get(OPENML_SKIP_PARQUET_ENV_VAR, "false").casefold() == "true" + skip_parquet = ( + os.environ.get(openml.config.OPENML_SKIP_PARQUET_ENV_VAR, "false").casefold() == "true" + ) download_parquet = "oml:parquet_url" in description and not skip_parquet if download_parquet and (download_data or download_all_files): try: diff --git a/openml/enums.py b/openml/enums.py new file mode 100644 index 000000000..f5a4381b7 --- /dev/null +++ b/openml/enums.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from enum import Enum + + +class APIVersion(str, Enum): + """Supported OpenML API versions.""" + + V1 = "v1" + V2 = "v2" + + +class ResourceType(str, Enum): + """Canonical resource types exposed by the OpenML API.""" + + DATASET = "dataset" + TASK = "task" + TASK_TYPE = "task_type" + EVALUATION_MEASURE = "evaluation_measure" + ESTIMATION_PROCEDURE = "estimation_procedure" + EVALUATION = "evaluation" + FLOW = "flow" + STUDY = "study" + RUN = "run" + SETUP = "setup" + USER = "user" + + +class RetryPolicy(str, Enum): + """Retry behavior for failed API requests.""" + + HUMAN = "human" + ROBOT = "robot" diff --git a/openml/evaluations/evaluation.py b/openml/evaluations/evaluation.py index 5db087024..e15bf728a 100644 --- a/openml/evaluations/evaluation.py +++ b/openml/evaluations/evaluation.py @@ -3,7 +3,7 @@ from dataclasses import asdict, dataclass -import openml.config +import openml import openml.datasets import openml.flows import openml.runs diff --git a/openml/exceptions.py b/openml/exceptions.py index 1c1343ff3..e96ebfcb2 100644 --- a/openml/exceptions.py +++ b/openml/exceptions.py @@ -88,3 +88,7 @@ def __init__(self, message: str): class ObjectNotPublishedError(PyOpenMLError): """Indicates an object has not been published yet.""" + + +class OpenMLNotSupportedError(PyOpenMLError): + """Raised when an API operation is not supported for a resource/version.""" diff --git a/openml/runs/functions.py b/openml/runs/functions.py index 503788dbd..b8eb739ae 100644 --- a/openml/runs/functions.py +++ b/openml/runs/functions.py @@ -18,7 +18,6 @@ import openml import openml._api_calls import openml.utils -from openml import config from openml.exceptions import ( OpenMLCacheException, OpenMLRunsExistError, @@ -45,7 +44,7 @@ # Avoid import cycles: https://mypy.readthedocs.io/en/latest/common_issues.html#import-cycles if TYPE_CHECKING: - from openml.config import _Config + from openml._config import _Config from openml.extensions.extension_interface import Extension # get_dict is in run.py to avoid circular imports @@ -107,7 +106,7 @@ def run_model_on_task( # noqa: PLR0913 """ if avoid_duplicate_runs is None: avoid_duplicate_runs = openml.config.avoid_duplicate_runs - if avoid_duplicate_runs and not config.apikey: + if avoid_duplicate_runs and not openml.config.apikey: warnings.warn( "avoid_duplicate_runs is set to True, but no API key is set. " "Please set your API key in the OpenML configuration file, see" @@ -336,7 +335,7 @@ def run_flow_on_task( # noqa: C901, PLR0912, PLR0915, PLR0913 message = f"Executed Task {task.task_id} with Flow id:{run.flow_id}" else: message = f"Executed Task {task.task_id} on local Flow with name {flow.name}." - config.logger.info(message) + openml.config.logger.info(message) return run @@ -528,7 +527,7 @@ def _run_task_get_arffcontent( # noqa: PLR0915, PLR0912, C901 # The forked child process may not copy the configuration state of OpenML from the parent. # Current configuration setup needs to be copied and passed to the child processes. - _config = config.get_config_as_dict() + _config = openml.config.get_config_as_dict() # Execute runs in parallel # assuming the same number of tasks as workers (n_jobs), the total compute time for this # statement will be similar to the slowest run @@ -733,7 +732,7 @@ def _run_task_get_arffcontent_parallel_helper( # noqa: PLR0913 """ # Sets up the OpenML instantiated in the child process to match that of the parent's # if configuration=None, loads the default - config._setup(configuration) + openml.config._setup(configuration) train_indices, test_indices = task.get_train_test_split_indices( repeat=rep_no, @@ -762,7 +761,7 @@ def _run_task_get_arffcontent_parallel_helper( # noqa: PLR0913 f"task_class={task.__class__.__name__}" ) - config.logger.info( + openml.config.logger.info( f"Going to run model {model!s} on " f"dataset {openml.datasets.get_dataset(task.dataset_id).name} " f"for repeat {rep_no} fold {fold_no} sample {sample_no}" diff --git a/openml/setups/functions.py b/openml/setups/functions.py index 4bf279ed1..a24d3a456 100644 --- a/openml/setups/functions.py +++ b/openml/setups/functions.py @@ -14,7 +14,6 @@ import openml import openml.exceptions import openml.utils -from openml import config from openml.flows import OpenMLFlow, flow_exists from .setup import OpenMLParameter, OpenMLSetup @@ -84,7 +83,7 @@ def _get_cached_setup(setup_id: int) -> OpenMLSetup: OpenMLCacheException If the setup file for the given setup ID is not cached. """ - cache_dir = Path(config.get_cache_directory()) + cache_dir = Path(openml.config.get_cache_directory()) setup_cache_dir = cache_dir / "setups" / str(setup_id) try: setup_file = setup_cache_dir / "description.xml" @@ -112,7 +111,7 @@ def get_setup(setup_id: int) -> OpenMLSetup: ------- OpenMLSetup (an initialized openml setup object) """ - setup_dir = Path(config.get_cache_directory()) / "setups" / str(setup_id) + setup_dir = Path(openml.config.get_cache_directory()) / "setups" / str(setup_id) setup_dir.mkdir(exist_ok=True, parents=True) setup_file = setup_dir / "description.xml" diff --git a/openml/setups/setup.py b/openml/setups/setup.py index 170838138..19a11e0d4 100644 --- a/openml/setups/setup.py +++ b/openml/setups/setup.py @@ -4,7 +4,7 @@ from dataclasses import asdict, dataclass from typing import Any -import openml.config +import openml import openml.flows diff --git a/openml/study/functions.py b/openml/study/functions.py index bb24ddcff..367537773 100644 --- a/openml/study/functions.py +++ b/openml/study/functions.py @@ -8,8 +8,8 @@ import pandas as pd import xmltodict +import openml import openml._api_calls -import openml.config import openml.utils from openml.study.study import OpenMLBenchmarkSuite, OpenMLStudy diff --git a/openml/study/study.py b/openml/study/study.py index 7a9c80bbe..803c6455b 100644 --- a/openml/study/study.py +++ b/openml/study/study.py @@ -5,8 +5,8 @@ from collections.abc import Sequence from typing import Any +import openml from openml.base import OpenMLBase -from openml.config import get_server_base_url class BaseStudy(OpenMLBase): @@ -111,7 +111,7 @@ def _get_repr_body_fields(self) -> Sequence[tuple[str, str | int | list[str]]]: fields["ID"] = self.study_id fields["Study URL"] = self.openml_url if self.creator is not None: - fields["Creator"] = f"{get_server_base_url()}/u/{self.creator}" + fields["Creator"] = f"{openml.config.get_server_base_url()}/u/{self.creator}" if self.creation_date is not None: fields["Upload Time"] = self.creation_date.replace("T", " ") if self.data is not None: diff --git a/openml/tasks/task.py b/openml/tasks/task.py index 385b1f949..86c158bcc 100644 --- a/openml/tasks/task.py +++ b/openml/tasks/task.py @@ -9,8 +9,8 @@ from typing import TYPE_CHECKING, Any, ClassVar from typing_extensions import TypedDict +import openml import openml._api_calls -import openml.config from openml import datasets from openml.base import OpenMLBase from openml.utils import _create_cache_directory_for_id diff --git a/openml/testing.py b/openml/testing.py index 9f694f9bf..76b84b9f3 100644 --- a/openml/testing.py +++ b/openml/testing.py @@ -15,6 +15,8 @@ import requests import openml +from openml._api import API_REGISTRY, HTTPCache, HTTPClient, MinIOClient, ResourceAPI +from openml.enums import APIVersion, ResourceType from openml.exceptions import OpenMLServerException from openml.tasks import TaskType @@ -47,7 +49,7 @@ class TestBase(unittest.TestCase): "user": [], } flow_name_tracker: ClassVar[list[str]] = [] - test_server = f"{openml.config.TEST_SERVER_URL}/api/v1/xml" + test_server = f"{openml.config.TEST_SERVER_URL}/api/v1/xml/" admin_key = os.environ.get(openml.config.OPENML_TEST_SERVER_ADMIN_KEY_ENV_VAR) user_key = openml.config._TEST_SERVER_NORMAL_USER_KEY @@ -55,6 +57,11 @@ class TestBase(unittest.TestCase): logger = logging.getLogger("unit_tests_published_entities") logger.setLevel(logging.DEBUG) + # migration-specific attributes + cache: HTTPCache + http_clients: dict[APIVersion, HTTPClient] + minio_client: MinIOClient + def setUp(self, n_levels: int = 1, tmpdir_suffix: str = "") -> None: """Setup variables and temporary directories. @@ -108,6 +115,13 @@ def setUp(self, n_levels: int = 1, tmpdir_suffix: str = "") -> None: self.connection_n_retries = openml.config.connection_n_retries openml.config.set_retry_policy("robot", n_retries=20) + self.cache = HTTPCache() + self.http_clients = { + APIVersion.V1: HTTPClient(api_version=APIVersion.V1), + APIVersion.V2: HTTPClient(api_version=APIVersion.V2), + } + self.minio_client = MinIOClient() + def use_production_server(self) -> None: """ Use the production server for the OpenML API calls. @@ -275,6 +289,11 @@ def _check_fold_timing_evaluations( # noqa: PLR0913 assert evaluation >= min_val assert evaluation <= max_val + def _create_resource(self, api_version: APIVersion, resource_type: ResourceType) -> ResourceAPI: + http_client = self.http_clients[api_version] + resource_cls = API_REGISTRY[api_version][resource_type] + return resource_cls(http=http_client, minio=self.minio_client) + def check_task_existence( task_type: TaskType, diff --git a/openml/utils/_openml.py b/openml/utils/_openml.py index f18dbe3e0..2bf54690e 100644 --- a/openml/utils/_openml.py +++ b/openml/utils/_openml.py @@ -26,7 +26,6 @@ import openml import openml._api_calls import openml.exceptions -from openml import config # Avoid import cycles: https://mypy.readthedocs.io/en/latest/common_issues.html#import-cycles if TYPE_CHECKING: @@ -336,7 +335,7 @@ def _list_all( # noqa: C901 def _get_cache_dir_for_key(key: str) -> Path: - return Path(config.get_cache_directory()) / key + return Path(openml.config.get_cache_directory()) / key def _create_cache_directory(key: str) -> Path: @@ -443,12 +442,12 @@ def get_cache_size() -> int: cache_size: int Total size of cache in bytes """ - path = Path(config.get_cache_directory()) + path = Path(openml.config.get_cache_directory()) return sum(f.stat().st_size for f in path.rglob("*") if f.is_file()) def _create_lockfiles_dir() -> Path: - path = Path(config.get_cache_directory()) / "locks" + path = Path(openml.config.get_cache_directory()) / "locks" # TODO(eddiebergman): Not sure why this is allowed to error and ignore??? with contextlib.suppress(OSError): path.mkdir(exist_ok=True, parents=True) diff --git a/tests/conftest.py b/tests/conftest.py index 2a7a6dcc7..c8455334b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -34,6 +34,8 @@ from pathlib import Path import pytest import openml_sklearn +from openml._api import HTTPClient, MinIOClient +from openml.enums import APIVersion import openml from openml.testing import TestBase @@ -273,11 +275,11 @@ def as_robot() -> Iterator[None]: @pytest.fixture(autouse=True) def with_server(request): if "production_server" in request.keywords: - openml.config.server = "https://www.openml.org/api/v1/xml" + openml.config.server = "https://www.openml.org/api/v1/xml/" openml.config.apikey = None yield return - openml.config.server = f"{openml.config.TEST_SERVER_URL}/api/v1/xml" + openml.config.server = f"{openml.config.TEST_SERVER_URL}/api/v1/xml/" openml.config.apikey = TestBase.user_key yield @@ -307,3 +309,28 @@ def workdir(tmp_path): os.chdir(tmp_path) yield tmp_path os.chdir(original_cwd) + + +@pytest.fixture +def use_api_v1() -> None: + openml.config.set_api_version(api_version=APIVersion.V1) + + +@pytest.fixture +def use_api_v2() -> None: + openml.config.set_api_version(api_version=APIVersion.V2) + + +@pytest.fixture +def http_client_v1() -> HTTPClient: + return HTTPClient(api_version=APIVersion.V1) + + +@pytest.fixture +def http_client_v2() -> HTTPClient: + return HTTPClient(api_version=APIVersion.V2) + + +@pytest.fixture +def minio_client() -> MinIOClient: + return MinIOClient() diff --git a/tests/test_api/__init__.py b/tests/test_api/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_api/test_http.py b/tests/test_api/test_http.py new file mode 100644 index 000000000..e2150f5b0 --- /dev/null +++ b/tests/test_api/test_http.py @@ -0,0 +1,238 @@ +from requests import Response, Request, Session +from unittest.mock import patch +import pytest +import os +from pathlib import Path +from urllib.parse import urljoin, urlparse +from openml.enums import APIVersion +from openml.exceptions import OpenMLAuthenticationError +from openml._api import HTTPClient, HTTPCache +import openml + + +@pytest.fixture +def cache(http_client_v1) -> HTTPCache: + return http_client_v1.cache + + +@pytest.fixture +def http_client(http_client_v1) -> HTTPClient: + return http_client_v1 + + +@pytest.fixture +def sample_path() -> str: + return "task/1" + + +@pytest.fixture +def sample_url(sample_path) -> str: + return urljoin(openml.config.server, sample_path) + + +@pytest.fixture +def sample_download_url() -> str: + server = openml.config.server.split("api/")[0] + endpoint = "data/v1/download/1/anneal.arff" + url = server + endpoint + return url + + +def test_cache(cache, sample_url): + params = {"param1": "value1", "param2": "value2"} + + parsed_url = urlparse(sample_url) + netloc_parts = parsed_url.netloc.split(".")[::-1] + path_parts = parsed_url.path.strip("/").split("/") + params_key = "&".join([f"{k}={v}" for k, v in params.items()]) + + + key = cache.get_key(sample_url, params) + + expected_key = os.path.join( + *netloc_parts, + *path_parts, + params_key, + ) + + assert key == expected_key + + # mock response + req = Request("GET", sample_url).prepare() + response = Response() + response.status_code = 200 + response.url = sample_url + response.reason = "OK" + response._content = b"test" + response.headers = {"Content-Type": "text/xml"} + response.encoding = "utf-8" + response.request = req + response.elapsed = type("Elapsed", (), {"total_seconds": lambda x: 0.1})() + + cache.save(key, response) + cached = cache.load(key) + + assert cached.status_code == 200 + assert cached.url == sample_url + assert cached.content == b"test" + assert cached.headers["Content-Type"] == "text/xml" + + +@pytest.mark.uses_test_server() +def test_get(http_client): + response = http_client.get("task/1") + + assert response.status_code == 200 + assert b" DummyTaskV1API: + return DummyTaskV1API(http=http_client_v1, minio=minio_client) + + +@pytest.fixture +def dummy_task_v2(http_client_v2, minio_client) -> DummyTaskV1API: + return DummyTaskV2API(http=http_client_v2, minio=minio_client) + + +@pytest.fixture +def dummy_task_fallback(dummy_task_v1, dummy_task_v2) -> DummyTaskV1API: + return FallbackProxy(dummy_task_v2, dummy_task_v1) + + +def test_v1_publish(dummy_task_v1, use_api_v1): + resource = dummy_task_v1 + resource_name = resource.resource_type.value + resource_files = {"description": "Resource Description File"} + resource_id = 123 + + with patch.object(Session, "request") as mock_request: + mock_request.return_value = Response() + mock_request.return_value.status_code = 200 + mock_request.return_value._content = ( + f'\n' + f"\t{resource_id}\n" + f"\n" + ).encode("utf-8") + + published_resource_id = resource.publish( + resource_name, + files=resource_files, + ) + + assert resource_id == published_resource_id + + mock_request.assert_called_once_with( + method="POST", + url=openml.config.server + resource_name, + params={}, + data={"api_key": openml.config.apikey}, + headers=openml.config._HEADERS, + files=resource_files, + ) + + +def test_v1_delete(dummy_task_v1, use_api_v1): + resource = dummy_task_v1 + resource_name = resource.resource_type.value + resource_id = 123 + + with patch.object(Session, "request") as mock_request: + mock_request.return_value = Response() + mock_request.return_value.status_code = 200 + mock_request.return_value._content = ( + f'\n' + f" {resource_id}\n" + f"\n" + ).encode("utf-8") + + resource.delete(resource_id) + + mock_request.assert_called_once_with( + method="DELETE", + url=( + openml.config.server + + resource_name + + "/" + + str(resource_id) + ), + params={"api_key": openml.config.apikey}, + data={}, + headers=openml.config._HEADERS, + files=None, + ) + + +def test_v1_tag(dummy_task_v1, use_api_v1): + resource = dummy_task_v1 + resource_id = 123 + resource_tag = "TAG" + + with patch.object(Session, "request") as mock_request: + mock_request.return_value = Response() + mock_request.return_value.status_code = 200 + mock_request.return_value._content = ( + f'' + f"{resource_id}" + f"{resource_tag}" + f"" + ).encode("utf-8") + + tags = resource.tag(resource_id, resource_tag) + + assert resource_tag in tags + + mock_request.assert_called_once_with( + method="POST", + url=( + openml.config.server + + resource.resource_type + + "/tag" + ), + params={}, + data={ + "api_key": openml.config.apikey, + "task_id": resource_id, + "tag": resource_tag, + }, + headers=openml.config._HEADERS, + files=None, + ) + + +def test_v1_untag(dummy_task_v1, use_api_v1): + resource = dummy_task_v1 + resource_id = 123 + resource_tag = "TAG" + + with patch.object(Session, "request") as mock_request: + mock_request.return_value = Response() + mock_request.return_value.status_code = 200 + mock_request.return_value._content = ( + f'' + f"{resource_id}" + f"" + ).encode("utf-8") + + tags = resource.untag(resource_id, resource_tag) + + assert resource_tag not in tags + + mock_request.assert_called_once_with( + method="POST", + url=( + openml.config.server + + resource.resource_type + + "/untag" + ), + params={}, + data={ + "api_key": openml.config.apikey, + "task_id": resource_id, + "tag": resource_tag, + }, + headers=openml.config._HEADERS, + files=None, + ) + + +def test_v2_publish(dummy_task_v2, use_api_v2): + with pytest.raises(OpenMLNotSupportedError): + dummy_task_v2.publish(path=None, files=None) + + +def test_v2_delete(dummy_task_v2, use_api_v2): + with pytest.raises(OpenMLNotSupportedError): + dummy_task_v2.delete(resource_id=None) + + +def test_v2_tag(dummy_task_v2, use_api_v2): + with pytest.raises(OpenMLNotSupportedError): + dummy_task_v2.tag(resource_id=None, tag=None) + + +def test_v2_untag(dummy_task_v2, use_api_v2): + with pytest.raises(OpenMLNotSupportedError): + dummy_task_v2.untag(resource_id=None, tag=None) + + +def test_fallback_publish(dummy_task_fallback): + with patch.object(ResourceV1API, "publish") as mock_publish: + mock_publish.return_value = None + dummy_task_fallback.publish(path=None, files=None) + mock_publish.assert_called_once_with(path=None, files=None) + + +def test_fallback_delete(dummy_task_fallback): + with patch.object(ResourceV1API, "delete") as mock_delete: + mock_delete.return_value = None + dummy_task_fallback.delete(resource_id=None) + mock_delete.assert_called_once_with(resource_id=None) + + +def test_fallback_tag(dummy_task_fallback): + with patch.object(ResourceV1API, "tag") as mock_tag: + mock_tag.return_value = None + dummy_task_fallback.tag(resource_id=None, tag=None) + mock_tag.assert_called_once_with(resource_id=None, tag=None) + + +def test_fallback_untag(dummy_task_fallback): + with patch.object(ResourceV1API, "untag") as mock_untag: + mock_untag.return_value = None + dummy_task_fallback.untag(resource_id=None, tag=None) + mock_untag.assert_called_once_with(resource_id=None, tag=None) diff --git a/tests/test_evaluations/test_evaluations_example.py b/tests/test_evaluations/test_evaluations_example.py index a9ad7e8c1..b321f475d 100644 --- a/tests/test_evaluations/test_evaluations_example.py +++ b/tests/test_evaluations/test_evaluations_example.py @@ -3,14 +3,13 @@ import unittest -from openml.config import overwrite_config_context - +import openml class TestEvaluationsExample(unittest.TestCase): def test_example_python_paper(self): # Example script which will appear in the upcoming OpenML-Python paper # This test ensures that the example will keep running! - with overwrite_config_context( + with openml.config.overwrite_config_context( # noqa: F823 { "server": "https://www.openml.org/api/v1/xml", "apikey": None, @@ -18,7 +17,6 @@ def test_example_python_paper(self): ): import matplotlib.pyplot as plt import numpy as np - import openml df = openml.evaluations.list_evaluations_setups( "predictive_accuracy", diff --git a/tests/test_openml/test_api_calls.py b/tests/test_openml/test_api_calls.py index 3f30f38ba..cf021a0ab 100644 --- a/tests/test_openml/test_api_calls.py +++ b/tests/test_openml/test_api_calls.py @@ -9,7 +9,6 @@ import pytest import openml -from openml.config import ConfigurationForExamples import openml.testing from openml._api_calls import _download_minio_bucket, API_TOKEN_HELP_LINK diff --git a/tests/test_openml/test_config.py b/tests/test_openml/test_config.py index 13b06223a..0cd642fe7 100644 --- a/tests/test_openml/test_config.py +++ b/tests/test_openml/test_config.py @@ -9,12 +9,14 @@ from typing import Any, Iterator from pathlib import Path import platform +from urllib.parse import urlparse import pytest -import openml.config +import openml import openml.testing from openml.testing import TestBase +from openml.enums import APIVersion @contextmanager @@ -37,7 +39,7 @@ def safe_environ_patcher(key: str, value: Any) -> Iterator[None]: class TestConfig(openml.testing.TestBase): @unittest.mock.patch("openml.config.openml_logger.warning") - @unittest.mock.patch("openml.config._create_log_handlers") + @unittest.mock.patch("openml._config.OpenMLConfigManager._create_log_handlers") @unittest.skipIf(os.name == "nt", "https://github.com/openml/openml-python/issues/1033") @unittest.skipIf( platform.uname().release.endswith(("-Microsoft", "microsoft-standard-WSL2")), @@ -77,22 +79,24 @@ def test_get_config_as_dict(self): """Checks if the current configuration is returned accurately as a dict.""" config = openml.config.get_config_as_dict() _config = {} - _config["apikey"] = TestBase.user_key - _config["server"] = f"{openml.config.TEST_SERVER_URL}/api/v1/xml" + _config["api_version"] = APIVersion.V1 + _config["fallback_api_version"] = None + _config["servers"] = openml.config.get_servers("test") _config["cachedir"] = self.workdir _config["avoid_duplicate_runs"] = False _config["connection_n_retries"] = 20 _config["retry_policy"] = "robot" _config["show_progress"] = False assert isinstance(config, dict) - assert len(config) == 7 + assert len(config) == 8 self.assertDictEqual(config, _config) def test_setup_with_config(self): """Checks if the OpenML configuration can be updated using _setup().""" _config = {} - _config["apikey"] = TestBase.user_key - _config["server"] = "https://www.openml.org/api/v1/xml" + _config["api_version"] = APIVersion.V1 + _config["fallback_api_version"] = None + _config["servers"] = openml.config.get_servers("test") _config["cachedir"] = self.workdir _config["avoid_duplicate_runs"] = True _config["retry_policy"] = "human" @@ -127,7 +131,6 @@ def test_switch_from_example_configuration(self): openml.config.start_using_configuration_for_example() openml.config.stop_using_configuration_for_example() - assert openml.config.apikey == TestBase.user_key assert openml.config.server == self.production_server @@ -136,7 +139,7 @@ def test_example_configuration_stop_before_start(self): error_regex = ".*stop_use_example_configuration.*start_use_example_configuration.*first" # Tests do not reset the state of this class. Thus, we ensure it is in # the original state before the test. - openml.config.ConfigurationForExamples._start_last_called = False + openml.config._examples._start_last_called = False self.assertRaisesRegex( RuntimeError, error_regex, @@ -191,5 +194,6 @@ def test_openml_cache_dir_env_var(tmp_path: Path) -> None: with safe_environ_patcher("OPENML_CACHE_DIR", str(expected_path)): openml.config._setup() + assert openml.config._root_cache_directory == expected_path assert openml.config.get_cache_directory() == str(expected_path / "org" / "openml" / "www")