Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions asdf/_tests/test_asdf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
import pickle

import numpy as np
import pytest

from asdf import config_context
Expand Down Expand Up @@ -379,3 +381,12 @@ def test_fsspec_http(httpserver):
with fsspec.open(fn) as f:
af = open_asdf(f)
assert_tree_match(tree, af.tree)


def test_asdf_file_pickle_from_dict():
"""Verify that an AsdfFile created from a dict (with no file descriptor) can be pickled"""
tree = {"a": 1, "b": {"c": 2, "d": np.ones((10, 10))}}
af = AsdfFile(tree)
pkl = pickle.dumps(af)
loaded = pickle.loads(pkl) # noqa: S301
assert_tree_match(af.tree, loaded.tree)
28 changes: 27 additions & 1 deletion asdf/_tests/test_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
Validator,
get_cached_extension_manager,
)
from asdf.extension._manager import _resolve_type
from asdf.extension._manager import ValidatorManager, _resolve_type
from asdf.tagged import TaggedList
from asdf.testing.helpers import roundtrip_object


Expand Down Expand Up @@ -734,6 +735,31 @@ def test_validator():
af.validate()


class ValidatorFailOn(Validator):
schema_property = "fail"
tags = ["fail"]

def __init__(self, fail_on):
self.fail_on = fail_on

def validate(self, schema_property_value, node, schema):
if schema_property_value == self.fail_on:
yield ValidationError("Node was doomed to fail")


def test_validator_manager():
validator = ValidatorManager([ValidatorFailOn("bar")])
errs = list(validator.validate("fail", "foo", TaggedList([], "fail"), {}))
assert len(errs) == 0

errs = list(validator.validate("fail", "bar", TaggedList([], "other"), {}))
assert len(errs) == 0

errs = list(validator.validate("fail", "bar", TaggedList([], "fail"), {}))
assert len(errs) == 1
assert isinstance(errs[0], ValidationError)


def test_converter_deferral():
class Bar:
def __init__(self, value):
Expand Down
81 changes: 47 additions & 34 deletions asdf/extension/_manager.py
Comment thread
sydduckworth marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
from __future__ import annotations

import sys
from dataclasses import dataclass
from functools import lru_cache
from typing import TYPE_CHECKING

from asdf.tagged import Tagged
from asdf.util import get_class_name, uri_match

from ._extension import ExtensionProxy

if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Mapping
from typing import Any

from asdf.exceptions import ValidationError
from asdf.extension import Validator
from asdf.typing import TreeKey


def _resolve_type(path):
"""
Expand Down Expand Up @@ -317,32 +329,26 @@ def _get_cached_extension_manager(extensions):

class ValidatorManager:
"""
Wraps a list of custom validators and indexes them by schema property.
Wraps a list of custom validators and binds them to their associated schemas.

Parameters
----------
validators : iterable of asdf.extension.Validator
List of validators to manage.
"""

def __init__(self, validators):
self._validators = list(validators)
def __init__(self, validators: Iterable[Validator]):
self._validators = {}
for validator in validators:
if validator.schema_property not in self._validators:
self._validators[validator.schema_property] = set()

self._validators_by_schema_property = {}
for validator in self._validators:
if validator.schema_property not in self._validators_by_schema_property:
self._validators_by_schema_property[validator.schema_property] = set()
self._validators_by_schema_property[validator.schema_property].add(validator)
self._validators[validator.schema_property].add(validator)

self._jsonschema_validators_by_schema_property = {}
for schema_property in self._validators_by_schema_property:
self._jsonschema_validators_by_schema_property[schema_property] = self._get_jsonschema_validator(
schema_property,
)

def validate(self, schema_property, schema_property_value, node, schema):
"""
Validate an ASDF tree node against custom validators for a schema property.
def validate(
self, schema_property: str, schema_property_value: Any, node: Tagged, schema: Mapping[TreeKey, Any]
) -> Iterator[ValidationError]:
"""Validate an ASDF tree node against custom validators for a schema property.

Parameters
----------
Expand All @@ -360,27 +366,34 @@ def validate(self, schema_property, schema_property_value, node, schema):
------
asdf.exceptions.ValidationError
"""
if schema_property in self._validators_by_schema_property:
for validator in self._validators_by_schema_property[schema_property]:
if _validator_matches(validator, node):
yield from validator.validate(schema_property_value, node, schema)
for validator in self._validators[schema_property]:
if _validator_matches(validator, node):
yield from validator.validate(schema_property_value, node, schema)

