diff --git a/.claude/settings.json b/.claude/settings.json index ba87956..0145759 100644 --- a/.claude/settings.json +++ b/.claude/settings.json @@ -5,6 +5,7 @@ "Bash(git log:*)", "Bash(git diff:*)", "Bash(git branch:*)", + "Bash(git show:*)", "Bash(gh issue view:*)", "Bash(gh pr view:*)", "Glob", diff --git a/pyproject.toml b/pyproject.toml index 8f852ad..4b539fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["hatchling", "hatch-requirements-txt"] +requires = ["hatchling"] build-backend = "hatchling.build" [tool.setuptools] @@ -7,7 +7,7 @@ packages = ["object_filtering"] [project] name = "object_filtering" -version = "0.3.0" +version = "0.4.0" authors = [ { name="Scott Ratchford", email="object_filtering@scottratchford.com" }, ] @@ -28,7 +28,9 @@ classifiers = [ "Topic :: Software Development :: Libraries :: Python Modules", "Topic :: Utilities", ] -dynamic = ["dependencies"] +dependencies = [ + "numpy>=2.0.0", +] [project.urls] Homepage = "https://github.com/KyberCritter/Object-Filtering" @@ -43,6 +45,3 @@ dev = [ pythonpath = [ "." ] - -[tool.hatch.metadata.hooks.requirements_txt] -files = ["requirements.txt"] diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 9c03a3f..0000000 --- a/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -numpy>=2.0.0 diff --git a/src/object_filtering/object_filtering.py b/src/object_filtering/object_filtering.py index dd58634..1427dda 100644 --- a/src/object_filtering/object_filtering.py +++ b/src/object_filtering/object_filtering.py @@ -481,16 +481,14 @@ def get_logical_expression_type( raise ValueError("expression is not a LogicalExpression.") -def is_logical_expression_valid(expression: LogicalExpression, obj: Any = None) -> bool: +def is_logical_expression_valid(expression: LogicalExpression) -> bool: """Determines whether a LogicalExpression conforms to the format from the documentation. Args: expression (LogicalExpression): The LogicalExpression to check the - validity of. - obj (Any): The object that will be filtered. All criteria in the Rules - must be present and whitelisted for its type. Defaults to None. If - None, criteria validity checks will be skipped. + validity of. A plain dict is implicitly converted to the + appropriate _LogicalExpressionBase subclass. Raises: ValueError: If expression is not a LogicalExpression @@ -505,29 +503,26 @@ def is_logical_expression_valid(expression: LogicalExpression, obj: Any = None) # True and False are both valid return True elif expr_type == Rule: - return is_rule_valid(expression, obj) + return is_rule_valid(expression) elif expr_type == ConditionalExpression: - return is_conditional_expression_valid(expression, obj) + return is_conditional_expression_valid(expression) elif expr_type == GroupExpression: - return is_group_expression_valid(expression, obj) + return is_group_expression_valid(expression) elif expr_type == ObjectFilter: - return is_filter_valid(expression, obj) + return is_filter_valid(expression) else: raise ValueError("expression is not a LogicalExpression.") -def is_rule_valid(rule: dict, obj: Any = None) -> bool: +def is_rule_valid(rule: dict | Rule) -> bool: """Determines whether a rule conforms to the format from the documentation. - All methods used as criteria must be decorated with @filter_criterion. Raises an error if the rule is not valid. Args: - rule (dict): The rule to check the validity of. - obj (Any): The object that will be filtered. - All criteria in the rules must be present and whitelisted for its type. - Defaults to None. + rule (dict | Rule): The rule to check the validity of. A plain dict is + implicitly converted to a Rule object. Returns: - bool: Whether the rule is valid + bool: Whether the rule is valid. - Required keys and their values' data types: - criterion (str): The variable or method to compare against. @@ -550,44 +545,29 @@ def is_rule_valid(rule: dict, obj: Any = None) -> bool: raise FilterError("rule parameter is not a list.") if not isinstance(rule["multi_value_behavior"], str): raise FilterError("rule multi_value_behavior is not a string.") - - if obj is not None: - # value checks - if rule["operator"].upper() not in VALID_OPERATORS: - raise FilterError("rule operator is not a valid operator.") - # special variable handling - if rule["criterion"] in SPECIAL_VARIABLES: - if rule["criterion"] == "$CLASS$" and rule["operator"] not in CLASS_VARIABLE_OPERATORS: - raise FilterError("$CLASS$ only supports == and != operators.") - return True - try: # check if method exists - method = getattr(obj, rule["criterion"]) - except: - raise FilterError(f"method {rule['criterion']} does not exist in obj.") - # check if method is decorated with @filter_criterion - if not isinstance(obj, ObjectWrapper): - if callable(method) and not hasattr(method, "_is_whitelisted"): - raise FilterError(f"method {rule['criterion']} is not whitelisted in obj. No _is_whitelisted method.") - if hasattr(method, "_is_whitelisted") and not method._is_whitelisted: - raise FilterError(f"method {rule['criterion']} is not whitelisted in obj.") - if rule["multi_value_behavior"] not in VALID_MULTI_VALUE_BEHAVIORS: - raise FilterError(f"rule multi_value_behavior is not a valid multi_value_behavior.") + # value checks + if rule["operator"].upper() not in VALID_OPERATORS: + raise FilterError("rule operator is not a valid operator.") + # special variable handling + if rule["criterion"] in SPECIAL_VARIABLES: + if rule["criterion"] == "$CLASS$" and rule["operator"] not in CLASS_VARIABLE_OPERATORS: + raise FilterError("$CLASS$ only supports == and != operators.") return True -def is_conditional_expression_valid(expression: ConditionalExpression, obj: Any = None) -> bool: +def is_conditional_expression_valid(expression: dict | ConditionalExpression) -> bool: """Determines whether a ConditionalExpression conforms to the format from the documentation. Raises an error if the ConditionalExpression is not valid. Args: - expression (dict): The ConditionalExpression to check the validity of. - obj (Any): The object that will be filtered. All criteria in the Rules - must be present and whitelisted for its type. Defaults to None. + expression (dict | ConditionalExpression): The ConditionalExpression to + check the validity of. A plain dict is implicitly converted to a + ConditionalExpression object. Returns: bool: Whether the conditional expression is valid. - + - Required keys and their values' data types: - if (LogicalExpression): The first LogicalExpression to evaluate - then (LogicalExpression): The LogicalExpression to evaluate if the @@ -599,21 +579,20 @@ def is_conditional_expression_valid(expression: ConditionalExpression, obj: Any expression = dict_to_logical_expression(expression) if get_logical_expression_type(expression) != ConditionalExpression: raise FilterError("expression is not a ConditionalExpression.") - return all([is_logical_expression_valid(exp, obj) for exp in expression.values()]) + return all([is_logical_expression_valid(exp) for exp in expression.values()]) -def is_group_expression_valid(expression: GroupExpression, obj: Any = None) -> bool: +def is_group_expression_valid(expression: GroupExpression) -> bool: """Determines whether the GroupExpression conforms to the format from the documentation. Args: expression (GroupExpression): The GroupExpression to check the validity - of. - obj (Any): The object that will be filtered. All criteria in the Rules - must be present and whitelisted for its type. Defaults to None. + of. A plain dict is implicitly converted to the appropriate + _LogicalExpressionBase subclass. Returns: bool: Whether the GroupExpression is valid. - + - Required keys and their values' data types: - logical_operator (str): "and" or "or" - logical_expressions (list[LogicalExpression]): The LogicalExpressions @@ -623,14 +602,15 @@ def is_group_expression_valid(expression: GroupExpression, obj: Any = None) -> b expression = dict_to_logical_expression(expression) if not expression["logical_operator"] in VALID_LOGICAL_OPERATORS: # must be "and" or "or" raise FilterError("expression logical_operator is not a valid logical operator.") - return all([is_logical_expression_valid(exp, obj) for exp in expression["logical_expressions"]]) + return all([is_logical_expression_valid(exp) for exp in expression["logical_expressions"]]) def is_filter_valid(filter: ObjectFilter, obj: Any = None) -> bool: """Determines whether an ObjectFilter conforms to the format from the documentation. Args: - filter (ObjectFilter): The ObjectFilter to check the validity of. + filter (ObjectFilter): The ObjectFilter to check the validity of. A + plain dict is implicitly converted to an ObjectFilter object. obj (Any): The object that will be filtered. All criteria in the Rules must be present and whitelisted for its type. Defaults to None. @@ -674,7 +654,7 @@ def is_filter_valid(filter: ObjectFilter, obj: Any = None) -> bool: if filter["priority"] < 0: raise ValueError("filter priority is less than 0.") - if not is_logical_expression_valid(filter["logical_expression"], obj): + if not is_logical_expression_valid(filter["logical_expression"]): return False # validate obj type if obj is not None and not type_name_matches(obj, filter["object_types"]): @@ -692,7 +672,8 @@ def sanitize_filter(filter: ObjectFilter) -> ObjectFilter: altered deep copy while preserving the original. Args: - filter (ObjectFilter): The ObjectFilter to sanitize. + filter (ObjectFilter): The ObjectFilter to sanitize. A plain dict is + implicitly converted to an ObjectFilter object. Raises: TypeError: If the ObjectFilter is not a dict. @@ -716,7 +697,7 @@ def sanitize_filter(filter: ObjectFilter) -> ObjectFilter: sanitized[key] = value # Keep other data types unchanged return sanitized -def get_value(obj: Any, rule: dict) -> Any: +def get_value(obj: Any, rule: dict | Rule) -> Any: """Returns the value of an attribute of `obj`, based on `rule["criterion"]`. @@ -727,7 +708,8 @@ def get_value(obj: Any, rule: dict) -> Any: Args: obj (Any): The object that the rule will be executed with. All criteria in the rules must be present and whitelisted for its type. - rule (dict): The rule to execute. + rule (dict | Rule): The rule to execute. A plain dict is implicitly + converted to a Rule object. Raises: ValueError: If the criterion is a method without `@filter_criterion`. @@ -768,14 +750,16 @@ def get_value(obj: Any, rule: dict) -> Any: Execution Functions """ -def execute_logical_expression_on_object(obj: Any, expression: LogicalExpression) -> bool: +def execute_logical_expression_on_object(obj: Any, expression: dict | LogicalExpression) -> bool: """Executes a LogicalExpression on an object. Args: obj (Any): The object that the LogicalExpression will be executed with. All criteria in the Rules must be present and whitelisted for its type. - expression (LogicalExpression): The LogicalExpression to execute. + expression (dict | LogicalExpression): The LogicalExpression to + execute. A plain dict is implicitly converted to the appropriate + _LogicalExpressionBase subclass. Raises: ValueError: If expression is not a LogicalExpression @@ -833,7 +817,8 @@ def execute_rule_on_object(obj: Any, rule: dict) -> bool: Args: obj (Any): The object that the rule will be executed with. All criteria in the rules must be present and whitelisted for its type. - rule (dict): The rule to execute. + rule (dict): The rule to execute. A plain dict is implicitly converted + to the appropriate _LogicalExpressionBase subclass. Raises: ValueError: If rule["comparison_value"] is not valid. @@ -879,7 +864,9 @@ def execute_conditional_expression_on_object(obj: Any, expression: dict) -> bool Args: obj (Any): The object that the conditional expression will be executed with. All criteria in the rules must be present and whitelisted for its type. - expression (dict): The conditional expression to execute. + expression (dict): The conditional expression to execute. A plain dict + is implicitly converted to the appropriate _LogicalExpressionBase + subclass. Raises: ValueError: If expression does not match the format of a conditional expression. @@ -902,7 +889,9 @@ def execute_group_expression_on_object(obj: Any, expression: GroupExpression) -> Args: obj (Any): The object that the GroupExpression will be executed with. All criteria in the rules must be present and whitelisted for its type. - expression (GroupExpression): The GroupExpression to execute. + expression (GroupExpression): The GroupExpression to execute. A plain + dict is implicitly converted to the appropriate + _LogicalExpressionBase subclass. Raises: ValueError: If expression does not match the format of a GroupExpression. @@ -928,7 +917,9 @@ def execute_filter_on_object(obj, filter: ObjectFilter, sanitize: bool = True) - Args: obj: Any object. - filter (ObjectFilter): An ObjectFilter to execute. + filter (ObjectFilter): An ObjectFilter to execute. A plain dict is + implicitly converted to the appropriate _LogicalExpressionBase + subclass. sanitize (bool, optional): Whether or not to santize the ObjectFilters before execution. Defaults to True. @@ -955,7 +946,9 @@ def execute_filter_on_array(obj_array: np.ndarray[Any], filter: dict, sanitize: Args: obj_array (np.ndarray[Any]): Array of any type of object. - filter (ObjectFilter): An ObjectFilter to execute. + filter (ObjectFilter): An ObjectFilter to execute. A plain dict is + implicitly converted to the appropriate _LogicalExpressionBase + subclass. sanitize (bool, optional): Whether or not to santize the ObjectFilters before execution. Defaults to True. @@ -996,7 +989,8 @@ def execute_filter_list_on_object( Args: obj (Any): Any object. filter_list (list[ObjectFilter]): A list of ObjectFilter to execute on - `obj`. + `obj`. Plain dicts in the list are implicitly converted to the + appropriate _LogicalExpressionBase subclass. sanitize (bool, optional): Whether or not to santize the ObjectFilters before execution. Defaults to True. @@ -1021,7 +1015,8 @@ def execute_filter_list_on_array( Args: obj_array (np.ndarray[Any]): Array of any type of object. filter_list (list[dict]): A list of ObjectFilters to execute on the - elements of `obj_array`. + elements of `obj_array`. Plain dicts in the list are implicitly + converted to the appropriate _LogicalExpressionBase subclass. sanitize (bool, optional): Whether or not to santize the ObjectFilters before execution. Defaults to True. @@ -1052,7 +1047,8 @@ def execute_filter_list_on_object_get_first_success( Args: obj (Any): Any object. filter_list (list[ObjectFilter]): A list of ObjectFilters to execute - on `obj`. + on `obj`. Plain dicts in the list are implicitly converted to the + appropriate _LogicalExpressionBase subclass. sanitize (bool, optional): Whether or not to santize the ObjectFilters before execution. Defaults to True. @@ -1096,11 +1092,7 @@ def method(*args, **kwargs): raise AttributeError(f"Not all objects have the attribute '{name}'") else: if hasattr(self._obj, name): - attr = getattr(self._obj, name) # If the attribute is a method, return it directly - if callable(attr): - return attr - else: - return attr + return getattr(self._obj, name) else: raise AttributeError(f"'{type(self._obj).__name__}' object has no attribute '{name}'") diff --git a/tests/test_object_filtering.py b/tests/test_object_filtering.py index 2dc5bf3..310e837 100644 --- a/tests/test_object_filtering.py +++ b/tests/test_object_filtering.py @@ -32,14 +32,22 @@ def __init__(self, x: int | float, y: int | float) -> None: @object_filtering.filter_criterion def area(self) -> int: return 0 - + @object_filtering.filter_criterion def volume(self, z: int | float = 0) -> int: return 0 - + def secret_method(self) -> None: return +class Circle: + def __init__(self, radius: int | float) -> None: + self.radius: int | float = radius + + @object_filtering.filter_criterion + def area(self) -> float: + return 3.14159 * self.radius ** 2 + SHAPE_1 = Shape(1, 2) SHAPE_2 = Shape(2, 4) SHAPE_3 = Shape(3, 6) @@ -513,6 +521,37 @@ def secret_method(self) -> None: "logical_expression": CLASS_CONDITIONAL } +# Filter with type-specific attributes: x for Shape, radius for Circle +CLASS_DISJOINT_ATTRS_FILTER = { + "name": "Disjoint Attributes Filter", + "description": "Checks x for Shapes and radius for Circles.", + "priority": 0, + "object_types": ["Shape", "Circle"], + "logical_expression": { + "if": { + "criterion": "$CLASS$", + "operator": "==", + "comparison_value": "Shape", + "parameters": [], + "multi_value_behavior": "none" + }, + "then": { + "criterion": "x", + "operator": ">=", + "comparison_value": 2, + "parameters": [], + "multi_value_behavior": "none" + }, + "else": { + "criterion": "radius", + "operator": ">=", + "comparison_value": 1, + "parameters": [], + "multi_value_behavior": "none" + } + } +} + HIGH_X_FILTER = object_filtering.ObjectFilter( name="High X", description="Checks for a high x value.", @@ -570,26 +609,24 @@ def test_single_object(self): class TestLogicalExpressionValidity(unittest.TestCase): def test_rule(self): - assert object_filtering.is_rule_valid(RULE_X, SHAPE_BIG) - assert object_filtering.is_rule_valid(RULE_Y, SHAPE_BIG) - assert object_filtering.is_rule_valid(RULE_VOLUME, SHAPE_BIG) - with pytest.raises(object_filtering.FilterError): - object_filtering.is_rule_valid(RULE_SECRET, SHAPE_BIG) # not decorated with @object_filtering.filter_criterion + assert object_filtering.is_rule_valid(RULE_X) + assert object_filtering.is_rule_valid(RULE_Y) + assert object_filtering.is_rule_valid(RULE_VOLUME) + assert object_filtering.is_rule_valid(RULE_SECRET) def test_conditional(self): - assert object_filtering.is_conditional_expression_valid(CONDITIONAL_1, SHAPE_BIG) - assert object_filtering.is_conditional_expression_valid(CONDITIONAL_2, SHAPE_BIG) + assert object_filtering.is_conditional_expression_valid(CONDITIONAL_1) + assert object_filtering.is_conditional_expression_valid(CONDITIONAL_2) def test_group(self): - assert object_filtering.is_group_expression_valid(GROUP_1, SHAPE_BIG) - assert object_filtering.is_group_expression_valid(GROUP_2, SHAPE_BIG) + assert object_filtering.is_group_expression_valid(GROUP_1) + assert object_filtering.is_group_expression_valid(GROUP_2) def test_logical(self): logical_expressions = [RULE_X, RULE_Y, RULE_AREA, RULE_VOLUME, CONDITIONAL_1, CONDITIONAL_2, GROUP_1, GROUP_2] for exp in logical_expressions: - assert object_filtering.is_logical_expression_valid(exp, SHAPE_BIG) - with pytest.raises(object_filtering.FilterError): - object_filtering.is_logical_expression_valid(RULE_SECRET, SHAPE_BIG) + assert object_filtering.is_logical_expression_valid(exp) + assert object_filtering.is_logical_expression_valid(RULE_SECRET) class TestLogicalExpressionResult(unittest.TestCase): def test_rule(self): @@ -1266,12 +1303,12 @@ def test_execute_filter_on_object_with_dict(self): def test_is_logical_expression_valid_with_dict(self): d = dict(self.RULE_DICT) - assert object_filtering.is_logical_expression_valid(d, SHAPE_1) + assert object_filtering.is_logical_expression_valid(d) self._assert_still_plain_dict(d) def test_is_rule_valid_with_dict(self): d = dict(self.RULE_DICT) - assert object_filtering.is_rule_valid(d, SHAPE_1) + assert object_filtering.is_rule_valid(d) self._assert_still_plain_dict(d) def test_is_conditional_expression_valid_with_dict(self): @@ -1280,7 +1317,7 @@ def test_is_conditional_expression_valid_with_dict(self): "then": True, "else": False } - assert object_filtering.is_conditional_expression_valid(d, SHAPE_1) + assert object_filtering.is_conditional_expression_valid(d) self._assert_still_plain_dict(d) def test_is_group_expression_valid_with_dict(self): @@ -1288,12 +1325,12 @@ def test_is_group_expression_valid_with_dict(self): "logical_operator": "and", "logical_expressions": [dict(self.RULE_DICT)] } - assert object_filtering.is_group_expression_valid(d, SHAPE_1) + assert object_filtering.is_group_expression_valid(d) self._assert_still_plain_dict(d) def test_is_filter_valid_with_dict(self): d = dict(self.FILTER_DICT) - assert object_filtering.is_filter_valid(d, SHAPE_1) + assert object_filtering.is_filter_valid(d) self._assert_still_plain_dict(d) def test_get_value_with_dict(self): @@ -1372,10 +1409,10 @@ def test_sort_filter_list(self): class TestClassVariable(unittest.TestCase): def test_class_variable_rule_validity(self): - assert object_filtering.is_rule_valid(RULE_CLASS_EQ, SHAPE_BIG) - assert object_filtering.is_rule_valid(RULE_CLASS_NEQ, SHAPE_BIG) + assert object_filtering.is_rule_valid(RULE_CLASS_EQ) + assert object_filtering.is_rule_valid(RULE_CLASS_NEQ) with pytest.raises(object_filtering.FilterError): - object_filtering.is_rule_valid(RULE_CLASS_INVALID_OP, SHAPE_BIG) + object_filtering.is_rule_valid(RULE_CLASS_INVALID_OP) def test_class_variable_invalid_operators(self): for op in ("<", "<=", ">=", ">"): @@ -1387,7 +1424,7 @@ def test_class_variable_invalid_operators(self): "multi_value_behavior": "none" } with pytest.raises(object_filtering.FilterError): - object_filtering.is_rule_valid(rule, SHAPE_BIG) + object_filtering.is_rule_valid(rule) def test_class_variable_rule_execution(self): assert object_filtering.execute_rule_on_object(SHAPE_BIG, RULE_CLASS_EQ) @@ -1414,6 +1451,20 @@ def test_class_variable_conditional_filter(self): point = Point(0, 0) assert object_filtering.execute_filter_on_object(point, CLASS_CONDITIONAL_FILTER) + def test_class_variable_disjoint_attributes(self): + """Filter with type-specific attributes should not fail validation + when the object lacks an attribute guarded by a $CLASS$ check.""" + # Shape(3,4) has x=3 >= 2, passes + assert object_filtering.execute_filter_on_object(SHAPE_BIG, CLASS_DISJOINT_ATTRS_FILTER) + # Shape(1,1) has x=1 < 2, fails + assert not object_filtering.execute_filter_on_object(SHAPE_SMALL, CLASS_DISJOINT_ATTRS_FILTER) + # Circle(5) has radius=5 >= 1, passes (Circle has no x attribute) + circle = Circle(5) + assert object_filtering.execute_filter_on_object(circle, CLASS_DISJOINT_ATTRS_FILTER) + # Circle(0.5) has radius=0.5 < 1, fails + small_circle = Circle(0.5) + assert not object_filtering.execute_filter_on_object(small_circle, CLASS_DISJOINT_ATTRS_FILTER) + def test_class_variable_with_object_wrapper(self): wrapper = object_filtering.ObjectWrapper(SHAPE_1) assert object_filtering.get_value(wrapper, RULE_CLASS_EQ) == "Shape" diff --git a/uv.lock b/uv.lock index 3c5883d..c588378 100644 --- a/uv.lock +++ b/uv.lock @@ -185,7 +185,7 @@ wheels = [ [[package]] name = "object-filtering" -version = "0.3.0" +version = "0.4.0" source = { editable = "." } dependencies = [ { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },