Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 5 additions & 13 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,30 +1,22 @@

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: no-commit-to-branch
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-json
- id: check-toml
- id: check-added-large-files

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.3
rev: v0.12.5
hooks:
- id: ruff
- id: ruff-check
- id: ruff-format

- repo: https://github.com/pycqa/pydocstyle
rev: 6.3.0
hooks:
- id: pydocstyle
exclude: (^tests)
args:
- --convention=google

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.13.0
rev: v1.17.0
hooks:
- id: mypy
args:
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# DataSerious

[![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)
![Tests](https://github.com/Noza23/dataserious/actions/workflows/tests.yaml/badge.svg)
[![codecov](https://codecov.io/gh/Noza23/dataserious/graph/badge.svg?token=m9yHQyL0sQ)](https://codecov.io/gh/Noza23/dataserious)

Expand Down
42 changes: 25 additions & 17 deletions dataserious/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import itertools
import json
import operator
import random
import secrets
import sys
from dataclasses import MISSING
from functools import reduce
Expand Down Expand Up @@ -66,7 +66,7 @@ class BaseConfig:
Example::

class SomeConfig(BaseConfig):
z: list[int] = field(metadata={'description': 'list of integers'})
z: list[int] = field(metadata={"description": "list of integers"})

"""

Expand Down Expand Up @@ -100,9 +100,10 @@ def __post_init__(self):
class SomeConfig(BaseConfig):
x: int


def __post_init__(self):
super().__post_init__()
assert x < 2 # Additional check
super().__post_init__()
assert x < 2 # Additional check
"""
for field in self.fields():
self._modify_field(
Expand All @@ -111,11 +112,11 @@ def __post_init__(self):

if not check_type(getattr(self, field.name), field.type):
raise TypeError(
'\n'
f'| loc: {self.__class__.__name__}.{field.name}\n'
f'| expects: {type_to_view_string(field.type)}\n'
f'| got: {type(getattr(self, field.name))}\n'
f'| description: {field.metadata.get("description")}\n'
"\n"
f"| loc: {self.__class__.__name__}.{field.name}\n"
f"| expects: {type_to_view_string(field.type)}\n"
f"| got: {type(getattr(self, field.name))}\n"
f"| description: {field.metadata.get('description')}\n"
)

def __contains__(self, item):
Expand Down Expand Up @@ -174,12 +175,14 @@ def get_by_path(self, path: list[str] | str):
class SomeOtherConfig(BaseConfig):
z: int


class SomeConfig(BaseConfig):
x: SomeOtherConfig
y: int


c = SomeConfig(SomeOtherConfig(z=1), 2)
assert c.get_by_path('x.z') == 1
assert c.get_by_path("x.z") == 1

"""
if isinstance(path, str):
Expand All @@ -201,8 +204,9 @@ def replace(self: C, /, **changes) -> C:
Example::

class SomeConfig(BaseConfig):
x: int
y: int
x: int
y: int


c = SomeConfig(1, 2)
c_new = c.replace(x=3)
Expand Down Expand Up @@ -302,7 +306,7 @@ def from_dir(cls, path: str | Path):
load all of them.

"""
patt = f'*[{"|".join(YAML_SUFFIXES + JSON_SUFFIXES)}]'
patt = f"*[{'|'.join(YAML_SUFFIXES + JSON_SUFFIXES)}]"
return [cls.from_file(p) for p in Path(path).glob(patt)]

@classmethod
Expand Down Expand Up @@ -491,12 +495,13 @@ def _yield_config_search_space(
mapping = _get_grid_mapping(search_tree)
product = itertools.product(*mapping.values())
if random_n: # Random Search might get slow for large search spaces.
assert seed is not None, "Seed must be provided for Random Search."
if seed is None:
raise ValueError("Seed must be provided for Random Search.")
product_list = list(product)
random.Random(seed).shuffle(product_list)
secrets.SystemRandom(seed).shuffle(product_list)
product = product_list[:random_n] # type: ignore[assignment]
for values in product:
for k, values_ in zip(mapping.keys(), values):
for k, values_ in zip(mapping.keys(), values, strict=True):
set_config_value_by_path(config_tree, k, values_)
yield config_tree

Expand Down Expand Up @@ -628,7 +633,10 @@ def parse(attr, annot: Annotation):
}

if isinstance(annot, ForwardRef):
return parse(attr, eval(annot.__forward_arg__))
annot = ForwardRef._evaluate(
annot, globalns=globals(), localns=locals(), recursive_guard=frozenset()
)
return parse(attr, annot)
return attr


Expand Down
22 changes: 13 additions & 9 deletions dataserious/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def check_type(attr, annot: Annotation) -> bool:
False
>>> check_type([1, "a"], list[int | str])
True
>>> check_type({'a': [1, 2], 'b': ["a", "b"]}, JsonType)
>>> check_type({"a": [1, 2], "b": ["a", "b"]}, JsonType)
True
>>> check_type([BaseConfig()], list[BaseConfig])
True
Expand All @@ -33,29 +33,30 @@ def check_type(attr, annot: Annotation) -> bool:
return True

if isinstance(annot, UnionTypes):
return any([check_type(attr, t) for t in get_args(annot)])
return any(check_type(attr, t) for t in get_args(annot))

if isinstance(annot, GenericAliasTypes):
origin = get_origin(annot)
args = get_args(annot)
if isclasssubclass(origin, (List, Set)):
return isinstance(attr, (list, set)) and all(
[check_type(element, args[0]) for element in attr]
check_type(element, args[0]) for element in attr
)
elif isclasssubclass(origin, Dict):
return isinstance(attr, dict) and all(
[
check_type(key, args[0]) and check_type(value, args[1])
for key, value in attr.items()
]
check_type(key, args[0]) and check_type(value, args[1])
for key, value in attr.items()
)
elif origin is Literal:
return attr in args
else:
return isinstance(attr, origin)

if isinstance(annot, ForwardRef):
return check_type(attr, eval(annot.__forward_arg__))
annot = ForwardRef._evaluate(
annot, globalns=globals(), localns=locals(), recursive_guard=frozenset()
)
return check_type(attr, annot)

return isinstance(attr, annot)

Expand Down Expand Up @@ -85,5 +86,8 @@ def type_to_view_string(annot: Annotation):
if origin is Literal:
return " | ".join(a for a in args)
if isinstance(annot, ForwardRef):
return type_to_view_string(eval(annot.__forward_arg__))
annot = ForwardRef._evaluate(
annot, globalns=globals(), localns=locals(), recursive_guard=frozenset()
)
return type_to_view_string(annot)
return annot.__name__
21 changes: 20 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,25 @@ dev = ["pytest", "coverage", "ruff", "pre-commit"]

[tool.ruff]
line-length = 88

[tool.ruff.lint]
select = [
"I", # Sort Imports
"D", # Docstrings
"A", # flake8-builtins
"T20", # flake8-print
"C4", # flake8-comprehensions
"B", # flake8-bugbear
"S", # flake8-bandit
"BLE", #flake8-blind-except
]
fixable = ["I"]
extend-select = ["I", "T"]
pydocstyle = { convention = "google" }

[tool.ruff.lint.per-file-ignores]
"tests/*" = [ "D", "B", "S", "C4" ]

[tool.ruff.format]
quote-style = "double"
indent-style = "space"
docstring-code-format = true