def get_jsonschema_validators(self):
"""
Get a dictionary of validator methods suitable for use
with the jsonschema library.
def get_jsonschema_validators(self) -> dict[str, JsonSchemaValidators]:
"""Get a dictionary mapping schema names to ``jsonschema``-compatible validator functions."""
return {
schema_property: JsonSchemaValidators(schema_property, frozenset(validators))
for schema_property, validators in self._validators.items()
}

Returns
-------
dict of str: callable
"""
return dict(self._jsonschema_validators_by_schema_property)

def _get_jsonschema_validator(self, schema_property):
def _validator(_, schema_property_value, node, schema):
return self.validate(schema_property, schema_property_value, node, schema)
@dataclass(frozen=True, slots=True)
class JsonSchemaValidators:
"""Callable that wraps a set of `Validator` objects to make them compatible with `jsonschema`.

Each validator is always passed `schema_property` as its first argument regardless of the actual input schema.
"""

schema_property: str
validators: frozenset[Validator]

return _validator
def __call__(
self, _schema_property: Any, schema_property_value: Any, node: Tagged, schema: Mapping[TreeKey, Any]
) -> Iterator[ValidationError]:
for validator in self.validators:
if _validator_matches(validator, node):
yield from validator.validate(schema_property_value, node, schema)


def _validator_matches(validator, node):
Expand Down
21 changes: 18 additions & 3 deletions asdf/extension/_validator.py
Comment thread
braingram marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
from __future__ import annotations

import abc
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Mapping

from asdf.exceptions import ValidationError
from asdf.tagged import Tagged
from asdf.typing import TreeKey


class Validator(abc.ABC):
Expand All @@ -8,13 +18,14 @@ class Validator(abc.ABC):
"""

@abc.abstractproperty
def schema_property(self):
def schema_property(self) -> str:
"""
Name of the schema property used to invoke this validator.
"""
...

@abc.abstractproperty
def tags(self):
def tags(self) -> Iterable[str]:
"""
Get the YAML tags that are appropriate to this validator.
URI patterns are permitted, see `asdf.util.uri_match` for details.
Expand All @@ -24,9 +35,12 @@ def tags(self):
iterable of str
Tag URIs or URI patterns.
"""
...

@abc.abstractmethod
def validate(self, schema_property_value, node, schema):
def validate(
self, schema_property_value: Any, node: Tagged, schema: Mapping[TreeKey, Any]
) -> Iterator[ValidationError]:
"""
Validate the given node from the ASDF tree.

Expand Down Expand Up @@ -54,3 +68,4 @@ def validate(self, schema_property_value, node, schema):
asdf.exceptions.ValidationError
Yield an instance of ValidationError for each error present in the node.
"""
...
5 changes: 4 additions & 1 deletion asdf/tags/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@
# to pass an isinstance(..., dict) check and to allow it to be "lazy"
# loaded when "lazy_tree=True".
class AsdfObject(collections.UserDict, dict):
pass
def __reduce__(self):
# Necessary for correct pickling/unpickling
# Otherwise pickle will use dict's reduce method which causes UserDict to fail to unpickle
return super(collections.UserDict, self).__reduce__()


class Software(dict):
Expand Down
6 changes: 1 addition & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,7 @@ omit = [
]

[tool.coverage.report]
exclude_lines = [
# Have to re-enable the standard pragma
"pragma: no cover",
exclude_also = [
# Don't complain about packages we have installed
"except ImportError",
# Don't complain if tests don't hit assertions
Expand All @@ -152,8 +150,6 @@ exclude_lines = [
'def main\(.*\):',
# Ignore branches that don't pertain to this version of Python
"pragma: py{ ignore_python_version }",
# Ignore type-checking imports
"if TYPE_CHECKING:",
]

[tool.ruff]
Expand Down
Loading