diff --git a/asdf/_tests/test_asdf.py b/asdf/_tests/test_asdf.py index b7095467f..cf1cdaa3d 100644 --- a/asdf/_tests/test_asdf.py +++ b/asdf/_tests/test_asdf.py @@ -1,5 +1,7 @@ import os +import pickle +import numpy as np import pytest from asdf import config_context @@ -379,3 +381,12 @@ def test_fsspec_http(httpserver): with fsspec.open(fn) as f: af = open_asdf(f) assert_tree_match(tree, af.tree) + + +def test_asdf_file_pickle_from_dict(): + """Verify that an AsdfFile created from a dict (with no file descriptor) can be pickled""" + tree = {"a": 1, "b": {"c": 2, "d": np.ones((10, 10))}} + af = AsdfFile(tree) + pkl = pickle.dumps(af) + loaded = pickle.loads(pkl) # noqa: S301 + assert_tree_match(af.tree, loaded.tree) diff --git a/asdf/_tests/test_extension.py b/asdf/_tests/test_extension.py index de0c0c1fa..ae13a46ad 100644 --- a/asdf/_tests/test_extension.py +++ b/asdf/_tests/test_extension.py @@ -20,7 +20,8 @@ Validator, get_cached_extension_manager, ) -from asdf.extension._manager import _resolve_type +from asdf.extension._manager import ValidatorManager, _resolve_type +from asdf.tagged import TaggedList from asdf.testing.helpers import roundtrip_object @@ -734,6 +735,31 @@ def test_validator(): af.validate() +class ValidatorFailOn(Validator): + schema_property = "fail" + tags = ["fail"] + + def __init__(self, fail_on): + self.fail_on = fail_on + + def validate(self, schema_property_value, node, schema): + if schema_property_value == self.fail_on: + yield ValidationError("Node was doomed to fail") + + +def test_validator_manager(): + validator = ValidatorManager([ValidatorFailOn("bar")]) + errs = list(validator.validate("fail", "foo", TaggedList([], "fail"), {})) + assert len(errs) == 0 + + errs = list(validator.validate("fail", "bar", TaggedList([], "other"), {})) + assert len(errs) == 0 + + errs = list(validator.validate("fail", "bar", TaggedList([], "fail"), {})) + assert len(errs) == 1 + assert isinstance(errs[0], ValidationError) + + def test_converter_deferral(): class Bar: def __init__(self, value): diff --git a/asdf/extension/_manager.py b/asdf/extension/_manager.py index d02b35600..74e8ec2ed 100644 --- a/asdf/extension/_manager.py +++ b/asdf/extension/_manager.py @@ -1,11 +1,23 @@ +from __future__ import annotations + import sys +from dataclasses import dataclass from functools import lru_cache +from typing import TYPE_CHECKING from asdf.tagged import Tagged from asdf.util import get_class_name, uri_match from ._extension import ExtensionProxy +if TYPE_CHECKING: + from collections.abc import Iterable, Iterator, Mapping + from typing import Any + + from asdf.exceptions import ValidationError + from asdf.extension import Validator + from asdf.typing import TreeKey + def _resolve_type(path): """ @@ -317,7 +329,7 @@ def _get_cached_extension_manager(extensions): class ValidatorManager: """ - Wraps a list of custom validators and indexes them by schema property. + Wraps a list of custom validators and binds them to their associated schemas. Parameters ---------- @@ -325,24 +337,18 @@ class ValidatorManager: List of validators to manage. """ - def __init__(self, validators): - self._validators = list(validators) + def __init__(self, validators: Iterable[Validator]): + self._validators = {} + for validator in validators: + if validator.schema_property not in self._validators: + self._validators[validator.schema_property] = set() - self._validators_by_schema_property = {} - for validator in self._validators: - if validator.schema_property not in self._validators_by_schema_property: - self._validators_by_schema_property[validator.schema_property] = set() - self._validators_by_schema_property[validator.schema_property].add(validator) + self._validators[validator.schema_property].add(validator) - self._jsonschema_validators_by_schema_property = {} - for schema_property in self._validators_by_schema_property: - self._jsonschema_validators_by_schema_property[schema_property] = self._get_jsonschema_validator( - schema_property, - ) - - def validate(self, schema_property, schema_property_value, node, schema): - """ - Validate an ASDF tree node against custom validators for a schema property. + def validate( + self, schema_property: str, schema_property_value: Any, node: Tagged, schema: Mapping[TreeKey, Any] + ) -> Iterator[ValidationError]: + """Validate an ASDF tree node against custom validators for a schema property. Parameters ---------- @@ -360,27 +366,34 @@ def validate(self, schema_property, schema_property_value, node, schema): ------ asdf.exceptions.ValidationError """ - if schema_property in self._validators_by_schema_property: - for validator in self._validators_by_schema_property[schema_property]: - if _validator_matches(validator, node): - yield from validator.validate(schema_property_value, node, schema) + for validator in self._validators[schema_property]: + if _validator_matches(validator, node): + yield from validator.validate(schema_property_value, node, schema) - def get_jsonschema_validators(self): - """ - Get a dictionary of validator methods suitable for use - with the jsonschema library. + def get_jsonschema_validators(self) -> dict[str, JsonSchemaValidators]: + """Get a dictionary mapping schema names to ``jsonschema``-compatible validator functions.""" + return { + schema_property: JsonSchemaValidators(schema_property, frozenset(validators)) + for schema_property, validators in self._validators.items() + } - Returns - ------- - dict of str: callable - """ - return dict(self._jsonschema_validators_by_schema_property) - def _get_jsonschema_validator(self, schema_property): - def _validator(_, schema_property_value, node, schema): - return self.validate(schema_property, schema_property_value, node, schema) +@dataclass(frozen=True, slots=True) +class JsonSchemaValidators: + """Callable that wraps a set of `Validator` objects to make them compatible with `jsonschema`. + + Each validator is always passed `schema_property` as its first argument regardless of the actual input schema. + """ + + schema_property: str + validators: frozenset[Validator] - return _validator + def __call__( + self, _schema_property: Any, schema_property_value: Any, node: Tagged, schema: Mapping[TreeKey, Any] + ) -> Iterator[ValidationError]: + for validator in self.validators: + if _validator_matches(validator, node): + yield from validator.validate(schema_property_value, node, schema) def _validator_matches(validator, node): diff --git a/asdf/extension/_validator.py b/asdf/extension/_validator.py index c695ac7a9..155e2d3b3 100644 --- a/asdf/extension/_validator.py +++ b/asdf/extension/_validator.py @@ -1,4 +1,14 @@ +from __future__ import annotations + import abc +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from collections.abc import Iterable, Iterator, Mapping + + from asdf.exceptions import ValidationError + from asdf.tagged import Tagged + from asdf.typing import TreeKey class Validator(abc.ABC): @@ -8,13 +18,14 @@ class Validator(abc.ABC): """ @abc.abstractproperty - def schema_property(self): + def schema_property(self) -> str: """ Name of the schema property used to invoke this validator. """ + ... @abc.abstractproperty - def tags(self): + def tags(self) -> Iterable[str]: """ Get the YAML tags that are appropriate to this validator. URI patterns are permitted, see `asdf.util.uri_match` for details. @@ -24,9 +35,12 @@ def tags(self): iterable of str Tag URIs or URI patterns. """ + ... @abc.abstractmethod - def validate(self, schema_property_value, node, schema): + def validate( + self, schema_property_value: Any, node: Tagged, schema: Mapping[TreeKey, Any] + ) -> Iterator[ValidationError]: """ Validate the given node from the ASDF tree. @@ -54,3 +68,4 @@ def validate(self, schema_property_value, node, schema): asdf.exceptions.ValidationError Yield an instance of ValidationError for each error present in the node. """ + ... diff --git a/asdf/tags/core/__init__.py b/asdf/tags/core/__init__.py index 2b74bcb13..2f42c5fdf 100644 --- a/asdf/tags/core/__init__.py +++ b/asdf/tags/core/__init__.py @@ -24,7 +24,10 @@ # to pass an isinstance(..., dict) check and to allow it to be "lazy" # loaded when "lazy_tree=True". class AsdfObject(collections.UserDict, dict): - pass + def __reduce__(self): + # Necessary for correct pickling/unpickling + # Otherwise pickle will use dict's reduce method which causes UserDict to fail to unpickle + return super(collections.UserDict, self).__reduce__() class Software(dict): diff --git a/pyproject.toml b/pyproject.toml index b4a5b2d18..a2f6f676b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -140,9 +140,7 @@ omit = [ ] [tool.coverage.report] -exclude_lines = [ - # Have to re-enable the standard pragma - "pragma: no cover", +exclude_also = [ # Don't complain about packages we have installed "except ImportError", # Don't complain if tests don't hit assertions @@ -152,8 +150,6 @@ exclude_lines = [ 'def main\(.*\):', # Ignore branches that don't pertain to this version of Python "pragma: py{ ignore_python_version }", - # Ignore type-checking imports - "if TYPE_CHECKING:", ] [tool.ruff]